diff --git a/courses/api.py b/courses/api.py index 224272d934..d9bdfddf77 100644 --- a/courses/api.py +++ b/courses/api.py @@ -133,6 +133,23 @@ def get_user_relevant_program_course_run_qset( return enrollable_run_qset.order_by("enrollment_start") +def run_requires_payment_for_user(user, run) -> bool: + """ + Returns True if the course run requires payment and the user has not yet paid. + + A run requires payment when all of its enrollment modes have requires_payment=True + and at least one mode is configured. If the run has no enrollment modes, no + restriction is enforced and False is returned. + """ + enrollment_modes = list(run.enrollment_modes.all()) + if not enrollment_modes: + return False + has_free_mode = any(not mode.requires_payment for mode in enrollment_modes) + if has_free_mode: + return False + return not PaidCourseRun.fulfilled_paid_course_run_exists(user, run) + + def create_local_enrollment(user, run, *, mode=EDX_DEFAULT_ENROLLMENT_MODE): """ Creates a local-only CourseRunEnrollment record without calling the edX API. diff --git a/courses/serializers/v1/courses.py b/courses/serializers/v1/courses.py index b07b8305fc..e8d28464d6 100644 --- a/courses/serializers/v1/courses.py +++ b/courses/serializers/v1/courses.py @@ -11,7 +11,7 @@ from cms.serializers import CoursePageSerializer from courses import models -from courses.api import create_run_enrollments +from courses.api import create_run_enrollments, run_requires_payment_for_user from courses.serializers.v1.base import ( BaseCourseRunEnrollmentWithFlexiblePriceSerializer, BaseCourseRunSerializer, @@ -175,6 +175,11 @@ def create(self, validated_data): if run.b2b_contract is not None: raise ValidationError({"run_id": f"Invalid course run id: {run_id}"}) + if run_requires_payment_for_user(user, run): + raise ValidationError( + {"run_id": "Payment is required to enroll in this course run."} + ) + successful_enrollments, _ = create_run_enrollments( user, [run], diff --git a/courses/serializers/v1/courses_test.py b/courses/serializers/v1/courses_test.py index 325286c3e8..71aa216270 100644 --- a/courses/serializers/v1/courses_test.py +++ b/courses/serializers/v1/courses_test.py @@ -3,6 +3,7 @@ import pytest from django.contrib.auth.models import AnonymousUser from django.db.models import Prefetch +from rest_framework.exceptions import ValidationError from cms.factories import CoursePageFactory, FlexiblePricingFormFactory from cms.serializers import CoursePageSerializer @@ -11,8 +12,10 @@ CourseRunEnrollmentFactory, CourseRunFactory, CourseRunGradeFactory, + EnrollmentModeFactory, + UserFactory, ) -from courses.models import Course, CourseRun, Department +from courses.models import Course, CourseRun, Department, PaidCourseRun from courses.serializers.v1.base import BaseCourseSerializer, CourseRunGradeSerializer from courses.serializers.v1.courses import ( CourseRunEnrollmentSerializer, @@ -22,10 +25,13 @@ CourseWithCourseRunsSerializer, ) from courses.serializers.v1.programs import ProgramSerializer +from ecommerce.factories import OrderFactory +from ecommerce.models import OrderStatus from ecommerce.serializers.v0 import BaseProductSerializer from flexiblepricing.constants import FlexiblePriceStatus from flexiblepricing.factories import FlexiblePriceFactory from main.test_utils import assert_drf_json_equal, drf_datetime +from openedx.constants import EDX_ENROLLMENT_AUDIT_MODE, EDX_ENROLLMENT_VERIFIED_MODE pytestmark = [pytest.mark.django_db] @@ -263,3 +269,86 @@ def test_serialize_course_run_enrollments_with_grades(): "certificate": None, "grades": CourseRunGradeSerializer([grade], many=True).data, } + + +class TestCourseRunEnrollmentSerializerCreate: + """Tests for enrollment creation validation in CourseRunEnrollmentSerializer.""" + + def test_create_enrollment_allowed_when_audit_mode_available(self, mocker): + """Free enrollment is permitted when the run has a free (non-payment-required) mode.""" + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_AUDIT_MODE, requires_payment=False + ) + ] + ) + mocker.patch( + "courses.serializers.v1.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None + + def test_create_enrollment_blocked_when_payment_required_and_not_paid(self): + """Enrollment is rejected when all modes require payment and the user has not paid.""" + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + with pytest.raises(ValidationError) as exc_info: + serializer.save() + assert "run_id" in exc_info.value.detail + + def test_create_enrollment_allowed_when_payment_required_and_user_paid( + self, mocker + ): + """Enrollment is permitted when all modes require payment but the user has paid.""" + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + order = OrderFactory.create(purchaser=user, state=OrderStatus.FULFILLED) + PaidCourseRun.objects.create(user=user, course_run=run, order=order) + mocker.patch( + "courses.serializers.v1.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None + + def test_create_enrollment_allowed_when_no_enrollment_modes(self, mocker): + """Enrollment is permitted when no enrollment modes are set (no restriction enforced).""" + user = UserFactory.create() + run = CourseRunFactory.create(enrollment_modes=[]) + mocker.patch( + "courses.serializers.v1.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None diff --git a/courses/serializers/v2/courses.py b/courses/serializers/v2/courses.py index bd12ee484a..64616e3ef0 100644 --- a/courses/serializers/v2/courses.py +++ b/courses/serializers/v2/courses.py @@ -12,7 +12,7 @@ from cms.serializers import CoursePageSerializer from courses import models -from courses.api import create_run_enrollments +from courses.api import create_run_enrollments, run_requires_payment_for_user from courses.serializers.utils import get_topics_from_page from courses.serializers.v1.base import ( BaseCourseRunEnrollmentWithFlexiblePriceSerializer, @@ -398,6 +398,12 @@ def create(self, validated_data): if run.b2b_contract is not None: raise ValidationError({"run_id": f"Invalid course run id: {run_id}"}) + + if run_requires_payment_for_user(user, run): + raise ValidationError( + {"run_id": "Payment is required to enroll in this course run."} + ) + successful_enrollments, _ = create_run_enrollments( user, [run], diff --git a/courses/serializers/v2/courses_test.py b/courses/serializers/v2/courses_test.py index 6e2f4fb2fb..9054a3ee3e 100644 --- a/courses/serializers/v2/courses_test.py +++ b/courses/serializers/v2/courses_test.py @@ -414,3 +414,81 @@ def test_course_serializer_language_options(): # run_tag should be present in each option for opt in serializer.data["language_options"]: assert opt["run_tag"] == "1T2026" + + +class TestCourseRunEnrollmentSerializerV2PaymentGuard: + """Tests for payment guard in v2 CourseRunEnrollmentSerializer.create().""" + + def test_blocked_when_payment_required_and_not_paid(self): + """Enrollment is rejected when all modes require payment and the user has not paid.""" + from rest_framework.exceptions import ValidationError # noqa: PLC0415 + + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + with pytest.raises(ValidationError) as exc_info: + serializer.save() + assert "run_id" in exc_info.value.detail + + def test_allowed_when_audit_mode_available(self, mocker): + """Free enrollment is permitted when the run has a free mode.""" + from openedx.constants import EDX_ENROLLMENT_AUDIT_MODE # noqa: PLC0415 + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_AUDIT_MODE, requires_payment=False + ) + ] + ) + mocker.patch( + "courses.serializers.v2.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None + + def test_allowed_when_user_has_paid(self, mocker): + """Enrollment is permitted when all modes require payment but the user has paid.""" + from courses.models import PaidCourseRun # noqa: PLC0415 + from ecommerce.factories import OrderFactory # noqa: PLC0415 + from ecommerce.models import OrderStatus # noqa: PLC0415 + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + order = OrderFactory.create(purchaser=user, state=OrderStatus.FULFILLED) + PaidCourseRun.objects.create(user=user, course_run=run, order=order) + mocker.patch( + "courses.serializers.v2.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None diff --git a/courses/serializers/v3/courses.py b/courses/serializers/v3/courses.py index 9f9dd2e28d..8c7d5f56ff 100644 --- a/courses/serializers/v3/courses.py +++ b/courses/serializers/v3/courses.py @@ -10,7 +10,7 @@ from rest_framework.exceptions import ValidationError from courses import models -from courses.api import create_run_enrollments +from courses.api import create_run_enrollments, run_requires_payment_for_user from courses.serializers.v1.base import ( BaseCourseRunEnrollmentSerializer, BaseCourseRunSerializer, @@ -113,6 +113,11 @@ def create(self, validated_data): if run is None or run.b2b_contract_id is not None: raise ValidationError({"run_id": f"Invalid course run id: {run_id}"}) + if run_requires_payment_for_user(user, run): + raise ValidationError( + {"run_id": "Payment is required to enroll in this course run."} + ) + successful_enrollments, _ = create_run_enrollments( user, [run], diff --git a/courses/serializers/v3/courses_test.py b/courses/serializers/v3/courses_test.py index b0fdd9e485..629059cb75 100644 --- a/courses/serializers/v3/courses_test.py +++ b/courses/serializers/v3/courses_test.py @@ -80,3 +80,95 @@ def test_serializer_upgrade_fields_null_when_not_eligible(self): assert serialized_data["run"]["upgrade_product_id"] is None assert serialized_data["run"]["upgrade_product_price"] is None assert serialized_data["run"]["upgrade_product_is_active"] is None + + +class TestCourseRunEnrollmentSerializerV3PaymentGuard: + """Tests for payment guard in v3 CourseRunEnrollmentSerializer.create().""" + + def test_blocked_when_payment_required_and_not_paid(self): + """Enrollment is rejected when all modes require payment and the user has not paid.""" + from rest_framework.exceptions import ValidationError # noqa: PLC0415 + + from courses.factories import ( # noqa: PLC0415 + CourseRunFactory, + EnrollmentModeFactory, + ) + from openedx.constants import EDX_ENROLLMENT_VERIFIED_MODE # noqa: PLC0415 + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + with pytest.raises(ValidationError) as exc_info: + serializer.save() + assert "run_id" in exc_info.value.detail + + def test_allowed_when_audit_mode_available(self, mocker): + """Free enrollment is permitted when the run has a free mode.""" + from courses.factories import ( # noqa: PLC0415 + CourseRunFactory, + EnrollmentModeFactory, + ) + from openedx.constants import EDX_ENROLLMENT_AUDIT_MODE # noqa: PLC0415 + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_AUDIT_MODE, requires_payment=False + ) + ] + ) + mocker.patch( + "courses.serializers.v3.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None + + def test_allowed_when_user_has_paid(self, mocker): + """Enrollment is permitted when all modes require payment but the user has paid.""" + from courses.factories import ( # noqa: PLC0415 + CourseRunFactory, + EnrollmentModeFactory, + ) + from courses.models import PaidCourseRun # noqa: PLC0415 + from ecommerce.factories import OrderFactory # noqa: PLC0415 + from ecommerce.models import OrderStatus # noqa: PLC0415 + from openedx.constants import EDX_ENROLLMENT_VERIFIED_MODE # noqa: PLC0415 + from users.factories import UserFactory # noqa: PLC0415 + + user = UserFactory.create() + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + order = OrderFactory.create(purchaser=user, state=OrderStatus.FULFILLED) + PaidCourseRun.objects.create(user=user, course_run=run, order=order) + mocker.patch( + "courses.serializers.v3.courses.create_run_enrollments", + return_value=([mocker.Mock()], True), + ) + serializer = CourseRunEnrollmentSerializer( + data={"run_id": run.id}, context={"user": user} + ) + assert serializer.is_valid(), serializer.errors + result = serializer.save() + assert result is not None diff --git a/courses/views/v1/__init__.py b/courses/views/v1/__init__.py index 1834d51b94..2f1a736722 100644 --- a/courses/views/v1/__init__.py +++ b/courses/views/v1/__init__.py @@ -32,6 +32,7 @@ deactivate_run_enrollment, get_relevant_course_run_qset, get_user_relevant_program_course_run_qset, + run_requires_payment_for_user, ) from courses.constants import ENROLL_CHANGE_STATUS_UNENROLLED from courses.models import ( @@ -72,6 +73,7 @@ USER_MSG_TYPE_ENROLL_BLOCKED, USER_MSG_TYPE_ENROLL_DUPLICATED, USER_MSG_TYPE_ENROLL_FAILED, + USER_MSG_TYPE_ENROLL_PAYMENT_REQUIRED, USER_MSG_TYPE_ENROLLED, ) from main.utils import encode_json_cookie_value, redirect_with_user_message @@ -340,6 +342,18 @@ def _validate_enrollment_post_request( {"type": USER_MSG_TYPE_ENROLL_DUPLICATED}, ) return resp, None, None + if run_requires_payment_for_user(user, run): + resp = HttpResponseRedirect(request.headers["Referer"]) + resp.set_cookie( + key=USER_MSG_COOKIE_NAME, + value=encode_json_cookie_value( + { + "type": USER_MSG_TYPE_ENROLL_PAYMENT_REQUIRED, + } + ), + max_age=USER_MSG_COOKIE_MAX_AGE, + ) + return resp, None, None return None, user, run diff --git a/courses/views/v1/views_test.py b/courses/views/v1/views_test.py index 1b4c502465..40726f0b12 100644 --- a/courses/views/v1/views_test.py +++ b/courses/views/v1/views_test.py @@ -70,6 +70,7 @@ USER_MSG_COOKIE_NAME, USER_MSG_TYPE_ENROLL_BLOCKED, USER_MSG_TYPE_ENROLL_FAILED, + USER_MSG_TYPE_ENROLL_PAYMENT_REQUIRED, USER_MSG_TYPE_ENROLLED, ) from main.test_utils import assert_drf_json_equal, duplicate_queries_check @@ -733,6 +734,36 @@ def test_create_enrollments_blocked_country(user_client, user): ) +def test_create_enrollments_payment_required(user_client, user): + """ + Create enrollment view should redirect with a payment-required message if all enrollment + modes on the run require payment and the user has not paid. + """ + from courses.factories import EnrollmentModeFactory # noqa: PLC0415 + from openedx.constants import EDX_ENROLLMENT_VERIFIED_MODE # noqa: PLC0415 + + run = CourseRunFactory.create( + enrollment_modes=[ + EnrollmentModeFactory.create( + mode_slug=EDX_ENROLLMENT_VERIFIED_MODE, requires_payment=True + ) + ] + ) + resp = user_client.post( + reverse("create-enrollment-via-form"), + data={"run": str(run.id)}, + HTTP_REFERER=EXAMPLE_URL, + ) + assert resp.status_code == status.HTTP_302_FOUND + assert resp.url == EXAMPLE_URL + assert USER_MSG_COOKIE_NAME in resp.cookies + assert resp.cookies[USER_MSG_COOKIE_NAME].value == encode_json_cookie_value( + { + "type": USER_MSG_TYPE_ENROLL_PAYMENT_REQUIRED, + } + ) + + @pytest.mark.parametrize("receive_emails", [True, False]) def test_update_user_enrollment(mocker, user_drf_client, user, receive_emails): """The enrollment should update the course email subscriptions""" diff --git a/main/constants.py b/main/constants.py index be0f76bcb8..17e4dc7ac3 100644 --- a/main/constants.py +++ b/main/constants.py @@ -5,6 +5,7 @@ USER_MSG_TYPE_ENROLL_FAILED = "enroll-failed" USER_MSG_TYPE_ENROLL_BLOCKED = "enroll-blocked" USER_MSG_TYPE_ENROLL_DUPLICATED = "enroll-duplicated" +USER_MSG_TYPE_ENROLL_PAYMENT_REQUIRED = "enroll-payment-required" USER_MSG_TYPE_COMPLETED_AUTH = "completed-auth" USER_MSG_TYPE_COURSE_NON_UPGRADABLE = "course-non-upgradable" USER_MSG_TYPE_DISCOUNT_INVALID = "discount-invalid"