Skip to content

Commit

Permalink
fixup! wip
Browse files Browse the repository at this point in the history
  • Loading branch information
rumpelsepp committed Feb 28, 2025
1 parent ca62258 commit b2df6fb
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 57 deletions.
4 changes: 3 additions & 1 deletion src/gallia/cli/gallia.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import json
import os
import sys
Expand Down Expand Up @@ -182,7 +184,7 @@ def __call__(
logger_name="", # Take over the root logger
)

sys.exit(get_command(config).entry_point())
sys.exit(asyncio.run(get_command(config).entry_point()))


def version() -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/gallia/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from gallia.command.base import AsyncScript, BaseCommand, Scanner, Script
from gallia.command.base import AsyncScript, BaseCommand, Scanner
from gallia.command.uds import UDSDiscoveryScanner, UDSScanner

__all__ = [
"BaseCommand",
"AsyncScript",
"Script",
"Scanner",
"UDSScanner",
"UDSDiscoveryScanner",
Expand Down
42 changes: 3 additions & 39 deletions src/gallia/command/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, config: BaseCommandConfig) -> None:
self.log_file_handlers = []

@abstractmethod
def run(self) -> int: ...
async def run(self) -> int: ...

def run_hook(self, variant: HookVariant, exit_code: int | None = None) -> None:
script = self.config.pre_hook if variant == HookVariant.PRE else self.config.post_hook
Expand Down Expand Up @@ -343,7 +343,7 @@ async def entry_point(self) -> int:

exit_code = 0
try:
exit_code = self.run()
exit_code = await self.run()
except KeyboardInterrupt:
exit_code = 128 + signal.SIGINT
# Ensure that META.json gets written in the case a
Expand Down Expand Up @@ -395,39 +395,6 @@ async def entry_point(self) -> int:
return exit_code


class ScriptConfig(
BaseCommandConfig,
ABC,
cli_group=BaseCommandConfig._cli_group,
config_section=BaseCommandConfig._config_section,
):
pass


class Script(BaseCommand, ABC):
"""Script is a base class for a synchronous gallia command.
To implement a script, create a subclass and implement the
.main() method."""

GROUP = "script"

def setup(self) -> None: ...

@abstractmethod
def main(self) -> None: ...

def teardown(self) -> None: ...

def run(self) -> int:
self.setup()
try:
self.main()
finally:
self.teardown()

return exitcodes.OK


class AsyncScriptConfig(
BaseCommandConfig,
ABC,
Expand All @@ -451,15 +418,12 @@ async def main(self) -> None: ...

async def teardown(self) -> None: ...

async def _run(self) -> None:
async def run(self) -> int:
await self.setup()
try:
await self.main()
finally:
await self.teardown()

def run(self) -> int:
asyncio.run(self._run())
return exitcodes.OK


Expand Down
10 changes: 5 additions & 5 deletions src/gallia/commands/script/flexray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pickle
import sys

from gallia.command.base import AsyncScriptConfig, ScriptConfig
from gallia.command.base import AsyncScriptConfig
from gallia.command.config import AutoInt, Field, Ranges

assert sys.platform == "win32"

from gallia.command import AsyncScript, Script
from gallia.command import AsyncScript
from gallia.transports._ctypes_vector_xl_wrapper import FlexRayCtypesBackend
from gallia.transports.flexray_vector import FlexRayFrame, RawFlexRayTransport, parse_frame_type

Expand Down Expand Up @@ -71,12 +71,12 @@ async def main(self) -> None:
print(f"slot_id: {frame.slot_id:03d}; data: {frame.data.hex()}")


class FRConfigDumpConfig(ScriptConfig):
class FRConfigDumpConfig(AsyncScriptConfig):
channel: int | None = Field(description="the channel number of the flexray device")
pretty: bool = Field(False, description="pretty print the configuration", short="p")


class FRConfigDump(Script):
class FRConfigDump(AsyncScript):
"""Dump the flexray configuration as base64"""

CONFIG_TYPE = FRConfigDumpConfig
Expand All @@ -86,7 +86,7 @@ def __init__(self, config: FRConfigDumpConfig):
super().__init__(config)
self.config: FRConfigDumpConfig = config

def main(self) -> None:
async def main(self) -> None:
backend = FlexRayCtypesBackend.create(self.config.channel)
raw_config = backend.get_configuration()
config = pickle.dumps(raw_config)
Expand Down
20 changes: 10 additions & 10 deletions src/gallia/commands/script/rerun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import importlib
import json
import sys
Expand All @@ -14,14 +13,14 @@
from pydantic import model_validator

from gallia.command import BaseCommand
from gallia.command.base import Script, ScriptConfig
from gallia.command.base import AsyncScript, AsyncScriptConfig
from gallia.command.config import Field
from gallia.log import get_logger

logger = get_logger(__name__)


class RerunnerConfig(ScriptConfig):
class RerunnerConfig(AsyncScriptConfig):
id: int | None = Field(None, description="The id of the run_meta entry in the db")
file: Path | None = Field(None, description="The path of the META.json in the logs")

Expand All @@ -40,20 +39,21 @@ def check_meta_source(self) -> Self:
return self


class Rerunner(Script):
class Rerunner(AsyncScript):
CONFIG_TYPE = RerunnerConfig
SHORT_HELP = "Rerun a previous gallia command based on its run_meta in the database"

def __init__(self, config: RerunnerConfig):
super().__init__(config)
self.config: RerunnerConfig = config

def main(self) -> None:
async def main(self) -> None:
if self.config.id is not None:
script, config = self.db()
script, config = await self.db()
else:
script, config = self.file()

# TODO: Make a interface for this to avoid error-prone parsing.
script_parts = script.split(".")
module = ".".join(script_parts[:-1])
class_name = script_parts[-1]
Expand All @@ -63,9 +63,9 @@ def main(self) -> None:
gallia_class: type[BaseCommand] = getattr(importlib.import_module(module), class_name)
command = gallia_class(gallia_class.CONFIG_TYPE(**config))

sys.exit(command.entry_point())
sys.exit(await command.entry_point())

def db(self) -> tuple[str, Mapping[str, Any]]:
async def db(self) -> tuple[str, Mapping[str, Any]]:
assert self.config.id is not None

query = "SELECT script, config FROM run_meta WHERE id = ?"
Expand All @@ -77,8 +77,8 @@ def db(self) -> tuple[str, Mapping[str, Any]]:

assert connection is not None

cursor: aiosqlite.Cursor = asyncio.run(connection.execute(query, parameters))
row = asyncio.run(cursor.fetchone())
cursor: aiosqlite.Cursor = await connection.execute(query, parameters)
row = await cursor.fetchone()

if row is None:
logger.error(f"There id no run_meta entry with the id {self.config.id}")
Expand Down

0 comments on commit b2df6fb

Please sign in to comment.