Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make entry_point() async #688

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/gallia/cli/gallia.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import json
import os
import sys
Expand Down Expand Up @@ -182,7 +184,7 @@ def __call__(
logger_name="", # Take over the root logger
)

sys.exit(get_command(config).entry_point())
sys.exit(asyncio.run(get_command(config).entry_point()))


def version() -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/gallia/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from gallia.command.base import AsyncScript, BaseCommand, Scanner, Script
from gallia.command.base import AsyncScript, BaseCommand, Scanner
from gallia.command.uds import UDSDiscoveryScanner, UDSScanner

__all__ = [
"BaseCommand",
"AsyncScript",
"Script",
"Scanner",
"UDSScanner",
"UDSDiscoveryScanner",
Expand Down
70 changes: 11 additions & 59 deletions src/gallia/command/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _open_lockfile(self, path: Path) -> int | None:
logger.notice("opening lockfile…")
return os.open(path, os.O_RDONLY)

def _aquire_flock(self: Flockable) -> None:
async def _aquire_flock(self: Flockable) -> None:
assert self._lock_file_fd is not None

try:
Expand All @@ -87,7 +87,7 @@ def _aquire_flock(self: Flockable) -> None:
fcntl.flock(self._lock_file_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except BlockingIOError:
logger.notice("waiting for flock…")
fcntl.flock(self._lock_file_fd, fcntl.LOCK_EX)
await asyncio.to_thread(fcntl.flock, self._lock_file_fd, fcntl.LOCK_EX)
logger.info("Acquired lock. Continuing…")

def _release_flock(self: Flockable) -> None:
Expand All @@ -103,7 +103,7 @@ def _open_lockfile(self, path: Path) -> int | None:
logger.warn("lockfile in windows is not supported")
return None

def _aquire_flock(self) -> None:
async def _aquire_flock(self) -> None:
pass

def _release_flock(self) -> None:
Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(self, config: BaseCommandConfig) -> None:
self.log_file_handlers = []

@abstractmethod
def run(self) -> int: ...
async def run(self) -> int: ...

def run_hook(self, variant: HookVariant, exit_code: int | None = None) -> None:
script = self.config.pre_hook if variant == HookVariant.PRE else self.config.post_hook
Expand Down Expand Up @@ -239,8 +239,6 @@ def run_hook(self, variant: HookVariant, exit_code: int | None = None) -> None:
logger.info(p.stderr.strip(), extra={"tags": [hook_id, "stderr"]})

async def _db_insert_run_meta(self) -> None:
# TODO: This function has to call `connect` and `disconnect` in order to be safely run in an own event loop

if self.config.db is not None:
self.db_handler = DBHandler(self.config.db)
await self.db_handler.connect()
Expand All @@ -252,15 +250,9 @@ async def _db_insert_run_meta(self) -> None:
path=self.artifacts_dir,
)

await self.db_handler.disconnect()

async def _db_finish_run_meta(self) -> None:
# TODO: This function has to call `connect` and `disconnect` in order to be safely run in an own event loop

if self.db_handler is not None:
if self.db_handler is not None and self.db_handler.connection is not None:
if self.db_handler.meta is not None:
await self.db_handler.connect()

try:
await self.db_handler.complete_run_meta(
datetime.now(UTC).astimezone(), self.run_meta.exit_code, self.artifacts_dir
Expand Down Expand Up @@ -323,11 +315,11 @@ def prepare_artifactsdir(

raise ValueError("base_dir or force_path must be different from None")

def entry_point(self) -> int:
async def entry_point(self) -> int:
if (p := self.config.lock_file) is not None:
try:
self._lock_file_fd = self._open_lockfile(p)
self._aquire_flock()
await self._aquire_flock()
except OSError as e:
logger.critical(f"Unable to lock {p}: {e}")
return exitcodes.OSFILE
Expand All @@ -347,11 +339,11 @@ def entry_point(self) -> int:
if self.config.hooks:
self.run_hook(HookVariant.PRE)

asyncio.run(self._db_insert_run_meta())
await self._db_insert_run_meta()

exit_code = 0
try:
exit_code = self.run()
exit_code = await self.run()
except KeyboardInterrupt:
exit_code = 128 + signal.SIGINT
# Ensure that META.json gets written in the case a
Expand All @@ -377,7 +369,7 @@ def entry_point(self) -> int:
self.run_meta.exit_code = exit_code
self.run_meta.end_time = datetime.now(tz).isoformat()

asyncio.run(self._db_finish_run_meta())
await self._db_finish_run_meta()

if self.HAS_ARTIFACTS_DIR:
self.artifacts_dir.joinpath(FileNames.META.value).write_text(
Expand All @@ -403,39 +395,6 @@ def entry_point(self) -> int:
return exit_code


class ScriptConfig(
BaseCommandConfig,
ABC,
cli_group=BaseCommandConfig._cli_group,
config_section=BaseCommandConfig._config_section,
):
pass


class Script(BaseCommand, ABC):
"""Script is a base class for a synchronous gallia command.
To implement a script, create a subclass and implement the
.main() method."""

GROUP = "script"

def setup(self) -> None: ...

@abstractmethod
def main(self) -> None: ...

def teardown(self) -> None: ...

def run(self) -> int:
self.setup()
try:
self.main()
finally:
self.teardown()

return exitcodes.OK


class AsyncScriptConfig(
BaseCommandConfig,
ABC,
Expand All @@ -459,15 +418,12 @@ async def main(self) -> None: ...

async def teardown(self) -> None: ...

async def _run(self) -> None:
async def run(self) -> int:
await self.setup()
try:
await self.main()
finally:
await self.teardown()

def run(self) -> int:
asyncio.run(self._run())
return exitcodes.OK


Expand Down Expand Up @@ -539,10 +495,6 @@ async def main(self) -> None: ...
async def setup(self) -> None:
from gallia.plugins.plugin import load_transport

if self.db_handler is not None:
# Open the DB handler that will be closed in `teardown`
await self.db_handler.connect()

if self.config.power_supply is not None:
self.power_supply = await PowerSupply.connect(self.config.power_supply)
if self.config.power_cycle is True:
Expand Down
10 changes: 5 additions & 5 deletions src/gallia/commands/script/flexray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pickle
import sys

from gallia.command.base import AsyncScriptConfig, ScriptConfig
from gallia.command.base import AsyncScriptConfig
from gallia.command.config import AutoInt, Field, Ranges

assert sys.platform == "win32"

from gallia.command import AsyncScript, Script
from gallia.command import AsyncScript
from gallia.transports._ctypes_vector_xl_wrapper import FlexRayCtypesBackend
from gallia.transports.flexray_vector import FlexRayFrame, RawFlexRayTransport, parse_frame_type

Expand Down Expand Up @@ -71,12 +71,12 @@ async def main(self) -> None:
print(f"slot_id: {frame.slot_id:03d}; data: {frame.data.hex()}")


class FRConfigDumpConfig(ScriptConfig):
class FRConfigDumpConfig(AsyncScriptConfig):
channel: int | None = Field(description="the channel number of the flexray device")
pretty: bool = Field(False, description="pretty print the configuration", short="p")


class FRConfigDump(Script):
class FRConfigDump(AsyncScript):
"""Dump the flexray configuration as base64"""

CONFIG_TYPE = FRConfigDumpConfig
Expand All @@ -86,7 +86,7 @@ def __init__(self, config: FRConfigDumpConfig):
super().__init__(config)
self.config: FRConfigDumpConfig = config

def main(self) -> None:
async def main(self) -> None:
backend = FlexRayCtypesBackend.create(self.config.channel)
raw_config = backend.get_configuration()
config = pickle.dumps(raw_config)
Expand Down
20 changes: 10 additions & 10 deletions src/gallia/commands/script/rerun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import importlib
import json
import sys
Expand All @@ -14,14 +13,14 @@
from pydantic import model_validator

from gallia.command import BaseCommand
from gallia.command.base import Script, ScriptConfig
from gallia.command.base import AsyncScript, AsyncScriptConfig
from gallia.command.config import Field
from gallia.log import get_logger

logger = get_logger(__name__)


class RerunnerConfig(ScriptConfig):
class RerunnerConfig(AsyncScriptConfig):
id: int | None = Field(None, description="The id of the run_meta entry in the db")
file: Path | None = Field(None, description="The path of the META.json in the logs")

Expand All @@ -40,20 +39,21 @@ def check_meta_source(self) -> Self:
return self


class Rerunner(Script):
class Rerunner(AsyncScript):
CONFIG_TYPE = RerunnerConfig
SHORT_HELP = "Rerun a previous gallia command based on its run_meta in the database"

def __init__(self, config: RerunnerConfig):
super().__init__(config)
self.config: RerunnerConfig = config

def main(self) -> None:
async def main(self) -> None:
if self.config.id is not None:
script, config = self.db()
script, config = await self.db()
else:
script, config = self.file()

# TODO: Make a interface for this to avoid error-prone parsing.
script_parts = script.split(".")
module = ".".join(script_parts[:-1])
class_name = script_parts[-1]
Expand All @@ -63,9 +63,9 @@ def main(self) -> None:
gallia_class: type[BaseCommand] = getattr(importlib.import_module(module), class_name)
command = gallia_class(gallia_class.CONFIG_TYPE(**config))

sys.exit(command.entry_point())
sys.exit(await command.entry_point())

def db(self) -> tuple[str, Mapping[str, Any]]:
async def db(self) -> tuple[str, Mapping[str, Any]]:
assert self.config.id is not None

query = "SELECT script, config FROM run_meta WHERE id = ?"
Expand All @@ -77,8 +77,8 @@ def db(self) -> tuple[str, Mapping[str, Any]]:

assert connection is not None

cursor: aiosqlite.Cursor = asyncio.run(connection.execute(query, parameters))
row = asyncio.run(cursor.fetchone())
cursor: aiosqlite.Cursor = await connection.execute(query, parameters)
row = await cursor.fetchone()

if row is None:
logger.error(f"There id no run_meta entry with the id {self.config.id}")
Expand Down
10 changes: 10 additions & 0 deletions tests/bats/002-scans.bats
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ teardown() {
gallia scan uds services --sessions 1 2 --check-session
}

@test "scan services with database" {
local db_file="${BATS_TEST_NAME}.sqlite"

gallia scan uds services --db "$db_file" --sessions 1 2 --check-session


# TODO: This is not finished; check here that the database contains expected fields
sqlite3 "$db_file" "SELECT * FROM run_meta;" >&3
}

@test "scan sessions" {
gallia scan uds sessions
}
Expand Down
Loading