diff --git a/backend/app/models.py b/backend/app/models.py index 8b96c88..3431fc8 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -5,6 +5,13 @@ from datetime import datetime from werkzeug.security import generate_password_hash, check_password_hash +enrollment = db.Table( + "enrollment", + sa.Column("user_id", sa.ForeignKey("user.id"), primary_key=True), + sa.Column("course_id", sa.ForeignKey("course.id"), primary_key=True), +) + + class User(UserMixin, db.Model): id = sa.Column(sa.Integer, primary_key=True) username = sa.Column(sa.String(64), index=True, unique=True) @@ -13,24 +20,39 @@ class User(UserMixin, db.Model): password_hash = sa.Column(sa.String(128)) last_seen = sa.Column(sa.DateTime, default=datetime.utcnow) token = sa.Column(sa.String(32), index=True, unique=True) + enrolled_courses = db.relationship( + "Course", + secondary=enrollment, + backref=db.backref("students", lazy="dynamic"), + lazy="dynamic", + ) - def __repr__(self): + def __repr__(self) -> str: return f"" - def set_password(self, password): + def set_password(self, password) -> None: self.password_hash = generate_password_hash(password) - def check_password(self, password): + def check_password(self, password) -> bool: return check_password_hash(self.password_hash, password) - def to_dict(self): + def is_enrolled(self, c) -> bool: + return self.enrolled_courses.filter(enrollment.c.course_id == c.id).count() > 0 + + def enroll(self, c) -> bool: + if not self.is_enrolled(c): + self.enrolled_courses.append(c) + return True + return False + + def to_dict(self) -> dict: return { "id": self.id, "username": self.username, "email": self.email, } - def from_dict(self, data, new_user=False): + def from_dict(self, data, new_user=False) -> None: for field in ["role", "username", "email"]: if field in data: setattr(self, field, data[field]) @@ -45,7 +67,10 @@ class Course(db.Model): instructor = sa.Column(sa.ForeignKey(User.id), index=True) created_at = sa.Column(sa.DateTime) - def from_dict(self, data): + def __repr__(self) -> str: + return f"" + + def from_dict(self, data) -> None: for field in ["name", "description", "instructor"]: if field in data: setattr(self, field, data[field]) @@ -53,8 +78,10 @@ class Course(db.Model): if not self.created_at: self.created_at = datetime.now() - def to_dict(self): + def to_dict(self) -> dict: d = {} - for f in ["id", "name", "description", "instructor", "created_at"]: + for f in ["id", "name", "description", "created_at"]: d[f] = getattr(self, f) + + d["instructor"] = User.query.get(self.instructor).username return d diff --git a/backend/app/routes.py b/backend/app/routes.py index 5ee41b6..fd86958 100644 --- a/backend/app/routes.py +++ b/backend/app/routes.py @@ -95,3 +95,12 @@ def create_course(): return jsonify(c.to_dict()) + +@bp.route("/user//courses", methods=["GET"]) +def get_courses(id): + u = User.query.get(id) + d = {"courses": []} + for c in u.enrolled_courses.all(): + d["courses"].append(c.to_dict()) + resp = jsonify(d) + return resp diff --git a/backend/migrations/versions/093a66f0b581_.py b/backend/migrations/versions/093a66f0b581_.py new file mode 100644 index 0000000..b68afa1 --- /dev/null +++ b/backend/migrations/versions/093a66f0b581_.py @@ -0,0 +1,34 @@ +"""Create enrollment table + +Revision ID: 093a66f0b581 +Revises: 471b4225837e +Create Date: 2023-04-06 16:14:21.262823 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '093a66f0b581' +down_revision = '471b4225837e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('enrollment', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('course_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['course_id'], ['course.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('user_id', 'course_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('enrollment') + # ### end Alembic commands ###