Skip to content

Commit

Permalink
🎨 Use constant to define default values
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilbadyal committed Aug 18, 2024
1 parent 37528e8 commit 0809dfa
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:
- id: ruff
args:
- "--config=pyproject.toml"
- "--fix"
- "--unsafe-fixes"

- repo: https://github.com/psf/black
rev: 23.12.1
Expand Down
77 changes: 55 additions & 22 deletions esxport/click_opt/cli_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import urllib3
from elastic_transport import SecurityWarning

from esxport.constant import default_config_fields

if TYPE_CHECKING:
from typing_extensions import Self

Expand All @@ -18,28 +20,59 @@
class CliOptions(object):
"""CLI options."""

def __init__(
self: Self,
myclass_kwargs: dict[str, Any],
) -> None:
self.query: dict[str, Any] = myclass_kwargs["query"]
self.output_file = myclass_kwargs["output_file"]
self.url = myclass_kwargs["url"]
self.user = myclass_kwargs["user"]
self.password = myclass_kwargs["password"]
self.index_prefixes = myclass_kwargs["index_prefixes"]
self.fields: list[str] = list(myclass_kwargs["fields"])
self.sort: list[dict[str, str]] = myclass_kwargs["sort"]
self.delimiter = myclass_kwargs["delimiter"]
self.max_results = int(myclass_kwargs["max_results"])
self.scroll_size = int(myclass_kwargs["scroll_size"])
self.meta_fields: list[str] = list(myclass_kwargs["meta_fields"])
self.verify_certs: bool = myclass_kwargs["verify_certs"]
self.ca_certs = myclass_kwargs["ca_certs"]
self.client_cert = myclass_kwargs["ca_certs"]
self.client_key = myclass_kwargs["ca_certs"]
self.debug: bool = myclass_kwargs["debug"]
self.format: str = "csv"
# Explicitly declare all attributes with their types
query: dict[str, Any]
output_file: str
url: str
user: str
password: str
index_prefixes: list[str]
fields: list[str]
sort: list[dict[str, str]]
delimiter: str
max_results: int
scroll_size: int
meta_fields: list[str]
verify_certs: bool
ca_certs: str
client_cert: str
client_key: str
debug: bool
export_format: str

def __init__(self: Self, myclass_kwargs: dict[str, Any]) -> None:
# All keys that you want to set as attributes
attrs_to_set = {
"query",
"output_file",
"url",
"user",
"password",
"index_prefixes",
"fields",
"sort",
"delimiter",
"max_results",
"scroll_size",
"meta_fields",
"verify_certs",
"ca_certs",
"client_cert",
"client_key",
"debug",
}

for attr in attrs_to_set:
setattr(self, attr, myclass_kwargs.get(attr, default_config_fields.get(attr)))

# Additional processing for certain attributes
self.fields: list[str] = list(self.fields)
self.index_prefixes: list[str] = list(self.index_prefixes)
self.meta_fields: list[str] = list(self.meta_fields)
self.sort: list[dict[str, str]] = self.sort
self.max_results = int(self.max_results)
self.scroll_size = int(self.scroll_size)
self.export_format: str = "csv"

def __str__(self: Self) -> str:
"""Print the class."""
Expand Down
16 changes: 16 additions & 0 deletions esxport/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,19 @@
TIMES_TO_TRY = 3
RETRY_DELAY = 60
META_FIELDS = ["_id", "_index", "_score"]
default_config_fields = {
"url": "https://localhost:9200",
"user": "elastic",
"index_prefixes": "",
"fields": ["_all"],
"sort": [],
"delimiter": ",",
"max_results": 10,
"scroll_size": 100,
"meta_fields": [],
"verify_certs": True,
"ca_certs": "",
"client_cert": "",
"client_key": "",
"debug": False,
}
2 changes: 1 addition & 1 deletion esxport/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
client_key=cli_options.client_key,
)

def indices_exists(self: Self, index: str) -> bool:
def indices_exists(self: Self, index: str | list[str] | tuple[str, ...]) -> bool:
"""Check if a given index exists."""
return bool(self.client.indices.exists(index=index))

Expand Down
2 changes: 1 addition & 1 deletion esxport/esxport.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _export(self: Self) -> None:
headers = self._extract_headers()
kwargs = {
"delimiter": self.opts.delimiter,
"output_format": self.opts.format,
"output_format": self.opts.export_format,
}
Writer.write(
headers=headers,
Expand Down
18 changes: 9 additions & 9 deletions esxport/esxport_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from esxport.__init__ import __version__
from esxport.click_opt.cli_options import CliOptions
from esxport.click_opt.click_custom import JSON, sort
from esxport.constant import META_FIELDS
from esxport.constant import META_FIELDS, default_config_fields
from esxport.elastic import ElasticsearchClient
from esxport.strings import cli_version

Expand Down Expand Up @@ -48,14 +48,14 @@ def print_version(ctx: Context, _: Parameter, value: bool) -> None: # noqa: FBT
"--url",
type=URL,
required=False,
default="https://localhost:9200",
default=default_config_fields["url"],
help="Elasticsearch host URL.",
)
@click.option(
"-U",
"--user",
required=False,
default="elastic",
default=default_config_fields["user"],
help="Elasticsearch basic authentication user.",
)
@click.password_option(
Expand All @@ -68,7 +68,7 @@ def print_version(ctx: Context, _: Parameter, value: bool) -> None: # noqa: FBT
@click.option(
"-f",
"--fields",
default=["_all"],
default=default_config_fields["fields"],
multiple=True,
help="List of _source fields to present be in output.",
)
Expand All @@ -82,28 +82,28 @@ def print_version(ctx: Context, _: Parameter, value: bool) -> None: # noqa: FBT
@click.option(
"-d",
"--delimiter",
default=",",
default=default_config_fields["delimiter"],
help="Delimiter to use in CSV file.",
)
@click.option(
"-m",
"--max-results",
default=10,
default=default_config_fields["max_results"],
type=int,
help="Maximum number of results to return.",
)
@click.option(
"-s",
"--scroll-size",
default=100,
default=default_config_fields["scroll_size"],
type=int,
help="Scroll size for each batch of results.",
)
@click.option(
"-e",
"--meta-fields",
type=click.Choice(META_FIELDS),
default=[],
default=default_config_fields["meta_fields"],
multiple=True,
help="Add meta-fields in output.",
)
Expand Down Expand Up @@ -138,7 +138,7 @@ def print_version(ctx: Context, _: Parameter, value: bool) -> None: # noqa: FBT
@click.option(
"--debug",
is_flag=True,
default=False,
default=default_config_fields["debug"],
help="Debug mode on.",
)
def cli( # noqa: PLR0913
Expand Down
2 changes: 1 addition & 1 deletion test/esxport/_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_export_invalid_format(
esxport_obj: EsXport,
) -> None:
"""Check if exception is raised when formatting is invalid."""
esxport_obj.opts.format = "invalid_format"
esxport_obj.opts.export_format = "invalid_format"
with patch.object(EsXport, "_extract_headers", return_value=[]), pytest.raises(NotImplementedError):
esxport_obj.export()
TestExport.rm_export_file(f"{inspect.stack()[0].function}.csv")
Expand Down

0 comments on commit 0809dfa

Please sign in to comment.