fix core daos

This commit is contained in:
Kenneth Kehl
2024-10-11 11:30:35 -07:00
parent ef6e4048c2
commit 4f20bfe2db

View File

@@ -4,7 +4,7 @@ from secrets import randbelow
import sqlalchemy import sqlalchemy
from flask import current_app from flask import current_app
from sqlalchemy import func, text from sqlalchemy import delete, func, select, text
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app import db from app import db
@@ -37,8 +37,8 @@ def get_login_gov_user(login_uuid, email_address):
login.gov uuids are. Eventually the code that checks by email address login.gov uuids are. Eventually the code that checks by email address
should be removed. should be removed.
""" """
stmt = select(User).filter_by(login_uuid=login_uuid)
user = User.query.filter_by(login_uuid=login_uuid).first() user = db.session.execute(stmt).scalars().first()
if user: if user:
if user.email_address != email_address: if user.email_address != email_address:
try: try:
@@ -54,7 +54,8 @@ def get_login_gov_user(login_uuid, email_address):
return user return user
# Remove this 1 July 2025, all users should have login.gov uuids by now # Remove this 1 July 2025, all users should have login.gov uuids by now
user = User.query.filter(User.email_address.ilike(email_address)).first() stmt = select(User).filter(User.email_address.ilike(email_address))
user = db.session.execute(stmt).scalars().first()
if user: if user:
save_user_attribute(user, {"login_uuid": login_uuid}) save_user_attribute(user, {"login_uuid": login_uuid})
@@ -102,24 +103,27 @@ def create_user_code(user, code, code_type):
def get_user_code(user, code, code_type): def get_user_code(user, code, code_type):
# Get the most recent codes to try and reduce the # Get the most recent codes to try and reduce the
# time searching for the correct code. # time searching for the correct code.
codes = VerifyCode.query.filter_by(user=user, code_type=code_type).order_by( stmt = (
VerifyCode.created_at.desc() select(VerifyCode)
.filter_by(user=user, code_type=code_type)
.order_by(VerifyCode.created_at.desc())
) )
codes = db.session.execute(stmt).scalars().all()
return next((x for x in codes if x.check_code(code)), None) return next((x for x in codes if x.check_code(code)), None)
def delete_codes_older_created_more_than_a_day_ago(): def delete_codes_older_created_more_than_a_day_ago():
deleted = ( stmt = delete(VerifyCode).filter(
db.session.query(VerifyCode) VerifyCode.created_at < utc_now() - timedelta(hours=24)
.filter(VerifyCode.created_at < utc_now() - timedelta(hours=24))
.delete()
) )
deleted = db.session.execute(stmt)
db.session.commit() db.session.commit()
return deleted return deleted
def use_user_code(id): def use_user_code(id):
verify_code = VerifyCode.query.get(id) verify_code = db.session.get(VerifyCode, id)
verify_code.code_used = True verify_code.code_used = True
db.session.add(verify_code) db.session.add(verify_code)
db.session.commit() db.session.commit()
@@ -131,36 +135,42 @@ def delete_model_user(user):
def delete_user_verify_codes(user): def delete_user_verify_codes(user):
VerifyCode.query.filter_by(user=user).delete() stmt = delete(VerifyCode).filter_by(user=user)
db.session.execute(stmt)
db.session.commit() db.session.commit()
def count_user_verify_codes(user): def count_user_verify_codes(user):
query = VerifyCode.query.filter( stmt = select(func.count(VerifyCode.id)).filter(
VerifyCode.user == user, VerifyCode.user == user,
VerifyCode.expiry_datetime > utc_now(), VerifyCode.expiry_datetime > utc_now(),
VerifyCode.code_used.is_(False), VerifyCode.code_used.is_(False),
) )
return query.count() result = db.session.execute(stmt)
return result.rowcount
def get_user_by_id(user_id=None): def get_user_by_id(user_id=None):
if user_id: if user_id:
return User.query.filter_by(id=user_id).one() stmt = select(User).filter_by(id=user_id)
return User.query.filter_by().all() return db.session.execute(stmt).scalars().one()
return get_users()
def get_users(): def get_users():
return User.query.all() stmt = select(User)
return db.session.execute(stmt).scalars().all()
def get_user_by_email(email): def get_user_by_email(email):
return User.query.filter(func.lower(User.email_address) == func.lower(email)).one() stmt = select(User).filter(func.lower(User.email_address) == func.lower(email))
return db.session.execute(stmt).scalars().one()
def get_users_by_partial_email(email): def get_users_by_partial_email(email):
email = escape_special_characters(email) email = escape_special_characters(email)
return User.query.filter(User.email_address.ilike("%{}%".format(email))).all() stmt = select(User).filter(User.email_address.ilike("%{}%".format(email)))
return db.session.execute(stmt).scalars().all()
def increment_failed_login_count(user): def increment_failed_login_count(user):
@@ -188,16 +198,17 @@ def get_user_and_accounts(user_id):
# TODO: With sqlalchemy 2.0 change as below because of the breaking change # TODO: With sqlalchemy 2.0 change as below because of the breaking change
# at User.organizations.services, we need to verify that the below subqueryload # at User.organizations.services, we need to verify that the below subqueryload
# that we have put is functionally doing the same thing as before # that we have put is functionally doing the same thing as before
return ( stmt = (
User.query.filter(User.id == user_id) select(User)
.filter(User.id == user_id)
.options( .options(
# eagerly load the user's services and organizations, and also the service's org and vice versa # eagerly load the user's services and organizations, and also the service's org and vice versa
# (so we can see if the user knows about it) # (so we can see if the user knows about it)
joinedload(User.services).joinedload(Service.organization), joinedload(User.services).joinedload(Service.organization),
joinedload(User.organizations).subqueryload(Organization.services), joinedload(User.organizations).subqueryload(Organization.services),
) )
.one()
) )
return db.session.execute(stmt).scalars().one()
@autocommit @autocommit