From 8b07b03ccda2733a1fc30cb73d154b4478d2826b Mon Sep 17 00:00:00 2001 From: Julian Lobbes Date: Fri, 12 May 2023 14:11:32 +0200 Subject: [PATCH] refactor(backend): subfolders and docstrings --- backend/config.py | 19 ++- backend/crud.py | 114 ------------------ backend/crud/__init__.py | 0 backend/crud/users.py | 144 +++++++++++++++++++++++ backend/database.py | 13 -- backend/database/__init__.py | 0 backend/database/engine.py | 19 +++ backend/main.py | 48 ++++---- backend/models/__init__.py | 0 backend/{models.py => models/users.py} | 17 ++- backend/schemas/__init__.py | 0 backend/{schemas.py => schemas/users.py} | 41 ++++++- 12 files changed, 255 insertions(+), 160 deletions(-) delete mode 100644 backend/crud.py create mode 100644 backend/crud/__init__.py create mode 100644 backend/crud/users.py delete mode 100644 backend/database.py create mode 100644 backend/database/__init__.py create mode 100644 backend/database/engine.py create mode 100644 backend/models/__init__.py rename backend/{models.py => models/users.py} (77%) create mode 100644 backend/schemas/__init__.py rename backend/{schemas.py => schemas/users.py} (71%) diff --git a/backend/config.py b/backend/config.py index b459467..70bf0a1 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,3 +1,9 @@ +"""This module provides global application settings. + +All settings are read from environment variables, but defaults are provided below +if the respective envvar is unset. +""" + import os from urllib.parse import quote_plus as url_encode from functools import lru_cache @@ -15,16 +21,17 @@ class Settings(BaseSettings): app_name: str = os.getenv("APP_NAME", "MEDWingS") admin_email: str = os.getenv("ADMIN_EMAIL", "admin@example.com") + # Debug mode has the following effects: + # - logs SQL operations debug_mode: bool = False if os.getenv("DEBUG_MODE", "false").lower() == "true": debug_mode = True - _pg_hostname = os.getenv("POSTGRES_HOST", "db") - _pg_port = os.getenv("POSTGRES_PORT", "5432") - _pg_dbname = os.getenv("POSTGRES_DB", "medwings") - _pg_user = url_encode(os.getenv("POSTGRES_USER", "medwings")) - _pg_password = url_encode(os.getenv("POSTGRES_PASSWORD", "medwings")) - pg_dsn: PostgresDsn = f"postgresql://{_pg_user}:{_pg_password}@{_pg_hostname}:{_pg_port}/{_pg_dbname}" + pg_hostname = os.getenv("POSTGRES_HOST", "db") + pg_port = os.getenv("POSTGRES_PORT", "5432") + pg_dbname = os.getenv("POSTGRES_DB", "medwings") + pg_user = url_encode(os.getenv("POSTGRES_USER", "medwings")) + pg_password = url_encode(os.getenv("POSTGRES_PASSWORD", "medwings")) @lru_cache diff --git a/backend/crud.py b/backend/crud.py deleted file mode 100644 index 6f5e889..0000000 --- a/backend/crud.py +++ /dev/null @@ -1,114 +0,0 @@ -import logging -from datetime import datetime - -from sqlalchemy.orm import Session - -from .import models, schemas - -log = logging.getLogger() - - -def hash_password(password: str) -> str: - # TODO actually hash the password! - return password - - -def _fill_missing_user_fields(db_user: models.User) -> schemas.User: - full_user = schemas.User.from_orm(db_user) - if db_user.patient: - full_user.gender = db_user.patient.gender - full_user.date_of_birth = db_user.patient.date_of_birth - full_user.is_patient = True - full_user.is_admin = False - else: - full_user.is_patient = False - full_user.is_admin = True - - return full_user - - -def create_user(db: Session, user: schemas.UserCreate): - """Creates a new user as either a patient or an administrator.""" - - db_user = models.User( - email=user.email, - first_name=user.first_name, - last_name=user.last_name, - password=hash_password(user.password), - ) - - # Add user to database - if user.is_patient: - db_patient = models.Patient( - user=db_user, - gender=user.gender, - date_of_birth=user.date_of_birth, - ) - db.add(db_patient) - else: - db_administrator = models.Administrator( - user=db_user, - ) - db.add(db_administrator) - - db.commit() - - # Construct the updated user to return - db.refresh(db_user) - return _fill_missing_user_fields(db_user) - - -def read_user(db: Session, id: int): - db_user = db.query(models.User).filter(models.User.id == id).first() - if not db_user: - return None - - return _fill_missing_user_fields(db_user) - - -def read_user_by_email(db: Session, email: str): - db_user = db.query(models.User).filter(models.User.email == email).first() - if not db_user: - return None - - return _fill_missing_user_fields(db_user) - - -def read_users(db: Session, skip: int = 0, limit: int = 100): - db_users = db.query(models.User).offset(skip).limit(limit).all() - - full_users = [] - for db_user in db_users: - full_users.append(_fill_missing_user_fields(db_user)) - return full_users - - -def update_user(db: Session, user: schemas.UserUpdate, id: int): - db_user = db.query(models.User).filter(models.User.id == id).first() - current_user = _fill_missing_user_fields(db_user) - - for key in ['gender', 'date_of_birth']: - value = getattr(user, key) - if value is not None: - setattr(db_user.patient, key, value) - for key in ['email', 'first_name', 'last_name']: - value = getattr(user, key) - if value is not None: - setattr(db_user, key, value) - if user.password is not None: - db_user.password = hash_password(user.password) - - db.commit() - db.refresh(db_user) - return _fill_missing_user_fields(db_user) - - -def delete_user(db: Session, id: int): - db_user = db.query(models.User).filter(models.User.id == id).first() - user_copy = _fill_missing_user_fields(db_user) - - db.delete(db_user) - db.commit() - - user_copy.updated = datetime.now(user_copy.updated.tzinfo) - return user_copy diff --git a/backend/crud/__init__.py b/backend/crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/crud/users.py b/backend/crud/users.py new file mode 100644 index 0000000..ad2f4b2 --- /dev/null +++ b/backend/crud/users.py @@ -0,0 +1,144 @@ +"""This module handles CRUD operations for users in the database, based on pydanctic schemas.""" + +from datetime import datetime + +from sqlalchemy.orm import Session + +from backend.models import users as usermodel +from backend.schemas import users as userschema + + +def hash_password(password: str) -> str: + """This is a placeholder for a secure password hashing algorithm. + + It will convert a plaintext password into a secure, salted hash, for storage + in the database. + """ + + # TODO actually hash the password! + return password + + +def _fill_missing_user_fields(db_user: usermodel.User) -> userschema.User: + """Fills all the fields of an instance of userschema.User that cannot be filled by pydantic. + + This function is necessary because the userschema is not a one-to-one reflection + of the database data model. I did not want the 'patient' and 'administrator' + database table to be encoded as their own top level JSON keys in serialized + user object. Instead, the user schema combines all fields from all user types. + This function fills the optional fields, depending on what type of user is + passed in. + """ + + + full_user = userschema.User.from_orm(db_user) + if db_user.patient: + full_user.gender = db_user.patient.gender + full_user.date_of_birth = db_user.patient.date_of_birth + full_user.is_patient = True + full_user.is_admin = False + else: + full_user.is_patient = False + full_user.is_admin = True + + return full_user + + +def create_user(db: Session, user: userschema.UserCreate) -> userschema.User: + """Creates the specified user in the database.""" + + db_user = usermodel.User( + email=user.email, + first_name=user.first_name, + last_name=user.last_name, + password=hash_password(user.password), + ) + + # Add user to database + if user.is_patient: + db_patient = usermodel.Patient( + user=db_user, + gender=user.gender, + date_of_birth=user.date_of_birth, + ) + db.add(db_patient) + else: + db_administrator = usermodel.Administrator( + user=db_user, + ) + db.add(db_administrator) + + db.commit() + + # Construct the updated user to return + db.refresh(db_user) + return _fill_missing_user_fields(db_user) + + +def read_user(db: Session, id: int) -> userschema.User | None: + """Queries the db for a user with the specified id and returns them if they exist.""" + + db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() + if not db_user: + return None + + return _fill_missing_user_fields(db_user) + + +def read_user_by_email(db: Session, email: str) -> userschema.User | None: + """Queries the db for a user with the specified email and returns them if they exist.""" + + db_user = db.query(usermodel.User).filter(usermodel.User.email == email).first() + if not db_user: + return None + + return _fill_missing_user_fields(db_user) + + +def read_users(db: Session, skip: int = 0, limit: int = 100) -> list[userschema.User]: + """Returns an unfiltered range (by id) of users in the database.""" + + db_users = db.query(usermodel.User).offset(skip).limit(limit).all() + + full_users = [] + for db_user in db_users: + full_users.append(_fill_missing_user_fields(db_user)) + return full_users + + +def update_user(db: Session, user: userschema.UserUpdate, id: int) -> userschema.User: + """Updates the user with the provided id with all non-None fields from the input user.""" + + db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() + if not db_user: + raise RuntimeError("Query returned no user.") # should be checked by caller + + for key in ['gender', 'date_of_birth']: + value = getattr(user, key) + if value is not None: + setattr(db_user.patient, key, value) + for key in ['email', 'first_name', 'last_name']: + value = getattr(user, key) + if value is not None: + setattr(db_user, key, value) + if user.password is not None: + db_user.password = hash_password(user.password) + + db.commit() + db.refresh(db_user) + return _fill_missing_user_fields(db_user) + + +def delete_user(db: Session, id: int) -> userschema.User: + """Deletes the user with the provided id from the db.""" + + db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() + if not db_user: + raise RuntimeError("Query returned no user.") # should be checked by caller + user_copy = _fill_missing_user_fields(db_user) + + db.delete(db_user) + db.commit() + + user_copy.updated = datetime.now(user_copy.updated.tzinfo) + return user_copy diff --git a/backend/database.py b/backend/database.py deleted file mode 100644 index 419a3d8..0000000 --- a/backend/database.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base - -from .config import get_settings - -engine = create_engine( - get_settings().pg_dsn, # Get connection string from global settings - echo=get_settings().debug_mode # Get debugmode status from global settings -) -SessionLocal = sessionmaker(engine) - -Base = declarative_base() diff --git a/backend/database/__init__.py b/backend/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/database/engine.py b/backend/database/engine.py new file mode 100644 index 0000000..3ce5cc9 --- /dev/null +++ b/backend/database/engine.py @@ -0,0 +1,19 @@ +"""This module configures and provides the sqlalchemy session factory and base model.""" + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from backend.config import get_settings + + +s = get_settings() + +# The SQL driver is specified by the DSN-prefix below. +_pg_dsn = f"postgresql+psycopg2://{s.pg_user}:{s.pg_password}@{s.pg_hostname}:{s.pg_port}/{s.pg_dbname}" +engine = create_engine(_pg_dsn, echo=s.debug_mode) + +# SQLalchemy session factory +SessionLocal = sessionmaker(engine) +# SQLalchemy base model +Base = declarative_base() diff --git a/backend/main.py b/backend/main.py index b3c2674..a701baf 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,16 +1,17 @@ -import logging +"""Main entry point for the MEDWingS backend. + +This module defines the API routes provided by the backend. +""" from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session -from . import crud, models, schemas -from.database import engine, SessionLocal +import backend.models.users as usermodel +import backend.schemas.users as userschema +import backend.crud.users as usercrud +from backend.database.engine import SessionLocal -log = logging.getLogger() - -models.Base.metadata.create_all(bind=engine) - app = FastAPI() @@ -24,43 +25,44 @@ def get_db(): @app.get("/hello/") def hello(): + """Placeholder for a proper healthcheck endpoint.""" + return "Hello World!" -@app.post("/users/", response_model=schemas.User) -def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)): - existing_user = crud.read_user_by_email(db, email=user.email) +@app.post("/users/", response_model=userschema.User) +def create_user(user: userschema.UserCreate, db: Session = Depends(get_db)): + existing_user = usercrud.read_user_by_email(db, email=user.email) if existing_user: raise HTTPException(status_code=400, detail="A user with this email address is already registered.") - - return crud.create_user(db=db, user=user) + return usercrud.create_user(db=db, user=user) -@app.get("/users/{id}", response_model=schemas.User) +@app.get("/users/{id}", response_model=userschema.User) def read_user(id: int, db: Session = Depends(get_db)): - user = crud.read_user(db=db, id=id) + user = usercrud.read_user(db=db, id=id) if not user: raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.") return user -@app.get("/users/", response_model=list[schemas.User]) +@app.get("/users/", response_model=list[userschema.User]) def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): - users = crud.read_users(db=db, skip=skip, limit=limit) + users = usercrud.read_users(db=db, skip=skip, limit=limit) return users -@app.patch("/users/{id}", response_model=schemas.User) -def update_user(id: int, user: schemas.UserUpdate, db: Session = Depends(get_db)): - current_user = crud.read_user(db=db, id=id) +@app.patch("/users/{id}", response_model=userschema.User) +def update_user(id: int, user: userschema.UserUpdate, db: Session = Depends(get_db)): + current_user = usercrud.read_user(db=db, id=id) if not current_user: raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.") - return crud.update_user(db=db, user=user, id=id) + return usercrud.update_user(db=db, user=user, id=id) -@app.delete("/users/{id}", response_model=schemas.User) +@app.delete("/users/{id}", response_model=userschema.User) def delete_user(id: int, db: Session = Depends(get_db)): - user = crud.read_user(db=db, id=id) + user = usercrud.read_user(db=db, id=id) if not user: raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.") - return crud.delete_user(db=db, id=id) + return usercrud.delete_user(db=db, id=id) diff --git a/backend/models/__init__.py b/backend/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/models.py b/backend/models/users.py similarity index 77% rename from backend/models.py rename to backend/models/users.py index 2029665..de3cfd4 100644 --- a/backend/models.py +++ b/backend/models/users.py @@ -1,13 +1,20 @@ +"""This module defines the SQL user model for users. + +All users are either Patients or Administrators. +""" + import enum from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Date, Enum, CheckConstraint from sqlalchemy.sql.functions import now from sqlalchemy.orm import relationship -from .database import Base +from backend.database.engine import Base class User(Base): + """Model for the users table. Contains user info common to all user classes.""" + __tablename__ = "users" id = Column(Integer, primary_key=True, autoincrement=True, index=True) @@ -20,11 +27,11 @@ class User(Base): administrator = relationship("Administrator", back_populates="user", uselist=False, cascade="all, delete") patient = relationship("Patient", back_populates="user", uselist=False, cascade="all, delete") - #patient = Column(Integer, ForeignKey('patients.id'), nullable=True) - #CheckConstraint("(administrator=NULL AND patient!=NULL) OR (administrator!=NULL AND patient=NULL)") class Administrator(Base): + """Model for the administrators table. Contains user info specific to administrators.""" + __tablename__ = "administrators" user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True,) @@ -32,11 +39,15 @@ class Administrator(Base): class Gender(enum.Enum): + """Gender (as assigned at birth) of a patient.""" + male = 'm' female = 'f' class Patient(Base): + """Model for the patients table. Contains user info specific to patients.""" + __tablename__ = "patients" user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True) diff --git a/backend/schemas/__init__.py b/backend/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/schemas.py b/backend/schemas/users.py similarity index 71% rename from backend/schemas.py rename to backend/schemas/users.py index 2cf7051..f874483 100644 --- a/backend/schemas.py +++ b/backend/schemas/users.py @@ -1,13 +1,25 @@ +"""This module declared the pydantic schema representation for users. + +Note that it is not a direct representation of how users are modeled in the +database. Instead, the User schema class contains all attributes from all user classes +as optional attributes. + +I haven't figured out a smart way to do this with pydantic yet, so behold the +inheritance hellhole below. +""" + from datetime import datetime, date from abc import ABC from typing import Optional from pydantic import BaseModel, validator -from .models import Gender +from backend.models.users import Gender class AbstractUserInfoValidation(BaseModel, ABC): + """Base class providing common field validators.""" + @validator('email', check_fields=False) def assert_email_is_valid(cls, email): if email is not None: @@ -37,7 +49,14 @@ class AbstractUserInfoValidation(BaseModel, ABC): raise ValueError("Date of birth cannot be in the future.") return dob + class AbstractUser(AbstractUserInfoValidation, ABC): + """Base class for attributes common to user creation and user representation. + + A user must be either a patient or an administrator. If a user is a patient, + they must specify valid 'date_of_birth' and 'gender' attributes. + """ + email: str first_name: str last_name: str @@ -50,6 +69,8 @@ class AbstractUser(AbstractUserInfoValidation, ABC): @validator('is_admin') def assert_tegridy(cls, is_admin, values): + """Ensures logical model integrity when optional fields are set.""" + if values['is_patient']: if is_admin: raise ValueError('User cannot be both patient and admin.') @@ -62,6 +83,8 @@ class AbstractUser(AbstractUserInfoValidation, ABC): class UserCreate(AbstractUser): + """Scheme for user creation.""" + password: str password_confirmation: str @@ -76,6 +99,16 @@ class UserCreate(AbstractUser): class UserUpdate(AbstractUserInfoValidation): + """Scheme for user updates. + + All fields here are optional, but passwords must match if at least one was + provided. + Note that even administrator updates can specify 'gender' and 'date_of_birth' + fields, the function inserting the update into the db should handle this (and + just ignore the fields). + Switching user types is prohibited. + """ + email: Optional[str] first_name: Optional[str] last_name: Optional[str] @@ -99,6 +132,12 @@ class UserUpdate(AbstractUserInfoValidation): class User(AbstractUser): + """Final representation of all types of users, wrapped into one User schema. + + The id, created and updated fields are filled by the db during creation, so + they are not needed in the parent classes. + """ + id: int created: datetime updated: datetime