Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLAlchemy 2.0 upgrades (part 2) #16724

Merged
merged 28 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6b32c17
Fix SA2.0 ORM usage in datatypes
jdavcs Sep 22, 2023
597054f
Fix SA2.0 ORM usage in galaxy.controllers.page [partial]
jdavcs Sep 22, 2023
3fd5b02
Fix SA2.0 ORM usage in model.security
jdavcs Sep 25, 2023
b40deee
Fix bug (replace str with model attrs in join clauses), fix SA2.0 req…
jdavcs Sep 26, 2023
336a32c
Fix SA2.0 ORM usage in model.item_attrs
jdavcs Sep 27, 2023
581cc77
Fix SA 2.0 ORM usage in model.tags
jdavcs Sep 27, 2023
4a87976
Fix SA2.0 ORM usage in model.store
jdavcs Sep 27, 2023
7365b24
Fix SA2.0 ORM usage in model.metadata
jdavcs Sep 28, 2023
023646d
Fix SA2.0 ORM usage in model.HistoryAudit.prune()
jdavcs Sep 29, 2023
f5e3d37
Fix SA2.0 ORM usage in model.History.disk_size; simplify
jdavcs Sep 29, 2023
517a165
Fix SA2.0 usage in model.Job.paused_jobs
jdavcs Sep 29, 2023
325d87e
Fix SA2.0 usage in model.History.active_dataset_and_roles_query
jdavcs Sep 29, 2023
456171f
Fix SA2.0 usage in model.History.active_visible_dataset_collections
jdavcs Sep 29, 2023
30f2ba9
Fix SA2.0 usage in model.History.__filter_contents
jdavcs Sep 29, 2023
bf8d718
Optimize + fix SA2.0 usage in model.DatasetInstance.convert_dataset
jdavcs Oct 2, 2023
0545b3e
Fix SA2.0 usage in model.StoredWorkflow.show_in_tool_panel
jdavcs Oct 2, 2023
9508463
Fix SA2.0 usage in model.WorkflowInvocation.poll_unhandled...
jdavcs Oct 2, 2023
d577b17
Fix SA2.0 usage in model.PSAAssociation
jdavcs Oct 2, 2023
cb5bd5d
Fix SA2.0 usage in model.PSACode
jdavcs Oct 2, 2023
abd6724
Fix SA2.0 usage in PSANonce
jdavcs Oct 2, 2023
94625dd
Fix bug (needs commit) + SA2.0 usage in model.PSAPartial
jdavcs Oct 2, 2023
57f26a1
Fix SA2.0 usage in model.UserAuthnzToken
jdavcs Oct 2, 2023
80b5ada
Refactor + fix SA2.0 in model.UserAuthnzToken
jdavcs Oct 2, 2023
48fe6a7
Fix SA2.0 usage in managers.quotas
jdavcs Oct 3, 2023
bbb0253
Fix SA2.0 usage in managers.model_stores
jdavcs Oct 3, 2023
02a2ee9
Fix SA2.0 usage in managers.roles
jdavcs Oct 3, 2023
189a9b3
Fix SA2.0 usage in managers.users
jdavcs Oct 3, 2023
c320836
Refactor get user by email
jdavcs Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/galaxy/datatypes/display_applications/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def decode_dataset_user(trans, dataset_hash, user_hash):
# decode dataset id as usual
# decode user id using the dataset create time as the key
dataset_id = trans.security.decode_id(dataset_hash)
dataset = trans.sa_session.query(trans.app.model.HistoryDatasetAssociation).get(dataset_id)
dataset = trans.sa_session.get(trans.app.model.HistoryDatasetAssociation, dataset_id)
assert dataset, "Bad Dataset id provided to decode_dataset_user"
if user_hash in [None, "None"]:
user = None
else:
security = IdEncodingHelper(id_secret=dataset.create_time)
user_id = security.decode_id(user_hash)
user = trans.sa_session.query(trans.app.model.User).get(user_id)
user = trans.sa_session.get(trans.app.model.User, user_id)
assert user, "A Bad user id was passed to decode_dataset_user"
return dataset, user
18 changes: 9 additions & 9 deletions lib/galaxy/managers/model_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def setup_history_export_job(self, request: SetupHistoryExportJob):
include_deleted = request.include_deleted
store_directory = request.store_directory

