Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions grace/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ class Application:
__base: DeclarativeMeta = declarative_base()

def __init__(self):
database_config_path: Path = Path("config/database.cfg")

if not database_config_path.exists():
raise ConfigError("Unable to find the 'database.cfg' file.")

self.__token: str = self.config.get("discord", "token")
self.__engine: Union[Engine, None] = None

Expand Down Expand Up @@ -78,6 +73,20 @@ def config(self) -> Config:
def client(self) -> SectionProxy:
return self.config.client

@property
def database(self) -> SectionProxy:
return self.config.database

@property
def database_infos(self) -> Dict[str, str]:
if not self.database:
return {}

return {
"dialect": self.session.bind.dialect.name,
"database": self.session.bind.url.database
}

@property
def extension_modules(self) -> Generator[str, Any, None]:
"""Generate the extensions modules"""
Expand All @@ -90,17 +99,6 @@ def extension_modules(self) -> Generator[str, Any, None]:
continue
yield module

@property
def database_infos(self) -> Dict[str, str]:
return {
"dialect": self.session.bind.dialect.name,
"database": self.session.bind.url.database
}

@property
def database_exists(self):
return database_exists(self.config.database_uri)

def get_extension_module(self, extension_name) -> Union[str, None]:
"""Return the extension from the given extension name"""

Expand Down Expand Up @@ -150,17 +148,23 @@ def load_logs(self):

def load_database(self):
"""Loads and connects to the database using the loaded config"""
if not self.database:
return None

self.__engine = create_engine(
self.config.database_uri,
echo=self.config.environment.getboolean("sqlalchemy_echo")
)

if self.database_exists:
if database_exists(self.config.database_uri):
try:
self.__engine.connect()
except OperationalError as e:
critical(f"Unable to load the 'database': {e}")
else:
self.create_database()
self.create_tables()


def unload_database(self):
"""Unloads the current database"""
Expand All @@ -179,24 +183,32 @@ def reload_database(self):

def create_database(self):
"""Creates the database for the current loaded config"""
if not self.database:
return None

self.load_database()
create_database(self.config.database_uri)

def drop_database(self):
"""Drops the database for the current loaded config"""
if not self.database:
return None

self.load_database()
drop_database(self.config.database_uri)

def create_tables(self):
"""Creates all the tables for the current loaded database"""
if not self.database:
return None

self.load_database()
self.base.metadata.create_all(self.__engine)

def drop_tables(self):
"""Drops all the tables for the current loaded database"""
if not self.database:
return None

self.load_database()
self.base.metadata.drop_all(self.__engine)
24 changes: 12 additions & 12 deletions grace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
| PID: {pid}
| Environment: {env}
| Syncing command: {command_sync}
""".rstrip()

DB_INFO = """
| Using database: {database} with {dialect}
""".rstrip()

Expand All @@ -29,9 +32,7 @@ def generate():

@cli.command()
@argument("name")
# This database option is currently disabled since the application and config
# does not currently support it.
# @option("--database/--no-database", default=True)
@option("--database/--no-database", default=True)
@pass_context
def new(ctx, name, database=True):
cmd = generate.get_command(ctx, 'project')
Expand All @@ -48,7 +49,6 @@ def run(environment=None, sync=None):
from bot import app, run

_loading_application(app, environment, sync)
_load_database(app)
_show_application_info(app)

run()
Expand All @@ -60,19 +60,19 @@ def _loading_application(app, environment, command_sync):
app.load(environment, command_sync=command_sync)


def _load_database(app):
if not app.database_exists:
app.create_database()
app.create_tables()

def _show_application_info(app):
info(APP_INFO.format(
info_message = APP_INFO

if app.database:
info_message += DB_INFO

info(info_message.format(
discord_version=discord.__version__,
env=app.config.current_environment,
pid=getpid(),
command_sync=app.command_sync,
database=app.database_infos["database"],
dialect=app.database_infos["dialect"],
database=app.database_infos.get("database"),
dialect=app.database_infos.get("dialect"),
))


Expand Down
13 changes: 11 additions & 2 deletions grace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,19 @@ def __init__(self):
interpolation=EnvironmentInterpolation()
)

self.read("config/settings.cfg")
self.read("config/database.cfg")
self.read("config/config.cfg")
self.read("config/environment.cfg")

self.__database_config = self.get("database", "config")

if self.__database_config:
self.read(f"config/{self.__database_config}")

@property
def database_uri(self) -> Union[str, URL]:
if not self.database:
return None

if self.database.get("url"):
return self.database.get("url")

Expand All @@ -95,6 +102,8 @@ def database_uri(self) -> Union[str, URL]:

@property
def database(self) -> SectionProxy:
if not self.__database_config:
return None
return self.__config[f"database.{self.__environment}"]

@property
Expand Down
21 changes: 16 additions & 5 deletions grace/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def generator() -> Generator:
from grace.importer import import_package_modules
from grace.exceptions import GeneratorError, ValidationError, NoTemplateError
from cookiecutter.main import cookiecutter
from jinja2 import Environment, PackageLoader, Template
from jinja2 import Environment, PackageLoader


def register_generators(command_group: Group):
"""Registers generator commands to the given Click command group.

