* [bitbake-devel][PATCH v4 01/22] asyncrpc: Abstract sockets
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 02/22] hashserv: Add websocket connection implementation Joshua Watt
` (23 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 32 +---
lib/bb/asyncrpc/client.py | 78 +++------
lib/bb/asyncrpc/connection.py | 95 +++++++++++
lib/bb/asyncrpc/exceptions.py | 17 ++
lib/bb/asyncrpc/serv.py | 298 +++++++++++++++++-----------------
lib/hashserv/__init__.py | 21 ---
lib/hashserv/client.py | 38 ++---
lib/hashserv/server.py | 115 ++++++-------
lib/prserv/client.py | 8 +-
lib/prserv/serv.py | 31 ++--
10 files changed, 380 insertions(+), 353 deletions(-)
create mode 100644 lib/bb/asyncrpc/connection.py
create mode 100644 lib/bb/asyncrpc/exceptions.py
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9a85e996..9f677eac 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-import itertools
-import json
-
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
-
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
from .client import AsyncClient, Client
-from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+ ClientError,
+ ServerError,
+ ConnectionClosedError,
+)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe..7f33099b 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
- os.chdir(cwd)
- return await asyncio.open_unix_connection(sock=sock)
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ # End of headers
+ await self.socket.send("")
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
- self.reader = None
-
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
-
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
-
- m = json.loads("".join(lines))
-
- return m
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
new file mode 100644
index 00000000..c4fd2475
--- /dev/null
+++ b/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+ if len(msg) < max_chunk - 1:
+ yield "".join((msg, "\n"))
+ else:
+ yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+ args = [iter(msg)] * (max_chunk - 1)
+ for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+ yield "".join(itertools.chain(m, "\n"))
+ yield "\n"
+
+
+class StreamConnection(object):
+ def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+ self.reader = reader
+ self.writer = writer
+ self.timeout = timeout
+ self.max_chunk = max_chunk
+
+ @property
+ def address(self):
+ return self.writer.get_extra_info("peername")
+
+ async def send_message(self, msg):
+ for c in chunkify(json.dumps(msg), self.max_chunk):
+ self.writer.write(c.encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv_message(self):
+ l = await self.recv()
+
+ m = json.loads(l)
+ if not m:
+ return m
+
+ if "chunk-stream" in m:
+ lines = []
+ while True:
+ l = await self.recv()
+ if not l:
+ break
+ lines.append(l)
+
+ m = json.loads("".join(lines))
+
+ return m
+
+ async def send(self, msg):
+ self.writer.write(("%s\n" % msg).encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv(self):
+ if self.timeout < 0:
+ line = await self.reader.readline()
+ else:
+ try:
+ line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+
+ if not line:
+ raise ConnectionClosedError("Connection closed")
+
+ line = line.decode("utf-8")
+
+ if not line.endswith("\n"):
+ raise ConnectionError("Bad message %r" % (line))
+
+ return line.rstrip()
+
+ async def close(self):
+ self.reader = None
+ if self.writer is not None:
+ self.writer.close()
+ self.writer = None
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 00000000..a8942b4f
--- /dev/null
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+ pass
+
+
+class ServerError(Exception):
+ pass
+
+
+class ConnectionClosedError(Exception):
+ pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index d2de4891..8d4da1e2 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,242 @@ import signal
import socket
import sys
import multiprocessing
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
- pass
-
-
-class ServerError(Exception):
- pass
+from .connection import StreamConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
- def __init__(self, reader, writer, proto_name, logger):
- self.reader = reader
- self.writer = writer
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
self.proto_name = proto_name
- self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
- 'chunk-stream': self.handle_chunk,
- 'ping': self.handle_ping,
+ "ping": self.handle_ping,
}
self.logger = logger
+ async def close(self):
+ await self.socket.close()
+
async def process_requests(self):
try:
- self.addr = self.writer.get_extra_info('peername')
- self.logger.debug('Client %r connected' % (self.addr,))
+ self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
- client_protocol = await self.reader.readline()
+ client_protocol = await self.socket.recv()
if not client_protocol:
return
- (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+ (client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
- self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
- self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
- self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
- line = await self.reader.readline()
- if not line:
- return
-
- line = line.decode('utf-8').rstrip()
- if not line:
+ header = await self.socket.recv()
+ if not header:
break
# Handle messages
while True:
- d = await self.read_message()
+ d = await self.socket.recv_message()
if d is None:
break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
+ response = await self.dispatch_message(d)
+ await self.socket.send_message(response)
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
- self.writer.close()
+ await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- await self.handlers[k](msg[k])
- return
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
+ async def handle_ping(self, request):
+ return {"alive": True}
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
- try:
- message = l.decode('utf-8')
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
- if not message.endswith('\n'):
- return None
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
+
+ await self.handler(socket)
+
+ async def stop(self):
+ self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
+
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ pass
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % lines)
- raise e
+ def start(self, loop):
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
+ )
+ finally:
+ os.chdir(cwd)
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
- await self.dispatch_message(msg)
+ async def stop(self):
+ await super().stop()
+ self.server.close()
- async def handle_ping(self, request):
- response = {'alive': True}
- self.write_message(response)
+ def cleanup(self):
+ os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
- self._cleanup_socket = None
self.logger = logger
- self.start = None
- self.address = None
self.loop = None
+ self.run_tasks = []
def start_tcp_server(self, host, port):
- def start_tcp():
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port)
- )
-
- for s in self.server.sockets:
- self.logger.debug('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- # Enable keep alives. This prevents broken client connections
- # from persisting on the server for long periods of time.
- s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
-
- self.start = start_tcp
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
-
- def start_unix():
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path))
- )
- finally:
- os.chdir(cwd)
-
- self.logger.debug('Listening on %r' % path)
-
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
-
- self.start = start_unix
-
- @abc.abstractmethod
- def accept_client(self, reader, writer):
- pass
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
+ async def _client_handler(self, socket):
try:
- client = self.accept_client(reader, writer)
+ client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+ self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
- writer.close()
- self.logger.debug('Client disconnected')
+ await socket.close()
+ self.logger.debug("Client disconnected")
- def run_loop_forever(self):
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
- self.loop.stop()
+ self.loop.create_task(self.stop())
- def _serve_forever(self):
+ def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
- self.run_loop_forever()
- self.server.close()
+ self.loop.run_until_complete(asyncio.gather(*tasks))
- self.loop.run_until_complete(self.server.wait_closed())
- self.logger.debug('Server shutting down')
+ self.logger.debug("Server shutting down")
finally:
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
- self.start()
- self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
+
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@@ -259,18 +260,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
+ self._create_loop()
try:
- self.start()
+ self.address = None
+ tasks = self.start()
finally:
+ # Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
- self._serve_forever()
+ self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57..3a401835 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
-
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267..ebf1c361 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
- self.writer.write(("%s\n" % msg).encode("utf-8"))
- await self.writer.drain()
- l = await self.reader.readline()
- if not l:
- raise ConnectionError("Connection closed")
- return l.decode("utf-8").rstrip()
+ await self.socket.send(msg)
+ return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
+ async def stream_to_normal():
+ await self.socket.send("END")
+ return await self.socket.recv_message()
+
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
- r = await self.send_stream("END")
+ r = await self._send_wrapper(stream_to_normal)
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
- r = await self.send_message({"get-stream": None})
+ r = await self.invoke({"get-stream": None})
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
- return await self.send_message({"report": m})
+ return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
- return await self.send_message({"report-equiv": m})
+ return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"get-stats": None})
+ return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"reset-stats": None})
+ return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
- return (await self.send_message({"backfill-wait": None}))["tasks"]
+ return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"remove": {"where": where}})
+ return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+ return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476b..6d3a4751 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+ def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+ super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
- await self.handlers[k](msg[k])
+ return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
- d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
- self.write_message(d)
+ return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
- d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
- self.write_message(d)
+ return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
- self.write_message('ok')
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
+ break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,29 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
+ msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
+ msg = upstream
else:
- msg = "\n".encode("utf-8")
+ msg = ""
else:
- msg = '\n'.encode('utf-8')
+ msg = ""
- self.writer.write(msg)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
-
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
+ return "ok"
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@@ -468,7 +461,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
- self.write_message(d)
+ return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@@ -491,30 +484,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
+ return d
async def handle_get_stats(self, request):
- d = {
+ return {
'requests': self.request_stats.todict(),
}
- self.write_message(d)
-
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
- self.write_message(d)
+ return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
- self.write_message(d)
+ return d
async def handle_remove(self, request):
condition = request["where"]
@@ -541,7 +532,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
- self.write_message({"count": count})
+ return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@@ -558,7 +549,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
- self.write_message({"count": count})
+ return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@@ -583,41 +574,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ def accept_client(self, socket):
+ return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ async def backfill_worker_task(self):
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
self.backfill_queue.task_done()
- finally:
- await client.close()
+ break
+ method, taskhash = item
+ await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
- async def join_worker(worker):
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+ return tasks
+
+ async def stop(self):
+ if self.backfill_queue is not None:
await self.backfill_queue.put(None)
- await worker
-
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
-
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
-
- with self._backfill_worker():
- super().run_loop_forever()
+ await super().stop()
diff --git a/lib/prserv/client.py b/lib/prserv/client.py
index 69ab7a4a..6b81356f 100644
--- a/lib/prserv/client.py
+++ b/lib/prserv/client.py
@@ -14,28 +14,28 @@ class PRAsyncClient(bb.asyncrpc.AsyncClient):
super().__init__('PRSERVICE', '1.0', logger)
async def getPR(self, version, pkgarch, checksum):
- response = await self.send_message(
+ response = await self.invoke(
{'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
)
if response:
return response['value']
async def importone(self, version, pkgarch, checksum, value):
- response = await self.send_message(
+ response = await self.invoke(
{'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
)
if response:
return response['value']
async def export(self, version, pkgarch, checksum, colinfo):
- response = await self.send_message(
+ response = await self.invoke(
{'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
)
if response:
return (response['metainfo'], response['datainfo'])
async def is_readonly(self):
- response = await self.send_message(
+ response = await self.invoke(
{'is-readonly': {}}
)
if response:
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index c686b206..ea793316 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -20,8 +20,8 @@ PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
singleton = None
class PRServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, table, read_only):
- super().__init__(reader, writer, 'PRSERVICE', logger)
+ def __init__(self, socket, table, read_only):
+ super().__init__(socket, 'PRSERVICE', logger)
self.handlers.update({
'get-pr': self.handle_get_pr,
'import-one': self.handle_import_one,
@@ -36,12 +36,12 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
try:
- await super().dispatch_message(msg)
+ return await super().dispatch_message(msg)
except:
self.table.sync()
raise
-
- self.table.sync_if_dirty()
+ else:
+ self.table.sync_if_dirty()
async def handle_get_pr(self, request):
version = request['version']
@@ -57,7 +57,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
except sqlite3.Error as exc:
logger.error(str(exc))
- self.write_message(response)
+ return response
async def handle_import_one(self, request):
response = None
@@ -71,7 +71,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
if value is not None:
response = {'value': value}
- self.write_message(response)
+ return response
async def handle_export(self, request):
version = request['version']
@@ -85,12 +85,10 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
logger.error(str(exc))
metainfo = datainfo = None
- response = {'metainfo': metainfo, 'datainfo': datainfo}
- self.write_message(response)
+ return {'metainfo': metainfo, 'datainfo': datainfo}
async def handle_is_readonly(self, request):
- response = {'readonly': self.read_only}
- self.write_message(response)
+ return {'readonly': self.read_only}
class PRServer(bb.asyncrpc.AsyncServer):
def __init__(self, dbfile, read_only=False):
@@ -99,20 +97,23 @@ class PRServer(bb.asyncrpc.AsyncServer):
self.table = None
self.read_only = read_only
- def accept_client(self, reader, writer):
- return PRServerClient(reader, writer, self.table, self.read_only)
+ def accept_client(self, socket):
+ return PRServerClient(socket, self.table, self.read_only)
- def _serve_forever(self):
+ def start(self):
+ tasks = super().start()
self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
self.table = self.db["PRMAIN"]
logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
(self.dbfile, self.address, str(os.getpid())))
- super()._serve_forever()
+ return tasks
+ async def stop(self):
self.table.sync_if_dirty()
self.db.disconnect()
+ await super().stop()
def signal_handler(self):
super().signal_handler()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 02/22] hashserv: Add websocket connection implementation
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 01/22] asyncrpc: Abstract sockets Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 03/22] asyncrpc: Add context manager API Joshua Watt
` (22 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support to the hash equivalence client and server to communicate
over websockets. Since websockets are message orientated instead of
stream orientated, and new connection class is needed to handle them.
Note that websocket support does require the 3rd party websockets python
module be installed on the host, but it should not be required unless
websockets are actually being used.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 11 +++++++-
lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
lib/bb/asyncrpc/serv.py | 53 ++++++++++++++++++++++++++++++++++-
lib/hashserv/__init__.py | 13 +++++++++
lib/hashserv/client.py | 1 +
lib/hashserv/tests.py | 17 +++++++++++
6 files changed, 137 insertions(+), 2 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 7f33099b..802c07df 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,7 +10,7 @@ import json
import os
import socket
import sys
-from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError
@@ -47,6 +47,15 @@ class AsyncClient(object):
self._connect_sock = connect_sock
+ async def connect_websocket(self, uri):
+ import websockets
+
+ async def connect_sock():
+ websocket = await websockets.connect(uri, ping_interval=None)
+ return WebsocketConnection(websocket, self.timeout)
+
+ self._connect_sock = connect_sock
+
async def setup_connection(self):
# Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index c4fd2475..a10628f7 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -93,3 +93,47 @@ class StreamConnection(object):
if self.writer is not None:
self.writer.close()
self.writer = None
+
+
+class WebsocketConnection(object):
+ def __init__(self, socket, timeout):
+ self.socket = socket
+ self.timeout = timeout
+
+ @property
+ def address(self):
+ return ":".join(str(s) for s in self.socket.remote_address)
+
+ async def send_message(self, msg):
+ await self.send(json.dumps(msg))
+
+ async def recv_message(self):
+ m = await self.recv()
+ return json.loads(m)
+
+ async def send(self, msg):
+ import websockets.exceptions
+
+ try:
+ await self.socket.send(msg)
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def recv(self):
+ import websockets.exceptions
+
+ try:
+ if self.timeout < 0:
+ return await self.socket.recv()
+
+ try:
+ return await asyncio.wait_for(self.socket.recv(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def close(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 8d4da1e2..3040ac91 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,7 +12,7 @@ import signal
import socket
import sys
import multiprocessing
-from .connection import StreamConnection
+from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
@@ -172,6 +172,54 @@ class UnixStreamServer(StreamServer):
os.unlink(self.path)
+class WebsocketsServer(object):
+ def __init__(self, host, port, handler, logger):
+ self.host = host
+ self.port = port
+ self.handler = handler
+ self.logger = logger
+
+ def start(self, loop):
+ import websockets.server
+
+ self.server = loop.run_until_complete(
+ websockets.server.serve(
+ self.client_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ )
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "ws://[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "ws://%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ self.server.close()
+
+ def cleanup(self):
+ pass
+
+ async def client_handler(self, websocket):
+ socket = WebsocketConnection(websocket, -1)
+ await self.handler(socket)
+
+
class AsyncServer(object):
def __init__(self, logger):
self.logger = logger
@@ -184,6 +232,9 @@ class AsyncServer(object):
def start_unix_server(self, path):
self.server = UnixStreamServer(path, self._client_handler, self.logger)
+ def start_websocket_server(self, host, port):
+ self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
+
async def _client_handler(self, socket):
try:
client = self.accept_client(socket)
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 3a401835..56b9c6bc 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -9,11 +9,15 @@ import re
import sqlite3
import itertools
import json
+from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
+WS_PREFIX = "ws://"
+WSS_PREFIX = "wss://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
+ADDR_TYPE_WS = 2
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
@@ -84,6 +88,8 @@ def setup_database(database, sync=True):
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
+ return (ADDR_TYPE_WS, (addr,))
else:
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
if m is not None:
@@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
s.start_unix_server(*a)
+ elif typ == ADDR_TYPE_WS:
+ url = urlparse(a[0])
+ s.start_websocket_server(url.hostname, url.port)
else:
s.start_tcp_server(*a)
@@ -116,6 +125,8 @@ def create_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
else:
c.connect_tcp(*a)
@@ -128,6 +139,8 @@ async def create_async_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
else:
await c.connect_tcp(*a)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index ebf1c361..ebb58f33 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
super().__init__()
self._add_methods(
"connect_tcp",
+ "connect_websocket",
"get_unihash",
"report_unihash",
"report_unihash_equiv",
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f343c586..01ffd52c 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
# If IPv6 is enabled, it should be safe to use localhost directly, in general
# case it is more reliable to resolve the IP address explicitly.
return socket.gethostbyname("localhost") + ":0"
+
+
+class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def setUp(self):
+ try:
+ import websockets
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def get_server_addr(self, server_idx):
+ # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
+ # If IPv6 is enabled, it should be safe to use localhost directly, in general
+ # case it is more reliable to resolve the IP address explicitly.
+ host = socket.gethostbyname("localhost")
+ return "ws://%s:0" % host
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 03/22] asyncrpc: Add context manager API
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 01/22] asyncrpc: Abstract sockets Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 02/22] hashserv: Add websocket connection implementation Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 04/22] hashserv: tests: Add external database tests Joshua Watt
` (21 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds context manager API for the asyncrcp client class which allow
writing code that will automatically close the connection like so:
with hashserv.create_client(address) as client:
...
Rework the bitbake-hashclient tool and PR server to use this new API to
fix warnings about unclosed event loops when exiting
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 36 +++++++++++++++++-------------------
lib/bb/asyncrpc/client.py | 13 +++++++++++++
lib/prserv/serv.py | 6 +++---
3 files changed, 33 insertions(+), 22 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 3f265e8f..a02a65b9 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -56,25 +56,24 @@ def main():
nonlocal missed_hashes
nonlocal max_time
- client = hashserv.create_client(args.address)
-
- for i in range(args.requests):
- taskhash = hashlib.sha256()
- taskhash.update(args.taskhash_seed.encode('utf-8'))
- taskhash.update(str(i).encode('utf-8'))
+ with hashserv.create_client(args.address) as client:
+ for i in range(args.requests):
+ taskhash = hashlib.sha256()
+ taskhash.update(args.taskhash_seed.encode('utf-8'))
+ taskhash.update(str(i).encode('utf-8'))
- start_time = time.perf_counter()
- l = client.get_unihash(METHOD, taskhash.hexdigest())
- elapsed = time.perf_counter() - start_time
+ start_time = time.perf_counter()
+ l = client.get_unihash(METHOD, taskhash.hexdigest())
+ elapsed = time.perf_counter() - start_time
- with lock:
- if l:
- found_hashes += 1
- else:
- missed_hashes += 1
+ with lock:
+ if l:
+ found_hashes += 1
+ else:
+ missed_hashes += 1
- max_time = max(elapsed, max_time)
- pbar.update()
+ max_time = max(elapsed, max_time)
+ pbar.update()
max_time = 0
found_hashes = 0
@@ -174,9 +173,8 @@ def main():
func = getattr(args, 'func', None)
if func:
- client = hashserv.create_client(args.address)
-
- return func(args, client)
+ with hashserv.create_client(args.address) as client:
+ return func(args, client)
return 0
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 802c07df..009085c3 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -103,6 +103,12 @@ class AsyncClient(object):
async def ping(self):
return await self.invoke({"ping": {}})
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
class Client(object):
def __init__(self):
@@ -153,3 +159,10 @@ class Client(object):
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index ea793316..6168eb18 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -345,9 +345,9 @@ def auto_shutdown():
def ping(host, port):
from . import client
- conn = client.PRClient()
- conn.connect_tcp(host, port)
- return conn.ping()
+ with client.PRClient() as conn:
+ conn.connect_tcp(host, port)
+ return conn.ping()
def connect(host, port):
from . import client
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 04/22] hashserv: tests: Add external database tests
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (2 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 03/22] asyncrpc: Add context manager API Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
` (20 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for running the hash equivalence test suite against an
external hash equivalence implementation.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 54 +++++++++++++++++++++++++++++++++++--------
1 file changed, 44 insertions(+), 10 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 01ffd52c..4c98a280 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -51,13 +51,20 @@ class HashEquivalenceTestSetup(object):
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
self.addCleanup(cleanup_server, server)
+ return server
+
+ def start_client(self, server_address):
def cleanup_client(client):
client.close()
- client = create_client(server.address)
+ client = create_client(server_address)
self.addCleanup(cleanup_client, client)
- return (client, server)
+ return client
+
+ def start_test_server(self):
+ server = self.start_server()
+ return server.address
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -66,7 +73,9 @@ class HashEquivalenceTestSetup(object):
self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
self.addCleanup(self.temp_dir.cleanup)
- (self.client, self.server) = self.start_server()
+ self.server_address = self.start_test_server()
+
+ self.client = self.start_client(self.server_address)
def assertClientGetHash(self, client, taskhash, unihash):
result = client.get_unihash(self.METHOD, taskhash)
@@ -206,7 +215,7 @@ class HashEquivalenceCommonTests(object):
def test_stress(self):
def query_server(failures):
- client = Client(self.server.address)
+ client = Client(self.server_address)
try:
for i in range(1000):
taskhash = hashlib.sha256()
@@ -245,8 +254,10 @@ class HashEquivalenceCommonTests(object):
# the side client. It also verifies that the results are pulled into
# the downstream database by checking that the downstream and side servers
# match after the downstream is done waiting for all backfill tasks
- (down_client, down_server) = self.start_server(upstream=self.server.address)
- (side_client, side_server) = self.start_server(dbpath=down_server.dbpath)
+ down_server = self.start_server(upstream=self.server_address)
+ down_client = self.start_client(down_server.address)
+ side_server = self.start_server(dbpath=down_server.dbpath)
+ side_client = self.start_client(side_server.address)
def check_hash(taskhash, unihash, old_sidehash):
nonlocal down_client
@@ -351,14 +362,18 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['method'], self.METHOD)
def test_ro_server(self):
- (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True)
+ rw_server = self.start_server()
+ rw_client = self.start_client(rw_server.address)
+
+ ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
+ ro_client = self.start_client(ro_server.address)
# Report a hash via the read-write server
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
# Check the hash via the read-only server
@@ -373,7 +388,7 @@ class HashEquivalenceCommonTests(object):
ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
# Ensure that the database was not modified
- self.assertClientGetHash(self.client, taskhash2, None)
+ self.assertClientGetHash(rw_client, taskhash2, None)
def test_slow_server_start(self):
@@ -393,7 +408,7 @@ class HashEquivalenceCommonTests(object):
old_signal = signal.signal(signal.SIGTERM, do_nothing)
self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
- _, server = self.start_server(prefunc=prefunc)
+ server = self.start_server(prefunc=prefunc)
server.process.terminate()
time.sleep(30)
event.set()
@@ -500,3 +515,22 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
# case it is more reliable to resolve the IP address explicitly.
host = socket.gethostbyname("localhost")
return "ws://%s:0" % host
+
+
+class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def start_test_server(self):
+ if 'BB_TEST_HASHSERV' not in os.environ:
+ self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+
+ return os.environ['BB_TEST_HASHSERV']
+
+ def start_server(self, *args, **kwargs):
+ self.skipTest('Cannot start local server when testing external servers')
+
+ def setUp(self):
+ super().setUp()
+ self.client.remove({"method": self.METHOD})
+
+ def tearDown(self):
+ self.client.remove({"method": self.METHOD})
+ super().tearDown()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 05/22] asyncrpc: Prefix log messages with client info
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (3 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 04/22] hashserv: tests: Add external database tests Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
` (19 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds a logging adaptor to the asyncrpc clients that prefixes log
messages with the client remote address to aid in debugging
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/serv.py | 20 +++++++++++++++++---
lib/hashserv/server.py | 10 +++++-----
2 files changed, 22 insertions(+), 8 deletions(-)
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 3040ac91..7569ad6c 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,10 +12,20 @@ import signal
import socket
import sys
import multiprocessing
+import logging
from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
+class ClientLoggerAdapter(logging.LoggerAdapter):
+ def __init__(self, logger, address):
+ super().__init__(logger)
+ self.address = address
+
+ def process(self, msg, kwargs):
+ return f"[Client {self.address}] {msg}", kwargs
+
+
class AsyncServerConnection(object):
def __init__(self, socket, proto_name, logger):
self.socket = socket
@@ -23,7 +33,7 @@ class AsyncServerConnection(object):
self.handlers = {
"ping": self.handle_ping,
}
- self.logger = logger
+ self.logger = ClientLoggerAdapter(logger, socket.address)
async def close(self):
await self.socket.close()
@@ -236,16 +246,20 @@ class AsyncServer(object):
self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
async def _client_handler(self, socket):
+ address = socket.address
try:
client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error("Error from client: %s" % str(e), exc_info=True)
+ self.logger.error(
+ "Error from client %s: %s" % (address, str(e)), exc_info=True
+ )
traceback.print_exc()
+ finally:
+ self.logger.debug("Client %s disconnected", address)
await socket.close()
- self.logger.debug("Client disconnected")
@abc.abstractmethod
def accept_client(self, socket):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 6d3a4751..928532c7 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -207,7 +207,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- logger.debug('Handling %s' % k)
+ self.logger.debug('Handling %s' % k)
if 'stream' in k:
return await self.handlers[k](msg[k])
else:
@@ -351,7 +351,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
break
(method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
+ #self.logger.debug('Looking up %s %s' % (method, taskhash))
cursor = self.db.cursor()
try:
row = self.query_equivalent(cursor, method, taskhash)
@@ -360,7 +360,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if row is not None:
msg = row['unihash']
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -479,8 +479,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
row = self.query_equivalent(cursor, data['method'], data['taskhash'])
if row['unihash'] == data['unihash']:
- logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
+ self.logger.info('Adding taskhash equivalence for %s with unihash %s',
+ data['taskhash'], row['unihash'])
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 06/22] bitbake-hashserv: Allow arguments from environment
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (4 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 07/22] hashserv: Abstract database Joshua Watt
` (18 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows the arguments to the bitbake-hashserv command to be specified in
environment variables. This is a very common idiom when running services
in containers as it allows the arguments to be specified from different
sources as desired by the service administrator
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 80 +++++++++++++++++++++++++++++++++-----------
1 file changed, 60 insertions(+), 20 deletions(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 00af76b2..a916a90c 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -11,56 +11,96 @@ import logging
import argparse
import sqlite3
import warnings
+
warnings.simplefilter("default")
-sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
+sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
VERSION = "1.0.0"
-DEFAULT_BIND = 'unix://./hashserve.sock'
+DEFAULT_BIND = "unix://./hashserve.sock"
def main():
- parser = argparse.ArgumentParser(description='Hash Equivalence Reference Server. Version=%s' % VERSION,
- epilog='''The bind address is the path to a unix domain socket if it is
- prefixed with "unix://". Otherwise, it is an IP address
- and port in form ADDRESS:PORT. To bind to all addresses, leave
- the ADDRESS empty, e.g. "--bind :8686". To bind to a specific
- IPv6 address, enclose the address in "[]", e.g.
- "--bind [::1]:8686"'''
- )
-
- parser.add_argument('-b', '--bind', default=DEFAULT_BIND, help='Bind address (default "%(default)s")')
- parser.add_argument('-d', '--database', default='./hashserv.db', help='Database file (default "%(default)s")')
- parser.add_argument('-l', '--log', default='WARNING', help='Set logging level')
- parser.add_argument('-u', '--upstream', help='Upstream hashserv to pull hashes from')
- parser.add_argument('-r', '--read-only', action='store_true', help='Disallow write operations from clients')
+ parser = argparse.ArgumentParser(
+ description="Hash Equivalence Reference Server. Version=%s" % VERSION,
+ formatter_class=argparse.RawTextHelpFormatter,
+ epilog="""
+The bind address may take one of the following formats:
+ unix://PATH - Bind to unix domain socket at PATH
+ ws://ADDRESS:PORT - Bind to websocket on ADDRESS:PORT
+ ADDRESS:PORT - Bind to raw TCP socket on ADDRESS:PORT
+
+To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
+"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
+"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+ """,
+ )
+
+ parser.add_argument(
+ "-b",
+ "--bind",
+ default=os.environ.get("HASHSERVER_BIND", DEFAULT_BIND),
+ help='Bind address (default $HASHSERVER_BIND, "%(default)s")',
+ )
+ parser.add_argument(
+ "-d",
+ "--database",
+ default=os.environ.get("HASHSERVER_DB", "./hashserv.db"),
+ help='Database file (default $HASHSERVER_DB, "%(default)s")',
+ )
+ parser.add_argument(
+ "-l",
+ "--log",
+ default=os.environ.get("HASHSERVER_LOG_LEVEL", "WARNING"),
+ help='Set logging level (default $HASHSERVER_LOG_LEVEL, "%(default)s")',
+ )
+ parser.add_argument(
+ "-u",
+ "--upstream",
+ default=os.environ.get("HASHSERVER_UPSTREAM", None),
+ help="Upstream hashserv to pull hashes from ($HASHSERVER_UPSTREAM)",
+ )
+ parser.add_argument(
+ "-r",
+ "--read-only",
+ action="store_true",
+ help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
+ )
args = parser.parse_args()
- logger = logging.getLogger('hashserv')
+ logger = logging.getLogger("hashserv")
level = getattr(logging, args.log.upper(), None)
if not isinstance(level, int):
- raise ValueError('Invalid log level: %s' % args.log)
+ raise ValueError("Invalid log level: %s" % args.log)
logger.setLevel(level)
console = logging.StreamHandler()
console.setLevel(level)
logger.addHandler(console)
- server = hashserv.create_server(args.bind, args.database, upstream=args.upstream, read_only=args.read_only)
+ read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+
+ server = hashserv.create_server(
+ args.bind,
+ args.database,
+ upstream=args.upstream,
+ read_only=read_only,
+ )
server.serve_forever()
return 0
-if __name__ == '__main__':
+if __name__ == "__main__":
try:
ret = main()
except Exception:
ret = 1
import traceback
+
traceback.print_exc()
sys.exit(ret)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 07/22] hashserv: Abstract database
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (5 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 08/22] hashserv: Add SQLalchemy backend Joshua Watt
` (17 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Abstracts the way the database backend is accessed by the hash
equivalence server to make it possible to use other backends
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/__init__.py | 90 ++-----
lib/hashserv/server.py | 491 +++++++++++++--------------------------
lib/hashserv/sqlite.py | 259 +++++++++++++++++++++
3 files changed, 439 insertions(+), 401 deletions(-)
create mode 100644 lib/hashserv/sqlite.py
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 56b9c6bc..90d8cff1 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -6,7 +6,6 @@
import asyncio
from contextlib import closing
import re
-import sqlite3
import itertools
import json
from urllib.parse import urlparse
@@ -19,92 +18,34 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
-UNIHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("unihash", "TEXT NOT NULL", ""),
-)
-
-UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
-
-OUTHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("outhash", "TEXT NOT NULL", "UNIQUE"),
- ("created", "DATETIME", ""),
-
- # Optional fields
- ("owner", "TEXT", ""),
- ("PN", "TEXT", ""),
- ("PV", "TEXT", ""),
- ("PR", "TEXT", ""),
- ("task", "TEXT", ""),
- ("outhash_siginfo", "TEXT", ""),
-)
-
-OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
-
-def _make_table(cursor, name, definition):
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS {name} (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- {fields}
- UNIQUE({unique})
- )
- '''.format(
- name=name,
- fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
- unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
- ))
-
-
-def setup_database(database, sync=True):
- db = sqlite3.connect(database)
- db.row_factory = sqlite3.Row
-
- with closing(db.cursor()) as cursor:
- _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
- _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
-
- cursor.execute('PRAGMA journal_mode = WAL')
- cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
-
- # Drop old indexes
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
-
- # TODO: Upgrade from tasks_v2?
- cursor.execute('DROP TABLE IF EXISTS tasks_v2')
-
- # Create new indexes
- cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
- cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
-
- return db
-
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
- return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
return (ADDR_TYPE_WS, (addr,))
else:
- m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
if m is not None:
- host = m.group('host')
- port = m.group('port')
+ host = m.group("host")
+ port = m.group("port")
else:
- host, port = addr.split(':')
+ host, port = addr.split(":")
return (ADDR_TYPE_TCP, (host, int(port)))
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+ def sqlite_engine():
+ from .sqlite import DatabaseEngine
+
+ return DatabaseEngine(dbname, sync)
+
from . import server
- db = setup_database(dbname, sync=sync)
- s = server.Server(db, upstream=upstream, read_only=read_only)
+
+ db_engine = sqlite_engine()
+
+ s = server.Server(db_engine, upstream=upstream, read_only=read_only)
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -120,6 +61,7 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
def create_client(addr):
from . import client
+
c = client.Client()
(typ, a) = parse_address(addr)
@@ -132,8 +74,10 @@ def create_client(addr):
return c
+
async def create_async_client(addr):
from . import client
+
c = client.AsyncClient()
(typ, a) = parse_address(addr)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 928532c7..12255cc2 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -3,18 +3,16 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing, contextmanager
from datetime import datetime, timedelta
-import enum
import asyncio
import logging
import math
import time
-from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
+from . import create_async_client
import bb.asyncrpc
-logger = logging.getLogger('hashserv.server')
+logger = logging.getLogger("hashserv.server")
class Measurement(object):
@@ -104,229 +102,136 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-
-
-@enum.unique
-class Resolve(enum.Enum):
- FAIL = enum.auto()
- IGNORE = enum.auto()
- REPLACE = enum.auto()
-
-
-def insert_table(cursor, table, data, on_conflict):
- resolve = {
- Resolve.FAIL: "",
- Resolve.IGNORE: " OR IGNORE",
- Resolve.REPLACE: " OR REPLACE",
- }[on_conflict]
-
- keys = sorted(data.keys())
- query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
- resolve=resolve,
- table=table,
- fields=", ".join(keys),
- values=", ".join(":" + k for k in keys),
- )
- prevrowid = cursor.lastrowid
- cursor.execute(query, data)
- logging.debug(
- "Inserting %r into %s, %s",
- data,
- table,
- on_conflict
- )
- return (cursor.lastrowid, cursor.lastrowid != prevrowid)
-
-def insert_unihash(cursor, data, on_conflict):
- return insert_table(cursor, "unihashes_v2", data, on_conflict)
-
-def insert_outhash(cursor, data, on_conflict):
- return insert_table(cursor, "outhashes_v2", data, on_conflict)
-
-async def copy_unihash_from_upstream(client, db, method, taskhash):
- d = await client.get_taskhash(method, taskhash)
- if d is not None:
- with closing(db.cursor()) as cursor:
- insert_unihash(
- cursor,
- {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE,
- )
- db.commit()
- return d
-
-
-class ServerCursor(object):
- def __init__(self, db, cursor, upstream):
- self.db = db
- self.cursor = cursor
- self.upstream = upstream
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
+ }
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(socket, 'OEHASHEQUIV', logger)
- self.db = db
+ def __init__(
+ self,
+ socket,
+ db_engine,
+ request_stats,
+ backfill_queue,
+ upstream,
+ read_only,
+ ):
+ super().__init__(socket, "OEHASHEQUIV", logger)
+ self.db_engine = db_engine
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
- self.handlers.update({
- 'get': self.handle_get,
- 'get-outhash': self.handle_get_outhash,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- })
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "get-stats": self.handle_get_stats,
+ }
+ )
if not read_only:
- self.handlers.update({
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'reset-stats': self.handle_reset_stats,
- 'backfill-wait': self.handle_backfill_wait,
- 'remove': self.handle_remove,
- 'clean-unused': self.handle_clean_unused,
- })
+ self.handlers.update(
+ {
+ "report": self.handle_report,
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "clean-unused": self.handle_clean_unused,
+ }
+ )
def validate_proto_version(self):
- return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
async def process_requests(self):
- if self.upstream is not None:
- self.upstream_client = await create_async_client(self.upstream)
- else:
- self.upstream_client = None
-
- await super().process_requests()
+ async with self.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.upstream is not None:
+ self.upstream_client = await create_async_client(self.upstream)
+ else:
+ self.upstream_client = None
- if self.upstream_client is not None:
- await self.upstream_client.close()
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- if 'stream' in k:
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
return await self.handlers[k](msg[k])
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
+ with self.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
- fetch_all = request.get('all', False)
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
- with closing(self.db.cursor()) as cursor:
- return await self.get_unihash(cursor, method, taskhash, fetch_all)
+ return await self.get_unihash(method, taskhash, fetch_all)
- async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
+ async def get_unihash(self, method, taskhash, fetch_all=False):
d = None
if fetch_all:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
-
- )
- row = cursor.fetchone()
-
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash, True)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
else:
- row = self.query_equivalent(cursor, method, taskhash)
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
- d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
- insert_unihash(cursor, d, Resolve.IGNORE)
- self.db.commit()
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
return d
async def handle_get_outhash(self, request):
- method = request['method']
- outhash = request['outhash']
- taskhash = request['taskhash']
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
with_unihash = request.get("with_unihash", True)
- with closing(self.db.cursor()) as cursor:
- return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
- async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
d = None
if with_unihash:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
+ row = await self.db.get_unihash_by_outhash(method, outhash)
else:
- cursor.execute(
- """
- SELECT * FROM outhashes_v2
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- """,
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
- row = cursor.fetchone()
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_outhash(method, outhash, taskhash)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
return d
- def update_unified(self, cursor, data):
+ async def update_unified(self, data):
if data is None:
return
- insert_unihash(
- cursor,
- {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
- insert_outhash(
- cursor,
- {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -347,20 +252,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- if l == 'END':
+ if l == "END":
break
(method, taskhash) = l.split()
- #self.logger.debug('Looking up %s %s' % (method, taskhash))
- cursor = self.db.cursor()
- try:
- row = self.query_equivalent(cursor, method, taskhash)
- finally:
- cursor.close()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
- msg = row['unihash']
- #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ msg = row["unihash"]
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -383,118 +284,81 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return "ok"
async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- outhash_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'created': datetime.now()
- }
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- outhash_data[k] = data[k]
-
- # Insert the new entry, unless it already exists
- (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
-
- if inserted:
- # If this row is new, check if it is equivalent to another
- # output hash
- cursor.execute(
- '''
- SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- -- Select any matching output hash except the one we just inserted
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
- -- Pick the oldest hash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- }
- )
- row = cursor.fetchone()
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
- if row is not None:
- # A matching output hash was found. Set our taskhash to the
- # same unihash since they are equivalent
- unihash = row['unihash']
- resolve = Resolve.IGNORE
- else:
- # No matching output hash was found. This is probably the
- # first outhash to be added.
- unihash = data['unihash']
- resolve = Resolve.IGNORE
-
- # Query upstream to see if it has a unihash we can use
- if self.upstream_client is not None:
- upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
- if upstream_data is not None:
- unihash = upstream_data['unihash']
-
-
- insert_unihash(
- cursor,
- {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- },
- resolve
- )
-
- unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
- if unihash_data is not None:
- unihash = unihash_data['unihash']
- else:
- unihash = data['unihash']
-
- self.db.commit()
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash,
- }
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- return d
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
+ }
async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- }
- insert_unihash(cursor, insert_data, Resolve.IGNORE)
- self.db.commit()
-
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(cursor, data['method'], data['taskhash'])
-
- if row['unihash'] == data['unihash']:
- self.logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
-
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
-
- return d
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
+
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
async def handle_get_stats(self, request):
return {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
self.request_stats.reset()
@@ -502,7 +366,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_backfill_wait(self, request):
d = {
- 'tasks': self.backfill_queue.qsize(),
+ "tasks": self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
return d
@@ -512,92 +376,63 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if not isinstance(condition, dict):
raise TypeError("Bad condition type %s" % type(condition))
- def do_remove(columns, table_name, cursor):
- nonlocal condition
- where = {}
- for c in columns:
- if c in condition and condition[c] is not None:
- where[c] = condition[c]
-
- if where:
- query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys())
- cursor.execute(query, where)
- return cursor.rowcount
-
- return 0
-
- count = 0
- with closing(self.db.cursor()) as cursor:
- count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
- count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
- self.db.commit()
-
- return {"count": count}
+ return {"count": await self.db.remove(condition)}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
- with closing(self.db.cursor()) as cursor:
- cursor.execute(
- """
- DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
- SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
- )
- """,
- {
- "oldest": datetime.now() - timedelta(seconds=-max_age)
- }
- )
- count = cursor.rowcount
-
- return {"count": count}
-
- def query_equivalent(self, cursor, method, taskhash):
- # This is part of the inner loop and must be as fast as possible
- cursor.execute(
- 'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
- )
- return cursor.fetchone()
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, upstream=None, read_only=False):
+ def __init__(self, db_engine, upstream=None, read_only=False):
if upstream and read_only:
- raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
super().__init__(logger)
self.request_stats = Stats()
- self.db = db
+ self.db_engine = db_engine
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, socket):
- return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ return ServerClient(
+ socket,
+ self.db_engine,
+ self.request_stats,
+ self.backfill_queue,
+ self.upstream,
+ self.read_only,
+ )
async def backfill_worker_task(self):
- client = await create_async_client(self.upstream)
- try:
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
+
method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
self.backfill_queue.task_done()
- finally:
- await client.close()
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
+
+ self.loop.run_until_complete(self.db_engine.create())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
new file mode 100644
index 00000000..6809c537
--- /dev/null
+++ b/lib/hashserv/sqlite.py
@@ -0,0 +1,259 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+import sqlite3
+import logging
+from contextlib import closing
+
+logger = logging.getLogger("hashserv.sqlite")
+
+UNIHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("unihash", "TEXT NOT NULL", ""),
+)
+
+UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
+
+OUTHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("outhash", "TEXT NOT NULL", "UNIQUE"),
+ ("created", "DATETIME", ""),
+ # Optional fields
+ ("owner", "TEXT", ""),
+ ("PN", "TEXT", ""),
+ ("PV", "TEXT", ""),
+ ("PR", "TEXT", ""),
+ ("task", "TEXT", ""),
+ ("outhash_siginfo", "TEXT", ""),
+)
+
+OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+
+
+def _make_table(cursor, name, definition):
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS {name} (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ {fields}
+ UNIQUE({unique})
+ )
+ """.format(
+ name=name,
+ fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
+ unique=", ".join(
+ name for name, _, flags in definition if "UNIQUE" in flags
+ ),
+ )
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, dbname, sync):
+ self.dbname = dbname
+ self.logger = logger
+ self.sync = sync
+
+ async def create(self):
+ db = sqlite3.connect(self.dbname)
+ db.row_factory = sqlite3.Row
+
+ with closing(db.cursor()) as cursor:
+ _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
+ _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+
+ cursor.execute("PRAGMA journal_mode = WAL")
+ cursor.execute(
+ "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
+ )
+
+ # Drop old indexes
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
+
+ # TODO: Upgrade from tasks_v2?
+ cursor.execute("DROP TABLE IF EXISTS tasks_v2")
+
+ # Create new indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
+ )
+
+ def connect(self, logger):
+ return Database(logger, self.dbname)
+
+
+class Database(object):
+ def __init__(self, logger, dbname, sync=True):
+ self.dbname = dbname
+ self.logger = logger
+
+ self.db = sqlite3.connect(self.dbname)
+ self.db.row_factory = sqlite3.Row
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ self.db.close()
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT * FROM outhashes_v2
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ -- Select any matching output hash except the one we just inserted
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
+ -- Pick the oldest hash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def remove(self, condition):
+ def do_remove(columns, table_name, cursor):
+ where = {}
+ for c in columns:
+ if c in condition and condition[c] is not None:
+ where[c] = condition[c]
+
+ if where:
+ query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
+ "%s=:%s" % (k, k) for k in where.keys()
+ )
+ cursor.execute(query, where)
+ return cursor.rowcount
+
+ return 0
+
+ count = 0
+ with closing(self.db.cursor()) as cursor:
+ count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
+ count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
+ self.db.commit()
+
+ return count
+
+ async def clean_unused(self, oldest):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
+ SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
+ )
+ """,
+ {
+ "oldest": oldest,
+ },
+ )
+ return cursor.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(
+ """
+ INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ "unihash": unihash,
+ },
+ )
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
+
+ async def insert_outhash(self, data):
+ data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
+ keys = sorted(data.keys())
+ query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
+ fields=", ".join(keys),
+ values=", ".join(":" + k for k in keys),
+ )
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(query, data)
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 08/22] hashserv: Add SQLalchemy backend
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (6 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 07/22] hashserv: Abstract database Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
` (16 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an SQLAlchemy backend to the server. While this database backend is
slower than the more direct sqlite backend, it easily supports just
about any SQL server, which is useful for large scale deployments.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 12 ++
lib/bb/asyncrpc/connection.py | 11 +-
lib/hashserv/__init__.py | 21 ++-
lib/hashserv/sqlalchemy.py | 304 ++++++++++++++++++++++++++++++++++
lib/hashserv/tests.py | 19 ++-
5 files changed, 362 insertions(+), 5 deletions(-)
create mode 100644 lib/hashserv/sqlalchemy.py
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index a916a90c..59b8b07f 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -69,6 +69,16 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
action="store_true",
help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
)
+ parser.add_argument(
+ "--db-username",
+ default=os.environ.get("HASHSERVER_DB_USERNAME", None),
+ help="Database username ($HASHSERVER_DB_USERNAME)",
+ )
+ parser.add_argument(
+ "--db-password",
+ default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
+ help="Database password ($HASHSERVER_DB_PASSWORD)",
+ )
args = parser.parse_args()
@@ -90,6 +100,8 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
args.database,
upstream=args.upstream,
read_only=read_only,
+ db_username=args.db_username,
+ db_password=args.db_password,
)
server.serve_forever()
return 0
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index a10628f7..7f0cf6ba 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -7,6 +7,7 @@
import asyncio
import itertools
import json
+from datetime import datetime
from .exceptions import ClientError, ConnectionClosedError
@@ -30,6 +31,12 @@ def chunkify(msg, max_chunk):
yield "\n"
+def json_serialize(obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat()
+ raise TypeError("Type %s not serializeable" % type(obj))
+
+
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
@@ -42,7 +49,7 @@ class StreamConnection(object):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
+ for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
@@ -105,7 +112,7 @@ class WebsocketConnection(object):
return ":".join(str(s) for s in self.socket.remote_address)
async def send_message(self, msg):
- await self.send(json.dumps(msg))
+ await self.send(json.dumps(msg, default=json_serialize))
async def recv_message(self):
m = await self.recv()
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 90d8cff1..9a8ee4e8 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -35,15 +35,32 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+def create_server(
+ addr,
+ dbname,
+ *,
+ sync=True,
+ upstream=None,
+ read_only=False,
+ db_username=None,
+ db_password=None
+):
def sqlite_engine():
from .sqlite import DatabaseEngine
return DatabaseEngine(dbname, sync)
+ def sqlalchemy_engine():
+ from .sqlalchemy import DatabaseEngine
+
+ return DatabaseEngine(dbname, db_username, db_password)
+
from . import server
- db_engine = sqlite_engine()
+ if "://" in dbname:
+ db_engine = sqlalchemy_engine()
+ else:
+ db_engine = sqlite_engine()
s = server.Server(db_engine, upstream=upstream, read_only=read_only)
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
new file mode 100644
index 00000000..3216621f
--- /dev/null
+++ b/lib/hashserv/sqlalchemy.py
@@ -0,0 +1,304 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import logging
+from datetime import datetime
+
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.pool import NullPool
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Text,
+ Integer,
+ UniqueConstraint,
+ DateTime,
+ Index,
+ select,
+ insert,
+ exists,
+ literal,
+ and_,
+ delete,
+)
+import sqlalchemy.engine
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.exc import IntegrityError
+
+logger = logging.getLogger("hashserv.sqlalchemy")
+
+Base = declarative_base()
+
+
+class UnihashesV2(Base):
+ __tablename__ = "unihashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ unihash = Column(Text, nullable=False)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash"),
+ Index("taskhash_lookup_v3", "method", "taskhash"),
+ )
+
+
+class OuthashesV2(Base):
+ __tablename__ = "outhashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ outhash = Column(Text, nullable=False)
+ created = Column(DateTime)
+ owner = Column(Text)
+ PN = Column(Text)
+ PV = Column(Text)
+ PR = Column(Text)
+ task = Column(Text)
+ outhash_siginfo = Column(Text)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash", "outhash"),
+ Index("outhash_lookup_v3", "method", "outhash"),
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, url, username=None, password=None):
+ self.logger = logger
+ self.url = sqlalchemy.engine.make_url(url)
+
+ if username is not None:
+ self.url = self.url.set(username=username)
+
+ if password is not None:
+ self.url = self.url.set(password=password)
+
+ async def create(self):
+ self.logger.info("Using database %s", self.url)
+ self.engine = create_async_engine(self.url, poolclass=NullPool)
+
+ async with self.engine.begin() as conn:
+ # Create tables
+ logger.info("Creating tables...")
+ await conn.run_sync(Base.metadata.create_all)
+
+ def connect(self, logger):
+ return Database(self.engine, logger)
+
+
+def map_row(row):
+ if row is None:
+ return None
+ return dict(**row._mapping)
+
+
+class Database(object):
+ def __init__(self, engine, logger):
+ self.engine = engine
+ self.db = None
+ self.logger = logger
+
+ async def __aenter__(self):
+ self.db = await self.engine.connect()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ await self.db.close()
+ self.db = None
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ statement = (
+ select(
+ OuthashesV2,
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.taskhash == taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2)
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ statement = (
+ select(
+ OuthashesV2.taskhash.label("taskhash"),
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ OuthashesV2.taskhash != taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent(self, method, taskhash):
+ statement = select(
+ UnihashesV2.unihash,
+ UnihashesV2.method,
+ UnihashesV2.taskhash,
+ ).where(
+ UnihashesV2.method == method,
+ UnihashesV2.taskhash == taskhash,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def remove(self, condition):
+ async def do_remove(table):
+ where = {}
+ for c in table.__table__.columns:
+ if c.key in condition and condition[c.key] is not None:
+ where[c] = condition[c.key]
+
+ if where:
+ statement = delete(table).where(*[(k == v) for k, v in where.items()])
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ return 0
+
+ count = 0
+ count += await do_remove(UnihashesV2)
+ count += await do_remove(OuthashesV2)
+
+ return count
+
+ async def clean_unused(self, oldest):
+ statement = delete(OuthashesV2).where(
+ OuthashesV2.created < oldest,
+ ~(
+ select(UnihashesV2.id)
+ .where(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ )
+ .limit(1)
+ .exists()
+ ),
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ statement = insert(UnihashesV2).values(
+ method=method,
+ taskhash=taskhash,
+ unihash=unihash,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s, %s already in unihash database", method, taskhash, unihash
+ )
+ return False
+
+ async def insert_outhash(self, data):
+ outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
+
+ data = {k: v for k, v in data.items() if k in outhash_columns}
+
+ if "created" in data and not isinstance(data["created"], datetime):
+ data["created"] = datetime.fromisoformat(data["created"])
+
+ statement = insert(OuthashesV2).values(**data)
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s already in outhash database", data["method"], data["outhash"]
+ )
+ return False
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 4c98a280..268b2700 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -33,7 +33,7 @@ class HashEquivalenceTestSetup(object):
def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
- dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+ dbpath = self.make_dbpath()
def cleanup_server(server):
if server.process.exitcode is not None:
@@ -53,6 +53,9 @@ class HashEquivalenceTestSetup(object):
return server
+ def make_dbpath(self):
+ return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
def start_client(self, server_address):
def cleanup_client(client):
client.close()
@@ -517,6 +520,20 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
return "ws://%s:0" % host
+class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
+ def setUp(self):
+ try:
+ import sqlalchemy
+ import aiosqlite
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def make_dbpath(self):
+ return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
+
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def start_test_server(self):
if 'BB_TEST_HASHSERV' not in os.environ:
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 09/22] hashserv: Implement read-only version of "report" RPC
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (7 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 08/22] hashserv: Add SQLalchemy backend Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 10/22] asyncrpc: Add InvokeError Joshua Watt
` (15 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
When the hash equivalence server is in read-only mode, it should still
return a unihash for a given "report" call if there is one.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 25 ++++++++++++++++++++++++-
lib/hashserv/tests.py | 4 ++--
2 files changed, 26 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 12255cc2..2e6977cb 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -124,6 +124,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
+ self.read_only = read_only
self.handlers.update(
{
@@ -131,13 +132,15 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ # Not always read-only, but internally checks if the server is
+ # read-only
+ "report": self.handle_report,
}
)
if not read_only:
self.handlers.update(
{
- "report": self.handle_report,
"report-equiv": self.handle_equivreport,
"reset-stats": self.handle_reset_stats,
"backfill-wait": self.handle_backfill_wait,
@@ -283,7 +286,27 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return "ok"
+ async def report_readonly(self, data):
+ method = data["method"]
+ outhash = data["outhash"]
+ taskhash = data["taskhash"]
+
+ info = await self.get_outhash(method, outhash, taskhash)
+ if info:
+ unihash = info["unihash"]
+ else:
+ unihash = data["unihash"]
+
+ return {
+ "taskhash": taskhash,
+ "method": method,
+ "unihash": unihash,
+ }
+
async def handle_report(self, data):
+ if self.read_only:
+ return await self.report_readonly(data)
+
outhash_data = {
"method": data["method"],
"outhash": data["outhash"],
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 268b2700..e9a361dc 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -387,8 +387,8 @@ class HashEquivalenceCommonTests(object):
outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
- with self.assertRaises(ConnectionError):
- ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ self.assertEqual(result['unihash'], unihash2)
# Ensure that the database was not modified
self.assertClientGetHash(rw_client, taskhash2, None)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 10/22] asyncrpc: Add InvokeError
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (8 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
` (14 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for Invocation Errors (that is, errors raised by the actual
RPC call instead of at the protocol level) to propagate across the
connection. If a server RPC call raises an InvokeError, it will be sent
across the connection and then raised on the client side also. The
connection is still terminated on this error.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 1 +
lib/bb/asyncrpc/client.py | 10 ++++++++--
lib/bb/asyncrpc/exceptions.py | 4 ++++
lib/bb/asyncrpc/serv.py | 11 +++++++++--
4 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9f677eac..a4371643 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -12,4 +12,5 @@ from .exceptions import (
ClientError,
ServerError,
ConnectionClosedError,
+ InvokeError,
)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 009085c3..d27dbf71 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -11,7 +11,7 @@ import os
import socket
import sys
from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
-from .exceptions import ConnectionClosedError
+from .exceptions import ConnectionClosedError, InvokeError
class AsyncClient(object):
@@ -93,12 +93,18 @@ class AsyncClient(object):
await self.close()
count += 1
+ def check_invoke_error(self, msg):
+ if isinstance(msg, dict) and "invoke-error" in msg:
+ raise InvokeError(msg["invoke-error"]["message"])
+
async def invoke(self, msg):
async def proc():
await self.socket.send_message(msg)
return await self.socket.recv_message()
- return await self._send_wrapper(proc)
+ result = await self._send_wrapper(proc)
+ self.check_invoke_error(result)
+ return result
async def ping(self):
return await self.invoke({"ping": {}})
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
index a8942b4f..ae1043a3 100644
--- a/lib/bb/asyncrpc/exceptions.py
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -9,6 +9,10 @@ class ClientError(Exception):
pass
+class InvokeError(Exception):
+ pass
+
+
class ServerError(Exception):
pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 7569ad6c..1a7f9a88 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -14,7 +14,7 @@ import sys
import multiprocessing
import logging
from .connection import StreamConnection, WebsocketConnection
-from .exceptions import ClientError, ServerError, ConnectionClosedError
+from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
class ClientLoggerAdapter(logging.LoggerAdapter):
@@ -71,7 +71,14 @@ class AsyncServerConnection(object):
d = await self.socket.recv_message()
if d is None:
break
- response = await self.dispatch_message(d)
+ try:
+ response = await self.dispatch_message(d)
+ except InvokeError as e:
+ await self.socket.send_message(
+ {"invoke-error": {"message": str(e)}}
+ )
+ break
+
await self.socket.send_message(response)
except ConnectionClosedError as e:
self.logger.info(str(e))
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 11/22] asyncrpc: client: Prevent double closing of loop
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (9 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 10/22] asyncrpc: Add InvokeError Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 12/22] asyncrpc: client: Add disconnect API Joshua Watt
` (13 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Invalidate the loop in the client close() call so that it is not closed
twice (which is an error in the asyncio code)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index d27dbf71..628b90ee 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -161,10 +161,12 @@ class Client(object):
self.client.max_chunk = value
def close(self):
- self.loop.run_until_complete(self.client.close())
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ if self.loop:
+ self.loop.run_until_complete(self.client.close())
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
def __enter__(self):
return self
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 12/22] asyncrpc: client: Add disconnect API
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (10 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 13/22] hashserv: Add user permissions Joshua Watt
` (12 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to explicitly disconnect a client. This can be useful for
testing the auto-reconnect behavior of clients
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 628b90ee..0d7cd857 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -67,11 +67,14 @@ class AsyncClient(object):
self.socket = await self._connect_sock()
await self.setup_connection()
- async def close(self):
+ async def disconnect(self):
if self.socket is not None:
await self.socket.close()
self.socket = None
+ async def close(self):
+ await self.disconnect()
+
async def _send_wrapper(self, proc):
count = 0
while True:
@@ -160,6 +163,9 @@ class Client(object):
def max_chunk(self, value):
self.client.max_chunk = value
+ def disconnect(self):
+ self.loop.run_until_complete(self.client.close())
+
def close(self):
if self.loop:
self.loop.run_until_complete(self.client.close())
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 13/22] hashserv: Add user permissions
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (11 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 12/22] asyncrpc: client: Add disconnect API Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 14/22] hashserv: Add become-user API Joshua Watt
` (11 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for the hashserver to have per-user permissions. User
management is done via a new "auth" RPC API where a client can
authenticate itself with the server using a randomly generated token.
The user can then be given permissions to read, report, manage the
database, or manage other users.
In addition to explicit user logins, the server supports anonymous users
which is what all users start as before they make the "auth" RPC call.
Anonymous users can be assigned a set of permissions by the server,
making it unnecessary for users to authenticate to use the server. The
set of Anonymous permissions defines the default behavior of the server,
for example if set to "@read", Anonymous users are unable to report
equivalent hashes with authenticating. Similarly, setting the Anonymous
permissions to "@none" would require authentication for users to perform
any action.
User creation and management is entirely manual (although
bitbake-hashclient is very useful as a front end). There are many
different mechanisms that could be implemented to allow user
self-registration (e.g. OAuth, LDAP, etc.), and implementing these is
outside the scope of the server. Instead, it is recommended to
implement a registration service that validates users against the
necessary service, then adds them as a user in the hash equivalence
server.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 84 ++++++++-
bin/bitbake-hashserv | 37 ++++
lib/hashserv/__init__.py | 69 ++++---
lib/hashserv/client.py | 62 ++++++-
lib/hashserv/server.py | 357 ++++++++++++++++++++++++++++++++++++-
lib/hashserv/sqlalchemy.py | 111 +++++++++++-
lib/hashserv/sqlite.py | 105 +++++++++++
lib/hashserv/tests.py | 276 +++++++++++++++++++++++++++-
8 files changed, 1054 insertions(+), 47 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index a02a65b9..328c15cd 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -14,6 +14,7 @@ import sys
import threading
import time
import warnings
+import netrc
warnings.simplefilter("default")
try:
@@ -36,10 +37,18 @@ except ImportError:
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
import hashserv
+import bb.asyncrpc
DEFAULT_ADDRESS = 'unix://./hashserve.sock'
METHOD = 'stress.test.method'
+def print_user(u):
+ print(f"Username: {u['username']}")
+ if "permissions" in u:
+ print("Permissions: " + " ".join(u["permissions"]))
+ if "token" in u:
+ print(f"Token: {u['token']}")
+
def main():
def handle_stats(args, client):
@@ -125,9 +134,39 @@ def main():
print("Removed %d rows" % (result["count"]))
return 0
+ def handle_refresh_token(args, client):
+ r = client.refresh_token(args.username)
+ print_user(r)
+
+ def handle_set_user_permissions(args, client):
+ r = client.set_user_perms(args.username, args.permissions)
+ print_user(r)
+
+ def handle_get_user(args, client):
+ r = client.get_user(args.username)
+ print_user(r)
+
+ def handle_get_all_users(args, client):
+ users = client.get_all_users()
+ print("{username:20}| {permissions}".format(username="Username", permissions="Permissions"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for u in users:
+ print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"])))
+
+ def handle_new_user(args, client):
+ r = client.new_user(args.username, args.permissions)
+ print_user(r)
+
+ def handle_delete_user(args, client):
+ r = client.delete_user(args.username)
+ print_user(r)
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
+ parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
+ parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -158,6 +197,31 @@ def main():
clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old")
clean_unused_parser.set_defaults(func=handle_clean_unused)
+ refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token")
+ refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)")
+ refresh_token_parser.set_defaults(func=handle_refresh_token)
+
+ set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user")
+ set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True)
+ set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ set_user_perms_parser.set_defaults(func=handle_set_user_permissions)
+
+ get_user_parser = subparsers.add_parser('get-user', help="Get user")
+ get_user_parser.add_argument("--username", "-u", help="Username")
+ get_user_parser.set_defaults(func=handle_get_user)
+
+ get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users")
+ get_all_users_parser.set_defaults(func=handle_get_all_users)
+
+ new_user_parser = subparsers.add_parser('new-user', help="Create new user")
+ new_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ new_user_parser.set_defaults(func=handle_new_user)
+
+ delete_user_parser = subparsers.add_parser('delete-user', help="Delete user")
+ delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ delete_user_parser.set_defaults(func=handle_delete_user)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
@@ -171,10 +235,26 @@ def main():
console.setLevel(level)
logger.addHandler(console)
+ login = args.login
+ password = args.password
+
+ if login is None and args.netrc:
+ try:
+ n = netrc.netrc()
+ auth = n.authenticators(args.address)
+ if auth is not None:
+ login, _, password = auth
+ except FileNotFoundError:
+ pass
+
func = getattr(args, 'func', None)
if func:
- with hashserv.create_client(args.address) as client:
- return func(args, client)
+ try:
+ with hashserv.create_client(args.address, login, password) as client:
+ return func(args, client)
+ except bb.asyncrpc.InvokeError as e:
+ print(f"ERROR: {e}")
+ return 1
return 0
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 59b8b07f..1085d058 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -17,6 +17,7 @@ warnings.simplefilter("default")
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
+from hashserv.server import DEFAULT_ANON_PERMS
VERSION = "1.0.0"
@@ -36,6 +37,22 @@ The bind address may take one of the following formats:
To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+
+Note that the default Anonymous permissions are designed to not break existing
+server instances when upgrading, but are not particularly secure defaults. If
+you want to use authentication, it is recommended that you use "--anon-perms
+@read" to only give anonymous users read access, or "--anon-perms @none" to
+give un-authenticated users no access at all.
+
+Setting "--anon-perms @all" or "--anon-perms @user-admin" is not allowed, since
+this would allow anonymous users to manage all users accounts, which is a bad
+idea.
+
+If you are using user authentication, you should run your server in websockets
+mode with an SSL terminating load balancer in front of it (as this server does
+not implement SSL). Otherwise all usernames and passwords will be transmitted
+in the clear. When configured this way, clients can connect using a secure
+websocket, as in "wss://SERVER:PORT"
""",
)
@@ -79,6 +96,22 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
help="Database password ($HASHSERVER_DB_PASSWORD)",
)
+ parser.add_argument(
+ "--anon-perms",
+ metavar="PERM[,PERM[,...]]",
+ default=os.environ.get("HASHSERVER_ANON_PERMS", ",".join(DEFAULT_ANON_PERMS)),
+ help='Permissions to give anonymous users (default $HASHSERVER_ANON_PERMS, "%(default)s")',
+ )
+ parser.add_argument(
+ "--admin-user",
+ default=os.environ.get("HASHSERVER_ADMIN_USER", None),
+ help="Create default admin user with name ADMIN_USER ($HASHSERVER_ADMIN_USER)",
+ )
+ parser.add_argument(
+ "--admin-password",
+ default=os.environ.get("HASHSERVER_ADMIN_PASSWORD", None),
+ help="Create default admin user with password ADMIN_PASSWORD ($HASHSERVER_ADMIN_PASSWORD)",
+ )
args = parser.parse_args()
@@ -94,6 +127,7 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+ anon_perms = args.anon_perms.split(",")
server = hashserv.create_server(
args.bind,
@@ -102,6 +136,9 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
read_only=read_only,
db_username=args.db_username,
db_password=args.db_password,
+ anon_perms=anon_perms,
+ admin_username=args.admin_user,
+ admin_password=args.admin_password,
)
server.serve_forever()
return 0
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9a8ee4e8..552a3327 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -8,6 +8,7 @@ from contextlib import closing
import re
import itertools
import json
+from collections import namedtuple
from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
@@ -18,6 +19,8 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
+User = namedtuple("User", ("username", "permissions"))
+
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
@@ -43,7 +46,10 @@ def create_server(
upstream=None,
read_only=False,
db_username=None,
- db_password=None
+ db_password=None,
+ anon_perms=None,
+ admin_username=None,
+ admin_password=None,
):
def sqlite_engine():
from .sqlite import DatabaseEngine
@@ -62,7 +68,17 @@ def create_server(
else:
db_engine = sqlite_engine()
- s = server.Server(db_engine, upstream=upstream, read_only=read_only)
+ if anon_perms is None:
+ anon_perms = server.DEFAULT_ANON_PERMS
+
+ s = server.Server(
+ db_engine,
+ upstream=upstream,
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password,
+ )
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -76,33 +92,40 @@ def create_server(
return s
-def create_client(addr):
+def create_client(addr, username=None, password=None):
from . import client
- c = client.Client()
-
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- c.connect_websocket(*a)
- else:
- c.connect_tcp(*a)
+ c = client.Client(username, password)
- return c
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
+ else:
+ c.connect_tcp(*a)
+ return c
+ except Exception as e:
+ c.close()
+ raise e
-async def create_async_client(addr):
+async def create_async_client(addr, username=None, password=None):
from . import client
- c = client.AsyncClient()
+ c = client.AsyncClient(username, password)
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- await c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- await c.connect_websocket(*a)
- else:
- await c.connect_tcp(*a)
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
+ else:
+ await c.connect_tcp(*a)
- return c
+ return c
+ except Exception as e:
+ await c.close()
+ raise e
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index ebb58f33..5ed8d381 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -6,6 +6,7 @@
import logging
import socket
import bb.asyncrpc
+import json
from . import create_async_client
@@ -16,15 +17,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_NORMAL = 0
MODE_GET_STREAM = 1
- def __init__(self):
+ def __init__(self, username=None, password=None):
super().__init__('OEHASHEQUIV', '1.1', logger)
self.mode = self.MODE_NORMAL
+ self.username = username
+ self.password = password
async def setup_connection(self):
await super().setup_connection()
cur_mode = self.mode
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
+ if self.username:
+ await self.auth(self.username, self.password)
async def send_stream(self, msg):
async def proc():
@@ -41,6 +46,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
r = await self._send_wrapper(stream_to_normal)
if r != "ok":
+ self.check_invoke_error(r)
raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
r = await self.invoke({"get-stream": None})
@@ -109,9 +115,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
+ async def auth(self, username, token):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"auth": {"username": username, "token": token}})
+ self.username = username
+ self.password = token
+ return result
+
+ async def refresh_token(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ result = await self.invoke({"refresh-token": m})
+ if self.username and result["username"] == self.username:
+ self.password = result["token"]
+ return result
+
+ async def set_user_perms(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+
+ async def get_user(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ return await self.invoke({"get-user": m})
+
+ async def get_all_users(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-all-users": {}}))["users"]
+
+ async def new_user(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+
+ async def delete_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"delete-user": {"username": username}})
+
class Client(bb.asyncrpc.Client):
- def __init__(self):
+ def __init__(self, username=None, password=None):
+ self.username = username
+ self.password = password
+
super().__init__()
self._add_methods(
"connect_tcp",
@@ -126,7 +175,14 @@ class Client(bb.asyncrpc.Client):
"backfill_wait",
"remove",
"clean_unused",
+ "auth",
+ "refresh_token",
+ "set_user_perms",
+ "get_user",
+ "get_all_users",
+ "new_user",
+ "delete_user",
)
def _get_async_client(self):
- return AsyncClient()
+ return AsyncClient(self.username, self.password)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 2e6977cb..5c70d81f 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -8,13 +8,48 @@ import asyncio
import logging
import math
import time
+import os
+import base64
+import hashlib
from . import create_async_client
import bb.asyncrpc
-
logger = logging.getLogger("hashserv.server")
+# This permission only exists to match nothing
+NONE_PERM = "@none"
+
+READ_PERM = "@read"
+REPORT_PERM = "@report"
+DB_ADMIN_PERM = "@db-admin"
+USER_ADMIN_PERM = "@user-admin"
+ALL_PERM = "@all"
+
+ALL_PERMISSIONS = {
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+ USER_ADMIN_PERM,
+ ALL_PERM,
+}
+
+DEFAULT_ANON_PERMS = (
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+)
+
+TOKEN_ALGORITHM = "sha256"
+
+# 48 bytes of random data will result in 64 characters when base64
+# encoded. This number also ensures that the base64 encoding won't have any
+# trailing '=' characters.
+TOKEN_SIZE = 48
+
+SALT_SIZE = 8
+
+
class Measurement(object):
def __init__(self, sample):
self.sample = sample
@@ -108,6 +143,85 @@ class Stats(object):
}
+token_refresh_semaphore = asyncio.Lock()
+
+
+async def new_token():
+ # Prevent malicious users from using this API to deduce the entropy
+ # pool on the server and thus be able to guess a token. *All* token
+ # refresh requests lock the same global semaphore and then sleep for a
+ # short time. The effectively rate limits the total number of requests
+ # than can be made across all clients to 10/second, which should be enough
+ # since you have to be an authenticated users to make the request in the
+ # first place
+ async with token_refresh_semaphore:
+ await asyncio.sleep(0.1)
+ raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
+
+ return base64.b64encode(raw, b"._").decode("utf-8")
+
+
+def new_salt():
+ return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
+
+
+def hash_token(algo, salt, token):
+ h = hashlib.new(algo)
+ h.update(salt.encode("utf-8"))
+ h.update(token.encode("utf-8"))
+ return ":".join([algo, salt, h.hexdigest()])
+
+
+def permissions(*permissions, allow_anon=True, allow_self_service=False):
+ """
+ Function decorator that can be used to decorate an RPC function call and
+ check that the current users permissions match the require permissions.
+
+ If allow_anon is True, the user will also be allowed to make the RPC call
+ if the anonymous user permissions match the permissions.
+
+ If allow_self_service is True, and the "username" property in the request
+ is the currently logged in user, or not specified, the user will also be
+ allowed to make the request. This allows users to access normal privileged
+ API, as long as they are only modifying their own user properties (e.g.
+ users can be allowed to reset their own token without @user-admin
+ permissions, but not the token for any other user.
+ """
+
+ def wrapper(func):
+ async def wrap(self, request):
+ if allow_self_service and self.user is not None:
+ username = request.get("username", self.user.username)
+ if username == self.user.username:
+ request["username"] = self.user.username
+ return await func(self, request)
+
+ if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
+ if not self.user:
+ username = "Anonymous user"
+ user_perms = self.anon_perms
+ else:
+ username = self.user.username
+ user_perms = self.user.permissions
+
+ self.logger.info(
+ "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
+ username,
+ ", ".join(user_perms),
+ func.__name__,
+ ", ".join(permissions),
+ )
+ raise bb.asyncrpc.InvokeError(
+ f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
+ )
+
+ return await func(self, request)
+
+ return wrap
+
+ return wrapper
+
+
class ServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(
self,
@@ -117,6 +231,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
backfill_queue,
upstream,
read_only,
+ anon_perms,
):
super().__init__(socket, "OEHASHEQUIV", logger)
self.db_engine = db_engine
@@ -125,6 +240,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.backfill_queue = backfill_queue
self.upstream = upstream
self.read_only = read_only
+ self.user = None
+ self.anon_perms = anon_perms
self.handlers.update(
{
@@ -135,6 +252,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
+ "auth": self.handle_auth,
+ "get-user": self.handle_get_user,
+ "get-all-users": self.handle_get_all_users,
}
)
@@ -146,9 +266,36 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"backfill-wait": self.handle_backfill_wait,
"remove": self.handle_remove,
"clean-unused": self.handle_clean_unused,
+ "refresh-token": self.handle_refresh_token,
+ "set-user-perms": self.handle_set_perms,
+ "new-user": self.handle_new_user,
+ "delete-user": self.handle_delete_user,
}
)
+ def raise_no_user_error(self, username):
+ raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
+
+ def user_has_permissions(self, *permissions, allow_anon=True):
+ permissions = set(permissions)
+ if allow_anon:
+ if ALL_PERM in self.anon_perms:
+ return True
+
+ if not permissions - self.anon_perms:
+ return True
+
+ if self.user is None:
+ return False
+
+ if ALL_PERM in self.user.permissions:
+ return True
+
+ if not permissions - self.user.permissions:
+ return True
+
+ return False
+
def validate_proto_version(self):
return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
@@ -178,6 +325,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
+ @permissions(READ_PERM)
async def handle_get(self, request):
method = request["method"]
taskhash = request["taskhash"]
@@ -206,6 +354,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return d
+ @permissions(READ_PERM)
async def handle_get_outhash(self, request):
method = request["method"]
outhash = request["outhash"]
@@ -236,6 +385,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
await self.db.insert_outhash(data)
+ @permissions(READ_PERM)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -303,8 +453,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ # Since this can be called either read only or to report, the check to
+ # report is made inside the function
+ @permissions(READ_PERM)
async def handle_report(self, data):
- if self.read_only:
+ if self.read_only or not self.user_has_permissions(REPORT_PERM):
return await self.report_readonly(data)
outhash_data = {
@@ -357,6 +510,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ @permissions(READ_PERM, REPORT_PERM)
async def handle_equivreport(self, data):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
@@ -374,11 +528,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {k: row[k] for k in ("taskhash", "method", "unihash")}
+ @permissions(READ_PERM)
async def handle_get_stats(self, request):
return {
"requests": self.request_stats.todict(),
}
+ @permissions(DB_ADMIN_PERM)
async def handle_reset_stats(self, request):
d = {
"requests": self.request_stats.todict(),
@@ -387,6 +543,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.request_stats.reset()
return d
+ @permissions(READ_PERM)
async def handle_backfill_wait(self, request):
d = {
"tasks": self.backfill_queue.qsize(),
@@ -394,6 +551,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.backfill_queue.join()
return d
+ @permissions(DB_ADMIN_PERM)
async def handle_remove(self, request):
condition = request["where"]
if not isinstance(condition, dict):
@@ -401,19 +559,178 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"count": await self.db.remove(condition)}
+ @permissions(DB_ADMIN_PERM)
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ # The authentication API is always allowed
+ async def handle_auth(self, request):
+ username = str(request["username"])
+ token = str(request["token"])
+
+ async def fail_auth():
+ nonlocal username
+ # Rate limit bad login attempts
+ await asyncio.sleep(1)
+ raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
+
+ user, db_token = await self.db.lookup_user_token(username)
+
+ if not user or not db_token:
+ await fail_auth()
+
+ try:
+ algo, salt, _ = db_token.split(":")
+ except ValueError:
+ await fail_auth()
+
+ if hash_token(algo, salt, token) != db_token:
+ await fail_auth()
+
+ self.user = user
+
+ self.logger.info("Authenticated as %s", username)
+
+ return {
+ "result": True,
+ "username": self.user.username,
+ "permissions": sorted(list(self.user.permissions)),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_refresh_token(self, request):
+ username = str(request["username"])
+
+ token = await new_token()
+
+ updated = await self.db.set_user_token(
+ username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not updated:
+ self.raise_no_user_error(username)
+
+ return {"username": username, "token": token}
+
+ def get_perm_arg(self, arg):
+ if not isinstance(arg, list):
+ raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
+
+ arg = set(arg)
+ try:
+ arg.remove(NONE_PERM)
+ except KeyError:
+ pass
+
+ unknown_perms = arg - ALL_PERMISSIONS
+ if unknown_perms:
+ raise bb.asyncrpc.InvokeError(
+ "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
+ )
+
+ return sorted(list(arg))
+
+ def return_perms(self, permissions):
+ if ALL_PERM in permissions:
+ return sorted(list(ALL_PERMISSIONS))
+ return sorted(list(permissions))
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_set_perms(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ if not await self.db.set_user_perms(username, permissions):
+ self.raise_no_user_error(username)
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_get_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ return None
+
+ return {
+ "username": user.username,
+ "permissions": self.return_perms(user.permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_get_all_users(self, request):
+ users = await self.db.get_all_users()
+ return {
+ "users": [
+ {
+ "username": u.username,
+ "permissions": self.return_perms(u.permissions),
+ }
+ for u in users
+ ]
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_new_user(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ token = await new_token()
+
+ inserted = await self.db.new_user(
+ username,
+ permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not inserted:
+ raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ "token": token,
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_delete_user(self, request):
+ username = str(request["username"])
+
+ if not await self.db.delete_user(username):
+ self.raise_no_user_error(username)
+
+ return {"username": username}
+
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db_engine, upstream=None, read_only=False):
+ def __init__(
+ self,
+ db_engine,
+ upstream=None,
+ read_only=False,
+ anon_perms=DEFAULT_ANON_PERMS,
+ admin_username=None,
+ admin_password=None,
+ ):
if upstream and read_only:
raise bb.asyncrpc.ServerError(
"Read-only hashserv cannot pull from an upstream server"
)
+ disallowed_perms = set(anon_perms) - set(
+ [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
+ )
+
+ if disallowed_perms:
+ raise bb.asyncrpc.ServerError(
+ f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
+ )
+
super().__init__(logger)
self.request_stats = Stats()
@@ -421,6 +738,13 @@ class Server(bb.asyncrpc.AsyncServer):
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
+ self.anon_perms = set(anon_perms)
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+
+ self.logger.info(
+ "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
+ )
def accept_client(self, socket):
return ServerClient(
@@ -430,12 +754,34 @@ class Server(bb.asyncrpc.AsyncServer):
self.backfill_queue,
self.upstream,
self.read_only,
+ self.anon_perms,
)
+ async def create_admin_user(self):
+ admin_permissions = (ALL_PERM,)
+ async with self.db_engine.connect(self.logger) as db:
+ added = await db.new_user(
+ self.admin_username,
+ admin_permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ if added:
+ self.logger.info("Created admin user '%s'", self.admin_username)
+ else:
+ await db.set_user_perms(
+ self.admin_username,
+ admin_permissions,
+ )
+ await db.set_user_token(
+ self.admin_username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ self.logger.info("Admin user '%s' updated", self.admin_username)
+
async def backfill_worker_task(self):
async with await create_async_client(
self.upstream
- ) as client, self.db_engine.connect(logger) as db:
+ ) as client, self.db_engine.connect(self.logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
@@ -456,6 +802,9 @@ class Server(bb.asyncrpc.AsyncServer):
self.loop.run_until_complete(self.db_engine.create())
+ if self.admin_username:
+ self.loop.run_until_complete(self.create_admin_user())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 3216621f..bfd8a844 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -7,6 +7,7 @@
import logging
from datetime import datetime
+from . import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
@@ -25,13 +26,12 @@ from sqlalchemy import (
literal,
and_,
delete,
+ update,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.exc import IntegrityError
-logger = logging.getLogger("hashserv.sqlalchemy")
-
Base = declarative_base()
@@ -68,9 +68,19 @@ class OuthashesV2(Base):
)
+class Users(Base):
+ __tablename__ = "users"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ username = Column(Text, nullable=False)
+ token = Column(Text, nullable=False)
+ permissions = Column(Text)
+
+ __table_args__ = (UniqueConstraint("username"),)
+
+
class DatabaseEngine(object):
def __init__(self, url, username=None, password=None):
- self.logger = logger
+ self.logger = logging.getLogger("hashserv.sqlalchemy")
self.url = sqlalchemy.engine.make_url(url)
if username is not None:
@@ -85,7 +95,7 @@ class DatabaseEngine(object):
async with self.engine.begin() as conn:
# Create tables
- logger.info("Creating tables...")
+ self.logger.info("Creating tables...")
await conn.run_sync(Base.metadata.create_all)
def connect(self, logger):
@@ -98,6 +108,15 @@ def map_row(row):
return dict(**row._mapping)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row.username,
+ permissions=set(row.permissions.split()),
+ )
+
+
class Database(object):
def __init__(self, engine, logger):
self.engine = engine
@@ -278,7 +297,7 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s, %s already in unihash database", method, taskhash, unihash
)
return False
@@ -298,7 +317,87 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s already in outhash database", data["method"], data["outhash"]
)
return False
+
+ async def _get_user(self, username):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ Users.token,
+ ).where(
+ Users.username == username,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.first()
+
+ async def lookup_user_token(self, username):
+ row = await self._get_user(username)
+ if not row:
+ return None, None
+ return map_user(row), row.token
+
+ async def lookup_user(self, username):
+ return map_user(await self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ statement = (
+ update(Users)
+ .where(
+ Users.username == username,
+ )
+ .values(
+ token=token,
+ )
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ statement = (
+ update(Users)
+ .where(Users.username == username)
+ .values(permissions=" ".join(permissions))
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def get_all_users(self):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return [map_user(row) for row in result]
+
+ async def new_user(self, username, permissions, token):
+ statement = insert(Users).values(
+ username=username,
+ permissions=" ".join(permissions),
+ token=token,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError as e:
+ self.logger.debug("Cannot create new user %s: %s", username, e)
+ return False
+
+ async def delete_user(self, username):
+ statement = delete(Users).where(Users.username == username)
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 6809c537..414ee8ff 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -7,6 +7,7 @@
import sqlite3
import logging
from contextlib import closing
+from . import User
logger = logging.getLogger("hashserv.sqlite")
@@ -34,6 +35,14 @@ OUTHASH_TABLE_DEFINITION = (
OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+USERS_TABLE_DEFINITION = (
+ ("username", "TEXT NOT NULL", "UNIQUE"),
+ ("token", "TEXT NOT NULL", ""),
+ ("permissions", "TEXT NOT NULL", ""),
+)
+
+USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
+
def _make_table(cursor, name, definition):
cursor.execute(
@@ -53,6 +62,15 @@ def _make_table(cursor, name, definition):
)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row["username"],
+ permissions=set(row["permissions"].split()),
+ )
+
+
class DatabaseEngine(object):
def __init__(self, dbname, sync):
self.dbname = dbname
@@ -66,6 +84,7 @@ class DatabaseEngine(object):
with closing(db.cursor()) as cursor:
_make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
_make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+ _make_table(cursor, "users", USERS_TABLE_DEFINITION)
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute(
@@ -227,6 +246,7 @@ class Database(object):
"oldest": oldest,
},
)
+ self.db.commit()
return cursor.rowcount
async def insert_unihash(self, method, taskhash, unihash):
@@ -257,3 +277,88 @@ class Database(object):
cursor.execute(query, data)
self.db.commit()
return cursor.lastrowid != prevrowid
+
+ def _get_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT username, permissions, token FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ return cursor.fetchone()
+
+ async def lookup_user_token(self, username):
+ row = self._get_user(username)
+ if row is None:
+ return None, None
+ return map_user(row), row["token"]
+
+ async def lookup_user(self, username):
+ return map_user(self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET token=:token WHERE username=:username
+ """,
+ {
+ "username": username,
+ "token": token,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET permissions=:permissions WHERE username=:username
+ """,
+ {
+ "username": username,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def get_all_users(self):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute("SELECT username, permissions FROM users")
+ return [map_user(r) for r in cursor.fetchall()]
+
+ async def new_user(self, username, permissions, token):
+ with closing(self.db.cursor()) as cursor:
+ try:
+ cursor.execute(
+ """
+ INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
+ """,
+ {
+ "username": username,
+ "token": token,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return True
+ except sqlite3.IntegrityError:
+ return False
+
+ async def delete_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e9a361dc..f92f37c4 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -6,6 +6,8 @@
#
from . import create_server, create_client
+from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
+from bb.asyncrpc import InvokeError
import hashlib
import logging
import multiprocessing
@@ -29,8 +31,9 @@ class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
+ client_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
self.server_index += 1
if dbpath is None:
dbpath = self.make_dbpath()
@@ -45,7 +48,10 @@ class HashEquivalenceTestSetup(object):
server = create_server(self.get_server_addr(self.server_index),
dbpath,
upstream=upstream,
- read_only=read_only)
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password)
server.dbpath = dbpath
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
@@ -56,18 +62,31 @@ class HashEquivalenceTestSetup(object):
def make_dbpath(self):
return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def start_client(self, server_address):
+ def start_client(self, server_address, username=None, password=None):
def cleanup_client(client):
client.close()
- client = create_client(server_address)
+ client = create_client(server_address, username=username, password=password)
self.addCleanup(cleanup_client, client)
return client
def start_test_server(self):
- server = self.start_server()
- return server.address
+ self.server = self.start_server()
+ return self.server.address
+
+ def start_auth_server(self):
+ self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ return self.admin_client
+
+ def auth_client(self, user):
+ return self.start_client(self.auth_server.address, user["username"], user["token"])
+
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -86,18 +105,21 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
- def test_create_hash(self):
+ def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- self.assertClientGetHash(self.client, taskhash, None)
+ self.assertClientGetHash(client, taskhash, None)
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def test_create_hash(self):
+ return self.create_test_hash(self.client)
+
def test_create_equivalent(self):
# Tests that a second reported task with the same outhash will be
# assigned the same unihash
@@ -471,6 +493,242 @@ class HashEquivalenceCommonTests(object):
# shares a taskhash with Task 2
self.assertClientGetHash(self.client, taskhash2, unihash2)
+ def test_auth_read_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Create hashes with non-authenticated server
+ taskhash, outhash, unihash = self.test_create_hash()
+
+ # Validate hash can be retrieved using authenticated client
+ with self.auth_perms("@read") as client:
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_report_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Without read permission, the user is completely denied
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read permission allows the call to succeed, but it doesn't record
+ # anythin in the database
+ with self.auth_perms("@read") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, None)
+
+ # Report permission alone is insufficient
+ with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read and report permission actually modify the database
+ with self.auth_perms("@read", "@report") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_no_token_refresh_from_anon_user(self):
+ self.start_auth_server()
+
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.refresh_token()
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
+
+ def test_auth_self_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ # Create a new user with no permissions
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client:
+ new_user = client.refresh_token()
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ # Explicitly specifying with your own username is fine also
+ with self.auth_client(new_user) as client:
+ new_user2 = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user2["username"])
+ self.assertNotEqual(user["token"], new_user2["token"])
+ self.assertUserCanAuth(new_user2)
+ self.assertUserCannotAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.refresh_token(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ new_user = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_self_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Explicitly asking for your own username is fine also
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ def test_auth_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.get_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ info = client.get_user("nonexist-user")
+ self.assertIsNone(info)
+
+ def test_auth_reconnect(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ def test_auth_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ # No self service
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ client.delete_user(user["username"])
+
+ # User doesn't exist, so even though the permission is correct, it's an
+ # error
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def test_auth_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ self.assertUserPerms(user, [])
+
+ # No self service to change permissions
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms("@user-admin") as client:
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ # Bad permissions
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ def test_auth_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.get_all_users()
+
+ # Give the test user the correct permission
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ with self.auth_client(user) as client:
+ all_users = client.get_all_users()
+
+ # Convert to a dictionary for easier comparison
+ all_users = {u["username"]: u for u in all_users}
+
+ self.assertEqual(all_users,
+ {
+ "admin": {
+ "username": "admin",
+ "permissions": sorted(list(ALL_PERMISSIONS)),
+ },
+ "test-user": {
+ "username": "test-user",
+ "permissions": ["@user-admin"],
+ }
+ }
+ )
+
+ def test_auth_new_user(self):
+ self.start_auth_server()
+
+ permissions = ["@read", "@report", "@db-admin", "@user-admin"]
+ permissions.sort()
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.new_user("test-user", permissions)
+
+ with self.auth_perms("@user-admin") as client:
+ user = client.new_user("test-user", permissions)
+ self.assertIn("token", user)
+ self.assertEqual(user["username"], "test-user")
+ self.assertEqual(user["permissions"], permissions)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 14/22] hashserv: Add become-user API
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (12 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 13/22] hashserv: Add user permissions Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 15/22] hashserv: Add db-usage API Joshua Watt
` (10 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds API that allows a user admin to impersonate another user in the
system. This makes it easier to write external services that have
external authentication, since they can use a common user account to
access the server, then impersonate the logged in user.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 +++
lib/hashserv/client.py | 42 +++++++++++++++++++++++++++++++++++++-----
lib/hashserv/server.py | 18 ++++++++++++++++++
lib/hashserv/tests.py | 39 +++++++++++++++++++++++++++++++++++++++
4 files changed, 97 insertions(+), 5 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 328c15cd..cfbc197e 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -166,6 +166,7 @@ def main():
parser.add_argument('--log', default='WARNING', help='Set logging level')
parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions")
parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -251,6 +252,8 @@ def main():
if func:
try:
with hashserv.create_client(args.address, login, password) as client:
+ if args.become:
+ client.become_user(args.become)
return func(args, client)
except bb.asyncrpc.InvokeError as e:
print(f"ERROR: {e}")
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 5ed8d381..90f1dd71 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -18,10 +18,11 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_GET_STREAM = 1
def __init__(self, username=None, password=None):
- super().__init__('OEHASHEQUIV', '1.1', logger)
+ super().__init__("OEHASHEQUIV", "1.1", logger)
self.mode = self.MODE_NORMAL
self.username = username
self.password = password
+ self.saved_become_user = None
async def setup_connection(self):
await super().setup_connection()
@@ -29,8 +30,13 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
if self.username:
+ # Save off become user temporarily because auth() resets it
+ become = self.saved_become_user
await self.auth(self.username, self.password)
+ if become:
+ await self.become_user(become)
+
async def send_stream(self, msg):
async def proc():
await self.socket.send(msg)
@@ -92,7 +98,14 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke(
- {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
+ {
+ "get-outhash": {
+ "outhash": outhash,
+ "taskhash": taskhash,
+ "method": method,
+ "with_unihash": with_unihash,
+ }
+ }
)
async def get_stats(self):
@@ -120,6 +133,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
result = await self.invoke({"auth": {"username": username, "token": token}})
self.username = username
self.password = token
+ self.saved_become_user = None
return result
async def refresh_token(self, username=None):
@@ -128,13 +142,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if username:
m["username"] = username
result = await self.invoke({"refresh-token": m})
- if self.username and result["username"] == self.username:
+ if (
+ self.username
+ and not self.saved_become_user
+ and result["username"] == self.username
+ ):
self.password = result["token"]
return result
async def set_user_perms(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"set-user-perms": {"username": username, "permissions": permissions}}
+ )
async def get_user(self, username=None):
await self._set_mode(self.MODE_NORMAL)
@@ -149,12 +169,23 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def new_user(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"new-user": {"username": username, "permissions": permissions}}
+ )
async def delete_user(self, username):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"delete-user": {"username": username}})
+ async def become_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"become-user": {"username": username}})
+ if username == self.username:
+ self.saved_become_user = None
+ else:
+ self.saved_become_user = username
+ return result
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -182,6 +213,7 @@ class Client(bb.asyncrpc.Client):
"get_all_users",
"new_user",
"delete_user",
+ "become_user",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 5c70d81f..d506088e 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -255,6 +255,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"auth": self.handle_auth,
"get-user": self.handle_get_user,
"get-all-users": self.handle_get_all_users,
+ "become-user": self.handle_become_user,
}
)
@@ -706,6 +707,23 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"username": username}
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_become_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
+
+ self.user = user
+
+ self.logger.info("Became user %s", username)
+
+ return {
+ "username": self.user.username,
+ "permissions": self.return_perms(self.user.permissions),
+ }
+
class Server(bb.asyncrpc.AsyncServer):
def __init__(
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f92f37c4..311b7b77 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -728,6 +728,45 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
+ def test_auth_become_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(user["username"])
+ self.assertEqual(become, user_info)
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Verify become user is preserved across disconnect
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # test-user doesn't have become_user permissions, so this should
+ # not work
+ with self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # No self-service of become
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # Give test user permissions to become
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ # It's possible to become yourself (effectively a noop)
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(client.username)
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 15/22] hashserv: Add db-usage API
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (13 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 14/22] hashserv: Add become-user API Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 16/22] hashserv: Add database column query API Joshua Watt
` (9 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to query the server for the usage of the database (e.g. how
many rows are present in each table)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 16 ++++++++++++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 14 ++++++++++++++
lib/hashserv/sqlite.py | 20 ++++++++++++++++++++
lib/hashserv/tests.py | 9 +++++++++
6 files changed, 69 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index cfbc197e..5d65c7bc 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -161,6 +161,19 @@ def main():
r = client.delete_user(args.username)
print_user(r)
+ def handle_get_db_usage(args, client):
+ usage = client.get_db_usage()
+ print(usage)
+ tables = sorted(usage.keys())
+ print("{name:20}| {rows:20}".format(name="Table name", rows="Rows"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for t in tables:
+ print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"]))
+ print()
+
+ total_rows = sum(t["rows"] for t in usage.values())
+ print(f"Total rows: {total_rows}")
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -223,6 +236,9 @@ def main():
delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
delete_user_parser.set_defaults(func=handle_delete_user)
+ db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
+ db_usage_parser.set_defaults(func=handle_get_db_usage)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 90f1dd71..0c3f556a 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -186,6 +186,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.saved_become_user = username
return result
+ async def get_db_usage(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-usage": {}}))["usage"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -214,6 +218,7 @@ class Client(bb.asyncrpc.Client):
"new_user",
"delete_user",
"become_user",
+ "get_db_usage",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index d506088e..4fec1556 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -249,6 +249,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ "get-db-usage": self.handle_get_db_usage,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -566,6 +567,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_usage(self, request):
+ return {"usage": await self.db.get_usage()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index bfd8a844..818b5195 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -27,6 +27,7 @@ from sqlalchemy import (
and_,
delete,
update,
+ func,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
@@ -401,3 +402,16 @@ class Database(object):
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ async with self.db.begin() as session:
+ for name, table in Base.metadata.tables.items():
+ statement = select(func.count()).select_from(table)
+ self.logger.debug("%s", statement)
+ result = await self.db.execute(statement)
+ usage[name] = {
+ "rows": result.scalar(),
+ }
+
+ return usage
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 414ee8ff..e9ef38a1 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -362,3 +362,23 @@ class Database(object):
)
self.db.commit()
return cursor.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
+ """
+ )
+ for row in cursor.fetchall():
+ cursor.execute(
+ """
+ SELECT COUNT() FROM %s
+ """
+ % row["name"],
+ )
+ usage[row["name"]] = {
+ "rows": cursor.fetchone()[0],
+ }
+ return usage
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 311b7b77..9d5bec24 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -767,6 +767,15 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client:
become = client.become_user(client.username)
+ def test_get_db_usage(self):
+ usage = self.client.get_db_usage()
+
+ self.assertTrue(isinstance(usage, dict))
+ for name in usage.keys():
+ self.assertTrue(isinstance(usage[name], dict))
+ self.assertIn("rows", usage[name])
+ self.assertTrue(isinstance(usage[name]["rows"], int))
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 16/22] hashserv: Add database column query API
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (14 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 15/22] hashserv: Add db-usage API Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
` (8 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to retrieve the columns that can be queried on from the
database backend. This prevents front end applications from needing to
hardcode the query columns
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 7 +++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 10 ++++++++++
lib/hashserv/sqlite.py | 7 +++++++
lib/hashserv/tests.py | 8 ++++++++
6 files changed, 42 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 5d65c7bc..58aa02ee 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -174,6 +174,10 @@ def main():
total_rows = sum(t["rows"] for t in usage.values())
print(f"Total rows: {total_rows}")
+ def handle_get_db_query_columns(args, client):
+ columns = client.get_db_query_columns()
+ print("\n".join(sorted(columns)))
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -239,6 +243,9 @@ def main():
db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
db_usage_parser.set_defaults(func=handle_get_db_usage)
+ db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries")
+ db_query_columns_parser.set_defaults(func=handle_get_db_query_columns)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 0c3f556a..319da2d9 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -190,6 +190,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return (await self.invoke({"get-db-usage": {}}))["usage"]
+ async def get_db_query_columns(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-query-columns": {}}))["columns"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -219,6 +223,7 @@ class Client(bb.asyncrpc.Client):
"delete_user",
"become_user",
"get_db_usage",
+ "get_db_query_columns",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 4fec1556..d2fd75df 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -250,6 +250,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
"get-db-usage": self.handle_get_db_usage,
+ "get-db-query-columns": self.handle_get_db_query_columns,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -571,6 +572,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_get_db_usage(self, request):
return {"usage": await self.db.get_usage()}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_query_columns(self, request):
+ return {"columns": await self.db.get_query_columns()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 818b5195..cee04bff 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -415,3 +415,13 @@ class Database(object):
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for table in (UnihashesV2, OuthashesV2):
+ for c in table.__table__.columns:
+ if not isinstance(c.type, Text):
+ continue
+ columns.add(c.key)
+
+ return list(columns)
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index e9ef38a1..f5c451f4 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -382,3 +382,10 @@ class Database(object):
"rows": cursor.fetchone()[0],
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
+ if typ.startswith("TEXT"):
+ columns.add(name)
+ return list(columns)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 9d5bec24..fc69acaf 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -776,6 +776,14 @@ class HashEquivalenceCommonTests(object):
self.assertIn("rows", usage[name])
self.assertTrue(isinstance(usage[name]["rows"], int))
+ def test_get_db_query_columns(self):
+ columns = self.client.get_db_query_columns()
+
+ self.assertTrue(isinstance(columns, list))
+ self.assertTrue(len(columns) > 0)
+
+ for col in columns:
+ self.client.remove({col: ""})
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 17/22] hashserv: test: Add bitbake-hashclient tests
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (15 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 16/22] hashserv: Add database column query API Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
` (7 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
The bitbake-hashclient command-line tool now has a lot more features
which should be tested, so add some tests for them.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 300 ++++++++++++++++++++++++++++++++++++++----
1 file changed, 277 insertions(+), 23 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index fc69acaf..a80ccd57 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -19,6 +19,14 @@ import unittest
import socket
import time
import signal
+import subprocess
+import json
+import re
+from pathlib import Path
+
+
+THIS_DIR = Path(__file__).parent
+BIN_DIR = THIS_DIR.parent.parent / "bin"
def server_prefunc(server, idx):
logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
@@ -103,8 +111,22 @@ class HashEquivalenceTestSetup(object):
result = client.get_unihash(self.METHOD, taskhash)
self.assertEqual(result, unihash)
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
-class HashEquivalenceCommonTests(object):
def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
@@ -117,6 +139,24 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def run_hashclient(self, args, **kwargs):
+ try:
+ p = subprocess.run(
+ [BIN_DIR / "bitbake-hashclient"] + args,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ encoding="utf-8",
+ **kwargs
+ )
+ except subprocess.CalledProcessError as e:
+ print(e.output)
+ raise e
+
+ print(p.stdout)
+ return p
+
+
+class HashEquivalenceCommonTests(object):
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -161,7 +201,7 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash, unihash)
def test_remove_taskhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"taskhash": taskhash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -170,13 +210,13 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_unihash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"unihash": unihash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
def test_remove_outhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"outhash": outhash})
self.assertGreater(result["count"], 0)
@@ -184,7 +224,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_method(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"method": self.METHOD})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -193,7 +233,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_clean_unused(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Clean the database, which should not remove anything because all hashes an in-use
result = self.client.clean_unused(0)
@@ -497,7 +537,7 @@ class HashEquivalenceCommonTests(object):
admin_client = self.start_auth_server()
# Create hashes with non-authenticated server
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Validate hash can be retrieved using authenticated client
with self.auth_perms("@read") as client:
@@ -534,14 +574,6 @@ class HashEquivalenceCommonTests(object):
with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
client.refresh_token()
- def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
- client.auth(user["username"], user["token"])
-
- def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
- client.auth(user["username"], user["token"])
-
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
@@ -650,14 +682,6 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
- def assertUserPerms(self, user, permissions):
- with self.auth_client(user) as client:
- info = client.get_user()
- self.assertEqual(info, {
- "username": user["username"],
- "permissions": permissions,
- })
-
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
@@ -785,6 +809,236 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+
+class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
+ def get_server_addr(self, server_idx):
+ return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
+
+ def test_stats(self):
+ self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+
+ def test_stress(self):
+ self.run_hashclient(["--address", self.server_address, "stress"], check=True)
+
+ def test_remove_taskhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "taskhash", taskhash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_unihash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ def test_remove_outhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "outhash", outhash,
+ ], check=True)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_method(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "method", self.METHOD,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_clean_unused(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+
+ # Clean the database, which should not remove anything because all hashes an in-use
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, unihash)
+
+ # Remove the unihash. The row in the outhash table should still be present
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNotNone(result_outhash)
+
+ # Now clean with no minimum age which will remove the outhash
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNone(result_outhash)
+
+ def test_refresh_token(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "refresh-token"
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ print("New token is %r" % new_token)
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", new_token,
+ "get-user"
+ ], check=True)
+
+ def test_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "set-user-perms",
+ "-u", user["username"],
+ "@read", "@report",
+ ], check=True)
+
+ new_user = admin_client.get_user(user["username"])
+
+ self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
+
+ def test_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-user",
+ "-u", user["username"],
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "get-user",
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ def test_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ admin_client.new_user("test-user1", ["@read"])
+ admin_client.new_user("test-user2", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-all-users",
+ ], check=True)
+
+ self.assertIn("admin", p.stdout)
+ self.assertIn("test-user1", p.stdout)
+ self.assertIn("test-user2", p.stdout)
+
+ def test_new_user(self):
+ admin_client = self.start_auth_server()
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "new-user",
+ "-u", "test-user",
+ "@read", "@report",
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ user = {
+ "username": "test-user",
+ "token": new_token,
+ }
+
+ self.assertUserPerms(user, ["@read", "@report"])
+
+ def test_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "delete-user",
+ "-u", user["username"],
+ ], check=True)
+
+
+ self.assertIsNone(admin_client.get_user(user["username"]))
+
+ def test_get_db_usage(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-usage",
+ ], check=True)
+
+ def test_get_db_query_columns(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-query-columns",
+ ], check=True)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 18/22] bitbake-hashclient: Output stats in JSON format
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (16 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
` (6 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Outputting the stats in JSON format makes more sense as it's easier for
a downstream tool to parse if desired.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 ++-
lib/hashserv/tests.py | 3 ++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 58aa02ee..3ff7b763 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -15,6 +15,7 @@ import threading
import time
import warnings
import netrc
+import json
warnings.simplefilter("default")
try:
@@ -56,7 +57,7 @@ def main():
s = client.reset_stats()
else:
s = client.get_stats()
- pprint.pprint(s)
+ print(json.dumps(s, sort_keys=True, indent=4))
return 0
def handle_stress(args, client):
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index a80ccd57..2d78f9e9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -815,7 +815,8 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
def test_stats(self):
- self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ json.loads(p.stdout)
def test_stress(self):
self.run_hashclient(["--address", self.server_address, "stress"], check=True)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (17 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
` (5 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Space separation is more natural when setting the value from an
environment variable, so allow that here for convenience.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 1085d058..c560b3e5 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -127,7 +127,10 @@ websocket, as in "wss://SERVER:PORT"
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
- anon_perms = args.anon_perms.split(",")
+ if "," in args.anon_perms:
+ anon_perms = args.anon_perms.split(",")
+ else:
+ anon_perms = args.anon_perms.split()
server = hashserv.create_server(
args.bind,
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 20/22] hashserv: tests: Allow authentication for external server tests
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (18 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 21/22] hashserv: Allow self-service deletion Joshua Watt
` (4 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If BB_TEST_HASHSERV_USERNAME and BB_TEST_HASHSERV_PASSWORD are provided
for a server admin user, the authentication tests for the external
hashserver will run. In addition, any users that get created will now be
deleted when the test finishes.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 109 ++++++++++++++++++++++++++++--------------
1 file changed, 74 insertions(+), 35 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 2d78f9e9..5d209ffb 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -84,17 +84,13 @@ class HashEquivalenceTestSetup(object):
return self.server.address
def start_auth_server(self):
- self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
- self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.auth_server_address = auth_server.address
+ self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
return self.admin_client
def auth_client(self, user):
- return self.start_client(self.auth_server.address, user["username"], user["token"])
-
- def auth_perms(self, *permissions):
- self.client_index += 1
- user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
- return self.auth_client(user)
+ return self.start_client(self.auth_server_address, user["username"], user["token"])
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -120,11 +116,11 @@ class HashEquivalenceTestSetup(object):
})
def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
+ with self.start_client(self.auth_server_address) as client:
client.auth(user["username"], user["token"])
def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.auth(user["username"], user["token"])
def create_test_hash(self, client):
@@ -157,6 +153,26 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.create_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
+
+ def create_user(self, username, permissions, *, client=None):
+ def remove_user(username):
+ try:
+ self.admin_client.delete_user(username)
+ except bb.asyncrpc.InvokeError:
+ pass
+
+ if client is None:
+ client = self.admin_client
+
+ user = client.new_user(username, permissions)
+ self.addCleanup(remove_user, username)
+
+ return user
+
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -571,14 +587,14 @@ class HashEquivalenceCommonTests(object):
def test_auth_no_token_refresh_from_anon_user(self):
self.start_auth_server()
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.refresh_token()
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
# Create a new user with no permissions
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client:
new_user = client.refresh_token()
@@ -601,7 +617,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_token_refresh(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.refresh_token(user["username"])
@@ -617,7 +633,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_self_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -632,7 +648,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -649,7 +665,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_reconnect(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -665,7 +681,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_delete_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
# No self service
with self.auth_client(user) as client, self.assertRaises(InvokeError):
@@ -685,7 +701,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
self.assertUserPerms(user, [])
@@ -710,7 +726,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_all_users(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client, self.assertRaises(InvokeError):
client.get_all_users()
@@ -744,10 +760,10 @@ class HashEquivalenceCommonTests(object):
permissions.sort()
with self.auth_perms() as client, self.assertRaises(InvokeError):
- client.new_user("test-user", permissions)
+ self.create_user("test-user", permissions, client=client)
with self.auth_perms("@user-admin") as client:
- user = client.new_user("test-user", permissions)
+ user = self.create_user("test-user", permissions, client=client)
self.assertIn("token", user)
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
@@ -755,7 +771,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_become_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", ["@read", "@report"])
+ user = self.create_user("test-user", ["@read", "@report"])
user_info = user.copy()
del user_info["token"]
@@ -898,7 +914,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read", "@report"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"refresh-token"
@@ -916,7 +932,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
print("New token is %r" % new_token)
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", new_token,
"get-user"
@@ -928,7 +944,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"set-user-perms",
@@ -946,7 +962,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-user",
@@ -957,7 +973,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
self.assertIn("Permissions:", p.stdout)
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"get-user",
@@ -973,7 +989,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client.new_user("test-user2", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-all-users",
@@ -987,7 +1003,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client = self.start_auth_server()
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"new-user",
@@ -1017,14 +1033,13 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"delete-user",
"-u", user["username"],
], check=True)
-
self.assertIsNone(admin_client.get_user(user["username"]))
def test_get_db_usage(self):
@@ -1104,19 +1119,43 @@ class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocket
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
- def start_test_server(self):
- if 'BB_TEST_HASHSERV' not in os.environ:
- self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+ def get_env(self, name):
+ v = os.environ.get(name)
+ if not v:
+ self.skipTest(f'{name} not defined to test an external server')
+ return v
- return os.environ['BB_TEST_HASHSERV']
+ def start_test_server(self):
+ return self.get_env('BB_TEST_HASHSERV')
def start_server(self, *args, **kwargs):
self.skipTest('Cannot start local server when testing external servers')
+ def start_auth_server(self):
+
+ self.auth_server_address = self.server_address
+ self.admin_client = self.start_client(
+ self.server_address,
+ username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
+ password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
+ )
+ return self.admin_client
+
def setUp(self):
super().setUp()
+ if "BB_TEST_HASHSERV_USERNAME" in os.environ:
+ self.client = self.start_client(
+ self.server_address,
+ username=os.environ["BB_TEST_HASHSERV_USERNAME"],
+ password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
+ )
self.client.remove({"method": self.METHOD})
def tearDown(self):
self.client.remove({"method": self.METHOD})
super().tearDown()
+
+
+ def test_auth_get_all_users(self):
+ self.skipTest("Cannot test all users with external server")
+
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 21/22] hashserv: Allow self-service deletion
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (19 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
` (3 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows users to self-service deletion of their own user accounts
(meaning, they can delete their own accounts without special
permissions).
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 2 +-
lib/hashserv/tests.py | 7 +++++--
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index d2fd75df..6da56df7 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -708,7 +708,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"token": token,
}
- @permissions(USER_ADMIN_PERM, allow_anon=False)
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
async def handle_delete_user(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 5d209ffb..f0be8679 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -683,10 +683,13 @@ class HashEquivalenceCommonTests(object):
user = self.create_user("test-user", [])
- # No self service
- with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ # self service
+ with self.auth_client(user) as client:
client.delete_user(user["username"])
+ self.assertIsNone(admin_client.get_user(user["username"]))
+ user = self.create_user("test-user", [])
+
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v4 22/22] hashserv: server: Add owner if user is logged in
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (20 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 21/22] hashserv: Allow self-service deletion Joshua Watt
@ 2023-10-31 17:21 ` Joshua Watt
2023-11-01 13:17 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Alexandre Belloni
` (2 subsequent siblings)
24 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-10-31 17:21 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If a user is authenticated with the server, report them as the owner of
a report
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 3 +++
lib/hashserv/tests.py | 9 +++++++++
2 files changed, 12 insertions(+)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 6da56df7..a9714b5b 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -474,6 +474,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in data:
outhash_data[k] = data[k]
+ if self.user:
+ outhash_data["owner"] = self.user.username
+
# Insert the new entry, unless it already exists
if await self.db.insert_outhash(outhash_data):
# If this row is new, check if it is equivalent to another
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f0be8679..a9e6fdf9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -828,6 +828,15 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+ def test_auth_is_owner(self):
+ admin_client = self.start_auth_server()
+
+ user = self.create_user("test-user", ["@read", "@report"])
+ with self.auth_client(user) as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ data = client.get_taskhash(self.METHOD, taskhash, True)
+ self.assertEqual(data["owner"], user["username"])
+
class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (21 preceding siblings ...)
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
@ 2023-11-01 13:17 ` Alexandre Belloni
2023-11-01 13:29 ` Alexandre Belloni
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
24 siblings, 1 reply; 138+ messages in thread
From: Alexandre Belloni @ 2023-11-01 13:17 UTC (permalink / raw)
To: Joshua Watt; +Cc: bitbake-devel
Hello Joshua,
This is causing warning on the AB:
WARNING: Error contacting Hash Equivalence Server hashserv.yocto.io:8686: Expecting value: line 1 column 1 (char 0)
https://autobuilder.yoctoproject.org/typhoon/#/builders/122/builds/3522/steps/12/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/122/builds/3523/steps/17/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/23/builds/8384/steps/14/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/73/builds/7990/steps/13/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/59/builds/7987/steps/12/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/73/builds/7991/steps/13/logs/warnings
https://autobuilder.yoctoproject.org/typhoon/#/builders/59/builds/7988/steps/12/logs/warnings
On 31/10/2023 11:21:16-0600, Joshua Watt wrote:
> This patch series reworks the bitbake asyncrpc API to add a WebSockets
> implementation for both the client and server. The hash equivalence
> server is updated to allow using this new API (the PR server can also be
> updated in the future if desired).
>
> In addition, the database backed for the hash equivalence server is
> abstracted so that sqlalchemy can optionally be used instead of sqlite.
> This allows using "big metal" databases as the backend, which allows the
> hash equivalence server to scale to a large number of queries.
>
> Note that both websockets and sqlalchemy require 3rd party python
> modules to function. However, these modules are optional unless the user
> desires to use the APIs.
>
> Also, user management is added. This allows user accounts to be
> registered with the server and users can be given permissions to do
> certain operations on the server. Users are not (necessarily) required
> to login to access the server, as permissions can granted to anonymous
> users. The default permissions will give anonymous users the same
> permissions that they would have before user accounts were added so as
> to retain backward compatibility, but server admins will likely want to
> change this.
>
> V3: Remove RFC status; patches are ready for review
> V4: Fixed protocol breakage with mixing older and newer clients/servers
>
> Joshua Watt (22):
> asyncrpc: Abstract sockets
> hashserv: Add websocket connection implementation
> asyncrpc: Add context manager API
> hashserv: tests: Add external database tests
> asyncrpc: Prefix log messages with client info
> bitbake-hashserv: Allow arguments from environment
> hashserv: Abstract database
> hashserv: Add SQLalchemy backend
> hashserv: Implement read-only version of "report" RPC
> asyncrpc: Add InvokeError
> asyncrpc: client: Prevent double closing of loop
> asyncrpc: client: Add disconnect API
> hashserv: Add user permissions
> hashserv: Add become-user API
> hashserv: Add db-usage API
> hashserv: Add database column query API
> hashserv: test: Add bitbake-hashclient tests
> bitbake-hashclient: Output stats in JSON format
> bitbake-hashserver: Allow anonymous permissions to be space separated
> hashserv: tests: Allow authentication for external server tests
> hashserv: Allow self-service deletion
> hashserv: server: Add owner if user is logged in
>
> bin/bitbake-hashclient | 145 +++++-
> bin/bitbake-hashserv | 132 ++++-
> lib/bb/asyncrpc/__init__.py | 33 +-
> lib/bb/asyncrpc/client.py | 120 ++---
> lib/bb/asyncrpc/connection.py | 146 ++++++
> lib/bb/asyncrpc/exceptions.py | 21 +
> lib/bb/asyncrpc/serv.py | 356 ++++++++-----
> lib/hashserv/__init__.py | 190 +++----
> lib/hashserv/client.py | 147 +++++-
> lib/hashserv/server.py | 951 +++++++++++++++++++++-------------
> lib/hashserv/sqlalchemy.py | 427 +++++++++++++++
> lib/hashserv/sqlite.py | 391 ++++++++++++++
> lib/hashserv/tests.py | 736 +++++++++++++++++++++++++-
> lib/prserv/client.py | 8 +-
> lib/prserv/serv.py | 37 +-
> 15 files changed, 3034 insertions(+), 806 deletions(-)
> create mode 100644 lib/bb/asyncrpc/connection.py
> create mode 100644 lib/bb/asyncrpc/exceptions.py
> create mode 100644 lib/hashserv/sqlalchemy.py
> create mode 100644 lib/hashserv/sqlite.py
>
> --
> 2.34.1
>
>
> -=-=-=-=-=-=-=-=-=-=-=-
> Links: You receive all messages sent to this group.
> View/Reply Online (#15350): https://lists.openembedded.org/g/bitbake-devel/message/15350
> Mute This Topic: https://lists.openembedded.org/mt/102302326/3617179
> Group Owner: bitbake-devel+owner@lists.openembedded.org
> Unsubscribe: https://lists.openembedded.org/g/bitbake-devel/unsub [alexandre.belloni@bootlin.com]
> -=-=-=-=-=-=-=-=-=-=-=-
>
--
Alexandre Belloni, co-owner and COO, Bootlin
Embedded Linux and Kernel engineering
https://bootlin.com
^ permalink raw reply [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management
2023-11-01 13:17 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Alexandre Belloni
@ 2023-11-01 13:29 ` Alexandre Belloni
0 siblings, 0 replies; 138+ messages in thread
From: Alexandre Belloni @ 2023-11-01 13:29 UTC (permalink / raw)
To: Joshua Watt; +Cc: bitbake-devel
Replying with more failures :( :
https://autobuilder.yoctoproject.org/typhoon/#/builders/79/builds/5983/steps/14/logs/stdio
ERROR: PRservice localhost:44747 not available
ERROR: Unable to start PR Server, exiting, check the bitbake-cookerdaemon.log
https://autobuilder.yoctoproject.org/typhoon/#/builders/79/builds/5983/steps/11/logs/stdio
WARNING: default:a1 do_packagedata: Error contacting Hash Equivalence Server unix:///tmp/runqueuetesteux28xpz/hashserve.sock: [Errno 104] Connection reset by peer
Error talking to server: [Errno 104] Connection reset by peer
Error talking to server: [Errno 104] Connection reset by peer
Error talking to server: [Errno 104] Connection reset by peer
RecursionError: maximum recursion depth exceeded
https://autobuilder.yoctoproject.org/typhoon/#/builders/80/builds/5933/steps/11/logs/stdio
https://autobuilder.yoctoproject.org/typhoon/#/builders/80/builds/5933/steps/14/logs/stdio
https://autobuilder.yoctoproject.org/typhoon/#/builders/127/builds/2357/steps/14/logs/stdio
https://autobuilder.yoctoproject.org/typhoon/#/builders/127/builds/2357/steps/11/logs/stdio
On 01/11/2023 14:17:07+0100, Alexandre Belloni wrote:
> Hello Joshua,
>
> This is causing warning on the AB:
>
> WARNING: Error contacting Hash Equivalence Server hashserv.yocto.io:8686: Expecting value: line 1 column 1 (char 0)
>
> https://autobuilder.yoctoproject.org/typhoon/#/builders/122/builds/3522/steps/12/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/122/builds/3523/steps/17/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/23/builds/8384/steps/14/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/73/builds/7990/steps/13/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/59/builds/7987/steps/12/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/73/builds/7991/steps/13/logs/warnings
> https://autobuilder.yoctoproject.org/typhoon/#/builders/59/builds/7988/steps/12/logs/warnings
>
>
>
>
>
> On 31/10/2023 11:21:16-0600, Joshua Watt wrote:
> > This patch series reworks the bitbake asyncrpc API to add a WebSockets
> > implementation for both the client and server. The hash equivalence
> > server is updated to allow using this new API (the PR server can also be
> > updated in the future if desired).
> >
> > In addition, the database backed for the hash equivalence server is
> > abstracted so that sqlalchemy can optionally be used instead of sqlite.
> > This allows using "big metal" databases as the backend, which allows the
> > hash equivalence server to scale to a large number of queries.
> >
> > Note that both websockets and sqlalchemy require 3rd party python
> > modules to function. However, these modules are optional unless the user
> > desires to use the APIs.
> >
> > Also, user management is added. This allows user accounts to be
> > registered with the server and users can be given permissions to do
> > certain operations on the server. Users are not (necessarily) required
> > to login to access the server, as permissions can granted to anonymous
> > users. The default permissions will give anonymous users the same
> > permissions that they would have before user accounts were added so as
> > to retain backward compatibility, but server admins will likely want to
> > change this.
> >
> > V3: Remove RFC status; patches are ready for review
> > V4: Fixed protocol breakage with mixing older and newer clients/servers
> >
> > Joshua Watt (22):
> > asyncrpc: Abstract sockets
> > hashserv: Add websocket connection implementation
> > asyncrpc: Add context manager API
> > hashserv: tests: Add external database tests
> > asyncrpc: Prefix log messages with client info
> > bitbake-hashserv: Allow arguments from environment
> > hashserv: Abstract database
> > hashserv: Add SQLalchemy backend
> > hashserv: Implement read-only version of "report" RPC
> > asyncrpc: Add InvokeError
> > asyncrpc: client: Prevent double closing of loop
> > asyncrpc: client: Add disconnect API
> > hashserv: Add user permissions
> > hashserv: Add become-user API
> > hashserv: Add db-usage API
> > hashserv: Add database column query API
> > hashserv: test: Add bitbake-hashclient tests
> > bitbake-hashclient: Output stats in JSON format
> > bitbake-hashserver: Allow anonymous permissions to be space separated
> > hashserv: tests: Allow authentication for external server tests
> > hashserv: Allow self-service deletion
> > hashserv: server: Add owner if user is logged in
> >
> > bin/bitbake-hashclient | 145 +++++-
> > bin/bitbake-hashserv | 132 ++++-
> > lib/bb/asyncrpc/__init__.py | 33 +-
> > lib/bb/asyncrpc/client.py | 120 ++---
> > lib/bb/asyncrpc/connection.py | 146 ++++++
> > lib/bb/asyncrpc/exceptions.py | 21 +
> > lib/bb/asyncrpc/serv.py | 356 ++++++++-----
> > lib/hashserv/__init__.py | 190 +++----
> > lib/hashserv/client.py | 147 +++++-
> > lib/hashserv/server.py | 951 +++++++++++++++++++++-------------
> > lib/hashserv/sqlalchemy.py | 427 +++++++++++++++
> > lib/hashserv/sqlite.py | 391 ++++++++++++++
> > lib/hashserv/tests.py | 736 +++++++++++++++++++++++++-
> > lib/prserv/client.py | 8 +-
> > lib/prserv/serv.py | 37 +-
> > 15 files changed, 3034 insertions(+), 806 deletions(-)
> > create mode 100644 lib/bb/asyncrpc/connection.py
> > create mode 100644 lib/bb/asyncrpc/exceptions.py
> > create mode 100644 lib/hashserv/sqlalchemy.py
> > create mode 100644 lib/hashserv/sqlite.py
> >
> > --
> > 2.34.1
> >
>
> >
> > -=-=-=-=-=-=-=-=-=-=-=-
> > Links: You receive all messages sent to this group.
> > View/Reply Online (#15350): https://lists.openembedded.org/g/bitbake-devel/message/15350
> > Mute This Topic: https://lists.openembedded.org/mt/102302326/3617179
> > Group Owner: bitbake-devel+owner@lists.openembedded.org
> > Unsubscribe: https://lists.openembedded.org/g/bitbake-devel/unsub [alexandre.belloni@bootlin.com]
> > -=-=-=-=-=-=-=-=-=-=-=-
> >
>
>
> --
> Alexandre Belloni, co-owner and COO, Bootlin
> Embedded Linux and Kernel engineering
> https://bootlin.com
--
Alexandre Belloni, co-owner and COO, Bootlin
Embedded Linux and Kernel engineering
https://bootlin.com
^ permalink raw reply [flat|nested] 138+ messages in thread
* [bitbake-devel][PATCH v5 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (22 preceding siblings ...)
2023-11-01 13:17 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Alexandre Belloni
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 01/22] asyncrpc: Abstract sockets Joshua Watt
` (21 more replies)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
24 siblings, 22 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
This patch series reworks the bitbake asyncrpc API to add a WebSockets
implementation for both the client and server. The hash equivalence
server is updated to allow using this new API (the PR server can also be
updated in the future if desired).
In addition, the database backed for the hash equivalence server is
abstracted so that sqlalchemy can optionally be used instead of sqlite.
This allows using "big metal" databases as the backend, which allows the
hash equivalence server to scale to a large number of queries.
Note that both websockets and sqlalchemy require 3rd party python
modules to function. However, these modules are optional unless the user
desires to use the APIs.
Also, user management is added. This allows user accounts to be
registered with the server and users can be given permissions to do
certain operations on the server. Users are not (necessarily) required
to login to access the server, as permissions can granted to anonymous
users. The default permissions will give anonymous users the same
permissions that they would have before user accounts were added so as
to retain backward compatibility, but server admins will likely want to
change this.
V3: Remove RFC status; patches are ready for review
V4: Fixed protocol breakage with mixing older and newer clients/servers
V5: Fixed compatibility with Python 3.8
Joshua Watt (22):
asyncrpc: Abstract sockets
hashserv: Add websocket connection implementation
asyncrpc: Add context manager API
hashserv: tests: Add external database tests
asyncrpc: Prefix log messages with client info
bitbake-hashserv: Allow arguments from environment
hashserv: Abstract database
hashserv: Add SQLalchemy backend
hashserv: Implement read-only version of "report" RPC
asyncrpc: Add InvokeError
asyncrpc: client: Prevent double closing of loop
asyncrpc: client: Add disconnect API
hashserv: Add user permissions
hashserv: Add become-user API
hashserv: Add db-usage API
hashserv: Add database column query API
hashserv: test: Add bitbake-hashclient tests
bitbake-hashclient: Output stats in JSON format
bitbake-hashserver: Allow anonymous permissions to be space separated
hashserv: tests: Allow authentication for external server tests
hashserv: Allow self-service deletion
hashserv: server: Add owner if user is logged in
bin/bitbake-hashclient | 145 +++++-
bin/bitbake-hashserv | 132 ++++-
lib/bb/asyncrpc/__init__.py | 33 +-
lib/bb/asyncrpc/client.py | 120 ++---
lib/bb/asyncrpc/connection.py | 146 ++++++
lib/bb/asyncrpc/exceptions.py | 21 +
lib/bb/asyncrpc/serv.py | 359 ++++++++-----
lib/hashserv/__init__.py | 190 +++----
lib/hashserv/client.py | 147 +++++-
lib/hashserv/server.py | 951 +++++++++++++++++++++-------------
lib/hashserv/sqlalchemy.py | 427 +++++++++++++++
lib/hashserv/sqlite.py | 408 +++++++++++++++
lib/hashserv/tests.py | 736 +++++++++++++++++++++++++-
lib/prserv/client.py | 8 +-
lib/prserv/serv.py | 37 +-
15 files changed, 3053 insertions(+), 807 deletions(-)
create mode 100644 lib/bb/asyncrpc/connection.py
create mode 100644 lib/bb/asyncrpc/exceptions.py
create mode 100644 lib/hashserv/sqlalchemy.py
create mode 100644 lib/hashserv/sqlite.py
--
2.34.1
^ permalink raw reply [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 01/22] asyncrpc: Abstract sockets
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 02/22] hashserv: Add websocket connection implementation Joshua Watt
` (20 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 32 +---
lib/bb/asyncrpc/client.py | 78 +++------
lib/bb/asyncrpc/connection.py | 95 +++++++++++
lib/bb/asyncrpc/exceptions.py | 17 ++
lib/bb/asyncrpc/serv.py | 298 +++++++++++++++++-----------------
lib/hashserv/__init__.py | 21 ---
lib/hashserv/client.py | 38 ++---
lib/hashserv/server.py | 115 ++++++-------
lib/prserv/client.py | 8 +-
lib/prserv/serv.py | 31 ++--
10 files changed, 380 insertions(+), 353 deletions(-)
create mode 100644 lib/bb/asyncrpc/connection.py
create mode 100644 lib/bb/asyncrpc/exceptions.py
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9a85e996..9f677eac 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-import itertools
-import json
-
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
-
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
from .client import AsyncClient, Client
-from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+ ClientError,
+ ServerError,
+ ConnectionClosedError,
+)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe..7f33099b 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
- os.chdir(cwd)
- return await asyncio.open_unix_connection(sock=sock)
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ # End of headers
+ await self.socket.send("")
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
- self.reader = None
-
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
-
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
-
- m = json.loads("".join(lines))
-
- return m
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
new file mode 100644
index 00000000..c4fd2475
--- /dev/null
+++ b/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+ if len(msg) < max_chunk - 1:
+ yield "".join((msg, "\n"))
+ else:
+ yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+ args = [iter(msg)] * (max_chunk - 1)
+ for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+ yield "".join(itertools.chain(m, "\n"))
+ yield "\n"
+
+
+class StreamConnection(object):
+ def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+ self.reader = reader
+ self.writer = writer
+ self.timeout = timeout
+ self.max_chunk = max_chunk
+
+ @property
+ def address(self):
+ return self.writer.get_extra_info("peername")
+
+ async def send_message(self, msg):
+ for c in chunkify(json.dumps(msg), self.max_chunk):
+ self.writer.write(c.encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv_message(self):
+ l = await self.recv()
+
+ m = json.loads(l)
+ if not m:
+ return m
+
+ if "chunk-stream" in m:
+ lines = []
+ while True:
+ l = await self.recv()
+ if not l:
+ break
+ lines.append(l)
+
+ m = json.loads("".join(lines))
+
+ return m
+
+ async def send(self, msg):
+ self.writer.write(("%s\n" % msg).encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv(self):
+ if self.timeout < 0:
+ line = await self.reader.readline()
+ else:
+ try:
+ line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+
+ if not line:
+ raise ConnectionClosedError("Connection closed")
+
+ line = line.decode("utf-8")
+
+ if not line.endswith("\n"):
+ raise ConnectionError("Bad message %r" % (line))
+
+ return line.rstrip()
+
+ async def close(self):
+ self.reader = None
+ if self.writer is not None:
+ self.writer.close()
+ self.writer = None
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 00000000..a8942b4f
--- /dev/null
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+ pass
+
+
+class ServerError(Exception):
+ pass
+
+
+class ConnectionClosedError(Exception):
+ pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index d2de4891..8d4da1e2 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,242 @@ import signal
import socket
import sys
import multiprocessing
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
- pass
-
-
-class ServerError(Exception):
- pass
+from .connection import StreamConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
- def __init__(self, reader, writer, proto_name, logger):
- self.reader = reader
- self.writer = writer
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
self.proto_name = proto_name
- self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
- 'chunk-stream': self.handle_chunk,
- 'ping': self.handle_ping,
+ "ping": self.handle_ping,
}
self.logger = logger
+ async def close(self):
+ await self.socket.close()
+
async def process_requests(self):
try:
- self.addr = self.writer.get_extra_info('peername')
- self.logger.debug('Client %r connected' % (self.addr,))
+ self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
- client_protocol = await self.reader.readline()
+ client_protocol = await self.socket.recv()
if not client_protocol:
return
- (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+ (client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
- self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
- self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
- self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
- line = await self.reader.readline()
- if not line:
- return
-
- line = line.decode('utf-8').rstrip()
- if not line:
+ header = await self.socket.recv()
+ if not header:
break
# Handle messages
while True:
- d = await self.read_message()
+ d = await self.socket.recv_message()
if d is None:
break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
+ response = await self.dispatch_message(d)
+ await self.socket.send_message(response)
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
- self.writer.close()
+ await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- await self.handlers[k](msg[k])
- return
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
+ async def handle_ping(self, request):
+ return {"alive": True}
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
- try:
- message = l.decode('utf-8')
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
- if not message.endswith('\n'):
- return None
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
+
+ await self.handler(socket)
+
+ async def stop(self):
+ self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
+
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ pass
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % lines)
- raise e
+ def start(self, loop):
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
+ )
+ finally:
+ os.chdir(cwd)
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
- await self.dispatch_message(msg)
+ async def stop(self):
+ await super().stop()
+ self.server.close()
- async def handle_ping(self, request):
- response = {'alive': True}
- self.write_message(response)
+ def cleanup(self):
+ os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
- self._cleanup_socket = None
self.logger = logger
- self.start = None
- self.address = None
self.loop = None
+ self.run_tasks = []
def start_tcp_server(self, host, port):
- def start_tcp():
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port)
- )
-
- for s in self.server.sockets:
- self.logger.debug('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- # Enable keep alives. This prevents broken client connections
- # from persisting on the server for long periods of time.
- s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
-
- self.start = start_tcp
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
-
- def start_unix():
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path))
- )
- finally:
- os.chdir(cwd)
-
- self.logger.debug('Listening on %r' % path)
-
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
-
- self.start = start_unix
-
- @abc.abstractmethod
- def accept_client(self, reader, writer):
- pass
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
+ async def _client_handler(self, socket):
try:
- client = self.accept_client(reader, writer)
+ client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+ self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
- writer.close()
- self.logger.debug('Client disconnected')
+ await socket.close()
+ self.logger.debug("Client disconnected")
- def run_loop_forever(self):
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
- self.loop.stop()
+ self.loop.create_task(self.stop())
- def _serve_forever(self):
+ def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
- self.run_loop_forever()
- self.server.close()
+ self.loop.run_until_complete(asyncio.gather(*tasks))
- self.loop.run_until_complete(self.server.wait_closed())
- self.logger.debug('Server shutting down')
+ self.logger.debug("Server shutting down")
finally:
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
- self.start()
- self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
+
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@@ -259,18 +260,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
+ self._create_loop()
try:
- self.start()
+ self.address = None
+ tasks = self.start()
finally:
+ # Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
- self._serve_forever()
+ self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57..3a401835 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
-
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267..ebf1c361 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
- self.writer.write(("%s\n" % msg).encode("utf-8"))
- await self.writer.drain()
- l = await self.reader.readline()
- if not l:
- raise ConnectionError("Connection closed")
- return l.decode("utf-8").rstrip()
+ await self.socket.send(msg)
+ return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
+ async def stream_to_normal():
+ await self.socket.send("END")
+ return await self.socket.recv_message()
+
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
- r = await self.send_stream("END")
+ r = await self._send_wrapper(stream_to_normal)
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
- r = await self.send_message({"get-stream": None})
+ r = await self.invoke({"get-stream": None})
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
- return await self.send_message({"report": m})
+ return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
- return await self.send_message({"report-equiv": m})
+ return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"get-stats": None})
+ return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"reset-stats": None})
+ return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
- return (await self.send_message({"backfill-wait": None}))["tasks"]
+ return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"remove": {"where": where}})
+ return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+ return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476b..6d3a4751 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+ def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+ super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
- await self.handlers[k](msg[k])
+ return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
- d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
- self.write_message(d)
+ return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
- d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
- self.write_message(d)
+ return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
- self.write_message('ok')
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
+ break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,29 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
+ msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
+ msg = upstream
else:
- msg = "\n".encode("utf-8")
+ msg = ""
else:
- msg = '\n'.encode('utf-8')
+ msg = ""
- self.writer.write(msg)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
-
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
+ return "ok"
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@@ -468,7 +461,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
- self.write_message(d)
+ return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@@ -491,30 +484,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
+ return d
async def handle_get_stats(self, request):
- d = {
+ return {
'requests': self.request_stats.todict(),
}
- self.write_message(d)
-
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
- self.write_message(d)
+ return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
- self.write_message(d)
+ return d
async def handle_remove(self, request):
condition = request["where"]
@@ -541,7 +532,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
- self.write_message({"count": count})
+ return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@@ -558,7 +549,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
- self.write_message({"count": count})
+ return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@@ -583,41 +574,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ def accept_client(self, socket):
+ return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ async def backfill_worker_task(self):
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
self.backfill_queue.task_done()
- finally:
- await client.close()
+ break
+ method, taskhash = item
+ await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
- async def join_worker(worker):
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+ return tasks
+
+ async def stop(self):
+ if self.backfill_queue is not None:
await self.backfill_queue.put(None)
- await worker
-
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
-
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
-
- with self._backfill_worker():
- super().run_loop_forever()
+ await super().stop()
diff --git a/lib/prserv/client.py b/lib/prserv/client.py
index 69ab7a4a..6b81356f 100644
--- a/lib/prserv/client.py
+++ b/lib/prserv/client.py
@@ -14,28 +14,28 @@ class PRAsyncClient(bb.asyncrpc.AsyncClient):
super().__init__('PRSERVICE', '1.0', logger)
async def getPR(self, version, pkgarch, checksum):
- response = await self.send_message(
+ response = await self.invoke(
{'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
)
if response:
return response['value']
async def importone(self, version, pkgarch, checksum, value):
- response = await self.send_message(
+ response = await self.invoke(
{'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
)
if response:
return response['value']
async def export(self, version, pkgarch, checksum, colinfo):
- response = await self.send_message(
+ response = await self.invoke(
{'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
)
if response:
return (response['metainfo'], response['datainfo'])
async def is_readonly(self):
- response = await self.send_message(
+ response = await self.invoke(
{'is-readonly': {}}
)
if response:
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index c686b206..ea793316 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -20,8 +20,8 @@ PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
singleton = None
class PRServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, table, read_only):
- super().__init__(reader, writer, 'PRSERVICE', logger)
+ def __init__(self, socket, table, read_only):
+ super().__init__(socket, 'PRSERVICE', logger)
self.handlers.update({
'get-pr': self.handle_get_pr,
'import-one': self.handle_import_one,
@@ -36,12 +36,12 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
try:
- await super().dispatch_message(msg)
+ return await super().dispatch_message(msg)
except:
self.table.sync()
raise
-
- self.table.sync_if_dirty()
+ else:
+ self.table.sync_if_dirty()
async def handle_get_pr(self, request):
version = request['version']
@@ -57,7 +57,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
except sqlite3.Error as exc:
logger.error(str(exc))
- self.write_message(response)
+ return response
async def handle_import_one(self, request):
response = None
@@ -71,7 +71,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
if value is not None:
response = {'value': value}
- self.write_message(response)
+ return response
async def handle_export(self, request):
version = request['version']
@@ -85,12 +85,10 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
logger.error(str(exc))
metainfo = datainfo = None
- response = {'metainfo': metainfo, 'datainfo': datainfo}
- self.write_message(response)
+ return {'metainfo': metainfo, 'datainfo': datainfo}
async def handle_is_readonly(self, request):
- response = {'readonly': self.read_only}
- self.write_message(response)
+ return {'readonly': self.read_only}
class PRServer(bb.asyncrpc.AsyncServer):
def __init__(self, dbfile, read_only=False):
@@ -99,20 +97,23 @@ class PRServer(bb.asyncrpc.AsyncServer):
self.table = None
self.read_only = read_only
- def accept_client(self, reader, writer):
- return PRServerClient(reader, writer, self.table, self.read_only)
+ def accept_client(self, socket):
+ return PRServerClient(socket, self.table, self.read_only)
- def _serve_forever(self):
+ def start(self):
+ tasks = super().start()
self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
self.table = self.db["PRMAIN"]
logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
(self.dbfile, self.address, str(os.getpid())))
- super()._serve_forever()
+ return tasks
+ async def stop(self):
self.table.sync_if_dirty()
self.db.disconnect()
+ await super().stop()
def signal_handler(self):
super().signal_handler()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 02/22] hashserv: Add websocket connection implementation
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 01/22] asyncrpc: Abstract sockets Joshua Watt
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 03/22] asyncrpc: Add context manager API Joshua Watt
` (19 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support to the hash equivalence client and server to communicate
over websockets. Since websockets are message orientated instead of
stream orientated, and new connection class is needed to handle them.
Note that websocket support does require the 3rd party websockets python
module be installed on the host, but it should not be required unless
websockets are actually being used.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 11 +++++++-
lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
lib/bb/asyncrpc/serv.py | 53 ++++++++++++++++++++++++++++++++++-
lib/hashserv/__init__.py | 13 +++++++++
lib/hashserv/client.py | 1 +
lib/hashserv/tests.py | 17 +++++++++++
6 files changed, 137 insertions(+), 2 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 7f33099b..802c07df 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,7 +10,7 @@ import json
import os
import socket
import sys
-from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError
@@ -47,6 +47,15 @@ class AsyncClient(object):
self._connect_sock = connect_sock
+ async def connect_websocket(self, uri):
+ import websockets
+
+ async def connect_sock():
+ websocket = await websockets.connect(uri, ping_interval=None)
+ return WebsocketConnection(websocket, self.timeout)
+
+ self._connect_sock = connect_sock
+
async def setup_connection(self):
# Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index c4fd2475..a10628f7 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -93,3 +93,47 @@ class StreamConnection(object):
if self.writer is not None:
self.writer.close()
self.writer = None
+
+
+class WebsocketConnection(object):
+ def __init__(self, socket, timeout):
+ self.socket = socket
+ self.timeout = timeout
+
+ @property
+ def address(self):
+ return ":".join(str(s) for s in self.socket.remote_address)
+
+ async def send_message(self, msg):
+ await self.send(json.dumps(msg))
+
+ async def recv_message(self):
+ m = await self.recv()
+ return json.loads(m)
+
+ async def send(self, msg):
+ import websockets.exceptions
+
+ try:
+ await self.socket.send(msg)
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def recv(self):
+ import websockets.exceptions
+
+ try:
+ if self.timeout < 0:
+ return await self.socket.recv()
+
+ try:
+ return await asyncio.wait_for(self.socket.recv(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def close(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 8d4da1e2..3040ac91 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,7 +12,7 @@ import signal
import socket
import sys
import multiprocessing
-from .connection import StreamConnection
+from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
@@ -172,6 +172,54 @@ class UnixStreamServer(StreamServer):
os.unlink(self.path)
+class WebsocketsServer(object):
+ def __init__(self, host, port, handler, logger):
+ self.host = host
+ self.port = port
+ self.handler = handler
+ self.logger = logger
+
+ def start(self, loop):
+ import websockets.server
+
+ self.server = loop.run_until_complete(
+ websockets.server.serve(
+ self.client_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ )
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "ws://[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "ws://%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ self.server.close()
+
+ def cleanup(self):
+ pass
+
+ async def client_handler(self, websocket):
+ socket = WebsocketConnection(websocket, -1)
+ await self.handler(socket)
+
+
class AsyncServer(object):
def __init__(self, logger):
self.logger = logger
@@ -184,6 +232,9 @@ class AsyncServer(object):
def start_unix_server(self, path):
self.server = UnixStreamServer(path, self._client_handler, self.logger)
+ def start_websocket_server(self, host, port):
+ self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
+
async def _client_handler(self, socket):
try:
client = self.accept_client(socket)
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 3a401835..56b9c6bc 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -9,11 +9,15 @@ import re
import sqlite3
import itertools
import json
+from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
+WS_PREFIX = "ws://"
+WSS_PREFIX = "wss://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
+ADDR_TYPE_WS = 2
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
@@ -84,6 +88,8 @@ def setup_database(database, sync=True):
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
+ return (ADDR_TYPE_WS, (addr,))
else:
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
if m is not None:
@@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
s.start_unix_server(*a)
+ elif typ == ADDR_TYPE_WS:
+ url = urlparse(a[0])
+ s.start_websocket_server(url.hostname, url.port)
else:
s.start_tcp_server(*a)
@@ -116,6 +125,8 @@ def create_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
else:
c.connect_tcp(*a)
@@ -128,6 +139,8 @@ async def create_async_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
else:
await c.connect_tcp(*a)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index ebf1c361..ebb58f33 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
super().__init__()
self._add_methods(
"connect_tcp",
+ "connect_websocket",
"get_unihash",
"report_unihash",
"report_unihash_equiv",
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f343c586..01ffd52c 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
# If IPv6 is enabled, it should be safe to use localhost directly, in general
# case it is more reliable to resolve the IP address explicitly.
return socket.gethostbyname("localhost") + ":0"
+
+
+class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def setUp(self):
+ try:
+ import websockets
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def get_server_addr(self, server_idx):
+ # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
+ # If IPv6 is enabled, it should be safe to use localhost directly, in general
+ # case it is more reliable to resolve the IP address explicitly.
+ host = socket.gethostbyname("localhost")
+ return "ws://%s:0" % host
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 03/22] asyncrpc: Add context manager API
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 01/22] asyncrpc: Abstract sockets Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 02/22] hashserv: Add websocket connection implementation Joshua Watt
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 04/22] hashserv: tests: Add external database tests Joshua Watt
` (18 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds context manager API for the asyncrcp client class which allow
writing code that will automatically close the connection like so:
with hashserv.create_client(address) as client:
...
Rework the bitbake-hashclient tool and PR server to use this new API to
fix warnings about unclosed event loops when exiting
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 36 +++++++++++++++++-------------------
lib/bb/asyncrpc/client.py | 13 +++++++++++++
lib/prserv/serv.py | 6 +++---
3 files changed, 33 insertions(+), 22 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 3f265e8f..a02a65b9 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -56,25 +56,24 @@ def main():
nonlocal missed_hashes
nonlocal max_time
- client = hashserv.create_client(args.address)
-
- for i in range(args.requests):
- taskhash = hashlib.sha256()
- taskhash.update(args.taskhash_seed.encode('utf-8'))
- taskhash.update(str(i).encode('utf-8'))
+ with hashserv.create_client(args.address) as client:
+ for i in range(args.requests):
+ taskhash = hashlib.sha256()
+ taskhash.update(args.taskhash_seed.encode('utf-8'))
+ taskhash.update(str(i).encode('utf-8'))
- start_time = time.perf_counter()
- l = client.get_unihash(METHOD, taskhash.hexdigest())
- elapsed = time.perf_counter() - start_time
+ start_time = time.perf_counter()
+ l = client.get_unihash(METHOD, taskhash.hexdigest())
+ elapsed = time.perf_counter() - start_time
- with lock:
- if l:
- found_hashes += 1
- else:
- missed_hashes += 1
+ with lock:
+ if l:
+ found_hashes += 1
+ else:
+ missed_hashes += 1
- max_time = max(elapsed, max_time)
- pbar.update()
+ max_time = max(elapsed, max_time)
+ pbar.update()
max_time = 0
found_hashes = 0
@@ -174,9 +173,8 @@ def main():
func = getattr(args, 'func', None)
if func:
- client = hashserv.create_client(args.address)
-
- return func(args, client)
+ with hashserv.create_client(args.address) as client:
+ return func(args, client)
return 0
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 802c07df..009085c3 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -103,6 +103,12 @@ class AsyncClient(object):
async def ping(self):
return await self.invoke({"ping": {}})
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
class Client(object):
def __init__(self):
@@ -153,3 +159,10 @@ class Client(object):
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index ea793316..6168eb18 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -345,9 +345,9 @@ def auto_shutdown():
def ping(host, port):
from . import client
- conn = client.PRClient()
- conn.connect_tcp(host, port)
- return conn.ping()
+ with client.PRClient() as conn:
+ conn.connect_tcp(host, port)
+ return conn.ping()
def connect(host, port):
from . import client
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 04/22] hashserv: tests: Add external database tests
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (2 preceding siblings ...)
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 03/22] asyncrpc: Add context manager API Joshua Watt
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
` (17 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for running the hash equivalence test suite against an
external hash equivalence implementation.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 54 +++++++++++++++++++++++++++++++++++--------
1 file changed, 44 insertions(+), 10 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 01ffd52c..4c98a280 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -51,13 +51,20 @@ class HashEquivalenceTestSetup(object):
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
self.addCleanup(cleanup_server, server)
+ return server
+
+ def start_client(self, server_address):
def cleanup_client(client):
client.close()
- client = create_client(server.address)
+ client = create_client(server_address)
self.addCleanup(cleanup_client, client)
- return (client, server)
+ return client
+
+ def start_test_server(self):
+ server = self.start_server()
+ return server.address
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -66,7 +73,9 @@ class HashEquivalenceTestSetup(object):
self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
self.addCleanup(self.temp_dir.cleanup)
- (self.client, self.server) = self.start_server()
+ self.server_address = self.start_test_server()
+
+ self.client = self.start_client(self.server_address)
def assertClientGetHash(self, client, taskhash, unihash):
result = client.get_unihash(self.METHOD, taskhash)
@@ -206,7 +215,7 @@ class HashEquivalenceCommonTests(object):
def test_stress(self):
def query_server(failures):
- client = Client(self.server.address)
+ client = Client(self.server_address)
try:
for i in range(1000):
taskhash = hashlib.sha256()
@@ -245,8 +254,10 @@ class HashEquivalenceCommonTests(object):
# the side client. It also verifies that the results are pulled into
# the downstream database by checking that the downstream and side servers
# match after the downstream is done waiting for all backfill tasks
- (down_client, down_server) = self.start_server(upstream=self.server.address)
- (side_client, side_server) = self.start_server(dbpath=down_server.dbpath)
+ down_server = self.start_server(upstream=self.server_address)
+ down_client = self.start_client(down_server.address)
+ side_server = self.start_server(dbpath=down_server.dbpath)
+ side_client = self.start_client(side_server.address)
def check_hash(taskhash, unihash, old_sidehash):
nonlocal down_client
@@ -351,14 +362,18 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['method'], self.METHOD)
def test_ro_server(self):
- (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True)
+ rw_server = self.start_server()
+ rw_client = self.start_client(rw_server.address)
+
+ ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
+ ro_client = self.start_client(ro_server.address)
# Report a hash via the read-write server
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
# Check the hash via the read-only server
@@ -373,7 +388,7 @@ class HashEquivalenceCommonTests(object):
ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
# Ensure that the database was not modified
- self.assertClientGetHash(self.client, taskhash2, None)
+ self.assertClientGetHash(rw_client, taskhash2, None)
def test_slow_server_start(self):
@@ -393,7 +408,7 @@ class HashEquivalenceCommonTests(object):
old_signal = signal.signal(signal.SIGTERM, do_nothing)
self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
- _, server = self.start_server(prefunc=prefunc)
+ server = self.start_server(prefunc=prefunc)
server.process.terminate()
time.sleep(30)
event.set()
@@ -500,3 +515,22 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
# case it is more reliable to resolve the IP address explicitly.
host = socket.gethostbyname("localhost")
return "ws://%s:0" % host
+
+
+class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def start_test_server(self):
+ if 'BB_TEST_HASHSERV' not in os.environ:
+ self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+
+ return os.environ['BB_TEST_HASHSERV']
+
+ def start_server(self, *args, **kwargs):
+ self.skipTest('Cannot start local server when testing external servers')
+
+ def setUp(self):
+ super().setUp()
+ self.client.remove({"method": self.METHOD})
+
+ def tearDown(self):
+ self.client.remove({"method": self.METHOD})
+ super().tearDown()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 05/22] asyncrpc: Prefix log messages with client info
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (3 preceding siblings ...)
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 04/22] hashserv: tests: Add external database tests Joshua Watt
@ 2023-11-01 15:41 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
` (16 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:41 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds a logging adaptor to the asyncrpc clients that prefixes log
messages with the client remote address to aid in debugging
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/serv.py | 21 ++++++++++++++++++---
lib/hashserv/server.py | 10 +++++-----
2 files changed, 23 insertions(+), 8 deletions(-)
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 3040ac91..a476cacd 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,10 +12,16 @@ import signal
import socket
import sys
import multiprocessing
+import logging
from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
+class ClientLoggerAdapter(logging.LoggerAdapter):
+ def process(self, msg, kwargs):
+ return f"[Client {self.extra['address']}] {msg}", kwargs
+
+
class AsyncServerConnection(object):
def __init__(self, socket, proto_name, logger):
self.socket = socket
@@ -23,7 +29,12 @@ class AsyncServerConnection(object):
self.handlers = {
"ping": self.handle_ping,
}
- self.logger = logger
+ self.logger = ClientLoggerAdapter(
+ logger,
+ {
+ "address": socket.address,
+ },
+ )
async def close(self):
await self.socket.close()
@@ -236,16 +247,20 @@ class AsyncServer(object):
self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
async def _client_handler(self, socket):
+ address = socket.address
try:
client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error("Error from client: %s" % str(e), exc_info=True)
+ self.logger.error(
+ "Error from client %s: %s" % (address, str(e)), exc_info=True
+ )
traceback.print_exc()
+ finally:
+ self.logger.debug("Client %s disconnected", address)
await socket.close()
- self.logger.debug("Client disconnected")
@abc.abstractmethod
def accept_client(self, socket):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 6d3a4751..928532c7 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -207,7 +207,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- logger.debug('Handling %s' % k)
+ self.logger.debug('Handling %s' % k)
if 'stream' in k:
return await self.handlers[k](msg[k])
else:
@@ -351,7 +351,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
break
(method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
+ #self.logger.debug('Looking up %s %s' % (method, taskhash))
cursor = self.db.cursor()
try:
row = self.query_equivalent(cursor, method, taskhash)
@@ -360,7 +360,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if row is not None:
msg = row['unihash']
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -479,8 +479,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
row = self.query_equivalent(cursor, data['method'], data['taskhash'])
if row['unihash'] == data['unihash']:
- logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
+ self.logger.info('Adding taskhash equivalence for %s with unihash %s',
+ data['taskhash'], row['unihash'])
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 06/22] bitbake-hashserv: Allow arguments from environment
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (4 preceding siblings ...)
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 07/22] hashserv: Abstract database Joshua Watt
` (15 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows the arguments to the bitbake-hashserv command to be specified in
environment variables. This is a very common idiom when running services
in containers as it allows the arguments to be specified from different
sources as desired by the service administrator
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 80 +++++++++++++++++++++++++++++++++-----------
1 file changed, 60 insertions(+), 20 deletions(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 00af76b2..a916a90c 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -11,56 +11,96 @@ import logging
import argparse
import sqlite3
import warnings
+
warnings.simplefilter("default")
-sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
+sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
VERSION = "1.0.0"
-DEFAULT_BIND = 'unix://./hashserve.sock'
+DEFAULT_BIND = "unix://./hashserve.sock"
def main():
- parser = argparse.ArgumentParser(description='Hash Equivalence Reference Server. Version=%s' % VERSION,
- epilog='''The bind address is the path to a unix domain socket if it is
- prefixed with "unix://". Otherwise, it is an IP address
- and port in form ADDRESS:PORT. To bind to all addresses, leave
- the ADDRESS empty, e.g. "--bind :8686". To bind to a specific
- IPv6 address, enclose the address in "[]", e.g.
- "--bind [::1]:8686"'''
- )
-
- parser.add_argument('-b', '--bind', default=DEFAULT_BIND, help='Bind address (default "%(default)s")')
- parser.add_argument('-d', '--database', default='./hashserv.db', help='Database file (default "%(default)s")')
- parser.add_argument('-l', '--log', default='WARNING', help='Set logging level')
- parser.add_argument('-u', '--upstream', help='Upstream hashserv to pull hashes from')
- parser.add_argument('-r', '--read-only', action='store_true', help='Disallow write operations from clients')
+ parser = argparse.ArgumentParser(
+ description="Hash Equivalence Reference Server. Version=%s" % VERSION,
+ formatter_class=argparse.RawTextHelpFormatter,
+ epilog="""
+The bind address may take one of the following formats:
+ unix://PATH - Bind to unix domain socket at PATH
+ ws://ADDRESS:PORT - Bind to websocket on ADDRESS:PORT
+ ADDRESS:PORT - Bind to raw TCP socket on ADDRESS:PORT
+
+To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
+"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
+"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+ """,
+ )
+
+ parser.add_argument(
+ "-b",
+ "--bind",
+ default=os.environ.get("HASHSERVER_BIND", DEFAULT_BIND),
+ help='Bind address (default $HASHSERVER_BIND, "%(default)s")',
+ )
+ parser.add_argument(
+ "-d",
+ "--database",
+ default=os.environ.get("HASHSERVER_DB", "./hashserv.db"),
+ help='Database file (default $HASHSERVER_DB, "%(default)s")',
+ )
+ parser.add_argument(
+ "-l",
+ "--log",
+ default=os.environ.get("HASHSERVER_LOG_LEVEL", "WARNING"),
+ help='Set logging level (default $HASHSERVER_LOG_LEVEL, "%(default)s")',
+ )
+ parser.add_argument(
+ "-u",
+ "--upstream",
+ default=os.environ.get("HASHSERVER_UPSTREAM", None),
+ help="Upstream hashserv to pull hashes from ($HASHSERVER_UPSTREAM)",
+ )
+ parser.add_argument(
+ "-r",
+ "--read-only",
+ action="store_true",
+ help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
+ )
args = parser.parse_args()
- logger = logging.getLogger('hashserv')
+ logger = logging.getLogger("hashserv")
level = getattr(logging, args.log.upper(), None)
if not isinstance(level, int):
- raise ValueError('Invalid log level: %s' % args.log)
+ raise ValueError("Invalid log level: %s" % args.log)
logger.setLevel(level)
console = logging.StreamHandler()
console.setLevel(level)
logger.addHandler(console)
- server = hashserv.create_server(args.bind, args.database, upstream=args.upstream, read_only=args.read_only)
+ read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+
+ server = hashserv.create_server(
+ args.bind,
+ args.database,
+ upstream=args.upstream,
+ read_only=read_only,
+ )
server.serve_forever()
return 0
-if __name__ == '__main__':
+if __name__ == "__main__":
try:
ret = main()
except Exception:
ret = 1
import traceback
+
traceback.print_exc()
sys.exit(ret)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 07/22] hashserv: Abstract database
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (5 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 08/22] hashserv: Add SQLalchemy backend Joshua Watt
` (14 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Abstracts the way the database backend is accessed by the hash
equivalence server to make it possible to use other backends
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/__init__.py | 90 ++-----
lib/hashserv/server.py | 491 +++++++++++++--------------------------
lib/hashserv/sqlite.py | 259 +++++++++++++++++++++
3 files changed, 439 insertions(+), 401 deletions(-)
create mode 100644 lib/hashserv/sqlite.py
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 56b9c6bc..90d8cff1 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -6,7 +6,6 @@
import asyncio
from contextlib import closing
import re
-import sqlite3
import itertools
import json
from urllib.parse import urlparse
@@ -19,92 +18,34 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
-UNIHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("unihash", "TEXT NOT NULL", ""),
-)
-
-UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
-
-OUTHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("outhash", "TEXT NOT NULL", "UNIQUE"),
- ("created", "DATETIME", ""),
-
- # Optional fields
- ("owner", "TEXT", ""),
- ("PN", "TEXT", ""),
- ("PV", "TEXT", ""),
- ("PR", "TEXT", ""),
- ("task", "TEXT", ""),
- ("outhash_siginfo", "TEXT", ""),
-)
-
-OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
-
-def _make_table(cursor, name, definition):
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS {name} (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- {fields}
- UNIQUE({unique})
- )
- '''.format(
- name=name,
- fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
- unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
- ))
-
-
-def setup_database(database, sync=True):
- db = sqlite3.connect(database)
- db.row_factory = sqlite3.Row
-
- with closing(db.cursor()) as cursor:
- _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
- _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
-
- cursor.execute('PRAGMA journal_mode = WAL')
- cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
-
- # Drop old indexes
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
-
- # TODO: Upgrade from tasks_v2?
- cursor.execute('DROP TABLE IF EXISTS tasks_v2')
-
- # Create new indexes
- cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
- cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
-
- return db
-
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
- return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
return (ADDR_TYPE_WS, (addr,))
else:
- m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
if m is not None:
- host = m.group('host')
- port = m.group('port')
+ host = m.group("host")
+ port = m.group("port")
else:
- host, port = addr.split(':')
+ host, port = addr.split(":")
return (ADDR_TYPE_TCP, (host, int(port)))
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+ def sqlite_engine():
+ from .sqlite import DatabaseEngine
+
+ return DatabaseEngine(dbname, sync)
+
from . import server
- db = setup_database(dbname, sync=sync)
- s = server.Server(db, upstream=upstream, read_only=read_only)
+
+ db_engine = sqlite_engine()
+
+ s = server.Server(db_engine, upstream=upstream, read_only=read_only)
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -120,6 +61,7 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
def create_client(addr):
from . import client
+
c = client.Client()
(typ, a) = parse_address(addr)
@@ -132,8 +74,10 @@ def create_client(addr):
return c
+
async def create_async_client(addr):
from . import client
+
c = client.AsyncClient()
(typ, a) = parse_address(addr)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 928532c7..12255cc2 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -3,18 +3,16 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing, contextmanager
from datetime import datetime, timedelta
-import enum
import asyncio
import logging
import math
import time
-from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
+from . import create_async_client
import bb.asyncrpc
-logger = logging.getLogger('hashserv.server')
+logger = logging.getLogger("hashserv.server")
class Measurement(object):
@@ -104,229 +102,136 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-
-
-@enum.unique
-class Resolve(enum.Enum):
- FAIL = enum.auto()
- IGNORE = enum.auto()
- REPLACE = enum.auto()
-
-
-def insert_table(cursor, table, data, on_conflict):
- resolve = {
- Resolve.FAIL: "",
- Resolve.IGNORE: " OR IGNORE",
- Resolve.REPLACE: " OR REPLACE",
- }[on_conflict]
-
- keys = sorted(data.keys())
- query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
- resolve=resolve,
- table=table,
- fields=", ".join(keys),
- values=", ".join(":" + k for k in keys),
- )
- prevrowid = cursor.lastrowid
- cursor.execute(query, data)
- logging.debug(
- "Inserting %r into %s, %s",
- data,
- table,
- on_conflict
- )
- return (cursor.lastrowid, cursor.lastrowid != prevrowid)
-
-def insert_unihash(cursor, data, on_conflict):
- return insert_table(cursor, "unihashes_v2", data, on_conflict)
-
-def insert_outhash(cursor, data, on_conflict):
- return insert_table(cursor, "outhashes_v2", data, on_conflict)
-
-async def copy_unihash_from_upstream(client, db, method, taskhash):
- d = await client.get_taskhash(method, taskhash)
- if d is not None:
- with closing(db.cursor()) as cursor:
- insert_unihash(
- cursor,
- {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE,
- )
- db.commit()
- return d
-
-
-class ServerCursor(object):
- def __init__(self, db, cursor, upstream):
- self.db = db
- self.cursor = cursor
- self.upstream = upstream
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
+ }
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(socket, 'OEHASHEQUIV', logger)
- self.db = db
+ def __init__(
+ self,
+ socket,
+ db_engine,
+ request_stats,
+ backfill_queue,
+ upstream,
+ read_only,
+ ):
+ super().__init__(socket, "OEHASHEQUIV", logger)
+ self.db_engine = db_engine
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
- self.handlers.update({
- 'get': self.handle_get,
- 'get-outhash': self.handle_get_outhash,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- })
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "get-stats": self.handle_get_stats,
+ }
+ )
if not read_only:
- self.handlers.update({
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'reset-stats': self.handle_reset_stats,
- 'backfill-wait': self.handle_backfill_wait,
- 'remove': self.handle_remove,
- 'clean-unused': self.handle_clean_unused,
- })
+ self.handlers.update(
+ {
+ "report": self.handle_report,
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "clean-unused": self.handle_clean_unused,
+ }
+ )
def validate_proto_version(self):
- return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
async def process_requests(self):
- if self.upstream is not None:
- self.upstream_client = await create_async_client(self.upstream)
- else:
- self.upstream_client = None
-
- await super().process_requests()
+ async with self.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.upstream is not None:
+ self.upstream_client = await create_async_client(self.upstream)
+ else:
+ self.upstream_client = None
- if self.upstream_client is not None:
- await self.upstream_client.close()
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- if 'stream' in k:
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
return await self.handlers[k](msg[k])
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
+ with self.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
- fetch_all = request.get('all', False)
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
- with closing(self.db.cursor()) as cursor:
- return await self.get_unihash(cursor, method, taskhash, fetch_all)
+ return await self.get_unihash(method, taskhash, fetch_all)
- async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
+ async def get_unihash(self, method, taskhash, fetch_all=False):
d = None
if fetch_all:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
-
- )
- row = cursor.fetchone()
-
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash, True)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
else:
- row = self.query_equivalent(cursor, method, taskhash)
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
- d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
- insert_unihash(cursor, d, Resolve.IGNORE)
- self.db.commit()
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
return d
async def handle_get_outhash(self, request):
- method = request['method']
- outhash = request['outhash']
- taskhash = request['taskhash']
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
with_unihash = request.get("with_unihash", True)
- with closing(self.db.cursor()) as cursor:
- return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
- async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
d = None
if with_unihash:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
+ row = await self.db.get_unihash_by_outhash(method, outhash)
else:
- cursor.execute(
- """
- SELECT * FROM outhashes_v2
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- """,
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
- row = cursor.fetchone()
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_outhash(method, outhash, taskhash)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
return d
- def update_unified(self, cursor, data):
+ async def update_unified(self, data):
if data is None:
return
- insert_unihash(
- cursor,
- {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
- insert_outhash(
- cursor,
- {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -347,20 +252,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- if l == 'END':
+ if l == "END":
break
(method, taskhash) = l.split()
- #self.logger.debug('Looking up %s %s' % (method, taskhash))
- cursor = self.db.cursor()
- try:
- row = self.query_equivalent(cursor, method, taskhash)
- finally:
- cursor.close()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
- msg = row['unihash']
- #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ msg = row["unihash"]
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -383,118 +284,81 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return "ok"
async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- outhash_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'created': datetime.now()
- }
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- outhash_data[k] = data[k]
-
- # Insert the new entry, unless it already exists
- (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
-
- if inserted:
- # If this row is new, check if it is equivalent to another
- # output hash
- cursor.execute(
- '''
- SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- -- Select any matching output hash except the one we just inserted
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
- -- Pick the oldest hash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- }
- )
- row = cursor.fetchone()
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
- if row is not None:
- # A matching output hash was found. Set our taskhash to the
- # same unihash since they are equivalent
- unihash = row['unihash']
- resolve = Resolve.IGNORE
- else:
- # No matching output hash was found. This is probably the
- # first outhash to be added.
- unihash = data['unihash']
- resolve = Resolve.IGNORE
-
- # Query upstream to see if it has a unihash we can use
- if self.upstream_client is not None:
- upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
- if upstream_data is not None:
- unihash = upstream_data['unihash']
-
-
- insert_unihash(
- cursor,
- {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- },
- resolve
- )
-
- unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
- if unihash_data is not None:
- unihash = unihash_data['unihash']
- else:
- unihash = data['unihash']
-
- self.db.commit()
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash,
- }
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- return d
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
+ }
async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- }
- insert_unihash(cursor, insert_data, Resolve.IGNORE)
- self.db.commit()
-
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(cursor, data['method'], data['taskhash'])
-
- if row['unihash'] == data['unihash']:
- self.logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
-
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
-
- return d
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
+
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
async def handle_get_stats(self, request):
return {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
self.request_stats.reset()
@@ -502,7 +366,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_backfill_wait(self, request):
d = {
- 'tasks': self.backfill_queue.qsize(),
+ "tasks": self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
return d
@@ -512,92 +376,63 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if not isinstance(condition, dict):
raise TypeError("Bad condition type %s" % type(condition))
- def do_remove(columns, table_name, cursor):
- nonlocal condition
- where = {}
- for c in columns:
- if c in condition and condition[c] is not None:
- where[c] = condition[c]
-
- if where:
- query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys())
- cursor.execute(query, where)
- return cursor.rowcount
-
- return 0
-
- count = 0
- with closing(self.db.cursor()) as cursor:
- count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
- count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
- self.db.commit()
-
- return {"count": count}
+ return {"count": await self.db.remove(condition)}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
- with closing(self.db.cursor()) as cursor:
- cursor.execute(
- """
- DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
- SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
- )
- """,
- {
- "oldest": datetime.now() - timedelta(seconds=-max_age)
- }
- )
- count = cursor.rowcount
-
- return {"count": count}
-
- def query_equivalent(self, cursor, method, taskhash):
- # This is part of the inner loop and must be as fast as possible
- cursor.execute(
- 'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
- )
- return cursor.fetchone()
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, upstream=None, read_only=False):
+ def __init__(self, db_engine, upstream=None, read_only=False):
if upstream and read_only:
- raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
super().__init__(logger)
self.request_stats = Stats()
- self.db = db
+ self.db_engine = db_engine
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, socket):
- return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ return ServerClient(
+ socket,
+ self.db_engine,
+ self.request_stats,
+ self.backfill_queue,
+ self.upstream,
+ self.read_only,
+ )
async def backfill_worker_task(self):
- client = await create_async_client(self.upstream)
- try:
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
+
method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
self.backfill_queue.task_done()
- finally:
- await client.close()
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
+
+ self.loop.run_until_complete(self.db_engine.create())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
new file mode 100644
index 00000000..6809c537
--- /dev/null
+++ b/lib/hashserv/sqlite.py
@@ -0,0 +1,259 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+import sqlite3
+import logging
+from contextlib import closing
+
+logger = logging.getLogger("hashserv.sqlite")
+
+UNIHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("unihash", "TEXT NOT NULL", ""),
+)
+
+UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
+
+OUTHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("outhash", "TEXT NOT NULL", "UNIQUE"),
+ ("created", "DATETIME", ""),
+ # Optional fields
+ ("owner", "TEXT", ""),
+ ("PN", "TEXT", ""),
+ ("PV", "TEXT", ""),
+ ("PR", "TEXT", ""),
+ ("task", "TEXT", ""),
+ ("outhash_siginfo", "TEXT", ""),
+)
+
+OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+
+
+def _make_table(cursor, name, definition):
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS {name} (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ {fields}
+ UNIQUE({unique})
+ )
+ """.format(
+ name=name,
+ fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
+ unique=", ".join(
+ name for name, _, flags in definition if "UNIQUE" in flags
+ ),
+ )
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, dbname, sync):
+ self.dbname = dbname
+ self.logger = logger
+ self.sync = sync
+
+ async def create(self):
+ db = sqlite3.connect(self.dbname)
+ db.row_factory = sqlite3.Row
+
+ with closing(db.cursor()) as cursor:
+ _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
+ _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+
+ cursor.execute("PRAGMA journal_mode = WAL")
+ cursor.execute(
+ "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
+ )
+
+ # Drop old indexes
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
+
+ # TODO: Upgrade from tasks_v2?
+ cursor.execute("DROP TABLE IF EXISTS tasks_v2")
+
+ # Create new indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
+ )
+
+ def connect(self, logger):
+ return Database(logger, self.dbname)
+
+
+class Database(object):
+ def __init__(self, logger, dbname, sync=True):
+ self.dbname = dbname
+ self.logger = logger
+
+ self.db = sqlite3.connect(self.dbname)
+ self.db.row_factory = sqlite3.Row
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ self.db.close()
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT * FROM outhashes_v2
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ -- Select any matching output hash except the one we just inserted
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
+ -- Pick the oldest hash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def remove(self, condition):
+ def do_remove(columns, table_name, cursor):
+ where = {}
+ for c in columns:
+ if c in condition and condition[c] is not None:
+ where[c] = condition[c]
+
+ if where:
+ query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
+ "%s=:%s" % (k, k) for k in where.keys()
+ )
+ cursor.execute(query, where)
+ return cursor.rowcount
+
+ return 0
+
+ count = 0
+ with closing(self.db.cursor()) as cursor:
+ count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
+ count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
+ self.db.commit()
+
+ return count
+
+ async def clean_unused(self, oldest):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
+ SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
+ )
+ """,
+ {
+ "oldest": oldest,
+ },
+ )
+ return cursor.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(
+ """
+ INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ "unihash": unihash,
+ },
+ )
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
+
+ async def insert_outhash(self, data):
+ data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
+ keys = sorted(data.keys())
+ query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
+ fields=", ".join(keys),
+ values=", ".join(":" + k for k in keys),
+ )
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(query, data)
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 08/22] hashserv: Add SQLalchemy backend
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (6 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 07/22] hashserv: Abstract database Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
` (13 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an SQLAlchemy backend to the server. While this database backend is
slower than the more direct sqlite backend, it easily supports just
about any SQL server, which is useful for large scale deployments.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 12 ++
lib/bb/asyncrpc/connection.py | 11 +-
lib/hashserv/__init__.py | 21 ++-
lib/hashserv/sqlalchemy.py | 304 ++++++++++++++++++++++++++++++++++
lib/hashserv/tests.py | 19 ++-
5 files changed, 362 insertions(+), 5 deletions(-)
create mode 100644 lib/hashserv/sqlalchemy.py
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index a916a90c..59b8b07f 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -69,6 +69,16 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
action="store_true",
help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
)
+ parser.add_argument(
+ "--db-username",
+ default=os.environ.get("HASHSERVER_DB_USERNAME", None),
+ help="Database username ($HASHSERVER_DB_USERNAME)",
+ )
+ parser.add_argument(
+ "--db-password",
+ default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
+ help="Database password ($HASHSERVER_DB_PASSWORD)",
+ )
args = parser.parse_args()
@@ -90,6 +100,8 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
args.database,
upstream=args.upstream,
read_only=read_only,
+ db_username=args.db_username,
+ db_password=args.db_password,
)
server.serve_forever()
return 0
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index a10628f7..7f0cf6ba 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -7,6 +7,7 @@
import asyncio
import itertools
import json
+from datetime import datetime
from .exceptions import ClientError, ConnectionClosedError
@@ -30,6 +31,12 @@ def chunkify(msg, max_chunk):
yield "\n"
+def json_serialize(obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat()
+ raise TypeError("Type %s not serializeable" % type(obj))
+
+
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
@@ -42,7 +49,7 @@ class StreamConnection(object):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
+ for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
@@ -105,7 +112,7 @@ class WebsocketConnection(object):
return ":".join(str(s) for s in self.socket.remote_address)
async def send_message(self, msg):
- await self.send(json.dumps(msg))
+ await self.send(json.dumps(msg, default=json_serialize))
async def recv_message(self):
m = await self.recv()
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 90d8cff1..9a8ee4e8 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -35,15 +35,32 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+def create_server(
+ addr,
+ dbname,
+ *,
+ sync=True,
+ upstream=None,
+ read_only=False,
+ db_username=None,
+ db_password=None
+):
def sqlite_engine():
from .sqlite import DatabaseEngine
return DatabaseEngine(dbname, sync)
+ def sqlalchemy_engine():
+ from .sqlalchemy import DatabaseEngine
+
+ return DatabaseEngine(dbname, db_username, db_password)
+
from . import server
- db_engine = sqlite_engine()
+ if "://" in dbname:
+ db_engine = sqlalchemy_engine()
+ else:
+ db_engine = sqlite_engine()
s = server.Server(db_engine, upstream=upstream, read_only=read_only)
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
new file mode 100644
index 00000000..3216621f
--- /dev/null
+++ b/lib/hashserv/sqlalchemy.py
@@ -0,0 +1,304 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import logging
+from datetime import datetime
+
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.pool import NullPool
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Text,
+ Integer,
+ UniqueConstraint,
+ DateTime,
+ Index,
+ select,
+ insert,
+ exists,
+ literal,
+ and_,
+ delete,
+)
+import sqlalchemy.engine
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.exc import IntegrityError
+
+logger = logging.getLogger("hashserv.sqlalchemy")
+
+Base = declarative_base()
+
+
+class UnihashesV2(Base):
+ __tablename__ = "unihashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ unihash = Column(Text, nullable=False)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash"),
+ Index("taskhash_lookup_v3", "method", "taskhash"),
+ )
+
+
+class OuthashesV2(Base):
+ __tablename__ = "outhashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ outhash = Column(Text, nullable=False)
+ created = Column(DateTime)
+ owner = Column(Text)
+ PN = Column(Text)
+ PV = Column(Text)
+ PR = Column(Text)
+ task = Column(Text)
+ outhash_siginfo = Column(Text)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash", "outhash"),
+ Index("outhash_lookup_v3", "method", "outhash"),
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, url, username=None, password=None):
+ self.logger = logger
+ self.url = sqlalchemy.engine.make_url(url)
+
+ if username is not None:
+ self.url = self.url.set(username=username)
+
+ if password is not None:
+ self.url = self.url.set(password=password)
+
+ async def create(self):
+ self.logger.info("Using database %s", self.url)
+ self.engine = create_async_engine(self.url, poolclass=NullPool)
+
+ async with self.engine.begin() as conn:
+ # Create tables
+ logger.info("Creating tables...")
+ await conn.run_sync(Base.metadata.create_all)
+
+ def connect(self, logger):
+ return Database(self.engine, logger)
+
+
+def map_row(row):
+ if row is None:
+ return None
+ return dict(**row._mapping)
+
+
+class Database(object):
+ def __init__(self, engine, logger):
+ self.engine = engine
+ self.db = None
+ self.logger = logger
+
+ async def __aenter__(self):
+ self.db = await self.engine.connect()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ await self.db.close()
+ self.db = None
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ statement = (
+ select(
+ OuthashesV2,
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.taskhash == taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2)
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ statement = (
+ select(
+ OuthashesV2.taskhash.label("taskhash"),
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ OuthashesV2.taskhash != taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent(self, method, taskhash):
+ statement = select(
+ UnihashesV2.unihash,
+ UnihashesV2.method,
+ UnihashesV2.taskhash,
+ ).where(
+ UnihashesV2.method == method,
+ UnihashesV2.taskhash == taskhash,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def remove(self, condition):
+ async def do_remove(table):
+ where = {}
+ for c in table.__table__.columns:
+ if c.key in condition and condition[c.key] is not None:
+ where[c] = condition[c.key]
+
+ if where:
+ statement = delete(table).where(*[(k == v) for k, v in where.items()])
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ return 0
+
+ count = 0
+ count += await do_remove(UnihashesV2)
+ count += await do_remove(OuthashesV2)
+
+ return count
+
+ async def clean_unused(self, oldest):
+ statement = delete(OuthashesV2).where(
+ OuthashesV2.created < oldest,
+ ~(
+ select(UnihashesV2.id)
+ .where(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ )
+ .limit(1)
+ .exists()
+ ),
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ statement = insert(UnihashesV2).values(
+ method=method,
+ taskhash=taskhash,
+ unihash=unihash,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s, %s already in unihash database", method, taskhash, unihash
+ )
+ return False
+
+ async def insert_outhash(self, data):
+ outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
+
+ data = {k: v for k, v in data.items() if k in outhash_columns}
+
+ if "created" in data and not isinstance(data["created"], datetime):
+ data["created"] = datetime.fromisoformat(data["created"])
+
+ statement = insert(OuthashesV2).values(**data)
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s already in outhash database", data["method"], data["outhash"]
+ )
+ return False
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 4c98a280..268b2700 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -33,7 +33,7 @@ class HashEquivalenceTestSetup(object):
def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
- dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+ dbpath = self.make_dbpath()
def cleanup_server(server):
if server.process.exitcode is not None:
@@ -53,6 +53,9 @@ class HashEquivalenceTestSetup(object):
return server
+ def make_dbpath(self):
+ return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
def start_client(self, server_address):
def cleanup_client(client):
client.close()
@@ -517,6 +520,20 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
return "ws://%s:0" % host
+class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
+ def setUp(self):
+ try:
+ import sqlalchemy
+ import aiosqlite
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def make_dbpath(self):
+ return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
+
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def start_test_server(self):
if 'BB_TEST_HASHSERV' not in os.environ:
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 09/22] hashserv: Implement read-only version of "report" RPC
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (7 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 08/22] hashserv: Add SQLalchemy backend Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 10/22] asyncrpc: Add InvokeError Joshua Watt
` (12 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
When the hash equivalence server is in read-only mode, it should still
return a unihash for a given "report" call if there is one.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 25 ++++++++++++++++++++++++-
lib/hashserv/tests.py | 4 ++--
2 files changed, 26 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 12255cc2..2e6977cb 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -124,6 +124,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
+ self.read_only = read_only
self.handlers.update(
{
@@ -131,13 +132,15 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ # Not always read-only, but internally checks if the server is
+ # read-only
+ "report": self.handle_report,
}
)
if not read_only:
self.handlers.update(
{
- "report": self.handle_report,
"report-equiv": self.handle_equivreport,
"reset-stats": self.handle_reset_stats,
"backfill-wait": self.handle_backfill_wait,
@@ -283,7 +286,27 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return "ok"
+ async def report_readonly(self, data):
+ method = data["method"]
+ outhash = data["outhash"]
+ taskhash = data["taskhash"]
+
+ info = await self.get_outhash(method, outhash, taskhash)
+ if info:
+ unihash = info["unihash"]
+ else:
+ unihash = data["unihash"]
+
+ return {
+ "taskhash": taskhash,
+ "method": method,
+ "unihash": unihash,
+ }
+
async def handle_report(self, data):
+ if self.read_only:
+ return await self.report_readonly(data)
+
outhash_data = {
"method": data["method"],
"outhash": data["outhash"],
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 268b2700..e9a361dc 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -387,8 +387,8 @@ class HashEquivalenceCommonTests(object):
outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
- with self.assertRaises(ConnectionError):
- ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ self.assertEqual(result['unihash'], unihash2)
# Ensure that the database was not modified
self.assertClientGetHash(rw_client, taskhash2, None)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 10/22] asyncrpc: Add InvokeError
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (8 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
` (11 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for Invocation Errors (that is, errors raised by the actual
RPC call instead of at the protocol level) to propagate across the
connection. If a server RPC call raises an InvokeError, it will be sent
across the connection and then raised on the client side also. The
connection is still terminated on this error.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 1 +
lib/bb/asyncrpc/client.py | 10 ++++++++--
lib/bb/asyncrpc/exceptions.py | 4 ++++
lib/bb/asyncrpc/serv.py | 11 +++++++++--
4 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9f677eac..a4371643 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -12,4 +12,5 @@ from .exceptions import (
ClientError,
ServerError,
ConnectionClosedError,
+ InvokeError,
)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 009085c3..d27dbf71 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -11,7 +11,7 @@ import os
import socket
import sys
from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
-from .exceptions import ConnectionClosedError
+from .exceptions import ConnectionClosedError, InvokeError
class AsyncClient(object):
@@ -93,12 +93,18 @@ class AsyncClient(object):
await self.close()
count += 1
+ def check_invoke_error(self, msg):
+ if isinstance(msg, dict) and "invoke-error" in msg:
+ raise InvokeError(msg["invoke-error"]["message"])
+
async def invoke(self, msg):
async def proc():
await self.socket.send_message(msg)
return await self.socket.recv_message()
- return await self._send_wrapper(proc)
+ result = await self._send_wrapper(proc)
+ self.check_invoke_error(result)
+ return result
async def ping(self):
return await self.invoke({"ping": {}})
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
index a8942b4f..ae1043a3 100644
--- a/lib/bb/asyncrpc/exceptions.py
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -9,6 +9,10 @@ class ClientError(Exception):
pass
+class InvokeError(Exception):
+ pass
+
+
class ServerError(Exception):
pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index a476cacd..fccbb196 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -14,7 +14,7 @@ import sys
import multiprocessing
import logging
from .connection import StreamConnection, WebsocketConnection
-from .exceptions import ClientError, ServerError, ConnectionClosedError
+from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
class ClientLoggerAdapter(logging.LoggerAdapter):
@@ -72,7 +72,14 @@ class AsyncServerConnection(object):
d = await self.socket.recv_message()
if d is None:
break
- response = await self.dispatch_message(d)
+ try:
+ response = await self.dispatch_message(d)
+ except InvokeError as e:
+ await self.socket.send_message(
+ {"invoke-error": {"message": str(e)}}
+ )
+ break
+
await self.socket.send_message(response)
except ConnectionClosedError as e:
self.logger.info(str(e))
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 11/22] asyncrpc: client: Prevent double closing of loop
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (9 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 10/22] asyncrpc: Add InvokeError Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 12/22] asyncrpc: client: Add disconnect API Joshua Watt
` (10 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Invalidate the loop in the client close() call so that it is not closed
twice (which is an error in the asyncio code)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index d27dbf71..628b90ee 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -161,10 +161,12 @@ class Client(object):
self.client.max_chunk = value
def close(self):
- self.loop.run_until_complete(self.client.close())
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ if self.loop:
+ self.loop.run_until_complete(self.client.close())
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
def __enter__(self):
return self
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 12/22] asyncrpc: client: Add disconnect API
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (10 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 13/22] hashserv: Add user permissions Joshua Watt
` (9 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to explicitly disconnect a client. This can be useful for
testing the auto-reconnect behavior of clients
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 628b90ee..0d7cd857 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -67,11 +67,14 @@ class AsyncClient(object):
self.socket = await self._connect_sock()
await self.setup_connection()
- async def close(self):
+ async def disconnect(self):
if self.socket is not None:
await self.socket.close()
self.socket = None
+ async def close(self):
+ await self.disconnect()
+
async def _send_wrapper(self, proc):
count = 0
while True:
@@ -160,6 +163,9 @@ class Client(object):
def max_chunk(self, value):
self.client.max_chunk = value
+ def disconnect(self):
+ self.loop.run_until_complete(self.client.close())
+
def close(self):
if self.loop:
self.loop.run_until_complete(self.client.close())
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 13/22] hashserv: Add user permissions
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (11 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 12/22] asyncrpc: client: Add disconnect API Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 14/22] hashserv: Add become-user API Joshua Watt
` (8 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for the hashserver to have per-user permissions. User
management is done via a new "auth" RPC API where a client can
authenticate itself with the server using a randomly generated token.
The user can then be given permissions to read, report, manage the
database, or manage other users.
In addition to explicit user logins, the server supports anonymous users
which is what all users start as before they make the "auth" RPC call.
Anonymous users can be assigned a set of permissions by the server,
making it unnecessary for users to authenticate to use the server. The
set of Anonymous permissions defines the default behavior of the server,
for example if set to "@read", Anonymous users are unable to report
equivalent hashes with authenticating. Similarly, setting the Anonymous
permissions to "@none" would require authentication for users to perform
any action.
User creation and management is entirely manual (although
bitbake-hashclient is very useful as a front end). There are many
different mechanisms that could be implemented to allow user
self-registration (e.g. OAuth, LDAP, etc.), and implementing these is
outside the scope of the server. Instead, it is recommended to
implement a registration service that validates users against the
necessary service, then adds them as a user in the hash equivalence
server.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 84 ++++++++-
bin/bitbake-hashserv | 37 ++++
lib/hashserv/__init__.py | 69 ++++---
lib/hashserv/client.py | 62 ++++++-
lib/hashserv/server.py | 357 ++++++++++++++++++++++++++++++++++++-
lib/hashserv/sqlalchemy.py | 111 +++++++++++-
lib/hashserv/sqlite.py | 105 +++++++++++
lib/hashserv/tests.py | 276 +++++++++++++++++++++++++++-
8 files changed, 1054 insertions(+), 47 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index a02a65b9..328c15cd 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -14,6 +14,7 @@ import sys
import threading
import time
import warnings
+import netrc
warnings.simplefilter("default")
try:
@@ -36,10 +37,18 @@ except ImportError:
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
import hashserv
+import bb.asyncrpc
DEFAULT_ADDRESS = 'unix://./hashserve.sock'
METHOD = 'stress.test.method'
+def print_user(u):
+ print(f"Username: {u['username']}")
+ if "permissions" in u:
+ print("Permissions: " + " ".join(u["permissions"]))
+ if "token" in u:
+ print(f"Token: {u['token']}")
+
def main():
def handle_stats(args, client):
@@ -125,9 +134,39 @@ def main():
print("Removed %d rows" % (result["count"]))
return 0
+ def handle_refresh_token(args, client):
+ r = client.refresh_token(args.username)
+ print_user(r)
+
+ def handle_set_user_permissions(args, client):
+ r = client.set_user_perms(args.username, args.permissions)
+ print_user(r)
+
+ def handle_get_user(args, client):
+ r = client.get_user(args.username)
+ print_user(r)
+
+ def handle_get_all_users(args, client):
+ users = client.get_all_users()
+ print("{username:20}| {permissions}".format(username="Username", permissions="Permissions"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for u in users:
+ print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"])))
+
+ def handle_new_user(args, client):
+ r = client.new_user(args.username, args.permissions)
+ print_user(r)
+
+ def handle_delete_user(args, client):
+ r = client.delete_user(args.username)
+ print_user(r)
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
+ parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
+ parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -158,6 +197,31 @@ def main():
clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old")
clean_unused_parser.set_defaults(func=handle_clean_unused)
+ refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token")
+ refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)")
+ refresh_token_parser.set_defaults(func=handle_refresh_token)
+
+ set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user")
+ set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True)
+ set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ set_user_perms_parser.set_defaults(func=handle_set_user_permissions)
+
+ get_user_parser = subparsers.add_parser('get-user', help="Get user")
+ get_user_parser.add_argument("--username", "-u", help="Username")
+ get_user_parser.set_defaults(func=handle_get_user)
+
+ get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users")
+ get_all_users_parser.set_defaults(func=handle_get_all_users)
+
+ new_user_parser = subparsers.add_parser('new-user', help="Create new user")
+ new_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ new_user_parser.set_defaults(func=handle_new_user)
+
+ delete_user_parser = subparsers.add_parser('delete-user', help="Delete user")
+ delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ delete_user_parser.set_defaults(func=handle_delete_user)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
@@ -171,10 +235,26 @@ def main():
console.setLevel(level)
logger.addHandler(console)
+ login = args.login
+ password = args.password
+
+ if login is None and args.netrc:
+ try:
+ n = netrc.netrc()
+ auth = n.authenticators(args.address)
+ if auth is not None:
+ login, _, password = auth
+ except FileNotFoundError:
+ pass
+
func = getattr(args, 'func', None)
if func:
- with hashserv.create_client(args.address) as client:
- return func(args, client)
+ try:
+ with hashserv.create_client(args.address, login, password) as client:
+ return func(args, client)
+ except bb.asyncrpc.InvokeError as e:
+ print(f"ERROR: {e}")
+ return 1
return 0
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 59b8b07f..1085d058 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -17,6 +17,7 @@ warnings.simplefilter("default")
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
+from hashserv.server import DEFAULT_ANON_PERMS
VERSION = "1.0.0"
@@ -36,6 +37,22 @@ The bind address may take one of the following formats:
To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+
+Note that the default Anonymous permissions are designed to not break existing
+server instances when upgrading, but are not particularly secure defaults. If
+you want to use authentication, it is recommended that you use "--anon-perms
+@read" to only give anonymous users read access, or "--anon-perms @none" to
+give un-authenticated users no access at all.
+
+Setting "--anon-perms @all" or "--anon-perms @user-admin" is not allowed, since
+this would allow anonymous users to manage all users accounts, which is a bad
+idea.
+
+If you are using user authentication, you should run your server in websockets
+mode with an SSL terminating load balancer in front of it (as this server does
+not implement SSL). Otherwise all usernames and passwords will be transmitted
+in the clear. When configured this way, clients can connect using a secure
+websocket, as in "wss://SERVER:PORT"
""",
)
@@ -79,6 +96,22 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
help="Database password ($HASHSERVER_DB_PASSWORD)",
)
+ parser.add_argument(
+ "--anon-perms",
+ metavar="PERM[,PERM[,...]]",
+ default=os.environ.get("HASHSERVER_ANON_PERMS", ",".join(DEFAULT_ANON_PERMS)),
+ help='Permissions to give anonymous users (default $HASHSERVER_ANON_PERMS, "%(default)s")',
+ )
+ parser.add_argument(
+ "--admin-user",
+ default=os.environ.get("HASHSERVER_ADMIN_USER", None),
+ help="Create default admin user with name ADMIN_USER ($HASHSERVER_ADMIN_USER)",
+ )
+ parser.add_argument(
+ "--admin-password",
+ default=os.environ.get("HASHSERVER_ADMIN_PASSWORD", None),
+ help="Create default admin user with password ADMIN_PASSWORD ($HASHSERVER_ADMIN_PASSWORD)",
+ )
args = parser.parse_args()
@@ -94,6 +127,7 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+ anon_perms = args.anon_perms.split(",")
server = hashserv.create_server(
args.bind,
@@ -102,6 +136,9 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
read_only=read_only,
db_username=args.db_username,
db_password=args.db_password,
+ anon_perms=anon_perms,
+ admin_username=args.admin_user,
+ admin_password=args.admin_password,
)
server.serve_forever()
return 0
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9a8ee4e8..552a3327 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -8,6 +8,7 @@ from contextlib import closing
import re
import itertools
import json
+from collections import namedtuple
from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
@@ -18,6 +19,8 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
+User = namedtuple("User", ("username", "permissions"))
+
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
@@ -43,7 +46,10 @@ def create_server(
upstream=None,
read_only=False,
db_username=None,
- db_password=None
+ db_password=None,
+ anon_perms=None,
+ admin_username=None,
+ admin_password=None,
):
def sqlite_engine():
from .sqlite import DatabaseEngine
@@ -62,7 +68,17 @@ def create_server(
else:
db_engine = sqlite_engine()
- s = server.Server(db_engine, upstream=upstream, read_only=read_only)
+ if anon_perms is None:
+ anon_perms = server.DEFAULT_ANON_PERMS
+
+ s = server.Server(
+ db_engine,
+ upstream=upstream,
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password,
+ )
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -76,33 +92,40 @@ def create_server(
return s
-def create_client(addr):
+def create_client(addr, username=None, password=None):
from . import client
- c = client.Client()
-
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- c.connect_websocket(*a)
- else:
- c.connect_tcp(*a)
+ c = client.Client(username, password)
- return c
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
+ else:
+ c.connect_tcp(*a)
+ return c
+ except Exception as e:
+ c.close()
+ raise e
-async def create_async_client(addr):
+async def create_async_client(addr, username=None, password=None):
from . import client
- c = client.AsyncClient()
+ c = client.AsyncClient(username, password)
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- await c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- await c.connect_websocket(*a)
- else:
- await c.connect_tcp(*a)
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
+ else:
+ await c.connect_tcp(*a)
- return c
+ return c
+ except Exception as e:
+ await c.close()
+ raise e
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index ebb58f33..5ed8d381 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -6,6 +6,7 @@
import logging
import socket
import bb.asyncrpc
+import json
from . import create_async_client
@@ -16,15 +17,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_NORMAL = 0
MODE_GET_STREAM = 1
- def __init__(self):
+ def __init__(self, username=None, password=None):
super().__init__('OEHASHEQUIV', '1.1', logger)
self.mode = self.MODE_NORMAL
+ self.username = username
+ self.password = password
async def setup_connection(self):
await super().setup_connection()
cur_mode = self.mode
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
+ if self.username:
+ await self.auth(self.username, self.password)
async def send_stream(self, msg):
async def proc():
@@ -41,6 +46,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
r = await self._send_wrapper(stream_to_normal)
if r != "ok":
+ self.check_invoke_error(r)
raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
r = await self.invoke({"get-stream": None})
@@ -109,9 +115,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
+ async def auth(self, username, token):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"auth": {"username": username, "token": token}})
+ self.username = username
+ self.password = token
+ return result
+
+ async def refresh_token(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ result = await self.invoke({"refresh-token": m})
+ if self.username and result["username"] == self.username:
+ self.password = result["token"]
+ return result
+
+ async def set_user_perms(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+
+ async def get_user(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ return await self.invoke({"get-user": m})
+
+ async def get_all_users(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-all-users": {}}))["users"]
+
+ async def new_user(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+
+ async def delete_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"delete-user": {"username": username}})
+
class Client(bb.asyncrpc.Client):
- def __init__(self):
+ def __init__(self, username=None, password=None):
+ self.username = username
+ self.password = password
+
super().__init__()
self._add_methods(
"connect_tcp",
@@ -126,7 +175,14 @@ class Client(bb.asyncrpc.Client):
"backfill_wait",
"remove",
"clean_unused",
+ "auth",
+ "refresh_token",
+ "set_user_perms",
+ "get_user",
+ "get_all_users",
+ "new_user",
+ "delete_user",
)
def _get_async_client(self):
- return AsyncClient()
+ return AsyncClient(self.username, self.password)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 2e6977cb..5c70d81f 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -8,13 +8,48 @@ import asyncio
import logging
import math
import time
+import os
+import base64
+import hashlib
from . import create_async_client
import bb.asyncrpc
-
logger = logging.getLogger("hashserv.server")
+# This permission only exists to match nothing
+NONE_PERM = "@none"
+
+READ_PERM = "@read"
+REPORT_PERM = "@report"
+DB_ADMIN_PERM = "@db-admin"
+USER_ADMIN_PERM = "@user-admin"
+ALL_PERM = "@all"
+
+ALL_PERMISSIONS = {
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+ USER_ADMIN_PERM,
+ ALL_PERM,
+}
+
+DEFAULT_ANON_PERMS = (
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+)
+
+TOKEN_ALGORITHM = "sha256"
+
+# 48 bytes of random data will result in 64 characters when base64
+# encoded. This number also ensures that the base64 encoding won't have any
+# trailing '=' characters.
+TOKEN_SIZE = 48
+
+SALT_SIZE = 8
+
+
class Measurement(object):
def __init__(self, sample):
self.sample = sample
@@ -108,6 +143,85 @@ class Stats(object):
}
+token_refresh_semaphore = asyncio.Lock()
+
+
+async def new_token():
+ # Prevent malicious users from using this API to deduce the entropy
+ # pool on the server and thus be able to guess a token. *All* token
+ # refresh requests lock the same global semaphore and then sleep for a
+ # short time. The effectively rate limits the total number of requests
+ # than can be made across all clients to 10/second, which should be enough
+ # since you have to be an authenticated users to make the request in the
+ # first place
+ async with token_refresh_semaphore:
+ await asyncio.sleep(0.1)
+ raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
+
+ return base64.b64encode(raw, b"._").decode("utf-8")
+
+
+def new_salt():
+ return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
+
+
+def hash_token(algo, salt, token):
+ h = hashlib.new(algo)
+ h.update(salt.encode("utf-8"))
+ h.update(token.encode("utf-8"))
+ return ":".join([algo, salt, h.hexdigest()])
+
+
+def permissions(*permissions, allow_anon=True, allow_self_service=False):
+ """
+ Function decorator that can be used to decorate an RPC function call and
+ check that the current users permissions match the require permissions.
+
+ If allow_anon is True, the user will also be allowed to make the RPC call
+ if the anonymous user permissions match the permissions.
+
+ If allow_self_service is True, and the "username" property in the request
+ is the currently logged in user, or not specified, the user will also be
+ allowed to make the request. This allows users to access normal privileged
+ API, as long as they are only modifying their own user properties (e.g.
+ users can be allowed to reset their own token without @user-admin
+ permissions, but not the token for any other user.
+ """
+
+ def wrapper(func):
+ async def wrap(self, request):
+ if allow_self_service and self.user is not None:
+ username = request.get("username", self.user.username)
+ if username == self.user.username:
+ request["username"] = self.user.username
+ return await func(self, request)
+
+ if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
+ if not self.user:
+ username = "Anonymous user"
+ user_perms = self.anon_perms
+ else:
+ username = self.user.username
+ user_perms = self.user.permissions
+
+ self.logger.info(
+ "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
+ username,
+ ", ".join(user_perms),
+ func.__name__,
+ ", ".join(permissions),
+ )
+ raise bb.asyncrpc.InvokeError(
+ f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
+ )
+
+ return await func(self, request)
+
+ return wrap
+
+ return wrapper
+
+
class ServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(
self,
@@ -117,6 +231,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
backfill_queue,
upstream,
read_only,
+ anon_perms,
):
super().__init__(socket, "OEHASHEQUIV", logger)
self.db_engine = db_engine
@@ -125,6 +240,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.backfill_queue = backfill_queue
self.upstream = upstream
self.read_only = read_only
+ self.user = None
+ self.anon_perms = anon_perms
self.handlers.update(
{
@@ -135,6 +252,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
+ "auth": self.handle_auth,
+ "get-user": self.handle_get_user,
+ "get-all-users": self.handle_get_all_users,
}
)
@@ -146,9 +266,36 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"backfill-wait": self.handle_backfill_wait,
"remove": self.handle_remove,
"clean-unused": self.handle_clean_unused,
+ "refresh-token": self.handle_refresh_token,
+ "set-user-perms": self.handle_set_perms,
+ "new-user": self.handle_new_user,
+ "delete-user": self.handle_delete_user,
}
)
+ def raise_no_user_error(self, username):
+ raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
+
+ def user_has_permissions(self, *permissions, allow_anon=True):
+ permissions = set(permissions)
+ if allow_anon:
+ if ALL_PERM in self.anon_perms:
+ return True
+
+ if not permissions - self.anon_perms:
+ return True
+
+ if self.user is None:
+ return False
+
+ if ALL_PERM in self.user.permissions:
+ return True
+
+ if not permissions - self.user.permissions:
+ return True
+
+ return False
+
def validate_proto_version(self):
return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
@@ -178,6 +325,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
+ @permissions(READ_PERM)
async def handle_get(self, request):
method = request["method"]
taskhash = request["taskhash"]
@@ -206,6 +354,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return d
+ @permissions(READ_PERM)
async def handle_get_outhash(self, request):
method = request["method"]
outhash = request["outhash"]
@@ -236,6 +385,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
await self.db.insert_outhash(data)
+ @permissions(READ_PERM)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -303,8 +453,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ # Since this can be called either read only or to report, the check to
+ # report is made inside the function
+ @permissions(READ_PERM)
async def handle_report(self, data):
- if self.read_only:
+ if self.read_only or not self.user_has_permissions(REPORT_PERM):
return await self.report_readonly(data)
outhash_data = {
@@ -357,6 +510,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ @permissions(READ_PERM, REPORT_PERM)
async def handle_equivreport(self, data):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
@@ -374,11 +528,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {k: row[k] for k in ("taskhash", "method", "unihash")}
+ @permissions(READ_PERM)
async def handle_get_stats(self, request):
return {
"requests": self.request_stats.todict(),
}
+ @permissions(DB_ADMIN_PERM)
async def handle_reset_stats(self, request):
d = {
"requests": self.request_stats.todict(),
@@ -387,6 +543,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.request_stats.reset()
return d
+ @permissions(READ_PERM)
async def handle_backfill_wait(self, request):
d = {
"tasks": self.backfill_queue.qsize(),
@@ -394,6 +551,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.backfill_queue.join()
return d
+ @permissions(DB_ADMIN_PERM)
async def handle_remove(self, request):
condition = request["where"]
if not isinstance(condition, dict):
@@ -401,19 +559,178 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"count": await self.db.remove(condition)}
+ @permissions(DB_ADMIN_PERM)
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ # The authentication API is always allowed
+ async def handle_auth(self, request):
+ username = str(request["username"])
+ token = str(request["token"])
+
+ async def fail_auth():
+ nonlocal username
+ # Rate limit bad login attempts
+ await asyncio.sleep(1)
+ raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
+
+ user, db_token = await self.db.lookup_user_token(username)
+
+ if not user or not db_token:
+ await fail_auth()
+
+ try:
+ algo, salt, _ = db_token.split(":")
+ except ValueError:
+ await fail_auth()
+
+ if hash_token(algo, salt, token) != db_token:
+ await fail_auth()
+
+ self.user = user
+
+ self.logger.info("Authenticated as %s", username)
+
+ return {
+ "result": True,
+ "username": self.user.username,
+ "permissions": sorted(list(self.user.permissions)),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_refresh_token(self, request):
+ username = str(request["username"])
+
+ token = await new_token()
+
+ updated = await self.db.set_user_token(
+ username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not updated:
+ self.raise_no_user_error(username)
+
+ return {"username": username, "token": token}
+
+ def get_perm_arg(self, arg):
+ if not isinstance(arg, list):
+ raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
+
+ arg = set(arg)
+ try:
+ arg.remove(NONE_PERM)
+ except KeyError:
+ pass
+
+ unknown_perms = arg - ALL_PERMISSIONS
+ if unknown_perms:
+ raise bb.asyncrpc.InvokeError(
+ "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
+ )
+
+ return sorted(list(arg))
+
+ def return_perms(self, permissions):
+ if ALL_PERM in permissions:
+ return sorted(list(ALL_PERMISSIONS))
+ return sorted(list(permissions))
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_set_perms(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ if not await self.db.set_user_perms(username, permissions):
+ self.raise_no_user_error(username)
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_get_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ return None
+
+ return {
+ "username": user.username,
+ "permissions": self.return_perms(user.permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_get_all_users(self, request):
+ users = await self.db.get_all_users()
+ return {
+ "users": [
+ {
+ "username": u.username,
+ "permissions": self.return_perms(u.permissions),
+ }
+ for u in users
+ ]
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_new_user(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ token = await new_token()
+
+ inserted = await self.db.new_user(
+ username,
+ permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not inserted:
+ raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ "token": token,
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_delete_user(self, request):
+ username = str(request["username"])
+
+ if not await self.db.delete_user(username):
+ self.raise_no_user_error(username)
+
+ return {"username": username}
+
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db_engine, upstream=None, read_only=False):
+ def __init__(
+ self,
+ db_engine,
+ upstream=None,
+ read_only=False,
+ anon_perms=DEFAULT_ANON_PERMS,
+ admin_username=None,
+ admin_password=None,
+ ):
if upstream and read_only:
raise bb.asyncrpc.ServerError(
"Read-only hashserv cannot pull from an upstream server"
)
+ disallowed_perms = set(anon_perms) - set(
+ [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
+ )
+
+ if disallowed_perms:
+ raise bb.asyncrpc.ServerError(
+ f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
+ )
+
super().__init__(logger)
self.request_stats = Stats()
@@ -421,6 +738,13 @@ class Server(bb.asyncrpc.AsyncServer):
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
+ self.anon_perms = set(anon_perms)
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+
+ self.logger.info(
+ "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
+ )
def accept_client(self, socket):
return ServerClient(
@@ -430,12 +754,34 @@ class Server(bb.asyncrpc.AsyncServer):
self.backfill_queue,
self.upstream,
self.read_only,
+ self.anon_perms,
)
+ async def create_admin_user(self):
+ admin_permissions = (ALL_PERM,)
+ async with self.db_engine.connect(self.logger) as db:
+ added = await db.new_user(
+ self.admin_username,
+ admin_permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ if added:
+ self.logger.info("Created admin user '%s'", self.admin_username)
+ else:
+ await db.set_user_perms(
+ self.admin_username,
+ admin_permissions,
+ )
+ await db.set_user_token(
+ self.admin_username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ self.logger.info("Admin user '%s' updated", self.admin_username)
+
async def backfill_worker_task(self):
async with await create_async_client(
self.upstream
- ) as client, self.db_engine.connect(logger) as db:
+ ) as client, self.db_engine.connect(self.logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
@@ -456,6 +802,9 @@ class Server(bb.asyncrpc.AsyncServer):
self.loop.run_until_complete(self.db_engine.create())
+ if self.admin_username:
+ self.loop.run_until_complete(self.create_admin_user())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 3216621f..bfd8a844 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -7,6 +7,7 @@
import logging
from datetime import datetime
+from . import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
@@ -25,13 +26,12 @@ from sqlalchemy import (
literal,
and_,
delete,
+ update,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.exc import IntegrityError
-logger = logging.getLogger("hashserv.sqlalchemy")
-
Base = declarative_base()
@@ -68,9 +68,19 @@ class OuthashesV2(Base):
)
+class Users(Base):
+ __tablename__ = "users"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ username = Column(Text, nullable=False)
+ token = Column(Text, nullable=False)
+ permissions = Column(Text)
+
+ __table_args__ = (UniqueConstraint("username"),)
+
+
class DatabaseEngine(object):
def __init__(self, url, username=None, password=None):
- self.logger = logger
+ self.logger = logging.getLogger("hashserv.sqlalchemy")
self.url = sqlalchemy.engine.make_url(url)
if username is not None:
@@ -85,7 +95,7 @@ class DatabaseEngine(object):
async with self.engine.begin() as conn:
# Create tables
- logger.info("Creating tables...")
+ self.logger.info("Creating tables...")
await conn.run_sync(Base.metadata.create_all)
def connect(self, logger):
@@ -98,6 +108,15 @@ def map_row(row):
return dict(**row._mapping)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row.username,
+ permissions=set(row.permissions.split()),
+ )
+
+
class Database(object):
def __init__(self, engine, logger):
self.engine = engine
@@ -278,7 +297,7 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s, %s already in unihash database", method, taskhash, unihash
)
return False
@@ -298,7 +317,87 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s already in outhash database", data["method"], data["outhash"]
)
return False
+
+ async def _get_user(self, username):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ Users.token,
+ ).where(
+ Users.username == username,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.first()
+
+ async def lookup_user_token(self, username):
+ row = await self._get_user(username)
+ if not row:
+ return None, None
+ return map_user(row), row.token
+
+ async def lookup_user(self, username):
+ return map_user(await self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ statement = (
+ update(Users)
+ .where(
+ Users.username == username,
+ )
+ .values(
+ token=token,
+ )
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ statement = (
+ update(Users)
+ .where(Users.username == username)
+ .values(permissions=" ".join(permissions))
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def get_all_users(self):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return [map_user(row) for row in result]
+
+ async def new_user(self, username, permissions, token):
+ statement = insert(Users).values(
+ username=username,
+ permissions=" ".join(permissions),
+ token=token,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError as e:
+ self.logger.debug("Cannot create new user %s: %s", username, e)
+ return False
+
+ async def delete_user(self, username):
+ statement = delete(Users).where(Users.username == username)
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 6809c537..414ee8ff 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -7,6 +7,7 @@
import sqlite3
import logging
from contextlib import closing
+from . import User
logger = logging.getLogger("hashserv.sqlite")
@@ -34,6 +35,14 @@ OUTHASH_TABLE_DEFINITION = (
OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+USERS_TABLE_DEFINITION = (
+ ("username", "TEXT NOT NULL", "UNIQUE"),
+ ("token", "TEXT NOT NULL", ""),
+ ("permissions", "TEXT NOT NULL", ""),
+)
+
+USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
+
def _make_table(cursor, name, definition):
cursor.execute(
@@ -53,6 +62,15 @@ def _make_table(cursor, name, definition):
)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row["username"],
+ permissions=set(row["permissions"].split()),
+ )
+
+
class DatabaseEngine(object):
def __init__(self, dbname, sync):
self.dbname = dbname
@@ -66,6 +84,7 @@ class DatabaseEngine(object):
with closing(db.cursor()) as cursor:
_make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
_make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+ _make_table(cursor, "users", USERS_TABLE_DEFINITION)
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute(
@@ -227,6 +246,7 @@ class Database(object):
"oldest": oldest,
},
)
+ self.db.commit()
return cursor.rowcount
async def insert_unihash(self, method, taskhash, unihash):
@@ -257,3 +277,88 @@ class Database(object):
cursor.execute(query, data)
self.db.commit()
return cursor.lastrowid != prevrowid
+
+ def _get_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT username, permissions, token FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ return cursor.fetchone()
+
+ async def lookup_user_token(self, username):
+ row = self._get_user(username)
+ if row is None:
+ return None, None
+ return map_user(row), row["token"]
+
+ async def lookup_user(self, username):
+ return map_user(self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET token=:token WHERE username=:username
+ """,
+ {
+ "username": username,
+ "token": token,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET permissions=:permissions WHERE username=:username
+ """,
+ {
+ "username": username,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def get_all_users(self):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute("SELECT username, permissions FROM users")
+ return [map_user(r) for r in cursor.fetchall()]
+
+ async def new_user(self, username, permissions, token):
+ with closing(self.db.cursor()) as cursor:
+ try:
+ cursor.execute(
+ """
+ INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
+ """,
+ {
+ "username": username,
+ "token": token,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return True
+ except sqlite3.IntegrityError:
+ return False
+
+ async def delete_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e9a361dc..f92f37c4 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -6,6 +6,8 @@
#
from . import create_server, create_client
+from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
+from bb.asyncrpc import InvokeError
import hashlib
import logging
import multiprocessing
@@ -29,8 +31,9 @@ class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
+ client_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
self.server_index += 1
if dbpath is None:
dbpath = self.make_dbpath()
@@ -45,7 +48,10 @@ class HashEquivalenceTestSetup(object):
server = create_server(self.get_server_addr(self.server_index),
dbpath,
upstream=upstream,
- read_only=read_only)
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password)
server.dbpath = dbpath
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
@@ -56,18 +62,31 @@ class HashEquivalenceTestSetup(object):
def make_dbpath(self):
return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def start_client(self, server_address):
+ def start_client(self, server_address, username=None, password=None):
def cleanup_client(client):
client.close()
- client = create_client(server_address)
+ client = create_client(server_address, username=username, password=password)
self.addCleanup(cleanup_client, client)
return client
def start_test_server(self):
- server = self.start_server()
- return server.address
+ self.server = self.start_server()
+ return self.server.address
+
+ def start_auth_server(self):
+ self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ return self.admin_client
+
+ def auth_client(self, user):
+ return self.start_client(self.auth_server.address, user["username"], user["token"])
+
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -86,18 +105,21 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
- def test_create_hash(self):
+ def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- self.assertClientGetHash(self.client, taskhash, None)
+ self.assertClientGetHash(client, taskhash, None)
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def test_create_hash(self):
+ return self.create_test_hash(self.client)
+
def test_create_equivalent(self):
# Tests that a second reported task with the same outhash will be
# assigned the same unihash
@@ -471,6 +493,242 @@ class HashEquivalenceCommonTests(object):
# shares a taskhash with Task 2
self.assertClientGetHash(self.client, taskhash2, unihash2)
+ def test_auth_read_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Create hashes with non-authenticated server
+ taskhash, outhash, unihash = self.test_create_hash()
+
+ # Validate hash can be retrieved using authenticated client
+ with self.auth_perms("@read") as client:
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_report_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Without read permission, the user is completely denied
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read permission allows the call to succeed, but it doesn't record
+ # anythin in the database
+ with self.auth_perms("@read") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, None)
+
+ # Report permission alone is insufficient
+ with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read and report permission actually modify the database
+ with self.auth_perms("@read", "@report") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_no_token_refresh_from_anon_user(self):
+ self.start_auth_server()
+
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.refresh_token()
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
+
+ def test_auth_self_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ # Create a new user with no permissions
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client:
+ new_user = client.refresh_token()
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ # Explicitly specifying with your own username is fine also
+ with self.auth_client(new_user) as client:
+ new_user2 = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user2["username"])
+ self.assertNotEqual(user["token"], new_user2["token"])
+ self.assertUserCanAuth(new_user2)
+ self.assertUserCannotAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.refresh_token(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ new_user = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_self_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Explicitly asking for your own username is fine also
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ def test_auth_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.get_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ info = client.get_user("nonexist-user")
+ self.assertIsNone(info)
+
+ def test_auth_reconnect(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ def test_auth_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ # No self service
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ client.delete_user(user["username"])
+
+ # User doesn't exist, so even though the permission is correct, it's an
+ # error
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def test_auth_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ self.assertUserPerms(user, [])
+
+ # No self service to change permissions
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms("@user-admin") as client:
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ # Bad permissions
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ def test_auth_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.get_all_users()
+
+ # Give the test user the correct permission
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ with self.auth_client(user) as client:
+ all_users = client.get_all_users()
+
+ # Convert to a dictionary for easier comparison
+ all_users = {u["username"]: u for u in all_users}
+
+ self.assertEqual(all_users,
+ {
+ "admin": {
+ "username": "admin",
+ "permissions": sorted(list(ALL_PERMISSIONS)),
+ },
+ "test-user": {
+ "username": "test-user",
+ "permissions": ["@user-admin"],
+ }
+ }
+ )
+
+ def test_auth_new_user(self):
+ self.start_auth_server()
+
+ permissions = ["@read", "@report", "@db-admin", "@user-admin"]
+ permissions.sort()
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.new_user("test-user", permissions)
+
+ with self.auth_perms("@user-admin") as client:
+ user = client.new_user("test-user", permissions)
+ self.assertIn("token", user)
+ self.assertEqual(user["username"], "test-user")
+ self.assertEqual(user["permissions"], permissions)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 14/22] hashserv: Add become-user API
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (12 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 13/22] hashserv: Add user permissions Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 15/22] hashserv: Add db-usage API Joshua Watt
` (7 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds API that allows a user admin to impersonate another user in the
system. This makes it easier to write external services that have
external authentication, since they can use a common user account to
access the server, then impersonate the logged in user.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 +++
lib/hashserv/client.py | 42 +++++++++++++++++++++++++++++++++++++-----
lib/hashserv/server.py | 18 ++++++++++++++++++
lib/hashserv/tests.py | 39 +++++++++++++++++++++++++++++++++++++++
4 files changed, 97 insertions(+), 5 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 328c15cd..cfbc197e 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -166,6 +166,7 @@ def main():
parser.add_argument('--log', default='WARNING', help='Set logging level')
parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions")
parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -251,6 +252,8 @@ def main():
if func:
try:
with hashserv.create_client(args.address, login, password) as client:
+ if args.become:
+ client.become_user(args.become)
return func(args, client)
except bb.asyncrpc.InvokeError as e:
print(f"ERROR: {e}")
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 5ed8d381..90f1dd71 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -18,10 +18,11 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_GET_STREAM = 1
def __init__(self, username=None, password=None):
- super().__init__('OEHASHEQUIV', '1.1', logger)
+ super().__init__("OEHASHEQUIV", "1.1", logger)
self.mode = self.MODE_NORMAL
self.username = username
self.password = password
+ self.saved_become_user = None
async def setup_connection(self):
await super().setup_connection()
@@ -29,8 +30,13 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
if self.username:
+ # Save off become user temporarily because auth() resets it
+ become = self.saved_become_user
await self.auth(self.username, self.password)
+ if become:
+ await self.become_user(become)
+
async def send_stream(self, msg):
async def proc():
await self.socket.send(msg)
@@ -92,7 +98,14 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke(
- {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
+ {
+ "get-outhash": {
+ "outhash": outhash,
+ "taskhash": taskhash,
+ "method": method,
+ "with_unihash": with_unihash,
+ }
+ }
)
async def get_stats(self):
@@ -120,6 +133,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
result = await self.invoke({"auth": {"username": username, "token": token}})
self.username = username
self.password = token
+ self.saved_become_user = None
return result
async def refresh_token(self, username=None):
@@ -128,13 +142,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if username:
m["username"] = username
result = await self.invoke({"refresh-token": m})
- if self.username and result["username"] == self.username:
+ if (
+ self.username
+ and not self.saved_become_user
+ and result["username"] == self.username
+ ):
self.password = result["token"]
return result
async def set_user_perms(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"set-user-perms": {"username": username, "permissions": permissions}}
+ )
async def get_user(self, username=None):
await self._set_mode(self.MODE_NORMAL)
@@ -149,12 +169,23 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def new_user(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"new-user": {"username": username, "permissions": permissions}}
+ )
async def delete_user(self, username):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"delete-user": {"username": username}})
+ async def become_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"become-user": {"username": username}})
+ if username == self.username:
+ self.saved_become_user = None
+ else:
+ self.saved_become_user = username
+ return result
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -182,6 +213,7 @@ class Client(bb.asyncrpc.Client):
"get_all_users",
"new_user",
"delete_user",
+ "become_user",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 5c70d81f..d506088e 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -255,6 +255,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"auth": self.handle_auth,
"get-user": self.handle_get_user,
"get-all-users": self.handle_get_all_users,
+ "become-user": self.handle_become_user,
}
)
@@ -706,6 +707,23 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"username": username}
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_become_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
+
+ self.user = user
+
+ self.logger.info("Became user %s", username)
+
+ return {
+ "username": self.user.username,
+ "permissions": self.return_perms(self.user.permissions),
+ }
+
class Server(bb.asyncrpc.AsyncServer):
def __init__(
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f92f37c4..311b7b77 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -728,6 +728,45 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
+ def test_auth_become_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(user["username"])
+ self.assertEqual(become, user_info)
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Verify become user is preserved across disconnect
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # test-user doesn't have become_user permissions, so this should
+ # not work
+ with self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # No self-service of become
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # Give test user permissions to become
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ # It's possible to become yourself (effectively a noop)
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(client.username)
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 15/22] hashserv: Add db-usage API
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (13 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 14/22] hashserv: Add become-user API Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 16/22] hashserv: Add database column query API Joshua Watt
` (6 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to query the server for the usage of the database (e.g. how
many rows are present in each table)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 16 ++++++++++++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 14 ++++++++++++++
lib/hashserv/sqlite.py | 37 +++++++++++++++++++++++++++++++++++++
lib/hashserv/tests.py | 9 +++++++++
6 files changed, 86 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index cfbc197e..5d65c7bc 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -161,6 +161,19 @@ def main():
r = client.delete_user(args.username)
print_user(r)
+ def handle_get_db_usage(args, client):
+ usage = client.get_db_usage()
+ print(usage)
+ tables = sorted(usage.keys())
+ print("{name:20}| {rows:20}".format(name="Table name", rows="Rows"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for t in tables:
+ print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"]))
+ print()
+
+ total_rows = sum(t["rows"] for t in usage.values())
+ print(f"Total rows: {total_rows}")
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -223,6 +236,9 @@ def main():
delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
delete_user_parser.set_defaults(func=handle_delete_user)
+ db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
+ db_usage_parser.set_defaults(func=handle_get_db_usage)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 90f1dd71..0c3f556a 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -186,6 +186,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.saved_become_user = username
return result
+ async def get_db_usage(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-usage": {}}))["usage"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -214,6 +218,7 @@ class Client(bb.asyncrpc.Client):
"new_user",
"delete_user",
"become_user",
+ "get_db_usage",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index d506088e..4fec1556 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -249,6 +249,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ "get-db-usage": self.handle_get_db_usage,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -566,6 +567,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_usage(self, request):
+ return {"usage": await self.db.get_usage()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index bfd8a844..818b5195 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -27,6 +27,7 @@ from sqlalchemy import (
and_,
delete,
update,
+ func,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
@@ -401,3 +402,16 @@ class Database(object):
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ async with self.db.begin() as session:
+ for name, table in Base.metadata.tables.items():
+ statement = select(func.count()).select_from(table)
+ self.logger.debug("%s", statement)
+ result = await self.db.execute(statement)
+ usage[name] = {
+ "rows": result.scalar(),
+ }
+
+ return usage
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 414ee8ff..dfdccbba 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -120,6 +120,18 @@ class Database(object):
self.db = sqlite3.connect(self.dbname)
self.db.row_factory = sqlite3.Row
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute("SELECT sqlite_version()")
+
+ version = []
+ for v in cursor.fetchone()[0].split("."):
+ try:
+ version.append(int(v))
+ except ValueError:
+ version.append(v)
+
+ self.sqlite_version = tuple(version)
+
async def __aenter__(self):
return self
@@ -362,3 +374,28 @@ class Database(object):
)
self.db.commit()
return cursor.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ with closing(self.db.cursor()) as cursor:
+ if self.sqlite_version >= (3, 33):
+ table_name = "sqlite_schema"
+ else:
+ table_name = "sqlite_master"
+
+ cursor.execute(
+ f"""
+ SELECT name FROM {table_name} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
+ """
+ )
+ for row in cursor.fetchall():
+ cursor.execute(
+ """
+ SELECT COUNT() FROM %s
+ """
+ % row["name"],
+ )
+ usage[row["name"]] = {
+ "rows": cursor.fetchone()[0],
+ }
+ return usage
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 311b7b77..9d5bec24 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -767,6 +767,15 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client:
become = client.become_user(client.username)
+ def test_get_db_usage(self):
+ usage = self.client.get_db_usage()
+
+ self.assertTrue(isinstance(usage, dict))
+ for name in usage.keys():
+ self.assertTrue(isinstance(usage[name], dict))
+ self.assertIn("rows", usage[name])
+ self.assertTrue(isinstance(usage[name]["rows"], int))
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 16/22] hashserv: Add database column query API
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (14 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 15/22] hashserv: Add db-usage API Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
` (5 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to retrieve the columns that can be queried on from the
database backend. This prevents front end applications from needing to
hardcode the query columns
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 7 +++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 10 ++++++++++
lib/hashserv/sqlite.py | 7 +++++++
lib/hashserv/tests.py | 8 ++++++++
6 files changed, 42 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 5d65c7bc..58aa02ee 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -174,6 +174,10 @@ def main():
total_rows = sum(t["rows"] for t in usage.values())
print(f"Total rows: {total_rows}")
+ def handle_get_db_query_columns(args, client):
+ columns = client.get_db_query_columns()
+ print("\n".join(sorted(columns)))
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -239,6 +243,9 @@ def main():
db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
db_usage_parser.set_defaults(func=handle_get_db_usage)
+ db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries")
+ db_query_columns_parser.set_defaults(func=handle_get_db_query_columns)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 0c3f556a..319da2d9 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -190,6 +190,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return (await self.invoke({"get-db-usage": {}}))["usage"]
+ async def get_db_query_columns(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-query-columns": {}}))["columns"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -219,6 +223,7 @@ class Client(bb.asyncrpc.Client):
"delete_user",
"become_user",
"get_db_usage",
+ "get_db_query_columns",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 4fec1556..d2fd75df 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -250,6 +250,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
"get-db-usage": self.handle_get_db_usage,
+ "get-db-query-columns": self.handle_get_db_query_columns,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -571,6 +572,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_get_db_usage(self, request):
return {"usage": await self.db.get_usage()}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_query_columns(self, request):
+ return {"columns": await self.db.get_query_columns()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 818b5195..cee04bff 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -415,3 +415,13 @@ class Database(object):
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for table in (UnihashesV2, OuthashesV2):
+ for c in table.__table__.columns:
+ if not isinstance(c.type, Text):
+ continue
+ columns.add(c.key)
+
+ return list(columns)
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index dfdccbba..f65036be 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -399,3 +399,10 @@ class Database(object):
"rows": cursor.fetchone()[0],
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
+ if typ.startswith("TEXT"):
+ columns.add(name)
+ return list(columns)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 9d5bec24..fc69acaf 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -776,6 +776,14 @@ class HashEquivalenceCommonTests(object):
self.assertIn("rows", usage[name])
self.assertTrue(isinstance(usage[name]["rows"], int))
+ def test_get_db_query_columns(self):
+ columns = self.client.get_db_query_columns()
+
+ self.assertTrue(isinstance(columns, list))
+ self.assertTrue(len(columns) > 0)
+
+ for col in columns:
+ self.client.remove({col: ""})
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 17/22] hashserv: test: Add bitbake-hashclient tests
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (15 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 16/22] hashserv: Add database column query API Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
` (4 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
The bitbake-hashclient command-line tool now has a lot more features
which should be tested, so add some tests for them.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 300 ++++++++++++++++++++++++++++++++++++++----
1 file changed, 277 insertions(+), 23 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index fc69acaf..a80ccd57 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -19,6 +19,14 @@ import unittest
import socket
import time
import signal
+import subprocess
+import json
+import re
+from pathlib import Path
+
+
+THIS_DIR = Path(__file__).parent
+BIN_DIR = THIS_DIR.parent.parent / "bin"
def server_prefunc(server, idx):
logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
@@ -103,8 +111,22 @@ class HashEquivalenceTestSetup(object):
result = client.get_unihash(self.METHOD, taskhash)
self.assertEqual(result, unihash)
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
-class HashEquivalenceCommonTests(object):
def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
@@ -117,6 +139,24 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def run_hashclient(self, args, **kwargs):
+ try:
+ p = subprocess.run(
+ [BIN_DIR / "bitbake-hashclient"] + args,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ encoding="utf-8",
+ **kwargs
+ )
+ except subprocess.CalledProcessError as e:
+ print(e.output)
+ raise e
+
+ print(p.stdout)
+ return p
+
+
+class HashEquivalenceCommonTests(object):
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -161,7 +201,7 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash, unihash)
def test_remove_taskhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"taskhash": taskhash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -170,13 +210,13 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_unihash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"unihash": unihash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
def test_remove_outhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"outhash": outhash})
self.assertGreater(result["count"], 0)
@@ -184,7 +224,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_method(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"method": self.METHOD})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -193,7 +233,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_clean_unused(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Clean the database, which should not remove anything because all hashes an in-use
result = self.client.clean_unused(0)
@@ -497,7 +537,7 @@ class HashEquivalenceCommonTests(object):
admin_client = self.start_auth_server()
# Create hashes with non-authenticated server
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Validate hash can be retrieved using authenticated client
with self.auth_perms("@read") as client:
@@ -534,14 +574,6 @@ class HashEquivalenceCommonTests(object):
with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
client.refresh_token()
- def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
- client.auth(user["username"], user["token"])
-
- def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
- client.auth(user["username"], user["token"])
-
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
@@ -650,14 +682,6 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
- def assertUserPerms(self, user, permissions):
- with self.auth_client(user) as client:
- info = client.get_user()
- self.assertEqual(info, {
- "username": user["username"],
- "permissions": permissions,
- })
-
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
@@ -785,6 +809,236 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+
+class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
+ def get_server_addr(self, server_idx):
+ return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
+
+ def test_stats(self):
+ self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+
+ def test_stress(self):
+ self.run_hashclient(["--address", self.server_address, "stress"], check=True)
+
+ def test_remove_taskhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "taskhash", taskhash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_unihash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ def test_remove_outhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "outhash", outhash,
+ ], check=True)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_method(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "method", self.METHOD,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_clean_unused(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+
+ # Clean the database, which should not remove anything because all hashes an in-use
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, unihash)
+
+ # Remove the unihash. The row in the outhash table should still be present
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNotNone(result_outhash)
+
+ # Now clean with no minimum age which will remove the outhash
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNone(result_outhash)
+
+ def test_refresh_token(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "refresh-token"
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ print("New token is %r" % new_token)
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", new_token,
+ "get-user"
+ ], check=True)
+
+ def test_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "set-user-perms",
+ "-u", user["username"],
+ "@read", "@report",
+ ], check=True)
+
+ new_user = admin_client.get_user(user["username"])
+
+ self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
+
+ def test_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-user",
+ "-u", user["username"],
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "get-user",
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ def test_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ admin_client.new_user("test-user1", ["@read"])
+ admin_client.new_user("test-user2", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-all-users",
+ ], check=True)
+
+ self.assertIn("admin", p.stdout)
+ self.assertIn("test-user1", p.stdout)
+ self.assertIn("test-user2", p.stdout)
+
+ def test_new_user(self):
+ admin_client = self.start_auth_server()
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "new-user",
+ "-u", "test-user",
+ "@read", "@report",
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ user = {
+ "username": "test-user",
+ "token": new_token,
+ }
+
+ self.assertUserPerms(user, ["@read", "@report"])
+
+ def test_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "delete-user",
+ "-u", user["username"],
+ ], check=True)
+
+
+ self.assertIsNone(admin_client.get_user(user["username"]))
+
+ def test_get_db_usage(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-usage",
+ ], check=True)
+
+ def test_get_db_query_columns(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-query-columns",
+ ], check=True)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 18/22] bitbake-hashclient: Output stats in JSON format
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (16 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
` (3 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Outputting the stats in JSON format makes more sense as it's easier for
a downstream tool to parse if desired.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 ++-
lib/hashserv/tests.py | 3 ++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 58aa02ee..3ff7b763 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -15,6 +15,7 @@ import threading
import time
import warnings
import netrc
+import json
warnings.simplefilter("default")
try:
@@ -56,7 +57,7 @@ def main():
s = client.reset_stats()
else:
s = client.get_stats()
- pprint.pprint(s)
+ print(json.dumps(s, sort_keys=True, indent=4))
return 0
def handle_stress(args, client):
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index a80ccd57..2d78f9e9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -815,7 +815,8 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
def test_stats(self):
- self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ json.loads(p.stdout)
def test_stress(self):
self.run_hashclient(["--address", self.server_address, "stress"], check=True)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (17 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
` (2 subsequent siblings)
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Space separation is more natural when setting the value from an
environment variable, so allow that here for convenience.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 1085d058..c560b3e5 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -127,7 +127,10 @@ websocket, as in "wss://SERVER:PORT"
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
- anon_perms = args.anon_perms.split(",")
+ if "," in args.anon_perms:
+ anon_perms = args.anon_perms.split(",")
+ else:
+ anon_perms = args.anon_perms.split()
server = hashserv.create_server(
args.bind,
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 20/22] hashserv: tests: Allow authentication for external server tests
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (18 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 21/22] hashserv: Allow self-service deletion Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If BB_TEST_HASHSERV_USERNAME and BB_TEST_HASHSERV_PASSWORD are provided
for a server admin user, the authentication tests for the external
hashserver will run. In addition, any users that get created will now be
deleted when the test finishes.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 109 ++++++++++++++++++++++++++++--------------
1 file changed, 74 insertions(+), 35 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 2d78f9e9..5d209ffb 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -84,17 +84,13 @@ class HashEquivalenceTestSetup(object):
return self.server.address
def start_auth_server(self):
- self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
- self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.auth_server_address = auth_server.address
+ self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
return self.admin_client
def auth_client(self, user):
- return self.start_client(self.auth_server.address, user["username"], user["token"])
-
- def auth_perms(self, *permissions):
- self.client_index += 1
- user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
- return self.auth_client(user)
+ return self.start_client(self.auth_server_address, user["username"], user["token"])
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -120,11 +116,11 @@ class HashEquivalenceTestSetup(object):
})
def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
+ with self.start_client(self.auth_server_address) as client:
client.auth(user["username"], user["token"])
def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.auth(user["username"], user["token"])
def create_test_hash(self, client):
@@ -157,6 +153,26 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.create_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
+
+ def create_user(self, username, permissions, *, client=None):
+ def remove_user(username):
+ try:
+ self.admin_client.delete_user(username)
+ except bb.asyncrpc.InvokeError:
+ pass
+
+ if client is None:
+ client = self.admin_client
+
+ user = client.new_user(username, permissions)
+ self.addCleanup(remove_user, username)
+
+ return user
+
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -571,14 +587,14 @@ class HashEquivalenceCommonTests(object):
def test_auth_no_token_refresh_from_anon_user(self):
self.start_auth_server()
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.refresh_token()
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
# Create a new user with no permissions
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client:
new_user = client.refresh_token()
@@ -601,7 +617,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_token_refresh(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.refresh_token(user["username"])
@@ -617,7 +633,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_self_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -632,7 +648,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -649,7 +665,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_reconnect(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -665,7 +681,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_delete_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
# No self service
with self.auth_client(user) as client, self.assertRaises(InvokeError):
@@ -685,7 +701,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
self.assertUserPerms(user, [])
@@ -710,7 +726,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_all_users(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client, self.assertRaises(InvokeError):
client.get_all_users()
@@ -744,10 +760,10 @@ class HashEquivalenceCommonTests(object):
permissions.sort()
with self.auth_perms() as client, self.assertRaises(InvokeError):
- client.new_user("test-user", permissions)
+ self.create_user("test-user", permissions, client=client)
with self.auth_perms("@user-admin") as client:
- user = client.new_user("test-user", permissions)
+ user = self.create_user("test-user", permissions, client=client)
self.assertIn("token", user)
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
@@ -755,7 +771,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_become_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", ["@read", "@report"])
+ user = self.create_user("test-user", ["@read", "@report"])
user_info = user.copy()
del user_info["token"]
@@ -898,7 +914,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read", "@report"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"refresh-token"
@@ -916,7 +932,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
print("New token is %r" % new_token)
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", new_token,
"get-user"
@@ -928,7 +944,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"set-user-perms",
@@ -946,7 +962,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-user",
@@ -957,7 +973,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
self.assertIn("Permissions:", p.stdout)
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"get-user",
@@ -973,7 +989,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client.new_user("test-user2", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-all-users",
@@ -987,7 +1003,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client = self.start_auth_server()
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"new-user",
@@ -1017,14 +1033,13 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"delete-user",
"-u", user["username"],
], check=True)
-
self.assertIsNone(admin_client.get_user(user["username"]))
def test_get_db_usage(self):
@@ -1104,19 +1119,43 @@ class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocket
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
- def start_test_server(self):
- if 'BB_TEST_HASHSERV' not in os.environ:
- self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+ def get_env(self, name):
+ v = os.environ.get(name)
+ if not v:
+ self.skipTest(f'{name} not defined to test an external server')
+ return v
- return os.environ['BB_TEST_HASHSERV']
+ def start_test_server(self):
+ return self.get_env('BB_TEST_HASHSERV')
def start_server(self, *args, **kwargs):
self.skipTest('Cannot start local server when testing external servers')
+ def start_auth_server(self):
+
+ self.auth_server_address = self.server_address
+ self.admin_client = self.start_client(
+ self.server_address,
+ username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
+ password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
+ )
+ return self.admin_client
+
def setUp(self):
super().setUp()
+ if "BB_TEST_HASHSERV_USERNAME" in os.environ:
+ self.client = self.start_client(
+ self.server_address,
+ username=os.environ["BB_TEST_HASHSERV_USERNAME"],
+ password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
+ )
self.client.remove({"method": self.METHOD})
def tearDown(self):
self.client.remove({"method": self.METHOD})
super().tearDown()
+
+
+ def test_auth_get_all_users(self):
+ self.skipTest("Cannot test all users with external server")
+
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 21/22] hashserv: Allow self-service deletion
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (19 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows users to self-service deletion of their own user accounts
(meaning, they can delete their own accounts without special
permissions).
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 2 +-
lib/hashserv/tests.py | 7 +++++--
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index d2fd75df..6da56df7 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -708,7 +708,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"token": token,
}
- @permissions(USER_ADMIN_PERM, allow_anon=False)
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
async def handle_delete_user(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 5d209ffb..f0be8679 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -683,10 +683,13 @@ class HashEquivalenceCommonTests(object):
user = self.create_user("test-user", [])
- # No self service
- with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ # self service
+ with self.auth_client(user) as client:
client.delete_user(user["username"])
+ self.assertIsNone(admin_client.get_user(user["username"]))
+ user = self.create_user("test-user", [])
+
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v5 22/22] hashserv: server: Add owner if user is logged in
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
` (20 preceding siblings ...)
2023-11-01 15:42 ` [bitbake-devel][PATCH v5 21/22] hashserv: Allow self-service deletion Joshua Watt
@ 2023-11-01 15:42 ` Joshua Watt
21 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-01 15:42 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If a user is authenticated with the server, report them as the owner of
a report
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 3 +++
lib/hashserv/tests.py | 9 +++++++++
2 files changed, 12 insertions(+)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 6da56df7..a9714b5b 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -474,6 +474,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in data:
outhash_data[k] = data[k]
+ if self.user:
+ outhash_data["owner"] = self.user.username
+
# Insert the new entry, unless it already exists
if await self.db.insert_outhash(outhash_data):
# If this row is new, check if it is equivalent to another
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f0be8679..a9e6fdf9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -828,6 +828,15 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+ def test_auth_is_owner(self):
+ admin_client = self.start_auth_server()
+
+ user = self.create_user("test-user", ["@read", "@report"])
+ with self.auth_client(user) as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ data = client.get_taskhash(self.METHOD, taskhash, True)
+ self.assertEqual(data["owner"], user["username"])
+
class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread
* [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management
2023-10-31 17:21 ` [bitbake-devel][PATCH v4 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (23 preceding siblings ...)
2023-11-01 15:41 ` [bitbake-devel][PATCH v5 " Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 01/22] asyncrpc: Abstract sockets Joshua Watt
` (22 more replies)
24 siblings, 23 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
This patch series reworks the bitbake asyncrpc API to add a WebSockets
implementation for both the client and server. The hash equivalence
server is updated to allow using this new API (the PR server can also be
updated in the future if desired).
In addition, the database backed for the hash equivalence server is
abstracted so that sqlalchemy can optionally be used instead of sqlite.
This allows using "big metal" databases as the backend, which allows the
hash equivalence server to scale to a large number of queries.
Note that both websockets and sqlalchemy require 3rd party python
modules to function. However, these modules are optional unless the user
desires to use the APIs.
Also, user management is added. This allows user accounts to be
registered with the server and users can be given permissions to do
certain operations on the server. Users are not (necessarily) required
to login to access the server, as permissions can granted to anonymous
users. The default permissions will give anonymous users the same
permissions that they would have before user accounts were added so as
to retain backward compatibility, but server admins will likely want to
change this.
V3: Remove RFC status; patches are ready for review
V4: Fixed protocol breakage with mixing older and newer clients/servers
V5: Fixed compatibility with Python 3.8
V6: Fixed protocol incompatibility when exiting stream state that broke
mixing older and new clients/servers
Joshua Watt (22):
asyncrpc: Abstract sockets
hashserv: Add websocket connection implementation
asyncrpc: Add context manager API
hashserv: tests: Add external database tests
asyncrpc: Prefix log messages with client info
bitbake-hashserv: Allow arguments from environment
hashserv: Abstract database
hashserv: Add SQLalchemy backend
hashserv: Implement read-only version of "report" RPC
asyncrpc: Add InvokeError
asyncrpc: client: Prevent double closing of loop
asyncrpc: client: Add disconnect API
hashserv: Add user permissions
hashserv: Add become-user API
hashserv: Add db-usage API
hashserv: Add database column query API
hashserv: test: Add bitbake-hashclient tests
bitbake-hashclient: Output stats in JSON format
bitbake-hashserver: Allow anonymous permissions to be space separated
hashserv: tests: Allow authentication for external server tests
hashserv: Allow self-service deletion
hashserv: server: Add owner if user is logged in
bin/bitbake-hashclient | 145 +++++-
bin/bitbake-hashserv | 132 ++++-
lib/bb/asyncrpc/__init__.py | 33 +-
lib/bb/asyncrpc/client.py | 120 ++---
lib/bb/asyncrpc/connection.py | 146 ++++++
lib/bb/asyncrpc/exceptions.py | 21 +
lib/bb/asyncrpc/serv.py | 365 ++++++++-----
lib/hashserv/__init__.py | 190 +++----
lib/hashserv/client.py | 147 +++++-
lib/hashserv/server.py | 952 +++++++++++++++++++++-------------
lib/hashserv/sqlalchemy.py | 427 +++++++++++++++
lib/hashserv/sqlite.py | 408 +++++++++++++++
lib/hashserv/tests.py | 736 +++++++++++++++++++++++++-
lib/prserv/client.py | 8 +-
lib/prserv/serv.py | 37 +-
15 files changed, 3060 insertions(+), 807 deletions(-)
create mode 100644 lib/bb/asyncrpc/connection.py
create mode 100644 lib/bb/asyncrpc/exceptions.py
create mode 100644 lib/hashserv/sqlalchemy.py
create mode 100644 lib/hashserv/sqlite.py
--
2.34.1
^ permalink raw reply [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 01/22] asyncrpc: Abstract sockets
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation Joshua Watt
` (21 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 32 +---
lib/bb/asyncrpc/client.py | 78 +++------
lib/bb/asyncrpc/connection.py | 95 +++++++++++
lib/bb/asyncrpc/exceptions.py | 17 ++
lib/bb/asyncrpc/serv.py | 304 +++++++++++++++++-----------------
lib/hashserv/__init__.py | 21 ---
lib/hashserv/client.py | 38 ++---
lib/hashserv/server.py | 116 ++++++-------
lib/prserv/client.py | 8 +-
lib/prserv/serv.py | 31 ++--
10 files changed, 387 insertions(+), 353 deletions(-)
create mode 100644 lib/bb/asyncrpc/connection.py
create mode 100644 lib/bb/asyncrpc/exceptions.py
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9a85e996..9f677eac 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-import itertools
-import json
-
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
-
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
from .client import AsyncClient, Client
-from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+ ClientError,
+ ServerError,
+ ConnectionClosedError,
+)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe..7f33099b 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
- os.chdir(cwd)
- return await asyncio.open_unix_connection(sock=sock)
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ # End of headers
+ await self.socket.send("")
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
- self.reader = None
-
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
-
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
-
- m = json.loads("".join(lines))
-
- return m
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
new file mode 100644
index 00000000..c4fd2475
--- /dev/null
+++ b/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+ if len(msg) < max_chunk - 1:
+ yield "".join((msg, "\n"))
+ else:
+ yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+ args = [iter(msg)] * (max_chunk - 1)
+ for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+ yield "".join(itertools.chain(m, "\n"))
+ yield "\n"
+
+
+class StreamConnection(object):
+ def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+ self.reader = reader
+ self.writer = writer
+ self.timeout = timeout
+ self.max_chunk = max_chunk
+
+ @property
+ def address(self):
+ return self.writer.get_extra_info("peername")
+
+ async def send_message(self, msg):
+ for c in chunkify(json.dumps(msg), self.max_chunk):
+ self.writer.write(c.encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv_message(self):
+ l = await self.recv()
+
+ m = json.loads(l)
+ if not m:
+ return m
+
+ if "chunk-stream" in m:
+ lines = []
+ while True:
+ l = await self.recv()
+ if not l:
+ break
+ lines.append(l)
+
+ m = json.loads("".join(lines))
+
+ return m
+
+ async def send(self, msg):
+ self.writer.write(("%s\n" % msg).encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv(self):
+ if self.timeout < 0:
+ line = await self.reader.readline()
+ else:
+ try:
+ line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+
+ if not line:
+ raise ConnectionClosedError("Connection closed")
+
+ line = line.decode("utf-8")
+
+ if not line.endswith("\n"):
+ raise ConnectionError("Bad message %r" % (line))
+
+ return line.rstrip()
+
+ async def close(self):
+ self.reader = None
+ if self.writer is not None:
+ self.writer.close()
+ self.writer = None
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 00000000..a8942b4f
--- /dev/null
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+ pass
+
+
+class ServerError(Exception):
+ pass
+
+
+class ConnectionClosedError(Exception):
+ pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index d2de4891..3e0d0632 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,248 @@ import signal
import socket
import sys
import multiprocessing
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
- pass
-
-
-class ServerError(Exception):
- pass
+from .connection import StreamConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
- def __init__(self, reader, writer, proto_name, logger):
- self.reader = reader
- self.writer = writer
+ # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
+ # return message will be automatically be sent back to the client
+ NO_RESPONSE = object()
+
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
self.proto_name = proto_name
- self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
- 'chunk-stream': self.handle_chunk,
- 'ping': self.handle_ping,
+ "ping": self.handle_ping,
}
self.logger = logger
+ async def close(self):
+ await self.socket.close()
+
async def process_requests(self):
try:
- self.addr = self.writer.get_extra_info('peername')
- self.logger.debug('Client %r connected' % (self.addr,))
+ self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
- client_protocol = await self.reader.readline()
+ client_protocol = await self.socket.recv()
if not client_protocol:
return
- (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+ (client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
- self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
- self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
- self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
- line = await self.reader.readline()
- if not line:
- return
-
- line = line.decode('utf-8').rstrip()
- if not line:
+ header = await self.socket.recv()
+ if not header:
break
# Handle messages
while True:
- d = await self.read_message()
+ d = await self.socket.recv_message()
if d is None:
break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
+ response = await self.dispatch_message(d)
+ if response is not self.NO_RESPONSE:
+ await self.socket.send_message(response)
+
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
- self.writer.close()
+ await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- await self.handlers[k](msg[k])
- return
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
+ async def handle_ping(self, request):
+ return {"alive": True}
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
- try:
- message = l.decode('utf-8')
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
- if not message.endswith('\n'):
- return None
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
+
+ await self.handler(socket)
+
+ async def stop(self):
+ self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
+
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ pass
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % lines)
- raise e
+ def start(self, loop):
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
+ )
+ finally:
+ os.chdir(cwd)
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
- await self.dispatch_message(msg)
+ async def stop(self):
+ await super().stop()
+ self.server.close()
- async def handle_ping(self, request):
- response = {'alive': True}
- self.write_message(response)
+ def cleanup(self):
+ os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
- self._cleanup_socket = None
self.logger = logger
- self.start = None
- self.address = None
self.loop = None
+ self.run_tasks = []
def start_tcp_server(self, host, port):
- def start_tcp():
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port)
- )
-
- for s in self.server.sockets:
- self.logger.debug('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- # Enable keep alives. This prevents broken client connections
- # from persisting on the server for long periods of time.
- s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
-
- self.start = start_tcp
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
-
- def start_unix():
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path))
- )
- finally:
- os.chdir(cwd)
-
- self.logger.debug('Listening on %r' % path)
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
-
- self.start = start_unix
-
- @abc.abstractmethod
- def accept_client(self, reader, writer):
- pass
-
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
+ async def _client_handler(self, socket):
try:
- client = self.accept_client(reader, writer)
+ client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+ self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
- writer.close()
- self.logger.debug('Client disconnected')
+ await socket.close()
+ self.logger.debug("Client disconnected")
- def run_loop_forever(self):
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
- self.loop.stop()
+ self.loop.create_task(self.stop())
- def _serve_forever(self):
+ def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
- self.run_loop_forever()
- self.server.close()
+ self.loop.run_until_complete(asyncio.gather(*tasks))
- self.loop.run_until_complete(self.server.wait_closed())
- self.logger.debug('Server shutting down')
+ self.logger.debug("Server shutting down")
finally:
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
- self.start()
- self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
+
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@@ -259,18 +266,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
+ self._create_loop()
try:
- self.start()
+ self.address = None
+ tasks = self.start()
finally:
+ # Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
- self._serve_forever()
+ self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57..3a401835 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
-
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267..5f7d22ab 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
- self.writer.write(("%s\n" % msg).encode("utf-8"))
- await self.writer.drain()
- l = await self.reader.readline()
- if not l:
- raise ConnectionError("Connection closed")
- return l.decode("utf-8").rstrip()
+ await self.socket.send(msg)
+ return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
+ async def stream_to_normal():
+ await self.socket.send("END")
+ return await self.socket.recv()
+
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
- r = await self.send_stream("END")
+ r = await self._send_wrapper(stream_to_normal)
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
- r = await self.send_message({"get-stream": None})
+ r = await self.invoke({"get-stream": None})
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
- return await self.send_message({"report": m})
+ return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
- return await self.send_message({"report-equiv": m})
+ return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"get-stats": None})
+ return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"reset-stats": None})
+ return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
- return (await self.send_message({"backfill-wait": None}))["tasks"]
+ return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"remove": {"where": where}})
+ return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+ return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476b..13b75480 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+ def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+ super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
- await self.handlers[k](msg[k])
+ return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
- d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
- self.write_message(d)
+ return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
- d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
- self.write_message(d)
+ return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
- self.write_message('ok')
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
+ break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,30 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
+ msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
+ msg = upstream
else:
- msg = "\n".encode("utf-8")
+ msg = ""
else:
- msg = '\n'.encode('utf-8')
+ msg = ""
- self.writer.write(msg)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
-
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
+ await self.socket.send("ok")
+ return self.NO_RESPONSE
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@@ -468,7 +462,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
- self.write_message(d)
+ return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@@ -491,30 +485,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
+ return d
async def handle_get_stats(self, request):
- d = {
+ return {
'requests': self.request_stats.todict(),
}
- self.write_message(d)
-
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
- self.write_message(d)
+ return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
- self.write_message(d)
+ return d
async def handle_remove(self, request):
condition = request["where"]
@@ -541,7 +533,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
- self.write_message({"count": count})
+ return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@@ -558,7 +550,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
- self.write_message({"count": count})
+ return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@@ -583,41 +575,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ def accept_client(self, socket):
+ return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ async def backfill_worker_task(self):
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
self.backfill_queue.task_done()
- finally:
- await client.close()
+ break
+ method, taskhash = item
+ await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
- async def join_worker(worker):
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+ return tasks
+
+ async def stop(self):
+ if self.backfill_queue is not None:
await self.backfill_queue.put(None)
- await worker
-
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
-
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
-
- with self._backfill_worker():
- super().run_loop_forever()
+ await super().stop()
diff --git a/lib/prserv/client.py b/lib/prserv/client.py
index 69ab7a4a..6b81356f 100644
--- a/lib/prserv/client.py
+++ b/lib/prserv/client.py
@@ -14,28 +14,28 @@ class PRAsyncClient(bb.asyncrpc.AsyncClient):
super().__init__('PRSERVICE', '1.0', logger)
async def getPR(self, version, pkgarch, checksum):
- response = await self.send_message(
+ response = await self.invoke(
{'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
)
if response:
return response['value']
async def importone(self, version, pkgarch, checksum, value):
- response = await self.send_message(
+ response = await self.invoke(
{'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
)
if response:
return response['value']
async def export(self, version, pkgarch, checksum, colinfo):
- response = await self.send_message(
+ response = await self.invoke(
{'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
)
if response:
return (response['metainfo'], response['datainfo'])
async def is_readonly(self):
- response = await self.send_message(
+ response = await self.invoke(
{'is-readonly': {}}
)
if response:
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index c686b206..ea793316 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -20,8 +20,8 @@ PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
singleton = None
class PRServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, table, read_only):
- super().__init__(reader, writer, 'PRSERVICE', logger)
+ def __init__(self, socket, table, read_only):
+ super().__init__(socket, 'PRSERVICE', logger)
self.handlers.update({
'get-pr': self.handle_get_pr,
'import-one': self.handle_import_one,
@@ -36,12 +36,12 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
try:
- await super().dispatch_message(msg)
+ return await super().dispatch_message(msg)
except:
self.table.sync()
raise
-
- self.table.sync_if_dirty()
+ else:
+ self.table.sync_if_dirty()
async def handle_get_pr(self, request):
version = request['version']
@@ -57,7 +57,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
except sqlite3.Error as exc:
logger.error(str(exc))
- self.write_message(response)
+ return response
async def handle_import_one(self, request):
response = None
@@ -71,7 +71,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
if value is not None:
response = {'value': value}
- self.write_message(response)
+ return response
async def handle_export(self, request):
version = request['version']
@@ -85,12 +85,10 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
logger.error(str(exc))
metainfo = datainfo = None
- response = {'metainfo': metainfo, 'datainfo': datainfo}
- self.write_message(response)
+ return {'metainfo': metainfo, 'datainfo': datainfo}
async def handle_is_readonly(self, request):
- response = {'readonly': self.read_only}
- self.write_message(response)
+ return {'readonly': self.read_only}
class PRServer(bb.asyncrpc.AsyncServer):
def __init__(self, dbfile, read_only=False):
@@ -99,20 +97,23 @@ class PRServer(bb.asyncrpc.AsyncServer):
self.table = None
self.read_only = read_only
- def accept_client(self, reader, writer):
- return PRServerClient(reader, writer, self.table, self.read_only)
+ def accept_client(self, socket):
+ return PRServerClient(socket, self.table, self.read_only)
- def _serve_forever(self):
+ def start(self):
+ tasks = super().start()
self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
self.table = self.db["PRMAIN"]
logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
(self.dbfile, self.address, str(os.getpid())))
- super()._serve_forever()
+ return tasks
+ async def stop(self):
self.table.sync_if_dirty()
self.db.disconnect()
+ await super().stop()
def signal_handler(self):
super().signal_handler()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 01/22] asyncrpc: Abstract sockets Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-10 12:03 ` Matthias Schnelte
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 03/22] asyncrpc: Add context manager API Joshua Watt
` (20 subsequent siblings)
22 siblings, 1 reply; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support to the hash equivalence client and server to communicate
over websockets. Since websockets are message orientated instead of
stream orientated, and new connection class is needed to handle them.
Note that websocket support does require the 3rd party websockets python
module be installed on the host, but it should not be required unless
websockets are actually being used.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 11 +++++++-
lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
lib/bb/asyncrpc/serv.py | 53 ++++++++++++++++++++++++++++++++++-
lib/hashserv/__init__.py | 13 +++++++++
lib/hashserv/client.py | 1 +
lib/hashserv/tests.py | 17 +++++++++++
6 files changed, 137 insertions(+), 2 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 7f33099b..802c07df 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,7 +10,7 @@ import json
import os
import socket
import sys
-from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError
@@ -47,6 +47,15 @@ class AsyncClient(object):
self._connect_sock = connect_sock
+ async def connect_websocket(self, uri):
+ import websockets
+
+ async def connect_sock():
+ websocket = await websockets.connect(uri, ping_interval=None)
+ return WebsocketConnection(websocket, self.timeout)
+
+ self._connect_sock = connect_sock
+
async def setup_connection(self):
# Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index c4fd2475..a10628f7 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -93,3 +93,47 @@ class StreamConnection(object):
if self.writer is not None:
self.writer.close()
self.writer = None
+
+
+class WebsocketConnection(object):
+ def __init__(self, socket, timeout):
+ self.socket = socket
+ self.timeout = timeout
+
+ @property
+ def address(self):
+ return ":".join(str(s) for s in self.socket.remote_address)
+
+ async def send_message(self, msg):
+ await self.send(json.dumps(msg))
+
+ async def recv_message(self):
+ m = await self.recv()
+ return json.loads(m)
+
+ async def send(self, msg):
+ import websockets.exceptions
+
+ try:
+ await self.socket.send(msg)
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def recv(self):
+ import websockets.exceptions
+
+ try:
+ if self.timeout < 0:
+ return await self.socket.recv()
+
+ try:
+ return await asyncio.wait_for(self.socket.recv(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def close(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 3e0d0632..dfb03773 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,7 +12,7 @@ import signal
import socket
import sys
import multiprocessing
-from .connection import StreamConnection
+from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
@@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
os.unlink(self.path)
+class WebsocketsServer(object):
+ def __init__(self, host, port, handler, logger):
+ self.host = host
+ self.port = port
+ self.handler = handler
+ self.logger = logger
+
+ def start(self, loop):
+ import websockets.server
+
+ self.server = loop.run_until_complete(
+ websockets.server.serve(
+ self.client_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ )
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "ws://[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "ws://%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ self.server.close()
+
+ def cleanup(self):
+ pass
+
+ async def client_handler(self, websocket):
+ socket = WebsocketConnection(websocket, -1)
+ await self.handler(socket)
+
+
class AsyncServer(object):
def __init__(self, logger):
self.logger = logger
@@ -190,6 +238,9 @@ class AsyncServer(object):
def start_unix_server(self, path):
self.server = UnixStreamServer(path, self._client_handler, self.logger)
+ def start_websocket_server(self, host, port):
+ self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
+
async def _client_handler(self, socket):
try:
client = self.accept_client(socket)
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 3a401835..56b9c6bc 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -9,11 +9,15 @@ import re
import sqlite3
import itertools
import json
+from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
+WS_PREFIX = "ws://"
+WSS_PREFIX = "wss://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
+ADDR_TYPE_WS = 2
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
@@ -84,6 +88,8 @@ def setup_database(database, sync=True):
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
+ return (ADDR_TYPE_WS, (addr,))
else:
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
if m is not None:
@@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
s.start_unix_server(*a)
+ elif typ == ADDR_TYPE_WS:
+ url = urlparse(a[0])
+ s.start_websocket_server(url.hostname, url.port)
else:
s.start_tcp_server(*a)
@@ -116,6 +125,8 @@ def create_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
else:
c.connect_tcp(*a)
@@ -128,6 +139,8 @@ async def create_async_client(addr):
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
else:
await c.connect_tcp(*a)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 5f7d22ab..9542d72f 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
super().__init__()
self._add_methods(
"connect_tcp",
+ "connect_websocket",
"get_unihash",
"report_unihash",
"report_unihash_equiv",
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f343c586..01ffd52c 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
# If IPv6 is enabled, it should be safe to use localhost directly, in general
# case it is more reliable to resolve the IP address explicitly.
return socket.gethostbyname("localhost") + ":0"
+
+
+class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def setUp(self):
+ try:
+ import websockets
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def get_server_addr(self, server_idx):
+ # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
+ # If IPv6 is enabled, it should be safe to use localhost directly, in general
+ # case it is more reliable to resolve the IP address explicitly.
+ host = socket.gethostbyname("localhost")
+ return "ws://%s:0" % host
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation Joshua Watt
@ 2023-11-10 12:03 ` Matthias Schnelte
2023-11-10 14:11 ` Joshua Watt
0 siblings, 1 reply; 138+ messages in thread
From: Matthias Schnelte @ 2023-11-10 12:03 UTC (permalink / raw)
To: Joshua Watt, bitbake-devel
Hi Joshua,
thanks for this change! Being able to use websockets instead of some tcp
connection would help a lot in cooperate setups which are often
restricted to only http(s) ports and enforce the use of a cooperate proxy.
Unfortunately the websocket library you are using seems not to support
websockets over http proxy. At least that is what I understood.
Would it be possible to use another client lib for websockets in order
to support connection through proxy?
This library seems to support it:
https://websocket-client.readthedocs.io/en/latest/examples.html#connecting-through-a-proxy
Matthias
On 03.11.23 15:26, Joshua Watt wrote:
> Adds support to the hash equivalence client and server to communicate
> over websockets. Since websockets are message orientated instead of
> stream orientated, and new connection class is needed to handle them.
>
> Note that websocket support does require the 3rd party websockets python
> module be installed on the host, but it should not be required unless
> websockets are actually being used.
>
> Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
> ---
> lib/bb/asyncrpc/client.py | 11 +++++++-
> lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
> lib/bb/asyncrpc/serv.py | 53 ++++++++++++++++++++++++++++++++++-
> lib/hashserv/__init__.py | 13 +++++++++
> lib/hashserv/client.py | 1 +
> lib/hashserv/tests.py | 17 +++++++++++
> 6 files changed, 137 insertions(+), 2 deletions(-)
>
> diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
> index 7f33099b..802c07df 100644
> --- a/lib/bb/asyncrpc/client.py
> +++ b/lib/bb/asyncrpc/client.py
> @@ -10,7 +10,7 @@ import json
> import os
> import socket
> import sys
> -from .connection import StreamConnection, DEFAULT_MAX_CHUNK
> +from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
> from .exceptions import ConnectionClosedError
>
>
> @@ -47,6 +47,15 @@ class AsyncClient(object):
>
> self._connect_sock = connect_sock
>
> + async def connect_websocket(self, uri):
> + import websockets
> +
> + async def connect_sock():
> + websocket = await websockets.connect(uri, ping_interval=None)
> + return WebsocketConnection(websocket, self.timeout)
> +
> + self._connect_sock = connect_sock
> +
> async def setup_connection(self):
> # Send headers
> await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
> diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
> index c4fd2475..a10628f7 100644
> --- a/lib/bb/asyncrpc/connection.py
> +++ b/lib/bb/asyncrpc/connection.py
> @@ -93,3 +93,47 @@ class StreamConnection(object):
> if self.writer is not None:
> self.writer.close()
> self.writer = None
> +
> +
> +class WebsocketConnection(object):
> + def __init__(self, socket, timeout):
> + self.socket = socket
> + self.timeout = timeout
> +
> + @property
> + def address(self):
> + return ":".join(str(s) for s in self.socket.remote_address)
> +
> + async def send_message(self, msg):
> + await self.send(json.dumps(msg))
> +
> + async def recv_message(self):
> + m = await self.recv()
> + return json.loads(m)
> +
> + async def send(self, msg):
> + import websockets.exceptions
> +
> + try:
> + await self.socket.send(msg)
> + except websockets.exceptions.ConnectionClosed:
> + raise ConnectionClosedError("Connection closed")
> +
> + async def recv(self):
> + import websockets.exceptions
> +
> + try:
> + if self.timeout < 0:
> + return await self.socket.recv()
> +
> + try:
> + return await asyncio.wait_for(self.socket.recv(), self.timeout)
> + except asyncio.TimeoutError:
> + raise ConnectionError("Timed out waiting for data")
> + except websockets.exceptions.ConnectionClosed:
> + raise ConnectionClosedError("Connection closed")
> +
> + async def close(self):
> + if self.socket is not None:
> + await self.socket.close()
> + self.socket = None
> diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
> index 3e0d0632..dfb03773 100644
> --- a/lib/bb/asyncrpc/serv.py
> +++ b/lib/bb/asyncrpc/serv.py
> @@ -12,7 +12,7 @@ import signal
> import socket
> import sys
> import multiprocessing
> -from .connection import StreamConnection
> +from .connection import StreamConnection, WebsocketConnection
> from .exceptions import ClientError, ServerError, ConnectionClosedError
>
>
> @@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
> os.unlink(self.path)
>
>
> +class WebsocketsServer(object):
> + def __init__(self, host, port, handler, logger):
> + self.host = host
> + self.port = port
> + self.handler = handler
> + self.logger = logger
> +
> + def start(self, loop):
> + import websockets.server
> +
> + self.server = loop.run_until_complete(
> + websockets.server.serve(
> + self.client_handler,
> + self.host,
> + self.port,
> + ping_interval=None,
> + )
> + )
> +
> + for s in self.server.sockets:
> + self.logger.debug("Listening on %r" % (s.getsockname(),))
> +
> + # Enable keep alives. This prevents broken client connections
> + # from persisting on the server for long periods of time.
> + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
> + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
> + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
> + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
> +
> + name = self.server.sockets[0].getsockname()
> + if self.server.sockets[0].family == socket.AF_INET6:
> + self.address = "ws://[%s]:%d" % (name[0], name[1])
> + else:
> + self.address = "ws://%s:%d" % (name[0], name[1])
> +
> + return [self.server.wait_closed()]
> +
> + async def stop(self):
> + self.server.close()
> +
> + def cleanup(self):
> + pass
> +
> + async def client_handler(self, websocket):
> + socket = WebsocketConnection(websocket, -1)
> + await self.handler(socket)
> +
> +
> class AsyncServer(object):
> def __init__(self, logger):
> self.logger = logger
> @@ -190,6 +238,9 @@ class AsyncServer(object):
> def start_unix_server(self, path):
> self.server = UnixStreamServer(path, self._client_handler, self.logger)
>
> + def start_websocket_server(self, host, port):
> + self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
> +
> async def _client_handler(self, socket):
> try:
> client = self.accept_client(socket)
> diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
> index 3a401835..56b9c6bc 100644
> --- a/lib/hashserv/__init__.py
> +++ b/lib/hashserv/__init__.py
> @@ -9,11 +9,15 @@ import re
> import sqlite3
> import itertools
> import json
> +from urllib.parse import urlparse
>
> UNIX_PREFIX = "unix://"
> +WS_PREFIX = "ws://"
> +WSS_PREFIX = "wss://"
>
> ADDR_TYPE_UNIX = 0
> ADDR_TYPE_TCP = 1
> +ADDR_TYPE_WS = 2
>
> UNIHASH_TABLE_DEFINITION = (
> ("method", "TEXT NOT NULL", "UNIQUE"),
> @@ -84,6 +88,8 @@ def setup_database(database, sync=True):
> def parse_address(addr):
> if addr.startswith(UNIX_PREFIX):
> return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
> + elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
> + return (ADDR_TYPE_WS, (addr,))
> else:
> m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
> if m is not None:
> @@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
> (typ, a) = parse_address(addr)
> if typ == ADDR_TYPE_UNIX:
> s.start_unix_server(*a)
> + elif typ == ADDR_TYPE_WS:
> + url = urlparse(a[0])
> + s.start_websocket_server(url.hostname, url.port)
> else:
> s.start_tcp_server(*a)
>
> @@ -116,6 +125,8 @@ def create_client(addr):
> (typ, a) = parse_address(addr)
> if typ == ADDR_TYPE_UNIX:
> c.connect_unix(*a)
> + elif typ == ADDR_TYPE_WS:
> + c.connect_websocket(*a)
> else:
> c.connect_tcp(*a)
>
> @@ -128,6 +139,8 @@ async def create_async_client(addr):
> (typ, a) = parse_address(addr)
> if typ == ADDR_TYPE_UNIX:
> await c.connect_unix(*a)
> + elif typ == ADDR_TYPE_WS:
> + await c.connect_websocket(*a)
> else:
> await c.connect_tcp(*a)
>
> diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
> index 5f7d22ab..9542d72f 100644
> --- a/lib/hashserv/client.py
> +++ b/lib/hashserv/client.py
> @@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
> super().__init__()
> self._add_methods(
> "connect_tcp",
> + "connect_websocket",
> "get_unihash",
> "report_unihash",
> "report_unihash_equiv",
> diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
> index f343c586..01ffd52c 100644
> --- a/lib/hashserv/tests.py
> +++ b/lib/hashserv/tests.py
> @@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
> # If IPv6 is enabled, it should be safe to use localhost directly, in general
> # case it is more reliable to resolve the IP address explicitly.
> return socket.gethostbyname("localhost") + ":0"
> +
> +
> +class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
> + def setUp(self):
> + try:
> + import websockets
> + except ImportError as e:
> + self.skipTest(str(e))
> +
> + super().setUp()
> +
> + def get_server_addr(self, server_idx):
> + # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
> + # If IPv6 is enabled, it should be safe to use localhost directly, in general
> + # case it is more reliable to resolve the IP address explicitly.
> + host = socket.gethostbyname("localhost")
> + return "ws://%s:0" % host
>
> -=-=-=-=-=-=-=-=-=-=-=-
> Links: You receive all messages sent to this group.
> View/Reply Online (#15423): https://lists.openembedded.org/g/bitbake-devel/message/15423
> Mute This Topic: https://lists.openembedded.org/mt/102364905/7851872
> Group Owner: bitbake-devel+owner@lists.openembedded.org
> Unsubscribe: https://lists.openembedded.org/g/bitbake-devel/unsub [develop@schnelte.de]
> -=-=-=-=-=-=-=-=-=-=-=-
>
^ permalink raw reply [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation
2023-11-10 12:03 ` Matthias Schnelte
@ 2023-11-10 14:11 ` Joshua Watt
2023-11-15 7:44 ` Matthias Schnelte
0 siblings, 1 reply; 138+ messages in thread
From: Joshua Watt @ 2023-11-10 14:11 UTC (permalink / raw)
To: Matthias Schnelte; +Cc: bitbake-devel
[-- Attachment #1: Type: text/plain, Size: 12250 bytes --]
On Fri, Nov 10, 2023, 5:03 AM Matthias Schnelte <develop@schnelte.de> wrote:
> Hi Joshua,
>
> thanks for this change! Being able to use websockets instead of some tcp
> connection would help a lot in cooperate setups which are often
> restricted to only http(s) ports and enforce the use of a cooperate proxy.
>
> Unfortunately the websocket library you are using seems not to support
> websockets over http proxy. At least that is what I understood.
>
> Would it be possible to use another client lib for websockets in order
> to support connection through proxy?
>
> This library seems to support it:
>
> https://websocket-client.readthedocs.io/en/latest/examples.html#connecting-through-a-proxy
I'm not sure that's going to work. We need a library that supports asyncio,
and has very minimal dependencies, which the current library satisfies (it
only depends on core Python)
Maybe there is another solution for proxying?
>
> Matthias
>
> On 03.11.23 15:26, Joshua Watt wrote:
> > Adds support to the hash equivalence client and server to communicate
> > over websockets. Since websockets are message orientated instead of
> > stream orientated, and new connection class is needed to handle them.
> >
> > Note that websocket support does require the 3rd party websockets python
> > module be installed on the host, but it should not be required unless
> > websockets are actually being used.
> >
> > Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
> > ---
> > lib/bb/asyncrpc/client.py | 11 +++++++-
> > lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
> > lib/bb/asyncrpc/serv.py | 53 ++++++++++++++++++++++++++++++++++-
> > lib/hashserv/__init__.py | 13 +++++++++
> > lib/hashserv/client.py | 1 +
> > lib/hashserv/tests.py | 17 +++++++++++
> > 6 files changed, 137 insertions(+), 2 deletions(-)
> >
> > diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
> > index 7f33099b..802c07df 100644
> > --- a/lib/bb/asyncrpc/client.py
> > +++ b/lib/bb/asyncrpc/client.py
> > @@ -10,7 +10,7 @@ import json
> > import os
> > import socket
> > import sys
> > -from .connection import StreamConnection, DEFAULT_MAX_CHUNK
> > +from .connection import StreamConnection, WebsocketConnection,
> DEFAULT_MAX_CHUNK
> > from .exceptions import ConnectionClosedError
> >
> >
> > @@ -47,6 +47,15 @@ class AsyncClient(object):
> >
> > self._connect_sock = connect_sock
> >
> > + async def connect_websocket(self, uri):
> > + import websockets
> > +
> > + async def connect_sock():
> > + websocket = await websockets.connect(uri,
> ping_interval=None)
> > + return WebsocketConnection(websocket, self.timeout)
> > +
> > + self._connect_sock = connect_sock
> > +
> > async def setup_connection(self):
> > # Send headers
> > await self.socket.send("%s %s" % (self.proto_name,
> self.proto_version))
> > diff --git a/lib/bb/asyncrpc/connection.py
> b/lib/bb/asyncrpc/connection.py
> > index c4fd2475..a10628f7 100644
> > --- a/lib/bb/asyncrpc/connection.py
> > +++ b/lib/bb/asyncrpc/connection.py
> > @@ -93,3 +93,47 @@ class StreamConnection(object):
> > if self.writer is not None:
> > self.writer.close()
> > self.writer = None
> > +
> > +
> > +class WebsocketConnection(object):
> > + def __init__(self, socket, timeout):
> > + self.socket = socket
> > + self.timeout = timeout
> > +
> > + @property
> > + def address(self):
> > + return ":".join(str(s) for s in self.socket.remote_address)
> > +
> > + async def send_message(self, msg):
> > + await self.send(json.dumps(msg))
> > +
> > + async def recv_message(self):
> > + m = await self.recv()
> > + return json.loads(m)
> > +
> > + async def send(self, msg):
> > + import websockets.exceptions
> > +
> > + try:
> > + await self.socket.send(msg)
> > + except websockets.exceptions.ConnectionClosed:
> > + raise ConnectionClosedError("Connection closed")
> > +
> > + async def recv(self):
> > + import websockets.exceptions
> > +
> > + try:
> > + if self.timeout < 0:
> > + return await self.socket.recv()
> > +
> > + try:
> > + return await asyncio.wait_for(self.socket.recv(),
> self.timeout)
> > + except asyncio.TimeoutError:
> > + raise ConnectionError("Timed out waiting for data")
> > + except websockets.exceptions.ConnectionClosed:
> > + raise ConnectionClosedError("Connection closed")
> > +
> > + async def close(self):
> > + if self.socket is not None:
> > + await self.socket.close()
> > + self.socket = None
> > diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
> > index 3e0d0632..dfb03773 100644
> > --- a/lib/bb/asyncrpc/serv.py
> > +++ b/lib/bb/asyncrpc/serv.py
> > @@ -12,7 +12,7 @@ import signal
> > import socket
> > import sys
> > import multiprocessing
> > -from .connection import StreamConnection
> > +from .connection import StreamConnection, WebsocketConnection
> > from .exceptions import ClientError, ServerError, ConnectionClosedError
> >
> >
> > @@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
> > os.unlink(self.path)
> >
> >
> > +class WebsocketsServer(object):
> > + def __init__(self, host, port, handler, logger):
> > + self.host = host
> > + self.port = port
> > + self.handler = handler
> > + self.logger = logger
> > +
> > + def start(self, loop):
> > + import websockets.server
> > +
> > + self.server = loop.run_until_complete(
> > + websockets.server.serve(
> > + self.client_handler,
> > + self.host,
> > + self.port,
> > + ping_interval=None,
> > + )
> > + )
> > +
> > + for s in self.server.sockets:
> > + self.logger.debug("Listening on %r" % (s.getsockname(),))
> > +
> > + # Enable keep alives. This prevents broken client
> connections
> > + # from persisting on the server for long periods of time.
> > + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
> > + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
> > + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
> > + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
> > +
> > + name = self.server.sockets[0].getsockname()
> > + if self.server.sockets[0].family == socket.AF_INET6:
> > + self.address = "ws://[%s]:%d" % (name[0], name[1])
> > + else:
> > + self.address = "ws://%s:%d" % (name[0], name[1])
> > +
> > + return [self.server.wait_closed()]
> > +
> > + async def stop(self):
> > + self.server.close()
> > +
> > + def cleanup(self):
> > + pass
> > +
> > + async def client_handler(self, websocket):
> > + socket = WebsocketConnection(websocket, -1)
> > + await self.handler(socket)
> > +
> > +
> > class AsyncServer(object):
> > def __init__(self, logger):
> > self.logger = logger
> > @@ -190,6 +238,9 @@ class AsyncServer(object):
> > def start_unix_server(self, path):
> > self.server = UnixStreamServer(path, self._client_handler,
> self.logger)
> >
> > + def start_websocket_server(self, host, port):
> > + self.server = WebsocketsServer(host, port,
> self._client_handler, self.logger)
> > +
> > async def _client_handler(self, socket):
> > try:
> > client = self.accept_client(socket)
> > diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
> > index 3a401835..56b9c6bc 100644
> > --- a/lib/hashserv/__init__.py
> > +++ b/lib/hashserv/__init__.py
> > @@ -9,11 +9,15 @@ import re
> > import sqlite3
> > import itertools
> > import json
> > +from urllib.parse import urlparse
> >
> > UNIX_PREFIX = "unix://"
> > +WS_PREFIX = "ws://"
> > +WSS_PREFIX = "wss://"
> >
> > ADDR_TYPE_UNIX = 0
> > ADDR_TYPE_TCP = 1
> > +ADDR_TYPE_WS = 2
> >
> > UNIHASH_TABLE_DEFINITION = (
> > ("method", "TEXT NOT NULL", "UNIQUE"),
> > @@ -84,6 +88,8 @@ def setup_database(database, sync=True):
> > def parse_address(addr):
> > if addr.startswith(UNIX_PREFIX):
> > return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
> > + elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
> > + return (ADDR_TYPE_WS, (addr,))
> > else:
> > m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
> > if m is not None:
> > @@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True,
> upstream=None, read_only=False):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > s.start_unix_server(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + url = urlparse(a[0])
> > + s.start_websocket_server(url.hostname, url.port)
> > else:
> > s.start_tcp_server(*a)
> >
> > @@ -116,6 +125,8 @@ def create_client(addr):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > c.connect_unix(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + c.connect_websocket(*a)
> > else:
> > c.connect_tcp(*a)
> >
> > @@ -128,6 +139,8 @@ async def create_async_client(addr):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > await c.connect_unix(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + await c.connect_websocket(*a)
> > else:
> > await c.connect_tcp(*a)
> >
> > diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
> > index 5f7d22ab..9542d72f 100644
> > --- a/lib/hashserv/client.py
> > +++ b/lib/hashserv/client.py
> > @@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
> > super().__init__()
> > self._add_methods(
> > "connect_tcp",
> > + "connect_websocket",
> > "get_unihash",
> > "report_unihash",
> > "report_unihash_equiv",
> > diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
> > index f343c586..01ffd52c 100644
> > --- a/lib/hashserv/tests.py
> > +++ b/lib/hashserv/tests.py
> > @@ -483,3 +483,20 @@ class
> TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
> > # If IPv6 is enabled, it should be safe to use localhost
> directly, in general
> > # case it is more reliable to resolve the IP address
> explicitly.
> > return socket.gethostbyname("localhost") + ":0"
> > +
> > +
> > +class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup,
> HashEquivalenceCommonTests, unittest.TestCase):
> > + def setUp(self):
> > + try:
> > + import websockets
> > + except ImportError as e:
> > + self.skipTest(str(e))
> > +
> > + super().setUp()
> > +
> > + def get_server_addr(self, server_idx):
> > + # Some hosts cause asyncio module to misbehave, when IPv6 is
> not enabled.
> > + # If IPv6 is enabled, it should be safe to use localhost
> directly, in general
> > + # case it is more reliable to resolve the IP address explicitly.
> > + host = socket.gethostbyname("localhost")
> > + return "ws://%s:0" % host
> >
> > -=-=-=-=-=-=-=-=-=-=-=-
> > Links: You receive all messages sent to this group.
> > View/Reply Online (#15423):
> https://lists.openembedded.org/g/bitbake-devel/message/15423
> > Mute This Topic: https://lists.openembedded.org/mt/102364905/7851872
> > Group Owner: bitbake-devel+owner@lists.openembedded.org
> > Unsubscribe: https://lists.openembedded.org/g/bitbake-devel/unsub [
> develop@schnelte.de]
> > -=-=-=-=-=-=-=-=-=-=-=-
> >
>
[-- Attachment #2: Type: text/html, Size: 16062 bytes --]
^ permalink raw reply [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation
2023-11-10 14:11 ` Joshua Watt
@ 2023-11-15 7:44 ` Matthias Schnelte
0 siblings, 0 replies; 138+ messages in thread
From: Matthias Schnelte @ 2023-11-15 7:44 UTC (permalink / raw)
To: bitbake-devel
[-- Attachment #1: Type: text/plain, Size: 14681 bytes --]
Hi Joshua,
what we are using currently is a http tunnel to be able to connect to
the hashserver which is running on azure infrastructure. The hashserver
connection in the local.conf is then connecting to the localhost port of
the httptunnel.
The tunnel is started whenever bitbake is called. This is done by
sourcing our own environment that replaces the bitbake command with a
small script that opens the tunnel and then call the original bitbake. I
just came across 'addhandler' - maybe this would have been a better
integrated solution.
But this is some project specific plumping - not sure if one could make
this into a more generic solution.
On 10.11.23 15:11, Joshua Watt wrote:
>
>
> On Fri, Nov 10, 2023, 5:03 AM Matthias Schnelte <develop@schnelte.de>
> wrote:
>
> Hi Joshua,
>
> thanks for this change! Being able to use websockets instead of
> some tcp
> connection would help a lot in cooperate setups which are often
> restricted to only http(s) ports and enforce the use of a
> cooperate proxy.
>
> Unfortunately the websocket library you are using seems not to
> support
> websockets over http proxy. At least that is what I understood.
>
> Would it be possible to use another client lib for websockets in
> order
> to support connection through proxy?
>
> This library seems to support it:
> https://websocket-client.readthedocs.io/en/latest/examples.html#connecting-through-a-proxy
>
>
> I'm not sure that's going to work. We need a library that supports
> asyncio, and has very minimal dependencies, which the current library
> satisfies (it only depends on core Python)
>
>
> Maybe there is another solution for proxying?
>
>
>
>
> Matthias
>
> On 03.11.23 15:26, Joshua Watt wrote:
> > Adds support to the hash equivalence client and server to
> communicate
> > over websockets. Since websockets are message orientated instead of
> > stream orientated, and new connection class is needed to handle
> them.
> >
> > Note that websocket support does require the 3rd party
> websockets python
> > module be installed on the host, but it should not be required
> unless
> > websockets are actually being used.
> >
> > Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
> > ---
> > lib/bb/asyncrpc/client.py | 11 +++++++-
> > lib/bb/asyncrpc/connection.py | 44 +++++++++++++++++++++++++++++
> > lib/bb/asyncrpc/serv.py | 53
> ++++++++++++++++++++++++++++++++++-
> > lib/hashserv/__init__.py | 13 +++++++++
> > lib/hashserv/client.py | 1 +
> > lib/hashserv/tests.py | 17 +++++++++++
> > 6 files changed, 137 insertions(+), 2 deletions(-)
> >
> > diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
> > index 7f33099b..802c07df 100644
> > --- a/lib/bb/asyncrpc/client.py
> > +++ b/lib/bb/asyncrpc/client.py
> > @@ -10,7 +10,7 @@ import json
> > import os
> > import socket
> > import sys
> > -from .connection import StreamConnection, DEFAULT_MAX_CHUNK
> > +from .connection import StreamConnection, WebsocketConnection,
> DEFAULT_MAX_CHUNK
> > from .exceptions import ConnectionClosedError
> >
> >
> > @@ -47,6 +47,15 @@ class AsyncClient(object):
> >
> > self._connect_sock = connect_sock
> >
> > + async def connect_websocket(self, uri):
> > + import websockets
> > +
> > + async def connect_sock():
> > + websocket = await websockets.connect(uri,
> ping_interval=None)
> > + return WebsocketConnection(websocket, self.timeout)
> > +
> > + self._connect_sock = connect_sock
> > +
> > async def setup_connection(self):
> > # Send headers
> > await self.socket.send("%s %s" % (self.proto_name,
> self.proto_version))
> > diff --git a/lib/bb/asyncrpc/connection.py
> b/lib/bb/asyncrpc/connection.py
> > index c4fd2475..a10628f7 100644
> > --- a/lib/bb/asyncrpc/connection.py
> > +++ b/lib/bb/asyncrpc/connection.py
> > @@ -93,3 +93,47 @@ class StreamConnection(object):
> > if self.writer is not None:
> > self.writer.close()
> > self.writer = None
> > +
> > +
> > +class WebsocketConnection(object):
> > + def __init__(self, socket, timeout):
> > + self.socket = socket
> > + self.timeout = timeout
> > +
> > + @property
> > + def address(self):
> > + return ":".join(str(s) for s in self.socket.remote_address)
> > +
> > + async def send_message(self, msg):
> > + await self.send(json.dumps(msg))
> > +
> > + async def recv_message(self):
> > + m = await self.recv()
> > + return json.loads(m)
> > +
> > + async def send(self, msg):
> > + import websockets.exceptions
> > +
> > + try:
> > + await self.socket.send(msg)
> > + except websockets.exceptions.ConnectionClosed:
> > + raise ConnectionClosedError("Connection closed")
> > +
> > + async def recv(self):
> > + import websockets.exceptions
> > +
> > + try:
> > + if self.timeout < 0:
> > + return await self.socket.recv()
> > +
> > + try:
> > + return await
> asyncio.wait_for(self.socket.recv(), self.timeout)
> > + except asyncio.TimeoutError:
> > + raise ConnectionError("Timed out waiting for data")
> > + except websockets.exceptions.ConnectionClosed:
> > + raise ConnectionClosedError("Connection closed")
> > +
> > + async def close(self):
> > + if self.socket is not None:
> > + await self.socket.close()
> > + self.socket = None
> > diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
> > index 3e0d0632..dfb03773 100644
> > --- a/lib/bb/asyncrpc/serv.py
> > +++ b/lib/bb/asyncrpc/serv.py
> > @@ -12,7 +12,7 @@ import signal
> > import socket
> > import sys
> > import multiprocessing
> > -from .connection import StreamConnection
> > +from .connection import StreamConnection, WebsocketConnection
> > from .exceptions import ClientError, ServerError,
> ConnectionClosedError
> >
> >
> > @@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
> > os.unlink(self.path)
> >
> >
> > +class WebsocketsServer(object):
> > + def __init__(self, host, port, handler, logger):
> > + self.host = host
> > + self.port = port
> > + self.handler = handler
> > + self.logger = logger
> > +
> > + def start(self, loop):
> > + import websockets.server
> > +
> > + self.server = loop.run_until_complete(
> > + websockets.server.serve(
> > + self.client_handler,
> > + self.host,
> > + self.port,
> > + ping_interval=None,
> > + )
> > + )
> > +
> > + for s in self.server.sockets:
> > + self.logger.debug("Listening on %r" %
> (s.getsockname(),))
> > +
> > + # Enable keep alives. This prevents broken client
> connections
> > + # from persisting on the server for long periods of
> time.
> > + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
> > + s.setsockopt(socket.IPPROTO_TCP,
> socket.TCP_KEEPIDLE, 30)
> > + s.setsockopt(socket.IPPROTO_TCP,
> socket.TCP_KEEPINTVL, 15)
> > + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
> > +
> > + name = self.server.sockets[0].getsockname()
> > + if self.server.sockets[0].family == socket.AF_INET6:
> > + self.address = "ws://[%s]:%d" % (name[0], name[1])
> > + else:
> > + self.address = "ws://%s:%d" % (name[0], name[1])
> > +
> > + return [self.server.wait_closed()]
> > +
> > + async def stop(self):
> > + self.server.close()
> > +
> > + def cleanup(self):
> > + pass
> > +
> > + async def client_handler(self, websocket):
> > + socket = WebsocketConnection(websocket, -1)
> > + await self.handler(socket)
> > +
> > +
> > class AsyncServer(object):
> > def __init__(self, logger):
> > self.logger = logger
> > @@ -190,6 +238,9 @@ class AsyncServer(object):
> > def start_unix_server(self, path):
> > self.server = UnixStreamServer(path,
> self._client_handler, self.logger)
> >
> > + def start_websocket_server(self, host, port):
> > + self.server = WebsocketsServer(host, port,
> self._client_handler, self.logger)
> > +
> > async def _client_handler(self, socket):
> > try:
> > client = self.accept_client(socket)
> > diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
> > index 3a401835..56b9c6bc 100644
> > --- a/lib/hashserv/__init__.py
> > +++ b/lib/hashserv/__init__.py
> > @@ -9,11 +9,15 @@ import re
> > import sqlite3
> > import itertools
> > import json
> > +from urllib.parse import urlparse
> >
> > UNIX_PREFIX = "unix://"
> > +WS_PREFIX = "ws://"
> > +WSS_PREFIX = "wss://"
> >
> > ADDR_TYPE_UNIX = 0
> > ADDR_TYPE_TCP = 1
> > +ADDR_TYPE_WS = 2
> >
> > UNIHASH_TABLE_DEFINITION = (
> > ("method", "TEXT NOT NULL", "UNIQUE"),
> > @@ -84,6 +88,8 @@ def setup_database(database, sync=True):
> > def parse_address(addr):
> > if addr.startswith(UNIX_PREFIX):
> > return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
> > + elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
> > + return (ADDR_TYPE_WS, (addr,))
> > else:
> > m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
> > if m is not None:
> > @@ -103,6 +109,9 @@ def create_server(addr, dbname, *,
> sync=True, upstream=None, read_only=False):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > s.start_unix_server(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + url = urlparse(a[0])
> > + s.start_websocket_server(url.hostname, url.port)
> > else:
> > s.start_tcp_server(*a)
> >
> > @@ -116,6 +125,8 @@ def create_client(addr):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > c.connect_unix(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + c.connect_websocket(*a)
> > else:
> > c.connect_tcp(*a)
> >
> > @@ -128,6 +139,8 @@ async def create_async_client(addr):
> > (typ, a) = parse_address(addr)
> > if typ == ADDR_TYPE_UNIX:
> > await c.connect_unix(*a)
> > + elif typ == ADDR_TYPE_WS:
> > + await c.connect_websocket(*a)
> > else:
> > await c.connect_tcp(*a)
> >
> > diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
> > index 5f7d22ab..9542d72f 100644
> > --- a/lib/hashserv/client.py
> > +++ b/lib/hashserv/client.py
> > @@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
> > super().__init__()
> > self._add_methods(
> > "connect_tcp",
> > + "connect_websocket",
> > "get_unihash",
> > "report_unihash",
> > "report_unihash_equiv",
> > diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
> > index f343c586..01ffd52c 100644
> > --- a/lib/hashserv/tests.py
> > +++ b/lib/hashserv/tests.py
> > @@ -483,3 +483,20 @@ class
> TestHashEquivalenceTCPServer(HashEquivalenceTestSetup,
> HashEquivalenceComm
> > # If IPv6 is enabled, it should be safe to use
> localhost directly, in general
> > # case it is more reliable to resolve the IP address
> explicitly.
> > return socket.gethostbyname("localhost") + ":0"
> > +
> > +
> > +class
> TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup,
> HashEquivalenceCommonTests, unittest.TestCase):
> > + def setUp(self):
> > + try:
> > + import websockets
> > + except ImportError as e:
> > + self.skipTest(str(e))
> > +
> > + super().setUp()
> > +
> > + def get_server_addr(self, server_idx):
> > + # Some hosts cause asyncio module to misbehave, when
> IPv6 is not enabled.
> > + # If IPv6 is enabled, it should be safe to use
> localhost directly, in general
> > + # case it is more reliable to resolve the IP address
> explicitly.
> > + host = socket.gethostbyname("localhost")
> > + return "ws://%s:0" % host
> >
> >
> >
>
>
> -=-=-=-=-=-=-=-=-=-=-=-
> Links: You receive all messages sent to this group.
> View/Reply Online (#15501):https://lists.openembedded.org/g/bitbake-devel/message/15501
> Mute This Topic:https://lists.openembedded.org/mt/102364905/7851872
> Group Owner:bitbake-devel+owner@lists.openembedded.org
> Unsubscribe:https://lists.openembedded.org/g/bitbake-devel/unsub [develop@schnelte.de]
> -=-=-=-=-=-=-=-=-=-=-=-
>
[-- Attachment #2: Type: text/html, Size: 22382 bytes --]
^ permalink raw reply [flat|nested] 138+ messages in thread
* [bitbake-devel][PATCH v6 03/22] asyncrpc: Add context manager API
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 01/22] asyncrpc: Abstract sockets Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 02/22] hashserv: Add websocket connection implementation Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 04/22] hashserv: tests: Add external database tests Joshua Watt
` (19 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds context manager API for the asyncrcp client class which allow
writing code that will automatically close the connection like so:
with hashserv.create_client(address) as client:
...
Rework the bitbake-hashclient tool and PR server to use this new API to
fix warnings about unclosed event loops when exiting
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 36 +++++++++++++++++-------------------
lib/bb/asyncrpc/client.py | 13 +++++++++++++
lib/prserv/serv.py | 6 +++---
3 files changed, 33 insertions(+), 22 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 3f265e8f..a02a65b9 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -56,25 +56,24 @@ def main():
nonlocal missed_hashes
nonlocal max_time
- client = hashserv.create_client(args.address)
-
- for i in range(args.requests):
- taskhash = hashlib.sha256()
- taskhash.update(args.taskhash_seed.encode('utf-8'))
- taskhash.update(str(i).encode('utf-8'))
+ with hashserv.create_client(args.address) as client:
+ for i in range(args.requests):
+ taskhash = hashlib.sha256()
+ taskhash.update(args.taskhash_seed.encode('utf-8'))
+ taskhash.update(str(i).encode('utf-8'))
- start_time = time.perf_counter()
- l = client.get_unihash(METHOD, taskhash.hexdigest())
- elapsed = time.perf_counter() - start_time
+ start_time = time.perf_counter()
+ l = client.get_unihash(METHOD, taskhash.hexdigest())
+ elapsed = time.perf_counter() - start_time
- with lock:
- if l:
- found_hashes += 1
- else:
- missed_hashes += 1
+ with lock:
+ if l:
+ found_hashes += 1
+ else:
+ missed_hashes += 1
- max_time = max(elapsed, max_time)
- pbar.update()
+ max_time = max(elapsed, max_time)
+ pbar.update()
max_time = 0
found_hashes = 0
@@ -174,9 +173,8 @@ def main():
func = getattr(args, 'func', None)
if func:
- client = hashserv.create_client(args.address)
-
- return func(args, client)
+ with hashserv.create_client(args.address) as client:
+ return func(args, client)
return 0
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 802c07df..009085c3 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -103,6 +103,12 @@ class AsyncClient(object):
async def ping(self):
return await self.invoke({"ping": {}})
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
class Client(object):
def __init__(self):
@@ -153,3 +159,10 @@ class Client(object):
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index ea793316..6168eb18 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -345,9 +345,9 @@ def auto_shutdown():
def ping(host, port):
from . import client
- conn = client.PRClient()
- conn.connect_tcp(host, port)
- return conn.ping()
+ with client.PRClient() as conn:
+ conn.connect_tcp(host, port)
+ return conn.ping()
def connect(host, port):
from . import client
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 04/22] hashserv: tests: Add external database tests
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (2 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 03/22] asyncrpc: Add context manager API Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
` (18 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for running the hash equivalence test suite against an
external hash equivalence implementation.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 54 +++++++++++++++++++++++++++++++++++--------
1 file changed, 44 insertions(+), 10 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 01ffd52c..4c98a280 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -51,13 +51,20 @@ class HashEquivalenceTestSetup(object):
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
self.addCleanup(cleanup_server, server)
+ return server
+
+ def start_client(self, server_address):
def cleanup_client(client):
client.close()
- client = create_client(server.address)
+ client = create_client(server_address)
self.addCleanup(cleanup_client, client)
- return (client, server)
+ return client
+
+ def start_test_server(self):
+ server = self.start_server()
+ return server.address
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -66,7 +73,9 @@ class HashEquivalenceTestSetup(object):
self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
self.addCleanup(self.temp_dir.cleanup)
- (self.client, self.server) = self.start_server()
+ self.server_address = self.start_test_server()
+
+ self.client = self.start_client(self.server_address)
def assertClientGetHash(self, client, taskhash, unihash):
result = client.get_unihash(self.METHOD, taskhash)
@@ -206,7 +215,7 @@ class HashEquivalenceCommonTests(object):
def test_stress(self):
def query_server(failures):
- client = Client(self.server.address)
+ client = Client(self.server_address)
try:
for i in range(1000):
taskhash = hashlib.sha256()
@@ -245,8 +254,10 @@ class HashEquivalenceCommonTests(object):
# the side client. It also verifies that the results are pulled into
# the downstream database by checking that the downstream and side servers
# match after the downstream is done waiting for all backfill tasks
- (down_client, down_server) = self.start_server(upstream=self.server.address)
- (side_client, side_server) = self.start_server(dbpath=down_server.dbpath)
+ down_server = self.start_server(upstream=self.server_address)
+ down_client = self.start_client(down_server.address)
+ side_server = self.start_server(dbpath=down_server.dbpath)
+ side_client = self.start_client(side_server.address)
def check_hash(taskhash, unihash, old_sidehash):
nonlocal down_client
@@ -351,14 +362,18 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['method'], self.METHOD)
def test_ro_server(self):
- (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True)
+ rw_server = self.start_server()
+ rw_client = self.start_client(rw_server.address)
+
+ ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
+ ro_client = self.start_client(ro_server.address)
# Report a hash via the read-write server
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
# Check the hash via the read-only server
@@ -373,7 +388,7 @@ class HashEquivalenceCommonTests(object):
ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
# Ensure that the database was not modified
- self.assertClientGetHash(self.client, taskhash2, None)
+ self.assertClientGetHash(rw_client, taskhash2, None)
def test_slow_server_start(self):
@@ -393,7 +408,7 @@ class HashEquivalenceCommonTests(object):
old_signal = signal.signal(signal.SIGTERM, do_nothing)
self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
- _, server = self.start_server(prefunc=prefunc)
+ server = self.start_server(prefunc=prefunc)
server.process.terminate()
time.sleep(30)
event.set()
@@ -500,3 +515,22 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
# case it is more reliable to resolve the IP address explicitly.
host = socket.gethostbyname("localhost")
return "ws://%s:0" % host
+
+
+class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
+ def start_test_server(self):
+ if 'BB_TEST_HASHSERV' not in os.environ:
+ self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+
+ return os.environ['BB_TEST_HASHSERV']
+
+ def start_server(self, *args, **kwargs):
+ self.skipTest('Cannot start local server when testing external servers')
+
+ def setUp(self):
+ super().setUp()
+ self.client.remove({"method": self.METHOD})
+
+ def tearDown(self):
+ self.client.remove({"method": self.METHOD})
+ super().tearDown()
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 05/22] asyncrpc: Prefix log messages with client info
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (3 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 04/22] hashserv: tests: Add external database tests Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
` (17 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds a logging adaptor to the asyncrpc clients that prefixes log
messages with the client remote address to aid in debugging
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/serv.py | 21 ++++++++++++++++++---
lib/hashserv/server.py | 10 +++++-----
2 files changed, 23 insertions(+), 8 deletions(-)
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index dfb03773..c99add4d 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,10 +12,16 @@ import signal
import socket
import sys
import multiprocessing
+import logging
from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
+class ClientLoggerAdapter(logging.LoggerAdapter):
+ def process(self, msg, kwargs):
+ return f"[Client {self.extra['address']}] {msg}", kwargs
+
+
class AsyncServerConnection(object):
# If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
# return message will be automatically be sent back to the client
@@ -27,7 +33,12 @@ class AsyncServerConnection(object):
self.handlers = {
"ping": self.handle_ping,
}
- self.logger = logger
+ self.logger = ClientLoggerAdapter(
+ logger,
+ {
+ "address": socket.address,
+ },
+ )
async def close(self):
await self.socket.close()
@@ -242,16 +253,20 @@ class AsyncServer(object):
self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
async def _client_handler(self, socket):
+ address = socket.address
try:
client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error("Error from client: %s" % str(e), exc_info=True)
+ self.logger.error(
+ "Error from client %s: %s" % (address, str(e)), exc_info=True
+ )
traceback.print_exc()
+ finally:
+ self.logger.debug("Client %s disconnected", address)
await socket.close()
- self.logger.debug("Client disconnected")
@abc.abstractmethod
def accept_client(self, socket):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 13b75480..e6a3f405 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -207,7 +207,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- logger.debug('Handling %s' % k)
+ self.logger.debug('Handling %s' % k)
if 'stream' in k:
return await self.handlers[k](msg[k])
else:
@@ -351,7 +351,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
break
(method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
+ #self.logger.debug('Looking up %s %s' % (method, taskhash))
cursor = self.db.cursor()
try:
row = self.query_equivalent(cursor, method, taskhash)
@@ -360,7 +360,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if row is not None:
msg = row['unihash']
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -480,8 +480,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
row = self.query_equivalent(cursor, data['method'], data['taskhash'])
if row['unihash'] == data['unihash']:
- logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
+ self.logger.info('Adding taskhash equivalence for %s with unihash %s',
+ data['taskhash'], row['unihash'])
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 06/22] bitbake-hashserv: Allow arguments from environment
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (4 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 05/22] asyncrpc: Prefix log messages with client info Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 07/22] hashserv: Abstract database Joshua Watt
` (16 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows the arguments to the bitbake-hashserv command to be specified in
environment variables. This is a very common idiom when running services
in containers as it allows the arguments to be specified from different
sources as desired by the service administrator
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 80 +++++++++++++++++++++++++++++++++-----------
1 file changed, 60 insertions(+), 20 deletions(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 00af76b2..a916a90c 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -11,56 +11,96 @@ import logging
import argparse
import sqlite3
import warnings
+
warnings.simplefilter("default")
-sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
+sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
VERSION = "1.0.0"
-DEFAULT_BIND = 'unix://./hashserve.sock'
+DEFAULT_BIND = "unix://./hashserve.sock"
def main():
- parser = argparse.ArgumentParser(description='Hash Equivalence Reference Server. Version=%s' % VERSION,
- epilog='''The bind address is the path to a unix domain socket if it is
- prefixed with "unix://". Otherwise, it is an IP address
- and port in form ADDRESS:PORT. To bind to all addresses, leave
- the ADDRESS empty, e.g. "--bind :8686". To bind to a specific
- IPv6 address, enclose the address in "[]", e.g.
- "--bind [::1]:8686"'''
- )
-
- parser.add_argument('-b', '--bind', default=DEFAULT_BIND, help='Bind address (default "%(default)s")')
- parser.add_argument('-d', '--database', default='./hashserv.db', help='Database file (default "%(default)s")')
- parser.add_argument('-l', '--log', default='WARNING', help='Set logging level')
- parser.add_argument('-u', '--upstream', help='Upstream hashserv to pull hashes from')
- parser.add_argument('-r', '--read-only', action='store_true', help='Disallow write operations from clients')
+ parser = argparse.ArgumentParser(
+ description="Hash Equivalence Reference Server. Version=%s" % VERSION,
+ formatter_class=argparse.RawTextHelpFormatter,
+ epilog="""
+The bind address may take one of the following formats:
+ unix://PATH - Bind to unix domain socket at PATH
+ ws://ADDRESS:PORT - Bind to websocket on ADDRESS:PORT
+ ADDRESS:PORT - Bind to raw TCP socket on ADDRESS:PORT
+
+To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
+"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
+"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+ """,
+ )
+
+ parser.add_argument(
+ "-b",
+ "--bind",
+ default=os.environ.get("HASHSERVER_BIND", DEFAULT_BIND),
+ help='Bind address (default $HASHSERVER_BIND, "%(default)s")',
+ )
+ parser.add_argument(
+ "-d",
+ "--database",
+ default=os.environ.get("HASHSERVER_DB", "./hashserv.db"),
+ help='Database file (default $HASHSERVER_DB, "%(default)s")',
+ )
+ parser.add_argument(
+ "-l",
+ "--log",
+ default=os.environ.get("HASHSERVER_LOG_LEVEL", "WARNING"),
+ help='Set logging level (default $HASHSERVER_LOG_LEVEL, "%(default)s")',
+ )
+ parser.add_argument(
+ "-u",
+ "--upstream",
+ default=os.environ.get("HASHSERVER_UPSTREAM", None),
+ help="Upstream hashserv to pull hashes from ($HASHSERVER_UPSTREAM)",
+ )
+ parser.add_argument(
+ "-r",
+ "--read-only",
+ action="store_true",
+ help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
+ )
args = parser.parse_args()
- logger = logging.getLogger('hashserv')
+ logger = logging.getLogger("hashserv")
level = getattr(logging, args.log.upper(), None)
if not isinstance(level, int):
- raise ValueError('Invalid log level: %s' % args.log)
+ raise ValueError("Invalid log level: %s" % args.log)
logger.setLevel(level)
console = logging.StreamHandler()
console.setLevel(level)
logger.addHandler(console)
- server = hashserv.create_server(args.bind, args.database, upstream=args.upstream, read_only=args.read_only)
+ read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+
+ server = hashserv.create_server(
+ args.bind,
+ args.database,
+ upstream=args.upstream,
+ read_only=read_only,
+ )
server.serve_forever()
return 0
-if __name__ == '__main__':
+if __name__ == "__main__":
try:
ret = main()
except Exception:
ret = 1
import traceback
+
traceback.print_exc()
sys.exit(ret)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 07/22] hashserv: Abstract database
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (5 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 06/22] bitbake-hashserv: Allow arguments from environment Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 08/22] hashserv: Add SQLalchemy backend Joshua Watt
` (15 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Abstracts the way the database backend is accessed by the hash
equivalence server to make it possible to use other backends
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/__init__.py | 90 ++-----
lib/hashserv/server.py | 491 +++++++++++++--------------------------
lib/hashserv/sqlite.py | 259 +++++++++++++++++++++
3 files changed, 439 insertions(+), 401 deletions(-)
create mode 100644 lib/hashserv/sqlite.py
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 56b9c6bc..90d8cff1 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -6,7 +6,6 @@
import asyncio
from contextlib import closing
import re
-import sqlite3
import itertools
import json
from urllib.parse import urlparse
@@ -19,92 +18,34 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
-UNIHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("unihash", "TEXT NOT NULL", ""),
-)
-
-UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
-
-OUTHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("outhash", "TEXT NOT NULL", "UNIQUE"),
- ("created", "DATETIME", ""),
-
- # Optional fields
- ("owner", "TEXT", ""),
- ("PN", "TEXT", ""),
- ("PV", "TEXT", ""),
- ("PR", "TEXT", ""),
- ("task", "TEXT", ""),
- ("outhash_siginfo", "TEXT", ""),
-)
-
-OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
-
-def _make_table(cursor, name, definition):
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS {name} (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- {fields}
- UNIQUE({unique})
- )
- '''.format(
- name=name,
- fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
- unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
- ))
-
-
-def setup_database(database, sync=True):
- db = sqlite3.connect(database)
- db.row_factory = sqlite3.Row
-
- with closing(db.cursor()) as cursor:
- _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
- _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
-
- cursor.execute('PRAGMA journal_mode = WAL')
- cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
-
- # Drop old indexes
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
-
- # TODO: Upgrade from tasks_v2?
- cursor.execute('DROP TABLE IF EXISTS tasks_v2')
-
- # Create new indexes
- cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
- cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
-
- return db
-
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
- return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
return (ADDR_TYPE_WS, (addr,))
else:
- m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
if m is not None:
- host = m.group('host')
- port = m.group('port')
+ host = m.group("host")
+ port = m.group("port")
else:
- host, port = addr.split(':')
+ host, port = addr.split(":")
return (ADDR_TYPE_TCP, (host, int(port)))
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+ def sqlite_engine():
+ from .sqlite import DatabaseEngine
+
+ return DatabaseEngine(dbname, sync)
+
from . import server
- db = setup_database(dbname, sync=sync)
- s = server.Server(db, upstream=upstream, read_only=read_only)
+
+ db_engine = sqlite_engine()
+
+ s = server.Server(db_engine, upstream=upstream, read_only=read_only)
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -120,6 +61,7 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
def create_client(addr):
from . import client
+
c = client.Client()
(typ, a) = parse_address(addr)
@@ -132,8 +74,10 @@ def create_client(addr):
return c
+
async def create_async_client(addr):
from . import client
+
c = client.AsyncClient()
(typ, a) = parse_address(addr)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index e6a3f405..84cf4f22 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -3,18 +3,16 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing, contextmanager
from datetime import datetime, timedelta
-import enum
import asyncio
import logging
import math
import time
-from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
+from . import create_async_client
import bb.asyncrpc
-logger = logging.getLogger('hashserv.server')
+logger = logging.getLogger("hashserv.server")
class Measurement(object):
@@ -104,229 +102,136 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-
-
-@enum.unique
-class Resolve(enum.Enum):
- FAIL = enum.auto()
- IGNORE = enum.auto()
- REPLACE = enum.auto()
-
-
-def insert_table(cursor, table, data, on_conflict):
- resolve = {
- Resolve.FAIL: "",
- Resolve.IGNORE: " OR IGNORE",
- Resolve.REPLACE: " OR REPLACE",
- }[on_conflict]
-
- keys = sorted(data.keys())
- query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
- resolve=resolve,
- table=table,
- fields=", ".join(keys),
- values=", ".join(":" + k for k in keys),
- )
- prevrowid = cursor.lastrowid
- cursor.execute(query, data)
- logging.debug(
- "Inserting %r into %s, %s",
- data,
- table,
- on_conflict
- )
- return (cursor.lastrowid, cursor.lastrowid != prevrowid)
-
-def insert_unihash(cursor, data, on_conflict):
- return insert_table(cursor, "unihashes_v2", data, on_conflict)
-
-def insert_outhash(cursor, data, on_conflict):
- return insert_table(cursor, "outhashes_v2", data, on_conflict)
-
-async def copy_unihash_from_upstream(client, db, method, taskhash):
- d = await client.get_taskhash(method, taskhash)
- if d is not None:
- with closing(db.cursor()) as cursor:
- insert_unihash(
- cursor,
- {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE,
- )
- db.commit()
- return d
-
-
-class ServerCursor(object):
- def __init__(self, db, cursor, upstream):
- self.db = db
- self.cursor = cursor
- self.upstream = upstream
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
+ }
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(socket, 'OEHASHEQUIV', logger)
- self.db = db
+ def __init__(
+ self,
+ socket,
+ db_engine,
+ request_stats,
+ backfill_queue,
+ upstream,
+ read_only,
+ ):
+ super().__init__(socket, "OEHASHEQUIV", logger)
+ self.db_engine = db_engine
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
- self.handlers.update({
- 'get': self.handle_get,
- 'get-outhash': self.handle_get_outhash,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- })
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "get-stats": self.handle_get_stats,
+ }
+ )
if not read_only:
- self.handlers.update({
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'reset-stats': self.handle_reset_stats,
- 'backfill-wait': self.handle_backfill_wait,
- 'remove': self.handle_remove,
- 'clean-unused': self.handle_clean_unused,
- })
+ self.handlers.update(
+ {
+ "report": self.handle_report,
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "clean-unused": self.handle_clean_unused,
+ }
+ )
def validate_proto_version(self):
- return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
async def process_requests(self):
- if self.upstream is not None:
- self.upstream_client = await create_async_client(self.upstream)
- else:
- self.upstream_client = None
-
- await super().process_requests()
+ async with self.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.upstream is not None:
+ self.upstream_client = await create_async_client(self.upstream)
+ else:
+ self.upstream_client = None
- if self.upstream_client is not None:
- await self.upstream_client.close()
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- if 'stream' in k:
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
return await self.handlers[k](msg[k])
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
+ with self.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
- fetch_all = request.get('all', False)
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
- with closing(self.db.cursor()) as cursor:
- return await self.get_unihash(cursor, method, taskhash, fetch_all)
+ return await self.get_unihash(method, taskhash, fetch_all)
- async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
+ async def get_unihash(self, method, taskhash, fetch_all=False):
d = None
if fetch_all:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
-
- )
- row = cursor.fetchone()
-
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash, True)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
else:
- row = self.query_equivalent(cursor, method, taskhash)
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
- d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
- insert_unihash(cursor, d, Resolve.IGNORE)
- self.db.commit()
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
return d
async def handle_get_outhash(self, request):
- method = request['method']
- outhash = request['outhash']
- taskhash = request['taskhash']
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
with_unihash = request.get("with_unihash", True)
- with closing(self.db.cursor()) as cursor:
- return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
- async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
d = None
if with_unihash:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
+ row = await self.db.get_unihash_by_outhash(method, outhash)
else:
- cursor.execute(
- """
- SELECT * FROM outhashes_v2
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- """,
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
- row = cursor.fetchone()
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_outhash(method, outhash, taskhash)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
return d
- def update_unified(self, cursor, data):
+ async def update_unified(self, data):
if data is None:
return
- insert_unihash(
- cursor,
- {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
- insert_outhash(
- cursor,
- {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -347,20 +252,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- if l == 'END':
+ if l == "END":
break
(method, taskhash) = l.split()
- #self.logger.debug('Looking up %s %s' % (method, taskhash))
- cursor = self.db.cursor()
- try:
- row = self.query_equivalent(cursor, method, taskhash)
- finally:
- cursor.close()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
- msg = row['unihash']
- #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ msg = row["unihash"]
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -384,118 +285,81 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return self.NO_RESPONSE
async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- outhash_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'created': datetime.now()
- }
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- outhash_data[k] = data[k]
-
- # Insert the new entry, unless it already exists
- (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
-
- if inserted:
- # If this row is new, check if it is equivalent to another
- # output hash
- cursor.execute(
- '''
- SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- -- Select any matching output hash except the one we just inserted
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
- -- Pick the oldest hash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- }
- )
- row = cursor.fetchone()
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
- if row is not None:
- # A matching output hash was found. Set our taskhash to the
- # same unihash since they are equivalent
- unihash = row['unihash']
- resolve = Resolve.IGNORE
- else:
- # No matching output hash was found. This is probably the
- # first outhash to be added.
- unihash = data['unihash']
- resolve = Resolve.IGNORE
-
- # Query upstream to see if it has a unihash we can use
- if self.upstream_client is not None:
- upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
- if upstream_data is not None:
- unihash = upstream_data['unihash']
-
-
- insert_unihash(
- cursor,
- {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- },
- resolve
- )
-
- unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
- if unihash_data is not None:
- unihash = unihash_data['unihash']
- else:
- unihash = data['unihash']
-
- self.db.commit()
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash,
- }
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- return d
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
+ }
async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- }
- insert_unihash(cursor, insert_data, Resolve.IGNORE)
- self.db.commit()
-
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(cursor, data['method'], data['taskhash'])
-
- if row['unihash'] == data['unihash']:
- self.logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
-
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
-
- return d
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
+
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
async def handle_get_stats(self, request):
return {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
self.request_stats.reset()
@@ -503,7 +367,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_backfill_wait(self, request):
d = {
- 'tasks': self.backfill_queue.qsize(),
+ "tasks": self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
return d
@@ -513,92 +377,63 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if not isinstance(condition, dict):
raise TypeError("Bad condition type %s" % type(condition))
- def do_remove(columns, table_name, cursor):
- nonlocal condition
- where = {}
- for c in columns:
- if c in condition and condition[c] is not None:
- where[c] = condition[c]
-
- if where:
- query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys())
- cursor.execute(query, where)
- return cursor.rowcount
-
- return 0
-
- count = 0
- with closing(self.db.cursor()) as cursor:
- count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
- count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
- self.db.commit()
-
- return {"count": count}
+ return {"count": await self.db.remove(condition)}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
- with closing(self.db.cursor()) as cursor:
- cursor.execute(
- """
- DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
- SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
- )
- """,
- {
- "oldest": datetime.now() - timedelta(seconds=-max_age)
- }
- )
- count = cursor.rowcount
-
- return {"count": count}
-
- def query_equivalent(self, cursor, method, taskhash):
- # This is part of the inner loop and must be as fast as possible
- cursor.execute(
- 'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
- )
- return cursor.fetchone()
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, upstream=None, read_only=False):
+ def __init__(self, db_engine, upstream=None, read_only=False):
if upstream and read_only:
- raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
super().__init__(logger)
self.request_stats = Stats()
- self.db = db
+ self.db_engine = db_engine
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, socket):
- return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ return ServerClient(
+ socket,
+ self.db_engine,
+ self.request_stats,
+ self.backfill_queue,
+ self.upstream,
+ self.read_only,
+ )
async def backfill_worker_task(self):
- client = await create_async_client(self.upstream)
- try:
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
+
method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
self.backfill_queue.task_done()
- finally:
- await client.close()
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
+
+ self.loop.run_until_complete(self.db_engine.create())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
new file mode 100644
index 00000000..6809c537
--- /dev/null
+++ b/lib/hashserv/sqlite.py
@@ -0,0 +1,259 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+import sqlite3
+import logging
+from contextlib import closing
+
+logger = logging.getLogger("hashserv.sqlite")
+
+UNIHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("unihash", "TEXT NOT NULL", ""),
+)
+
+UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
+
+OUTHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("outhash", "TEXT NOT NULL", "UNIQUE"),
+ ("created", "DATETIME", ""),
+ # Optional fields
+ ("owner", "TEXT", ""),
+ ("PN", "TEXT", ""),
+ ("PV", "TEXT", ""),
+ ("PR", "TEXT", ""),
+ ("task", "TEXT", ""),
+ ("outhash_siginfo", "TEXT", ""),
+)
+
+OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+
+
+def _make_table(cursor, name, definition):
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS {name} (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ {fields}
+ UNIQUE({unique})
+ )
+ """.format(
+ name=name,
+ fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
+ unique=", ".join(
+ name for name, _, flags in definition if "UNIQUE" in flags
+ ),
+ )
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, dbname, sync):
+ self.dbname = dbname
+ self.logger = logger
+ self.sync = sync
+
+ async def create(self):
+ db = sqlite3.connect(self.dbname)
+ db.row_factory = sqlite3.Row
+
+ with closing(db.cursor()) as cursor:
+ _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
+ _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+
+ cursor.execute("PRAGMA journal_mode = WAL")
+ cursor.execute(
+ "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
+ )
+
+ # Drop old indexes
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
+
+ # TODO: Upgrade from tasks_v2?
+ cursor.execute("DROP TABLE IF EXISTS tasks_v2")
+
+ # Create new indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
+ )
+
+ def connect(self, logger):
+ return Database(logger, self.dbname)
+
+
+class Database(object):
+ def __init__(self, logger, dbname, sync=True):
+ self.dbname = dbname
+ self.logger = logger
+
+ self.db = sqlite3.connect(self.dbname)
+ self.db.row_factory = sqlite3.Row
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ self.db.close()
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT * FROM outhashes_v2
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ -- Select any matching output hash except the one we just inserted
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
+ -- Pick the oldest hash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def remove(self, condition):
+ def do_remove(columns, table_name, cursor):
+ where = {}
+ for c in columns:
+ if c in condition and condition[c] is not None:
+ where[c] = condition[c]
+
+ if where:
+ query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
+ "%s=:%s" % (k, k) for k in where.keys()
+ )
+ cursor.execute(query, where)
+ return cursor.rowcount
+
+ return 0
+
+ count = 0
+ with closing(self.db.cursor()) as cursor:
+ count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
+ count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
+ self.db.commit()
+
+ return count
+
+ async def clean_unused(self, oldest):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
+ SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
+ )
+ """,
+ {
+ "oldest": oldest,
+ },
+ )
+ return cursor.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(
+ """
+ INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ "unihash": unihash,
+ },
+ )
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
+
+ async def insert_outhash(self, data):
+ data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
+ keys = sorted(data.keys())
+ query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
+ fields=", ".join(keys),
+ values=", ".join(":" + k for k in keys),
+ )
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(query, data)
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 08/22] hashserv: Add SQLalchemy backend
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (6 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 07/22] hashserv: Abstract database Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
` (14 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an SQLAlchemy backend to the server. While this database backend is
slower than the more direct sqlite backend, it easily supports just
about any SQL server, which is useful for large scale deployments.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 12 ++
lib/bb/asyncrpc/connection.py | 11 +-
lib/hashserv/__init__.py | 21 ++-
lib/hashserv/sqlalchemy.py | 304 ++++++++++++++++++++++++++++++++++
lib/hashserv/tests.py | 19 ++-
5 files changed, 362 insertions(+), 5 deletions(-)
create mode 100644 lib/hashserv/sqlalchemy.py
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index a916a90c..59b8b07f 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -69,6 +69,16 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
action="store_true",
help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
)
+ parser.add_argument(
+ "--db-username",
+ default=os.environ.get("HASHSERVER_DB_USERNAME", None),
+ help="Database username ($HASHSERVER_DB_USERNAME)",
+ )
+ parser.add_argument(
+ "--db-password",
+ default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
+ help="Database password ($HASHSERVER_DB_PASSWORD)",
+ )
args = parser.parse_args()
@@ -90,6 +100,8 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
args.database,
upstream=args.upstream,
read_only=read_only,
+ db_username=args.db_username,
+ db_password=args.db_password,
)
server.serve_forever()
return 0
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
index a10628f7..7f0cf6ba 100644
--- a/lib/bb/asyncrpc/connection.py
+++ b/lib/bb/asyncrpc/connection.py
@@ -7,6 +7,7 @@
import asyncio
import itertools
import json
+from datetime import datetime
from .exceptions import ClientError, ConnectionClosedError
@@ -30,6 +31,12 @@ def chunkify(msg, max_chunk):
yield "\n"
+def json_serialize(obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat()
+ raise TypeError("Type %s not serializeable" % type(obj))
+
+
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
@@ -42,7 +49,7 @@ class StreamConnection(object):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
+ for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
@@ -105,7 +112,7 @@ class WebsocketConnection(object):
return ":".join(str(s) for s in self.socket.remote_address)
async def send_message(self, msg):
- await self.send(json.dumps(msg))
+ await self.send(json.dumps(msg, default=json_serialize))
async def recv_message(self):
m = await self.recv()
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 90d8cff1..9a8ee4e8 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -35,15 +35,32 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+def create_server(
+ addr,
+ dbname,
+ *,
+ sync=True,
+ upstream=None,
+ read_only=False,
+ db_username=None,
+ db_password=None
+):
def sqlite_engine():
from .sqlite import DatabaseEngine
return DatabaseEngine(dbname, sync)
+ def sqlalchemy_engine():
+ from .sqlalchemy import DatabaseEngine
+
+ return DatabaseEngine(dbname, db_username, db_password)
+
from . import server
- db_engine = sqlite_engine()
+ if "://" in dbname:
+ db_engine = sqlalchemy_engine()
+ else:
+ db_engine = sqlite_engine()
s = server.Server(db_engine, upstream=upstream, read_only=read_only)
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
new file mode 100644
index 00000000..3216621f
--- /dev/null
+++ b/lib/hashserv/sqlalchemy.py
@@ -0,0 +1,304 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import logging
+from datetime import datetime
+
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.pool import NullPool
+from sqlalchemy import (
+ MetaData,
+ Column,
+ Table,
+ Text,
+ Integer,
+ UniqueConstraint,
+ DateTime,
+ Index,
+ select,
+ insert,
+ exists,
+ literal,
+ and_,
+ delete,
+)
+import sqlalchemy.engine
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.exc import IntegrityError
+
+logger = logging.getLogger("hashserv.sqlalchemy")
+
+Base = declarative_base()
+
+
+class UnihashesV2(Base):
+ __tablename__ = "unihashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ unihash = Column(Text, nullable=False)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash"),
+ Index("taskhash_lookup_v3", "method", "taskhash"),
+ )
+
+
+class OuthashesV2(Base):
+ __tablename__ = "outhashes_v2"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ method = Column(Text, nullable=False)
+ taskhash = Column(Text, nullable=False)
+ outhash = Column(Text, nullable=False)
+ created = Column(DateTime)
+ owner = Column(Text)
+ PN = Column(Text)
+ PV = Column(Text)
+ PR = Column(Text)
+ task = Column(Text)
+ outhash_siginfo = Column(Text)
+
+ __table_args__ = (
+ UniqueConstraint("method", "taskhash", "outhash"),
+ Index("outhash_lookup_v3", "method", "outhash"),
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, url, username=None, password=None):
+ self.logger = logger
+ self.url = sqlalchemy.engine.make_url(url)
+
+ if username is not None:
+ self.url = self.url.set(username=username)
+
+ if password is not None:
+ self.url = self.url.set(password=password)
+
+ async def create(self):
+ self.logger.info("Using database %s", self.url)
+ self.engine = create_async_engine(self.url, poolclass=NullPool)
+
+ async with self.engine.begin() as conn:
+ # Create tables
+ logger.info("Creating tables...")
+ await conn.run_sync(Base.metadata.create_all)
+
+ def connect(self, logger):
+ return Database(self.engine, logger)
+
+
+def map_row(row):
+ if row is None:
+ return None
+ return dict(**row._mapping)
+
+
+class Database(object):
+ def __init__(self, engine, logger):
+ self.engine = engine
+ self.db = None
+ self.logger = logger
+
+ async def __aenter__(self):
+ self.db = await self.engine.connect()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ await self.db.close()
+ self.db = None
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ statement = (
+ select(
+ OuthashesV2,
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.taskhash == taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_outhash(self, method, outhash):
+ statement = (
+ select(OuthashesV2)
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ statement = (
+ select(
+ OuthashesV2.taskhash.label("taskhash"),
+ UnihashesV2.unihash.label("unihash"),
+ )
+ .join(
+ UnihashesV2,
+ and_(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ ),
+ )
+ .where(
+ OuthashesV2.method == method,
+ OuthashesV2.outhash == outhash,
+ OuthashesV2.taskhash != taskhash,
+ )
+ .order_by(
+ OuthashesV2.created.asc(),
+ )
+ .limit(1)
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def get_equivalent(self, method, taskhash):
+ statement = select(
+ UnihashesV2.unihash,
+ UnihashesV2.method,
+ UnihashesV2.taskhash,
+ ).where(
+ UnihashesV2.method == method,
+ UnihashesV2.taskhash == taskhash,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return map_row(result.first())
+
+ async def remove(self, condition):
+ async def do_remove(table):
+ where = {}
+ for c in table.__table__.columns:
+ if c.key in condition and condition[c.key] is not None:
+ where[c] = condition[c.key]
+
+ if where:
+ statement = delete(table).where(*[(k == v) for k, v in where.items()])
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ return 0
+
+ count = 0
+ count += await do_remove(UnihashesV2)
+ count += await do_remove(OuthashesV2)
+
+ return count
+
+ async def clean_unused(self, oldest):
+ statement = delete(OuthashesV2).where(
+ OuthashesV2.created < oldest,
+ ~(
+ select(UnihashesV2.id)
+ .where(
+ UnihashesV2.method == OuthashesV2.method,
+ UnihashesV2.taskhash == OuthashesV2.taskhash,
+ )
+ .limit(1)
+ .exists()
+ ),
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ statement = insert(UnihashesV2).values(
+ method=method,
+ taskhash=taskhash,
+ unihash=unihash,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s, %s already in unihash database", method, taskhash, unihash
+ )
+ return False
+
+ async def insert_outhash(self, data):
+ outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
+
+ data = {k: v for k, v in data.items() if k in outhash_columns}
+
+ if "created" in data and not isinstance(data["created"], datetime):
+ data["created"] = datetime.fromisoformat(data["created"])
+
+ statement = insert(OuthashesV2).values(**data)
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError:
+ logger.debug(
+ "%s, %s already in outhash database", data["method"], data["outhash"]
+ )
+ return False
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 4c98a280..268b2700 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -33,7 +33,7 @@ class HashEquivalenceTestSetup(object):
def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
- dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+ dbpath = self.make_dbpath()
def cleanup_server(server):
if server.process.exitcode is not None:
@@ -53,6 +53,9 @@ class HashEquivalenceTestSetup(object):
return server
+ def make_dbpath(self):
+ return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
def start_client(self, server_address):
def cleanup_client(client):
client.close()
@@ -517,6 +520,20 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
return "ws://%s:0" % host
+class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
+ def setUp(self):
+ try:
+ import sqlalchemy
+ import aiosqlite
+ except ImportError as e:
+ self.skipTest(str(e))
+
+ super().setUp()
+
+ def make_dbpath(self):
+ return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
+
+
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def start_test_server(self):
if 'BB_TEST_HASHSERV' not in os.environ:
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 09/22] hashserv: Implement read-only version of "report" RPC
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (7 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 08/22] hashserv: Add SQLalchemy backend Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 10/22] asyncrpc: Add InvokeError Joshua Watt
` (13 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
When the hash equivalence server is in read-only mode, it should still
return a unihash for a given "report" call if there is one.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 25 ++++++++++++++++++++++++-
lib/hashserv/tests.py | 4 ++--
2 files changed, 26 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 84cf4f22..c691df76 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -124,6 +124,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
+ self.read_only = read_only
self.handlers.update(
{
@@ -131,13 +132,15 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ # Not always read-only, but internally checks if the server is
+ # read-only
+ "report": self.handle_report,
}
)
if not read_only:
self.handlers.update(
{
- "report": self.handle_report,
"report-equiv": self.handle_equivreport,
"reset-stats": self.handle_reset_stats,
"backfill-wait": self.handle_backfill_wait,
@@ -284,7 +287,27 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.socket.send("ok")
return self.NO_RESPONSE
+ async def report_readonly(self, data):
+ method = data["method"]
+ outhash = data["outhash"]
+ taskhash = data["taskhash"]
+
+ info = await self.get_outhash(method, outhash, taskhash)
+ if info:
+ unihash = info["unihash"]
+ else:
+ unihash = data["unihash"]
+
+ return {
+ "taskhash": taskhash,
+ "method": method,
+ "unihash": unihash,
+ }
+
async def handle_report(self, data):
+ if self.read_only:
+ return await self.report_readonly(data)
+
outhash_data = {
"method": data["method"],
"outhash": data["outhash"],
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 268b2700..e9a361dc 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -387,8 +387,8 @@ class HashEquivalenceCommonTests(object):
outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
- with self.assertRaises(ConnectionError):
- ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ self.assertEqual(result['unihash'], unihash2)
# Ensure that the database was not modified
self.assertClientGetHash(rw_client, taskhash2, None)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 10/22] asyncrpc: Add InvokeError
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (8 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 09/22] hashserv: Implement read-only version of "report" RPC Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
` (12 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for Invocation Errors (that is, errors raised by the actual
RPC call instead of at the protocol level) to propagate across the
connection. If a server RPC call raises an InvokeError, it will be sent
across the connection and then raised on the client side also. The
connection is still terminated on this error.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/__init__.py | 1 +
lib/bb/asyncrpc/client.py | 10 ++++++++--
lib/bb/asyncrpc/exceptions.py | 4 ++++
lib/bb/asyncrpc/serv.py | 11 +++++++++--
4 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9f677eac..a4371643 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -12,4 +12,5 @@ from .exceptions import (
ClientError,
ServerError,
ConnectionClosedError,
+ InvokeError,
)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 009085c3..d27dbf71 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -11,7 +11,7 @@ import os
import socket
import sys
from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
-from .exceptions import ConnectionClosedError
+from .exceptions import ConnectionClosedError, InvokeError
class AsyncClient(object):
@@ -93,12 +93,18 @@ class AsyncClient(object):
await self.close()
count += 1
+ def check_invoke_error(self, msg):
+ if isinstance(msg, dict) and "invoke-error" in msg:
+ raise InvokeError(msg["invoke-error"]["message"])
+
async def invoke(self, msg):
async def proc():
await self.socket.send_message(msg)
return await self.socket.recv_message()
- return await self._send_wrapper(proc)
+ result = await self._send_wrapper(proc)
+ self.check_invoke_error(result)
+ return result
async def ping(self):
return await self.invoke({"ping": {}})
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
index a8942b4f..ae1043a3 100644
--- a/lib/bb/asyncrpc/exceptions.py
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -9,6 +9,10 @@ class ClientError(Exception):
pass
+class InvokeError(Exception):
+ pass
+
+
class ServerError(Exception):
pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index c99add4d..5fed1730 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -14,7 +14,7 @@ import sys
import multiprocessing
import logging
from .connection import StreamConnection, WebsocketConnection
-from .exceptions import ClientError, ServerError, ConnectionClosedError
+from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
class ClientLoggerAdapter(logging.LoggerAdapter):
@@ -76,7 +76,14 @@ class AsyncServerConnection(object):
d = await self.socket.recv_message()
if d is None:
break
- response = await self.dispatch_message(d)
+ try:
+ response = await self.dispatch_message(d)
+ except InvokeError as e:
+ await self.socket.send_message(
+ {"invoke-error": {"message": str(e)}}
+ )
+ break
+
if response is not self.NO_RESPONSE:
await self.socket.send_message(response)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 11/22] asyncrpc: client: Prevent double closing of loop
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (9 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 10/22] asyncrpc: Add InvokeError Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 12/22] asyncrpc: client: Add disconnect API Joshua Watt
` (11 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Invalidate the loop in the client close() call so that it is not closed
twice (which is an error in the asyncio code)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index d27dbf71..628b90ee 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -161,10 +161,12 @@ class Client(object):
self.client.max_chunk = value
def close(self):
- self.loop.run_until_complete(self.client.close())
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ if self.loop:
+ self.loop.run_until_complete(self.client.close())
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
def __enter__(self):
return self
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 12/22] asyncrpc: client: Add disconnect API
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (10 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 11/22] asyncrpc: client: Prevent double closing of loop Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 13/22] hashserv: Add user permissions Joshua Watt
` (10 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to explicitly disconnect a client. This can be useful for
testing the auto-reconnect behavior of clients
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/bb/asyncrpc/client.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index 628b90ee..0d7cd857 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -67,11 +67,14 @@ class AsyncClient(object):
self.socket = await self._connect_sock()
await self.setup_connection()
- async def close(self):
+ async def disconnect(self):
if self.socket is not None:
await self.socket.close()
self.socket = None
+ async def close(self):
+ await self.disconnect()
+
async def _send_wrapper(self, proc):
count = 0
while True:
@@ -160,6 +163,9 @@ class Client(object):
def max_chunk(self, value):
self.client.max_chunk = value
+ def disconnect(self):
+ self.loop.run_until_complete(self.client.close())
+
def close(self):
if self.loop:
self.loop.run_until_complete(self.client.close())
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 13/22] hashserv: Add user permissions
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (11 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 12/22] asyncrpc: client: Add disconnect API Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 14/22] hashserv: Add become-user API Joshua Watt
` (9 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds support for the hashserver to have per-user permissions. User
management is done via a new "auth" RPC API where a client can
authenticate itself with the server using a randomly generated token.
The user can then be given permissions to read, report, manage the
database, or manage other users.
In addition to explicit user logins, the server supports anonymous users
which is what all users start as before they make the "auth" RPC call.
Anonymous users can be assigned a set of permissions by the server,
making it unnecessary for users to authenticate to use the server. The
set of Anonymous permissions defines the default behavior of the server,
for example if set to "@read", Anonymous users are unable to report
equivalent hashes with authenticating. Similarly, setting the Anonymous
permissions to "@none" would require authentication for users to perform
any action.
User creation and management is entirely manual (although
bitbake-hashclient is very useful as a front end). There are many
different mechanisms that could be implemented to allow user
self-registration (e.g. OAuth, LDAP, etc.), and implementing these is
outside the scope of the server. Instead, it is recommended to
implement a registration service that validates users against the
necessary service, then adds them as a user in the hash equivalence
server.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 84 ++++++++-
bin/bitbake-hashserv | 37 ++++
lib/hashserv/__init__.py | 69 ++++---
lib/hashserv/client.py | 62 ++++++-
lib/hashserv/server.py | 357 ++++++++++++++++++++++++++++++++++++-
lib/hashserv/sqlalchemy.py | 111 +++++++++++-
lib/hashserv/sqlite.py | 105 +++++++++++
lib/hashserv/tests.py | 276 +++++++++++++++++++++++++++-
8 files changed, 1054 insertions(+), 47 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index a02a65b9..328c15cd 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -14,6 +14,7 @@ import sys
import threading
import time
import warnings
+import netrc
warnings.simplefilter("default")
try:
@@ -36,10 +37,18 @@ except ImportError:
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
import hashserv
+import bb.asyncrpc
DEFAULT_ADDRESS = 'unix://./hashserve.sock'
METHOD = 'stress.test.method'
+def print_user(u):
+ print(f"Username: {u['username']}")
+ if "permissions" in u:
+ print("Permissions: " + " ".join(u["permissions"]))
+ if "token" in u:
+ print(f"Token: {u['token']}")
+
def main():
def handle_stats(args, client):
@@ -125,9 +134,39 @@ def main():
print("Removed %d rows" % (result["count"]))
return 0
+ def handle_refresh_token(args, client):
+ r = client.refresh_token(args.username)
+ print_user(r)
+
+ def handle_set_user_permissions(args, client):
+ r = client.set_user_perms(args.username, args.permissions)
+ print_user(r)
+
+ def handle_get_user(args, client):
+ r = client.get_user(args.username)
+ print_user(r)
+
+ def handle_get_all_users(args, client):
+ users = client.get_all_users()
+ print("{username:20}| {permissions}".format(username="Username", permissions="Permissions"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for u in users:
+ print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"])))
+
+ def handle_new_user(args, client):
+ r = client.new_user(args.username, args.permissions)
+ print_user(r)
+
+ def handle_delete_user(args, client):
+ r = client.delete_user(args.username)
+ print_user(r)
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
+ parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
+ parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -158,6 +197,31 @@ def main():
clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old")
clean_unused_parser.set_defaults(func=handle_clean_unused)
+ refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token")
+ refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)")
+ refresh_token_parser.set_defaults(func=handle_refresh_token)
+
+ set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user")
+ set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True)
+ set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ set_user_perms_parser.set_defaults(func=handle_set_user_permissions)
+
+ get_user_parser = subparsers.add_parser('get-user', help="Get user")
+ get_user_parser.add_argument("--username", "-u", help="Username")
+ get_user_parser.set_defaults(func=handle_get_user)
+
+ get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users")
+ get_all_users_parser.set_defaults(func=handle_get_all_users)
+
+ new_user_parser = subparsers.add_parser('new-user', help="Create new user")
+ new_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
+ new_user_parser.set_defaults(func=handle_new_user)
+
+ delete_user_parser = subparsers.add_parser('delete-user', help="Delete user")
+ delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
+ delete_user_parser.set_defaults(func=handle_delete_user)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
@@ -171,10 +235,26 @@ def main():
console.setLevel(level)
logger.addHandler(console)
+ login = args.login
+ password = args.password
+
+ if login is None and args.netrc:
+ try:
+ n = netrc.netrc()
+ auth = n.authenticators(args.address)
+ if auth is not None:
+ login, _, password = auth
+ except FileNotFoundError:
+ pass
+
func = getattr(args, 'func', None)
if func:
- with hashserv.create_client(args.address) as client:
- return func(args, client)
+ try:
+ with hashserv.create_client(args.address, login, password) as client:
+ return func(args, client)
+ except bb.asyncrpc.InvokeError as e:
+ print(f"ERROR: {e}")
+ return 1
return 0
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 59b8b07f..1085d058 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -17,6 +17,7 @@ warnings.simplefilter("default")
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib"))
import hashserv
+from hashserv.server import DEFAULT_ANON_PERMS
VERSION = "1.0.0"
@@ -36,6 +37,22 @@ The bind address may take one of the following formats:
To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
"--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in
"[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686"
+
+Note that the default Anonymous permissions are designed to not break existing
+server instances when upgrading, but are not particularly secure defaults. If
+you want to use authentication, it is recommended that you use "--anon-perms
+@read" to only give anonymous users read access, or "--anon-perms @none" to
+give un-authenticated users no access at all.
+
+Setting "--anon-perms @all" or "--anon-perms @user-admin" is not allowed, since
+this would allow anonymous users to manage all users accounts, which is a bad
+idea.
+
+If you are using user authentication, you should run your server in websockets
+mode with an SSL terminating load balancer in front of it (as this server does
+not implement SSL). Otherwise all usernames and passwords will be transmitted
+in the clear. When configured this way, clients can connect using a secure
+websocket, as in "wss://SERVER:PORT"
""",
)
@@ -79,6 +96,22 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
help="Database password ($HASHSERVER_DB_PASSWORD)",
)
+ parser.add_argument(
+ "--anon-perms",
+ metavar="PERM[,PERM[,...]]",
+ default=os.environ.get("HASHSERVER_ANON_PERMS", ",".join(DEFAULT_ANON_PERMS)),
+ help='Permissions to give anonymous users (default $HASHSERVER_ANON_PERMS, "%(default)s")',
+ )
+ parser.add_argument(
+ "--admin-user",
+ default=os.environ.get("HASHSERVER_ADMIN_USER", None),
+ help="Create default admin user with name ADMIN_USER ($HASHSERVER_ADMIN_USER)",
+ )
+ parser.add_argument(
+ "--admin-password",
+ default=os.environ.get("HASHSERVER_ADMIN_PASSWORD", None),
+ help="Create default admin user with password ADMIN_PASSWORD ($HASHSERVER_ADMIN_PASSWORD)",
+ )
args = parser.parse_args()
@@ -94,6 +127,7 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
+ anon_perms = args.anon_perms.split(",")
server = hashserv.create_server(
args.bind,
@@ -102,6 +136,9 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
read_only=read_only,
db_username=args.db_username,
db_password=args.db_password,
+ anon_perms=anon_perms,
+ admin_username=args.admin_user,
+ admin_password=args.admin_password,
)
server.serve_forever()
return 0
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9a8ee4e8..552a3327 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -8,6 +8,7 @@ from contextlib import closing
import re
import itertools
import json
+from collections import namedtuple
from urllib.parse import urlparse
UNIX_PREFIX = "unix://"
@@ -18,6 +19,8 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
+User = namedtuple("User", ("username", "permissions"))
+
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
@@ -43,7 +46,10 @@ def create_server(
upstream=None,
read_only=False,
db_username=None,
- db_password=None
+ db_password=None,
+ anon_perms=None,
+ admin_username=None,
+ admin_password=None,
):
def sqlite_engine():
from .sqlite import DatabaseEngine
@@ -62,7 +68,17 @@ def create_server(
else:
db_engine = sqlite_engine()
- s = server.Server(db_engine, upstream=upstream, read_only=read_only)
+ if anon_perms is None:
+ anon_perms = server.DEFAULT_ANON_PERMS
+
+ s = server.Server(
+ db_engine,
+ upstream=upstream,
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password,
+ )
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -76,33 +92,40 @@ def create_server(
return s
-def create_client(addr):
+def create_client(addr, username=None, password=None):
from . import client
- c = client.Client()
-
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- c.connect_websocket(*a)
- else:
- c.connect_tcp(*a)
+ c = client.Client(username, password)
- return c
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ c.connect_websocket(*a)
+ else:
+ c.connect_tcp(*a)
+ return c
+ except Exception as e:
+ c.close()
+ raise e
-async def create_async_client(addr):
+async def create_async_client(addr, username=None, password=None):
from . import client
- c = client.AsyncClient()
+ c = client.AsyncClient(username, password)
- (typ, a) = parse_address(addr)
- if typ == ADDR_TYPE_UNIX:
- await c.connect_unix(*a)
- elif typ == ADDR_TYPE_WS:
- await c.connect_websocket(*a)
- else:
- await c.connect_tcp(*a)
+ try:
+ (typ, a) = parse_address(addr)
+ if typ == ADDR_TYPE_UNIX:
+ await c.connect_unix(*a)
+ elif typ == ADDR_TYPE_WS:
+ await c.connect_websocket(*a)
+ else:
+ await c.connect_tcp(*a)
- return c
+ return c
+ except Exception as e:
+ await c.close()
+ raise e
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 9542d72f..82400fe5 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -6,6 +6,7 @@
import logging
import socket
import bb.asyncrpc
+import json
from . import create_async_client
@@ -16,15 +17,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_NORMAL = 0
MODE_GET_STREAM = 1
- def __init__(self):
+ def __init__(self, username=None, password=None):
super().__init__('OEHASHEQUIV', '1.1', logger)
self.mode = self.MODE_NORMAL
+ self.username = username
+ self.password = password
async def setup_connection(self):
await super().setup_connection()
cur_mode = self.mode
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
+ if self.username:
+ await self.auth(self.username, self.password)
async def send_stream(self, msg):
async def proc():
@@ -41,6 +46,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
r = await self._send_wrapper(stream_to_normal)
if r != "ok":
+ self.check_invoke_error(r)
raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
r = await self.invoke({"get-stream": None})
@@ -109,9 +115,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
+ async def auth(self, username, token):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"auth": {"username": username, "token": token}})
+ self.username = username
+ self.password = token
+ return result
+
+ async def refresh_token(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ result = await self.invoke({"refresh-token": m})
+ if self.username and result["username"] == self.username:
+ self.password = result["token"]
+ return result
+
+ async def set_user_perms(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+
+ async def get_user(self, username=None):
+ await self._set_mode(self.MODE_NORMAL)
+ m = {}
+ if username:
+ m["username"] = username
+ return await self.invoke({"get-user": m})
+
+ async def get_all_users(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-all-users": {}}))["users"]
+
+ async def new_user(self, username, permissions):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+
+ async def delete_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ return await self.invoke({"delete-user": {"username": username}})
+
class Client(bb.asyncrpc.Client):
- def __init__(self):
+ def __init__(self, username=None, password=None):
+ self.username = username
+ self.password = password
+
super().__init__()
self._add_methods(
"connect_tcp",
@@ -126,7 +175,14 @@ class Client(bb.asyncrpc.Client):
"backfill_wait",
"remove",
"clean_unused",
+ "auth",
+ "refresh_token",
+ "set_user_perms",
+ "get_user",
+ "get_all_users",
+ "new_user",
+ "delete_user",
)
def _get_async_client(self):
- return AsyncClient()
+ return AsyncClient(self.username, self.password)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index c691df76..f5baa6be 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -8,13 +8,48 @@ import asyncio
import logging
import math
import time
+import os
+import base64
+import hashlib
from . import create_async_client
import bb.asyncrpc
-
logger = logging.getLogger("hashserv.server")
+# This permission only exists to match nothing
+NONE_PERM = "@none"
+
+READ_PERM = "@read"
+REPORT_PERM = "@report"
+DB_ADMIN_PERM = "@db-admin"
+USER_ADMIN_PERM = "@user-admin"
+ALL_PERM = "@all"
+
+ALL_PERMISSIONS = {
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+ USER_ADMIN_PERM,
+ ALL_PERM,
+}
+
+DEFAULT_ANON_PERMS = (
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+)
+
+TOKEN_ALGORITHM = "sha256"
+
+# 48 bytes of random data will result in 64 characters when base64
+# encoded. This number also ensures that the base64 encoding won't have any
+# trailing '=' characters.
+TOKEN_SIZE = 48
+
+SALT_SIZE = 8
+
+
class Measurement(object):
def __init__(self, sample):
self.sample = sample
@@ -108,6 +143,85 @@ class Stats(object):
}
+token_refresh_semaphore = asyncio.Lock()
+
+
+async def new_token():
+ # Prevent malicious users from using this API to deduce the entropy
+ # pool on the server and thus be able to guess a token. *All* token
+ # refresh requests lock the same global semaphore and then sleep for a
+ # short time. The effectively rate limits the total number of requests
+ # than can be made across all clients to 10/second, which should be enough
+ # since you have to be an authenticated users to make the request in the
+ # first place
+ async with token_refresh_semaphore:
+ await asyncio.sleep(0.1)
+ raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
+
+ return base64.b64encode(raw, b"._").decode("utf-8")
+
+
+def new_salt():
+ return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
+
+
+def hash_token(algo, salt, token):
+ h = hashlib.new(algo)
+ h.update(salt.encode("utf-8"))
+ h.update(token.encode("utf-8"))
+ return ":".join([algo, salt, h.hexdigest()])
+
+
+def permissions(*permissions, allow_anon=True, allow_self_service=False):
+ """
+ Function decorator that can be used to decorate an RPC function call and
+ check that the current users permissions match the require permissions.
+
+ If allow_anon is True, the user will also be allowed to make the RPC call
+ if the anonymous user permissions match the permissions.
+
+ If allow_self_service is True, and the "username" property in the request
+ is the currently logged in user, or not specified, the user will also be
+ allowed to make the request. This allows users to access normal privileged
+ API, as long as they are only modifying their own user properties (e.g.
+ users can be allowed to reset their own token without @user-admin
+ permissions, but not the token for any other user.
+ """
+
+ def wrapper(func):
+ async def wrap(self, request):
+ if allow_self_service and self.user is not None:
+ username = request.get("username", self.user.username)
+ if username == self.user.username:
+ request["username"] = self.user.username
+ return await func(self, request)
+
+ if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
+ if not self.user:
+ username = "Anonymous user"
+ user_perms = self.anon_perms
+ else:
+ username = self.user.username
+ user_perms = self.user.permissions
+
+ self.logger.info(
+ "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
+ username,
+ ", ".join(user_perms),
+ func.__name__,
+ ", ".join(permissions),
+ )
+ raise bb.asyncrpc.InvokeError(
+ f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
+ )
+
+ return await func(self, request)
+
+ return wrap
+
+ return wrapper
+
+
class ServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(
self,
@@ -117,6 +231,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
backfill_queue,
upstream,
read_only,
+ anon_perms,
):
super().__init__(socket, "OEHASHEQUIV", logger)
self.db_engine = db_engine
@@ -125,6 +240,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.backfill_queue = backfill_queue
self.upstream = upstream
self.read_only = read_only
+ self.user = None
+ self.anon_perms = anon_perms
self.handlers.update(
{
@@ -135,6 +252,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
+ "auth": self.handle_auth,
+ "get-user": self.handle_get_user,
+ "get-all-users": self.handle_get_all_users,
}
)
@@ -146,9 +266,36 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"backfill-wait": self.handle_backfill_wait,
"remove": self.handle_remove,
"clean-unused": self.handle_clean_unused,
+ "refresh-token": self.handle_refresh_token,
+ "set-user-perms": self.handle_set_perms,
+ "new-user": self.handle_new_user,
+ "delete-user": self.handle_delete_user,
}
)
+ def raise_no_user_error(self, username):
+ raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
+
+ def user_has_permissions(self, *permissions, allow_anon=True):
+ permissions = set(permissions)
+ if allow_anon:
+ if ALL_PERM in self.anon_perms:
+ return True
+
+ if not permissions - self.anon_perms:
+ return True
+
+ if self.user is None:
+ return False
+
+ if ALL_PERM in self.user.permissions:
+ return True
+
+ if not permissions - self.user.permissions:
+ return True
+
+ return False
+
def validate_proto_version(self):
return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
@@ -178,6 +325,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
+ @permissions(READ_PERM)
async def handle_get(self, request):
method = request["method"]
taskhash = request["taskhash"]
@@ -206,6 +354,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return d
+ @permissions(READ_PERM)
async def handle_get_outhash(self, request):
method = request["method"]
outhash = request["outhash"]
@@ -236,6 +385,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
await self.db.insert_outhash(data)
+ @permissions(READ_PERM)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -304,8 +454,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ # Since this can be called either read only or to report, the check to
+ # report is made inside the function
+ @permissions(READ_PERM)
async def handle_report(self, data):
- if self.read_only:
+ if self.read_only or not self.user_has_permissions(REPORT_PERM):
return await self.report_readonly(data)
outhash_data = {
@@ -358,6 +511,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"unihash": unihash,
}
+ @permissions(READ_PERM, REPORT_PERM)
async def handle_equivreport(self, data):
await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
@@ -375,11 +529,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {k: row[k] for k in ("taskhash", "method", "unihash")}
+ @permissions(READ_PERM)
async def handle_get_stats(self, request):
return {
"requests": self.request_stats.todict(),
}
+ @permissions(DB_ADMIN_PERM)
async def handle_reset_stats(self, request):
d = {
"requests": self.request_stats.todict(),
@@ -388,6 +544,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
self.request_stats.reset()
return d
+ @permissions(READ_PERM)
async def handle_backfill_wait(self, request):
d = {
"tasks": self.backfill_queue.qsize(),
@@ -395,6 +552,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
await self.backfill_queue.join()
return d
+ @permissions(DB_ADMIN_PERM)
async def handle_remove(self, request):
condition = request["where"]
if not isinstance(condition, dict):
@@ -402,19 +560,178 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"count": await self.db.remove(condition)}
+ @permissions(DB_ADMIN_PERM)
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ # The authentication API is always allowed
+ async def handle_auth(self, request):
+ username = str(request["username"])
+ token = str(request["token"])
+
+ async def fail_auth():
+ nonlocal username
+ # Rate limit bad login attempts
+ await asyncio.sleep(1)
+ raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
+
+ user, db_token = await self.db.lookup_user_token(username)
+
+ if not user or not db_token:
+ await fail_auth()
+
+ try:
+ algo, salt, _ = db_token.split(":")
+ except ValueError:
+ await fail_auth()
+
+ if hash_token(algo, salt, token) != db_token:
+ await fail_auth()
+
+ self.user = user
+
+ self.logger.info("Authenticated as %s", username)
+
+ return {
+ "result": True,
+ "username": self.user.username,
+ "permissions": sorted(list(self.user.permissions)),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_refresh_token(self, request):
+ username = str(request["username"])
+
+ token = await new_token()
+
+ updated = await self.db.set_user_token(
+ username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not updated:
+ self.raise_no_user_error(username)
+
+ return {"username": username, "token": token}
+
+ def get_perm_arg(self, arg):
+ if not isinstance(arg, list):
+ raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
+
+ arg = set(arg)
+ try:
+ arg.remove(NONE_PERM)
+ except KeyError:
+ pass
+
+ unknown_perms = arg - ALL_PERMISSIONS
+ if unknown_perms:
+ raise bb.asyncrpc.InvokeError(
+ "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
+ )
+
+ return sorted(list(arg))
+
+ def return_perms(self, permissions):
+ if ALL_PERM in permissions:
+ return sorted(list(ALL_PERMISSIONS))
+ return sorted(list(permissions))
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_set_perms(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ if not await self.db.set_user_perms(username, permissions):
+ self.raise_no_user_error(username)
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_get_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ return None
+
+ return {
+ "username": user.username,
+ "permissions": self.return_perms(user.permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_get_all_users(self, request):
+ users = await self.db.get_all_users()
+ return {
+ "users": [
+ {
+ "username": u.username,
+ "permissions": self.return_perms(u.permissions),
+ }
+ for u in users
+ ]
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_new_user(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ token = await new_token()
+
+ inserted = await self.db.new_user(
+ username,
+ permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not inserted:
+ raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ "token": token,
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_delete_user(self, request):
+ username = str(request["username"])
+
+ if not await self.db.delete_user(username):
+ self.raise_no_user_error(username)
+
+ return {"username": username}
+
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db_engine, upstream=None, read_only=False):
+ def __init__(
+ self,
+ db_engine,
+ upstream=None,
+ read_only=False,
+ anon_perms=DEFAULT_ANON_PERMS,
+ admin_username=None,
+ admin_password=None,
+ ):
if upstream and read_only:
raise bb.asyncrpc.ServerError(
"Read-only hashserv cannot pull from an upstream server"
)
+ disallowed_perms = set(anon_perms) - set(
+ [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
+ )
+
+ if disallowed_perms:
+ raise bb.asyncrpc.ServerError(
+ f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
+ )
+
super().__init__(logger)
self.request_stats = Stats()
@@ -422,6 +739,13 @@ class Server(bb.asyncrpc.AsyncServer):
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
+ self.anon_perms = set(anon_perms)
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+
+ self.logger.info(
+ "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
+ )
def accept_client(self, socket):
return ServerClient(
@@ -431,12 +755,34 @@ class Server(bb.asyncrpc.AsyncServer):
self.backfill_queue,
self.upstream,
self.read_only,
+ self.anon_perms,
)
+ async def create_admin_user(self):
+ admin_permissions = (ALL_PERM,)
+ async with self.db_engine.connect(self.logger) as db:
+ added = await db.new_user(
+ self.admin_username,
+ admin_permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ if added:
+ self.logger.info("Created admin user '%s'", self.admin_username)
+ else:
+ await db.set_user_perms(
+ self.admin_username,
+ admin_permissions,
+ )
+ await db.set_user_token(
+ self.admin_username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ self.logger.info("Admin user '%s' updated", self.admin_username)
+
async def backfill_worker_task(self):
async with await create_async_client(
self.upstream
- ) as client, self.db_engine.connect(logger) as db:
+ ) as client, self.db_engine.connect(self.logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
@@ -457,6 +803,9 @@ class Server(bb.asyncrpc.AsyncServer):
self.loop.run_until_complete(self.db_engine.create())
+ if self.admin_username:
+ self.loop.run_until_complete(self.create_admin_user())
+
return tasks
async def stop(self):
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 3216621f..bfd8a844 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -7,6 +7,7 @@
import logging
from datetime import datetime
+from . import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
@@ -25,13 +26,12 @@ from sqlalchemy import (
literal,
and_,
delete,
+ update,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.exc import IntegrityError
-logger = logging.getLogger("hashserv.sqlalchemy")
-
Base = declarative_base()
@@ -68,9 +68,19 @@ class OuthashesV2(Base):
)
+class Users(Base):
+ __tablename__ = "users"
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ username = Column(Text, nullable=False)
+ token = Column(Text, nullable=False)
+ permissions = Column(Text)
+
+ __table_args__ = (UniqueConstraint("username"),)
+
+
class DatabaseEngine(object):
def __init__(self, url, username=None, password=None):
- self.logger = logger
+ self.logger = logging.getLogger("hashserv.sqlalchemy")
self.url = sqlalchemy.engine.make_url(url)
if username is not None:
@@ -85,7 +95,7 @@ class DatabaseEngine(object):
async with self.engine.begin() as conn:
# Create tables
- logger.info("Creating tables...")
+ self.logger.info("Creating tables...")
await conn.run_sync(Base.metadata.create_all)
def connect(self, logger):
@@ -98,6 +108,15 @@ def map_row(row):
return dict(**row._mapping)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row.username,
+ permissions=set(row.permissions.split()),
+ )
+
+
class Database(object):
def __init__(self, engine, logger):
self.engine = engine
@@ -278,7 +297,7 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s, %s already in unihash database", method, taskhash, unihash
)
return False
@@ -298,7 +317,87 @@ class Database(object):
await self.db.execute(statement)
return True
except IntegrityError:
- logger.debug(
+ self.logger.debug(
"%s, %s already in outhash database", data["method"], data["outhash"]
)
return False
+
+ async def _get_user(self, username):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ Users.token,
+ ).where(
+ Users.username == username,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.first()
+
+ async def lookup_user_token(self, username):
+ row = await self._get_user(username)
+ if not row:
+ return None, None
+ return map_user(row), row.token
+
+ async def lookup_user(self, username):
+ return map_user(await self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ statement = (
+ update(Users)
+ .where(
+ Users.username == username,
+ )
+ .values(
+ token=token,
+ )
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ statement = (
+ update(Users)
+ .where(Users.username == username)
+ .values(permissions=" ".join(permissions))
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
+
+ async def get_all_users(self):
+ statement = select(
+ Users.username,
+ Users.permissions,
+ )
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return [map_user(row) for row in result]
+
+ async def new_user(self, username, permissions, token):
+ statement = insert(Users).values(
+ username=username,
+ permissions=" ".join(permissions),
+ token=token,
+ )
+ self.logger.debug("%s", statement)
+ try:
+ async with self.db.begin():
+ await self.db.execute(statement)
+ return True
+ except IntegrityError as e:
+ self.logger.debug("Cannot create new user %s: %s", username, e)
+ return False
+
+ async def delete_user(self, username):
+ statement = delete(Users).where(Users.username == username)
+ self.logger.debug("%s", statement)
+ async with self.db.begin():
+ result = await self.db.execute(statement)
+ return result.rowcount != 0
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 6809c537..414ee8ff 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -7,6 +7,7 @@
import sqlite3
import logging
from contextlib import closing
+from . import User
logger = logging.getLogger("hashserv.sqlite")
@@ -34,6 +35,14 @@ OUTHASH_TABLE_DEFINITION = (
OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+USERS_TABLE_DEFINITION = (
+ ("username", "TEXT NOT NULL", "UNIQUE"),
+ ("token", "TEXT NOT NULL", ""),
+ ("permissions", "TEXT NOT NULL", ""),
+)
+
+USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
+
def _make_table(cursor, name, definition):
cursor.execute(
@@ -53,6 +62,15 @@ def _make_table(cursor, name, definition):
)
+def map_user(row):
+ if row is None:
+ return None
+ return User(
+ username=row["username"],
+ permissions=set(row["permissions"].split()),
+ )
+
+
class DatabaseEngine(object):
def __init__(self, dbname, sync):
self.dbname = dbname
@@ -66,6 +84,7 @@ class DatabaseEngine(object):
with closing(db.cursor()) as cursor:
_make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
_make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+ _make_table(cursor, "users", USERS_TABLE_DEFINITION)
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute(
@@ -227,6 +246,7 @@ class Database(object):
"oldest": oldest,
},
)
+ self.db.commit()
return cursor.rowcount
async def insert_unihash(self, method, taskhash, unihash):
@@ -257,3 +277,88 @@ class Database(object):
cursor.execute(query, data)
self.db.commit()
return cursor.lastrowid != prevrowid
+
+ def _get_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT username, permissions, token FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ return cursor.fetchone()
+
+ async def lookup_user_token(self, username):
+ row = self._get_user(username)
+ if row is None:
+ return None, None
+ return map_user(row), row["token"]
+
+ async def lookup_user(self, username):
+ return map_user(self._get_user(username))
+
+ async def set_user_token(self, username, token):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET token=:token WHERE username=:username
+ """,
+ {
+ "username": username,
+ "token": token,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def set_user_perms(self, username, permissions):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ UPDATE users SET permissions=:permissions WHERE username=:username
+ """,
+ {
+ "username": username,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
+
+ async def get_all_users(self):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute("SELECT username, permissions FROM users")
+ return [map_user(r) for r in cursor.fetchall()]
+
+ async def new_user(self, username, permissions, token):
+ with closing(self.db.cursor()) as cursor:
+ try:
+ cursor.execute(
+ """
+ INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
+ """,
+ {
+ "username": username,
+ "token": token,
+ "permissions": " ".join(permissions),
+ },
+ )
+ self.db.commit()
+ return True
+ except sqlite3.IntegrityError:
+ return False
+
+ async def delete_user(self, username):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM users WHERE username=:username
+ """,
+ {
+ "username": username,
+ },
+ )
+ self.db.commit()
+ return cursor.rowcount != 0
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e9a361dc..f92f37c4 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -6,6 +6,8 @@
#
from . import create_server, create_client
+from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
+from bb.asyncrpc import InvokeError
import hashlib
import logging
import multiprocessing
@@ -29,8 +31,9 @@ class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
+ client_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
self.server_index += 1
if dbpath is None:
dbpath = self.make_dbpath()
@@ -45,7 +48,10 @@ class HashEquivalenceTestSetup(object):
server = create_server(self.get_server_addr(self.server_index),
dbpath,
upstream=upstream,
- read_only=read_only)
+ read_only=read_only,
+ anon_perms=anon_perms,
+ admin_username=admin_username,
+ admin_password=admin_password)
server.dbpath = dbpath
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
@@ -56,18 +62,31 @@ class HashEquivalenceTestSetup(object):
def make_dbpath(self):
return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def start_client(self, server_address):
+ def start_client(self, server_address, username=None, password=None):
def cleanup_client(client):
client.close()
- client = create_client(server_address)
+ client = create_client(server_address, username=username, password=password)
self.addCleanup(cleanup_client, client)
return client
def start_test_server(self):
- server = self.start_server()
- return server.address
+ self.server = self.start_server()
+ return self.server.address
+
+ def start_auth_server(self):
+ self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ return self.admin_client
+
+ def auth_client(self, user):
+ return self.start_client(self.auth_server.address, user["username"], user["token"])
+
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -86,18 +105,21 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
- def test_create_hash(self):
+ def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
- self.assertClientGetHash(self.client, taskhash, None)
+ self.assertClientGetHash(client, taskhash, None)
- result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def test_create_hash(self):
+ return self.create_test_hash(self.client)
+
def test_create_equivalent(self):
# Tests that a second reported task with the same outhash will be
# assigned the same unihash
@@ -471,6 +493,242 @@ class HashEquivalenceCommonTests(object):
# shares a taskhash with Task 2
self.assertClientGetHash(self.client, taskhash2, unihash2)
+ def test_auth_read_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Create hashes with non-authenticated server
+ taskhash, outhash, unihash = self.test_create_hash()
+
+ # Validate hash can be retrieved using authenticated client
+ with self.auth_perms("@read") as client:
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_report_perms(self):
+ admin_client = self.start_auth_server()
+
+ # Without read permission, the user is completely denied
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read permission allows the call to succeed, but it doesn't record
+ # anythin in the database
+ with self.auth_perms("@read") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, None)
+
+ # Report permission alone is insufficient
+ with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
+ self.create_test_hash(client)
+
+ # Read and report permission actually modify the database
+ with self.auth_perms("@read", "@report") as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ self.assertClientGetHash(client, taskhash, unihash)
+
+ def test_auth_no_token_refresh_from_anon_user(self):
+ self.start_auth_server()
+
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.refresh_token()
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
+
+ def test_auth_self_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ # Create a new user with no permissions
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client:
+ new_user = client.refresh_token()
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ # Explicitly specifying with your own username is fine also
+ with self.auth_client(new_user) as client:
+ new_user2 = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user2["username"])
+ self.assertNotEqual(user["token"], new_user2["token"])
+ self.assertUserCanAuth(new_user2)
+ self.assertUserCannotAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_token_refresh(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.refresh_token(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ new_user = client.refresh_token(user["username"])
+
+ self.assertEqual(user["username"], new_user["username"])
+ self.assertNotEqual(user["token"], new_user["token"])
+ self.assertUserCanAuth(new_user)
+ self.assertUserCannotAuth(user)
+
+ def test_auth_self_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Explicitly asking for your own username is fine also
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ def test_auth_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.get_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ info = client.get_user(user["username"])
+ self.assertEqual(info, user_info)
+
+ info = client.get_user("nonexist-user")
+ self.assertIsNone(info)
+
+ def test_auth_reconnect(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ def test_auth_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ # No self service
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ client.delete_user(user["username"])
+
+ # User doesn't exist, so even though the permission is correct, it's an
+ # error
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.delete_user(user["username"])
+
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def test_auth_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ self.assertUserPerms(user, [])
+
+ # No self service to change permissions
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, [])
+
+ with self.auth_perms("@user-admin") as client:
+ client.set_user_perms(user["username"], ["@all"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ # Bad permissions
+ with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
+ client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
+ self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
+
+ def test_auth_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", [])
+
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.get_all_users()
+
+ # Give the test user the correct permission
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ with self.auth_client(user) as client:
+ all_users = client.get_all_users()
+
+ # Convert to a dictionary for easier comparison
+ all_users = {u["username"]: u for u in all_users}
+
+ self.assertEqual(all_users,
+ {
+ "admin": {
+ "username": "admin",
+ "permissions": sorted(list(ALL_PERMISSIONS)),
+ },
+ "test-user": {
+ "username": "test-user",
+ "permissions": ["@user-admin"],
+ }
+ }
+ )
+
+ def test_auth_new_user(self):
+ self.start_auth_server()
+
+ permissions = ["@read", "@report", "@db-admin", "@user-admin"]
+ permissions.sort()
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.new_user("test-user", permissions)
+
+ with self.auth_perms("@user-admin") as client:
+ user = client.new_user("test-user", permissions)
+ self.assertIn("token", user)
+ self.assertEqual(user["username"], "test-user")
+ self.assertEqual(user["permissions"], permissions)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 14/22] hashserv: Add become-user API
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (12 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 13/22] hashserv: Add user permissions Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 15/22] hashserv: Add db-usage API Joshua Watt
` (8 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds API that allows a user admin to impersonate another user in the
system. This makes it easier to write external services that have
external authentication, since they can use a common user account to
access the server, then impersonate the logged in user.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 +++
lib/hashserv/client.py | 42 +++++++++++++++++++++++++++++++++++++-----
lib/hashserv/server.py | 18 ++++++++++++++++++
lib/hashserv/tests.py | 39 +++++++++++++++++++++++++++++++++++++++
4 files changed, 97 insertions(+), 5 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 328c15cd..cfbc197e 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -166,6 +166,7 @@ def main():
parser.add_argument('--log', default='WARNING', help='Set logging level')
parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
+ parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions")
parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
subparsers = parser.add_subparsers()
@@ -251,6 +252,8 @@ def main():
if func:
try:
with hashserv.create_client(args.address, login, password) as client:
+ if args.become:
+ client.become_user(args.become)
return func(args, client)
except bb.asyncrpc.InvokeError as e:
print(f"ERROR: {e}")
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 82400fe5..4457f8e5 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -18,10 +18,11 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_GET_STREAM = 1
def __init__(self, username=None, password=None):
- super().__init__('OEHASHEQUIV', '1.1', logger)
+ super().__init__("OEHASHEQUIV", "1.1", logger)
self.mode = self.MODE_NORMAL
self.username = username
self.password = password
+ self.saved_become_user = None
async def setup_connection(self):
await super().setup_connection()
@@ -29,8 +30,13 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.mode = self.MODE_NORMAL
await self._set_mode(cur_mode)
if self.username:
+ # Save off become user temporarily because auth() resets it
+ become = self.saved_become_user
await self.auth(self.username, self.password)
+ if become:
+ await self.become_user(become)
+
async def send_stream(self, msg):
async def proc():
await self.socket.send(msg)
@@ -92,7 +98,14 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke(
- {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
+ {
+ "get-outhash": {
+ "outhash": outhash,
+ "taskhash": taskhash,
+ "method": method,
+ "with_unihash": with_unihash,
+ }
+ }
)
async def get_stats(self):
@@ -120,6 +133,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
result = await self.invoke({"auth": {"username": username, "token": token}})
self.username = username
self.password = token
+ self.saved_become_user = None
return result
async def refresh_token(self, username=None):
@@ -128,13 +142,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
if username:
m["username"] = username
result = await self.invoke({"refresh-token": m})
- if self.username and result["username"] == self.username:
+ if (
+ self.username
+ and not self.saved_become_user
+ and result["username"] == self.username
+ ):
self.password = result["token"]
return result
async def set_user_perms(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"set-user-perms": {"username": username, "permissions": permissions}}
+ )
async def get_user(self, username=None):
await self._set_mode(self.MODE_NORMAL)
@@ -149,12 +169,23 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def new_user(self, username, permissions):
await self._set_mode(self.MODE_NORMAL)
- return await self.invoke({"new-user": {"username": username, "permissions": permissions}})
+ return await self.invoke(
+ {"new-user": {"username": username, "permissions": permissions}}
+ )
async def delete_user(self, username):
await self._set_mode(self.MODE_NORMAL)
return await self.invoke({"delete-user": {"username": username}})
+ async def become_user(self, username):
+ await self._set_mode(self.MODE_NORMAL)
+ result = await self.invoke({"become-user": {"username": username}})
+ if username == self.username:
+ self.saved_become_user = None
+ else:
+ self.saved_become_user = username
+ return result
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -182,6 +213,7 @@ class Client(bb.asyncrpc.Client):
"get_all_users",
"new_user",
"delete_user",
+ "become_user",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index f5baa6be..ca419a1a 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -255,6 +255,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"auth": self.handle_auth,
"get-user": self.handle_get_user,
"get-all-users": self.handle_get_all_users,
+ "become-user": self.handle_become_user,
}
)
@@ -707,6 +708,23 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"username": username}
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_become_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
+
+ self.user = user
+
+ self.logger.info("Became user %s", username)
+
+ return {
+ "username": self.user.username,
+ "permissions": self.return_perms(self.user.permissions),
+ }
+
class Server(bb.asyncrpc.AsyncServer):
def __init__(
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f92f37c4..311b7b77 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -728,6 +728,45 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
+ def test_auth_become_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+ user_info = user.copy()
+ del user_info["token"]
+
+ with self.auth_perms() as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(user["username"])
+ self.assertEqual(become, user_info)
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # Verify become user is preserved across disconnect
+ client.disconnect()
+
+ info = client.get_user()
+ self.assertEqual(info, user_info)
+
+ # test-user doesn't have become_user permissions, so this should
+ # not work
+ with self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # No self-service of become
+ with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ client.become_user(user["username"])
+
+ # Give test user permissions to become
+ admin_client.set_user_perms(user["username"], ["@user-admin"])
+
+ # It's possible to become yourself (effectively a noop)
+ with self.auth_perms("@user-admin") as client:
+ become = client.become_user(client.username)
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 15/22] hashserv: Add db-usage API
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (13 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 14/22] hashserv: Add become-user API Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 16/22] hashserv: Add database column query API Joshua Watt
` (7 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to query the server for the usage of the database (e.g. how
many rows are present in each table)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 16 ++++++++++++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 14 ++++++++++++++
lib/hashserv/sqlite.py | 37 +++++++++++++++++++++++++++++++++++++
lib/hashserv/tests.py | 9 +++++++++
6 files changed, 86 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index cfbc197e..5d65c7bc 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -161,6 +161,19 @@ def main():
r = client.delete_user(args.username)
print_user(r)
+ def handle_get_db_usage(args, client):
+ usage = client.get_db_usage()
+ print(usage)
+ tables = sorted(usage.keys())
+ print("{name:20}| {rows:20}".format(name="Table name", rows="Rows"))
+ print(("-" * 20) + "+" + ("-" * 20))
+ for t in tables:
+ print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"]))
+ print()
+
+ total_rows = sum(t["rows"] for t in usage.values())
+ print(f"Total rows: {total_rows}")
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -223,6 +236,9 @@ def main():
delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
delete_user_parser.set_defaults(func=handle_delete_user)
+ db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
+ db_usage_parser.set_defaults(func=handle_get_db_usage)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 4457f8e5..5e0a462b 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -186,6 +186,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
self.saved_become_user = username
return result
+ async def get_db_usage(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-usage": {}}))["usage"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -214,6 +218,7 @@ class Client(bb.asyncrpc.Client):
"new_user",
"delete_user",
"become_user",
+ "get_db_usage",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index ca419a1a..c5b9797e 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -249,6 +249,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-outhash": self.handle_get_outhash,
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
+ "get-db-usage": self.handle_get_db_usage,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -567,6 +568,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
oldest = datetime.now() - timedelta(seconds=-max_age)
return {"count": await self.db.clean_unused(oldest)}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_usage(self, request):
+ return {"usage": await self.db.get_usage()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index bfd8a844..818b5195 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -27,6 +27,7 @@ from sqlalchemy import (
and_,
delete,
update,
+ func,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
@@ -401,3 +402,16 @@ class Database(object):
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ async with self.db.begin() as session:
+ for name, table in Base.metadata.tables.items():
+ statement = select(func.count()).select_from(table)
+ self.logger.debug("%s", statement)
+ result = await self.db.execute(statement)
+ usage[name] = {
+ "rows": result.scalar(),
+ }
+
+ return usage
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index 414ee8ff..dfdccbba 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -120,6 +120,18 @@ class Database(object):
self.db = sqlite3.connect(self.dbname)
self.db.row_factory = sqlite3.Row
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute("SELECT sqlite_version()")
+
+ version = []
+ for v in cursor.fetchone()[0].split("."):
+ try:
+ version.append(int(v))
+ except ValueError:
+ version.append(v)
+
+ self.sqlite_version = tuple(version)
+
async def __aenter__(self):
return self
@@ -362,3 +374,28 @@ class Database(object):
)
self.db.commit()
return cursor.rowcount != 0
+
+ async def get_usage(self):
+ usage = {}
+ with closing(self.db.cursor()) as cursor:
+ if self.sqlite_version >= (3, 33):
+ table_name = "sqlite_schema"
+ else:
+ table_name = "sqlite_master"
+
+ cursor.execute(
+ f"""
+ SELECT name FROM {table_name} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
+ """
+ )
+ for row in cursor.fetchall():
+ cursor.execute(
+ """
+ SELECT COUNT() FROM %s
+ """
+ % row["name"],
+ )
+ usage[row["name"]] = {
+ "rows": cursor.fetchone()[0],
+ }
+ return usage
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 311b7b77..9d5bec24 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -767,6 +767,15 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client:
become = client.become_user(client.username)
+ def test_get_db_usage(self):
+ usage = self.client.get_db_usage()
+
+ self.assertTrue(isinstance(usage, dict))
+ for name in usage.keys():
+ self.assertTrue(isinstance(usage[name], dict))
+ self.assertIn("rows", usage[name])
+ self.assertTrue(isinstance(usage[name]["rows"], int))
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 16/22] hashserv: Add database column query API
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (14 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 15/22] hashserv: Add db-usage API Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
` (6 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Adds an API to retrieve the columns that can be queried on from the
database backend. This prevents front end applications from needing to
hardcode the query columns
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 7 +++++++
lib/hashserv/client.py | 5 +++++
lib/hashserv/server.py | 5 +++++
lib/hashserv/sqlalchemy.py | 10 ++++++++++
lib/hashserv/sqlite.py | 7 +++++++
lib/hashserv/tests.py | 8 ++++++++
6 files changed, 42 insertions(+)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 5d65c7bc..58aa02ee 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -174,6 +174,10 @@ def main():
total_rows = sum(t["rows"] for t in usage.values())
print(f"Total rows: {total_rows}")
+ def handle_get_db_query_columns(args, client):
+ columns = client.get_db_query_columns()
+ print("\n".join(sorted(columns)))
+
parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')
@@ -239,6 +243,9 @@ def main():
db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
db_usage_parser.set_defaults(func=handle_get_db_usage)
+ db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries")
+ db_query_columns_parser.set_defaults(func=handle_get_db_query_columns)
+
args = parser.parse_args()
logger = logging.getLogger('hashserv')
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index 5e0a462b..35a97687 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -190,6 +190,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await self._set_mode(self.MODE_NORMAL)
return (await self.invoke({"get-db-usage": {}}))["usage"]
+ async def get_db_query_columns(self):
+ await self._set_mode(self.MODE_NORMAL)
+ return (await self.invoke({"get-db-query-columns": {}}))["columns"]
+
class Client(bb.asyncrpc.Client):
def __init__(self, username=None, password=None):
@@ -219,6 +223,7 @@ class Client(bb.asyncrpc.Client):
"delete_user",
"become_user",
"get_db_usage",
+ "get_db_query_columns",
)
def _get_async_client(self):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index c5b9797e..8c3d20b6 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -250,6 +250,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"get-stream": self.handle_get_stream,
"get-stats": self.handle_get_stats,
"get-db-usage": self.handle_get_db_usage,
+ "get-db-query-columns": self.handle_get_db_query_columns,
# Not always read-only, but internally checks if the server is
# read-only
"report": self.handle_report,
@@ -572,6 +573,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_get_db_usage(self, request):
return {"usage": await self.db.get_usage()}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_query_columns(self, request):
+ return {"columns": await self.db.get_query_columns()}
+
# The authentication API is always allowed
async def handle_auth(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py
index 818b5195..cee04bff 100644
--- a/lib/hashserv/sqlalchemy.py
+++ b/lib/hashserv/sqlalchemy.py
@@ -415,3 +415,13 @@ class Database(object):
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for table in (UnihashesV2, OuthashesV2):
+ for c in table.__table__.columns:
+ if not isinstance(c.type, Text):
+ continue
+ columns.add(c.key)
+
+ return list(columns)
diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py
index dfdccbba..f65036be 100644
--- a/lib/hashserv/sqlite.py
+++ b/lib/hashserv/sqlite.py
@@ -399,3 +399,10 @@ class Database(object):
"rows": cursor.fetchone()[0],
}
return usage
+
+ async def get_query_columns(self):
+ columns = set()
+ for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
+ if typ.startswith("TEXT"):
+ columns.add(name)
+ return list(columns)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 9d5bec24..fc69acaf 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -776,6 +776,14 @@ class HashEquivalenceCommonTests(object):
self.assertIn("rows", usage[name])
self.assertTrue(isinstance(usage[name]["rows"], int))
+ def test_get_db_query_columns(self):
+ columns = self.client.get_db_query_columns()
+
+ self.assertTrue(isinstance(columns, list))
+ self.assertTrue(len(columns) > 0)
+
+ for col in columns:
+ self.client.remove({col: ""})
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 17/22] hashserv: test: Add bitbake-hashclient tests
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (15 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 16/22] hashserv: Add database column query API Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
` (5 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
The bitbake-hashclient command-line tool now has a lot more features
which should be tested, so add some tests for them.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 300 ++++++++++++++++++++++++++++++++++++++----
1 file changed, 277 insertions(+), 23 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index fc69acaf..a80ccd57 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -19,6 +19,14 @@ import unittest
import socket
import time
import signal
+import subprocess
+import json
+import re
+from pathlib import Path
+
+
+THIS_DIR = Path(__file__).parent
+BIN_DIR = THIS_DIR.parent.parent / "bin"
def server_prefunc(server, idx):
logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
@@ -103,8 +111,22 @@ class HashEquivalenceTestSetup(object):
result = client.get_unihash(self.METHOD, taskhash)
self.assertEqual(result, unihash)
+ def assertUserPerms(self, user, permissions):
+ with self.auth_client(user) as client:
+ info = client.get_user()
+ self.assertEqual(info, {
+ "username": user["username"],
+ "permissions": permissions,
+ })
+
+ def assertUserCanAuth(self, user):
+ with self.start_client(self.auth_server.address) as client:
+ client.auth(user["username"], user["token"])
+
+ def assertUserCannotAuth(self, user):
+ with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ client.auth(user["username"], user["token"])
-class HashEquivalenceCommonTests(object):
def create_test_hash(self, client):
# Simple test that hashes can be created
taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
@@ -117,6 +139,24 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
return taskhash, outhash, unihash
+ def run_hashclient(self, args, **kwargs):
+ try:
+ p = subprocess.run(
+ [BIN_DIR / "bitbake-hashclient"] + args,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ encoding="utf-8",
+ **kwargs
+ )
+ except subprocess.CalledProcessError as e:
+ print(e.output)
+ raise e
+
+ print(p.stdout)
+ return p
+
+
+class HashEquivalenceCommonTests(object):
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -161,7 +201,7 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash, unihash)
def test_remove_taskhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"taskhash": taskhash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -170,13 +210,13 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_unihash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"unihash": unihash})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
def test_remove_outhash(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"outhash": outhash})
self.assertGreater(result["count"], 0)
@@ -184,7 +224,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_remove_method(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
result = self.client.remove({"method": self.METHOD})
self.assertGreater(result["count"], 0)
self.assertClientGetHash(self.client, taskhash, None)
@@ -193,7 +233,7 @@ class HashEquivalenceCommonTests(object):
self.assertIsNone(result_outhash)
def test_clean_unused(self):
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Clean the database, which should not remove anything because all hashes an in-use
result = self.client.clean_unused(0)
@@ -497,7 +537,7 @@ class HashEquivalenceCommonTests(object):
admin_client = self.start_auth_server()
# Create hashes with non-authenticated server
- taskhash, outhash, unihash = self.test_create_hash()
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
# Validate hash can be retrieved using authenticated client
with self.auth_perms("@read") as client:
@@ -534,14 +574,6 @@ class HashEquivalenceCommonTests(object):
with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
client.refresh_token()
- def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
- client.auth(user["username"], user["token"])
-
- def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
- client.auth(user["username"], user["token"])
-
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
@@ -650,14 +682,6 @@ class HashEquivalenceCommonTests(object):
with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
- def assertUserPerms(self, user, permissions):
- with self.auth_client(user) as client:
- info = client.get_user()
- self.assertEqual(info, {
- "username": user["username"],
- "permissions": permissions,
- })
-
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
@@ -785,6 +809,236 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+
+class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
+ def get_server_addr(self, server_idx):
+ return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
+
+ def test_stats(self):
+ self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+
+ def test_stress(self):
+ self.run_hashclient(["--address", self.server_address, "stress"], check=True)
+
+ def test_remove_taskhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "taskhash", taskhash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_unihash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ def test_remove_outhash(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "outhash", outhash,
+ ], check=True)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_remove_method(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "method", self.METHOD,
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
+ self.assertIsNone(result_outhash)
+
+ def test_clean_unused(self):
+ taskhash, outhash, unihash = self.create_test_hash(self.client)
+
+ # Clean the database, which should not remove anything because all hashes an in-use
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ self.assertClientGetHash(self.client, taskhash, unihash)
+
+ # Remove the unihash. The row in the outhash table should still be present
+ self.run_hashclient([
+ "--address", self.server_address,
+ "remove",
+ "--where", "unihash", unihash,
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNotNone(result_outhash)
+
+ # Now clean with no minimum age which will remove the outhash
+ self.run_hashclient([
+ "--address", self.server_address,
+ "clean-unused", "0",
+ ], check=True)
+ result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
+ self.assertIsNone(result_outhash)
+
+ def test_refresh_token(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read", "@report"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "refresh-token"
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ print("New token is %r" % new_token)
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", new_token,
+ "get-user"
+ ], check=True)
+
+ def test_set_user_perms(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "set-user-perms",
+ "-u", user["username"],
+ "@read", "@report",
+ ], check=True)
+
+ new_user = admin_client.get_user(user["username"])
+
+ self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
+
+ def test_get_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-user",
+ "-u", user["username"],
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", user["username"],
+ "--password", user["token"],
+ "get-user",
+ ], check=True)
+
+ self.assertIn("Username:", p.stdout)
+ self.assertIn("Permissions:", p.stdout)
+
+ def test_get_all_users(self):
+ admin_client = self.start_auth_server()
+
+ admin_client.new_user("test-user1", ["@read"])
+ admin_client.new_user("test-user2", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "get-all-users",
+ ], check=True)
+
+ self.assertIn("admin", p.stdout)
+ self.assertIn("test-user1", p.stdout)
+ self.assertIn("test-user2", p.stdout)
+
+ def test_new_user(self):
+ admin_client = self.start_auth_server()
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "new-user",
+ "-u", "test-user",
+ "@read", "@report",
+ ], check=True)
+
+ new_token = None
+ for l in p.stdout.splitlines():
+ l = l.rstrip()
+ m = re.match(r'Token: +(.*)$', l)
+ if m is not None:
+ new_token = m.group(1)
+
+ self.assertTrue(new_token)
+
+ user = {
+ "username": "test-user",
+ "token": new_token,
+ }
+
+ self.assertUserPerms(user, ["@read", "@report"])
+
+ def test_delete_user(self):
+ admin_client = self.start_auth_server()
+
+ user = admin_client.new_user("test-user", ["@read"])
+
+ p = self.run_hashclient([
+ "--address", self.auth_server.address,
+ "--login", admin_client.username,
+ "--password", admin_client.password,
+ "delete-user",
+ "-u", user["username"],
+ ], check=True)
+
+
+ self.assertIsNone(admin_client.get_user(user["username"]))
+
+ def test_get_db_usage(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-usage",
+ ], check=True)
+
+ def test_get_db_query_columns(self):
+ p = self.run_hashclient([
+ "--address", self.server_address,
+ "get-db-query-columns",
+ ], check=True)
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 18/22] bitbake-hashclient: Output stats in JSON format
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (16 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 17/22] hashserv: test: Add bitbake-hashclient tests Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
` (4 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Outputting the stats in JSON format makes more sense as it's easier for
a downstream tool to parse if desired.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashclient | 3 ++-
lib/hashserv/tests.py | 3 ++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/bin/bitbake-hashclient b/bin/bitbake-hashclient
index 58aa02ee..3ff7b763 100755
--- a/bin/bitbake-hashclient
+++ b/bin/bitbake-hashclient
@@ -15,6 +15,7 @@ import threading
import time
import warnings
import netrc
+import json
warnings.simplefilter("default")
try:
@@ -56,7 +57,7 @@ def main():
s = client.reset_stats()
else:
s = client.get_stats()
- pprint.pprint(s)
+ print(json.dumps(s, sort_keys=True, indent=4))
return 0
def handle_stress(args, client):
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index a80ccd57..2d78f9e9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -815,7 +815,8 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
def test_stats(self):
- self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
+ json.loads(p.stdout)
def test_stress(self):
self.run_hashclient(["--address", self.server_address, "stress"], check=True)
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (17 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 18/22] bitbake-hashclient: Output stats in JSON format Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
` (3 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Space separation is more natural when setting the value from an
environment variable, so allow that here for convenience.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
bin/bitbake-hashserv | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/bin/bitbake-hashserv b/bin/bitbake-hashserv
index 1085d058..c560b3e5 100755
--- a/bin/bitbake-hashserv
+++ b/bin/bitbake-hashserv
@@ -127,7 +127,10 @@ websocket, as in "wss://SERVER:PORT"
logger.addHandler(console)
read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only
- anon_perms = args.anon_perms.split(",")
+ if "," in args.anon_perms:
+ anon_perms = args.anon_perms.split(",")
+ else:
+ anon_perms = args.anon_perms.split()
server = hashserv.create_server(
args.bind,
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 20/22] hashserv: tests: Allow authentication for external server tests
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (18 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 19/22] bitbake-hashserver: Allow anonymous permissions to be space separated Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 21/22] hashserv: Allow self-service deletion Joshua Watt
` (2 subsequent siblings)
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If BB_TEST_HASHSERV_USERNAME and BB_TEST_HASHSERV_PASSWORD are provided
for a server admin user, the authentication tests for the external
hashserver will run. In addition, any users that get created will now be
deleted when the test finishes.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/tests.py | 109 ++++++++++++++++++++++++++++--------------
1 file changed, 74 insertions(+), 35 deletions(-)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 2d78f9e9..5d209ffb 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -84,17 +84,13 @@ class HashEquivalenceTestSetup(object):
return self.server.address
def start_auth_server(self):
- self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
- self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password")
+ auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
+ self.auth_server_address = auth_server.address
+ self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
return self.admin_client
def auth_client(self, user):
- return self.start_client(self.auth_server.address, user["username"], user["token"])
-
- def auth_perms(self, *permissions):
- self.client_index += 1
- user = self.admin_client.new_user(f"user-{self.client_index}", permissions)
- return self.auth_client(user)
+ return self.start_client(self.auth_server_address, user["username"], user["token"])
def setUp(self):
if sys.version_info < (3, 5, 0):
@@ -120,11 +116,11 @@ class HashEquivalenceTestSetup(object):
})
def assertUserCanAuth(self, user):
- with self.start_client(self.auth_server.address) as client:
+ with self.start_client(self.auth_server_address) as client:
client.auth(user["username"], user["token"])
def assertUserCannotAuth(self, user):
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.auth(user["username"], user["token"])
def create_test_hash(self, client):
@@ -157,6 +153,26 @@ class HashEquivalenceTestSetup(object):
class HashEquivalenceCommonTests(object):
+ def auth_perms(self, *permissions):
+ self.client_index += 1
+ user = self.create_user(f"user-{self.client_index}", permissions)
+ return self.auth_client(user)
+
+ def create_user(self, username, permissions, *, client=None):
+ def remove_user(username):
+ try:
+ self.admin_client.delete_user(username)
+ except bb.asyncrpc.InvokeError:
+ pass
+
+ if client is None:
+ client = self.admin_client
+
+ user = client.new_user(username, permissions)
+ self.addCleanup(remove_user, username)
+
+ return user
+
def test_create_hash(self):
return self.create_test_hash(self.client)
@@ -571,14 +587,14 @@ class HashEquivalenceCommonTests(object):
def test_auth_no_token_refresh_from_anon_user(self):
self.start_auth_server()
- with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError):
+ with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
client.refresh_token()
def test_auth_self_token_refresh(self):
admin_client = self.start_auth_server()
# Create a new user with no permissions
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client:
new_user = client.refresh_token()
@@ -601,7 +617,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_token_refresh(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.refresh_token(user["username"])
@@ -617,7 +633,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_self_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -632,7 +648,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -649,7 +665,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_reconnect(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
user_info = user.copy()
del user_info["token"]
@@ -665,7 +681,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_delete_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
# No self service
with self.auth_client(user) as client, self.assertRaises(InvokeError):
@@ -685,7 +701,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_set_user_perms(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
self.assertUserPerms(user, [])
@@ -710,7 +726,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_get_all_users(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", [])
+ user = self.create_user("test-user", [])
with self.auth_client(user) as client, self.assertRaises(InvokeError):
client.get_all_users()
@@ -744,10 +760,10 @@ class HashEquivalenceCommonTests(object):
permissions.sort()
with self.auth_perms() as client, self.assertRaises(InvokeError):
- client.new_user("test-user", permissions)
+ self.create_user("test-user", permissions, client=client)
with self.auth_perms("@user-admin") as client:
- user = client.new_user("test-user", permissions)
+ user = self.create_user("test-user", permissions, client=client)
self.assertIn("token", user)
self.assertEqual(user["username"], "test-user")
self.assertEqual(user["permissions"], permissions)
@@ -755,7 +771,7 @@ class HashEquivalenceCommonTests(object):
def test_auth_become_user(self):
admin_client = self.start_auth_server()
- user = admin_client.new_user("test-user", ["@read", "@report"])
+ user = self.create_user("test-user", ["@read", "@report"])
user_info = user.copy()
del user_info["token"]
@@ -898,7 +914,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read", "@report"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"refresh-token"
@@ -916,7 +932,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
print("New token is %r" % new_token)
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", new_token,
"get-user"
@@ -928,7 +944,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"set-user-perms",
@@ -946,7 +962,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-user",
@@ -957,7 +973,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
self.assertIn("Permissions:", p.stdout)
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", user["username"],
"--password", user["token"],
"get-user",
@@ -973,7 +989,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client.new_user("test-user2", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"get-all-users",
@@ -987,7 +1003,7 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
admin_client = self.start_auth_server()
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"new-user",
@@ -1017,14 +1033,13 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
user = admin_client.new_user("test-user", ["@read"])
p = self.run_hashclient([
- "--address", self.auth_server.address,
+ "--address", self.auth_server_address,
"--login", admin_client.username,
"--password", admin_client.password,
"delete-user",
"-u", user["username"],
], check=True)
-
self.assertIsNone(admin_client.get_user(user["username"]))
def test_get_db_usage(self):
@@ -1104,19 +1119,43 @@ class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocket
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
- def start_test_server(self):
- if 'BB_TEST_HASHSERV' not in os.environ:
- self.skipTest('BB_TEST_HASHSERV not defined to test an external server')
+ def get_env(self, name):
+ v = os.environ.get(name)
+ if not v:
+ self.skipTest(f'{name} not defined to test an external server')
+ return v
- return os.environ['BB_TEST_HASHSERV']
+ def start_test_server(self):
+ return self.get_env('BB_TEST_HASHSERV')
def start_server(self, *args, **kwargs):
self.skipTest('Cannot start local server when testing external servers')
+ def start_auth_server(self):
+
+ self.auth_server_address = self.server_address
+ self.admin_client = self.start_client(
+ self.server_address,
+ username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
+ password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
+ )
+ return self.admin_client
+
def setUp(self):
super().setUp()
+ if "BB_TEST_HASHSERV_USERNAME" in os.environ:
+ self.client = self.start_client(
+ self.server_address,
+ username=os.environ["BB_TEST_HASHSERV_USERNAME"],
+ password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
+ )
self.client.remove({"method": self.METHOD})
def tearDown(self):
self.client.remove({"method": self.METHOD})
super().tearDown()
+
+
+ def test_auth_get_all_users(self):
+ self.skipTest("Cannot test all users with external server")
+
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 21/22] hashserv: Allow self-service deletion
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (19 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 20/22] hashserv: tests: Allow authentication for external server tests Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
2023-11-09 10:23 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Alexandre Belloni
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
Allows users to self-service deletion of their own user accounts
(meaning, they can delete their own accounts without special
permissions).
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 2 +-
lib/hashserv/tests.py | 7 +++++--
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 8c3d20b6..439962f7 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -709,7 +709,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"token": token,
}
- @permissions(USER_ADMIN_PERM, allow_anon=False)
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
async def handle_delete_user(self, request):
username = str(request["username"])
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index 5d209ffb..f0be8679 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -683,10 +683,13 @@ class HashEquivalenceCommonTests(object):
user = self.create_user("test-user", [])
- # No self service
- with self.auth_client(user) as client, self.assertRaises(InvokeError):
+ # self service
+ with self.auth_client(user) as client:
client.delete_user(user["username"])
+ self.assertIsNone(admin_client.get_user(user["username"]))
+ user = self.create_user("test-user", [])
+
with self.auth_perms() as client, self.assertRaises(InvokeError):
client.delete_user(user["username"])
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* [bitbake-devel][PATCH v6 22/22] hashserv: server: Add owner if user is logged in
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (20 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 21/22] hashserv: Allow self-service deletion Joshua Watt
@ 2023-11-03 14:26 ` Joshua Watt
2023-11-09 10:23 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Alexandre Belloni
22 siblings, 0 replies; 138+ messages in thread
From: Joshua Watt @ 2023-11-03 14:26 UTC (permalink / raw)
To: bitbake-devel; +Cc: Joshua Watt
If a user is authenticated with the server, report them as the owner of
a report
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
lib/hashserv/server.py | 3 +++
lib/hashserv/tests.py | 9 +++++++++
2 files changed, 12 insertions(+)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 439962f7..a8650783 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -475,6 +475,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in data:
outhash_data[k] = data[k]
+ if self.user:
+ outhash_data["owner"] = self.user.username
+
# Insert the new entry, unless it already exists
if await self.db.insert_outhash(outhash_data):
# If this row is new, check if it is equivalent to another
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index f0be8679..a9e6fdf9 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -828,6 +828,15 @@ class HashEquivalenceCommonTests(object):
for col in columns:
self.client.remove({col: ""})
+ def test_auth_is_owner(self):
+ admin_client = self.start_auth_server()
+
+ user = self.create_user("test-user", ["@read", "@report"])
+ with self.auth_client(user) as client:
+ taskhash, outhash, unihash = self.create_test_hash(client)
+ data = client.get_taskhash(self.METHOD, taskhash, True)
+ self.assertEqual(data["owner"], user["username"])
+
class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
def get_server_addr(self, server_idx):
--
2.34.1
^ permalink raw reply related [flat|nested] 138+ messages in thread* Re: [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 00/22] Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management Joshua Watt
` (21 preceding siblings ...)
2023-11-03 14:26 ` [bitbake-devel][PATCH v6 22/22] hashserv: server: Add owner if user is logged in Joshua Watt
@ 2023-11-09 10:23 ` Alexandre Belloni
22 siblings, 0 replies; 138+ messages in thread
From: Alexandre Belloni @ 2023-11-09 10:23 UTC (permalink / raw)
To: Joshua Watt; +Cc: bitbake-devel
Hello,
I missed it earlier but I got this:
https://autobuilder.yoctoproject.org/typhoon/#/builders/82/builds/5673/steps/12/logs/stdio
On 03/11/2023 08:26:18-0600, Joshua Watt wrote:
> This patch series reworks the bitbake asyncrpc API to add a WebSockets
> implementation for both the client and server. The hash equivalence
> server is updated to allow using this new API (the PR server can also be
> updated in the future if desired).
>
> In addition, the database backed for the hash equivalence server is
> abstracted so that sqlalchemy can optionally be used instead of sqlite.
> This allows using "big metal" databases as the backend, which allows the
> hash equivalence server to scale to a large number of queries.
>
> Note that both websockets and sqlalchemy require 3rd party python
> modules to function. However, these modules are optional unless the user
> desires to use the APIs.
>
> Also, user management is added. This allows user accounts to be
> registered with the server and users can be given permissions to do
> certain operations on the server. Users are not (necessarily) required
> to login to access the server, as permissions can granted to anonymous
> users. The default permissions will give anonymous users the same
> permissions that they would have before user accounts were added so as
> to retain backward compatibility, but server admins will likely want to
> change this.
>
> V3: Remove RFC status; patches are ready for review
> V4: Fixed protocol breakage with mixing older and newer clients/servers
> V5: Fixed compatibility with Python 3.8
> V6: Fixed protocol incompatibility when exiting stream state that broke
> mixing older and new clients/servers
>
> Joshua Watt (22):
> asyncrpc: Abstract sockets
> hashserv: Add websocket connection implementation
> asyncrpc: Add context manager API
> hashserv: tests: Add external database tests
> asyncrpc: Prefix log messages with client info
> bitbake-hashserv: Allow arguments from environment
> hashserv: Abstract database
> hashserv: Add SQLalchemy backend
> hashserv: Implement read-only version of "report" RPC
> asyncrpc: Add InvokeError
> asyncrpc: client: Prevent double closing of loop
> asyncrpc: client: Add disconnect API
> hashserv: Add user permissions
> hashserv: Add become-user API
> hashserv: Add db-usage API
> hashserv: Add database column query API
> hashserv: test: Add bitbake-hashclient tests
> bitbake-hashclient: Output stats in JSON format
> bitbake-hashserver: Allow anonymous permissions to be space separated
> hashserv: tests: Allow authentication for external server tests
> hashserv: Allow self-service deletion
> hashserv: server: Add owner if user is logged in
>
> bin/bitbake-hashclient | 145 +++++-
> bin/bitbake-hashserv | 132 ++++-
> lib/bb/asyncrpc/__init__.py | 33 +-
> lib/bb/asyncrpc/client.py | 120 ++---
> lib/bb/asyncrpc/connection.py | 146 ++++++
> lib/bb/asyncrpc/exceptions.py | 21 +
> lib/bb/asyncrpc/serv.py | 365 ++++++++-----
> lib/hashserv/__init__.py | 190 +++----
> lib/hashserv/client.py | 147 +++++-
> lib/hashserv/server.py | 952 +++++++++++++++++++++-------------
> lib/hashserv/sqlalchemy.py | 427 +++++++++++++++
> lib/hashserv/sqlite.py | 408 +++++++++++++++
> lib/hashserv/tests.py | 736 +++++++++++++++++++++++++-
> lib/prserv/client.py | 8 +-
> lib/prserv/serv.py | 37 +-
> 15 files changed, 3060 insertions(+), 807 deletions(-)
> create mode 100644 lib/bb/asyncrpc/connection.py
> create mode 100644 lib/bb/asyncrpc/exceptions.py
> create mode 100644 lib/hashserv/sqlalchemy.py
> create mode 100644 lib/hashserv/sqlite.py
>
> --
> 2.34.1
>
>
> -=-=-=-=-=-=-=-=-=-=-=-
> Links: You receive all messages sent to this group.
> View/Reply Online (#15421): https://lists.openembedded.org/g/bitbake-devel/message/15421
> Mute This Topic: https://lists.openembedded.org/mt/102364903/3617179
> Group Owner: bitbake-devel+owner@lists.openembedded.org
> Unsubscribe: https://lists.openembedded.org/g/bitbake-devel/unsub [alexandre.belloni@bootlin.com]
> -=-=-=-=-=-=-=-=-=-=-=-
>
--
Alexandre Belloni, co-owner and COO, Bootlin
Embedded Linux and Kernel engineering
https://bootlin.com
^ permalink raw reply [flat|nested] 138+ messages in thread