history = self._sa_session.query(model.History).get(history_id)
history = self._sa_session.get(model.History, history_id)
# symlink files on export, on worker files will tarred up in a dereferenced manner.
with DirectoryModelExportStore(store_directory, app=self._app, export_files="symlink") as export_store:
export_store.export_history(history, include_hidden=include_hidden, include_deleted=include_deleted)
job = self._sa_session.query(model.Job).get(job_id)
job = self._sa_session.get(model.Job, job_id)
job.state = model.Job.states.NEW
with transaction(self._sa_session):
self._sa_session.commit()
Expand Down Expand Up @@ -137,10 +137,10 @@ def prepare_history_content_download(self, request: GenerateHistoryContentDownlo
short_term_storage_target.path
) as export_store:
if request.content_type == HistoryContentType.dataset:
hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id)
hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id)
export_store.add_dataset(hda)
else:
hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id)
hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id)
export_store.export_collection(
hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -157,7 +157,7 @@ def prepare_invocation_download(self, request: GenerateInvocationDownload):
export_files=export_files,
bco_export_options=self._bco_export_options(request),
)(short_term_storage_target.path) as export_store:
invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id)
invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id)
export_store.export_workflow_invocation(
invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -174,7 +174,7 @@ def write_invocation_to(self, request: WriteInvocationTo):
bco_export_options=self._bco_export_options(request),
user_context=user_context,
)(target_uri) as export_store:
invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id)
invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id)
export_store.export_workflow_invocation(
invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -199,10 +199,10 @@ def write_history_content_to(self, request: WriteHistoryContentTo):
self._app, model_store_format, export_files=export_files, user_context=user_context
)(target_uri) as export_store:
if request.content_type == HistoryContentType.dataset:
hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id)
hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id)
export_store.add_dataset(hda)
else:
hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id)
hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id)
export_store.export_collection(
hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand Down Expand Up @@ -267,7 +267,7 @@ def import_model_store(self, request: ImportModelStoreTaskRequest):
)
history_id = request.history_id
if history_id:
history = self._sa_session.query(model.History).get(history_id)
history = self._sa_session.get(model.History, history_id)
else:
history = None
user_context = self._build_user_context(request.user.user_id)
Expand Down
33 changes: 32 additions & 1 deletion lib/galaxy/managers/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from sqlalchemy import (
desc,
false,
or_,
select,
Expand All @@ -42,7 +43,12 @@
ready_galaxy_markdown_for_export,
ready_galaxy_markdown_for_import,
)
from galaxy.model import PageRevision
from galaxy.model import (
Page,
PageRevision,
PageUserShareAssociation,
User,
)
from galaxy.model.base import transaction
from galaxy.model.index_filter_util import (
append_user_filter,
Expand Down Expand Up @@ -631,3 +637,28 @@ def placeholderRenderForSave(trans: ProvidesHistoryContext, item_class, item_id,
def get_page_revision(session: Session, page_id: int):
stmt = select(PageRevision).filter_by(page_id=page_id)
return session.scalars(stmt)


def get_shared_pages(session: Session, user: User):
stmt = (
select(PageUserShareAssociation)
.where(PageUserShareAssociation.user == user)
.join(Page)
.where(Page.deleted == false())
.order_by(desc(Page.update_time))
)
return session.scalars(stmt)


def get_page(session: Session, user: User, slug: str):
stmt = _build_page_query(select(Page), user, slug)
return session.scalar(stmt).first()


def page_exists(session: Session, user: User, slug: str) -> bool:
stmt = _build_page_query(select(Page.id), user, slug)
return session.scalar(stmt).first() is not None


def _build_page_query(select_clause, user: User, slug: str):
return select_clause.where(Page.user == user).where(Page.slug == slug).where(Page.deleted == false()).limit(1)
27 changes: 15 additions & 12 deletions lib/galaxy/managers/quotas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
Union,
)

from sqlalchemy import select

from galaxy import (
model,
util,
)
from galaxy.exceptions import ActionInputError
from galaxy.managers import base
from galaxy.model import (
Group,
Quota,
User,
)
from galaxy.model.base import transaction
from galaxy.quota import DatabaseQuotaAgent
from galaxy.quota._schema import (
Expand Down Expand Up @@ -46,7 +53,8 @@ def quota_agent(self) -> DatabaseQuotaAgent:
def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str]:
params = CreateQuotaParams.parse_obj(payload)
create_amount = self._parse_amount(params.amount)
if self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first():
stmt = select(Quota).where(Quota.name == params.name).limit(1)
if self.sa_session.scalars(stmt).first():
raise ActionInputError(
"Quota names must be unique and a quota with that name already exists, please choose another name."
)
Expand Down Expand Up @@ -74,12 +82,10 @@ def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str]
else:
# Create the UserQuotaAssociations
in_users = [
self.sa_session.query(model.User).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_users)
self.sa_session.get(User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users)
]
in_groups = [
self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_groups)
self.sa_session.get(Group, decode_id(x) if decode_id else x) for x in util.listify(params.in_groups)
]
if None in in_users:
raise ActionInputError("One or more invalid user id has been provided.")
Expand Down Expand Up @@ -108,12 +114,10 @@ def _parse_amount(self, amount: str) -> Optional[Union[int, bool]]:
return False

