Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions openedx_tagging/core/tagging/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
Utilities for tagging and taxonomy models
"""
from django.db import connection as db_connection
from django.db.models import Aggregate, CharField
from django.db.models.expressions import Func
from django.db.models.expressions import Combinable, Func

RESERVED_TAG_CHARS = [
'\t', # Used in the database to separate tag levels in the "lineage" field
# e.g. lineage="Earth\tNorth America\tMexico\tMexico City"
' > ', # Used in the search index and Instantsearch frontend to separate tag levels
# e.g. tags_level3="Earth > North America > Mexico > Mexico City"
';', # Used in CSV exports to separate multiple tags from the same taxonomy
';', # Used in CSV exports to separate multiple tags from the same taxonomy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
';', # Used in CSV exports to separate multiple tags from the same taxonomy
';', # Used in CSV exports to separate multiple tags from the same taxonomy

Not sure why this ellipsis character was added here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

# e.g. languages-v1: en;es;fr
]
TAGS_CSV_SEPARATOR = RESERVED_TAG_CHARS[2]
Expand All @@ -34,21 +35,56 @@ def as_sqlite(self, compiler, connection, **extra_context):
)


class StringAgg(Aggregate): # pylint: disable=abstract-method
class StringAgg(Aggregate, Combinable):
"""
Aggregate function that collects the values of some column across all rows,
and creates a string by concatenating those values, with "," as a separator.
and creates a string by concatenating those values, with a specified separator.

This is the same as Django's django.contrib.postgres.aggregates.StringAgg,
but this version works with MySQL and SQLite.
This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite.
"""
# Default function is for MySQL (GROUP_CONCAT)
function = 'GROUP_CONCAT'
template = '%(function)s(%(distinct)s%(expressions)s)'

def __init__(self, expression, distinct=False, **extra):
def __init__(self, expression, distinct=False, delimiter=',', **extra):
self.delimiter = delimiter
# Handle the distinct option and output type
distinct_str = 'DISTINCT ' if distinct else ''

extra.update({
'distinct': distinct_str,
'output_field': CharField(),
})

# Check the database backend (PostgreSQL, MySQL, or SQLite)
if 'postgresql' in db_connection.vendor.lower():
self.function = 'STRING_AGG'
self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)'
extra.update({"delimiter": delimiter})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If delimiter cannot be customized for MySQL or SQLite, then I don't think it should be exposed as an argument to the constructor here.

Can we just hardcode it to ',' for postgres?

Suggested change
extra.update({"delimiter": delimiter})
extra.update({"delimiter": ","})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is it possible to avoid overriding as_sql() below by providing the same output_field that postgres StringAgg does?

Suggested change
extra.update({"delimiter": delimiter})
extra.update({
"delimiter": ",",
"output_field": TextField(),
})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a great idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pomegranited I just updated the PR as per your recommendations, thank you for the great suggestion


# Initialize the parent class with the necessary parameters
super().__init__(
expression,
distinct='DISTINCT ' if distinct else '',
output_field=CharField(),
**extra,
)

def as_sql(self, compiler, connection, **extra_context):
# If PostgreSQL, we use STRING_AGG with a separator
if 'postgresql' in connection.vendor.lower():
# Ensure that expressions are cast to TEXT for PostgreSQL
expressions_sql, params = compiler.compile(self.source_expressions[0])
expressions_sql = f"({expressions_sql})::TEXT" # Cast to TEXT for PostgreSQL
return f"{self.function}({expressions_sql}, {self.delimiter!r})", params
else:
# MySQL/SQLite handles GROUP_CONCAT with SEPARATOR
return super().as_sql(compiler, connection, **extra_context)

# Implementing abstract methods from Combinable
def __rand__(self, other):
return self._combine(other, 'AND', False)

def __ror__(self, other):
return self._combine(other, 'OR', False)

def __rxor__(self, other):
return self._combine(other, 'XOR', False)