This function dynamically imports all modules in the `grace.generators` package
and registers each module's `generator` command to the provided `command_group`.
This function dynamically imports all modules in the `grace.generators`
package and registers each module's `generator` command to the provided
`command_group`.

:param command_group: The Click command group to register the generators to.
:type command_group: Group
Expand Down Expand Up @@ -108,7 +109,12 @@ def validate(self, *args, **kwargs):
"""Validates the arguments passed to the command."""
return True

def generate_template(self, template_dir: str, variables: dict[str, any] = {}):
def generate_template(
self,
template_dir: str,
variables: dict[str, any] = {},
output_dir: str = ""
):
"""Generates a template using Cookiecutter.

:param template_dir: The name of the template to generate.
Expand All @@ -118,7 +124,12 @@ def generate_template(self, template_dir: str, variables: dict[str, any] = {}):
:type variables: dict[str, any]
"""
template = str(self.templates_path / template_dir)
cookiecutter(template, extra_context=variables, no_input=True)
cookiecutter(
template,
extra_context=variables,
no_input=True,
output_dir=output_dir
)

def generate_file(
self,
Expand Down
21 changes: 21 additions & 0 deletions grace/generators/database_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from grace.generator import Generator
from re import match
from logging import info


class DatabaseGenerator(Generator):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add optional instructions when added later on

Add the following to your `config.py`:
    
    [database]
    config = database.cfg

NAME = 'database'
OPTIONS = {}

def generate(self, output_dir: str = ""):
info(f"Creating database at {output_dir}")
self.generate_template(self.NAME, variables={
"output_dir": output_dir
})

def validate(self, *_args, **_kwargs) -> bool:
return True


def generator() -> Generator:
return DatabaseGenerator()
14 changes: 10 additions & 4 deletions grace/generators/project_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from grace.generator import Generator
from grace.generators.database_generator import generator as db_generator
from re import match
from logging import info

Expand All @@ -12,12 +13,17 @@ class ProjectGenerator(Generator):
def generate(self, name: str, database: bool = True):
info(f"Creating '{name}'")

self.generate_template(self.NAME, values={
self.generate_template(self.NAME, variables={
"project_name": name,
"project_description": "",
"database": "yes" if database else "no"
"project_description": "",
"database": database
})

if database:
# Should probably be moved into its own generator so we can
# generate add the database later on.
db_generator().generate(output_dir=name)

def validate(self, name: str, **_kwargs) -> bool:
"""Validate the project name.

Expand All @@ -39,4 +45,4 @@ def validate(self, name: str, **_kwargs) -> bool:


def generator() -> Generator:
return ProjectGenerator()
return ProjectGenerator()
3 changes: 3 additions & 0 deletions grace/generators/templates/database/cookiecutter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"__database_slug": "db"
}
2 changes: 1 addition & 1 deletion grace/generators/templates/project/cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"__project_slug": "{{ cookiecutter.project_name|lower|replace('-', '_') }}",
"__project_class": "{{ cookiecutter.project_name|title|replace('-', '') }}",
"project_description": "{{ cookiecutter.project_description }}",
"database": ["yes", "no"]
"database": [true, false]
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ guild_id = ${GUILD_ID}
; Although it is possible to set directly your discord token here, we recommend, for security reasons, that you set
; your discord token as an environment variable called 'DISCORD_TOKEN'.
token = ${DISCORD_TOKEN}
{% if cookiecutter.database %}
[database]
config = database.cfg
{% endif %}
25 changes: 17 additions & 8 deletions tests/generators/test_project_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest.mock import call
from grace.generator import Generator
from grace.generators.project_generator import ProjectGenerator

Expand All @@ -12,14 +13,22 @@ def test_generate_project_with_database(mocker, generator):
"""Test if the generate method creates the correct template with a database."""
mock_generate_template = mocker.patch.object(Generator, 'generate_template')
name = "example-project"

generator.generate(name, database=True)

mock_generate_template.assert_called_once_with('project', values={
'project_name': name,
'project_description': '',
'database': 'yes'
})
expected_calls = [
call('project', variables={
'project_name': name,
'project_description': '',
'database': True
}),
call('database', variables={
'output_dir': name
})
]

mock_generate_template.assert_has_calls(expected_calls)
assert mock_generate_template.call_count == 2


def test_generate_project_without_database(mocker, generator):
Expand All @@ -29,10 +38,10 @@ def test_generate_project_without_database(mocker, generator):

generator.generate(name, database=False)

mock_generate_template.assert_called_once_with('project', values={
mock_generate_template.assert_called_once_with('project', variables={
'project_name': name,
'project_description': '',
'database': 'no'
'database': False
})


Expand Down
7 changes: 6 additions & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def test_generate_template(generator):
with patch('grace.generator.cookiecutter') as cookiecutter:
generator.generate_template('project', variables={})
template_path = str(generator.templates_path / 'project')
cookiecutter.assert_called_once_with(template_path, extra_context={}, no_input=True)
cookiecutter.assert_called_once_with(
template_path,
extra_context={},
no_input=True,
output_dir=''
)


def test_generate(generator):
Expand Down