def rename_quota(self, quota, params) -> str:
stmt = select(Quota).where(Quota.name == params.name).limit(1)
if not params.name:
raise ActionInputError("Enter a valid name.")
elif (
params.name != quota.name
and self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first()
):
elif params.name != quota.name and self.sa_session.scalars(stmt).first():
raise ActionInputError("A quota with that name already exists.")
else:
old_name = quota.name
Expand All @@ -131,13 +135,12 @@ def manage_users_and_groups_for_quota(self, quota, params, decode_id=None) -> st
raise ActionInputError("Default quotas cannot be associated with specific users and groups.")
else:
in_users = [
self.sa_session.query(model.User).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_users)
self.sa_session.get(model.User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users)
]
if None in in_users:
raise ActionInputError("One or more invalid user id has been provided.")
in_groups = [
self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x)
self.sa_session.get(model.Group, decode_id(x) if decode_id else x)
for x in util.listify(params.in_groups)
]
if None in in_groups:
Expand Down
18 changes: 12 additions & 6 deletions lib/galaxy/managers/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import logging
from typing import List

from sqlalchemy import false
from sqlalchemy import (
false,
select,
)
from sqlalchemy.orm import exc as sqlalchemy_exceptions

import galaxy.exceptions
Expand Down Expand Up @@ -44,7 +47,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role:
:raises: InconsistentDatabase, RequestParameterInvalidException, InternalServerError
"""
try:
role = self.session().query(self.model_class).filter(self.model_class.id == role_id).one()
stmt = select(self.model_class).where(self.model_class.id == role_id)
role = self.session().execute(stmt).scalar_one()
except sqlalchemy_exceptions.MultipleResultsFound:
raise galaxy.exceptions.InconsistentDatabase("Multiple roles found with the same id.")
except sqlalchemy_exceptions.NoResultFound:
Expand All @@ -59,7 +63,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role:

def list_displayable_roles(self, trans: ProvidesUserContext) -> List[Role]:
roles = []
for role in trans.sa_session.query(Role).filter(Role.deleted == false()):
stmt = select(Role).where(Role.deleted == false())
for role in trans.sa_session.scalars(stmt):
if trans.user_is_admin or trans.app.security_agent.ok_to_display(trans.user, role):
roles.append(role)
return roles
Expand All @@ -70,15 +75,16 @@ def create_role(self, trans: ProvidesUserContext, role_definition_model: RoleDef
user_ids = role_definition_model.user_ids or []
group_ids = role_definition_model.group_ids or []

if trans.sa_session.query(Role).filter(Role.name == name).first():
stmt = select(Role).where(Role.name == name).limit(1)
if trans.sa_session.scalars(stmt).first():
raise RequestParameterInvalidException(f"A role with that name already exists [{name}]")

role_type = Role.types.ADMIN # TODO: allow non-admins to create roles

role = Role(name=name, description=description, type=role_type)
trans.sa_session.add(role)
users = [trans.sa_session.query(model.User).get(i) for i in user_ids]
groups = [trans.sa_session.query(model.Group).get(i) for i in group_ids]
users = [trans.sa_session.get(model.User, i) for i in user_ids]
groups = [trans.sa_session.get(model.Group, i) for i in group_ids]

# Create the UserRoleAssociations
for user in users:
Expand Down
51 changes: 23 additions & 28 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from galaxy.model import (
User,
UserAddress,
UserQuotaUsage,
)
from galaxy.model.base import transaction
Expand Down Expand Up @@ -233,13 +234,8 @@ def purge(self, user, flush=True):
user.username = uname_hash
# Redact user addresses as well
if self.app.config.redact_user_address_during_deletion:
user_addresses = (
self.session()
.query(self.app.model.UserAddress)
.filter(self.app.model.UserAddress.user_id == user.id)
.all()
)
for addr in user_addresses:
stmt = select(UserAddress).where(UserAddress.user_id == user.id)
for addr in self.session().scalars(stmt):
addr.desc = new_secure_hash_v2(addr.desc + pseudorandom_value)
addr.name = new_secure_hash_v2(addr.name + pseudorandom_value)
addr.institution = new_secure_hash_v2(addr.institution + pseudorandom_value)
Expand All @@ -264,7 +260,7 @@ def _error_on_duplicate_email(self, email: str) -> None:
raise exceptions.Conflict("Email must be unique", email=email)

def by_id(self, user_id: int) -> model.User:
return self.app.model.session.query(self.model_class).get(user_id)
return self.app.model.session.get(self.model_class, user_id)

# ---- filters
def by_email(self, email: str, filters=None, **kwargs) -> Optional[model.User]:
Expand All @@ -286,7 +282,8 @@ def by_api_key(self, api_key: str, sa_session=None):
return schema.BootstrapAdminUser()
sa_session = sa_session or self.app.model.session
try:
provided_key = sa_session.query(self.app.model.APIKeys).filter_by(key=api_key, deleted=False).one()
stmt = select(self.app.model.APIKeys).filter_by(key=api_key, deleted=False)
provided_key = sa_session.execute(stmt).scalar_one()
except NoResultFound:
raise exceptions.AuthenticationFailed("Provided API key is not valid.")
if provided_key.user.deleted:
Expand Down Expand Up @@ -363,12 +360,7 @@ def get_user_by_identity(self, identity):
user = get_user_by_email(self.session(), identity, self.model_class)
if not user:
# Try a case-insensitive match on the email
user = (
self.session()
.query(self.model_class)
.filter(func.lower(self.model_class.table.c.email) == identity.lower())
.first()
)
user = self._get_user_by_email_case_insensitive(self.session(), identity)
else:
user = get_user_by_username(self.session(), identity, self.model_class)
return user
Expand Down Expand Up @@ -445,7 +437,7 @@ def change_password(self, trans, password=None, confirm=None, token=None, id=Non
if not token and not id:
return None, "Please provide a token or a user and password."
if token:
token_result = trans.sa_session.query(self.app.model.PasswordResetToken).get(token)
token_result = trans.sa_session.get(self.app.model.PasswordResetToken, token)
if not token_result or not token_result.expiration_time > datetime.utcnow():
return None, "Invalid or expired password reset token, please request a new one."
user = token_result.user
Expand Down Expand Up @@ -483,13 +475,14 @@ def __set_password(self, trans, user, password, confirm):
user.set_password_cleartext(password)
# Invalidate all other sessions
if trans.galaxy_session:
for other_galaxy_session in trans.sa_session.query(self.app.model.GalaxySession).filter(
stmt = select(self.app.model.GalaxySession).where(
and_(
self.app.model.GalaxySession.table.c.user_id == user.id,
self.app.model.GalaxySession.table.c.is_valid == true(),
self.app.model.GalaxySession.table.c.id != trans.galaxy_session.id,
self.app.model.GalaxySession.user_id == user.id,
self.app.model.GalaxySession.is_valid == true(),
self.app.model.GalaxySession.id != trans.galaxy_session.id,
)
):
)
for other_galaxy_session in trans.sa_session.scalars(stmt):
other_galaxy_session.is_valid = False
trans.sa_session.add(other_galaxy_session)
trans.sa_session.add(user)
Expand Down Expand Up @@ -581,11 +574,7 @@ def send_reset_email(self, trans, payload, **kwd):
def get_reset_token(self, trans, email):
reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User)
if not reset_user and email != email.lower():
reset_user = (
trans.sa_session.query(self.app.model.User)
.filter(func.lower(self.app.model.User.table.c.email) == email.lower())
.first()
)
reset_user = self._get_user_by_email_case_insensitive(trans.sa_session, email)
if reset_user:
prt = self.app.model.PasswordResetToken(reset_user)
trans.sa_session.add(prt)
Expand Down Expand Up @@ -644,9 +633,11 @@ def get_or_create_remote_user(self, remote_user_email):
for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]:
username = username.replace(char, "-")
# Find a unique username - user can change it later
if self.session().query(self.app.model.User).filter_by(username=username).first():
stmt = select(self.app.model.User).filter_by(username=username).limit(1)
if self.session().scalars(stmt).first():
i = 1
while self.session().query(self.app.model.User).filter_by(username=f"{username}-{str(i)}").first():
stmt = select(self.app.model.User).filter_by(username=f"{username}-{str(i)}").limit(1)
while self.session().scalars(stmt).first():
i += 1
username += f"-{str(i)}"
user.username = username
Expand All @@ -660,6 +651,10 @@ def get_or_create_remote_user(self, remote_user_email):
# self.log_event( "Automatically created account '%s'", user.email )
return user

def _get_user_by_email_case_insensitive(self, session, email):
stmt = select(self.app.model.User).where(func.lower(self.app.model.User.email) == email.lower()).limit(1)
return session.scalars(stmt).first()


class UserSerializer(base.ModelSerializer, deletable.PurgableSerializerMixin):
model_manager_class = UserManager
Expand Down
Loading