diff --git a/lms/djangoapps/course_home_api/course_metadata/tests/test_views.py b/lms/djangoapps/course_home_api/course_metadata/tests/test_views.py index 3e58ecd6464d..ea8ba99c16e9 100644 --- a/lms/djangoapps/course_home_api/course_metadata/tests/test_views.py +++ b/lms/djangoapps/course_home_api/course_metadata/tests/test_views.py @@ -5,18 +5,22 @@ import ddt import mock from django.urls import reverse - from edx_toggles.toggles.testutils import override_waffle_flag + from common.djangoapps.course_modes.models import CourseMode +from common.djangoapps.student.models import CourseEnrollment from common.djangoapps.student.roles import CourseInstructorRole +from common.djangoapps.student.tests.factories import UserFactory +from lms.djangoapps.course_home_api.tests.utils import BaseCourseHomeTests from lms.djangoapps.courseware.toggles import ( + COURSEWARE_MFE_MILESTONES_STREAK_DISCOUNT, COURSEWARE_MICROFRONTEND_PROGRESS_MILESTONES, - COURSEWARE_MICROFRONTEND_PROGRESS_MILESTONES_STREAK_CELEBRATION, + COURSEWARE_MICROFRONTEND_PROGRESS_MILESTONES_STREAK_CELEBRATION +) +from openedx.features.enterprise_support.tests.factories import ( + EnterpriseCourseEnrollmentFactory, + EnterpriseCustomerUserFactory ) -from common.djangoapps.student.models import CourseEnrollment -from common.djangoapps.student.tests.factories import UserFactory -from lms.djangoapps.course_home_api.tests.utils import BaseCourseHomeTests -from lms.djangoapps.courseware.toggles import COURSEWARE_MFE_MILESTONES_STREAK_DISCOUNT @ddt.ddt @@ -82,6 +86,16 @@ def test_get_unknown_course(self): response = self.client.get(url) assert response.status_code == 404 + def _assert_course_access_response(self, response, expect_course_access, expected_error_code): + """ + Responsible to asset the course_access response with expected values. + """ + assert response.status_code == 200 + assert response.data['course_access']['has_access'] == expect_course_access + assert response.data['course_access']['error_code'] == expected_error_code + # Start date is used when handling some errors, so make sure it is present too + assert response.data['start'] == self.course.start.isoformat() + 'Z' + def test_streak_data_in_response(self): """ Test that metadata endpoint returns data for the streak celebration """ CourseEnrollment.enroll(self.user, self.course.id, 'audit') @@ -138,6 +152,15 @@ def test_streak_data_in_response(self): 'dsc_required': True, 'expect_course_access': False, 'error_code': 'data_sharing_access_required' + }, + { + # Data sharing Consent required staff should Not have access. + 'enroll_user': True, + 'instructor_role': True, + 'masquerade_role': None, + 'dsc_required': True, + 'expect_course_access': False, + 'error_code': 'data_sharing_access_required' } ) @ddt.unpack @@ -159,8 +182,43 @@ def test_course_access( with mock.patch('openedx.features.enterprise_support.api.get_enterprise_consent_url', return_value=consent_url): response = self.client.get(self.url) - assert response.status_code == 200 - assert response.data['course_access']['has_access'] == expect_course_access - assert response.data['course_access']['error_code'] == error_code - # Start date is used when handling some errors, so make sure it is present too - assert response.data['start'] == self.course.start.isoformat() + 'Z' + self._assert_course_access_response(response, expect_course_access, error_code) + + @ddt.data(True, False) + def test_course_access_with_correct_active_enterprise(self, instructor_role): + """ + Test that course_access is calculated correctly based on + access to MFE and access to the course itself. + """ + if instructor_role: + CourseInstructorRole(self.course.id).add_users(self.user) + + # Test with no EnterpriseCourseEnrollment + course_enrollment = CourseEnrollment.enroll(self.user, self.course.id, 'audit') + response = self.client.get(self.url) + self._assert_course_access_response(response, True, None) + + # Test with EnterpriseCourseEnrollment and having correct active enterprise + course = course_enrollment.course + enterprise_customer_user = EnterpriseCustomerUserFactory(user_id=self.user.id) + EnterpriseCourseEnrollmentFactory(enterprise_customer_user=enterprise_customer_user, course_id=course.id) + response = self.client.get(self.url) + self._assert_course_access_response(response, True, None) + + # Test with incorrect active enterprise + enterprise_customer_user_2 = EnterpriseCustomerUserFactory(user_id=self.user.id, active=True) + enterprise_customer_user.refresh_from_db() + assert not enterprise_customer_user.active + assert enterprise_customer_user_2.active + response = self.client.get(self.url) + self._assert_course_access_response(response, False, 'incorrect_active_enterprise') + + # test when no active enterprise at all (ideally this should never happen) + enterprise_customer_user_2.active = False + enterprise_customer_user_2.save() + enterprise_customer_user.refresh_from_db() + enterprise_customer_user_2.refresh_from_db() + assert not enterprise_customer_user.active + assert not enterprise_customer_user_2.active + response = self.client.get(self.url) + self._assert_course_access_response(response, False, 'incorrect_active_enterprise') diff --git a/lms/djangoapps/course_home_api/course_metadata/views.py b/lms/djangoapps/course_home_api/course_metadata/views.py index 23de19fe5643..e534d03ebbf7 100644 --- a/lms/djangoapps/course_home_api/course_metadata/views.py +++ b/lms/djangoapps/course_home_api/course_metadata/views.py @@ -88,7 +88,7 @@ def get(self, request, *args, **kwargs): 'load', check_if_enrolled=True, check_if_authenticated=True, - check_if_dsc_required=True, + apply_enterprise_checks=True, ) _, request.user = setup_masquerade( diff --git a/lms/djangoapps/courseware/access_response.py b/lms/djangoapps/courseware/access_response.py index 9885b4169f22..abfdf61db2c6 100644 --- a/lms/djangoapps/courseware/access_response.py +++ b/lms/djangoapps/courseware/access_response.py @@ -227,6 +227,22 @@ def __init__(self): super().__init__(error_code, developer_message, user_message) +class IncorrectActiveEnterpriseAccessError(AccessError): + """ + Access denied because the user must login with correct enterprise. + """ + def __init__(self, enrollment_enterprise_name, active_enterprise_name): + error_code = "incorrect_active_enterprise" + developer_message = "User active enterprise should be same as EnterpriseCourseEnrollment enterprise." + user_message = _("You are enrolled in this course with '{enrollment_enterprise_name}'. However, you are " + "currently logged in as a '{active_enterprise_name}' user. Please log in with " + "'{enrollment_enterprise_name}' to access this course.") + user_message = user_message.format( + enrollment_enterprise_name=enrollment_enterprise_name, active_enterprise_name=active_enterprise_name + ) + super().__init__(error_code, developer_message, user_message) + + class DataSharingConsentRequiredAccessError(AccessError): """ Access denied because the user must give Data sharing consent before access it. diff --git a/lms/djangoapps/courseware/access_utils.py b/lms/djangoapps/courseware/access_utils.py index ab82296716f9..18f4bb27a8a1 100644 --- a/lms/djangoapps/courseware/access_utils.py +++ b/lms/djangoapps/courseware/access_utils.py @@ -9,6 +9,7 @@ from crum import get_current_request from django.conf import settings +from enterprise.models import EnterpriseCourseEnrollment, EnterpriseCustomerUser from pytz import UTC from common.djangoapps.student.models import CourseEnrollment @@ -18,6 +19,7 @@ AuthenticationRequiredAccessError, DataSharingConsentRequiredAccessError, EnrollmentRequiredAccessError, + IncorrectActiveEnterpriseAccessError, StartDateError ) from lms.djangoapps.courseware.masquerade import get_course_masquerade, is_masquerading_as_student @@ -178,7 +180,7 @@ def check_data_sharing_consent(course_id): from openedx.features.enterprise_support.api import get_enterprise_consent_url consent_url = get_enterprise_consent_url( request=get_current_request(), - course_id=course_id, + course_id=str(course_id), return_to='courseware', enrollment_exists=True, source='CoursewareAccess' @@ -186,3 +188,47 @@ def check_data_sharing_consent(course_id): if consent_url: return DataSharingConsentRequiredAccessError(consent_url=consent_url) return ACCESS_GRANTED + + +def check_correct_active_enterprise_customer(user, course_id): + """ + Grants access if the user's active enterprise customer is same as EnterpriseCourseEnrollment's Enterprise. + Also, Grant access if enrollment is not Enterprise + + Returns: + AccessResponse: Either ACCESS_GRANTED or IncorrectActiveEnterpriseAccessError + """ + enterprise_enrollments = EnterpriseCourseEnrollment.objects.filter( + course_id=course_id, enterprise_customer_user__user_id=user.id + ) + if not enterprise_enrollments.exists(): + return ACCESS_GRANTED + + try: + active_enterprise_customer_user = EnterpriseCustomerUser.objects.get(user_id=user.id, active=True) + if enterprise_enrollments.filter(enterprise_customer_user=active_enterprise_customer_user).exists(): + return ACCESS_GRANTED + + active_enterprise_name = active_enterprise_customer_user.enterprise_customer.name + except (EnterpriseCustomerUser.DoesNotExist, EnterpriseCustomerUser.MultipleObjectsReturned): + # Ideally this should not happen. As there should be only 1 active enterprise customer in our system + log.error("Multiple or No Active Enterprise found for the user %s.", user.id) + active_enterprise_name = 'Incorrect' + + enrollment_enterprise_name = enterprise_enrollments.first().enterprise_customer_user.enterprise_customer.name + return IncorrectActiveEnterpriseAccessError(enrollment_enterprise_name, active_enterprise_name) + + +def is_priority_access_error(access_error): + """ + Check if given access error is a priority Access Error or not. + Priority Access Error can not be bypassed by staff users. + """ + priority_access_errors = [ + DataSharingConsentRequiredAccessError, + IncorrectActiveEnterpriseAccessError, + ] + for priority_access_error in priority_access_errors: + if isinstance(access_error, priority_access_error): + return True + return False diff --git a/lms/djangoapps/courseware/courses.py b/lms/djangoapps/courseware/courses.py index 87f5a54f63bd..c7a9c981a516 100644 --- a/lms/djangoapps/courseware/courses.py +++ b/lms/djangoapps/courseware/courses.py @@ -33,7 +33,8 @@ OldMongoAccessError, StartDateError ) -from lms.djangoapps.courseware.access_utils import check_authentication, check_data_sharing_consent, check_enrollment +from lms.djangoapps.courseware.access_utils import check_authentication, check_data_sharing_consent, check_enrollment, \ + check_correct_active_enterprise_customer, is_priority_access_error from lms.djangoapps.courseware.courseware_access_exception import CoursewareAccessException from lms.djangoapps.courseware.date_summary import ( CertificateAvailableDate, @@ -138,7 +139,7 @@ def check_course_access( check_if_enrolled=False, check_survey_complete=True, check_if_authenticated=False, - check_if_dsc_required=False, + apply_enterprise_checks=False, ): """ Check that the user has the access to perform the specified action @@ -165,8 +166,12 @@ def _check_nonstaff_access(): if not enrollment_access_response: return enrollment_access_response - if check_if_dsc_required: - data_sharing_consent_response = check_data_sharing_consent(course) + if apply_enterprise_checks: + correct_active_enterprise_response = check_correct_active_enterprise_customer(user, course.id) + if not correct_active_enterprise_response: + return correct_active_enterprise_response + + data_sharing_consent_response = check_data_sharing_consent(course.id) if not data_sharing_consent_response: return data_sharing_consent_response @@ -179,15 +184,18 @@ def _check_nonstaff_access(): # This access_response will be ACCESS_GRANTED return access_response + non_staff_access_response = _check_nonstaff_access() + + # User has course access OR access error is a priority error + if non_staff_access_response or is_priority_access_error(non_staff_access_response): + return non_staff_access_response + # Allow staff full access to the course even if other checks fail - nonstaff_access_response = _check_nonstaff_access() - if not nonstaff_access_response: - staff_access_response = has_access(user, 'staff', course.id) - if staff_access_response: - return staff_access_response - - # This access_response will be ACCESS_GRANTED - return nonstaff_access_response + staff_access_response = has_access(user, 'staff', course.id) + if staff_access_response: + return staff_access_response + + return non_staff_access_response def check_course_access_with_redirect(course, user, action, check_if_enrolled=False, check_survey_complete=True, check_if_authenticated=False): # lint-amnesty, pylint: disable=line-too-long