diff --git a/apps/analysis/apps.py b/apps/analysis/apps.py index 126ab0e58f..9980ea0351 100644 --- a/apps/analysis/apps.py +++ b/apps/analysis/apps.py @@ -2,4 +2,4 @@ class AnalysisConfig(AppConfig): - name = 'analysis' + name = "analysis" diff --git a/apps/analysis/dataloaders.py b/apps/analysis/dataloaders.py index 76d6eb8da2..34cab8ca51 100644 --- a/apps/analysis/dataloaders.py +++ b/apps/analysis/dataloaders.py @@ -1,8 +1,8 @@ -from promise import Promise from collections import defaultdict -from django.utils.functional import cached_property from django.db import models +from django.utils.functional import cached_property +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin @@ -10,28 +10,30 @@ Analysis, AnalysisPillar, AnalysisReport, + AnalysisReportContainer, + AnalysisReportContainerData, + AnalysisReportSnapshot, + AnalysisReportUpload, AnalyticalStatement, AnalyticalStatementEntry, DiscardedEntry, TopicModelCluster, - AnalysisReportUpload, - AnalysisReportContainerData, - AnalysisReportContainer, - AnalysisReportSnapshot, ) class AnalysisPublicationDatesLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = AnalyticalStatementEntry.objects.filter( - analytical_statement__analysis_pillar__analysis__in=keys, - ).order_by().values('analytical_statement__analysis_pillar__analysis').annotate( - published_on_min=models.Min('entry__lead__published_on'), - published_on_max=models.Max('entry__lead__published_on'), - ).values_list( - 'published_on_min', - 'published_on_max', - 'analytical_statement__analysis_pillar__analysis' + qs = ( + AnalyticalStatementEntry.objects.filter( + analytical_statement__analysis_pillar__analysis__in=keys, + ) + .order_by() + .values("analytical_statement__analysis_pillar__analysis") + .annotate( + published_on_min=models.Min("entry__lead__published_on"), + published_on_max=models.Max("entry__lead__published_on"), + ) + .values_list("published_on_min", "published_on_max", "analytical_statement__analysis_pillar__analysis") ) _map = {} for start_date, end_date, _id in qs: @@ -45,17 +47,13 @@ def batch_load_fn(self, keys): class AnalysisAnalyzedEntriesLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - _map = Analysis.get_analyzed_entries([ - Analysis(id=key) for key in keys - ]) + _map = Analysis.get_analyzed_entries([Analysis(id=key) for key in keys]) return Promise.resolve([_map.get(key, 0) for key in keys]) class AnalysisAnalyzedLeadsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - _map = Analysis.get_analyzed_sources([ - Analysis(id=key) for key in keys - ]) + _map = Analysis.get_analyzed_sources([Analysis(id=key) for key in keys]) return Promise.resolve([_map.get(key, 0) for key in keys]) @@ -88,57 +86,71 @@ def batch_load_fn(self, keys): class AnalysisPillarsAnalyzedEntriesLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = AnalysisPillar.objects\ - .filter(id__in=keys)\ + qs = ( + AnalysisPillar.objects.filter(id__in=keys) .annotate( dragged_entries=models.functions.Coalesce( models.Subquery( - AnalyticalStatement.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar').annotate(count=models.Count( - 'entries', - distinct=True, - filter=models.Q(entries__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], + AnalyticalStatement.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar") + .annotate( + count=models.Count( + "entries", + distinct=True, + filter=models.Q(entries__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], output_field=models.IntegerField(), - ), 0), + ), + 0, + ), discarded_entries=models.functions.Coalesce( models.Subquery( - DiscardedEntry.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar__analysis').annotate(count=models.Count( - 'entry', - distinct=True, - filter=models.Q(entry__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], + DiscardedEntry.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar__analysis") + .annotate( + count=models.Count( + "entry", + distinct=True, + filter=models.Q(entry__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], output_field=models.IntegerField(), - ), 0), - analyzed_entries=models.F('dragged_entries') + models.F('discarded_entries'), - ).values_list('id', 'analyzed_entries') - _map = { - _id: count - for _id, count in qs - } + ), + 0, + ), + analyzed_entries=models.F("dragged_entries") + models.F("discarded_entries"), + ) + .values_list("id", "analyzed_entries") + ) + _map = {_id: count for _id, count in qs} return Promise.resolve([_map.get(key, 0) for key in keys]) class AnalysisStatementAnalyzedEntriesLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = AnalyticalStatement.objects.filter(id__in=keys).annotate( - count=models.Count('entries', distinct=True) - ).values('id', 'count') - _map = { - _id: count - for _id, count in qs - } + qs = ( + AnalyticalStatement.objects.filter(id__in=keys) + .annotate(count=models.Count("entries", distinct=True)) + .values("id", "count") + ) + _map = {_id: count for _id, count in qs} return Promise.resolve([_map.get(key, 0) for key in keys]) class AnalysisTopicModelClusterEntryLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = TopicModelCluster.entries.through.objects.filter( - topicmodelcluster__in=keys, - ).select_related('entry').order_by('topicmodelcluster', 'entry_id') + qs = ( + TopicModelCluster.entries.through.objects.filter( + topicmodelcluster__in=keys, + ) + .select_related("entry") + .order_by("topicmodelcluster", "entry_id") + ) _map = defaultdict(list) for cluster_entry in qs: _map[cluster_entry.topicmodelcluster_id].append(cluster_entry.entry) @@ -151,10 +163,7 @@ def batch_load_fn(self, keys): qs = AnalysisReportUpload.objects.filter( id__in=keys, ) - _map = { - item.pk: item - for item in qs - } + _map = {item.pk: item for item in qs} return Promise.resolve([_map.get(key, []) for key in keys]) @@ -173,7 +182,7 @@ class OrganizationByAnalysisReportLoader(DataLoaderWithContext): def batch_load_fn(self, keys): qs = AnalysisReport.organizations.through.objects.filter( analysisreport__in=keys, - ).select_related('organization') + ).select_related("organization") _map = defaultdict(list) for item in qs: _map[item.analysisreport_id].append(item.organization) @@ -204,13 +213,14 @@ def batch_load_fn(self, keys): class LatestReportSnapshotByAnalysisReportLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = AnalysisReportSnapshot.objects.filter( - report__in=keys, - ).order_by('report_id', '-published_on').distinct('report_id') - _map = { - snapshot.report_id: snapshot - for snapshot in qs - } + qs = ( + AnalysisReportSnapshot.objects.filter( + report__in=keys, + ) + .order_by("report_id", "-published_on") + .distinct("report_id") + ) + _map = {snapshot.report_id: snapshot for snapshot in qs} return Promise.resolve([_map.get(key) for key in keys]) diff --git a/apps/analysis/enums.py b/apps/analysis/enums.py index 41111c31a8..35203f0419 100644 --- a/apps/analysis/enums.py +++ b/apps/analysis/enums.py @@ -4,79 +4,88 @@ ) from .models import ( + AnalysisReportContainer, + AnalysisReportUpload, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, DiscardedEntry, TopicModel, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, - AnalysisReportUpload, - AnalysisReportContainer, ) from .serializers import ( - ReportEnum, - AnalysisReportVariableSerializer, - AnalysisReportTextStyleSerializer, + AnalysisReportBarChartConfigurationSerializer, AnalysisReportBorderStyleSerializer, - AnalysisReportImageContentStyleSerializer, + AnalysisReportCategoricalLegendStyleSerializer, AnalysisReportHeadingConfigurationSerializer, AnalysisReportHorizontalAxisSerializer, - AnalysisReportVerticalAxisSerializer, - AnalysisReportBarChartConfigurationSerializer, - AnalysisReportCategoricalLegendStyleSerializer, + AnalysisReportImageContentStyleSerializer, + AnalysisReportLineLayerStyleSerializer, AnalysisReportMapLayerConfigurationSerializer, AnalysisReportSymbolLayerConfigurationSerializer, - AnalysisReportLineLayerStyleSerializer, + AnalysisReportTextStyleSerializer, + AnalysisReportVariableSerializer, + AnalysisReportVerticalAxisSerializer, + ReportEnum, ) +DiscardedEntryTagTypeEnum = convert_enum_to_graphene_enum(DiscardedEntry.TagType, name="DiscardedEntryTagTypeEnum") -DiscardedEntryTagTypeEnum = convert_enum_to_graphene_enum(DiscardedEntry.TagType, name='DiscardedEntryTagTypeEnum') - -TopicModelStatusEnum = convert_enum_to_graphene_enum(TopicModel.Status, name='TopicModelStatusEnum') +TopicModelStatusEnum = convert_enum_to_graphene_enum(TopicModel.Status, name="TopicModelStatusEnum") -AutomaticSummaryStatusEnum = convert_enum_to_graphene_enum(AutomaticSummary.Status, name='AutomaticSummaryStatusEnum') +AutomaticSummaryStatusEnum = convert_enum_to_graphene_enum(AutomaticSummary.Status, name="AutomaticSummaryStatusEnum") AnalyticalStatementNGramStatusEnum = convert_enum_to_graphene_enum( - AnalyticalStatementNGram.Status, name='AnalyticalStatementNGramStatusEnum') + AnalyticalStatementNGram.Status, name="AnalyticalStatementNGramStatusEnum" +) AnalyticalStatementGeoTaskStatusEnum = convert_enum_to_graphene_enum( - AnalyticalStatementGeoTask.Status, name='AnalyticalStatementGeoTaskStatusEnum') + AnalyticalStatementGeoTask.Status, name="AnalyticalStatementGeoTaskStatusEnum" +) # Analysis Report -AnalysisReportUploadTypeEnum = convert_enum_to_graphene_enum(AnalysisReportUpload.Type, name='AnalysisReportUploadTypeEnum') +AnalysisReportUploadTypeEnum = convert_enum_to_graphene_enum(AnalysisReportUpload.Type, name="AnalysisReportUploadTypeEnum") AnalysisReportContainerContentTypeEnum = convert_enum_to_graphene_enum( - AnalysisReportContainer.ContentType, name='AnalysisReportContainerContentTypeEnum') + AnalysisReportContainer.ContentType, name="AnalysisReportContainerContentTypeEnum" +) # Client Side Enums -AnalysisReportVariableTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.VariableType, name='AnalysisReportVariableTypeEnum') +AnalysisReportVariableTypeEnum = convert_enum_to_graphene_enum(ReportEnum.VariableType, name="AnalysisReportVariableTypeEnum") AnalysisReportTextStyleAlignEnum = convert_enum_to_graphene_enum( - ReportEnum.TextStyleAlign, name='AnalysisReportTextStyleAlignEnum') + ReportEnum.TextStyleAlign, name="AnalysisReportTextStyleAlignEnum" +) AnalysisReportBorderStyleStyleEnum = convert_enum_to_graphene_enum( - ReportEnum.BorderStyleStyle, name='AnalysisReportBorderStyleStyleEnum') + ReportEnum.BorderStyleStyle, name="AnalysisReportBorderStyleStyleEnum" +) AnalysisReportImageContentStyleFitEnum = convert_enum_to_graphene_enum( - ReportEnum.ImageContentStyleFit, name='AnalysisReportImageContentStyleFitEnum') + ReportEnum.ImageContentStyleFit, name="AnalysisReportImageContentStyleFitEnum" +) AnalysisReportHeadingConfigurationVariantEnum = convert_enum_to_graphene_enum( - ReportEnum.HeadingConfigurationVariant, name='AnalysisReportHeadingConfigurationVariantEnum') + ReportEnum.HeadingConfigurationVariant, name="AnalysisReportHeadingConfigurationVariantEnum" +) AnalysisReportHorizontalAxisTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.HorizontalAxisType, name='AnalysisReportHorizontalAxisTypeEnum') -AnalysisReportBarChartTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.BarChartType, name='AnalysisReportBarChartTypeEnum') + ReportEnum.HorizontalAxisType, name="AnalysisReportHorizontalAxisTypeEnum" +) +AnalysisReportBarChartTypeEnum = convert_enum_to_graphene_enum(ReportEnum.BarChartType, name="AnalysisReportBarChartTypeEnum") AnalysisReportBarChartDirectionEnum = convert_enum_to_graphene_enum( - ReportEnum.BarChartDirection, name='AnalysisReportBarChartDirectionEnum') + ReportEnum.BarChartDirection, name="AnalysisReportBarChartDirectionEnum" +) AnalysisReportLegendPositionEnum = convert_enum_to_graphene_enum( - ReportEnum.LegendPosition, name='AnalysisReportLegendPositionEnum') + ReportEnum.LegendPosition, name="AnalysisReportLegendPositionEnum" +) AnalysisReportLegendDotShapeEnum = convert_enum_to_graphene_enum( - ReportEnum.LegendDotShape, name='AnalysisReportLegendDotShapeEnum') + ReportEnum.LegendDotShape, name="AnalysisReportLegendDotShapeEnum" +) AnalysisReportAggregationTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.AggregationType, name='AnalysisReportAggregationTypeEnum') -AnalysisReportMapLayerTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.MapLayerType, name='AnalysisReportMapLayerTypeEnum') -AnalysisReportScaleTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.ScaleType, name='AnalysisReportScaleTypeEnum') + ReportEnum.AggregationType, name="AnalysisReportAggregationTypeEnum" +) +AnalysisReportMapLayerTypeEnum = convert_enum_to_graphene_enum(ReportEnum.MapLayerType, name="AnalysisReportMapLayerTypeEnum") +AnalysisReportScaleTypeEnum = convert_enum_to_graphene_enum(ReportEnum.ScaleType, name="AnalysisReportScaleTypeEnum") AnalysisReportScalingTechniqueEnum = convert_enum_to_graphene_enum( - ReportEnum.ScalingTechnique, name='AnalysisReportScalingTechniqueEnum') + ReportEnum.ScalingTechnique, name="AnalysisReportScalingTechniqueEnum" +) AnalysisReportLineLayerStrokeTypeEnum = convert_enum_to_graphene_enum( - ReportEnum.LineLayerStrokeType, name='AnalysisReportLineLayerStrokeTypeEnum') + ReportEnum.LineLayerStrokeType, name="AnalysisReportLineLayerStrokeTypeEnum" +) # Model field mapping enum_map = { @@ -94,23 +103,25 @@ } # Serializers field mapping -enum_map.update({ - get_enum_name_from_django_field(serializer().fields[field]): enum - for serializer, field, enum in [ - (AnalysisReportVariableSerializer, 'type', AnalysisReportVariableTypeEnum), - (AnalysisReportTextStyleSerializer, 'align', AnalysisReportTextStyleAlignEnum), - (AnalysisReportBorderStyleSerializer, 'style', AnalysisReportBorderStyleStyleEnum), - (AnalysisReportImageContentStyleSerializer, 'fit', AnalysisReportImageContentStyleFitEnum), - (AnalysisReportHeadingConfigurationSerializer, 'variant', AnalysisReportHeadingConfigurationVariantEnum), - (AnalysisReportHorizontalAxisSerializer, 'type', AnalysisReportHorizontalAxisTypeEnum), - (AnalysisReportBarChartConfigurationSerializer, 'type', AnalysisReportBarChartTypeEnum), - (AnalysisReportBarChartConfigurationSerializer, 'direction', AnalysisReportBarChartDirectionEnum), - (AnalysisReportCategoricalLegendStyleSerializer, 'position', AnalysisReportLegendPositionEnum), - (AnalysisReportCategoricalLegendStyleSerializer, 'shape', AnalysisReportLegendDotShapeEnum), - (AnalysisReportVerticalAxisSerializer, 'aggregation_type', AnalysisReportAggregationTypeEnum), - (AnalysisReportMapLayerConfigurationSerializer, 'type', AnalysisReportMapLayerTypeEnum), - (AnalysisReportSymbolLayerConfigurationSerializer, 'scale_type', AnalysisReportScaleTypeEnum), - (AnalysisReportSymbolLayerConfigurationSerializer, 'scaling_technique', AnalysisReportScalingTechniqueEnum), - (AnalysisReportLineLayerStyleSerializer, 'stroke_type', AnalysisReportLineLayerStrokeTypeEnum), - ] -}) +enum_map.update( + { + get_enum_name_from_django_field(serializer().fields[field]): enum + for serializer, field, enum in [ + (AnalysisReportVariableSerializer, "type", AnalysisReportVariableTypeEnum), + (AnalysisReportTextStyleSerializer, "align", AnalysisReportTextStyleAlignEnum), + (AnalysisReportBorderStyleSerializer, "style", AnalysisReportBorderStyleStyleEnum), + (AnalysisReportImageContentStyleSerializer, "fit", AnalysisReportImageContentStyleFitEnum), + (AnalysisReportHeadingConfigurationSerializer, "variant", AnalysisReportHeadingConfigurationVariantEnum), + (AnalysisReportHorizontalAxisSerializer, "type", AnalysisReportHorizontalAxisTypeEnum), + (AnalysisReportBarChartConfigurationSerializer, "type", AnalysisReportBarChartTypeEnum), + (AnalysisReportBarChartConfigurationSerializer, "direction", AnalysisReportBarChartDirectionEnum), + (AnalysisReportCategoricalLegendStyleSerializer, "position", AnalysisReportLegendPositionEnum), + (AnalysisReportCategoricalLegendStyleSerializer, "shape", AnalysisReportLegendDotShapeEnum), + (AnalysisReportVerticalAxisSerializer, "aggregation_type", AnalysisReportAggregationTypeEnum), + (AnalysisReportMapLayerConfigurationSerializer, "type", AnalysisReportMapLayerTypeEnum), + (AnalysisReportSymbolLayerConfigurationSerializer, "scale_type", AnalysisReportScaleTypeEnum), + (AnalysisReportSymbolLayerConfigurationSerializer, "scaling_technique", AnalysisReportScalingTechniqueEnum), + (AnalysisReportLineLayerStyleSerializer, "stroke_type", AnalysisReportLineLayerStrokeTypeEnum), + ] + } +) diff --git a/apps/analysis/factories.py b/apps/analysis/factories.py index 7586925284..43a858d102 100644 --- a/apps/analysis/factories.py +++ b/apps/analysis/factories.py @@ -1,36 +1,36 @@ import factory from factory.django import DjangoModelFactory - from gallery.factories import FileFactory + from .models import ( Analysis, AnalysisPillar, + AnalysisReport, + AnalysisReportUpload, AnalyticalStatement, AnalyticalStatementEntry, DiscardedEntry, - AnalysisReport, - AnalysisReportUpload, ) class AnalysisFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Analysis-{n}') + title = factory.Sequence(lambda n: f"Analysis-{n}") class Meta: model = Analysis class AnalysisPillarFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Analysis-Pillar-{n}') - main_statement = factory.Faker('sentence', nb_words=20) - information_gap = factory.Faker('sentence', nb_words=20) + title = factory.Sequence(lambda n: f"Analysis-Pillar-{n}") + main_statement = factory.Faker("sentence", nb_words=20) + information_gap = factory.Faker("sentence", nb_words=20) class Meta: model = AnalysisPillar class AnalyticalStatementFactory(DjangoModelFactory): - statement = factory.Faker('sentence', nb_words=20) + statement = factory.Faker("sentence", nb_words=20) order = factory.Sequence(lambda n: n) class Meta: diff --git a/apps/analysis/filter_set.py b/apps/analysis/filter_set.py index 67794b52c9..a8731095dd 100644 --- a/apps/analysis/filter_set.py +++ b/apps/analysis/filter_set.py @@ -1,46 +1,47 @@ import django_filters from django.db import models from django.db.models.functions import Coalesce +from entry.filter_set import EntryGQFilterSet +from user_resource.filters import UserResourceGqlFilterSet from utils.graphene.filters import IDListFilter, MultipleInputFilter -from user_resource.filters import UserResourceGqlFilterSet -from entry.filter_set import EntryGQFilterSet +from .enums import AnalysisReportUploadTypeEnum, DiscardedEntryTagTypeEnum from .models import ( Analysis, AnalysisPillar, - DiscardedEntry, - AnalyticalStatement, AnalysisReport, - AnalysisReportUpload, AnalysisReportSnapshot, -) -from .enums import ( - DiscardedEntryTagTypeEnum, - AnalysisReportUploadTypeEnum, + AnalysisReportUpload, + AnalyticalStatement, + DiscardedEntry, ) class AnalysisFilterSet(django_filters.FilterSet): created_at__lt = django_filters.DateTimeFilter( - field_name='created_at', lookup_expr='lt', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lt", + input_formats=["%Y-%m-%d%z"], ) created_at__gt = django_filters.DateTimeFilter( - field_name='created_at', lookup_expr='gt', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="gt", + input_formats=["%Y-%m-%d%z"], ) created_at__lte = django_filters.DateTimeFilter( - field_name='created_at', lookup_expr='lte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lte", + input_formats=["%Y-%m-%d%z"], ) created_at__gte = django_filters.DateTimeFilter( - field_name='created_at', lookup_expr='gte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="gte", + input_formats=["%Y-%m-%d%z"], ) created_at = django_filters.DateTimeFilter( - field_name='created_at', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + input_formats=["%Y-%m-%d%z"], ) class Meta: @@ -51,7 +52,7 @@ class Meta: class DiscardedEntryFilterSet(django_filters.FilterSet): tag = django_filters.MultipleChoiceFilter( choices=DiscardedEntry.TagType.choices, - lookup_expr='in', + lookup_expr="in", widget=django_filters.widgets.CSVWidget, ) @@ -68,7 +69,7 @@ class Meta: class AnalysisPillarGQFilterSet(UserResourceGqlFilterSet): - analyses = IDListFilter(field_name='analysis') + analyses = IDListFilter(field_name="analysis") class Meta: model = AnalysisPillar @@ -76,8 +77,8 @@ class Meta: class AnalysisPillarEntryGQFilterSet(EntryGQFilterSet): - discarded = django_filters.BooleanFilter(method='filter_discarded') - exclude_entries = IDListFilter(method='filter_exclude_entries') + discarded = django_filters.BooleanFilter(method="filter_discarded") + exclude_entries = IDListFilter(method="filter_exclude_entries") def filter_discarded(self, queryset, *_): # NOTE: This is only for argument, filter is done in AnalysisPillarType.resolve_entries @@ -96,7 +97,7 @@ class Meta: class AnalysisPillarDiscardedEntryGqlFilterSet(django_filters.FilterSet): - tags = MultipleInputFilter(DiscardedEntryTagTypeEnum, field_name='tag') + tags = MultipleInputFilter(DiscardedEntryTagTypeEnum, field_name="tag") class Meta: model = DiscardedEntry @@ -104,10 +105,10 @@ class Meta: class AnalysisReportGQFilterSet(django_filters.FilterSet): - search = django_filters.CharFilter(method='search_filter') - analyses = IDListFilter(field_name='analysis') - is_public = django_filters.BooleanFilter(method='filter_discarded') - organizations = IDListFilter(method='organizations_filter') + search = django_filters.CharFilter(method="search_filter") + analyses = IDListFilter(field_name="analysis") + is_public = django_filters.BooleanFilter(method="filter_discarded") + organizations = IDListFilter(method="organizations_filter") class Meta: model = AnalysisReport @@ -115,24 +116,23 @@ class Meta: def organizations_filter(self, qs, _, value): if value: - qs = qs.annotate( - authoring_organizations=Coalesce('authors__parent_id', 'authors__id') - ).filter(authoring_organizations__in=value).distinct() + qs = ( + qs.annotate(authoring_organizations=Coalesce("authors__parent_id", "authors__id")) + .filter(authoring_organizations__in=value) + .distinct() + ) return qs def search_filter(self, qs, _, value): if value: - qs = qs.filter( - models.Q(slug__icontains=value) | - models.Q(title__icontains=value) - ).distinct() + qs = qs.filter(models.Q(slug__icontains=value) | models.Q(title__icontains=value)).distinct() return qs class AnalysisReportUploadGQFilterSet(django_filters.FilterSet): - search = django_filters.CharFilter(method='search_filter') - report = IDListFilter(field_name='report') - types = MultipleInputFilter(AnalysisReportUploadTypeEnum, field_name='type') + search = django_filters.CharFilter(method="search_filter") + report = IDListFilter(field_name="report") + types = MultipleInputFilter(AnalysisReportUploadTypeEnum, field_name="type") class Meta: model = AnalysisReportUpload @@ -140,14 +140,12 @@ class Meta: def search_filter(self, qs, _, value): if value: - qs = qs.filter( - models.Q(file__title__icontains=value) - ).distinct() + qs = qs.filter(models.Q(file__title__icontains=value)).distinct() return qs class AnalysisReportSnapshotGQFilterSet(django_filters.FilterSet): - report = IDListFilter(field_name='report') + report = IDListFilter(field_name="report") class Meta: model = AnalysisReportSnapshot diff --git a/apps/analysis/models.py b/apps/analysis/models.py index b29aff4068..d5a370b149 100644 --- a/apps/analysis/models.py +++ b/apps/analysis/models.py @@ -2,26 +2,26 @@ import json from datetime import timedelta +from deepl_integration.models import DeeplTrackBaseModel +from django.contrib.postgres.fields import ArrayField +from django.core.validators import MaxValueValidator, MinValueValidator +from django.db import connection as django_db_connection from django.db import models from django.db.models.functions import JSONObject -from django.db import connection as django_db_connection -from django.utils.translation import gettext_lazy as _ from django.utils import timezone -from django.contrib.postgres.fields import ArrayField -from django.core.validators import MaxValueValidator, MinValueValidator - -from utils.common import generate_sha256 -from deep.number_generator import client_id_generator -from deep.filter_set import get_dummy_request -from project.mixins import ProjectEntityMixin -from user.models import User -from project.models import Project +from django.utils.translation import gettext_lazy as _ from entry.models import Entry +from gallery.models import File from lead.models import Lead -from user_resource.models import UserResource from organization.models import Organization -from gallery.models import File -from deepl_integration.models import DeeplTrackBaseModel +from project.mixins import ProjectEntityMixin +from project.models import Project +from user.models import User +from user_resource.models import UserResource + +from deep.filter_set import get_dummy_request +from deep.number_generator import client_id_generator +from utils.common import generate_sha256 class Analysis(UserResource, ProjectEntityMixin): @@ -37,11 +37,7 @@ class Analysis(UserResource, ProjectEntityMixin): start_date = models.DateField(null=True, blank=True) end_date = models.DateField() # added to keep the track of cloned analysis - cloned_from = models.ForeignKey( - 'Analysis', - on_delete=models.SET_NULL, - null=True, blank=True - ) + cloned_from = models.ForeignKey("Analysis", on_delete=models.SET_NULL, null=True, blank=True) def __str__(self): return self.title @@ -82,13 +78,9 @@ def _get_clone_discarded_entry(obj, analysis_pillar_id): analysis_cloned.save() # Clone pillars cloned_pillars = [ - _get_clone_pillar(analysis_pillar, analysis_cloned.pk) - for analysis_pillar in self.analysispillar_set.all() + _get_clone_pillar(analysis_pillar, analysis_cloned.pk) for analysis_pillar in self.analysispillar_set.all() ] - cloned_pillar_id_map = { - pillar.cloned_from_id: pillar.id - for pillar in AnalysisPillar.objects.bulk_create(cloned_pillars) - } + cloned_pillar_id_map = {pillar.cloned_from_id: pillar.id for pillar in AnalysisPillar.objects.bulk_create(cloned_pillars)} # Clone discarded entries DiscardedEntry.objects.bulk_create( @@ -110,8 +102,7 @@ def _get_clone_discarded_entry(obj, analysis_pillar_id): for statement in AnalyticalStatement.objects.filter(analysis_pillar__analysis=self) ] cloned_statement_id_map = { - statement.cloned_from_id: statement.id - for statement in AnalyticalStatement.objects.bulk_create(cloned_statements) + statement.cloned_from_id: statement.id for statement in AnalyticalStatement.objects.bulk_create(cloned_statements) } # Clone statement entries @@ -120,9 +111,7 @@ def _get_clone_discarded_entry(obj, analysis_pillar_id): statement_entry, cloned_statement_id_map[statement_entry.analytical_statement_id], # Use newly cloned statement id ) - for statement_entry in AnalyticalStatementEntry.objects.filter( - analytical_statement__analysis_pillar__analysis=self - ) + for statement_entry in AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis=self) ] AnalyticalStatementEntry.objects.bulk_create(cloned_statement_entries) @@ -135,28 +124,29 @@ def get_analyzed_sources(cls, analysis_list): if len(analysis_ids) == 0: return {} - leads_dragged = AnalyticalStatement.objects\ - .filter(analysis_pillar__analysis__in=analysis_ids)\ - .order_by().values('analysis_pillar__analysis', 'entries__lead_id') - leads_discarded = DiscardedEntry.objects\ - .filter(analysis_pillar__analysis__in=analysis_ids)\ - .order_by().values('analysis_pillar__analysis', 'entry__lead_id') + leads_dragged = ( + AnalyticalStatement.objects.filter(analysis_pillar__analysis__in=analysis_ids) + .order_by() + .values("analysis_pillar__analysis", "entries__lead_id") + ) + leads_discarded = ( + DiscardedEntry.objects.filter(analysis_pillar__analysis__in=analysis_ids) + .order_by() + .values("analysis_pillar__analysis", "entry__lead_id") + ) union_query = leads_dragged.union(leads_discarded).query # NOTE: Django ORM union isn't allowed inside annotation with django_db_connection.cursor() as cursor: - raw_sql = f''' + raw_sql = f""" SELECT u.analysis_id, COUNT(DISTINCT(u.lead_id)) FROM ({union_query}) as u GROUP BY u.analysis_id - ''' + """ cursor.execute(raw_sql) - return { - analysis_id: lead_count - for analysis_id, lead_count in cursor.fetchall() - } + return {analysis_id: lead_count for analysis_id, lead_count in cursor.fetchall()} @classmethod def get_analyzed_entries(cls, analysis_list): @@ -165,34 +155,35 @@ def get_analyzed_entries(cls, analysis_list): if len(analysis_ids) == 0: return {} - entries_dragged = AnalyticalStatementEntry.objects\ - .filter( + entries_dragged = ( + AnalyticalStatementEntry.objects.filter( analytical_statement__analysis_pillar__analysis__in=analysis_ids, - entry__lead__published_on__lte=models.F('analytical_statement__analysis_pillar__analysis__end_date') - )\ - .order_by().values('analytical_statement__analysis_pillar__analysis', 'entry') - entries_discarded = DiscardedEntry.objects\ - .filter( + entry__lead__published_on__lte=models.F("analytical_statement__analysis_pillar__analysis__end_date"), + ) + .order_by() + .values("analytical_statement__analysis_pillar__analysis", "entry") + ) + entries_discarded = ( + DiscardedEntry.objects.filter( analysis_pillar__analysis__in=analysis_ids, - entry__lead__published_on__lte=models.F('analysis_pillar__analysis__end_date') - )\ - .order_by().values('analysis_pillar__analysis', 'entry') + entry__lead__published_on__lte=models.F("analysis_pillar__analysis__end_date"), + ) + .order_by() + .values("analysis_pillar__analysis", "entry") + ) union_query = entries_dragged.union(entries_discarded).query # NOTE: Django ORM union isn't allowed inside annotation with django_db_connection.cursor() as cursor: - raw_sql = f''' + raw_sql = f""" SELECT u.analysis_id, COUNT(DISTINCT(u.entry_id)) FROM ({union_query}) as u GROUP BY u.analysis_id - ''' + """ cursor.execute(raw_sql) - return { - analysis_id: entry_count - for analysis_id, entry_count in cursor.fetchall() - } + return {analysis_id: entry_count for analysis_id, entry_count in cursor.fetchall()} @classmethod def annotate_for_analysis_summary(cls, project_id, queryset, user): @@ -200,71 +191,95 @@ def annotate_for_analysis_summary(cls, project_id, queryset, user): This is used by AnalysisSummarySerializer and AnalysisViewSet.get_summary """ # NOTE: Using the entries and lead in the project for total entries and leads in analysis level - total_sources = Lead.objects\ - .filter(project=project_id)\ - .annotate(entries_count=models.Count('entry'))\ - .filter(entries_count__gt=0)\ + total_sources = ( + Lead.objects.filter(project=project_id) + .annotate(entries_count=models.Count("entry")) + .filter(entries_count__gt=0) .count() + ) total_entries = Entry.objects.filter(project=project_id).count() # Prefetch for AnalysisSummaryPillarSerializer. analysispillar_prefetch = models.Prefetch( - 'analysispillar_set', + "analysispillar_set", queryset=( AnalysisPillar.objects.select_related( - 'assignee', - 'assignee__profile', + "assignee", + "assignee__profile", ).annotate( - dragged_entries=models.functions.Coalesce(models.Subquery( - AnalyticalStatement.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar').annotate(count=models.Count( - 'entries', - distinct=True, - filter=models.Q(entries__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0), - discarded_entries=models.functions.Coalesce(models.Subquery( - DiscardedEntry.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar').annotate(count=models.Count( - 'entry', - distinct=True, - filter=models.Q(entry__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0), - analyzed_entries=models.F('dragged_entries') + models.F('discarded_entries') + dragged_entries=models.functions.Coalesce( + models.Subquery( + AnalyticalStatement.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar") + .annotate( + count=models.Count( + "entries", + distinct=True, + filter=models.Q(entries__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + discarded_entries=models.functions.Coalesce( + models.Subquery( + DiscardedEntry.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar") + .annotate( + count=models.Count( + "entry", + distinct=True, + filter=models.Q(entry__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + analyzed_entries=models.F("dragged_entries") + models.F("discarded_entries"), ) ), ) publication_date_subquery = models.Subquery( AnalyticalStatementEntry.objects.filter( - analytical_statement__analysis_pillar__analysis=models.OuterRef('pk'), - ).order_by().values('analytical_statement__analysis_pillar__analysis').annotate( - published_on_min=models.Min('entry__lead__published_on'), - published_on_max=models.Max('entry__lead__published_on'), - ).annotate( + analytical_statement__analysis_pillar__analysis=models.OuterRef("pk"), + ) + .order_by() + .values("analytical_statement__analysis_pillar__analysis") + .annotate( + published_on_min=models.Min("entry__lead__published_on"), + published_on_max=models.Max("entry__lead__published_on"), + ) + .annotate( publication_date=JSONObject( - start_date=models.F('published_on_min'), - end_date=models.F('published_on_max'), + start_date=models.F("published_on_min"), + end_date=models.F("published_on_max"), ) - ).values('publication_date')[:1], + ) + .values("publication_date")[:1], output_field=models.JSONField(), ) - return queryset.select_related( - 'team_lead', - 'team_lead__profile', - ).prefetch_related( - analysispillar_prefetch, - ).annotate( - team_lead_name=models.F('team_lead__username'), - total_entries=models.Value(total_entries, output_field=models.IntegerField()), - total_sources=models.Value(total_sources, output_field=models.IntegerField()), - publication_date=publication_date_subquery, + return ( + queryset.select_related( + "team_lead", + "team_lead__profile", + ) + .prefetch_related( + analysispillar_prefetch, + ) + .annotate( + team_lead_name=models.F("team_lead__username"), + total_entries=models.Value(total_entries, output_field=models.IntegerField()), + total_sources=models.Value(total_sources, output_field=models.IntegerField()), + publication_date=publication_date_subquery, + ) ) @@ -273,20 +288,10 @@ class AnalysisPillar(UserResource): main_statement = models.TextField(blank=True) information_gap = models.TextField(blank=True) filters = models.JSONField(blank=True, null=True, default=None) - assignee = models.ForeignKey( - User, - on_delete=models.CASCADE - ) - analysis = models.ForeignKey( - Analysis, - on_delete=models.CASCADE - ) + assignee = models.ForeignKey(User, on_delete=models.CASCADE) + analysis = models.ForeignKey(Analysis, on_delete=models.CASCADE) # added to keep the track of cloned analysispillar - cloned_from = models.ForeignKey( - 'AnalysisPillar', - on_delete=models.SET_NULL, - null=True, blank=True - ) + cloned_from = models.ForeignKey("AnalysisPillar", on_delete=models.SET_NULL, null=True, blank=True) def __str__(self): return self.title @@ -305,7 +310,7 @@ def get_entries_qs(self, queryset=None, only_discarded=False): project=self.analysis.project_id, lead__published_on__lte=self.analysis.end_date, ) - discarded_entries_qs = DiscardedEntry.objects.filter(analysis_pillar=self).values('entry') + discarded_entries_qs = DiscardedEntry.objects.filter(analysis_pillar=self).values("entry") if only_discarded: return _queryset.filter(id__in=discarded_entries_qs) return _queryset.exclude(id__in=discarded_entries_qs) @@ -313,65 +318,66 @@ def get_entries_qs(self, queryset=None, only_discarded=False): @classmethod def annotate_for_analysis_pillar_summary(cls, qs): analytical_statement_prefech = models.Prefetch( - 'analyticalstatement_set', - queryset=( - AnalyticalStatement.objects.annotate( - entries_count=models.Count('entries', distinct=True) - ) - ) + "analyticalstatement_set", + queryset=(AnalyticalStatement.objects.annotate(entries_count=models.Count("entries", distinct=True))), ) - return qs\ - .prefetch_related(analytical_statement_prefech)\ - .annotate( - dragged_entries=models.functions.Coalesce( - models.Subquery( - AnalyticalStatement.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar').annotate(count=models.Count( - 'entries', + return qs.prefetch_related(analytical_statement_prefech).annotate( + dragged_entries=models.functions.Coalesce( + models.Subquery( + AnalyticalStatement.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar") + .annotate( + count=models.Count( + "entries", distinct=True, - filter=models.Q(entries__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0), - discarded_entries=models.functions.Coalesce( - models.Subquery( - DiscardedEntry.objects.filter( - analysis_pillar=models.OuterRef('pk') - ).order_by().values('analysis_pillar__analysis').annotate(count=models.Count( - 'entry', + filter=models.Q(entries__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + discarded_entries=models.functions.Coalesce( + models.Subquery( + DiscardedEntry.objects.filter(analysis_pillar=models.OuterRef("pk")) + .order_by() + .values("analysis_pillar__analysis") + .annotate( + count=models.Count( + "entry", distinct=True, - filter=models.Q(entry__lead__published_on__lte=models.OuterRef('analysis__end_date')))) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0), - analyzed_entries=models.F('dragged_entries') + models.F('discarded_entries'), - ) + filter=models.Q(entry__lead__published_on__lte=models.OuterRef("analysis__end_date")), + ) + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + analyzed_entries=models.F("dragged_entries") + models.F("discarded_entries"), + ) class DiscardedEntry(models.Model): """ Discarded entries for AnalysisPillar """ + class TagType(models.IntegerChoices): - REDUNDANT = 0, _('Redundant') - TOO_OLD = 1, _('Too old') - ANECDOTAL = 2, _('Anecdotal') - OUTLIER = 3, _('Outlier') - - analysis_pillar = models.ForeignKey( - AnalysisPillar, - on_delete=models.CASCADE - ) - entry = models.ForeignKey( - Entry, - on_delete=models.CASCADE - ) + REDUNDANT = 0, _("Redundant") + TOO_OLD = 1, _("Too old") + ANECDOTAL = 2, _("Anecdotal") + OUTLIER = 3, _("Outlier") + + analysis_pillar = models.ForeignKey(AnalysisPillar, on_delete=models.CASCADE) + entry = models.ForeignKey(Entry, on_delete=models.CASCADE) tag = models.IntegerField(choices=TagType.choices) class Meta: - unique_together = ('entry', 'analysis_pillar') + unique_together = ("entry", "analysis_pillar") def can_get(self, user): return self.analysis_pillar.can_get(user) @@ -383,7 +389,7 @@ def can_delete(self, user): return self.can_modify(user) def __str__(self): - return f'{self.analysis_pillar} - {self.entry}' + return f"{self.analysis_pillar} - {self.entry}" class AnalyticalStatement(UserResource): @@ -391,27 +397,20 @@ class AnalyticalStatement(UserResource): title = models.CharField(max_length=150, blank=True, null=True) entries = models.ManyToManyField( Entry, - through='AnalyticalStatementEntry', - through_fields=('analytical_statement', 'entry'), + through="AnalyticalStatementEntry", + through_fields=("analytical_statement", "entry"), blank=True, ) - analysis_pillar = models.ForeignKey( - AnalysisPillar, - on_delete=models.CASCADE - ) + analysis_pillar = models.ForeignKey(AnalysisPillar, on_delete=models.CASCADE) include_in_report = models.BooleanField(default=False) order = models.IntegerField() report_text = models.TextField(blank=True) information_gaps = models.TextField(blank=True) # added to keep the track of cloned analysisstatement - cloned_from = models.ForeignKey( - 'AnalyticalStatement', - on_delete=models.SET_NULL, - null=True, blank=True - ) + cloned_from = models.ForeignKey("AnalyticalStatement", on_delete=models.SET_NULL, null=True, blank=True) class Meta: - ordering = ('order',) + ordering = ("order",) def can_get(self, user): return self.analysis_pillar.can_get(user) @@ -424,14 +423,8 @@ def __str__(self): class AnalyticalStatementEntry(UserResource): - entry = models.ForeignKey( - Entry, - on_delete=models.CASCADE - ) - analytical_statement = models.ForeignKey( - AnalyticalStatement, - on_delete=models.CASCADE - ) + entry = models.ForeignKey(Entry, on_delete=models.CASCADE) + analytical_statement = models.ForeignKey(AnalyticalStatement, on_delete=models.CASCADE) order = models.IntegerField() def can_get(self, user): @@ -441,12 +434,12 @@ def can_modify(self, user): return self.analytical_statement.can_modify(user) class Meta: - ordering = ('order',) + ordering = ("order",) # NLP Trigger Model -- Used as cache and tracking async data calculation def entries_file_upload_to(instance, filename: str) -> str: - return f'analysis/{type(instance).__name__.lower()}/entries/{filename}' + return f"analysis/{type(instance).__name__.lower()}/entries/{filename}" class TopicModel(UserResource, DeeplTrackBaseModel): @@ -456,12 +449,13 @@ class TopicModel(UserResource, DeeplTrackBaseModel): additional_filters = models.JSONField(default=dict) widget_tags = ArrayField(models.CharField(max_length=100), default=list) - topicmodelcluster_set: models.QuerySet['TopicModelCluster'] + topicmodelcluster_set: models.QuerySet["TopicModelCluster"] @staticmethod def _get_entries_qs(analysis_pillar, entry_filters): # Loading here to make sure models are loaded before filters from entry.filter_set import EntryGQFilterSet + dummy_request = get_dummy_request(active_project=analysis_pillar.analysis.project) return EntryGQFilterSet( queryset=analysis_pillar.get_entries_qs(), # Queryset from AnalysisPillar @@ -483,7 +477,7 @@ class TopicModelCluster(models.Model): class EntriesCollectionNlpTriggerBase(UserResource, DeeplTrackBaseModel): project = models.ForeignKey(Project, on_delete=models.CASCADE) entries_id = ArrayField(models.IntegerField()) - entries_hash = models.CharField(max_length=256, db_index=True) # Generated using entries_id + entries_hash = models.CharField(max_length=256, db_index=True) # Generated using entries_id entries_file = models.FileField(upload_to=entries_file_upload_to, max_length=255) CACHE_THRESHOLD_HOURS = 3 @@ -495,16 +489,20 @@ class Meta: def get_existing(cls, entries_id): threshold = timezone.now() - timedelta(hours=cls.CACHE_THRESHOLD_HOURS) entries_hash = cls.get_entry_hash(entries_id) - return cls.objects.filter( - entries_hash=entries_hash, - created_at__gte=threshold, - ).exclude( - status__in=[ - cls.Status.STARTED, - cls.Status.FAILED, - cls.Status.SEND_FAILED, - ], - ).first() + return ( + cls.objects.filter( + entries_hash=entries_hash, + created_at__gte=threshold, + ) + .exclude( + status__in=[ + cls.Status.STARTED, + cls.Status.FAILED, + cls.Status.SEND_FAILED, + ], + ) + .first() + ) @staticmethod def get_valid_entries_id(project_id, entries_id): @@ -512,7 +510,9 @@ def get_valid_entries_id(project_id, entries_id): Entry.objects.filter( project=project_id, id__in=entries_id, - ).order_by('id').values_list('id', flat=True) + ) + .order_by("id") + .values_list("id", flat=True) ) @staticmethod @@ -544,7 +544,7 @@ class AnalyticalStatementGeoEntry(models.Model): task = models.ForeignKey( AnalyticalStatementGeoTask, on_delete=models.CASCADE, - related_name='entry_geos', + related_name="entry_geos", ) entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="+") data = models.JSONField(default=list) @@ -554,7 +554,7 @@ class AnalyticalStatementGeoEntry(models.Model): class AnalysisReport(UserResource): analysis = models.ForeignKey(Analysis, on_delete=models.CASCADE) is_public = models.BooleanField( - help_text="A report should be public for \"shareable link\" to be accessible by", + help_text='A report should be public for "shareable link" to be accessible by', default=False, ) slug = models.CharField( @@ -580,18 +580,18 @@ def get_latest_snapshot(slug=None, report_id=None): queryset = queryset.filter(report__slug=slug) if report_id is not None: queryset = queryset.filter(report_id=report_id) - return queryset.order_by('-published_on').first() + return queryset.order_by("-published_on").first() class AnalysisReportUpload(models.Model): class Type(models.IntegerChoices): - CSV = 1, 'CSV' - XLSX = 2, 'XLSX' - GEOJSON = 3, 'GeoJson' - IMAGE = 4, 'Image' + CSV = 1, "CSV" + XLSX = 2, "XLSX" + GEOJSON = 3, "GeoJson" + IMAGE = 4, "Image" report = models.ForeignKey(AnalysisReport, on_delete=models.CASCADE) - file = models.ForeignKey(File, on_delete=models.PROTECT, related_name='+') + file = models.ForeignKey(File, on_delete=models.PROTECT, related_name="+") # NOTE: No validation required. Client will send this information type = models.SmallIntegerField(choices=Type.choices) metadata = models.JSONField(default=dict) @@ -599,15 +599,15 @@ class Type(models.IntegerChoices): class AnalysisReportContainer(models.Model): class ContentType(models.IntegerChoices): - TEXT = 1, 'Text' - HEADING = 2, 'Heading' - IMAGE = 3, 'Image' - URL = 4, 'URL' - TIMELINE_CHART = 5, 'Timeline Chart' - KPI = 6, 'KPIs' - BAR_CHART = 7, 'Bar Chart' - MAP = 8, 'Map' - LINE_CHART = 9, 'Line Chart' + TEXT = 1, "Text" + HEADING = 2, "Heading" + IMAGE = 3, "Image" + URL = 4, "URL" + TIMELINE_CHART = 5, "Timeline Chart" + KPI = 6, "KPIs" + BAR_CHART = 7, "Bar Chart" + MAP = 8, "Map" + LINE_CHART = 9, "Line Chart" report = models.ForeignKey(AnalysisReport, on_delete=models.CASCADE) row = models.SmallIntegerField() @@ -636,4 +636,4 @@ class AnalysisReportSnapshot(UserResource): report = models.ForeignKey(AnalysisReport, on_delete=models.CASCADE) published_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True) published_on = models.DateTimeField(auto_now_add=True) - report_data_file = models.FileField(upload_to='analysis_report_snapshot/') + report_data_file = models.FileField(upload_to="analysis_report_snapshot/") diff --git a/apps/analysis/mutation.py b/apps/analysis/mutation.py index de5bd03c19..bf65f2d97a 100644 --- a/apps/analysis/mutation.py +++ b/apps/analysis/mutation.py @@ -1,112 +1,111 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP from .models import ( AnalysisPillar, - DiscardedEntry, - TopicModel, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, AnalysisReport, - AnalysisReportUpload, AnalysisReportSnapshot, + AnalysisReportUpload, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, + DiscardedEntry, + TopicModel, ) from .schema import ( - get_analysis_pillar_qs, - get_analysis_report_qs, - get_analysis_report_upload_qs, - AnalysisPillarType, - AnalysisPillarDiscardedEntryType, - AnalysisTopicModelType, AnalysisAutomaticSummaryType, - AnalyticalStatementNGramType, - AnalyticalStatementGeoTaskType, + AnalysisPillarDiscardedEntryType, + AnalysisPillarType, + AnalysisReportSnapshotType, AnalysisReportType, AnalysisReportUploadType, - AnalysisReportSnapshotType, + AnalysisTopicModelType, + AnalyticalStatementGeoTaskType, + AnalyticalStatementNGramType, + get_analysis_pillar_qs, + get_analysis_report_qs, + get_analysis_report_upload_qs, ) from .serializers import ( - AnalysisPillarGqlSerializer, - DiscardedEntryGqlSerializer, - AnalysisTopicModelSerializer, AnalysisAutomaticSummarySerializer, - AnalyticalStatementNGramSerializer, - AnalyticalStatementGeoTaskSerializer, + AnalysisPillarGqlSerializer, AnalysisReportSerializer, AnalysisReportSnapshotSerializer, AnalysisReportUploadSerializer, + AnalysisTopicModelSerializer, + AnalyticalStatementGeoTaskSerializer, + AnalyticalStatementNGramSerializer, + DiscardedEntryGqlSerializer, ) - AnalysisPillarUpdateInputType = generate_input_type_for_serializer( - 'AnalysisPillarUpdateInputType', + "AnalysisPillarUpdateInputType", serializer_class=AnalysisPillarGqlSerializer, partial=True, ) DiscardedEntryCreateInputType = generate_input_type_for_serializer( - 'DiscardedEntryCreateInputType', + "DiscardedEntryCreateInputType", serializer_class=DiscardedEntryGqlSerializer, ) DiscardedEntryUpdateInputType = generate_input_type_for_serializer( - 'DiscardedEntryUpdateInputType', + "DiscardedEntryUpdateInputType", serializer_class=DiscardedEntryGqlSerializer, partial=True, ) AnalysisTopicModelCreateInputType = generate_input_type_for_serializer( - 'AnalysisTopicModelCreateInputType', + "AnalysisTopicModelCreateInputType", serializer_class=AnalysisTopicModelSerializer, ) AnalysisAutomaticSummaryCreateInputType = generate_input_type_for_serializer( - 'AnalysisAutomaticSummaryCreateInputType', + "AnalysisAutomaticSummaryCreateInputType", serializer_class=AnalysisAutomaticSummarySerializer, ) AnalyticalStatementNGramCreateInputType = generate_input_type_for_serializer( - 'AnalyticalStatementNGramCreateInputType', + "AnalyticalStatementNGramCreateInputType", serializer_class=AnalyticalStatementNGramSerializer, ) AnalyticalStatementGeoTaskInputType = generate_input_type_for_serializer( - 'AnalyticalStatementGeoTaskInputType', + "AnalyticalStatementGeoTaskInputType", serializer_class=AnalyticalStatementGeoTaskSerializer, ) # Analysi Report AnalysisReportInputType = generate_input_type_for_serializer( - 'AnalysisReportInputType', + "AnalysisReportInputType", serializer_class=AnalysisReportSerializer, ) AnalysisReportInputUpdateType = generate_input_type_for_serializer( - 'AnalysisReportInputUpdateType', + "AnalysisReportInputUpdateType", serializer_class=AnalysisReportSerializer, partial=True, ) AnalysisReportSnapshotInputType = generate_input_type_for_serializer( - 'AnalysisReportSnapshotInputType', + "AnalysisReportSnapshotInputType", serializer_class=AnalysisReportSnapshotSerializer, ) AnalysisReportUploadInputType = generate_input_type_for_serializer( - 'AnalysisReportUploadInputType', + "AnalysisReportUploadInputType", serializer_class=AnalysisReportUploadSerializer, ) -class RequiredPermissionMixin(): +class RequiredPermissionMixin: permissions = [ PP.Permission.VIEW_ENTRY, PP.Permission.CREATE_ANALYSIS_MODULE, @@ -143,6 +142,7 @@ class UpdateAnalysisPillar(AnalysisPillarMutationMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) data = AnalysisPillarUpdateInputType(required=True) + model = AnalysisPillar serializer_class = AnalysisPillarGqlSerializer result = graphene.Field(AnalysisPillarType) @@ -151,13 +151,14 @@ class Arguments: def get_serializer_context(cls, instance, context): return { **context, - 'analysis_end_date': instance.analysis.end_date, + "analysis_end_date": instance.analysis.end_date, } class CreateAnalysisPillarDiscardedEntry(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = DiscardedEntryCreateInputType(required=True) + model = DiscardedEntry serializer_class = DiscardedEntryGqlSerializer result = graphene.Field(AnalysisPillarDiscardedEntryType) @@ -167,6 +168,7 @@ class UpdateAnalysisPillarDiscardedEntry(DiscardedEntriesMutationMixin, PsGraphe class Arguments: id = graphene.ID(required=True) data = DiscardedEntryUpdateInputType(required=True) + model = DiscardedEntry serializer_class = DiscardedEntryGqlSerializer result = graphene.Field(AnalysisPillarDiscardedEntryType) @@ -175,6 +177,7 @@ class Arguments: class DeleteAnalysisPillarDiscardedEntry(DiscardedEntriesMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = DiscardedEntry result = graphene.Field(AnalysisPillarDiscardedEntryType) @@ -183,6 +186,7 @@ class Arguments: class TriggerAnalysisTopicModel(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = AnalysisTopicModelCreateInputType(required=True) + model = TopicModel serializer_class = AnalysisTopicModelSerializer result = graphene.Field(AnalysisTopicModelType) @@ -191,6 +195,7 @@ class Arguments: class TriggerAnalysisAutomaticSummary(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = AnalysisAutomaticSummaryCreateInputType(required=True) + model = AutomaticSummary serializer_class = AnalysisAutomaticSummarySerializer result = graphene.Field(AnalysisAutomaticSummaryType) @@ -199,6 +204,7 @@ class Arguments: class TriggerAnalysisAnalyticalStatementNGram(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = AnalyticalStatementNGramCreateInputType(required=True) + model = AnalyticalStatementNGram serializer_class = AnalyticalStatementNGramSerializer result = graphene.Field(AnalyticalStatementNGramType) @@ -207,6 +213,7 @@ class Arguments: class TriggerAnalysisAnalyticalGeoTask(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = AnalyticalStatementGeoTaskInputType(required=True) + model = AnalyticalStatementGeoTask serializer_class = AnalyticalStatementGeoTaskSerializer result = graphene.Field(AnalyticalStatementGeoTaskType) @@ -216,6 +223,7 @@ class Arguments: class CreateAnalysisReport(AnalysisReportMutationMixin, PsGrapheneMutation): class Arguments: data = AnalysisReportInputType(required=True) + model = AnalysisReport serializer_class = AnalysisReportSerializer result = graphene.Field(AnalysisReportType) @@ -225,6 +233,7 @@ class UpdateAnalysisReport(AnalysisReportMutationMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) data = AnalysisReportInputUpdateType(required=True) + model = AnalysisReport serializer_class = AnalysisReportSerializer result = graphene.Field(AnalysisReportType) @@ -233,13 +242,14 @@ class Arguments: def get_serializer_context(cls, instance, context): return { **context, - 'report': instance, + "report": instance, } class DeleteAnalysisReport(AnalysisReportMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = AnalysisReport result = graphene.Field(AnalysisReportType) @@ -248,6 +258,7 @@ class Arguments: class CreateAnalysisReportSnapshot(RequiredPermissionMixin, PsGrapheneMutation): class Arguments: data = AnalysisReportSnapshotInputType(required=True) + model = AnalysisReportSnapshot serializer_class = AnalysisReportSnapshotSerializer result = graphene.Field(AnalysisReportSnapshotType) @@ -257,6 +268,7 @@ class Arguments: class CreateAnalysisReportUpload(AnalysisReportUploadMutationMixin, PsGrapheneMutation): class Arguments: data = AnalysisReportUploadInputType(required=True) + model = AnalysisReportUpload serializer_class = AnalysisReportUploadSerializer result = graphene.Field(AnalysisReportUploadType) @@ -265,11 +277,12 @@ class Arguments: class DeleteAnalysisReportUpload(AnalysisReportUploadMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = AnalysisReportUpload result = graphene.Field(AnalysisReportUploadType) -class Mutation(): +class Mutation: # Analysis Pillar analysis_pillar_update = UpdateAnalysisPillar.Field() # Discarded Entry diff --git a/apps/analysis/public_schema.py b/apps/analysis/public_schema.py index b03fc412ed..e4f33114fe 100644 --- a/apps/analysis/public_schema.py +++ b/apps/analysis/public_schema.py @@ -1,8 +1,9 @@ import typing + import graphene -from .schema import AnalysisReportSnapshotType from .models import AnalysisReport, AnalysisReportSnapshot +from .schema import AnalysisReportSnapshotType class Query: diff --git a/apps/analysis/schema.py b/apps/analysis/schema.py index d81d6bcc8c..b6342ea768 100644 --- a/apps/analysis/schema.py +++ b/apps/analysis/schema.py @@ -1,80 +1,87 @@ import graphene - +from analysis_framework.models import Widget from django.db import models from django.db.models.functions import Cast +from entry.filter_set import EntriesFilterDataType +from entry.models import Attribute, Entry +from entry.schema import EntryType, get_entry_qs +from gallery.models import File as GalleryFile +from gallery.schema import GalleryFileType +from geo.models import GeoArea from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination - -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin, FileFieldType -from utils.graphene.fields import DjangoPaginatedListObjectField, generate_type_for_serializer -from utils.graphene.enums import EnumDescription -from utils.graphene.geo_scalars import PointScalar -from utils.common import has_select_related -from utils.db.functions import IsEmpty -from deep.permissions import ProjectPermissions as PP -from user_resource.schema import UserResourceMixin, resolve_user_field - from lead.models import Lead -from analysis_framework.models import Widget -from geo.models import GeoArea -from entry.models import Entry, Attribute -from entry.schema import get_entry_qs, EntryType -from entry.filter_set import EntriesFilterDataType -from user.schema import UserType from organization.schema import OrganizationType -from gallery.schema import GalleryFileType -from gallery.models import File as GalleryFile +from user.schema import UserType +from user_resource.schema import UserResourceMixin, resolve_user_field -from .models import ( - Analysis, - AnalysisPillar, - AnalyticalStatement, - AnalyticalStatementEntry, - DiscardedEntry, - TopicModel, - TopicModelCluster, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, - AnalyticalStatementGeoEntry, - AnalysisReport, - AnalysisReportUpload, - AnalysisReportContainer, - AnalysisReportContainerData, - AnalysisReportSnapshot, +from deep.permissions import ProjectPermissions as PP +from utils.common import has_select_related +from utils.db.functions import IsEmpty +from utils.graphene.enums import EnumDescription +from utils.graphene.fields import ( + DjangoPaginatedListObjectField, + generate_type_for_serializer, ) +from utils.graphene.geo_scalars import PointScalar +from utils.graphene.types import ( + ClientIdMixin, + CustomDjangoListObjectType, + FileFieldType, +) + from .enums import ( - DiscardedEntryTagTypeEnum, - TopicModelStatusEnum, - AutomaticSummaryStatusEnum, - AnalyticalStatementNGramStatusEnum, AnalysisReportContainerContentTypeEnum, - AnalyticalStatementGeoTaskStatusEnum, AnalysisReportUploadTypeEnum, + AnalyticalStatementGeoTaskStatusEnum, + AnalyticalStatementNGramStatusEnum, + AutomaticSummaryStatusEnum, + DiscardedEntryTagTypeEnum, + TopicModelStatusEnum, ) from .filter_set import ( AnalysisGQFilterSet, - AnalysisPillarGQFilterSet, - AnalysisPillarEntryGQFilterSet, - AnalyticalStatementGQFilterSet, AnalysisPillarDiscardedEntryGqlFilterSet, + AnalysisPillarEntryGQFilterSet, + AnalysisPillarGQFilterSet, AnalysisReportGQFilterSet, - AnalysisReportUploadGQFilterSet, AnalysisReportSnapshotGQFilterSet, + AnalysisReportUploadGQFilterSet, + AnalyticalStatementGQFilterSet, +) +from .models import ( + Analysis, + AnalysisPillar, + AnalysisReport, + AnalysisReportContainer, + AnalysisReportContainerData, + AnalysisReportSnapshot, + AnalysisReportUpload, + AnalyticalStatement, + AnalyticalStatementEntry, + AnalyticalStatementGeoEntry, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, + DiscardedEntry, + TopicModel, + TopicModelCluster, ) from .serializers import ( AnalysisReportConfigurationSerializer, - AnalysisReportUploadMetadataSerializer, AnalysisReportContainerContentConfigurationSerializer, AnalysisReportContainerStyleSerializer, + AnalysisReportUploadMetadataSerializer, ) def _get_qs(model, info, project_field): - qs = model.objects.filter(**{ - # Filter by project - project_field: info.context.active_project, - }) + qs = model.objects.filter( + **{ + # Filter by project + project_field: info.context.active_project, + } + ) # Generate queryset according to permission if PP.check_permission(info, PP.Permission.VIEW_ENTRY): return qs @@ -82,45 +89,45 @@ def _get_qs(model, info, project_field): def get_analysis_qs(info): - return _get_qs(Analysis, info, 'project') + return _get_qs(Analysis, info, "project") def get_analysis_pillar_qs(info): - return _get_qs(AnalysisPillar, info, 'analysis__project') + return _get_qs(AnalysisPillar, info, "analysis__project") def get_analytical_statement_qs(info): - return _get_qs(AnalyticalStatement, info, 'analysis_pillar__analysis__project') + return _get_qs(AnalyticalStatement, info, "analysis_pillar__analysis__project") def get_analysis_report_qs(info): - return _get_qs(AnalysisReport, info, 'analysis__project') + return _get_qs(AnalysisReport, info, "analysis__project") def get_analysis_report_upload_qs(info): - return _get_qs(AnalysisReportUpload, info, 'report__analysis__project') + return _get_qs(AnalysisReportUpload, info, "report__analysis__project") def get_analysis_report_snaphost_qs(info): - return _get_qs(AnalysisReportSnapshot, info, 'report__analysis__project') + return _get_qs(AnalysisReportSnapshot, info, "report__analysis__project") class AnalyticalStatementEntryType(ClientIdMixin, DjangoObjectType): class Meta: model = AnalyticalStatementEntry only_fields = ( - 'id', - 'order', + "id", + "order", ) entry = graphene.Field(EntryType, required=True) entry_id = graphene.ID(required=True) - analytical_statement = graphene.ID(source='analytical_statement_id', required=True) + analytical_statement = graphene.ID(source="analytical_statement_id", required=True) @staticmethod def resolve_entry(root, info, **_): - if has_select_related(root, 'entry'): - return getattr(root, 'entry') + if has_select_related(root, "entry"): + return getattr(root, "entry") # Use Dataloader to load the data return info.context.dl.entry.entry.load(root.entry_id) @@ -129,16 +136,16 @@ class AnalyticalStatementType(UserResourceMixin, ClientIdMixin, DjangoObjectType class Meta: model = AnalyticalStatement only_fields = ( - 'title', - 'id', - 'statement', - 'report_text', - 'information_gaps', - 'include_in_report', - 'order', + "title", + "id", + "statement", + "report_text", + "information_gaps", + "include_in_report", + "order", ) - cloned_from = graphene.ID(source='cloned_from_id') + cloned_from = graphene.ID(source="cloned_from_id") entries_count = graphene.Int(required=True) # XXX: N+1 and No pagination @@ -160,18 +167,18 @@ def resolve_entries(root, info, **_): class AnalysisPillarDiscardedEntryType(DjangoObjectType): class Meta: model = DiscardedEntry - only_fields = ('id',) + only_fields = ("id",) - analysis_pillar = graphene.ID(source='analysis_pillar_id') + analysis_pillar = graphene.ID(source="analysis_pillar_id") entry = graphene.Field(EntryType, required=True) entry_id = graphene.ID(required=True) tag = graphene.Field(DiscardedEntryTagTypeEnum, required=True) - tag_display = EnumDescription(source='get_tag_display', required=True) + tag_display = EnumDescription(source="get_tag_display", required=True) @staticmethod def resolve_entry(root, info, **_): - if has_select_related(root, 'entry'): - return getattr(root, 'entry') + if has_select_related(root, "entry"): + return getattr(root, "entry") # Use Dataloader to load the data return info.context.dl.entry.entry.load(root.entry_id) @@ -192,33 +199,27 @@ class AnalysisPillarType(ClientIdMixin, UserResourceMixin, DjangoObjectType): class Meta: model = AnalysisPillar only_fields = ( - 'id', - 'title', - 'main_statement', - 'information_gap', - 'filters', + "id", + "title", + "main_statement", + "information_gap", + "filters", ) assignee = graphene.Field(UserType, required=True) - analysis = graphene.ID(source='analysis_id', required=True) - cloned_from = graphene.ID(source='cloned_from_id') + analysis = graphene.ID(source="analysis_id", required=True) + cloned_from = graphene.ID(source="cloned_from_id") analyzed_entries_count = graphene.Int(required=True) # XXX: N+1 and No pagination statements = graphene.List(graphene.NonNull(AnalyticalStatementType)) discarded_entries = DjangoPaginatedListObjectField( - AnalysisPillarDiscardedEntryListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisPillarDiscardedEntryListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) # Generated entries = DjangoPaginatedListObjectField( - AnalysisPillarEntryListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisPillarEntryListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod @@ -227,7 +228,7 @@ def get_custom_queryset(queryset, info, **_): @staticmethod def resolve_assignee(root, info, **_): - return resolve_user_field(root, info, 'assignee') + return resolve_user_field(root, info, "assignee") @staticmethod def resolve_analyzed_entries_count(root, info, **_): @@ -247,7 +248,7 @@ def resolve_entries(root, info, **kwargs): # filtering out the entries whose lead published_on date is less than analysis end_date return root.get_entries_qs( queryset=get_entry_qs(info), - only_discarded=kwargs.get('discarded'), # NOTE: From AnalysisPillarEntryGQFilterSet.discarded + only_discarded=kwargs.get("discarded"), # NOTE: From AnalysisPillarEntryGQFilterSet.discarded ) @@ -255,19 +256,23 @@ class AnalysisType(UserResourceMixin, DjangoObjectType): class Meta: model = Analysis only_fields = ( - 'id', - 'title', - 'start_date', - 'end_date', + "id", + "title", + "start_date", + "end_date", ) - cloned_from = graphene.ID(source='cloned_from_id') + cloned_from = graphene.ID(source="cloned_from_id") team_lead = graphene.Field(UserType, required=True) publication_date = graphene.Field( - type('AnalysisPublicationDateType', (graphene.ObjectType,), { - 'start_date': graphene.Date(required=True), - 'end_date': graphene.Date(required=True), - }) + type( + "AnalysisPublicationDateType", + (graphene.ObjectType,), + { + "start_date": graphene.Date(required=True), + "end_date": graphene.Date(required=True), + }, + ) ) analyzed_entries_count = graphene.Int(required=True) analyzed_leads_count = graphene.Int(required=True) @@ -281,7 +286,7 @@ def get_custom_queryset(queryset, info, **_): @staticmethod def resolve_team_lead(root, info, **_): - return resolve_user_field(root, info, 'team_lead') + return resolve_user_field(root, info, "team_lead") @staticmethod def resolve_publication_date(root, info, **_): @@ -306,13 +311,19 @@ class AnalysisOverviewType(graphene.ObjectType): analyzed_entries_count = graphene.Int(required=True) analyzed_leads_count = graphene.Int(required=True) - authoring_organizations = graphene.List(graphene.NonNull( - type('AnalysisOverviewOrganizationType', (graphene.ObjectType,), { - 'id': graphene.ID(required=True), - 'title': graphene.String(required=True), - 'count': graphene.Int(required=True), - }) - )) + authoring_organizations = graphene.List( + graphene.NonNull( + type( + "AnalysisOverviewOrganizationType", + (graphene.ObjectType,), + { + "id": graphene.ID(required=True), + "title": graphene.String(required=True), + "count": graphene.Int(required=True), + }, + ) + ) + ) # analysis_list': analysis_list, # analysis_list = Analysis.objects.filter(project=project_id).values('id', 'title', 'created_at') @@ -322,64 +333,79 @@ def resolve_total_entries_count(root, info, **_): @staticmethod def resolve_total_leads_count(root, info, **_): - return Lead.objects\ - .filter(project=info.context.active_project)\ - .annotate(entries_count=models.Count('entry'))\ - .filter(entries_count__gt=0)\ + return ( + Lead.objects.filter(project=info.context.active_project) + .annotate(entries_count=models.Count("entry")) + .filter(entries_count__gt=0) .count() + ) @staticmethod def resolve_analyzed_entries_count(root, info, **_): project = info.context.active_project - entries_dragged = AnalyticalStatementEntry.objects\ - .filter(analytical_statement__analysis_pillar__analysis__project=project)\ - .order_by().values('entry').distinct() - entries_discarded = DiscardedEntry.objects\ - .filter(analysis_pillar__analysis__project=project)\ - .order_by().values('entry').distinct() + entries_dragged = ( + AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis__project=project) + .order_by() + .values("entry") + .distinct() + ) + entries_discarded = ( + DiscardedEntry.objects.filter(analysis_pillar__analysis__project=project).order_by().values("entry").distinct() + ) return entries_discarded.union(entries_dragged).count() @staticmethod def resolve_analyzed_leads_count(root, info, **_): project = info.context.active_project - sources_discarded = DiscardedEntry.objects\ - .filter(analysis_pillar__analysis__project=project)\ - .order_by().values('entry__lead_id').distinct() - sources_dragged = AnalyticalStatementEntry.objects\ - .filter(analytical_statement__analysis_pillar__analysis__project=project)\ - .order_by().values('entry__lead_id').distinct() + sources_discarded = ( + DiscardedEntry.objects.filter(analysis_pillar__analysis__project=project) + .order_by() + .values("entry__lead_id") + .distinct() + ) + sources_dragged = ( + AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis__project=project) + .order_by() + .values("entry__lead_id") + .distinct() + ) return sources_dragged.union(sources_discarded).count() @staticmethod def resolve_authoring_organizations(root, info, **_): - lead_qs = Lead.objects\ - .filter( + lead_qs = ( + Lead.objects.filter( project=info.context.active_project, authors__organization_type__isnull=False, - )\ + ) .annotate( - entries_count=models.functions.Coalesce(models.Subquery( - AnalyticalStatementEntry.objects.filter( - entry__lead_id=models.OuterRef('pk') - ).order_by().values('entry__lead_id').annotate(count=models.Count('*')) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0) - ).filter(entries_count__gt=0) - qs = Lead.objects\ - .filter(id__in=lead_qs)\ - .order_by('authors__organization_type').values('authors__organization_type')\ + entries_count=models.functions.Coalesce( + models.Subquery( + AnalyticalStatementEntry.objects.filter(entry__lead_id=models.OuterRef("pk")) + .order_by() + .values("entry__lead_id") + .annotate(count=models.Count("*")) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ) + ) + .filter(entries_count__gt=0) + ) + qs = ( + Lead.objects.filter(id__in=lead_qs) + .order_by("authors__organization_type") + .values("authors__organization_type") .annotate( - count=models.Count('id'), + count=models.Count("id"), organization_type_title=models.functions.Coalesce( - models.F('authors__organization_type__title'), - models.Value(''), - ) - ).values_list( - 'count', - 'organization_type_title', - models.F('authors__organization_type__id') + models.F("authors__organization_type__title"), + models.Value(""), + ), ) + .values_list("count", "organization_type_title", models.F("authors__organization_type__id")) + ) return [ dict( id=_id, @@ -415,9 +441,7 @@ class AnalysisTopicModelClusterType(DjangoObjectType): class Meta: model = TopicModelCluster - only_fields = ( - 'id', - ) + only_fields = ("id",) @staticmethod def resolve_entries(root: TopicModelCluster, info, **_): @@ -428,17 +452,15 @@ class AnalysisTopicModelType(UserResourceMixin, DjangoObjectType): status = graphene.Field(TopicModelStatusEnum, required=True) clusters = graphene.List(AnalysisTopicModelClusterType, required=True) additional_filters = graphene.Field(EntriesFilterDataType) - analysis_pillar = graphene.ID(source='analysis_pillar_id', required=True) + analysis_pillar = graphene.ID(source="analysis_pillar_id", required=True) class Meta: model = TopicModel - only_fields = ( - 'id', - ) + only_fields = ("id",) @staticmethod def get_custom_queryset(queryset, info, **_): - return _get_qs(TopicModel, info, 'analysis_pillar__analysis__project') + return _get_qs(TopicModel, info, "analysis_pillar__analysis__project") @staticmethod def resolve_clusters(root: TopicModel, info, **_): @@ -449,23 +471,21 @@ class AnalysisAutomaticSummaryType(UserResourceMixin, DjangoObjectType): class Meta: model = AutomaticSummary only_fields = ( - 'id', - 'summary', + "id", + "summary", ) status = graphene.Field(AutomaticSummaryStatusEnum, required=True) @staticmethod def get_custom_queryset(queryset, info, **_): - return _get_qs(AutomaticSummary, info, 'project') + return _get_qs(AutomaticSummary, info, "project") class AnalyticalStatementNGramType(UserResourceMixin, DjangoObjectType): class Meta: model = AnalyticalStatementNGram - only_fields = ( - 'id', - ) + only_fields = ("id",) class AnalyticalStatementNGramDataType(graphene.ObjectType): word = graphene.String(required=True) @@ -480,14 +500,11 @@ class AnalyticalStatementNGramDataType(graphene.ObjectType): @staticmethod def get_custom_queryset(queryset, info, **_): - return _get_qs(AnalyticalStatementNGram, info, 'project') + return _get_qs(AnalyticalStatementNGram, info, "project") @staticmethod def render_grams(dict_value): - return [ - dict(word=word, count=count) - for word, count in dict_value.items() - ] + return [dict(word=word, count=count) for word, count in dict_value.items()] @classmethod def resolve_unigrams(cls, root: AnalyticalStatementNGram, info, **_): @@ -505,29 +522,38 @@ def resolve_trigrams(cls, root: AnalyticalStatementNGram, info, **_): class AnalyticalStatementEntryGeoType(DjangoObjectType): class Meta: model = AnalyticalStatementGeoEntry - only_fields = ( - 'id', - ) + only_fields = ("id",) entry = graphene.Field(EntryType, required=True) entry_id = graphene.ID(required=True) - data = graphene.List(graphene.NonNull( - type('AnalyticalStatementEntryGeoDataType', (graphene.ObjectType,), { - 'entity': graphene.String(), - 'meta': graphene.NonNull( - type('AnalyticalStatementEntryGeoMetaDataType', (graphene.ObjectType,), { - 'latitude': graphene.Float(), - 'longitude': graphene.Float(), - 'offset_start': graphene.Int(), - 'offset_end': graphene.Int(), - })) - }) - )) + data = graphene.List( + graphene.NonNull( + type( + "AnalyticalStatementEntryGeoDataType", + (graphene.ObjectType,), + { + "entity": graphene.String(), + "meta": graphene.NonNull( + type( + "AnalyticalStatementEntryGeoMetaDataType", + (graphene.ObjectType,), + { + "latitude": graphene.Float(), + "longitude": graphene.Float(), + "offset_start": graphene.Int(), + "offset_end": graphene.Int(), + }, + ) + ), + }, + ) + ) + ) @staticmethod def resolve_entry(root, info, **_): - if has_select_related(root, 'entry'): - return getattr(root, 'entry') + if has_select_related(root, "entry"): + return getattr(root, "entry") # Use Dataloader to load the data return info.context.dl.entry.entry.load(root.entry_id) @@ -535,16 +561,14 @@ def resolve_entry(root, info, **_): class AnalyticalStatementGeoTaskType(UserResourceMixin, DjangoObjectType): class Meta: model = AnalyticalStatementGeoTask - only_fields = ( - 'id', - ) + only_fields = ("id",) status = graphene.Field(AnalyticalStatementGeoTaskStatusEnum, required=True) entry_geo = graphene.List(AnalyticalStatementEntryGeoType, required=True) @staticmethod def get_custom_queryset(queryset, info, **_): - return _get_qs(AnalyticalStatementGeoTask, info, 'project') + return _get_qs(AnalyticalStatementGeoTask, info, "project") @staticmethod def resolve_entry_geo(root, info, **_): @@ -560,16 +584,18 @@ class AnalysisReportUploadType(DjangoObjectType): class Meta: model = AnalysisReportUpload only_fields = ( - 'id', - 'file', + "id", + "file", ) - report = graphene.ID(source='report_id', required=True) + report = graphene.ID(source="report_id", required=True) type = graphene.Field(AnalysisReportUploadTypeEnum, required=True) - metadata = graphene.Field(generate_type_for_serializer( - 'AnalysisReportUploadMetadataType', - serializer_class=AnalysisReportUploadMetadataSerializer, - )) + metadata = graphene.Field( + generate_type_for_serializer( + "AnalysisReportUploadMetadataType", + serializer_class=AnalysisReportUploadMetadataSerializer, + ) + ) @staticmethod def get_custom_queryset(queryset, info, **_): @@ -584,10 +610,10 @@ class AnalysisReportContainerDataType(ClientIdMixin, DjangoObjectType): class Meta: model = AnalysisReportContainerData only_fields = ( - 'id', - 'upload', # AnalysisReportUploadType - 'data', # NOTE: This is Generic for now - 'client_reference_id', + "id", + "upload", # AnalysisReportUploadType + "data", # NOTE: This is Generic for now + "client_reference_id", ) @staticmethod @@ -599,19 +625,19 @@ class AnalysisReportContainerType(ClientIdMixin, DjangoObjectType): class Meta: model = AnalysisReportContainer only_fields = ( - 'id', - 'row', - 'column', - 'width', - 'height', + "id", + "row", + "column", + "width", + "height", ) content_type = graphene.Field(AnalysisReportContainerContentTypeEnum, required=True) - report = graphene.ID(source='report_id', required=True) + report = graphene.ID(source="report_id", required=True) style = graphene.Field( generate_type_for_serializer( - 'AnalysisReportContainerStyleType', + "AnalysisReportContainerStyleType", serializer_class=AnalysisReportContainerStyleSerializer, update_cache=True, ) @@ -619,7 +645,7 @@ class Meta: # Content metadata content_configuration = graphene.Field( generate_type_for_serializer( - 'AnalysisReportContainerContentConfigurationType', + "AnalysisReportContainerContentConfigurationType", serializer_class=AnalysisReportContainerContentConfigurationSerializer, ) ) @@ -634,11 +660,11 @@ class AnalysisReportSnapshotType(DjangoObjectType): class Meta: model = AnalysisReportSnapshot only_fields = ( - 'id', - 'published_on', + "id", + "published_on", ) - report = graphene.ID(source='report_id', required=True) + report = graphene.ID(source="report_id", required=True) published_by = graphene.Field(UserType, required=True) report_data_file = graphene.Field(FileFieldType) files = graphene.List(graphene.NonNull(GalleryFileType), required=True) @@ -649,17 +675,15 @@ def get_custom_queryset(queryset, info, **_): @staticmethod def resolve_published_by(root, info, **_): - return resolve_user_field(root, info, 'published_by') + return resolve_user_field(root, info, "published_by") @staticmethod def resolve_files(root, info, **_): # For now # - organization logos # - report uploads - related_file_id = ( - root.report.analysisreportupload_set.values_list('file').union( - root.report.organizations.values_list('logo') - ) + related_file_id = root.report.analysisreportupload_set.values_list("file").union( + root.report.organizations.values_list("logo") ) return GalleryFile.objects.filter(id__in=related_file_id).all() @@ -668,25 +692,22 @@ class AnalysisReportType(UserResourceMixin, DjangoObjectType): class Meta: model = AnalysisReport only_fields = ( - 'id', - 'is_public', - 'slug', - 'title', - 'sub_title', + "id", + "is_public", + "slug", + "title", + "sub_title", ) - analysis = graphene.ID(source='analysis_id', required=True) - configuration = graphene.Field(generate_type_for_serializer( - 'AnalysisReportConfigurationType', - serializer_class=AnalysisReportConfigurationSerializer, - )) - - containers = graphene.List( - graphene.NonNull( - AnalysisReportContainerType - ), - required=True + analysis = graphene.ID(source="analysis_id", required=True) + configuration = graphene.Field( + generate_type_for_serializer( + "AnalysisReportConfigurationType", + serializer_class=AnalysisReportConfigurationSerializer, + ) ) + + containers = graphene.List(graphene.NonNull(AnalysisReportContainerType), required=True) organizations = graphene.List(graphene.NonNull(OrganizationType), required=True) uploads = graphene.List(graphene.NonNull(AnalysisReportUploadType), required=True) latest_snapshot = graphene.Field(AnalysisReportSnapshotType, required=False) @@ -734,28 +755,19 @@ class Query: analysis_overview = graphene.Field(AnalysisOverviewType) analysis = DjangoObjectField(AnalysisType) analyses = DjangoPaginatedListObjectField( - AnalysisListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) # Pillar analysis_pillar = DjangoObjectField(AnalysisPillarType) analysis_pillars = DjangoPaginatedListObjectField( - AnalysisPillarListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisPillarListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) # Statement analytical_statement = DjangoObjectField(AnalyticalStatementType) analytical_statements = DjangoPaginatedListObjectField( - AnalyticalStatementListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalyticalStatementListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) # Custom entry nodes @@ -776,24 +788,15 @@ class Query: # Report analysis_report = DjangoObjectField(AnalysisReportType) analysis_reports = DjangoPaginatedListObjectField( - AnalysisReportListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisReportListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) analysis_report_upload = DjangoObjectField(AnalysisReportUploadType) analysis_report_uploads = DjangoPaginatedListObjectField( - AnalysisReportUploadListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisReportUploadListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) analysis_report_snapshot = DjangoObjectField(AnalysisReportSnapshotType) analysis_report_snapshots = DjangoPaginatedListObjectField( - AnalysisReportSnapshotListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisReportSnapshotListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod @@ -829,17 +832,17 @@ def resolve_entries_geo_data(_, info, entries_id): ).values( geo_area_id=Cast( models.Func( - models.F('data__value'), - function='jsonb_array_elements_text', + models.F("data__value"), + function="jsonb_array_elements_text", ), output_field=models.IntegerField(), ) ) geo_area_centroids_map = { geo_area_id: centroid - for geo_area_id, centroid in GeoArea.objects.filter( - id__in=entry_geo_area_id_qs.values('geo_area_id') - ).exclude(IsEmpty('centroid')).values_list('id', 'centroid') + for geo_area_id, centroid in GeoArea.objects.filter(id__in=entry_geo_area_id_qs.values("geo_area_id")) + .exclude(IsEmpty("centroid")) + .values_list("id", "centroid") if centroid is not None } return [ @@ -848,10 +851,10 @@ def resolve_entries_geo_data(_, info, entries_id): count=count, ) for geo_area_id, count in ( - entry_geo_area_id_qs - .order_by().values('geo_area_id') - .annotate(count=models.Count('*')) - .values_list('geo_area_id', 'count') + entry_geo_area_id_qs.order_by() + .values("geo_area_id") + .annotate(count=models.Count("*")) + .values_list("geo_area_id", "count") ) if geo_area_id in geo_area_centroids_map ] diff --git a/apps/analysis/serializers.py b/apps/analysis/serializers.py index fbc4a41654..e06f04371d 100644 --- a/apps/analysis/serializers.py +++ b/apps/analysis/serializers.py @@ -1,76 +1,78 @@ import logging from typing import Callable + +from commons.schema_snapshots import SnapshotQuery, generate_query_snapshot from django.conf import settings +from django.db import models, transaction from django.shortcuts import get_object_or_404 - -from rest_framework import serializers from drf_dynamic_fields import DynamicFieldsMixin -from drf_writable_nested import UniqueFieldsMixin, NestedCreateMixin -from django.db import transaction, models +from drf_writable_nested import NestedCreateMixin, UniqueFieldsMixin +from entry.filter_set import EntriesFilterDataInputType, EntryGQFilterSet +from entry.serializers import SimpleEntrySerializer +from rest_framework import serializers +from user.serializers import NanoUserSerializer +from user_resource.serializers import UserResourceSerializer from deep.graphene_context import GQLContext -from utils.graphene.fields import generate_serializer_field_class -from commons.schema_snapshots import generate_query_snapshot, SnapshotQuery -from deep.writable_nested_serializers import NestedUpdateMixin as CustomNestedUpdateMixin from deep.serializers import ( + GraphqlSupportDrfSerializerJSONField, + IdListField, + IntegerIDField, + ProjectPropertySerializerMixin, RemoveNullFieldsMixin, StringListField, TempClientIdMixin, - IntegerIDField, - IdListField, - GraphqlSupportDrfSerializerJSONField, - ProjectPropertySerializerMixin, ) -from user_resource.serializers import UserResourceSerializer -from user.serializers import NanoUserSerializer -from entry.serializers import SimpleEntrySerializer -from entry.filter_set import EntryGQFilterSet, EntriesFilterDataInputType +from deep.writable_nested_serializers import ( + NestedUpdateMixin as CustomNestedUpdateMixin, +) +from utils.graphene.fields import generate_serializer_field_class -from .models import ( +from .models import ( # Report Analysis, AnalysisPillar, + AnalysisReport, + AnalysisReportContainer, + AnalysisReportContainerData, + AnalysisReportSnapshot, + AnalysisReportUpload, AnalyticalStatement, AnalyticalStatementEntry, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, DiscardedEntry, - TopicModel, EntriesCollectionNlpTriggerBase, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, - # Report - AnalysisReport, - AnalysisReportUpload, - AnalysisReportContainerData, - AnalysisReportContainer, - AnalysisReportSnapshot, + TopicModel, ) from .tasks import ( - trigger_topic_model, - trigger_automatic_summary, trigger_automatic_ngram, + trigger_automatic_summary, trigger_geo_location, + trigger_topic_model, ) - logger = logging.getLogger(__name__) class AnalyticalEntriesSerializer(UniqueFieldsMixin, UserResourceSerializer): class Meta: model = AnalyticalStatementEntry - fields = ('id', 'client_id', 'order', 'entry') - read_only_fields = ('analytical_statement',) + fields = ("id", "client_id", "order", "entry") + read_only_fields = ("analytical_statement",) def validate(self, data): - analysis_id = self.context['view'].kwargs.get('analysis_id') + analysis_id = self.context["view"].kwargs.get("analysis_id") analysis = get_object_or_404(Analysis, id=analysis_id) analysis_end_date = analysis.end_date - entry = data.get('entry') + entry = data.get("entry") lead_published = entry.lead.published_on if analysis_end_date and lead_published and lead_published > analysis_end_date: - raise serializers.ValidationError({ - 'entry': f'Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}', - }) + raise serializers.ValidationError( + { + "entry": f"Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}", + } + ) return data @@ -82,47 +84,47 @@ class AnalyticalStatementSerializer( # XXX: This is a custom mixin where we delete first and then create to avoid duplicate key value CustomNestedUpdateMixin, ): - analytical_entries = AnalyticalEntriesSerializer(source='analyticalstatemententry_set', many=True, required=False) + analytical_entries = AnalyticalEntriesSerializer(source="analyticalstatemententry_set", many=True, required=False) class Meta: model = AnalyticalStatement - fields = '__all__' - read_only_fields = ('analysis_pillar',) + fields = "__all__" + read_only_fields = ("analysis_pillar",) def validate(self, data): - analysis_pillar_id = self.context['view'].kwargs.get('analysis_pillar_id', None) + analysis_pillar_id = self.context["view"].kwargs.get("analysis_pillar_id", None) if analysis_pillar_id: - data['analysis_pillar_id'] = int(analysis_pillar_id) + data["analysis_pillar_id"] = int(analysis_pillar_id) # Validate the analytical_entries - entries = data.get('analyticalstatemententry_set') + entries = data.get("analyticalstatemententry_set") if entries and len(entries) > settings.ANALYTICAL_ENTRIES_COUNT: - raise serializers.ValidationError( - f'Analytical entires count must be less than {settings.ANALYTICAL_ENTRIES_COUNT}' - ) + raise serializers.ValidationError(f"Analytical entires count must be less than {settings.ANALYTICAL_ENTRIES_COUNT}") return data class DiscardedEntrySerializer(serializers.ModelSerializer): - tag_display = serializers.CharField(source='get_tag_display', read_only=True) - entry_details = SimpleEntrySerializer(source='entry', read_only=True) + tag_display = serializers.CharField(source="get_tag_display", read_only=True) + entry_details = SimpleEntrySerializer(source="entry", read_only=True) class Meta: model = DiscardedEntry - fields = '__all__' - read_only_fields = ['analysis_pillar'] + fields = "__all__" + read_only_fields = ["analysis_pillar"] def validate(self, data): - data['analysis_pillar_id'] = int(self.context['analysis_pillar_id']) - analysis_pillar = get_object_or_404(AnalysisPillar, id=data['analysis_pillar_id']) - entry = data.get('entry') + data["analysis_pillar_id"] = int(self.context["analysis_pillar_id"]) + analysis_pillar = get_object_or_404(AnalysisPillar, id=data["analysis_pillar_id"]) + entry = data.get("entry") if entry.project != analysis_pillar.analysis.project: - raise serializers.ValidationError('Analysis pillar project doesnot match Entry project') + raise serializers.ValidationError("Analysis pillar project doesnot match Entry project") # validating the entry for the lead published_on greater than analysis end date analysis_end_date = analysis_pillar.analysis.end_date if entry.lead.published_on > analysis_end_date: - raise serializers.ValidationError({ - 'entry': f'Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}', - }) + raise serializers.ValidationError( + { + "entry": f"Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}", + } + ) return data @@ -131,24 +133,24 @@ class AnalysisPillarSerializer( DynamicFieldsMixin, UserResourceSerializer, ): - assignee_details = NanoUserSerializer(source='assignee', read_only=True) - analysis_title = serializers.CharField(source='analysis.title', read_only=True) - analytical_statements = AnalyticalStatementSerializer(many=True, source='analyticalstatement_set', required=False) + assignee_details = NanoUserSerializer(source="assignee", read_only=True) + analysis_title = serializers.CharField(source="analysis.title", read_only=True) + analytical_statements = AnalyticalStatementSerializer(many=True, source="analyticalstatement_set", required=False) class Meta: model = AnalysisPillar - fields = '__all__' - read_only_fields = ('analysis',) + fields = "__all__" + read_only_fields = ("analysis",) def validate(self, data): - analysis_id = self.context['view'].kwargs.get('analysis_id', None) + analysis_id = self.context["view"].kwargs.get("analysis_id", None) if analysis_id: - data['analysis_id'] = int(analysis_id) + data["analysis_id"] = int(analysis_id) # validate analysis_statement - analytical_statement = data.get('analyticalstatement_set') + analytical_statement = data.get("analyticalstatement_set") if analytical_statement and len(analytical_statement) > settings.ANALYTICAL_STATEMENT_COUNT: raise serializers.ValidationError( - f'Analytical statement count must be less than {settings.ANALYTICAL_STATEMENT_COUNT}' + f"Analytical statement count must be less than {settings.ANALYTICAL_STATEMENT_COUNT}" ) return data @@ -158,23 +160,21 @@ class AnalysisSerializer( DynamicFieldsMixin, UserResourceSerializer, ): - analysis_pillar = AnalysisPillarSerializer(many=True, source='analysispillar_set', required=False) - team_lead_details = NanoUserSerializer(source='team_lead', read_only=True) + analysis_pillar = AnalysisPillarSerializer(many=True, source="analysispillar_set", required=False) + team_lead_details = NanoUserSerializer(source="team_lead", read_only=True) start_date = serializers.DateField(required=False, allow_null=True) class Meta: model = Analysis - fields = '__all__' - read_only_fields = ('project',) + fields = "__all__" + read_only_fields = ("project",) def validate(self, data): - data['project_id'] = int(self.context['view'].kwargs['project_id']) - start_date = data.get('start_date') - end_date = data.get('end_date') + data["project_id"] = int(self.context["view"].kwargs["project_id"]) + start_date = data.get("start_date") + end_date = data.get("end_date") if start_date and start_date > end_date: - raise serializers.ValidationError( - {'end_date': 'End date must occur after start date'} - ) + raise serializers.ValidationError({"end_date": "End date must occur after start date"}) return data @@ -184,53 +184,61 @@ class AnalysisCloneInputSerializer(serializers.Serializer): end_date = serializers.DateField(required=True, write_only=True) def validate(self, data): - start_date = data.get('start_date') - end_date = data.get('end_date') + start_date = data.get("start_date") + end_date = data.get("end_date") if start_date and start_date > end_date: - raise serializers.ValidationError( - {'end_date': 'End date must occur after start date'} - ) + raise serializers.ValidationError({"end_date": "End date must occur after start date"}) return data class AnalysisSummaryPillarSerializer(serializers.ModelSerializer): analyzed_entries = serializers.IntegerField() - assignee_details = NanoUserSerializer(source='assignee') + assignee_details = NanoUserSerializer(source="assignee") class Meta: model = AnalysisPillar - fields = ('id', 'title', 'analyzed_entries', 'assignee_details') + fields = ("id", "title", "analyzed_entries", "assignee_details") class AnalysisSummarySerializer(serializers.ModelSerializer): """ Used with Analysis.annotate_for_analysis_summary """ + total_entries = serializers.IntegerField() total_sources = serializers.IntegerField() analyzed_entries = serializers.SerializerMethodField() publication_date = serializers.JSONField() - team_lead_details = NanoUserSerializer(source='team_lead', read_only=True) - pillars = AnalysisSummaryPillarSerializer(source='analysispillar_set', many=True, read_only=True) + team_lead_details = NanoUserSerializer(source="team_lead", read_only=True) + pillars = AnalysisSummaryPillarSerializer(source="analysispillar_set", many=True, read_only=True) analyzed_sources = serializers.SerializerMethodField() class Meta: model = Analysis fields = ( - 'id', 'title', 'team_lead', 'team_lead_details', - 'publication_date', 'pillars', - 'end_date', 'start_date', - 'analyzed_entries', 'analyzed_sources', 'total_entries', - 'total_sources', 'created_at', 'modified_at', + "id", + "title", + "team_lead", + "team_lead_details", + "publication_date", + "pillars", + "end_date", + "start_date", + "analyzed_entries", + "analyzed_sources", + "total_entries", + "total_sources", + "created_at", + "modified_at", ) def get_analyzed_sources(self, analysis): - return self.context['analyzed_sources'].get(analysis.pk) + return self.context["analyzed_sources"].get(analysis.pk) def get_analyzed_entries(self, analysis): - return self.context['analyzed_entries'].get(analysis.pk) + return self.context["analyzed_entries"].get(analysis.pk) class AnalysisPillarSummaryAnalyticalStatementSerializer(serializers.ModelSerializer): @@ -238,52 +246,51 @@ class AnalysisPillarSummaryAnalyticalStatementSerializer(serializers.ModelSerial class Meta: model = AnalyticalStatement - fields = ('id', 'statement', 'entries_count') + fields = ("id", "statement", "entries_count") class AnalysisPillarSummarySerializer(serializers.ModelSerializer): - assignee_details = NanoUserSerializer(source='assignee', read_only=True) + assignee_details = NanoUserSerializer(source="assignee", read_only=True) analytical_statements = AnalysisPillarSummaryAnalyticalStatementSerializer( - source='analyticalstatement_set', many=True, read_only=True) + source="analyticalstatement_set", many=True, read_only=True + ) analyzed_entries = serializers.IntegerField(read_only=True) class Meta: model = AnalysisPillar - fields = ( - 'id', 'title', 'assignee', 'created_at', - 'assignee_details', - 'analytical_statements', - 'analyzed_entries' - ) + fields = ("id", "title", "assignee", "created_at", "assignee_details", "analytical_statements", "analyzed_entries") # ------ GRAPHQL ------------ + class AnalyticalEntriesGqlSerializer(TempClientIdMixin, UniqueFieldsMixin, UserResourceSerializer): id = IntegerIDField(required=False) class Meta: model = AnalyticalStatementEntry fields = ( - 'id', - 'order', - 'entry', - 'client_id', + "id", + "order", + "entry", + "client_id", ) def validate_entry(self, entry): - if entry.project != self.context['request'].active_project: - raise serializers.ValidationError('Invalid entry') + if entry.project != self.context["request"].active_project: + raise serializers.ValidationError("Invalid entry") return entry def validate(self, data): - analysis_end_date = self.context['analysis_end_date'] # Passed by UpdateAnalysisPillar Mutation - entry = data.get('entry') + analysis_end_date = self.context["analysis_end_date"] # Passed by UpdateAnalysisPillar Mutation + entry = data.get("entry") lead_published = entry.lead.published_on if analysis_end_date and lead_published and lead_published > analysis_end_date: - raise serializers.ValidationError({ - 'entry': f'Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}', - }) + raise serializers.ValidationError( + { + "entry": f"Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}", + } + ) return data @@ -295,22 +302,22 @@ class AnalyticalStatementGqlSerializer( CustomNestedUpdateMixin, ): id = IntegerIDField(required=False) - entries = AnalyticalEntriesGqlSerializer(source='analyticalstatemententry_set', many=True, required=False) + entries = AnalyticalEntriesGqlSerializer(source="analyticalstatemententry_set", many=True, required=False) class Meta: model = AnalyticalStatement fields = ( - 'title', - 'id', - 'statement', - 'report_text', - 'information_gaps', - 'include_in_report', - 'order', - 'cloned_from', + "title", + "id", + "statement", + "report_text", + "information_gaps", + "include_in_report", + "order", + "cloned_from", # Custom - 'entries', - 'client_id', + "entries", + "client_id", ) # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -322,30 +329,28 @@ def _get_prefetch_related_instances_qs(self, qs): def validate(self, data): # Validate the analytical_entries - entries = data.get('analyticalstatemententry_set') + entries = data.get("analyticalstatemententry_set") if entries and len(entries) > settings.ANALYTICAL_ENTRIES_COUNT: - raise serializers.ValidationError( - f'Analytical entires count must be less than {settings.ANALYTICAL_ENTRIES_COUNT}' - ) + raise serializers.ValidationError(f"Analytical entires count must be less than {settings.ANALYTICAL_ENTRIES_COUNT}") return data class AnalysisPillarGqlSerializer(TempClientIdMixin, UserResourceSerializer): - statements = AnalyticalStatementGqlSerializer(many=True, source='analyticalstatement_set', required=False) + statements = AnalyticalStatementGqlSerializer(many=True, source="analyticalstatement_set", required=False) class Meta: model = AnalysisPillar fields = ( - 'title', - 'main_statement', - 'information_gap', - 'filters', - 'assignee', - 'analysis', - 'cloned_from', + "title", + "main_statement", + "information_gap", + "filters", + "assignee", + "analysis", + "cloned_from", # Custom - 'statements', - 'client_id', + "statements", + "client_id", ) # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -357,16 +362,16 @@ def _get_prefetch_related_instances_qs(self, qs): return qs.none() # On create throw error if existing id is provided def validate_analysis(self, analysis): - if analysis.project != self.context['request'].active_project: - raise serializers.ValidationError('Invalid analysis') + if analysis.project != self.context["request"].active_project: + raise serializers.ValidationError("Invalid analysis") return analysis def validate(self, data): # validate analysis_statement - analytical_statement = data.get('analyticalstatement_set') + analytical_statement = data.get("analyticalstatement_set") if analytical_statement and len(analytical_statement) > settings.ANALYTICAL_STATEMENT_COUNT: raise serializers.ValidationError( - f'Analytical statement count must be less than {settings.ANALYTICAL_STATEMENT_COUNT}' + f"Analytical statement count must be less than {settings.ANALYTICAL_STATEMENT_COUNT}" ) return data @@ -377,67 +382,64 @@ class DiscardedEntryGqlSerializer(serializers.ModelSerializer): class Meta: model = DiscardedEntry fields = ( - 'id', - 'analysis_pillar', - 'entry', - 'tag', + "id", + "analysis_pillar", + "entry", + "tag", ) def validate_analysis_pillar(self, analysis_pillar): - if analysis_pillar.analysis.project != self.context['request'].active_project: - raise serializers.ValidationError('Invalid analysis_pillar') + if analysis_pillar.analysis.project != self.context["request"].active_project: + raise serializers.ValidationError("Invalid analysis_pillar") return analysis_pillar def validate(self, data): # Validate entry data but analysis_pillar is required to do so - entry = data.get('entry') + entry = data.get("entry") if entry: - analysis_pillar = ( - self.instance.analysis_pillar if self.instance - else data['analysis_pillar'] - ) + analysis_pillar = self.instance.analysis_pillar if self.instance else data["analysis_pillar"] if entry.project != analysis_pillar.analysis.project: - raise serializers.ValidationError('Analysis pillar project doesnot match Entry project') + raise serializers.ValidationError("Analysis pillar project doesnot match Entry project") # validating the entry for the lead published_on greater than analysis end date analysis_end_date = analysis_pillar.analysis.end_date if entry.lead.published_on > analysis_end_date: - raise serializers.ValidationError({ - 'entry': ( - f'Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}' - ), - }) + raise serializers.ValidationError( + { + "entry": ( + f"Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis_end_date}" + ), + } + ) return data class AnalysisGqlSerializer(UserResourceSerializer): id = IntegerIDField(required=False) - analysis_pillar = AnalysisPillarGqlSerializer(many=True, source='analysispillar_set', required=False) + analysis_pillar = AnalysisPillarGqlSerializer(many=True, source="analysispillar_set", required=False) start_date = serializers.DateField(required=False, allow_null=True) class Meta: model = Analysis fields = ( - 'id', - 'title', - 'team_lead', - 'project', - 'start_date', - 'end_date', - 'cloned_from', + "id", + "title", + "team_lead", + "project", + "start_date", + "end_date", + "cloned_from", ) def validate_project(self, project): - if project != self.context['request'].active_project: - raise serializers.ValidationError('Invalid project') + if project != self.context["request"].active_project: + raise serializers.ValidationError("Invalid project") return project def validate(self, data): - start_date = data.get('start_date') - end_date = data.get('end_date') + start_date = data.get("start_date") + end_date = data.get("end_date") if start_date and start_date > end_date: - raise serializers.ValidationError( - {'end_date': 'End date must occur after start date'} - ) + raise serializers.ValidationError({"end_date": "End date must occur after start date"}) return data @@ -454,32 +456,30 @@ class AnalysisTopicModelSerializer(UserResourceSerializer, serializers.ModelSeri class Meta: model = TopicModel fields = ( - 'analysis_pillar', - 'additional_filters', - 'widget_tags', + "analysis_pillar", + "additional_filters", + "widget_tags", ) def validate_analysis_pillar(self, analysis_pillar): - if analysis_pillar.analysis.project != self.context['request'].active_project: - raise serializers.ValidationError('Invalid analysis pillar') + if analysis_pillar.analysis.project != self.context["request"].active_project: + raise serializers.ValidationError("Invalid analysis pillar") return analysis_pillar def validate_additional_filters(self, additional_filters): - filter_set = EntryGQFilterSet(data=additional_filters, request=self.context['request']) + filter_set = EntryGQFilterSet(data=additional_filters, request=self.context["request"]) if not filter_set.is_valid(): raise serializers.ValidationError(filter_set.errors) return additional_filters def create(self, data): if not TopicModel._get_entries_qs( - data['analysis_pillar'], - data.get('additional_filters') or {}, + data["analysis_pillar"], + data.get("additional_filters") or {}, ).exists(): - raise serializers.ValidationError('No entries found to process') + raise serializers.ValidationError("No entries found to process") instance = super().create(data) - transaction.on_commit( - lambda: trigger_topic_model.delay(instance.pk) - ) + transaction.on_commit(lambda: trigger_topic_model.delay(instance.pk)) return instance @@ -489,32 +489,25 @@ class EntriesCollectionNlpTriggerBaseSerializer(UserResourceSerializer, serializ class Meta: model = EntriesCollectionNlpTriggerBase - fields = ( - 'entries_id', - ) + fields = ("entries_id",) def validate_entries_id(self, entries_id): - entries_id = self.Meta.model.get_valid_entries_id( - self.context['request'].active_project.id, - entries_id - ) + entries_id = self.Meta.model.get_valid_entries_id(self.context["request"].active_project.id, entries_id) if not entries_id: - raise serializers.ValidationError('No entries found to process') + raise serializers.ValidationError("No entries found to process") return entries_id def create(self, data): - data['project'] = self.context['request'].active_project - existing_instance = self.Meta.model.get_existing(data['entries_id']) + data["project"] = self.context["request"].active_project + existing_instance = self.Meta.model.get_existing(data["entries_id"]) if existing_instance: return existing_instance instance = super().create(data) - transaction.on_commit( - lambda: self.trigger_task_func.delay(instance.pk) - ) + transaction.on_commit(lambda: self.trigger_task_func.delay(instance.pk)) return instance def update(self, _): - raise serializers.ValidationError('Not allowed using this serializer.') + raise serializers.ValidationError("Not allowed using this serializer.") class AnalysisAutomaticSummarySerializer(EntriesCollectionNlpTriggerBaseSerializer): @@ -524,8 +517,8 @@ class AnalysisAutomaticSummarySerializer(EntriesCollectionNlpTriggerBaseSerializ class Meta: model = AutomaticSummary fields = ( - 'entries_id', - 'widget_tags', + "entries_id", + "widget_tags", ) @@ -534,9 +527,7 @@ class AnalyticalStatementNGramSerializer(EntriesCollectionNlpTriggerBaseSerializ class Meta: model = AnalyticalStatementNGram - fields = ( - 'entries_id', - ) + fields = ("entries_id",) class AnalyticalStatementGeoTaskSerializer(EntriesCollectionNlpTriggerBaseSerializer): @@ -544,97 +535,95 @@ class AnalyticalStatementGeoTaskSerializer(EntriesCollectionNlpTriggerBaseSerial class Meta: model = AnalyticalStatementGeoTask - fields = ( - 'entries_id', - ) + fields = ("entries_id",) # -------------------------- ReportModule -------------------------------- class ReportEnum: class VariableType(models.TextChoices): - TEXT = 'text' - NUMBER = 'number' - DATE = 'date' - BOOLEAN = 'boolean' + TEXT = "text" + NUMBER = "number" + DATE = "date" + BOOLEAN = "boolean" class TextStyleAlign(models.TextChoices): - START = 'start' - END = 'end' - CENTER = 'center' - JUSTIFIED = 'justified' + START = "start" + END = "end" + CENTER = "center" + JUSTIFIED = "justified" class BorderStyleStyle(models.TextChoices): - DOTTED = 'dotted' - DASHED = 'dashed' - SOLID = 'solid' - DOUBLE = 'double' - NONE = 'none' + DOTTED = "dotted" + DASHED = "dashed" + SOLID = "solid" + DOUBLE = "double" + NONE = "none" class ImageContentStyleFit(models.TextChoices): - FILL = 'fill' - CONTAIN = 'contain' - COVER = 'cover' - SCALE_DOWN = 'scale-down' - NONE = 'none' + FILL = "fill" + CONTAIN = "contain" + COVER = "cover" + SCALE_DOWN = "scale-down" + NONE = "none" class HeadingConfigurationVariant(models.TextChoices): - H1 = 'h1' - H2 = 'h2' - H3 = 'h3' - H4 = 'h4' + H1 = "h1" + H2 = "h2" + H3 = "h3" + H4 = "h4" class HorizontalAxisType(models.TextChoices): - CATEGORICAL = 'categorical' - NUMERIC = 'numeric' - DATE = 'date' + CATEGORICAL = "categorical" + NUMERIC = "numeric" + DATE = "date" class BarChartType(models.TextChoices): - SIDE_BY_SIDE = 'side-by-side' - STACKED = 'stacked' + SIDE_BY_SIDE = "side-by-side" + STACKED = "stacked" class BarChartDirection(models.TextChoices): - VERTICAL = 'vertical' - HORIZONTAL = 'horizontal' + VERTICAL = "vertical" + HORIZONTAL = "horizontal" class LegendPosition(models.TextChoices): - TOP = 'top' - LEFT = 'left' - BOTTOM = 'bottom' - RIGHT = 'right' + TOP = "top" + LEFT = "left" + BOTTOM = "bottom" + RIGHT = "right" class LegendDotShape(models.TextChoices): - CIRCLE = 'circle' - TRIANGLE = 'triangle' - SQUARE = 'square' - DIAMOND = 'diamond' + CIRCLE = "circle" + TRIANGLE = "triangle" + SQUARE = "square" + DIAMOND = "diamond" class AggregationType(models.TextChoices): - COUNT = 'count' - SUM = 'sum' - MEAN = 'mean' - MEDIAN = 'median' - MIN = 'min' - MAX = 'max' + COUNT = "count" + SUM = "sum" + MEAN = "mean" + MEDIAN = "median" + MIN = "min" + MAX = "max" class ScaleType(models.TextChoices): - FIXED = 'fixed' - PROPORTIONAL = 'proportional' + FIXED = "fixed" + PROPORTIONAL = "proportional" class ScalingTechnique(models.TextChoices): - ABSOLUTE = 'absolute' - FLANNERY = 'flannery' + ABSOLUTE = "absolute" + FLANNERY = "flannery" class MapLayerType(models.TextChoices): - OSM_LAYER = 'OSM Layer' - MAPBOX_LAYER = 'Mapbox Layer' - SYMBOL_LAYER = 'Symbol Layer' - POLYGON_LAYER = 'Polygon Layer' - LINE_LAYER = 'Line Layer' - HEAT_MAP_LAYER = 'Heatmap Layer' + OSM_LAYER = "OSM Layer" + MAPBOX_LAYER = "Mapbox Layer" + SYMBOL_LAYER = "Symbol Layer" + POLYGON_LAYER = "Polygon Layer" + LINE_LAYER = "Line Layer" + HEAT_MAP_LAYER = "Heatmap Layer" class LineLayerStrokeType(models.TextChoices): - DASH = 'dash' - SOLID = 'solid' + DASH = "dash" + SOLID = "solid" class AnalysisReportVariableSerializer(serializers.Serializer): @@ -1048,23 +1037,19 @@ class AnalysisReportContainerDataSerializer(TempClientIdMixin, serializers.Model class Meta: model = AnalysisReportContainerData fields = ( - 'id', - 'client_id', - 'client_reference_id', - 'upload', - 'data', + "id", + "client_id", + "client_reference_id", + "upload", + "data", ) def validate_upload(self, upload): - report = self.context.get('report') + report = self.context.get("report") if report is None: - raise serializers.ValidationError( - 'Report needs to be created before assigning uploads to container' - ) + raise serializers.ValidationError("Report needs to be created before assigning uploads to container") if report.id != upload.report_id: - raise serializers.ValidationError( - 'Upload within report are only allowed' - ) + raise serializers.ValidationError("Upload within report are only allowed") return upload @@ -1074,26 +1059,25 @@ class AnalysisReportContainerSerializer(TempClientIdMixin, UserResourceSerialize class Meta: model = AnalysisReportContainer fields = ( - 'id', - 'client_id', - 'row', - 'column', - 'width', - 'height', - 'content_type', + "id", + "client_id", + "row", + "column", + "width", + "height", + "content_type", # Custom - 'style', - 'content_configuration', - 'content_data', + "style", + "content_configuration", + "content_data", ) style = AnalysisReportContainerStyleSerializer(required=False, allow_null=True) # Content metadata - content_configuration = AnalysisReportContainerContentConfigurationSerializer( - required=False, allow_null=True) + content_configuration = AnalysisReportContainerContentConfigurationSerializer(required=False, allow_null=True) - content_data = AnalysisReportContainerDataSerializer(many=True, source='analysisreportcontainerdata_set') + content_data = AnalysisReportContainerDataSerializer(many=True, source="analysisreportcontainerdata_set") # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) # This makes sure only scoped (individual Analysis Report) instances (container data) are updated. @@ -1107,19 +1091,19 @@ class AnalysisReportSerializer(ProjectPropertySerializerMixin, UserResourceSeria class Meta: model = AnalysisReport fields = ( - 'analysis', - 'slug', - 'title', - 'sub_title', - 'is_public', - 'organizations', + "analysis", + "slug", + "title", + "sub_title", + "is_public", + "organizations", # Custom - 'configuration', - 'containers', + "configuration", + "containers", ) configuration = AnalysisReportConfigurationSerializer(required=False, allow_null=True) - containers = AnalysisReportContainerSerializer(many=True, source='analysisreportcontainer_set') + containers = AnalysisReportContainerSerializer(many=True, source="analysisreportcontainer_set") # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) # This makes sure only scoped (individual Analysis Report) instances (containers) are updated. @@ -1131,11 +1115,8 @@ def _get_prefetch_related_instances_qs(self, qs): def validate_analysis(self, analysis): existing_analysis_id = self.instance and self.instance.analysis_id # NOTE: if changed, make sure user have access to that analysis - if ( - analysis.id != existing_analysis_id and - analysis.project_id != self.project.id - ): - raise serializers.ValidationError('You need access to analysis') + if analysis.id != existing_analysis_id and analysis.project_id != self.project.id: + raise serializers.ValidationError("You need access to analysis") return analysis @@ -1143,46 +1124,44 @@ def validate_analysis(self, analysis): class AnalysisReportSnapshotSerializer(ProjectPropertySerializerMixin, serializers.ModelSerializer): class Meta: model = AnalysisReportSnapshot - fields = ( - 'report', - ) + fields = ("report",) serializers.FileField() def validate_report(self, report): if self.project.id != report.analysis.project_id: - raise serializers.ValidationError('Invalid report') + raise serializers.ValidationError("Invalid report") return report def validate(self, data): - report = data['report'] + report = data["report"] snaphost_file, errors = generate_query_snapshot( SnapshotQuery.AnalysisReport.Snapshot, { - 'projectID': str(self.project.id), - 'reportID': str(report.id), + "projectID": str(self.project.id), + "reportID": str(report.id), }, - data_callback=lambda x: x['project']['analysisReport'], - context=GQLContext(self.context['request']), + data_callback=lambda x: x["project"]["analysisReport"], + context=GQLContext(self.context["request"]), ) if snaphost_file is None: logger.error( - f'Failed to generate snapshot for report-pk: {report.id}', - extra={'data': {'errors': errors}}, + f"Failed to generate snapshot for report-pk: {report.id}", + extra={"data": {"errors": errors}}, ) - raise serializers.ValidationError('Failed to generate snapshot') - data['report_data_file'] = snaphost_file - data['published_by'] = self.context['request'].user + raise serializers.ValidationError("Failed to generate snapshot") + data["report_data_file"] = snaphost_file + data["published_by"] = self.context["request"].user return data def create(self, data): instance = super().create(data) # Save file - instance.report_data_file.save(f'{instance.report.id}-{instance.report.slug}.json', data['report_data_file']) + instance.report_data_file.save(f"{instance.report.id}-{instance.report.slug}.json", data["report_data_file"]) return instance def update(self, _): - raise Exception('Not implemented') + raise Exception("Not implemented") # -- Uploads @@ -1217,12 +1196,12 @@ class AnalysisReportUploadSerializer(ProjectPropertySerializerMixin, serializers class Meta: model = AnalysisReportUpload fields = ( - 'id', - 'report', - 'file', - 'type', + "id", + "report", + "file", + "type", # Custom - 'metadata', + "metadata", ) metadata = AnalysisReportUploadMetadataSerializer() @@ -1230,19 +1209,13 @@ class Meta: def validate_file(self, file): existing_file_id = self.instance and self.instance.file_id # NOTE: if changed, make sure only owner can assign files - if ( - file.id != existing_file_id and - file.created_by != self.context['request'].user - ): - raise serializers.ValidationError('Only owner can assign file') + if file.id != existing_file_id and file.created_by != self.context["request"].user: + raise serializers.ValidationError("Only owner can assign file") return file def validate_report(self, report): existing_report_id = self.instance and self.instance.report_id # NOTE: if changed, make sure user have access to that report - if ( - report.id != existing_report_id and - report.analysis.project_id != self.project.id - ): - raise serializers.ValidationError('You need access to report') + if report.id != existing_report_id and report.analysis.project_id != self.project.id: + raise serializers.ValidationError("You need access to report") return report diff --git a/apps/analysis/tasks.py b/apps/analysis/tasks.py index 59722db9eb..2fb3a7fad6 100644 --- a/apps/analysis/tasks.py +++ b/apps/analysis/tasks.py @@ -1,22 +1,22 @@ import logging from celery import shared_task -from django.db import models - -from utils.files import generate_json_file_for_upload from deepl_integration.handlers import ( - AnalysisTopicModelHandler, AnalysisAutomaticSummaryHandler, - AnalyticalStatementNGramHandler, + AnalysisTopicModelHandler, AnalyticalStatementGeoHandler, + AnalyticalStatementNGramHandler, ) - +from django.db import models from entry.models import Entry + +from utils.files import generate_json_file_for_upload + from .models import ( - TopicModel, - AutomaticSummary, - AnalyticalStatementNGram, AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, + TopicModel, ) logger = logging.getLogger(__name__) @@ -27,20 +27,18 @@ def trigger_topic_model(_id): topic_model = TopicModel.objects.get(pk=_id) # Generate entries data file entries_id_qs = list( - topic_model - .get_entries_qs() - .exclude(excerpt='') + topic_model.get_entries_qs().exclude(excerpt="") # TODO: Use original? dropped_excerpt # This is the format which deepl expects # https://docs.google.com/document/d/1NmjOO5sOrhJU6b4QXJBrGAVk57_NW87mLJ9wzeY_NZI/edit#heading=h.cif9hh69nfvz - .values('excerpt', entry_id=models.F('id')) + .values("excerpt", entry_id=models.F("id")) ) payload = { - 'data': entries_id_qs, - 'tags': topic_model.widget_tags, + "data": entries_id_qs, + "tags": topic_model.widget_tags, } topic_model.entries_file.save( - f'{topic_model.id}.json', + f"{topic_model.id}.json", generate_json_file_for_upload(payload), ) # Send trigger request @@ -55,14 +53,14 @@ def trigger_automatic_summary(_id): Entry.objects.filter( project=a_summary.project, id__in=a_summary.entries_id, - ).values('excerpt', entry_id=models.F('id')) + ).values("excerpt", entry_id=models.F("id")) ) payload = { - 'data': entries_data, - 'tags': a_summary.widget_tags, + "data": entries_data, + "tags": a_summary.widget_tags, } a_summary.entries_file.save( - f'{a_summary.id}.json', + f"{a_summary.id}.json", generate_json_file_for_upload(payload), ) AnalysisAutomaticSummaryHandler.send_trigger_request_to_extractor(a_summary) @@ -75,10 +73,10 @@ def trigger_automatic_ngram(_id): Entry.objects.filter( project=a_ngram.project, id__in=a_ngram.entries_id, - ).values('excerpt', entry_id=models.F('id')) + ).values("excerpt", entry_id=models.F("id")) ) a_ngram.entries_file.save( - f'{a_ngram.id}.json', + f"{a_ngram.id}.json", generate_json_file_for_upload(entries_data), ) AnalyticalStatementNGramHandler.send_trigger_request_to_extractor(a_ngram) @@ -91,10 +89,10 @@ def trigger_geo_location(_id): Entry.objects.filter( project=geo_location_task.project, id__in=geo_location_task.entries_id, - ).values('excerpt', entry_id=models.F('id')) + ).values("excerpt", entry_id=models.F("id")) ) geo_location_task.entries_file.save( - f'{geo_location_task.id}.json', + f"{geo_location_task.id}.json", generate_json_file_for_upload(entries_data), ) AnalyticalStatementGeoHandler.send_trigger_request_to_extractor(geo_location_task) diff --git a/apps/analysis/tests/test_apis.py b/apps/analysis/tests/test_apis.py index 023fdab15a..9c57c6c4f7 100644 --- a/apps/analysis/tests/test_apis.py +++ b/apps/analysis/tests/test_apis.py @@ -1,25 +1,21 @@ -from dateutil.relativedelta import relativedelta from unittest.mock import patch -from django.utils import timezone -from django.conf import settings - -from rest_framework.exceptions import ErrorDetail - -from deep.tests import TestCase -from deep.number_generator import client_id_generator -from entry.models import Entry from analysis.models import ( Analysis, AnalysisPillar, AnalyticalStatement, AnalyticalStatementEntry, - DiscardedEntry -) -from organization.models import ( - Organization, - OrganizationType + DiscardedEntry, ) +from dateutil.relativedelta import relativedelta +from django.conf import settings +from django.utils import timezone +from entry.models import Entry +from organization.models import Organization, OrganizationType +from rest_framework.exceptions import ErrorDetail + +from deep.number_generator import client_id_generator +from deep.tests import TestCase class TestAnalysisAPIs(TestCase): @@ -30,30 +26,30 @@ def test_create_analysis_without_pillar(self): project = self.create_project() project.add_member(user) now = timezone.now() - url = f'/api/v1/projects/{project.id}/analysis/' + url = f"/api/v1/projects/{project.id}/analysis/" data = { - 'title': 'Test Analysis', - 'team_lead': user.id, - 'start_date': (now + relativedelta(days=2)).date(), - 'end_date': (now + relativedelta(days=22)).date(), + "title": "Test Analysis", + "team_lead": user.id, + "start_date": (now + relativedelta(days=2)).date(), + "end_date": (now + relativedelta(days=22)).date(), } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) self.assertEqual(Analysis.objects.count(), analysis_count + 1) r_data = response.json() - self.assertEqual(r_data['title'], data['title']) - self.assertEqual(r_data['teamLead'], user.id) + self.assertEqual(r_data["title"], data["title"]) + self.assertEqual(r_data["teamLead"], user.id) def test_create_analysis_with_user_not_project_member(self): user = self.create_user() user2 = self.create_user() project = self.create_project() project.add_member(user) - url = f'/api/v1/projects/{project.id}/analysis/' + url = f"/api/v1/projects/{project.id}/analysis/" data = { - 'title': 'Test Analysis', - 'team_lead': user.id, + "title": "Test Analysis", + "team_lead": user.id, } self.authenticate(user2) response = self.client.post(url, data) @@ -67,18 +63,20 @@ def test_create_pillar_from_analysis_api(self): project = self.create_project() project.add_member(user) now = timezone.now() - url = f'/api/v1/projects/{project.id}/analysis/' + url = f"/api/v1/projects/{project.id}/analysis/" data = { - 'title': 'Test Analysis', - 'team_lead': user.id, - 'start_date': (now + relativedelta(days=2)).date(), - 'end_date': (now + relativedelta(days=22)).date(), - 'analysis_pillar': [{ - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title' - }] + "title": "Test Analysis", + "team_lead": user.id, + "start_date": (now + relativedelta(days=2)).date(), + "end_date": (now + relativedelta(days=22)).date(), + "analysis_pillar": [ + { + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + } + ], } self.authenticate(user) response = self.client.post(url, data) @@ -91,18 +89,18 @@ def test_create_pillar_from_analysis(self): user = self.create_user() project = self.create_project() project.add_member(user) - analysis = self.create(Analysis, title='Test Analysis', project=project, created_by=user) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, title="Test Analysis", project=project, created_by=user) + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title' + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['created_by'], user.id) + self.assertEqual(response.data["created_by"], user.id) self.assertEqual(AnalysisPillar.objects.count(), pillar_count + 1) def test_create_pillar_along_with_statement(self): @@ -114,14 +112,14 @@ def test_create_pillar_along_with_statement(self): project.add_member(user) entry1 = self.create_entry(project=project) entry2 = self.create_entry(project=project) - analysis = self.create(Analysis, project=project, title='Test Analysis') - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, project=project, title="Test Analysis") + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -132,11 +130,7 @@ def test_create_pillar_along_with_statement(self): "client_id": "1-1", "entry": entry1.id, }, - { - "order": 2, - "client_id": "1-2", - "entry": entry2.id - } + {"order": 2, "client_id": "1-2", "entry": entry2.id}, ], }, { @@ -150,24 +144,23 @@ def test_create_pillar_along_with_statement(self): "entry": entry1.id, } ], - } - ] + }, + ], } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) self.assertEqual(AnalysisPillar.objects.count(), pillar_count + 1) - self.assertEqual(AnalyticalStatement.objects.filter( - analysis_pillar__analysis=analysis).count(), statement_count + 2) + self.assertEqual(AnalyticalStatement.objects.filter(analysis_pillar__analysis=analysis).count(), statement_count + 2) # try to edit - response_id = response.data['id'] + response_id = response.data["id"] data = { - 'main_statement': 'HELLO FROM MARS', - 'analytical_statements': [ + "main_statement": "HELLO FROM MARS", + "analytical_statements": [ { - 'statement': "tea", - 'order': 1, + "statement": "tea", + "order": 1, "client_id": "2-1", "analytical_entries": [ { @@ -175,31 +168,28 @@ def test_create_pillar_along_with_statement(self): "client_id": "2-1-1", "entry": entry1.id, }, - { - "order": 2, - "client_id": "2-1-2", - "entry": entry2.id - } + {"order": 2, "client_id": "2-1-2", "entry": entry2.id}, ], }, - ] + ], } self.authenticate(user) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{response_id}/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{response_id}/" response = self.client.patch(url, data) self.assert_200(response) - self.assertEqual(response.data['main_statement'], data['main_statement']) - self.assertEqual(response.data['analytical_statements'][0]['statement'], - data['analytical_statements'][0]['statement']) + self.assertEqual(response.data["main_statement"], data["main_statement"]) + self.assertEqual(response.data["analytical_statements"][0]["statement"], data["analytical_statements"][0]["statement"]) # not passing all the resources the data must be deleted from the database - self.assertEqual(AnalyticalStatement.objects.filter( - analysis_pillar__analysis=analysis).count(), statement_count + 1) - self.assertIn(response.data['analytical_statements'][0]['id'], - list(AnalyticalStatement.objects.filter( - analysis_pillar__analysis=analysis).values_list('id', flat=True)),) + self.assertEqual(AnalyticalStatement.objects.filter(analysis_pillar__analysis=analysis).count(), statement_count + 1) + self.assertIn( + response.data["analytical_statements"][0]["id"], + list(AnalyticalStatement.objects.filter(analysis_pillar__analysis=analysis).values_list("id", flat=True)), + ) # checking for the entries - self.assertEqual(AnalyticalStatementEntry.objects.filter( - analytical_statement__analysis_pillar__analysis=analysis).count(), entry_count + 2) + self.assertEqual( + AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis=analysis).count(), + entry_count + 2, + ) def test_end_date_analysis_greater_than_lead_published_on(self): """ @@ -211,14 +201,14 @@ def test_end_date_analysis_greater_than_lead_published_on(self): now = timezone.now() lead = self.create_lead(project=project, published_on=now + relativedelta(days=6)) entry = self.create_entry(project=project, lead=lead) - analysis = self.create(Analysis, project=project, title='Test Analysis', end_date=now + relativedelta(days=4)) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, project=project, title="Test Analysis", end_date=now + relativedelta(days=4)) + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -229,20 +219,20 @@ def test_end_date_analysis_greater_than_lead_published_on(self): "client_id": "1", "entry": entry.id, }, - ] + ], } - ] + ], } self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) self.assertEqual( - response.data['errors']['analytical_statements'][0]['analytical_entries'][0]['entry'][0], + response.data["errors"]["analytical_statements"][0]["analytical_entries"][0]["entry"][0], ErrorDetail( string=( - f'Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis.end_date.date()}' + f"Entry {entry.id} lead published_on cannot be greater than analysis end_date {analysis.end_date.date()}" ), - code='invalid', + code="invalid", ), ) @@ -253,14 +243,14 @@ def test_analysis_end_date_change(self): now = timezone.now() lead = self.create_lead(project=project, published_on=now + relativedelta(days=2)) entry = self.create_entry(project=project, lead=lead) - analysis = self.create(Analysis, project=project, title='Test Analysis', end_date=now + relativedelta(days=4)) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, project=project, title="Test Analysis", end_date=now + relativedelta(days=4)) + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -271,9 +261,9 @@ def test_analysis_end_date_change(self): "client_id": "1", "entry": entry.id, }, - ] + ], } - ] + ], } self.authenticate(user) response = self.client.post(url, data) @@ -281,14 +271,14 @@ def test_analysis_end_date_change(self): # try to change the analysis end_date and try to patch at the pillar analysis.end_date = now + relativedelta(days=1) analysis.save() - pillar_id = response.data['id'] - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar_id}/' + pillar_id = response.data["id"] + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar_id}/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -299,9 +289,9 @@ def test_analysis_end_date_change(self): "client_id": "1", "entry": entry.id, }, - ] + ], } - ] + ], } self.authenticate(user) response = self.client.patch(url, data) @@ -313,29 +303,22 @@ def test_create_analytical_statement(self): project = self.create_project() project.add_member(user) entry = self.create(Entry) - analysis = self.create(Analysis, title='Test Analysis', project=project) + analysis = self.create(Analysis, title="Test Analysis", project=project) pillar = self.create(AnalysisPillar, analysis=analysis) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar.id}/analytical-statement/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar.id}/analytical-statement/" data = { - "analytical_entries": [ - { - "order": 1, - "client_id": "1", - "entry": entry.id - } - ], + "analytical_entries": [{"order": 1, "client_id": "1", "entry": entry.id}], "statement": "test statement", "order": 1, "client_id": "1", - "analysisPillar": pillar.id + "analysisPillar": pillar.id, } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(AnalyticalStatement.objects.filter( - analysis_pillar__analysis=analysis).count(), statement_count + 1) + self.assertEqual(AnalyticalStatement.objects.filter(analysis_pillar__analysis=analysis).count(), statement_count + 1) r_data = response.json() - self.assertEqual(r_data['statement'], data['statement']) + self.assertEqual(r_data["statement"], data["statement"]) def test_create_analytical_statement_greater_than_30_api_level(self): user = self.create_user() @@ -345,11 +328,11 @@ def test_create_analytical_statement_greater_than_30_api_level(self): analysis = self.create(Analysis, project=project) data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -360,22 +343,23 @@ def test_create_analytical_statement_greater_than_30_api_level(self): "client_id": f"client-id-{index}", "entry": entry.id, } - ] - } for index in range(settings.ANALYTICAL_STATEMENT_COUNT) - ] + ], + } + for index in range(settings.ANALYTICAL_STATEMENT_COUNT) + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) # posting statement greater than `ANALYTICAL_STATEMENT_COUNT` data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -386,15 +370,16 @@ def test_create_analytical_statement_greater_than_30_api_level(self): "client_id": f"client-id-{index}-new", "entry": entry.id, } - ] - } for index in range(settings.ANALYTICAL_STATEMENT_COUNT + 1) - ] + ], + } + for index in range(settings.ANALYTICAL_STATEMENT_COUNT + 1) + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) - self.assertIn('non_field_errors', response.data['errors']) + self.assertIn("non_field_errors", response.data["errors"]) def test_create_analytical_entries_greater_than_50_api_level(self): user = self.create_user() @@ -405,11 +390,11 @@ def test_create_analytical_entries_greater_than_50_api_level(self): analysis = self.create(Analysis, project=project) data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -419,23 +404,24 @@ def test_create_analytical_entries_greater_than_50_api_level(self): "order": 1, "client_id": str(entry.id), "entry": entry.id, - } for entry in entries_list - ] + } + for entry in entries_list + ], } - ] + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) # try posting for entries less than `ANALYTICAL_ENTRIES_COUNT + 1` data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -445,12 +431,13 @@ def test_create_analytical_entries_greater_than_50_api_level(self): "order": 1, "client_id": str(entry.id), "entry": entry.id, - } for entry in entries_list_one_more - ] + } + for entry in entries_list_one_more + ], } - ] + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) @@ -461,47 +448,38 @@ def test_version_change_upon_changes_in_analytical_statement(self): project.add_member(user) self.create_entry(project=project) self.create_entry(project=project) - analysis = self.create(Analysis, title='Test Analysis', project=project) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, title="Test Analysis", project=project) + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ - { - "statement": "coffee", - "order": 1, - "client_id": "1" - }, - ] + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ + {"statement": "coffee", "order": 1, "client_id": "1"}, + ], } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - id = response.data['id'] - statement_id = response.data['analytical_statements'][0]['id'] - self.assertEqual(response.data['version_id'], 1) + id = response.data["id"] + statement_id = response.data["analytical_statements"][0]["id"] + self.assertEqual(response.data["version_id"], 1) # try to patch some changes in analytical_statements data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some not information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ - { - 'id': statement_id, - "statement": "tea", - "order": 1, - "client_id": "123" - }, - ] + "main_statement": "Some main statement", + "information_gap": "Some not information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ + {"id": statement_id, "statement": "tea", "order": 1, "client_id": "123"}, + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{id}/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{id}/" response = self.client.patch(url, data) self.assert_200(response) # after the sucessfull patch the version should change - self.assertEqual(response.data['version_id'], 2) + self.assertEqual(response.data["version_id"], 2) def test_nested_entry_validation(self): user = self.create_user() @@ -509,14 +487,14 @@ def test_nested_entry_validation(self): project.add_member(user) entry1 = self.create_entry(project=project) entry2 = self.create_entry(project=project) - analysis = self.create(Analysis, title='Test Analysis', project=project) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + analysis = self.create(Analysis, title="Test Analysis", project=project) + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -527,11 +505,7 @@ def test_nested_entry_validation(self): "client_id": "1", "entry": entry1.id, }, - { - "order": 2, - "client_id": "2", - "entry": entry2.id - } + {"order": 2, "client_id": "2", "entry": entry2.id}, ], }, { @@ -545,23 +519,23 @@ def test_nested_entry_validation(self): "entry": entry1.id, } ], - } - ] + }, + ], } self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - response_id = response.data['id'] + response_id = response.data["id"] # now try to delete an entry Entry.objects.filter(id=entry2.id).delete() # try to patch data = { - 'main_statement': 'Some main change', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main change", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -572,11 +546,7 @@ def test_nested_entry_validation(self): "client_id": "1", "entry": entry1.id, }, - { - "order": 2, - "client_id": "2", - "entry": entry2.id - } + {"order": 2, "client_id": "2", "entry": entry2.id}, ], }, { @@ -590,15 +560,15 @@ def test_nested_entry_validation(self): "entry": entry1.id, } ], - } - ] + }, + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{response_id}/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{response_id}/" response = self.client.patch(url, data) self.assert_400(response) self.assertEqual( - response.data['errors']['analytical_statements'][0]['analytical_entries'][1]['entry'][0], - ErrorDetail(string=f'Invalid pk "{entry2.id}" - object does not exist.', code='does_not_exist'), + response.data["errors"]["analytical_statements"][0]["analytical_entries"][1]["entry"][0], + ErrorDetail(string=f'Invalid pk "{entry2.id}" - object does not exist.', code="does_not_exist"), ) # TODO: Make sure the error is structured for client # self.assertEqual( @@ -635,24 +605,16 @@ def test_summary_for_analysis(self): entry10 = self.create_entry(lead=lead8, project=project) analysis1 = self.create( - Analysis, - title='Test Analysis', - team_lead=user, - project=project, - end_date=now + relativedelta(days=4) + Analysis, title="Test Analysis", team_lead=user, project=project, end_date=now + relativedelta(days=4) ) analysis2 = self.create( - Analysis, - title='Not for test', - team_lead=user, - project=project, - end_date=now + relativedelta(days=7) + Analysis, title="Not for test", team_lead=user, project=project, end_date=now + relativedelta(days=7) ) - pillar1 = self.create(AnalysisPillar, analysis=analysis1, title='title1', assignee=user) - pillar2 = self.create(AnalysisPillar, analysis=analysis1, title='title2', assignee=user) - pillar3 = self.create(AnalysisPillar, analysis=analysis1, title='title3', assignee=user2) + pillar1 = self.create(AnalysisPillar, analysis=analysis1, title="title1", assignee=user) + pillar2 = self.create(AnalysisPillar, analysis=analysis1, title="title2", assignee=user) + pillar3 = self.create(AnalysisPillar, analysis=analysis1, title="title3", assignee=user2) - pillar4 = self.create(AnalysisPillar, analysis=analysis2, title='title3', assignee=user2) + pillar4 = self.create(AnalysisPillar, analysis=analysis2, title="title3", assignee=user2) # lets analyze all the entries here analytical_statement1 = self.create(AnalyticalStatement, analysis_pillar=pillar1) @@ -665,86 +627,64 @@ def test_summary_for_analysis(self): self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement1, entry=entry8) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement1, entry=entry10) # lets discard some entry here - DiscardedEntry.objects.create( - analysis_pillar=pillar1, - entry=entry3, - tag=DiscardedEntry.TagType.REDUNDANT - ) - DiscardedEntry.objects.create( - analysis_pillar=pillar1, - entry=entry9, - tag=DiscardedEntry.TagType.REDUNDANT - ) + DiscardedEntry.objects.create(analysis_pillar=pillar1, entry=entry3, tag=DiscardedEntry.TagType.REDUNDANT) + DiscardedEntry.objects.create(analysis_pillar=pillar1, entry=entry9, tag=DiscardedEntry.TagType.REDUNDANT) analytical_statement2 = self.create(AnalyticalStatement, analysis_pillar=pillar2) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement2, entry=entry4) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement2, entry=entry8) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement2, entry=entry10) - DiscardedEntry.objects.create( - analysis_pillar=pillar2, - entry=entry5, - tag=DiscardedEntry.TagType.REDUNDANT - ) + DiscardedEntry.objects.create(analysis_pillar=pillar2, entry=entry5, tag=DiscardedEntry.TagType.REDUNDANT) analytical_statement3 = self.create(AnalyticalStatement, analysis_pillar=pillar3) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3, entry=entry5) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3, entry=entry6) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3, entry=entry8) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3, entry=entry10) - DiscardedEntry.objects.create( - analysis_pillar=pillar3, - entry=entry2, - tag=DiscardedEntry.TagType.REDUNDANT - ) + DiscardedEntry.objects.create(analysis_pillar=pillar3, entry=entry2, tag=DiscardedEntry.TagType.REDUNDANT) analytical_statement4 = self.create(AnalyticalStatement, analysis_pillar=pillar4) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement4, entry=entry) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement4, entry=entry1) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement4, entry=entry2) - DiscardedEntry.objects.create( - analysis_pillar=pillar4, - entry=entry8, - tag=DiscardedEntry.TagType.REDUNDANT - ) + DiscardedEntry.objects.create(analysis_pillar=pillar4, entry=entry8, tag=DiscardedEntry.TagType.REDUNDANT) - url = f'/api/v1/projects/{project.id}/analysis/summary/' + url = f"/api/v1/projects/{project.id}/analysis/summary/" self.authenticate(user) response = self.client.get(url) self.assert_200(response) - data = response.data['results'] - self.assertEqual(data[1]['team_lead'], user.id) - self.assertEqual(data[1]['end_date'], analysis1.end_date.strftime('%Y-%m-%d')) - self.assertEqual(data[1]['team_lead_details']['id'], user.id) - self.assertEqual(data[1]['team_lead_details']['display_name'], user.profile.get_display_name()) - self.assertEqual(data[1]['pillars'][0]['id'], pillar3.id) - self.assertEqual(data[1]['pillars'][1]['title'], pillar2.title) - self.assertEqual( - data[1]['pillars'][2]['assignee_details']['display_name'], pillar1.assignee.profile.get_display_name() - ) + data = response.data["results"] + self.assertEqual(data[1]["team_lead"], user.id) + self.assertEqual(data[1]["end_date"], analysis1.end_date.strftime("%Y-%m-%d")) + self.assertEqual(data[1]["team_lead_details"]["id"], user.id) + self.assertEqual(data[1]["team_lead_details"]["display_name"], user.profile.get_display_name()) + self.assertEqual(data[1]["pillars"][0]["id"], pillar3.id) + self.assertEqual(data[1]["pillars"][1]["title"], pillar2.title) + self.assertEqual(data[1]["pillars"][2]["assignee_details"]["display_name"], pillar1.assignee.profile.get_display_name()) self.assertEqual( - data[1]['publication_date']['start_date'], lead6.published_on.strftime('%Y-%m-%d') + data[1]["publication_date"]["start_date"], lead6.published_on.strftime("%Y-%m-%d") ) # since we use lead that has entry created for - self.assertEqual(data[1]['publication_date']['end_date'], lead5.published_on.strftime('%Y-%m-%d')) - self.assertEqual(data[1]['pillars'][0]['analyzed_entries'], 5) # discrded + analyzed entry - self.assertEqual(data[1]['pillars'][1]['analyzed_entries'], 4) # discrded + analyzed entry + self.assertEqual(data[1]["publication_date"]["end_date"], lead5.published_on.strftime("%Y-%m-%d")) + self.assertEqual(data[1]["pillars"][0]["analyzed_entries"], 5) # discrded + analyzed entry + self.assertEqual(data[1]["pillars"][1]["analyzed_entries"], 4) # discrded + analyzed entry # here considering the entry whose lead published date less than analysis end_date # also when analyzed all entries in ceratin pillar and not all in next pillar - self.assertEqual(data[1]['analyzed_entries'], 10) - self.assertEqual(data[1]['analyzed_sources'], 8) # have `distinct=True` - self.assertEqual(data[1]['total_entries'], 10) - self.assertEqual(data[1]['total_sources'], 8) # taking lead that has entry more than one - self.assertEqual(data[0]['team_lead'], user.id) - self.assertEqual(data[0]['team_lead_details']['id'], user.id) - self.assertEqual(data[0]['team_lead_details']['display_name'], user.profile.get_display_name()) - self.assertEqual(data[0]['pillars'][0]['id'], pillar4.id) - self.assertEqual(data[0]['analyzed_entries'], 4) - self.assertEqual(data[0]['analyzed_sources'], 4) + self.assertEqual(data[1]["analyzed_entries"], 10) + self.assertEqual(data[1]["analyzed_sources"], 8) # have `distinct=True` + self.assertEqual(data[1]["total_entries"], 10) + self.assertEqual(data[1]["total_sources"], 8) # taking lead that has entry more than one + self.assertEqual(data[0]["team_lead"], user.id) + self.assertEqual(data[0]["team_lead_details"]["id"], user.id) + self.assertEqual(data[0]["team_lead_details"]["display_name"], user.profile.get_display_name()) + self.assertEqual(data[0]["pillars"][0]["id"], pillar4.id) + self.assertEqual(data[0]["analyzed_entries"], 4) + self.assertEqual(data[0]["analyzed_sources"], 4) # Should be same in each analysis of the project - self.assertEqual(data[1]['total_entries'], 10) - self.assertEqual(data[1]['total_sources'], 8) + self.assertEqual(data[1]["total_entries"], 10) + self.assertEqual(data[1]["total_sources"], 8) # try to post to api - data = {'team_lead': user.id} + data = {"team_lead": user.id} self.authenticate(user) response = self.client.post(url, data) self.assert_405(response) @@ -754,7 +694,7 @@ def test_summary_for_analysis(self): response = self.client.get(url) self.assert_403(response) - @patch('analysis.models.client_id_generator', side_effect=client_id_generator) + @patch("analysis.models.client_id_generator", side_effect=client_id_generator) def test_clone_analysis(self, client_id_mock_func): user = self.create_user() user2 = self.create_user() @@ -763,85 +703,46 @@ def test_clone_analysis(self, client_id_mock_func): entry = self.create_entry(project=project) entry1 = self.create_entry(project=project) analysis = self.create(Analysis, project=project, title="Test Clone") - pillar = self.create(AnalysisPillar, analysis=analysis, title='title1', assignee=user) + pillar = self.create(AnalysisPillar, analysis=analysis, title="title1", assignee=user) analytical_statement = self.create( - AnalyticalStatement, - analysis_pillar=pillar, - statement='Hello from here', - client_id='1' - ) - self.create( - AnalyticalStatementEntry, - analytical_statement=analytical_statement, - entry=entry, - order=1, - client_id='1' - ) - self.create( - DiscardedEntry, - entry=entry1, - analysis_pillar=pillar, - tag=DiscardedEntry.TagType.REDUNDANT + AnalyticalStatement, analysis_pillar=pillar, statement="Hello from here", client_id="1" ) + self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement, entry=entry, order=1, client_id="1") + self.create(DiscardedEntry, entry=entry1, analysis_pillar=pillar, tag=DiscardedEntry.TagType.REDUNDANT) - url = url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/clone/' + url = url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/clone/" # try to post with no end_date data = { - 'title': 'cloned_title', + "title": "cloned_title", } self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) - assert 'end_date' in response.data - self.assertEqual( - response.data['end_date'], - [ErrorDetail(string='This field is required.', code='required')] - ) + assert "end_date" in response.data + self.assertEqual(response.data["end_date"], [ErrorDetail(string="This field is required.", code="required")]) # try to post with start_date greater than end_date - data = { - 'title': 'cloned_title', - 'end_date': '2020-10-01', - 'start_date': '2020-10-20' - } + data = {"title": "cloned_title", "end_date": "2020-10-01", "start_date": "2020-10-20"} self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) - self.assertEqual( - response.data['end_date'], - [ErrorDetail(string='End date must occur after start date', code='invalid')] - ) + self.assertEqual(response.data["end_date"], [ErrorDetail(string="End date must occur after start date", code="invalid")]) - data.pop('start_date') + data.pop("start_date") self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - self.assertNotEqual(response.data['id'], analysis.id) - self.assertEqual(response.data['title'], data['title']) - self.assertEqual(response.data['cloned_from'], analysis.id) - self.assertEqual(response.data['analysis_pillar'][0]['cloned_from'], pillar.id) + self.assertNotEqual(response.data["id"], analysis.id) + self.assertEqual(response.data["title"], data["title"]) + self.assertEqual(response.data["cloned_from"], analysis.id) + self.assertEqual(response.data["analysis_pillar"][0]["cloned_from"], pillar.id) assert client_id_mock_func.called # test if the nested fields are cloned or not - self.assertEqual( - Analysis.objects.count(), - 2 - ) # need to be cloned and created by user - self.assertEqual( - AnalysisPillar.objects.count(), - 2 - ) - self.assertEqual( - AnalyticalStatement.objects.count(), - 2 - ) - self.assertEqual( - AnalyticalStatementEntry.objects.count(), - 2 - ) - self.assertEqual( - DiscardedEntry.objects.count(), - 2 - ) + self.assertEqual(Analysis.objects.count(), 2) # need to be cloned and created by user + self.assertEqual(AnalysisPillar.objects.count(), 2) + self.assertEqual(AnalyticalStatement.objects.count(), 2) + self.assertEqual(AnalyticalStatementEntry.objects.count(), 2) + self.assertEqual(DiscardedEntry.objects.count(), 2) # authenticating with user that is not project member self.authenticate(user2) response = self.client.post(url, data) @@ -854,45 +755,30 @@ def test_patch_analytical_statement(self): entry1 = self.create(Entry, project=project) entry2 = self.create(Entry, project=project) analysis = self.create(Analysis, project=project) - pillar = self.create(AnalysisPillar, analysis=analysis, title='title1', assignee=user) + pillar = self.create(AnalysisPillar, analysis=analysis, title="title1", assignee=user) analytical_statement = self.create( - AnalyticalStatement, - analysis_pillar=pillar, - statement='Hello from here', - client_id='1' + AnalyticalStatement, analysis_pillar=pillar, statement="Hello from here", client_id="1" ) statement_entry1 = self.create( - AnalyticalStatementEntry, - analytical_statement=analytical_statement, - entry=entry1, - order=1, - client_id='1' + AnalyticalStatementEntry, analytical_statement=analytical_statement, entry=entry1, order=1, client_id="1" ) statement_entry2 = self.create( - AnalyticalStatementEntry, - analytical_statement=analytical_statement, - entry=entry2, - order=2, - client_id='2' + AnalyticalStatementEntry, analytical_statement=analytical_statement, entry=entry2, order=2, client_id="2" ) - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar.id}/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/{pillar.id}/" data = { - 'analytical_statements': [ + "analytical_statements": [ { - 'id': analytical_statement.id, + "id": analytical_statement.id, "client_id": str(analytical_statement.id), - 'statement': 'Hello from there', + "statement": "Hello from there", "analytical_entries": [ { "id": statement_entry1.pk, "order": 1, "entry": entry1.id, }, - { - "id": statement_entry2.pk, - "order": 3, - "entry": entry2.id - } + {"id": statement_entry2.pk, "order": 3, "entry": entry2.id}, ], } ] @@ -900,13 +786,13 @@ def test_patch_analytical_statement(self): self.authenticate(user) response = self.client.patch(url, data) self.assert_200(response) - self.assertEqual(response.data['analytical_statements'][0]['id'], analytical_statement.id) + self.assertEqual(response.data["analytical_statements"][0]["id"], analytical_statement.id) self.assertEqual( - response.data['analytical_statements'][0]['analytical_entries'][0]['entry'], + response.data["analytical_statements"][0]["analytical_entries"][0]["entry"], statement_entry1.entry.id, ) self.assertEqual( - response.data['analytical_statements'][0]['analytical_entries'][1]['entry'], + response.data["analytical_statements"][0]["analytical_entries"][1]["entry"], statement_entry2.entry.id, ) @@ -919,9 +805,9 @@ def test_pillar_overview_in_analysis(self): entry2 = self.create_entry(project=project) project.add_member(user) - analysis1 = self.create(Analysis, title='Test Analysis', team_lead=user, project=project) - pillar1 = self.create(AnalysisPillar, analysis=analysis1, title='title1', assignee=user) - pillar2 = self.create(AnalysisPillar, analysis=analysis1, title='title2', assignee=user) + analysis1 = self.create(Analysis, title="Test Analysis", team_lead=user, project=project) + pillar1 = self.create(AnalysisPillar, analysis=analysis1, title="title1", assignee=user) + pillar2 = self.create(AnalysisPillar, analysis=analysis1, title="title2", assignee=user) analytical_statement1 = self.create(AnalyticalStatement, analysis_pillar=pillar1) analytical_statement2 = self.create(AnalyticalStatement, analysis_pillar=pillar1) @@ -932,15 +818,15 @@ def test_pillar_overview_in_analysis(self): analytical_statement3 = self.create(AnalyticalStatement, analysis_pillar=pillar2) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3) - url = f'/api/v1/projects/{project.id}/analysis/{analysis1.id}/pillars/summary/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis1.id}/pillars/summary/" self.authenticate(user) response = self.client.get(url) self.assert_200(response) - data = response.data['results'] - self.assertEqual(data[0]['title'], pillar2.title) - self.assertEqual(len(data[0]['analytical_statements']), 1) - self.assertEqual(data[0]['analytical_statements'][0]['entries_count'], 1) - self.assertEqual(len(data[1]['analytical_statements']), 2) + data = response.data["results"] + self.assertEqual(data[0]["title"], pillar2.title) + self.assertEqual(len(data[0]["analytical_statements"]), 1) + self.assertEqual(data[0]["analytical_statements"][0]["entries_count"], 1) + self.assertEqual(len(data[1]["analytical_statements"]), 2) # try get pillar-overview by user that is not member of project self.authenticate(user2) @@ -953,17 +839,17 @@ def test_analysis_overview_in_project(self): project = self.create_project() project.add_member(user) - organization_type1 = self.create(OrganizationType, title='OrgA') - organization_type2 = self.create(OrganizationType, title='Orgb') + organization_type1 = self.create(OrganizationType, title="OrgA") + organization_type2 = self.create(OrganizationType, title="Orgb") - organization1 = self.create(Organization, title='UN', organization_type=organization_type1) - organization2 = self.create(Organization, title='RED CROSS', organization_type=organization_type2) - organization3 = self.create(Organization, title='ToggleCorp', organization_type=organization_type1) + organization1 = self.create(Organization, title="UN", organization_type=organization_type1) + organization2 = self.create(Organization, title="RED CROSS", organization_type=organization_type2) + organization3 = self.create(Organization, title="ToggleCorp", organization_type=organization_type1) - lead1 = self.create_lead(authors=[organization1], project=project, title='TESTA') - lead2 = self.create_lead(authors=[organization2, organization3], project=project, title='TESTB') - lead3 = self.create_lead(authors=[organization3], project=project, title='TESTC') - self.create_lead(authors=[organization2], project=project, title='TESTD') + lead1 = self.create_lead(authors=[organization1], project=project, title="TESTA") + lead2 = self.create_lead(authors=[organization2, organization3], project=project, title="TESTB") + lead3 = self.create_lead(authors=[organization3], project=project, title="TESTC") + self.create_lead(authors=[organization2], project=project, title="TESTD") entry1 = self.create_entry(lead=lead1, project=project) entry2 = self.create_entry(lead=lead2, project=project) @@ -971,43 +857,36 @@ def test_analysis_overview_in_project(self): self.create_entry(lead=lead3, project=project) entry4 = self.create_entry(lead=lead2, project=project) - analysis1 = self.create(Analysis, title='Test Analysis', team_lead=user, project=project) - analysis2 = self.create(Analysis, title='Test Analysis New', team_lead=user, project=project) - pillar1 = self.create(AnalysisPillar, analysis=analysis1, title='title1') - pillar2 = self.create(AnalysisPillar, analysis=analysis2, title='title2') + analysis1 = self.create(Analysis, title="Test Analysis", team_lead=user, project=project) + analysis2 = self.create(Analysis, title="Test Analysis New", team_lead=user, project=project) + pillar1 = self.create(AnalysisPillar, analysis=analysis1, title="title1") + pillar2 = self.create(AnalysisPillar, analysis=analysis2, title="title2") analytical_statement1 = self.create(AnalyticalStatement, analysis_pillar=pillar1) analytical_statement2 = self.create(AnalyticalStatement, analysis_pillar=pillar1) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement1, entry=entry1) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement2, entry=entry2) - DiscardedEntry.objects.create( - analysis_pillar=pillar1, - entry=entry4, - tag=DiscardedEntry.TagType.REDUNDANT - ) + DiscardedEntry.objects.create(analysis_pillar=pillar1, entry=entry4, tag=DiscardedEntry.TagType.REDUNDANT) analytical_statement3 = self.create(AnalyticalStatement, analysis_pillar=pillar2) self.create(AnalyticalStatementEntry, analytical_statement=analytical_statement3, entry=entry3) - url = f'/api/v1/projects/{project.id}/analysis-overview/' + url = f"/api/v1/projects/{project.id}/analysis-overview/" self.authenticate(user) response = self.client.get(url) self.assert_200(response) data = response.data - self.assertEqual(len(data['analysis_list']), 2) - self.assertEqual(data['analysis_list'][1]['title'], analysis1.title) - self.assertEqual(data['entries_total'], 5) - self.assertEqual(data['sources_total'], 3) # since we take only that lead which entry has been created - self.assertEqual(data['analyzed_source_count'], 3) # since we take entry - self.assertEqual(data['analyzed_entries_count'], 4) # discarded + analyzed - self.assertEqual(len(data['authoring_organizations']), 2) - self.assertIn(organization_type1.id, [item['organization_type_id'] for item in data['authoring_organizations']]) - self.assertIn( - organization_type1.title, - [item['organization_type_title'] for item in data['authoring_organizations']] - ) - self.assertEqual(set([item['count'] for item in data['authoring_organizations']]), set([1, 3])) + self.assertEqual(len(data["analysis_list"]), 2) + self.assertEqual(data["analysis_list"][1]["title"], analysis1.title) + self.assertEqual(data["entries_total"], 5) + self.assertEqual(data["sources_total"], 3) # since we take only that lead which entry has been created + self.assertEqual(data["analyzed_source_count"], 3) # since we take entry + self.assertEqual(data["analyzed_entries_count"], 4) # discarded + analyzed + self.assertEqual(len(data["authoring_organizations"]), 2) + self.assertIn(organization_type1.id, [item["organization_type_id"] for item in data["authoring_organizations"]]) + self.assertIn(organization_type1.title, [item["organization_type_title"] for item in data["authoring_organizations"]]) + self.assertEqual(set([item["count"] for item in data["authoring_organizations"]]), set([1, 3])) # authenticate with user that is not project member self.authenticate(user2) @@ -1024,18 +903,15 @@ def test_post_discarded_entries_in_analysis_pillar(self): entry = self.create_entry(project=project, lead=lead) analysis = self.create(Analysis, project=project, end_date=now + relativedelta(days=2)) pillar1 = self.create(AnalysisPillar, analysis=analysis) - data = { - 'entry': entry.id, - 'tag': DiscardedEntry.TagType.REDUNDANT - } - url = f'/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/' + data = {"entry": entry.id, "tag": DiscardedEntry.TagType.REDUNDANT} + url = f"/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/" self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['analysis_pillar'], pillar1.id) - self.assertEqual(response.data['entry'], entry.id) - self.assertIn('entry_details', response.data) - self.assertEqual(response.data['entry_details']['id'], entry.id) + self.assertEqual(response.data["analysis_pillar"], pillar1.id) + self.assertEqual(response.data["entry"], entry.id) + self.assertIn("entry_details", response.data) + self.assertEqual(response.data["entry_details"]["id"], entry.id) # try to authenticate with user that is not project member user2 = self.create_user() @@ -1044,11 +920,8 @@ def test_post_discarded_entries_in_analysis_pillar(self): self.assert_403(response) entry1 = self.create_entry(project=project, lead=lead) - data = { - 'entry': entry1.id, - 'tag': DiscardedEntry.TagType.REDUNDANT - } - url = f'/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/' + data = {"entry": entry1.id, "tag": DiscardedEntry.TagType.REDUNDANT} + url = f"/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/" self.authenticate(user2) response = self.client.post(url, data) self.assert_403(response) @@ -1059,21 +932,18 @@ def test_post_discarded_entries_in_analysis_pillar(self): project2.add_member(user2) entry = self.create_entry(project=project2) data = { - 'entry': entry.id, - 'tag': DiscardedEntry.TagType.REDUNDANT, + "entry": entry.id, + "tag": DiscardedEntry.TagType.REDUNDANT, } - url = f'/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/' + url = f"/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/" self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) # try to post the entry with lead published date greater than the analysis end_date entry2 = self.create_entry(project=project, lead=lead1) - data = { - 'entry': entry2.id, - 'tag': DiscardedEntry.TagType.REDUNDANT - } - url = f'/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/' + data = {"entry": entry2.id, "tag": DiscardedEntry.TagType.REDUNDANT} + url = f"/api/v1/analysis-pillar/{pillar1.id}/discarded-entries/" self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) @@ -1089,11 +959,11 @@ def test_discarded_entries_tag_filter(self): self.create(DiscardedEntry, analysis_pillar=pillar, tag=DiscardedEntry.TagType.TOO_OLD) self.create(DiscardedEntry, analysis_pillar=pillar, tag=DiscardedEntry.TagType.OUTLIER) - url = f'/api/v1/analysis-pillar/{pillar.id}/discarded-entries/?tag={DiscardedEntry.TagType.TOO_OLD.value}' + url = f"/api/v1/analysis-pillar/{pillar.id}/discarded-entries/?tag={DiscardedEntry.TagType.TOO_OLD.value}" self.authenticate(user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) # Two discarded entries be present + self.assertEqual(len(response.data["results"]), 2) # Two discarded entries be present # filter by member that is not project member user2 = self.create_user() @@ -1110,9 +980,9 @@ def test_all_entries_in_analysis_pillar(self): now = timezone.now() analysis = self.create(Analysis, project=project, end_date=now) pillar = self.create(AnalysisPillar, analysis=analysis) - lead1 = self.create_lead(project=project, title='TESTA', published_on=now + relativedelta(days=2)) - lead2 = self.create_lead(project=project, title='TESTA', published_on=now + relativedelta(days=-4)) - lead3 = self.create_lead(project=project, title='TESTA', published_on=now + relativedelta(days=-2)) + lead1 = self.create_lead(project=project, title="TESTA", published_on=now + relativedelta(days=2)) + lead2 = self.create_lead(project=project, title="TESTA", published_on=now + relativedelta(days=-4)) + lead3 = self.create_lead(project=project, title="TESTA", published_on=now + relativedelta(days=-2)) entry1 = self.create(Entry, project=project, lead=lead2) entry2 = self.create(Entry, project=project, lead=lead2) entry3 = self.create(Entry, project=project, lead=lead3) @@ -1120,61 +990,46 @@ def test_all_entries_in_analysis_pillar(self): self.create(Entry, project=project2, lead=lead3) # Check the entry count - analysis_pillar_entries_url = f'/api/v1/analysis-pillar/{pillar.id}/entries/' + analysis_pillar_entries_url = f"/api/v1/analysis-pillar/{pillar.id}/entries/" self.authenticate(user) response = self.client.post(analysis_pillar_entries_url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 3) # this should list all the entries present + self.assertEqual(len(response.data["results"]), 3) # this should list all the entries present # now try to discard the entry from the discarded entries api - data = { - 'entry': entry1.id, - 'tag': DiscardedEntry.TagType.REDUNDANT - } + data = {"entry": entry1.id, "tag": DiscardedEntry.TagType.REDUNDANT} self.authenticate(user) - response = self.client.post(f'/api/v1/analysis-pillar/{pillar.id}/discarded-entries/', data) + response = self.client.post(f"/api/v1/analysis-pillar/{pillar.id}/discarded-entries/", data) self.assert_201(response) # try checking the entries that are discarded self.authenticate(user) - response = self.post_filter_test(analysis_pillar_entries_url, {'discarded': True}, count=1) - response_id = [res['id'] for res in response.data['results']] + response = self.post_filter_test(analysis_pillar_entries_url, {"discarded": True}, count=1) + response_id = [res["id"] for res in response.data["results"]] self.assertIn(entry1.id, response_id) # try checking the entries that are not discarded self.authenticate(user) - response = self.post_filter_test(analysis_pillar_entries_url, {'discarded': False}, count=2) - response_id = [res['id'] for res in response.data['results']] + response = self.post_filter_test(analysis_pillar_entries_url, {"discarded": False}, count=2) + response_id = [res["id"] for res in response.data["results"]] self.assertNotIn(entry1.id, response_id) # try to exclude some entries self.authenticate(user) - data = { - 'exclude_entries': [entry2.id, entry3.id] - } + data = {"exclude_entries": [entry2.id, entry3.id]} response = self.post_filter_test(analysis_pillar_entries_url, data, count=0) self.assert_200(response) def test_discardedentry_options(self): - url = '/api/v1/discarded-entry-options/' + url = "/api/v1/discarded-entry-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual( - response.data[0]['key'], - DiscardedEntry.TagType.REDUNDANT) - self.assertEqual( - response.data[0]['value'], - DiscardedEntry.TagType.REDUNDANT.label - ) - self.assertEqual( - response.data[1]['key'], - DiscardedEntry.TagType.TOO_OLD) - self.assertEqual( - response.data[1]['value'], - DiscardedEntry.TagType.TOO_OLD.label - ) + self.assertEqual(response.data[0]["key"], DiscardedEntry.TagType.REDUNDANT) + self.assertEqual(response.data[0]["value"], DiscardedEntry.TagType.REDUNDANT.label) + self.assertEqual(response.data[1]["key"], DiscardedEntry.TagType.TOO_OLD) + self.assertEqual(response.data[1]["value"], DiscardedEntry.TagType.TOO_OLD.label) def test_add_same_entries_in_multiple_analytical_statements(self): user = self.create_user() @@ -1184,14 +1039,14 @@ def test_add_same_entries_in_multiple_analytical_statements(self): entry2 = self.create_entry(project=project) project.add_member(user) - analysis = self.create(Analysis, title='Test Analysis', team_lead=user, project=project) + analysis = self.create(Analysis, title="Test Analysis", team_lead=user, project=project) data = { - 'main_statement': 'Some main statement', - 'information_gap': 'Some information gap', - 'assignee': user.id, - 'title': 'Some title', - 'analytical_statements': [ + "main_statement": "Some main statement", + "information_gap": "Some information gap", + "assignee": user.id, + "title": "Some title", + "analytical_statements": [ { "statement": "coffee", "order": 1, @@ -1207,7 +1062,7 @@ def test_add_same_entries_in_multiple_analytical_statements(self): "client_id": "2", "entry": entry2.id, }, - ] + ], }, { "statement": "tea", @@ -1224,11 +1079,11 @@ def test_add_same_entries_in_multiple_analytical_statements(self): "client_id": "5", "entry": entry2.id, }, - ] + ], }, - ] + ], } - url = f'/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/' + url = f"/api/v1/projects/{project.id}/analysis/{analysis.id}/pillars/" self.authenticate(user) response = self.client.post(url, data) self.assert_201(response) diff --git a/apps/analysis/tests/test_mutations.py b/apps/analysis/tests/test_mutations.py index da7295854f..39a9e30a57 100644 --- a/apps/analysis/tests/test_mutations.py +++ b/apps/analysis/tests/test_mutations.py @@ -1,33 +1,32 @@ -import os import datetime import json +import os from unittest import mock -from utils.graphene.tests import GraphQLTestCase - -from deepl_integration.handlers import AnalysisAutomaticSummaryHandler -from deepl_integration.serializers import DeeplServerBaseCallbackSerializer -from commons.schema_snapshots import SnapshotQuery -from user.factories import UserFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory -from analysis_framework.factories import AnalysisFrameworkFactory from analysis.factories import ( AnalysisFactory, AnalysisPillarFactory, AnalysisReportFactory, AnalysisReportUploadFactory, ) - from analysis.models import ( + AnalysisReportSnapshot, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, TopicModel, TopicModelCluster, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, - AnalysisReportSnapshot, ) +from analysis_framework.factories import AnalysisFrameworkFactory +from commons.schema_snapshots import SnapshotQuery +from deepl_integration.handlers import AnalysisAutomaticSummaryHandler +from deepl_integration.serializers import DeeplServerBaseCallbackSerializer +from entry.factories import EntryFactory +from lead.factories import LeadFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestAnalysisNlpMutationSchema(GraphQLTestCase): @@ -41,7 +40,7 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): ENABLE_NOW_PATCHER = True - TRIGGER_TOPIC_MODEL = ''' + TRIGGER_TOPIC_MODEL = """ mutation MyMutation ($projectId: ID!, $input: AnalysisTopicModelCreateInputType!) { project(id: $projectId) { triggerAnalysisTopicModel(data: $input) { @@ -61,9 +60,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - QUERY_TOPIC_MODEL = ''' + QUERY_TOPIC_MODEL = """ query MyQuery ($projectId: ID!, $topicModelID: ID!) { project(id: $projectId) { analysisTopicModel(id: $topicModelID) { @@ -78,9 +77,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - TRIGGER_AUTOMATIC_SUMMARY = ''' + TRIGGER_AUTOMATIC_SUMMARY = """ mutation MyMutation ($projectId: ID!, $input: AnalysisAutomaticSummaryCreateInputType!) { project(id: $projectId) { triggerAnalysisAutomaticSummary(data: $input) { @@ -94,9 +93,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - QUERY_AUTOMATIC_SUMMARY = ''' + QUERY_AUTOMATIC_SUMMARY = """ query MyQuery ($projectId: ID!, $summaryID: ID!) { project(id: $projectId) { analysisAutomaticSummary(id: $summaryID) { @@ -106,9 +105,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - TRIGGER_AUTOMATIC_NGRAM = ''' + TRIGGER_AUTOMATIC_NGRAM = """ mutation MyMutation ($projectId: ID!, $input: AnalyticalStatementNGramCreateInputType!) { project(id: $projectId) { triggerAnalysisAutomaticNgram(data: $input) { @@ -133,9 +132,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - QUERY_AUTOMATIC_NGRAM = ''' + QUERY_AUTOMATIC_NGRAM = """ query MyQuery ($projectId: ID!, $ngramID: ID!) { project(id: $projectId) { analysisAutomaticNgram(id: $ngramID) { @@ -156,9 +155,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - TRIGGER_GEOLOCATION = ''' + TRIGGER_GEOLOCATION = """ mutation MyMutation ($projectId: ID!, $input: AnalyticalStatementGeoTaskInputType!) { project(id: $projectId) { triggerAnalysisGeoLocation(data: $input) { @@ -183,9 +182,9 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ - QUERY_GEOLOCATION = ''' + QUERY_GEOLOCATION = """ query MyQuery ($projectId: ID!, $ID: ID!) { project(id: $projectId) { analysisGeoTask(id: $ID) { @@ -206,7 +205,7 @@ class TestAnalysisNlpMutationSchema(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -223,8 +222,8 @@ def _check_status(self, obj, status): obj.refresh_from_db() self.assertEqual(obj.status, status) - @mock.patch('deepl_integration.handlers.RequestHelper') - @mock.patch('deepl_integration.handlers.requests') + @mock.patch("deepl_integration.handlers.RequestHelper") + @mock.patch("deepl_integration.handlers.requests") def test_topic_model(self, trigger_results_mock, RequestHelperMock): analysis = AnalysisFactory.create( project=self.project, @@ -259,10 +258,10 @@ def nlp_validator_mock(url, data=None, json=None, **kwargs): # Get payload from file payload = self.get_json_media_file( - json['entries_url'].split('http://testserver/media/')[1], + json["entries_url"].split("http://testserver/media/")[1], ) # TODO: Need to check the Child fields of data and File payload as well - expected_keys = ['data', 'tags'] + expected_keys = ["data", "tags"] if set(payload.keys()) != set(expected_keys): return mock.MagicMock(status_code=400) return mock.MagicMock(status_code=202) @@ -274,33 +273,29 @@ def nlp_fail_mock(*args, **kwargs): def _mutation_check(minput, **kwargs): return self.query_check( - self.TRIGGER_TOPIC_MODEL, - minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + self.TRIGGER_TOPIC_MODEL, minput=minput, mnested=["project"], variables={"projectId": self.project.id}, **kwargs ) def _query_check(_id): return self.query_check( self.QUERY_TOPIC_MODEL, minput=minput, - variables={'projectId': self.project.id, 'topicModelID': _id}, + variables={"projectId": self.project.id, "topicModelID": _id}, ) minput = dict( - analysisPillar='0', # Non existing ID + analysisPillar="0", # Non existing ID additionalFilters=dict( filterableData=[ dict( - filterKey='random-key', - value='random-value', + filterKey="random-key", + value="random-value", ) ], ), widgetTags=[ - 'tag1', - 'tag2', + "tag1", + "tag2", ], ) @@ -320,95 +315,83 @@ def _query_check(_id): _mutation_check(minput, okay=False) # Valid data - minput['analysisPillar'] = str(analysis_pillar.id) + minput["analysisPillar"] = str(analysis_pillar.id) # --- member user (All good) with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_summary_id = response['data']['project']['triggerAnalysisTopicModel']['result']['id'] - assert _query_check(a_summary_id)['data']['project']['analysisTopicModel']['status'] ==\ - self.genum(TopicModel.Status.STARTED) + a_summary_id = response["data"]["project"]["triggerAnalysisTopicModel"]["result"]["id"] + assert _query_check(a_summary_id)["data"]["project"]["analysisTopicModel"]["status"] == self.genum( + TopicModel.Status.STARTED + ) # -- Bad status code from NLP on trigger request trigger_results_mock.post.side_effect = nlp_fail_mock with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_summary_id = response['data']['project']['triggerAnalysisTopicModel']['result']['id'] - assert _query_check(a_summary_id)['data']['project']['analysisTopicModel']['status'] ==\ - self.genum(TopicModel.Status.SEND_FAILED) + a_summary_id = response["data"]["project"]["triggerAnalysisTopicModel"]["result"]["id"] + assert _query_check(a_summary_id)["data"]["project"]["analysisTopicModel"]["status"] == self.genum( + TopicModel.Status.SEND_FAILED + ) topic_model = TopicModel.objects.get(pk=a_summary_id) # Check if generated entries are within the project - assert list(topic_model.get_entries_qs().values_list('id', flat=True)) == [ - entry.id - for entry in lead2_entries - ] + assert list(topic_model.get_entries_qs().values_list("id", flat=True)) == [entry.id for entry in lead2_entries] # -- Callback test (Mocking NLP part) SAMPLE_TOPIC_MODEL_RESPONSE = { - 'cluster_1': { - "entry_id": [ - entry.id - for entry in lead2_entries[:1] - ], - 'label': "Label 1", + "cluster_1": { + "entry_id": [entry.id for entry in lead2_entries[:1]], + "label": "Label 1", }, - 'cluster_2': { - "entry_id": [ - entry.id - for entry in lead2_entries[1:] - ], - 'label': "Label 2" - } + "cluster_2": {"entry_id": [entry.id for entry in lead2_entries[1:]], "label": "Label 2"}, } RequestHelperMock.return_value.json.return_value = SAMPLE_TOPIC_MODEL_RESPONSE - callback_url = '/api/v1/callback/analysis-topic-model/' + callback_url = "/api/v1/callback/analysis-topic-model/" data = { - 'client_id': 'invalid-id', - 'presigned_s3_url': 'https://random-domain.com/random-url.json', - 'status': DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, + "client_id": "invalid-id", + "presigned_s3_url": "https://random-domain.com/random-url.json", + "status": DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, } response = self.client.post(callback_url, data) self.assert_400(response) # With valid client_id - data['client_id'] = AnalysisAutomaticSummaryHandler.get_client_id(topic_model) + data["client_id"] = AnalysisAutomaticSummaryHandler.get_client_id(topic_model) response = self.client.post(callback_url, data) self.assert_200(response) topic_model.refresh_from_db() assert topic_model.status == TopicModel.Status.SUCCESS.value assert [ - { - 'entries_id': list(cluster.entries.order_by('id').values_list('id', flat=True)) - } + {"entries_id": list(cluster.entries.order_by("id").values_list("id", flat=True))} for cluster in TopicModelCluster.objects.filter(topic_model=topic_model) ] == [ - {'entries_id': [entry.id for entry in lead2_entries[:1]]}, - {'entries_id': [entry.id for entry in lead2_entries[1:]]}, + {"entries_id": [entry.id for entry in lead2_entries[:1]]}, + {"entries_id": [entry.id for entry in lead2_entries[1:]]}, ] # -- Check query data after mock callback - response_result = _query_check(a_summary_id)['data']['project']['analysisTopicModel'] - assert response_result['status'] == self.genum(TopicModel.Status.SUCCESS) - assert response_result['clusters'] == [ - {'entries': [dict(id=str(entry.id), excerpt=entry.excerpt) for entry in lead2_entries[:1]]}, - {'entries': [dict(id=str(entry.id), excerpt=entry.excerpt) for entry in lead2_entries[1:]]}, + response_result = _query_check(a_summary_id)["data"]["project"]["analysisTopicModel"] + assert response_result["status"] == self.genum(TopicModel.Status.SUCCESS) + assert response_result["clusters"] == [ + {"entries": [dict(id=str(entry.id), excerpt=entry.excerpt) for entry in lead2_entries[:1]]}, + {"entries": [dict(id=str(entry.id), excerpt=entry.excerpt) for entry in lead2_entries[1:]]}, ] # With failed status - data['status'] = DeeplServerBaseCallbackSerializer.Status.FAILED.value + data["status"] = DeeplServerBaseCallbackSerializer.Status.FAILED.value response = self.client.post(callback_url, data) self.assert_200(response) topic_model.refresh_from_db() assert topic_model.status == TopicModel.Status.FAILED - @mock.patch('deepl_integration.handlers.RequestHelper') - @mock.patch('deepl_integration.handlers.requests') + @mock.patch("deepl_integration.handlers.RequestHelper") + @mock.patch("deepl_integration.handlers.requests") def test_automatic_summary(self, trigger_results_mock, RequestHelperMock): lead1 = LeadFactory.create(project=self.project) lead2 = LeadFactory.create(project=self.project) @@ -423,10 +406,10 @@ def nlp_validator_mock(url, data=None, json=None, **kwargs): # Get payload from file payload = self.get_json_media_file( - json['entries_url'].split('http://testserver/media/')[1], + json["entries_url"].split("http://testserver/media/")[1], ) # TODO: Need to check the Child fields of data and File payload as well - expected_keys = ['data', 'tags'] + expected_keys = ["data", "tags"] if set(payload.keys()) != set(expected_keys): return mock.MagicMock(status_code=400) return mock.MagicMock(status_code=202) @@ -440,16 +423,16 @@ def _mutation_check(minput, **kwargs): return self.query_check( self.TRIGGER_AUTOMATIC_SUMMARY, minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + mnested=["project"], + variables={"projectId": self.project.id}, + **kwargs, ) def _query_check(_id): return self.query_check( self.QUERY_AUTOMATIC_SUMMARY, minput=minput, - variables={'projectId': self.project.id, 'summaryID': _id}, + variables={"projectId": self.project.id, "summaryID": _id}, ) minput = dict(entriesId=[]) @@ -469,26 +452,21 @@ def _query_check(_id): self.force_login(self.member_user) _mutation_check(minput, okay=False) - minput['entriesId'] = [ - str(entry.id) - for entries in [ - lead1_entries, - lead2_entries, - another_lead_entries - ] - for entry in entries + minput["entriesId"] = [ + str(entry.id) for entries in [lead1_entries, lead2_entries, another_lead_entries] for entry in entries ] - minput['widgetTags'] = [ - 'tag1', - 'tag2', + minput["widgetTags"] = [ + "tag1", + "tag2", ] # --- member user (All good) with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_summary_id = response['data']['project']['triggerAnalysisAutomaticSummary']['result']['id'] - assert _query_check(a_summary_id)['data']['project']['analysisAutomaticSummary']['status'] ==\ - self.genum(AutomaticSummary.Status.STARTED) + a_summary_id = response["data"]["project"]["triggerAnalysisAutomaticSummary"]["result"]["id"] + assert _query_check(a_summary_id)["data"]["project"]["analysisAutomaticSummary"]["status"] == self.genum( + AutomaticSummary.Status.STARTED + ) # Clear out AutomaticSummary.objects.get(pk=a_summary_id).delete() @@ -498,9 +476,10 @@ def _query_check(_id): with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_summary_id = response['data']['project']['triggerAnalysisAutomaticSummary']['result']['id'] - assert _query_check(a_summary_id)['data']['project']['analysisAutomaticSummary']['status'] ==\ - self.genum(AutomaticSummary.Status.SEND_FAILED) + a_summary_id = response["data"]["project"]["triggerAnalysisAutomaticSummary"]["result"]["id"] + assert _query_check(a_summary_id)["data"]["project"]["analysisAutomaticSummary"]["status"] == self.genum( + AutomaticSummary.Status.SEND_FAILED + ) a_summary = AutomaticSummary.objects.get(pk=a_summary_id) # Check if generated entries are within the project @@ -514,21 +493,21 @@ def _query_check(_id): ] # -- Callback test (Mocking NLP part) - SAMPLE_SUMMARY_TEXT = 'SAMPLE SUMMARY TEXT' + SAMPLE_SUMMARY_TEXT = "SAMPLE SUMMARY TEXT" RequestHelperMock.return_value.get_text.return_value = SAMPLE_SUMMARY_TEXT - callback_url = '/api/v1/callback/analysis-automatic-summary/' + callback_url = "/api/v1/callback/analysis-automatic-summary/" data = { - 'client_id': 'invalid-id', - 'presigned_s3_url': 'https://random-domain.com/random-url.txt', - 'status': DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, + "client_id": "invalid-id", + "presigned_s3_url": "https://random-domain.com/random-url.txt", + "status": DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, } response = self.client.post(callback_url, data) self.assert_400(response) # With valid client_id - data['client_id'] = AnalysisAutomaticSummaryHandler.get_client_id(a_summary) + data["client_id"] = AnalysisAutomaticSummaryHandler.get_client_id(a_summary) response = self.client.post(callback_url, data) self.assert_200(response) @@ -537,28 +516,27 @@ def _query_check(_id): assert a_summary.summary == SAMPLE_SUMMARY_TEXT # -- Check existing instance if provided until threshold is over - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisAutomaticSummary']['result'] - assert response_result['id'] == a_summary_id - assert response_result['summary'] == SAMPLE_SUMMARY_TEXT + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisAutomaticSummary"]["result"] + assert response_result["id"] == a_summary_id + assert response_result["summary"] == SAMPLE_SUMMARY_TEXT - a_summary.created_at = self.PATCHER_NOW_VALUE -\ - datetime.timedelta(hours=AutomaticSummary.CACHE_THRESHOLD_HOURS + 1) + a_summary.created_at = self.PATCHER_NOW_VALUE - datetime.timedelta(hours=AutomaticSummary.CACHE_THRESHOLD_HOURS + 1) a_summary.save() - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisAutomaticSummary']['result'] - assert response_result['id'] != a_summary_id - assert response_result['summary'] != SAMPLE_SUMMARY_TEXT + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisAutomaticSummary"]["result"] + assert response_result["id"] != a_summary_id + assert response_result["summary"] != SAMPLE_SUMMARY_TEXT # With failed status - data['status'] = DeeplServerBaseCallbackSerializer.Status.FAILED.value + data["status"] = DeeplServerBaseCallbackSerializer.Status.FAILED.value response = self.client.post(callback_url, data) self.assert_200(response) a_summary.refresh_from_db() assert a_summary.status == AutomaticSummary.Status.FAILED - @mock.patch('deepl_integration.handlers.RequestHelper') - @mock.patch('deepl_integration.handlers.requests') + @mock.patch("deepl_integration.handlers.RequestHelper") + @mock.patch("deepl_integration.handlers.requests") def test_automatic_ngram(self, trigger_results_mock, RequestHelperMock): lead1 = LeadFactory.create(project=self.project) lead2 = LeadFactory.create(project=self.project) @@ -573,16 +551,16 @@ def _mutation_check(minput, **kwargs): return self.query_check( self.TRIGGER_AUTOMATIC_NGRAM, minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + mnested=["project"], + variables={"projectId": self.project.id}, + **kwargs, ) def _query_check(_id): return self.query_check( self.QUERY_AUTOMATIC_NGRAM, minput=minput, - variables={'projectId': self.project.id, 'ngramID': _id}, + variables={"projectId": self.project.id, "ngramID": _id}, ) minput = dict(entriesId=[]) @@ -602,22 +580,17 @@ def _query_check(_id): self.force_login(self.member_user) _mutation_check(minput, okay=False) - minput['entriesId'] = [ - str(entry.id) - for entries in [ - lead1_entries, - lead2_entries, - another_lead_entries - ] - for entry in entries + minput["entriesId"] = [ + str(entry.id) for entries in [lead1_entries, lead2_entries, another_lead_entries] for entry in entries ] # --- member user (All good) with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_ngram_id = response['data']['project']['triggerAnalysisAutomaticNgram']['result']['id'] - assert _query_check(a_ngram_id)['data']['project']['analysisAutomaticNgram']['status'] ==\ - self.genum(AnalyticalStatementNGram.Status.STARTED) + a_ngram_id = response["data"]["project"]["triggerAnalysisAutomaticNgram"]["result"]["id"] + assert _query_check(a_ngram_id)["data"]["project"]["analysisAutomaticNgram"]["status"] == self.genum( + AnalyticalStatementNGram.Status.STARTED + ) # Clear out AnalyticalStatementNGram.objects.get(pk=a_ngram_id).delete() @@ -627,9 +600,10 @@ def _query_check(_id): with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - a_ngram_id = response['data']['project']['triggerAnalysisAutomaticNgram']['result']['id'] - assert _query_check(a_ngram_id)['data']['project']['analysisAutomaticNgram']['status'] ==\ - self.genum(AnalyticalStatementNGram.Status.SEND_FAILED) + a_ngram_id = response["data"]["project"]["triggerAnalysisAutomaticNgram"]["result"]["id"] + assert _query_check(a_ngram_id)["data"]["project"]["analysisAutomaticNgram"]["status"] == self.genum( + AnalyticalStatementNGram.Status.SEND_FAILED + ) a_ngram = AnalyticalStatementNGram.objects.get(pk=a_ngram_id) # Check if generated entries are within the project @@ -644,73 +618,70 @@ def _query_check(_id): # -- Callback test (Mocking NLP part) SAMPLE_NGRAM_RESPONSE = { - 'unigrams': { - 'unigrams-word-1': 1, - 'unigrams-word-2': 1, - 'unigrams-word-5': 3, + "unigrams": { + "unigrams-word-1": 1, + "unigrams-word-2": 1, + "unigrams-word-5": 3, }, - 'bigrams': { - 'bigrams-word-2': 1, - 'bigrams-word-3': 0, - 'bigrams-word-4': 2, + "bigrams": { + "bigrams-word-2": 1, + "bigrams-word-3": 0, + "bigrams-word-4": 2, }, } RequestHelperMock.return_value.json.return_value = SAMPLE_NGRAM_RESPONSE - callback_url = '/api/v1/callback/analysis-automatic-ngram/' + callback_url = "/api/v1/callback/analysis-automatic-ngram/" data = { - 'client_id': 'invalid-id', - 'presigned_s3_url': 'https://random-domain.com/random-url.json', - 'status': DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, + "client_id": "invalid-id", + "presigned_s3_url": "https://random-domain.com/random-url.json", + "status": DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, } response = self.client.post(callback_url, data) self.assert_400(response) # With valid client_id - data['client_id'] = AnalysisAutomaticSummaryHandler.get_client_id(a_ngram) + data["client_id"] = AnalysisAutomaticSummaryHandler.get_client_id(a_ngram) response = self.client.post(callback_url, data) self.assert_200(response) a_ngram.refresh_from_db() assert a_ngram.status == AnalyticalStatementNGram.Status.SUCCESS.value - assert a_ngram.unigrams == SAMPLE_NGRAM_RESPONSE['unigrams'] - assert a_ngram.bigrams == SAMPLE_NGRAM_RESPONSE['bigrams'] + assert a_ngram.unigrams == SAMPLE_NGRAM_RESPONSE["unigrams"] + assert a_ngram.bigrams == SAMPLE_NGRAM_RESPONSE["bigrams"] assert a_ngram.trigrams == {} # -- Check existing instance if provided until threshold is over - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisAutomaticNgram']['result'] - assert response_result['id'] == a_ngram_id - assert response_result['unigrams'] == [ - dict(word=word, count=count) - for word, count in SAMPLE_NGRAM_RESPONSE['unigrams'].items() + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisAutomaticNgram"]["result"] + assert response_result["id"] == a_ngram_id + assert response_result["unigrams"] == [ + dict(word=word, count=count) for word, count in SAMPLE_NGRAM_RESPONSE["unigrams"].items() ] - assert response_result['bigrams'] == [ - dict(word=word, count=count) - for word, count in SAMPLE_NGRAM_RESPONSE['bigrams'].items() + assert response_result["bigrams"] == [ + dict(word=word, count=count) for word, count in SAMPLE_NGRAM_RESPONSE["bigrams"].items() ] - assert response_result['trigrams'] == [] + assert response_result["trigrams"] == [] - a_ngram.created_at = self.PATCHER_NOW_VALUE -\ - datetime.timedelta(hours=AnalyticalStatementNGram.CACHE_THRESHOLD_HOURS + 1) + a_ngram.created_at = self.PATCHER_NOW_VALUE - datetime.timedelta(hours=AnalyticalStatementNGram.CACHE_THRESHOLD_HOURS + 1) a_ngram.save() - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisAutomaticNgram']['result'] - assert response_result['id'] != a_ngram_id - assert response_result['unigrams'] == [] - assert response_result['bigrams'] == [] - assert response_result['trigrams'] == [] + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisAutomaticNgram"]["result"] + assert response_result["id"] != a_ngram_id + assert response_result["unigrams"] == [] + assert response_result["bigrams"] == [] + assert response_result["trigrams"] == [] # With failed status - data['status'] = DeeplServerBaseCallbackSerializer.Status.FAILED.value + data["status"] = DeeplServerBaseCallbackSerializer.Status.FAILED.value response = self.client.post(callback_url, data) self.assert_200(response) a_ngram.refresh_from_db() assert a_ngram.status == AnalyticalStatementNGram.Status.FAILED - @mock.patch('deepl_integration.handlers.RequestHelper') - @mock.patch('deepl_integration.handlers.requests') + @mock.patch("deepl_integration.handlers.RequestHelper") + @mock.patch("deepl_integration.handlers.requests") def test_geo_location(self, trigger_results_mock, RequestHelperMock): lead1 = LeadFactory.create(project=self.project) lead2 = LeadFactory.create(project=self.project) @@ -723,18 +694,14 @@ def test_geo_location(self, trigger_results_mock, RequestHelperMock): def _mutation_check(minput, **kwargs): return self.query_check( - self.TRIGGER_GEOLOCATION, - minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + self.TRIGGER_GEOLOCATION, minput=minput, mnested=["project"], variables={"projectId": self.project.id}, **kwargs ) def _query_check(_id): return self.query_check( self.QUERY_GEOLOCATION, minput=minput, - variables={'projectId': self.project.id, 'ID': _id}, + variables={"projectId": self.project.id, "ID": _id}, ) minput = dict(entriesId=[]) @@ -754,22 +721,17 @@ def _query_check(_id): self.force_login(self.member_user) _mutation_check(minput, okay=False) - minput['entriesId'] = [ - str(entry.id) - for entries in [ - lead1_entries, - lead2_entries, - another_lead_entries - ] - for entry in entries + minput["entriesId"] = [ + str(entry.id) for entries in [lead1_entries, lead2_entries, another_lead_entries] for entry in entries ] # --- member user (All good) with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - geo_task_id = response['data']['project']['triggerAnalysisGeoLocation']['result']['id'] - assert _query_check(geo_task_id)['data']['project']['analysisGeoTask']['status'] ==\ - self.genum(AnalyticalStatementGeoTask.Status.STARTED) + geo_task_id = response["data"]["project"]["triggerAnalysisGeoLocation"]["result"]["id"] + assert _query_check(geo_task_id)["data"]["project"]["analysisGeoTask"]["status"] == self.genum( + AnalyticalStatementGeoTask.Status.STARTED + ) # Clear out AnalyticalStatementGeoTask.objects.get(pk=geo_task_id).delete() @@ -779,9 +741,10 @@ def _query_check(_id): with self.captureOnCommitCallbacks(execute=True): response = _mutation_check(minput, okay=True) - geo_task_id = response['data']['project']['triggerAnalysisGeoLocation']['result']['id'] - assert _query_check(geo_task_id)['data']['project']['analysisGeoTask']['status'] ==\ - self.genum(AnalyticalStatementGeoTask.Status.SEND_FAILED) + geo_task_id = response["data"]["project"]["triggerAnalysisGeoLocation"]["result"]["id"] + assert _query_check(geo_task_id)["data"]["project"]["analysisGeoTask"]["status"] == self.genum( + AnalyticalStatementGeoTask.Status.SEND_FAILED + ) geo_task = AnalyticalStatementGeoTask.objects.get(pk=geo_task_id) # Check if generated entries are within the project @@ -798,20 +761,18 @@ def _query_check(_id): CALLBACK_ENTRIES = lead1_entries SAMPLE_GEO_DATA_RESPONSE = [ { - 'entry_id': str(entry.id), - 'locations': [ + "entry_id": str(entry.id), + "locations": [ { - 'entity': 'test', - 'meta': - { - 'latitude': 11, - 'longitude': 11, - 'offset_start': 0, - 'offset_end': 3, - } - + "entity": "test", + "meta": { + "latitude": 11, + "longitude": 11, + "offset_start": 0, + "offset_end": 3, + }, } - ] + ], } for entry in [ *CALLBACK_ENTRIES, @@ -821,56 +782,55 @@ def _query_check(_id): RequestHelperMock.return_value.json.return_value = SAMPLE_GEO_DATA_RESPONSE - callback_url = '/api/v1/callback/analysis-geo/' + callback_url = "/api/v1/callback/analysis-geo/" data = { - 'client_id': 'invalid-id', - 'presigned_s3_url': 'https://random-domain.com/random-url.json', - 'status': DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, + "client_id": "invalid-id", + "presigned_s3_url": "https://random-domain.com/random-url.json", + "status": DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, } response = self.client.post(callback_url, data) self.assert_400(response) # With valid client_id - data['client_id'] = AnalysisAutomaticSummaryHandler.get_client_id(geo_task) + data["client_id"] = AnalysisAutomaticSummaryHandler.get_client_id(geo_task) response = self.client.post(callback_url, data) self.assert_200(response) geo_task.refresh_from_db() assert geo_task.status == AnalyticalStatementGeoTask.Status.SUCCESS.value - assert _query_check(geo_task.id)['data']['project']['analysisGeoTask']['entryGeo'] == [ + assert _query_check(geo_task.id)["data"]["project"]["analysisGeoTask"]["entryGeo"] == [ { - 'data': - [ + "data": [ { - 'entity': 'test', - 'meta': - { - 'latitude': 11, - 'longitude': 11, - 'offsetStart': 0, - 'offsetEnd': 3, - }, + "entity": "test", + "meta": { + "latitude": 11, + "longitude": 11, + "offsetStart": 0, + "offsetEnd": 3, + }, }, ], - 'entryId': str(entry.id), + "entryId": str(entry.id), } for entry in CALLBACK_ENTRIES ] # -- Check existing instance if provided until threshold is over (CACHE check) - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisGeoLocation']['result'] - assert response_result['id'] == geo_task_id + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisGeoLocation"]["result"] + assert response_result["id"] == geo_task_id - geo_task.created_at = self.PATCHER_NOW_VALUE -\ - datetime.timedelta(hours=AnalyticalStatementGeoTask.CACHE_THRESHOLD_HOURS + 1) + geo_task.created_at = self.PATCHER_NOW_VALUE - datetime.timedelta( + hours=AnalyticalStatementGeoTask.CACHE_THRESHOLD_HOURS + 1 + ) geo_task.save() - response_result = _mutation_check(minput, okay=True)['data']['project']['triggerAnalysisGeoLocation']['result'] - assert response_result['id'] != geo_task_id + response_result = _mutation_check(minput, okay=True)["data"]["project"]["triggerAnalysisGeoLocation"]["result"] + assert response_result["id"] != geo_task_id # With failed status - data['status'] = DeeplServerBaseCallbackSerializer.Status.FAILED.value + data["status"] = DeeplServerBaseCallbackSerializer.Status.FAILED.value response = self.client.post(callback_url, data) self.assert_200(response) @@ -883,7 +843,7 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): AnalysisReportUploadFactory, ] - REPORT_SNAPSHOT_FRAGMENT = ''' + REPORT_SNAPSHOT_FRAGMENT = """ fragment AnalysisReportSnapshotResponse on AnalysisReportSnapshotType { id publishedOn @@ -906,11 +866,11 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ CREATE_REPORT = ( - SnapshotQuery.AnalysisReport.SnapshotFragment + - '''\n + SnapshotQuery.AnalysisReport.SnapshotFragment + + """\n mutation CreateReport($projectId: ID!, $input: AnalysisReportInputType!) { project(id: $projectId) { analysisReportCreate(data: $input) { @@ -922,12 +882,12 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ ) CREATE_REPORT_SNAPSHOT = ( - REPORT_SNAPSHOT_FRAGMENT + - '''\n + REPORT_SNAPSHOT_FRAGMENT + + """\n mutation CreateReportSnapshot($projectId: ID!, $input: AnalysisReportSnapshotInputType!) { project(id: $projectId) { analysisReportSnapshotCreate(data: $input) { @@ -939,12 +899,12 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ ) UPDATE_REPORT = ( - SnapshotQuery.AnalysisReport.SnapshotFragment + - '''\n + SnapshotQuery.AnalysisReport.SnapshotFragment + + """\n mutation UpdateReport($projectId: ID!, $reportId: ID!, $input: AnalysisReportInputUpdateType!) { project(id: $projectId) { analysisReportUpdate(id: $reportId, data: $input) { @@ -956,12 +916,12 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ ) QUERY_REPORT = ( - SnapshotQuery.AnalysisReport.SnapshotFragment + - '''\n + SnapshotQuery.AnalysisReport.SnapshotFragment + + """\n query Report($projectId: ID!, $reportId: ID!) { project(id: $projectId) { analysisReport(id: $reportId) { @@ -969,12 +929,12 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ ) QUERY_REPORT_SNAPSHOT = ( - REPORT_SNAPSHOT_FRAGMENT + - '''\n + REPORT_SNAPSHOT_FRAGMENT + + """\n query QueryReportSnapshot($projectId: ID!, $snapshotId: ID!) { project(id: $projectId) { analysisReportSnapshot(id: $snapshotId) { @@ -982,18 +942,18 @@ class TestAnalysisReportQueryAndMutationSchema(GraphQLTestCase): } } } - ''' + """ ) QUERY_PUBLIC_REPORT_SNAPSHOT = ( - REPORT_SNAPSHOT_FRAGMENT + - '''\n + REPORT_SNAPSHOT_FRAGMENT + + """\n query QueryPublicReportSnapshot($slug: String!) { publicAnalysisReportSnapshot(slug: $slug) { ...AnalysisReportSnapshotResponse } } - ''' + """ ) def setUp(self): @@ -1015,38 +975,29 @@ def test_mutation_and_query(self): def _create_mutation_check(minput, **kwargs): return self.query_check( - self.CREATE_REPORT, - minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + self.CREATE_REPORT, minput=minput, mnested=["project"], variables={"projectId": self.project.id}, **kwargs ) def _create_snapshot_mutation_check(minput, **kwargs): return self.query_check( self.CREATE_REPORT_SNAPSHOT, minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, - **kwargs + mnested=["project"], + variables={"projectId": self.project.id}, + **kwargs, ) def _query_snapshot_check(snapshot_id, **kwargs): return self.query_check( self.QUERY_REPORT_SNAPSHOT, - variables={ - 'projectId': self.project.id, - 'snapshotId': snapshot_id - }, + variables={"projectId": self.project.id, "snapshotId": snapshot_id}, **kwargs, ) def _query_public_snapshot_check(slug, **kwargs): return self.query_check( self.QUERY_PUBLIC_REPORT_SNAPSHOT, - variables={ - 'slug': slug - }, + variables={"slug": slug}, **kwargs, ) @@ -1054,42 +1005,40 @@ def _update_mutation_check(_id, minput, **kwargs): return self.query_check( self.UPDATE_REPORT, minput=minput, - mnested=['project'], + mnested=["project"], variables={ - 'projectId': self.project.id, - 'reportId': _id, + "projectId": self.project.id, + "reportId": _id, }, - **kwargs + **kwargs, ) def _query_check(_id, **kwargs): return self.query_check( self.QUERY_REPORT, - variables={ - 'projectId': self.project.id, - 'reportId': _id - }, + variables={"projectId": self.project.id, "reportId": _id}, **kwargs, ) test_data_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), - 'analysis_report', + "analysis_report", ) - with\ - open(os.path.join(test_data_dir, 'data.json'), 'r') as test_data_file, \ - open(os.path.join(test_data_dir, 'error1.json'), 'r') as test_error1_file, \ - open(os.path.join(test_data_dir, 'error2.json'), 'r') as test_error2_file: + with ( + open(os.path.join(test_data_dir, "data.json"), "r") as test_data_file, + open(os.path.join(test_data_dir, "error1.json"), "r") as test_error1_file, + open(os.path.join(test_data_dir, "error2.json"), "r") as test_error2_file, + ): test_data = json.load(test_data_file) error_1_data = json.load(test_error1_file) error_2_data = json.load(test_error2_file) minput = { - 'isPublic': False, - 'analysis': str(analysis.pk), - 'slug': 'analysis-test-1001', - 'title': 'Test 2', - 'subTitle': 'Test 2', + "isPublic": False, + "analysis": str(analysis.pk), + "slug": "analysis-test-1001", + "title": "Test 2", + "subTitle": "Test 2", **test_data, } @@ -1107,38 +1056,48 @@ def _query_check(_id, **kwargs): # --- member user (All good) self.force_login(self.member_user) response = _create_mutation_check(minput, okay=True) - created_report1_data = response['data']['project']['analysisReportCreate']['result'] - report1_id = created_report1_data['id'] + created_report1_data = response["data"]["project"]["analysisReportCreate"]["result"] + report1_id = created_report1_data["id"] report1_upload1, report1_upload2 = AnalysisReportUploadFactory.create_batch(2, report_id=report1_id) - minput['containers'][0]['contentData'] = [{ - 'clientReferenceId': 'upload-1-id', - 'upload': str(report1_upload1.pk), - }] + minput["containers"][0]["contentData"] = [ + { + "clientReferenceId": "upload-1-id", + "upload": str(report1_upload1.pk), + } + ] # -- Validation check errors = _create_mutation_check( minput, okay=False, - )['data']['project']['analysisReportCreate']['errors'] + )[ + "data" + ]["project"][ + "analysisReportCreate" + ]["errors"] assert errors == error_1_data del errors - minput['containers'][0]['contentData'] = [] - minput['slug'] = 'analysis-test-1002' + minput["containers"][0]["contentData"] = [] + minput["slug"] = "analysis-test-1002" created_report2_data = _create_mutation_check( minput, okay=True, - )['data']['project']['analysisReportCreate']['result'] - report2_id = created_report2_data['id'] + )[ + "data" + ]["project"][ + "analysisReportCreate" + ]["result"] + report2_id = created_report2_data["id"] # Update # -- -- Report 1 minput = { **created_report1_data, } - minput.pop('id') + minput.pop("id") # -- Without login self.logout() _update_mutation_check(report1_id, minput, assert_for_error=True) @@ -1153,34 +1112,42 @@ def _query_check(_id, **kwargs): # --- member user (error since input is empty) self.force_login(self.member_user) response = _update_mutation_check(report1_id, minput, okay=True) - updated_report_data = response['data']['project']['analysisReportUpdate']['result'] + updated_report_data = response["data"]["project"]["analysisReportUpdate"]["result"] assert updated_report_data == created_report1_data del updated_report_data # -- -- Report 2 minput = { **created_report2_data, } - minput.pop('id') + minput.pop("id") # Invalid data - minput['containers'][0]['contentData'] = [{ - 'clientReferenceId': 'upload-2-id', - 'upload': str(report1_upload2.pk), - }] + minput["containers"][0]["contentData"] = [ + { + "clientReferenceId": "upload-2-id", + "upload": str(report1_upload2.pk), + } + ] errors = _update_mutation_check( report2_id, minput, okay=False, - )['data']['project']['analysisReportUpdate']['errors'] + )[ + "data" + ]["project"][ + "analysisReportUpdate" + ]["errors"] assert errors == error_2_data report2_upload1 = AnalysisReportUploadFactory.create(report_id=report2_id) - minput['containers'][0]['contentData'] = [{ - 'clientReferenceId': 'upload-1-id', - 'upload': str(report2_upload1.pk), - }] + minput["containers"][0]["contentData"] = [ + { + "clientReferenceId": "upload-1-id", + "upload": str(report2_upload1.pk), + } + ] response = _update_mutation_check(report2_id, minput, okay=True) - updated_report_data = response['data']['project']['analysisReportUpdate']['result'] + updated_report_data = response["data"]["project"]["analysisReportUpdate"]["result"] assert updated_report_data != created_report2_data # Basic query check @@ -1190,19 +1157,16 @@ def _query_check(_id, **kwargs): _query_check(report2_id, assert_for_error=True) # -- With login (non-member) self.force_login(self.non_member_user) - assert _query_check(report1_id)['data']['project']['analysisReport'] is None - assert _query_check(report2_id)['data']['project']['analysisReport'] is None + assert _query_check(report1_id)["data"]["project"]["analysisReport"] is None + assert _query_check(report2_id)["data"]["project"]["analysisReport"] is None # --- member user - for user in [ - self.readonly_member_user, - self.member_user - ]: + for user in [self.readonly_member_user, self.member_user]: self.force_login(user) - assert _query_check(report1_id)['data']['project']['analysisReport'] is not None - assert _query_check(report2_id)['data']['project']['analysisReport'] is not None + assert _query_check(report1_id)["data"]["project"]["analysisReport"] is not None + assert _query_check(report2_id)["data"]["project"]["analysisReport"] is not None # Snapshot Mutation - minput = {'report': str(report1_id)} + minput = {"report": str(report1_id)} self.logout() _create_snapshot_mutation_check(minput, assert_for_error=True) # -- With login (non-member) @@ -1217,10 +1181,14 @@ def _query_check(_id, **kwargs): snapshot_data = _create_snapshot_mutation_check( minput, okay=True, - )['data']['project']['analysisReportSnapshotCreate']['result'] - snapshot_id = snapshot_data['id'] - assert snapshot_data['report'] == minput['report'] - assert snapshot_data['reportDataFile']['url'] not in ['', None] + )[ + "data" + ]["project"][ + "analysisReportSnapshotCreate" + ]["result"] + snapshot_id = snapshot_data["id"] + assert snapshot_data["report"] == minput["report"] + assert snapshot_data["reportDataFile"]["url"] not in ["", None] another_report = AnalysisReportFactory.create( analysis=AnalysisFactory.create( @@ -1229,7 +1197,7 @@ def _query_check(_id, **kwargs): end_date=datetime.date(2022, 4, 1), ) ) - minput = {'report': str(another_report.pk)} + minput = {"report": str(another_report.pk)} _create_snapshot_mutation_check(minput, okay=False) # Snapshot Query @@ -1237,19 +1205,34 @@ def _query_check(_id, **kwargs): _query_snapshot_check(snapshot_id, assert_for_error=True) # -- With login (non-member) self.force_login(self.non_member_user) - assert _query_snapshot_check( - snapshot_id, - )['data']['project']['analysisReportSnapshot'] is None + assert ( + _query_snapshot_check( + snapshot_id, + )["data"][ + "project" + ]["analysisReportSnapshot"] + is None + ) # --- member user (read-only) self.force_login(self.readonly_member_user) - assert _query_snapshot_check( - snapshot_id, - )['data']['project']['analysisReportSnapshot'] is not None + assert ( + _query_snapshot_check( + snapshot_id, + )["data"][ + "project" + ]["analysisReportSnapshot"] + is not None + ) # --- member user self.force_login(self.member_user) - assert _query_snapshot_check( - snapshot_id, - )['data']['project']['analysisReportSnapshot'] is not None + assert ( + _query_snapshot_check( + snapshot_id, + )["data"][ + "project" + ]["analysisReportSnapshot"] + is not None + ) # Snapshot Public Query snapshot = AnalysisReportSnapshot.objects.get(pk=snapshot_id) @@ -1268,7 +1251,7 @@ def _query_check(_id, **kwargs): self.logout() else: self.force_login(user) - assert _query_public_snapshot_check(snapshot_slug)['data']['publicAnalysisReportSnapshot'] is None + assert _query_public_snapshot_check(snapshot_slug)["data"]["publicAnalysisReportSnapshot"] is None # -- Public [Not enabled in project] snapshot.report.is_public = True @@ -1283,10 +1266,10 @@ def _query_check(_id, **kwargs): self.logout() else: self.force_login(user) - assert _query_public_snapshot_check(snapshot_slug)['data']['publicAnalysisReportSnapshot'] is None + assert _query_public_snapshot_check(snapshot_slug)["data"]["publicAnalysisReportSnapshot"] is None self.project.enable_publicly_viewable_analysis_report_snapshot = True - self.project.save(update_fields=('enable_publicly_viewable_analysis_report_snapshot',)) + self.project.save(update_fields=("enable_publicly_viewable_analysis_report_snapshot",)) # -- Not Public [Enabled in project] snapshot.report.is_public = False snapshot.report.save() @@ -1300,7 +1283,7 @@ def _query_check(_id, **kwargs): self.logout() else: self.force_login(user) - assert _query_public_snapshot_check(snapshot_slug)['data']['publicAnalysisReportSnapshot'] is None + assert _query_public_snapshot_check(snapshot_slug)["data"]["publicAnalysisReportSnapshot"] is None # -- Public [Enabled in project] snapshot.report.is_public = True @@ -1315,4 +1298,4 @@ def _query_check(_id, **kwargs): self.logout() else: self.force_login(user) - assert _query_public_snapshot_check(snapshot_slug)['data']['publicAnalysisReportSnapshot'] is not None + assert _query_public_snapshot_check(snapshot_slug)["data"]["publicAnalysisReportSnapshot"] is not None diff --git a/apps/analysis/tests/test_schemas.py b/apps/analysis/tests/test_schemas.py index 86c9435545..764d903c5a 100644 --- a/apps/analysis/tests/test_schemas.py +++ b/apps/analysis/tests/test_schemas.py @@ -1,20 +1,19 @@ import datetime -from utils.graphene.tests import GraphQLTestCase - -from user.factories import UserFactory -from project.factories import ProjectFactory - from analysis.factories import AnalysisFactory, AnalysisPillarFactory from analysis_framework.factories import AnalysisFrameworkFactory -from lead.factories import LeadFactory from entry.factories import EntryFactory +from lead.factories import LeadFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestAnalysisQuerySchema(GraphQLTestCase): def test_analyses_and_analysis_pillars_query(self): # Permission checks - query = ''' + query = """ query MyQuery ($projectId: ID!) { project(id: $projectId) { analyses { @@ -33,7 +32,7 @@ def test_analyses_and_analysis_pillars_query(self): } } } - ''' + """ member_user = UserFactory.create() non_member_user = UserFactory.create() @@ -44,7 +43,7 @@ def test_analyses_and_analysis_pillars_query(self): AnalysisPillarFactory.create_batch(5, analysis=analysis, assignee=member_user) def _query_check(**kwargs): - return self.query_check(query, variables={'projectId': project.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -52,21 +51,21 @@ def _query_check(**kwargs): # --- With login self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['analyses']['totalCount'], 0, content) - self.assertEqual(len(content['data']['project']['analyses']['results']), 0, content) - self.assertEqual(content['data']['project']['analysisPillars']['totalCount'], 0, content) - self.assertEqual(len(content['data']['project']['analysisPillars']['results']), 0, content) + self.assertEqual(content["data"]["project"]["analyses"]["totalCount"], 0, content) + self.assertEqual(len(content["data"]["project"]["analyses"]["results"]), 0, content) + self.assertEqual(content["data"]["project"]["analysisPillars"]["totalCount"], 0, content) + self.assertEqual(len(content["data"]["project"]["analysisPillars"]["results"]), 0, content) self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['analyses']['totalCount'], 2, content) - self.assertEqual(len(content['data']['project']['analyses']['results']), 2, content) - self.assertEqual(content['data']['project']['analysisPillars']['totalCount'], 10, content) - self.assertEqual(len(content['data']['project']['analysisPillars']['results']), 10, content) + self.assertEqual(content["data"]["project"]["analyses"]["totalCount"], 2, content) + self.assertEqual(len(content["data"]["project"]["analyses"]["results"]), 2, content) + self.assertEqual(content["data"]["project"]["analysisPillars"]["totalCount"], 10, content) + self.assertEqual(len(content["data"]["project"]["analysisPillars"]["results"]), 10, content) def test_analysis_and_analysis_pillar_query(self): # Permission checks - query = ''' + query = """ query MyQuery ($projectId: ID!, $analysisId: ID!, $analysisPillarId: ID!) { project(id: $projectId) { analysis (id: $analysisId) { @@ -79,7 +78,7 @@ def test_analysis_and_analysis_pillar_query(self): } } } - ''' + """ member_user = UserFactory.create() non_member_user = UserFactory.create() @@ -92,11 +91,12 @@ def _query_check(**kwargs): return self.query_check( query, variables={ - 'projectId': project.id, - 'analysisId': analysis.id, - 'analysisPillarId': analysis_pillar.id, + "projectId": project.id, + "analysisId": analysis.id, + "analysisPillarId": analysis_pillar.id, }, - **kwargs) + **kwargs, + ) # -- Without login _query_check(assert_for_error=True) @@ -104,16 +104,16 @@ def _query_check(**kwargs): # --- With login self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['analysis'], None, content) - self.assertEqual(content['data']['project']['analysisPillar'], None, content) + self.assertEqual(content["data"]["project"]["analysis"], None, content) + self.assertEqual(content["data"]["project"]["analysisPillar"], None, content) self.force_login(member_user) content = _query_check() - self.assertNotEqual(content['data']['project']['analysis'], None, content) - self.assertNotEqual(content['data']['project']['analysisPillar'], None, content) + self.assertNotEqual(content["data"]["project"]["analysis"], None, content) + self.assertNotEqual(content["data"]["project"]["analysisPillar"], None, content) def test_analysis_pillars_entries_query(self): - query = ''' + query = """ query MyQuery ($projectId: ID!, $analysisPillarId: ID!) { project(id: $projectId) { analysisPillar (id: $analysisPillarId) { @@ -128,7 +128,7 @@ def test_analysis_pillars_entries_query(self): } } } - ''' + """ now = datetime.datetime.now() member_user = UserFactory.create() @@ -143,7 +143,7 @@ def test_analysis_pillars_entries_query(self): def _query_check(**kwargs): return self.query_check( query, - variables={'projectId': project.id, 'analysisPillarId': analysis_pillar.pk}, + variables={"projectId": project.id, "analysisPillarId": analysis_pillar.pk}, **kwargs, ) @@ -153,12 +153,12 @@ def _query_check(**kwargs): # --- With login self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['analysisPillar'], None, content) + self.assertEqual(content["data"]["project"]["analysisPillar"], None, content) self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['analysisPillar']['entries']['totalCount'], 0, content) - self.assertEqual(len(content['data']['project']['analysisPillar']['entries']['results']), 0, content) + self.assertEqual(content["data"]["project"]["analysisPillar"]["entries"]["totalCount"], 0, content) + self.assertEqual(len(content["data"]["project"]["analysisPillar"]["entries"]["results"]), 0, content) # Let's add some entries lead_published_on = now - datetime.timedelta(days=1) # To fit within analysis end_date @@ -166,5 +166,5 @@ def _query_check(**kwargs): EntryFactory.create_batch(8, lead=LeadFactory.create(project=another_project, published_on=lead_published_on)) content = _query_check() - self.assertEqual(content['data']['project']['analysisPillar']['entries']['totalCount'], 10, content) - self.assertEqual(len(content['data']['project']['analysisPillar']['entries']['results']), 10, content) + self.assertEqual(content["data"]["project"]["analysisPillar"]["entries"]["totalCount"], 10, content) + self.assertEqual(len(content["data"]["project"]["analysisPillar"]["entries"]["results"]), 10, content) diff --git a/apps/analysis/views.py b/apps/analysis/views.py index e2f974869c..1113eedb06 100644 --- a/apps/analysis/views.py +++ b/apps/analysis/views.py @@ -1,35 +1,20 @@ from django.shortcuts import get_object_or_404 - +from entry.views import EntryFilterView +from rest_framework import permissions, response, status, views, viewsets from rest_framework.decorators import action -from rest_framework import ( - permissions, - views, - response, - viewsets, - status -) from deep.permissions import IsProjectMember, ModifyPermission -from entry.views import EntryFilterView -from .models import ( - Analysis, - AnalysisPillar, - AnalyticalStatement, - DiscardedEntry, -) +from .filter_set import AnalysisFilterSet, DiscardedEntryFilterSet +from .models import Analysis, AnalysisPillar, AnalyticalStatement, DiscardedEntry from .serializers import ( - AnalysisSerializer, + AnalysisCloneInputSerializer, AnalysisPillarSerializer, - AnalyticalStatementSerializer, - AnalysisSummarySerializer, AnalysisPillarSummarySerializer, + AnalysisSerializer, + AnalysisSummarySerializer, + AnalyticalStatementSerializer, DiscardedEntrySerializer, - AnalysisCloneInputSerializer -) -from .filter_set import ( - AnalysisFilterSet, - DiscardedEntryFilterSet, ) @@ -39,52 +24,42 @@ class AnalysisViewSet(viewsets.ModelViewSet): filterset_class = AnalysisFilterSet def get_queryset(self): - return Analysis.objects.filter(project=self.kwargs['project_id']).select_related( - 'project', - 'team_lead', + return Analysis.objects.filter(project=self.kwargs["project_id"]).select_related( + "project", + "team_lead", ) - @action( - detail=False, - url_path='summary' - ) + @action(detail=False, url_path="summary") def get_summary(self, request, project_id, pk=None, version=None): queryset = self.filter_queryset(self.get_queryset()) queryset = Analysis.annotate_for_analysis_summary(project_id, queryset, self.request.user) page = self.paginate_queryset(queryset) # NOTE: Calculating here and passing as context since we can't calculate union in subquery in Django for now context = { - 'analyzed_sources': Analysis.get_analyzed_sources(page), - 'analyzed_entries': Analysis.get_analyzed_entries(page) + "analyzed_sources": Analysis.get_analyzed_sources(page), + "analyzed_entries": Analysis.get_analyzed_entries(page), } serializer = AnalysisSummarySerializer(page, many=True, context=context, partial=True) return self.get_paginated_response(serializer.data) - @action( - detail=True, - url_path='clone', - methods=['post'] - ) + @action(detail=True, url_path="clone", methods=["post"]) def clone_analysis(self, request, project_id, pk=None, version=None): analysis = self.get_object() input_serializer = AnalysisCloneInputSerializer(data=request.data) if input_serializer.is_valid(): - title = input_serializer.validated_data['title'] - end_date = input_serializer.validated_data['end_date'] + title = input_serializer.validated_data["title"] + end_date = input_serializer.validated_data["end_date"] new_analysis = analysis.clone_analysis(title, end_date) serializer = AnalysisSerializer( new_analysis, - context={'request': request}, + context={"request": request}, ) return response.Response( serializer.data, status=status.HTTP_201_CREATED, ) else: - return response.Response( - input_serializer.errors, - status=status.HTTP_400_BAD_REQUEST - ) + return response.Response(input_serializer.errors, status=status.HTTP_400_BAD_REQUEST) class AnalysisPillarViewSet(viewsets.ModelViewSet): @@ -92,15 +67,14 @@ class AnalysisPillarViewSet(viewsets.ModelViewSet): permission_classes = [permissions.IsAuthenticated, IsProjectMember, ModifyPermission] def get_queryset(self): - return AnalysisPillar.objects\ - .filter( - analysis=self.kwargs['analysis_id'], - analysis__project=self.kwargs['project_id'], - ).select_related('analysis', 'assignee', 'assignee__profile') + return AnalysisPillar.objects.filter( + analysis=self.kwargs["analysis_id"], + analysis__project=self.kwargs["project_id"], + ).select_related("analysis", "assignee", "assignee__profile") @action( detail=False, - url_path='summary', + url_path="summary", ) def get_summary(self, request, **kwargs): queryset = self.filter_queryset(self.get_queryset()) @@ -116,12 +90,12 @@ class AnalysisPillarDiscardedEntryViewSet(viewsets.ModelViewSet): filterset_class = DiscardedEntryFilterSet def get_queryset(self): - return DiscardedEntry.objects.filter(analysis_pillar=self.kwargs['analysis_pillar_id']) + return DiscardedEntry.objects.filter(analysis_pillar=self.kwargs["analysis_pillar_id"]) def get_serializer_context(self): return { **super().get_serializer_context(), - 'analysis_pillar_id': self.kwargs.get('analysis_pillar_id'), + "analysis_pillar_id": self.kwargs.get("analysis_pillar_id"), } @@ -131,18 +105,17 @@ class AnalysisPillarEntryViewSet(EntryFilterView): def get_queryset(self): queryset = super().get_queryset() filters = self.get_entries_filters() - analysis_pillar_id = self.kwargs['analysis_pillar_id'] - analysis_pillar = get_object_or_404(AnalysisPillar, id=self.kwargs['analysis_pillar_id']) + analysis_pillar_id = self.kwargs["analysis_pillar_id"] + analysis_pillar = get_object_or_404(AnalysisPillar, id=self.kwargs["analysis_pillar_id"]) # filtering out the entries whose lead published_on date is less than analysis end_date queryset = queryset.filter( - project=analysis_pillar.analysis.project, - lead__published_on__lte=analysis_pillar.analysis.end_date + project=analysis_pillar.analysis.project, lead__published_on__lte=analysis_pillar.analysis.end_date ) - discarded_entries_qs = DiscardedEntry.objects.filter(analysis_pillar=analysis_pillar_id).values('entry') - if filters.get('discarded'): + discarded_entries_qs = DiscardedEntry.objects.filter(analysis_pillar=analysis_pillar_id).values("entry") + if filters.get("discarded"): return queryset.filter(id__in=discarded_entries_qs) queryset = queryset.exclude(id__in=discarded_entries_qs) - exclude_entries = filters.get('exclude_entries') + exclude_entries = filters.get("exclude_entries") if exclude_entries: queryset = queryset.exclude(id__in=exclude_entries) return queryset @@ -153,11 +126,15 @@ class AnalyticalStatementViewSet(viewsets.ModelViewSet): permissions_classes = [permissions.IsAuthenticated, IsProjectMember, ModifyPermission] def get_queryset(self): - return AnalyticalStatement.objects.filter(analysis_pillar=self.kwargs['analysis_pillar_id']).select_related( - 'analysis_pillar', - ).prefetch_related( - 'entries', - 'analyticalstatemententry_set', + return ( + AnalyticalStatement.objects.filter(analysis_pillar=self.kwargs["analysis_pillar_id"]) + .select_related( + "analysis_pillar", + ) + .prefetch_related( + "entries", + "analyticalstatemententry_set", + ) ) @@ -167,8 +144,9 @@ class DiscardedEntryOptionsView(views.APIView): def get(self, request, version=None): options = [ { - 'key': tag.value, - 'value': tag.label, - } for tag in DiscardedEntry.TagType + "key": tag.value, + "value": tag.label, + } + for tag in DiscardedEntry.TagType ] return response.Response(options) diff --git a/apps/analysis_framework/admin.py b/apps/analysis_framework/admin.py index 9c50badab6..96dfb68fd4 100644 --- a/apps/analysis_framework/admin.py +++ b/apps/analysis_framework/admin.py @@ -1,32 +1,28 @@ from django.contrib import admin -from deep.admin import linkify -from questionnaire.models import ( - FrameworkQuestion, -) +from questionnaire.models import FrameworkQuestion -from deep.admin import ( - VersionAdmin, - StackedInline, - query_buttons, - ModelAdmin as JFModelAdmin, -) +from deep.admin import ModelAdmin as JFModelAdmin +from deep.admin import StackedInline, VersionAdmin, linkify, query_buttons from .models import ( AnalysisFramework, - AnalysisFrameworkTag, - AnalysisFrameworkRole, AnalysisFrameworkMembership, + AnalysisFrameworkRole, + AnalysisFrameworkTag, + Exportable, + Filter, Section, Widget, - Filter, - Exportable, ) class AnalysisFrameworkMemebershipInline(admin.TabularInline): model = AnalysisFrameworkMembership extra = 0 - autocomplete_fields = ('added_by', 'member',) + autocomplete_fields = ( + "added_by", + "member", + ) class WidgetInline(StackedInline): @@ -50,14 +46,18 @@ class SectionInline(StackedInline): class AFRelatedAdmin(JFModelAdmin): - search_fields = ('analysis_framework__title', 'title',) + search_fields = ( + "analysis_framework__title", + "title", + ) list_display = ( - '__str__', linkify('analysis_framework'), + "__str__", + linkify("analysis_framework"), ) - autocomplete_fields = ('analysis_framework',) + autocomplete_fields = ("analysis_framework",) def get_queryset(self, request): - return super().get_queryset(request).prefetch_related('analysis_framework') + return super().get_queryset(request).prefetch_related("analysis_framework") def has_add_permission(self, request, obj=None): return False @@ -69,27 +69,33 @@ def has_add_permission(self, request, obj=None): @admin.register(AnalysisFramework) class AnalysisFrameworkAdmin(VersionAdmin): - readonly_fields = ['is_private'] + readonly_fields = ["is_private"] inlines = [AnalysisFrameworkMemebershipInline, SectionInline, WidgetInline] - search_fields = ('title',) - list_filter = ('is_private', 'assisted_tagging_enabled',) + search_fields = ("title",) + list_filter = ( + "is_private", + "assisted_tagging_enabled", + ) custom_inlines = [ - ('filter', FilterInline), - ('exportable', ExportableInline), - ('framework_question', FrameworkQuestionInline), + ("filter", FilterInline), + ("exportable", ExportableInline), + ("framework_question", FrameworkQuestionInline), ] list_display = [ - 'title', # 'project_count', - 'created_at', - 'created_by', - query_buttons('View', [inline[0] for inline in custom_inlines]), + "title", # 'project_count', + "created_at", + "created_by", + query_buttons("View", [inline[0] for inline in custom_inlines]), ] - autocomplete_fields = ('created_by', 'modified_by',) + autocomplete_fields = ( + "created_by", + "modified_by", + ) def get_inline_instances(self, request, obj=None): inlines = super().get_inline_instances(request, obj) for name, inline in self.custom_inlines: - if request.GET.get(f'show_{name}', 'False').lower() == 'true': + if request.GET.get(f"show_{name}", "False").lower() == "true": inlines.append(inline(self.model, self.admin_site)) return inlines @@ -100,7 +106,7 @@ def get_formsets_with_inlines(self, request, obj=None): widget_queryset = Widget.objects.filter(analysis_framework=obj) for inline in self.get_inline_instances(request, obj): formset = inline.get_formset(request, obj) - for field in ['widget', 'parent_widget', 'conditional_parent_widget']: + for field in ["widget", "parent_widget", "conditional_parent_widget"]: if field not in formset.form.base_fields: continue formset.form.base_fields[field].queryset = widget_queryset @@ -109,8 +115,8 @@ def get_formsets_with_inlines(self, request, obj=None): @admin.register(AnalysisFrameworkRole) class AnalysisFrameworkRoleAdmin(admin.ModelAdmin): - list_display = ('id', 'title', 'type', 'is_default_role') - readonly_fields = ['is_private_role'] + list_display = ("id", "title", "type", "is_default_role") + readonly_fields = ["is_private_role"] def has_add_permission(self, request, obj=None): return False @@ -118,4 +124,7 @@ def has_add_permission(self, request, obj=None): @admin.register(AnalysisFrameworkTag) class AnalysisFrameworkTagAdmin(admin.ModelAdmin): - list_display = ('id', 'title',) + list_display = ( + "id", + "title", + ) diff --git a/apps/analysis_framework/apps.py b/apps/analysis_framework/apps.py index 9b0e2d59fd..026b8aee16 100644 --- a/apps/analysis_framework/apps.py +++ b/apps/analysis_framework/apps.py @@ -2,4 +2,4 @@ class AnalysisFrameworkConfig(AppConfig): - name = 'analysis_framework' + name = "analysis_framework" diff --git a/apps/analysis_framework/dataloaders.py b/apps/analysis_framework/dataloaders.py index c4508144a7..169c2eeb5b 100644 --- a/apps/analysis_framework/dataloaders.py +++ b/apps/analysis_framework/dataloaders.py @@ -1,56 +1,58 @@ from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property from django.db import models +from django.utils.functional import cached_property from project.models import Project +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin from .models import ( - Widget, - Section, - Filter, - Exportable, - AnalysisFrameworkMembership, AnalysisFramework, + AnalysisFrameworkMembership, + Exportable, + Filter, + Section, + Widget, ) class WidgetLoader(DataLoaderWithContext): @staticmethod def load_widgets(keys, parent, **filters): - qs = Widget.objects\ - .filter( + qs = ( + Widget.objects.filter( **{ - f'{parent}__in': keys, + f"{parent}__in": keys, **filters, } - ).exclude(widget_id__in=Widget.DEPRECATED_TYPES)\ - .annotate(conditional_parent_widget_type=models.F('conditional_parent_widget__widget_id'))\ - .order_by('order', 'id') + ) + .exclude(widget_id__in=Widget.DEPRECATED_TYPES) + .annotate(conditional_parent_widget_type=models.F("conditional_parent_widget__widget_id")) + .order_by("order", "id") + ) _map = defaultdict(list) for widget in qs: - _map[getattr(widget, f'{parent}_id')].append(widget) + _map[getattr(widget, f"{parent}_id")].append(widget) return Promise.resolve([_map[key] for key in keys]) def batch_load_fn(self, keys): - return self.load_widgets(keys, 'analysis_framework') + return self.load_widgets(keys, "analysis_framework") class SecondaryWidgetLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - return WidgetLoader.load_widgets(keys, 'analysis_framework', section__isnull=True) + return WidgetLoader.load_widgets(keys, "analysis_framework", section__isnull=True) class SectionWidgetLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - return WidgetLoader.load_widgets(keys, 'section') + return WidgetLoader.load_widgets(keys, "section") class SectionLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = Section.objects.filter(analysis_framework__in=keys).order_by('order', 'id') + qs = Section.objects.filter(analysis_framework__in=keys).order_by("order", "id") _map = defaultdict(list) for section in qs: _map[section.analysis_framework_id].append(section) @@ -83,9 +85,7 @@ def batch_load_fn(self, keys): class MembershipLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = AnalysisFrameworkMembership.objects\ - .filter(framework__in=keys)\ - .select_related('role', 'member', 'added_by') + qs = AnalysisFrameworkMembership.objects.filter(framework__in=keys).select_related("role", "member", "added_by") _map = defaultdict(list) for section in qs: _map[section.framework_id].append(section) @@ -96,7 +96,7 @@ class AnalysisFrameworkTagsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): qs = AnalysisFramework.tags.through.objects.filter( analysisframework__in=keys, - ).select_related('analysisframeworktag') + ).select_related("analysisframeworktag") _map = defaultdict(list) for row in qs: _map[row.analysisframework_id].append(row.analysisframeworktag) @@ -105,32 +105,26 @@ def batch_load_fn(self, keys): class AnalysisFrameworkProjectCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - stat_qs = Project.objects\ - .filter(analysis_framework__in=keys)\ - .order_by('analysis_framework').values('analysis_framework')\ + stat_qs = ( + Project.objects.filter(analysis_framework__in=keys) + .order_by("analysis_framework") + .values("analysis_framework") .annotate( project_count=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(is_test=False) - ), + models.Count("id", filter=models.Q(is_test=False)), 0, ), test_project_count=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(is_test=True) - ), + models.Count("id", filter=models.Q(is_test=True)), 0, ), - ).values('analysis_framework', 'project_count', 'test_project_count') - _map = { - stat.pop('analysis_framework'): stat - for stat in stat_qs - } + ) + .values("analysis_framework", "project_count", "test_project_count") + ) + _map = {stat.pop("analysis_framework"): stat for stat in stat_qs} _dummy = { - 'project_count': 0, - 'test_project_count': 0, + "project_count": 0, + "test_project_count": 0, } return Promise.resolve([_map.get(key, _dummy) for key in keys]) diff --git a/apps/analysis_framework/enums.py b/apps/analysis_framework/enums.py index acc914886e..de94fe300a 100644 --- a/apps/analysis_framework/enums.py +++ b/apps/analysis_framework/enums.py @@ -3,18 +3,12 @@ get_enum_name_from_django_field, ) -from .models import ( - Widget, - Filter, - AnalysisFrameworkRole, -) - +from .models import AnalysisFrameworkRole, Filter, Widget -AnalysisFrameworkRoleTypeEnum = convert_enum_to_graphene_enum( - AnalysisFrameworkRole.Type, name='AnalysisFrameworkRoleTypeEnum') -WidgetWidgetTypeEnum = convert_enum_to_graphene_enum(Widget.WidgetType, name='WidgetWidgetTypeEnum') -WidgetWidthTypeEnum = convert_enum_to_graphene_enum(Widget.WidthType, name='WidgetWidthTypeEnum') -WidgetFilterTypeEnum = convert_enum_to_graphene_enum(Filter.FilterType, name='WidgetFilterTypeEnum') +AnalysisFrameworkRoleTypeEnum = convert_enum_to_graphene_enum(AnalysisFrameworkRole.Type, name="AnalysisFrameworkRoleTypeEnum") +WidgetWidgetTypeEnum = convert_enum_to_graphene_enum(Widget.WidgetType, name="WidgetWidgetTypeEnum") +WidgetWidthTypeEnum = convert_enum_to_graphene_enum(Widget.WidthType, name="WidgetWidthTypeEnum") +WidgetFilterTypeEnum = convert_enum_to_graphene_enum(Filter.FilterType, name="WidgetFilterTypeEnum") enum_map = { get_enum_name_from_django_field(field): enum diff --git a/apps/analysis_framework/export.py b/apps/analysis_framework/export.py index 1167b198ab..abdaf00e59 100644 --- a/apps/analysis_framework/export.py +++ b/apps/analysis_framework/export.py @@ -2,11 +2,11 @@ class ExportColumn: - TITLE = 'Title' - PILLAR = 'Pillar' - SUB_PILLAR = 'Sub pillar' - COLUMN_2D = '2D column' - SUB_COLUMN_2D = '2D sub column' + TITLE = "Title" + PILLAR = "Pillar" + SUB_PILLAR = "Sub pillar" + COLUMN_2D = "2D column" + SUB_COLUMN_2D = "2D sub column" AF_EXPORT_COLUMNS = [ @@ -27,29 +27,33 @@ def export_af_to_csv(af, file): writer = csv.DictWriter(file, fieldnames=AF_EXPORT_COLUMNS) writer.writeheader() - for widget in af.widget_set.order_by('widget_id'): + for widget in af.widget_set.order_by("widget_id"): w_type = widget.widget_id w_title = widget.title widget_prop = widget.properties or {} - if w_type == 'matrix1dWidget': - for row in widget_prop['rows']: - for cell in row['cells']: - writer.writerow({ - ExportColumn.TITLE: w_title, - ExportColumn.PILLAR: row['label'], - ExportColumn.SUB_PILLAR: cell['label'], - }) - - elif w_type == 'matrix2dWidget': - for row in widget_prop['rows']: - for sub_row in row['subRows']: - for column in widget_prop['columns']: - for sub_column in column['subColumns'] or [{'label': ''}]: - writer.writerow({ - ExportColumn.TITLE: w_title, - ExportColumn.PILLAR: row['label'], - ExportColumn.SUB_PILLAR: sub_row['label'], - ExportColumn.COLUMN_2D: column['label'], - ExportColumn.SUB_COLUMN_2D: sub_column['label'], - }) + if w_type == "matrix1dWidget": + for row in widget_prop["rows"]: + for cell in row["cells"]: + writer.writerow( + { + ExportColumn.TITLE: w_title, + ExportColumn.PILLAR: row["label"], + ExportColumn.SUB_PILLAR: cell["label"], + } + ) + + elif w_type == "matrix2dWidget": + for row in widget_prop["rows"]: + for sub_row in row["subRows"]: + for column in widget_prop["columns"]: + for sub_column in column["subColumns"] or [{"label": ""}]: + writer.writerow( + { + ExportColumn.TITLE: w_title, + ExportColumn.PILLAR: row["label"], + ExportColumn.SUB_PILLAR: sub_row["label"], + ExportColumn.COLUMN_2D: column["label"], + ExportColumn.SUB_COLUMN_2D: sub_column["label"], + } + ) diff --git a/apps/analysis_framework/factories.py b/apps/analysis_framework/factories.py index d708160508..35a0ddba83 100644 --- a/apps/analysis_framework/factories.py +++ b/apps/analysis_framework/factories.py @@ -1,27 +1,17 @@ import factory +from django.core.files.base import ContentFile from factory import fuzzy from factory.django import DjangoModelFactory -from django.core.files.base import ContentFile -from .models import ( - AnalysisFramework, - AnalysisFrameworkTag, - Section, - Widget, - Filter, -) +from .models import AnalysisFramework, AnalysisFrameworkTag, Filter, Section, Widget from .widgets.store import widget_store class AnalysisFrameworkTagFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'AF-Tag-{n}') - description = factory.Faker('sentence', nb_words=20) + title = factory.Sequence(lambda n: f"AF-Tag-{n}") + description = factory.Faker("sentence", nb_words=20) icon = factory.LazyAttribute( - lambda n: ContentFile( - factory.django.ImageField()._make_data( - {'width': 100, 'height': 100} - ), f'example_{n.title}.png' - ) + lambda n: ContentFile(factory.django.ImageField()._make_data({"width": 100, "height": 100}), f"example_{n.title}.png") ) class Meta: @@ -29,8 +19,8 @@ class Meta: class AnalysisFrameworkFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'AF-{n}') - description = factory.Faker('sentence', nb_words=20) + title = factory.Sequence(lambda n: f"AF-{n}") + description = factory.Faker("sentence", nb_words=20) class Meta: model = AnalysisFramework @@ -45,15 +35,15 @@ def tags(self, create, extracted, **_): class SectionFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Section-{n}') + title = factory.Sequence(lambda n: f"Section-{n}") class Meta: model = Section class WidgetFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Widget-{n}') - key = factory.Sequence(lambda n: f'widget-key-{n}') + title = factory.Sequence(lambda n: f"Widget-{n}") + key = factory.Sequence(lambda n: f"widget-key-{n}") widget_id = fuzzy.FuzzyChoice(widget_store.keys()) properties = {} version = 1 @@ -63,8 +53,8 @@ class Meta: class AfFilterFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Widget-filter-{n}') - key = factory.Sequence(lambda n: f'widget-filter-key-{n}') + title = factory.Sequence(lambda n: f"Widget-filter-{n}") + key = factory.Sequence(lambda n: f"widget-filter-key-{n}") properties = {} class Meta: diff --git a/apps/analysis_framework/filter_set.py b/apps/analysis_framework/filter_set.py index 8186f1c798..e1ab919bbb 100644 --- a/apps/analysis_framework/filter_set.py +++ b/apps/analysis_framework/filter_set.py @@ -1,32 +1,32 @@ -from django.db import models +from datetime import timedelta + import django_filters +from django.db import models +from django.utils import timezone +from entry.models import Entry +from user_resource.filters import UserResourceFilterSet, UserResourceGqlFilterSet -from user_resource.filters import ( - UserResourceFilterSet, - UserResourceGqlFilterSet, -) from utils.graphene.filters import IDListFilter -from .models import ( - AnalysisFramework, - AnalysisFrameworkTag, -) -from entry.models import Entry -from django.utils import timezone -from datetime import timedelta +from .models import AnalysisFramework, AnalysisFrameworkTag class AnalysisFrameworkFilterSet(UserResourceFilterSet): class Meta: model = AnalysisFramework - fields = ('id', 'title', 'description', 'created_at',) + fields = ( + "id", + "title", + "description", + "created_at", + ) filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } @@ -34,34 +34,30 @@ class Meta: # ----------------------------- Graphql Filters --------------------------------------- class AnalysisFrameworkTagGqFilterSet(django_filters.FilterSet): - search = django_filters.CharFilter(method='search_filter') + search = django_filters.CharFilter(method="search_filter") class Meta: model = AnalysisFrameworkTag - fields = ['id'] + fields = ["id"] def search_filter(self, qs, _, value): if value: - return qs.filter( - models.Q(title__icontains=value) | - models.Q(description__icontains=value) - ) + return qs.filter(models.Q(title__icontains=value) | models.Q(description__icontains=value)) return qs class AnalysisFrameworkGqFilterSet(UserResourceGqlFilterSet): - search = django_filters.CharFilter(method='search_filter') - is_current_user_member = django_filters.BooleanFilter( - field_name='is_current_user_member', method='filter_with_membership') + search = django_filters.CharFilter(method="search_filter") + is_current_user_member = django_filters.BooleanFilter(field_name="is_current_user_member", method="filter_with_membership") recently_used = django_filters.BooleanFilter( - method='filter_recently_used', - label='Recently Used', + method="filter_recently_used", + label="Recently Used", ) tags = IDListFilter(distinct=True) class Meta: model = AnalysisFramework - fields = ['id'] + fields = ["id"] def search_filter(self, qs, _, value): if value: @@ -84,5 +80,5 @@ def filter_recently_used(self, queryset, name, value): # Calculate the date for "recent" usage (e.g., within the last 6 months or 180 days) recent_usage_cutoff = timezone.now() - timedelta(days=180) entries_qs = Entry.objects.filter(modified_at__gte=recent_usage_cutoff) - return queryset.filter(id__in=entries_qs.values('analysis_framework')) + return queryset.filter(id__in=entries_qs.values("analysis_framework")) return queryset diff --git a/apps/analysis_framework/management/commands/add_af_owner_roles.py b/apps/analysis_framework/management/commands/add_af_owner_roles.py index 1934360e8b..e60a411910 100644 --- a/apps/analysis_framework/management/commands/add_af_owner_roles.py +++ b/apps/analysis_framework/management/commands/add_af_owner_roles.py @@ -1,14 +1,11 @@ +from analysis_framework.models import AnalysisFramework as AF +from analysis_framework.models import AnalysisFrameworkMembership as AFMembership +from analysis_framework.models import AnalysisFrameworkRole as AFRole from django.core.management.base import BaseCommand -from analysis_framework.models import ( - AnalysisFramework as AF, - AnalysisFrameworkRole as AFRole, - AnalysisFrameworkMembership as AFMembership, -) - class Command(BaseCommand): - help = 'Add framework owner membership for all creators of frameworks' + help = "Add framework owner membership for all creators of frameworks" def handle(self, *args, **options): add_owner_memberships_to_existing_frameworks() @@ -20,7 +17,7 @@ def add_owner_memberships_to_existing_frameworks(): is_private_role=True, can_clone_framework=False, can_edit_framework=True, - can_use_in_other_projects=True + can_use_in_other_projects=True, ) pub_role = AFRole.objects.get( @@ -28,16 +25,11 @@ def add_owner_memberships_to_existing_frameworks(): is_private_role=False, can_clone_framework=True, can_edit_framework=True, - can_use_in_other_projects=True + can_use_in_other_projects=True, ) for af in AF.objects.all(): if not AFMembership.objects.filter(framework=af, member=af.created_by).exists(): # Means creator's membership does not exist, So create one owner_role = priv_role if af.is_private else pub_role - AFMembership.objects.create( - framework=af, - member=af.created_by, - joined_at=af.created_at, - role=owner_role - ) + AFMembership.objects.create(framework=af, member=af.created_by, joined_at=af.created_at, role=owner_role) diff --git a/apps/analysis_framework/models.py b/apps/analysis_framework/models.py index 88a583d120..a3217814ba 100644 --- a/apps/analysis_framework/models.py +++ b/apps/analysis_framework/models.py @@ -1,20 +1,21 @@ import copy from typing import Union -from django.db import models from django.core.exceptions import ValidationError +from django.db import models +from organization.models import Organization +from user.models import User +from user_resource.models import UserResource from utils.common import get_enum_display -from user_resource.models import UserResource -from user.models import User -from organization.models import Organization + from .widgets import store as widgets_store class AnalysisFrameworkTag(models.Model): title = models.CharField(max_length=255) description = models.TextField(blank=True) - icon = models.FileField(upload_to='af-tag-icon/', max_length=255) + icon = models.FileField(upload_to="af-tag-icon/", max_length=255) def __str__(self): return self.title @@ -26,17 +27,16 @@ class AnalysisFramework(UserResource): Analysis is done to create entries out of leads. """ + title = models.CharField(max_length=255) description = models.TextField(blank=True, null=True) - tags = models.ManyToManyField(AnalysisFrameworkTag, related_name='+', blank=True) + tags = models.ManyToManyField(AnalysisFrameworkTag, related_name="+", blank=True) is_private = models.BooleanField(default=False) assisted_tagging_enabled = models.BooleanField(default=False) members = models.ManyToManyField( - User, blank=True, - through_fields=('framework', 'member'), - through='AnalysisFrameworkMembership' + User, blank=True, through_fields=("framework", "member"), through="AnalysisFrameworkMembership" ) properties = models.JSONField(default=dict, blank=True, null=True) @@ -44,15 +44,15 @@ class AnalysisFramework(UserResource): organization = models.ForeignKey(Organization, on_delete=models.SET_NULL, blank=True, null=True) # Image is provided by user as a reference. preview_image = models.FileField( - upload_to='af-preview-image/', max_length=255, null=True, blank=True, default=None, + upload_to="af-preview-image/", + max_length=255, + null=True, + blank=True, + default=None, ) - export = models.FileField(upload_to='af-exports/', max_length=255, null=True, blank=True, default=None) + export = models.FileField(upload_to="af-exports/", max_length=255, null=True, blank=True, default=None) # added to keep the track of cloned analysisframework - cloned_from = models.ForeignKey( - 'AnalysisFramework', - on_delete=models.SET_NULL, - null=True, blank=True - ) + cloned_from = models.ForeignKey("AnalysisFramework", on_delete=models.SET_NULL, null=True, blank=True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -68,10 +68,8 @@ def clone(self, user, overrides={}): Clone analysis framework along with all widgets, filters and exportables """ - title = overrides.get( - 'title', '{} (cloned)'.format(self.title[:230]) - ) # Strip off extra chars from title - description = overrides.get('description', '') + title = overrides.get("title", "{} (cloned)".format(self.title[:230])) # Strip off extra chars from title + description = overrides.get("description", "") clone_analysis_framework = AnalysisFramework( title=title, description=description, @@ -99,34 +97,30 @@ def clone(self, user, overrides={}): # For widgets with conditional assigned. for widget in widgets_with_conditional: widget.conditional_parent_widget_id = old_new_widgets_map[widget.conditional_parent_widget_id] - widget.save(update_fields=('conditional_parent_widget_id',)) + widget.save(update_fields=("conditional_parent_widget_id",)) return clone_analysis_framework @staticmethod def get_for(user): return AnalysisFramework.objects.all().exclude( - models.Q(is_private=True) & ~models.Q(members=user) & - ~models.Q(project__members=user) + models.Q(is_private=True) & ~models.Q(members=user) & ~models.Q(project__members=user) ) @classmethod def get_for_gq(cls, user, only_member=False): - visible_afs = cls.objects\ - .annotate( - # NOTE: This is used by permission module - current_user_role=models.Subquery( - AnalysisFrameworkMembership.objects.filter( - framework=models.OuterRef('pk'), - member=user, - ).order_by('role__type').values('role__type')[:1], - output_field=models.CharField() + visible_afs = cls.objects.annotate( + # NOTE: This is used by permission module + current_user_role=models.Subquery( + AnalysisFrameworkMembership.objects.filter( + framework=models.OuterRef("pk"), + member=user, ) - # NOTE: Exclude if af is private + user is not a member and user is not member of project using af - ).exclude( - models.Q(is_private=True) & - models.Q(current_user_role__isnull=True) & - ~models.Q(project__members=user) + .order_by("role__type") + .values("role__type")[:1], + output_field=models.CharField(), ) + # NOTE: Exclude if af is private + user is not a member and user is not member of project using af + ).exclude(models.Q(is_private=True) & models.Q(current_user_role__isnull=True) & ~models.Q(project__members=user)) if only_member: return visible_afs.filter(current_user_role__isnull=False) return visible_afs @@ -135,14 +129,14 @@ def get_current_user_role(self, user): """ Return current_user_role from instance (if get_for_gq is used or generate) """ - if hasattr(self, 'current_user_role'): + if hasattr(self, "current_user_role"): self.current_user_role: Union[str, None] return self.current_user_role # If not available generate self.current_user_role = None - self.current_user_role = AnalysisFrameworkMembership.objects\ - .filter(framework=self, member=user)\ - .values_list('role__type', flat=True).first() + self.current_user_role = ( + AnalysisFrameworkMembership.objects.filter(framework=self, member=user).values_list("role__type", flat=True).first() + ) return self.current_user_role def can_get(self, _: User): @@ -155,18 +149,12 @@ def can_modify(self, user: User): * user is super user, or * the framework belongs to a project where the user is admin """ - return ( - AnalysisFrameworkMembership.objects.filter( - member=user, - framework=self, - role__can_edit_framework=True - ).exists() - ) + return AnalysisFrameworkMembership.objects.filter(member=user, framework=self, role__can_edit_framework=True).exists() def can_clone(self, user): return ( - not self.is_private or - AnalysisFrameworkMembership.objects.filter( + not self.is_private + or AnalysisFrameworkMembership.objects.filter( member=user, framework=self, role__can_clone_framework=True, @@ -175,43 +163,36 @@ def can_clone(self, user): def get_entries_count(self): from entry.models import Entry + return Entry.objects.filter(analysis_framework=self).count() def get_or_create_owner_role(self): permission_fields = self.get_owner_permissions() - privacy_label = 'Private' if self.is_private else 'Public' + privacy_label = "Private" if self.is_private else "Public" role, _ = AnalysisFrameworkRole.objects.get_or_create( - **permission_fields, - is_private_role=self.is_private, - defaults={ - 'title': f'Owner ({privacy_label})' - } + **permission_fields, is_private_role=self.is_private, defaults={"title": f"Owner ({privacy_label})"} ) return role def get_or_create_editor_role(self): permission_fields = self.get_editor_permissions() - privacy_label = 'Private' if self.is_private else 'Public' + privacy_label = "Private" if self.is_private else "Public" role, _ = AnalysisFrameworkRole.objects.get_or_create( - **permission_fields, - is_private_role=self.is_private, - defaults={ - 'title': f'Editor ({privacy_label})' - } + **permission_fields, is_private_role=self.is_private, defaults={"title": f"Editor ({privacy_label})"} ) return role def get_or_create_default_role(self): permission_fields = self.get_default_permissions() - privacy_label = 'Private' if self.is_private else 'Public' + privacy_label = "Private" if self.is_private else "Public" role, _ = AnalysisFrameworkRole.objects.get_or_create( is_default_role=True, is_private_role=self.is_private, defaults={ **permission_fields, - 'title': f'Default({privacy_label})', - } + "title": f"Default({privacy_label})", + }, ) return role @@ -222,7 +203,7 @@ def add_member(self, user, role=None, added_by=None): framework=self, role=role, defaults={ - 'added_by': added_by, + "added_by": added_by, }, ) @@ -255,7 +236,7 @@ def get_owner_permissions(self): return permission_fields def get_active_filters(self): - current_widgets_key = self.widget_set.values_list('key', flat=True) + current_widgets_key = self.widget_set.values_list("key", flat=True) return self.filter_set.filter(widget_key__in=current_widgets_key).all() @@ -263,6 +244,7 @@ class Section(models.Model): """ Section to group widgets """ + analysis_framework_id: Union[int, None] analysis_framework = models.ForeignKey(AnalysisFramework, on_delete=models.CASCADE) title = models.CharField(max_length=100) @@ -270,7 +252,7 @@ class Section(models.Model): tooltip = models.TextField(blank=True, null=True) def __str__(self): - return f'{self.analysis_framework_id}#{self.title}' + return f"{self.analysis_framework_id}#{self.title}" def clone_to(self, analysis_framework): section_clone = copy.deepcopy(self) @@ -284,6 +266,7 @@ class Widget(models.Model): """ Widget inserted into a framework """ + class WidgetType(models.TextChoices): DATE = widgets_store.date_widget.WIDGET_ID DATE_RANGE = widgets_store.date_range_widget.WIDGET_ID @@ -300,7 +283,7 @@ class WidgetType(models.TextChoices): NUMBER_MATRIX = widgets_store.number_matrix_widget.WIDGET_ID CONDITIONAL = widgets_store.conditional_widget.WIDGET_ID TEXT = widgets_store.text_widget.WIDGET_ID - EXCERPT = 'excerptWidget', 'Excerpt #DEPRICATED' # TODO:DEPRICATED + EXCERPT = "excerptWidget", "Excerpt #DEPRICATED" # TODO:DEPRICATED DEPRECATED_TYPES = [ WidgetType.EXCERPT, @@ -309,8 +292,8 @@ class WidgetType(models.TextChoices): ] class WidthType(models.TextChoices): - FULL = 'full', 'Full' - HALF = 'half', 'Half' + FULL = "full", "Full" + HALF = "half", "Half" analysis_framework = models.ForeignKey(AnalysisFramework, on_delete=models.CASCADE) # FIXME: key shouldn't be null (Filter/Exportable have non-nullable key) @@ -327,7 +310,8 @@ class WidthType(models.TextChoices): version = models.SmallIntegerField(null=True, blank=True) # Conditional conditional_parent_widget = models.ForeignKey( - 'Widget', related_name='child_widget_conditionals', on_delete=models.SET_NULL, null=True, blank=True) + "Widget", related_name="child_widget_conditionals", on_delete=models.SET_NULL, null=True, blank=True + ) conditional_conditions = models.JSONField(default=list, blank=True) # For typing @@ -339,10 +323,11 @@ def save(self, *args, **kwargs): self.analysis_framework_id = self.section.analysis_framework_id super().save(*args, **kwargs) from .utils import update_widget + update_widget(self) def __str__(self): - return '{}:: {}:{} ({})'.format(self.analysis_framework_id, self.title, self.pk, self.widget_id) + return "{}:: {}:{} ({})".format(self.analysis_framework_id, self.title, self.pk, self.widget_id) def clone_to(self, analysis_framework, section_id): widget_clone = copy.deepcopy(self) @@ -359,9 +344,9 @@ def get_for(user): AnalysisFramework which has access to it's project """ return Widget.objects.filter( - models.Q(analysis_framework__project=None) | - models.Q(analysis_framework__project__members=user) | - models.Q(analysis_framework__project__user_groups__members=user) + models.Q(analysis_framework__project=None) + | models.Q(analysis_framework__project__members=user) + | models.Q(analysis_framework__project__user_groups__members=user) ).distinct() def can_get(self, user): @@ -375,14 +360,16 @@ class Filter(models.Model): """ A filter for a widget in an analysis framework """ + class FilterType(models.TextChoices): - TEXT = 'text', 'Text' - NUMBER = 'number', 'Number' - LIST = 'list', 'List' - INTERSECTS = 'intersects', 'Intersection between two numbers' + TEXT = "text", "Text" + NUMBER = "number", "Number" + LIST = "list", "List" + INTERSECTS = "intersects", "Intersection between two numbers" analysis_framework = models.ForeignKey( - AnalysisFramework, on_delete=models.CASCADE, + AnalysisFramework, + on_delete=models.CASCADE, ) key = models.CharField(max_length=100, db_index=True) widget_key = models.CharField(max_length=100) @@ -391,13 +378,13 @@ class FilterType(models.TextChoices): filter_type = models.CharField(max_length=20, choices=FilterType.choices, default=FilterType.LIST) class Meta: - ordering = ['title', 'widget_key', 'key'] + ordering = ["title", "widget_key", "key"] def __str__(self): - return '{} ({})'.format(self.title, self.key) + return "{} ({})".format(self.title, self.key) def get_widget_type_display(self): - widget_type = getattr(self, 'widget_type') # Included when qs_with_widget_type is used + widget_type = getattr(self, "widget_type") # Included when qs_with_widget_type is used if widget_type: return get_enum_display(Widget.WidgetType, widget_type) @@ -408,9 +395,10 @@ def qs_with_widget_type(cls): return cls.objects.annotate( widget_type=models.Subquery( Widget.objects.filter( - key=models.OuterRef('widget_key'), - analysis_framework=models.OuterRef('analysis_framework'), - ).values('widget_id')[:1], output_field=models.CharField() + key=models.OuterRef("widget_key"), + analysis_framework=models.OuterRef("analysis_framework"), + ).values("widget_id")[:1], + output_field=models.CharField(), ) ) @@ -436,9 +424,9 @@ def get_for(cls, user, with_widget_type=False): if with_widget_type: qs = cls.qs_with_widget_type() return qs.filter( - models.Q(analysis_framework__project=None) | - models.Q(analysis_framework__project__members=user) | - models.Q(analysis_framework__project__user_groups__members=user) + models.Q(analysis_framework__project=None) + | models.Q(analysis_framework__project__members=user) + | models.Q(analysis_framework__project__user_groups__members=user) ).distinct() def can_get(self, user): @@ -452,8 +440,10 @@ class Exportable(models.Model): """ Export data for given widget """ + analysis_framework = models.ForeignKey( - AnalysisFramework, on_delete=models.CASCADE, + AnalysisFramework, + on_delete=models.CASCADE, ) widget_key = models.CharField(max_length=100, db_index=True) inline = models.BooleanField(default=False) @@ -461,10 +451,10 @@ class Exportable(models.Model): data = models.JSONField(default=None, blank=True, null=True) def __str__(self): - return 'Exportable ({})'.format(self.widget_key) + return "Exportable ({})".format(self.widget_key) class Meta: - ordering = ['order'] + ordering = ["order"] @classmethod def qs_with_widget_type(cls): @@ -473,9 +463,10 @@ def qs_with_widget_type(cls): return cls.objects.annotate( widget_type=models.Subquery( Widget.objects.filter( - key=models.OuterRef('widget_key'), - analysis_framework=models.OuterRef('analysis_framework'), - ).values('widget_id')[:1], output_field=models.CharField() + key=models.OuterRef("widget_key"), + analysis_framework=models.OuterRef("analysis_framework"), + ).values("widget_id")[:1], + output_field=models.CharField(), ) ) @@ -495,13 +486,13 @@ def get_for(user): AnalysisFramework which has access to it's project """ return Exportable.objects.filter( - models.Q(analysis_framework__project=None) | - models.Q(analysis_framework__project__members=user) | - models.Q(analysis_framework__project__user_groups__members=user) + models.Q(analysis_framework__project=None) + | models.Q(analysis_framework__project__members=user) + | models.Q(analysis_framework__project__user_groups__members=user) ).distinct() def get_widget_type_display(self): - widget_type = getattr(self, 'widget_type') # Included when qs_with_widget_type is used + widget_type = getattr(self, "widget_type") # Included when qs_with_widget_type is used if widget_type: return get_enum_display(Widget.WidgetType, widget_type) @@ -516,14 +507,15 @@ class AnalysisFrameworkRole(models.Model): """ Roles for AnalysisFramework """ + class Type(models.TextChoices): - EDITOR = 'editor', 'Editor' - OWNER = 'owner', 'Owner' - DEFAULT = 'default', 'default' - PRIVATE_EDITOR = 'private_editor', 'Private Editor' - PRIVATE_OWNER = 'private_owner', 'Private Owner' - PRIVATE_VIEWER = 'private_viewer', 'Private Viewer' - UNKNOWN = 'unknown', 'Unknown' + EDITOR = "editor", "Editor" + OWNER = "owner", "Owner" + DEFAULT = "default", "default" + PRIVATE_EDITOR = "private_editor", "Private Editor" + PRIVATE_OWNER = "private_owner", "Private Owner" + PRIVATE_VIEWER = "private_viewer", "Private Viewer" + UNKNOWN = "unknown", "Unknown" PRIVATE_TYPES = [ Type.PRIVATE_EDITOR, @@ -531,10 +523,10 @@ class Type(models.TextChoices): Type.PRIVATE_VIEWER, ] - CAN_ADD_USER = 'can_add_user' - CAN_CLONE_FRAMEWORK = 'can_clone_framework' - CAN_EDIT_FRAMEWORK = 'can_edit_framework' - CAN_USE_IN_OTHER_PROJECTS = 'can_use_in_other_projects' + CAN_ADD_USER = "can_add_user" + CAN_CLONE_FRAMEWORK = "can_clone_framework" + CAN_EDIT_FRAMEWORK = "can_edit_framework" + CAN_USE_IN_OTHER_PROJECTS = "can_use_in_other_projects" PERMISSION_FIELDS = ( CAN_ADD_USER, @@ -561,12 +553,12 @@ class Type(models.TextChoices): class Meta: unique_together = ( - 'can_add_user', - 'can_clone_framework', - 'can_edit_framework', - 'can_use_in_other_projects', - 'is_default_role', - 'is_private_role', + "can_add_user", + "can_clone_framework", + "can_edit_framework", + "can_use_in_other_projects", + "is_default_role", + "is_private_role", ) def __str__(self): @@ -574,28 +566,26 @@ def __str__(self): @property def permissions(self): - return { - x: self.__dict__[x] - for x in AnalysisFrameworkRole.PERMISSION_FIELDS - } + return {x: self.__dict__[x] for x in AnalysisFrameworkRole.PERMISSION_FIELDS} def clean(self): if self.is_private_role: if self.type not in self.PRIVATE_TYPES: - raise ValidationError({ - 'type': f'{self.type} is not allowed for Private Roles.', - }) + raise ValidationError( + { + "type": f"{self.type} is not allowed for Private Roles.", + } + ) elif self.type in self.PRIVATE_TYPES: - raise ValidationError({ - 'type': f'{self.type} is not allowed for Public Roles.', - }) + raise ValidationError( + { + "type": f"{self.type} is not allowed for Public Roles.", + } + ) class AnalysisFrameworkMembership(models.Model): - member = models.ForeignKey( - User, on_delete=models.CASCADE, - related_name='framework_membership' - ) + member = models.ForeignKey(User, on_delete=models.CASCADE, related_name="framework_membership") framework = models.ForeignKey(AnalysisFramework, on_delete=models.CASCADE) role = models.ForeignKey( AnalysisFrameworkRole, @@ -603,19 +593,18 @@ class AnalysisFrameworkMembership(models.Model): ) joined_at = models.DateTimeField(auto_now_add=True) added_by = models.ForeignKey( - User, on_delete=models.CASCADE, - null=True, blank=True, default=None, + User, + on_delete=models.CASCADE, + null=True, + blank=True, + default=None, ) class Meta: - unique_together = ('member', 'framework') + unique_together = ("member", "framework") @staticmethod def get_for(user): return AnalysisFrameworkMembership.objects.filter( - ( - models.Q(member=user) & - models.Q(role__can_add_user=True) - ) | - models.Q(framework__members=user), + (models.Q(member=user) & models.Q(role__can_add_user=True)) | models.Q(framework__members=user), ) diff --git a/apps/analysis_framework/mutation.py b/apps/analysis_framework/mutation.py index eb8d99ea4f..bc81c4043e 100644 --- a/apps/analysis_framework/mutation.py +++ b/apps/analysis_framework/mutation.py @@ -1,38 +1,30 @@ import graphene +from django.core.exceptions import PermissionDenied from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField -from django.core.exceptions import PermissionDenied from deep.permissions import AnalysisFrameworkPermissions as AfP - from utils.graphene.mutation import ( - generate_input_type_for_serializer, - GrapheneMutation, - AfGrapheneMutation, AfBulkGrapheneMutation, + AfGrapheneMutation, + GrapheneMutation, + generate_input_type_for_serializer, ) -from .models import ( - AnalysisFramework, - AnalysisFrameworkMembership, -) +from .models import AnalysisFramework, AnalysisFrameworkMembership +from .schema import AnalysisFrameworkDetailType, AnalysisFrameworkMembershipType +from .serializers import AnalysisFrameworkGqlSerializer as AnalysisFrameworkSerializer from .serializers import ( - AnalysisFrameworkGqlSerializer as AnalysisFrameworkSerializer, AnalysisFrameworkMembershipGqlSerializer as AnalysisFrameworkMembershipSerializer, ) -from .schema import ( - AnalysisFrameworkDetailType, - AnalysisFrameworkMembershipType, -) - AnalysisFrameworkInputType = generate_input_type_for_serializer( - 'AnalysisFrameworkInputType', + "AnalysisFrameworkInputType", serializer_class=AnalysisFrameworkSerializer, ) AnalysisFrameworkMembershipInputType = generate_input_type_for_serializer( - 'AnalysisFrameworkMembershipInputType', + "AnalysisFrameworkMembershipInputType", serializer_class=AnalysisFrameworkMembershipSerializer, ) @@ -64,7 +56,7 @@ class Arguments: @classmethod def perform_mutate(cls, root, info, **kwargs): - kwargs['id'] = info.context.active_af.id + kwargs["id"] = info.context.active_af.id return super().perform_mutate(root, info, **kwargs) @@ -100,13 +92,14 @@ class AnalysisFrameworkMutationType(DjangoObjectType): """ This mutation is for other scoped objects """ + analysis_framework_update = UpdateAnalysisFramework.Field() analysis_framework_membership_bulk = BulkUpdateAnalysisFrameworkMembership.Field() class Meta: model = AnalysisFramework skip_registry = True - fields = ('id', 'title') + fields = ("id", "title") @staticmethod def get_custom_node(_, info, id): diff --git a/apps/analysis_framework/permissions.py b/apps/analysis_framework/permissions.py index 33a9dff251..6f42eceb64 100644 --- a/apps/analysis_framework/permissions.py +++ b/apps/analysis_framework/permissions.py @@ -4,14 +4,12 @@ class FrameworkMembershipModifyPermission(permissions.BasePermission): def has_object_permission(self, request, view, obj): from .models import AnalysisFrameworkMembership + if request.method in permissions.SAFE_METHODS: return True framework = obj.framework - membership = AnalysisFrameworkMembership.objects.filter( - framework=framework, - member=request.user - ).first() + membership = AnalysisFrameworkMembership.objects.filter(framework=framework, member=request.user).first() user_role = membership and membership.role if not user_role: diff --git a/apps/analysis_framework/public_schema.py b/apps/analysis_framework/public_schema.py index b2e1d48e64..6765dc0106 100644 --- a/apps/analysis_framework/public_schema.py +++ b/apps/analysis_framework/public_schema.py @@ -2,18 +2,15 @@ from utils.graphene.types import CustomDjangoListObjectType -from .models import AnalysisFramework from .filter_set import AnalysisFrameworkGqFilterSet +from .models import AnalysisFramework class PublicAnalysisFramework(DjangoObjectType): class Meta: model = AnalysisFramework skip_registry = True - fields = ( - 'id', - 'title' - ) + fields = ("id", "title") class PublicAnalysisFrameworkListType(CustomDjangoListObjectType): diff --git a/apps/analysis_framework/schema.py b/apps/analysis_framework/schema.py index 322686a275..16d91316a7 100644 --- a/apps/analysis_framework/schema.py +++ b/apps/analysis_framework/schema.py @@ -1,41 +1,48 @@ from typing import Union import graphene -from graphene_django import DjangoObjectType, DjangoListField -from graphene_django_extras import DjangoObjectField, PageGraphqlPagination -from graphene.types.generic import GenericScalar +from assisted_tagging.models import PredictionTagAnalysisFrameworkWidgetMapping from django.db.models import QuerySet +from graphene.types.generic import GenericScalar +from graphene_django import DjangoListField, DjangoObjectType +from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from project.models import ProjectMembership +from project.schema import AnalysisFrameworkVisibleProjectType -from utils.graphene.enums import EnumDescription -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin, FileFieldType -from utils.graphene.fields import DjangoPaginatedListObjectField, generate_type_for_serializer from deep.permissions import AnalysisFrameworkPermissions as AfP -from project.schema import AnalysisFrameworkVisibleProjectType -from project.models import ProjectMembership -from assisted_tagging.models import PredictionTagAnalysisFrameworkWidgetMapping -from .models import ( - AnalysisFramework, - AnalysisFrameworkTag, - Section, - Widget, - Filter, - Exportable, - AnalysisFrameworkMembership, - AnalysisFrameworkRole, +from utils.graphene.enums import EnumDescription +from utils.graphene.fields import ( + DjangoPaginatedListObjectField, + generate_type_for_serializer, +) +from utils.graphene.types import ( + ClientIdMixin, + CustomDjangoListObjectType, + FileFieldType, ) + from .enums import ( + AnalysisFrameworkRoleTypeEnum, + WidgetFilterTypeEnum, WidgetWidgetTypeEnum, WidgetWidthTypeEnum, - WidgetFilterTypeEnum, - AnalysisFrameworkRoleTypeEnum, ) -from .serializers import AnalysisFrameworkPropertiesGqlSerializer from .filter_set import AnalysisFrameworkGqFilterSet, AnalysisFrameworkTagGqFilterSet +from .models import ( + AnalysisFramework, + AnalysisFrameworkMembership, + AnalysisFrameworkRole, + AnalysisFrameworkTag, + Exportable, + Filter, + Section, + Widget, +) from .public_schema import PublicAnalysisFrameworkListType - +from .serializers import AnalysisFrameworkPropertiesGqlSerializer AnalysisFrameworkPropertiesType = generate_type_for_serializer( - 'AnalysisFrameworkPropertiesType', + "AnalysisFrameworkPropertiesType", serializer_class=AnalysisFrameworkPropertiesGqlSerializer, ) @@ -50,13 +57,17 @@ class WidgetType(ClientIdMixin, DjangoObjectType): class Meta: model = Widget only_fields = ( - 'id', 'title', 'order', 'properties', 'version', + "id", + "title", + "order", + "properties", + "version", ) widget_id = graphene.Field(WidgetWidgetTypeEnum, required=True) - widget_id_display = EnumDescription(source='get_widget_id_display', required=True) + widget_id_display = EnumDescription(source="get_widget_id_display", required=True) width = graphene.Field(WidgetWidthTypeEnum, required=True) - width_display = EnumDescription(source='get_width_display', required=True) + width_display = EnumDescription(source="get_width_display", required=True) key = graphene.String(required=True) version = graphene.Int(required=True) conditional = graphene.Field(WidgetConditionalType) @@ -77,7 +88,10 @@ class SectionType(ClientIdMixin, DjangoObjectType): class Meta: model = Section only_fields = ( - 'id', 'title', 'order', 'tooltip', + "id", + "title", + "order", + "tooltip", ) @staticmethod @@ -89,9 +103,9 @@ class AnalysisFrameworkTagType(DjangoObjectType): class Meta: model = AnalysisFrameworkTag only_fields = ( - 'id', - 'title', - 'description', + "id", + "title", + "description", ) icon = graphene.Field(FileFieldType, required=False) @@ -107,14 +121,22 @@ class AnalysisFrameworkType(DjangoObjectType): class Meta: model = AnalysisFramework only_fields = ( - 'id', 'title', 'description', 'is_private', 'assisted_tagging_enabled', 'organization', - 'created_by', 'created_at', 'modified_by', 'modified_at', + "id", + "title", + "description", + "is_private", + "assisted_tagging_enabled", + "organization", + "created_by", + "created_at", + "modified_by", + "modified_at", ) current_user_role = graphene.Field(AnalysisFrameworkRoleTypeEnum) preview_image = graphene.Field(FileFieldType) export = graphene.Field(FileFieldType) - cloned_from = graphene.ID(source='cloned_from_id') + cloned_from = graphene.ID(source="cloned_from_id") allowed_permissions = graphene.List( graphene.NonNull( graphene.Enum.from_enum(AfP.Permission), @@ -162,10 +184,10 @@ class AnalysisFrameworkRoleType(DjangoObjectType): class Meta: model = AnalysisFrameworkRole only_fields = ( - 'id', - 'title', - 'is_private_role', - 'is_default_role', + "id", + "title", + "is_private_role", + "is_default_role", ) type = graphene.Field(AnalysisFrameworkRoleTypeEnum) @@ -174,13 +196,18 @@ class Meta: class AnalysisFrameworkFilterType(DjangoObjectType): class Meta: model = Filter - only_fields = ('id', 'title', 'properties', 'widget_key',) + only_fields = ( + "id", + "title", + "properties", + "widget_key", + ) key = graphene.String(required=True) widget_type = graphene.Field(WidgetWidgetTypeEnum, required=True) - widget_type_display = EnumDescription(source='get_widget_type_display', required=True) + widget_type_display = EnumDescription(source="get_widget_type_display", required=True) filter_type = graphene.Field(WidgetFilterTypeEnum, required=True) - filter_type_display = EnumDescription(source='get_filter_type_display', required=True) + filter_type_display = EnumDescription(source="get_filter_type_display", required=True) @staticmethod def resolve_widget_type(root, info, **kwargs): @@ -190,11 +217,16 @@ def resolve_widget_type(root, info, **kwargs): class AnalysisFrameworkExportableType(DjangoObjectType): class Meta: model = Exportable - only_fields = ('id', 'inline', 'order', 'data',) + only_fields = ( + "id", + "inline", + "order", + "data", + ) widget_key = graphene.String(required=True) widget_type = graphene.Field(WidgetWidgetTypeEnum, required=True) - widget_type_display = EnumDescription(source='get_widget_type_display', required=True) + widget_type_display = EnumDescription(source="get_widget_type_display", required=True) @staticmethod def resolve_widget_type(root, info, **kwargs): @@ -204,21 +236,21 @@ def resolve_widget_type(root, info, **kwargs): class AnalysisFrameworkMembershipType(ClientIdMixin, DjangoObjectType): class Meta: model = AnalysisFrameworkMembership - only_fields = ('id', 'member', 'role', 'joined_at', 'added_by') + only_fields = ("id", "member", "role", "joined_at", "added_by") class AnalysisFrameworkPredictionMappingType(ClientIdMixin, DjangoObjectType): - widget = graphene.ID(source='widget_id', required=True) + widget = graphene.ID(source="widget_id", required=True) widget_type = graphene.Field(WidgetWidgetTypeEnum, required=True) - tag = graphene.ID(source='tag_id') + tag = graphene.ID(source="tag_id") class Meta: model = PredictionTagAnalysisFrameworkWidgetMapping only_fields = ( - 'id', - 'widget', - 'tag', - 'association', + "id", + "widget", + "tag", + "association", ) @staticmethod @@ -246,8 +278,16 @@ class Meta: model = AnalysisFramework skip_registry = True only_fields = ( - 'id', 'title', 'description', 'is_private', 'assisted_tagging_enabled', 'organization', - 'created_by', 'created_at', 'modified_by', 'modified_at', + "id", + "title", + "description", + "is_private", + "assisted_tagging_enabled", + "organization", + "created_by", + "created_at", + "modified_by", + "modified_at", ) @staticmethod @@ -274,11 +314,10 @@ def resolve_members(root, info): @staticmethod def resolve_prediction_tags_mapping(root, info): - project_membership_qs = ProjectMembership.objects\ - .filter( - project__analysis_framework=root, - member=info.context.request.user, - ) + project_membership_qs = ProjectMembership.objects.filter( + project__analysis_framework=root, + member=info.context.request.user, + ) if root.get_current_user_role(info.context.request.user) is not None or project_membership_qs.exists(): return PredictionTagAnalysisFrameworkWidgetMapping.objects.filter( widget__analysis_framework=root, @@ -301,22 +340,13 @@ class Meta: class Query: analysis_framework = DjangoObjectField(AnalysisFrameworkDetailType) analysis_frameworks = DjangoPaginatedListObjectField( - AnalysisFrameworkListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisFrameworkListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) public_analysis_frameworks = DjangoPaginatedListObjectField( - PublicAnalysisFrameworkListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + PublicAnalysisFrameworkListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) analysis_framework_tags = DjangoPaginatedListObjectField( - AnalysisFrameworkTagListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AnalysisFrameworkTagListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) analysis_framework_roles = graphene.List(graphene.NonNull(AnalysisFrameworkRoleType), required=True) diff --git a/apps/analysis_framework/serializers.py b/apps/analysis_framework/serializers.py index a322bc1c85..14c8059611 100644 --- a/apps/analysis_framework/serializers.py +++ b/apps/analysis_framework/serializers.py @@ -1,143 +1,137 @@ +from assisted_tagging.models import PredictionTagAnalysisFrameworkWidgetMapping +from assisted_tagging.serializers import PredictionTagAnalysisFrameworkMapSerializer +from django.db import models, transaction from django.utils.functional import cached_property - from drf_dynamic_fields import DynamicFieldsMixin -from rest_framework import serializers, exceptions from drf_writable_nested.serializers import WritableNestedModelSerializer -from django.db import models -from django.db import transaction - -from deep.serializers import RemoveNullFieldsMixin, TempClientIdMixin, IntegerIDField -from user_resource.serializers import UserResourceSerializer +from organization.serializers import SimpleOrganizationSerializer +from project.change_log import ProjectChangeManager +from project.models import Project from questionnaire.serializers import FrameworkQuestionSerializer -from user.models import User, Feature +from rest_framework import exceptions, serializers +from user.models import Feature, User from user.serializers import SimpleUserSerializer -from project.models import Project -from project.change_log import ProjectChangeManager -from assisted_tagging.models import PredictionTagAnalysisFrameworkWidgetMapping -from organization.serializers import SimpleOrganizationSerializer -from assisted_tagging.serializers import PredictionTagAnalysisFrameworkMapSerializer +from user_resource.serializers import UserResourceSerializer + +from deep.serializers import IntegerIDField, RemoveNullFieldsMixin, TempClientIdMixin from .models import ( AnalysisFramework, - AnalysisFrameworkRole, AnalysisFrameworkMembership, - Widget, - Section, - Filter, + AnalysisFrameworkRole, Exportable, + Filter, + Section, + Widget, ) from .tasks import export_af_to_csv_task -class WidgetSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class WidgetSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Widget Model Serializer """ class Meta: model = Widget - fields = ('__all__') + fields = "__all__" # Validations def validate_analysis_framework(self, analysis_framework): - if not analysis_framework.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Analysis Framework') + if not analysis_framework.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Analysis Framework") return analysis_framework -class FilterSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class FilterSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Filter data Serializer """ class Meta: model = Filter - fields = ('__all__') + fields = "__all__" # Validations def validate_analysis_framework(self, analysis_framework): - if not analysis_framework.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Analysis Framework') + if not analysis_framework.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Analysis Framework") return analysis_framework -class ExportableSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class ExportableSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Export data Serializer """ class Meta: model = Exportable - fields = ('__all__') + fields = "__all__" # Validations def validate_analysis_framework(self, analysis_framework): - if not analysis_framework.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Analysis Framework') + if not analysis_framework.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Analysis Framework") return analysis_framework class SimpleWidgetSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Widget - fields = ('id', 'key', 'widget_id', 'title', 'properties', 'order', 'section') + fields = ("id", "key", "widget_id", "title", "properties", "order", "section") -class SimpleFilterSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleFilterSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Filter - fields = ('id', 'key', 'widget_key', 'title', - 'properties', 'filter_type') + fields = ("id", "key", "widget_key", "title", "properties", "filter_type") -class SimpleExportableSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleExportableSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Exportable - fields = ('id', 'widget_key', 'inline', 'order', 'data') + fields = ("id", "widget_key", "inline", "order", "data") class AnalysisFrameworkRoleSerializer( - RemoveNullFieldsMixin, serializers.ModelSerializer, + RemoveNullFieldsMixin, + serializers.ModelSerializer, ): class Meta: model = AnalysisFrameworkRole - fields = ('__all__') + fields = "__all__" class AnalysisFrameworkMembershipSerializer( - RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer, + RemoveNullFieldsMixin, + DynamicFieldsMixin, + serializers.ModelSerializer, ): - member_details = SimpleUserSerializer(read_only=True, source='member') + member_details = SimpleUserSerializer(read_only=True, source="member") role = serializers.PrimaryKeyRelatedField( required=False, queryset=AnalysisFrameworkRole.objects.all(), ) - added_by_details = SimpleUserSerializer(read_only=True, source='added_by') - role_details = AnalysisFrameworkRoleSerializer(read_only=True, source='role') + added_by_details = SimpleUserSerializer(read_only=True, source="added_by") + role_details = AnalysisFrameworkRoleSerializer(read_only=True, source="role") class Meta: model = AnalysisFrameworkMembership - fields = ('__all__') + fields = "__all__" def create(self, validated_data): - user = self.context['request'].user - framework = validated_data.get('framework') + user = self.context["request"].user + framework = validated_data.get("framework") # NOTE: Default role is different for private and public framework # For public, two sorts of default role, one for non members and one while adding # member to af, which is editor role - default_role = framework.get_or_create_default_role() if framework.is_private else\ - framework.get_or_create_editor_role() + default_role = framework.get_or_create_default_role() if framework.is_private else framework.get_or_create_editor_role() - role = validated_data.get('role') or default_role + role = validated_data.get("role") or default_role if framework is None: - raise serializers.ValidationError('Analysis Framework does not exist') + raise serializers.ValidationError("Analysis Framework does not exist") membership = AnalysisFrameworkMembership.objects.filter( member=user, @@ -156,39 +150,34 @@ def create(self, validated_data): raise exceptions.PermissionDenied() if role.is_private_role and not framework.is_private: - raise exceptions.PermissionDenied( - {'message': 'Public framework cannot have private role'} - ) + raise exceptions.PermissionDenied({"message": "Public framework cannot have private role"}) if not role.is_private_role and framework.is_private: - raise exceptions.PermissionDenied( - {'message': 'Private framework cannot have public role'} - ) + raise exceptions.PermissionDenied({"message": "Private framework cannot have public role"}) - validated_data['role'] = role # Just in case role is not provided, add default role - validated_data['added_by'] = user # make request user to be added_by by default + validated_data["role"] = role # Just in case role is not provided, add default role + validated_data["added_by"] = user # make request user to be added_by by default return super().create(validated_data) -class AnalysisFrameworkSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer): +class AnalysisFrameworkSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): """ Analysis Framework Model Serializer """ - widgets = SimpleWidgetSerializer(source='widget_set', many=True, required=False) - filters = SimpleFilterSerializer(source='get_active_filters', many=True, read_only=True) - exportables = SimpleExportableSerializer(source='exportable_set', many=True, read_only=True) - questions = FrameworkQuestionSerializer(source='frameworkquestion_set', many=True, required=False, read_only=True) + + widgets = SimpleWidgetSerializer(source="widget_set", many=True, required=False) + filters = SimpleFilterSerializer(source="get_active_filters", many=True, read_only=True) + exportables = SimpleExportableSerializer(source="exportable_set", many=True, read_only=True) + questions = FrameworkQuestionSerializer(source="frameworkquestion_set", many=True, required=False, read_only=True) entries_count = serializers.IntegerField( - source='get_entries_count', + source="get_entries_count", read_only=True, ) is_admin = serializers.SerializerMethodField() users_with_add_permission = serializers.SerializerMethodField() visible_projects = serializers.SerializerMethodField() - all_projects_count = serializers.IntegerField(source='project_set.count', read_only=True) + all_projects_count = serializers.IntegerField(source="project_set.count", read_only=True) project = serializers.IntegerField( write_only=True, @@ -196,17 +185,18 @@ class AnalysisFrameworkSerializer(RemoveNullFieldsMixin, ) role = serializers.SerializerMethodField() - organization_details = SimpleOrganizationSerializer(source='organization', read_only=True) + organization_details = SimpleOrganizationSerializer(source="organization", read_only=True) class Meta: model = AnalysisFramework - fields = ('__all__') + fields = "__all__" def get_visible_projects(self, obj): from project.serializers import SimpleProjectSerializer + user = None - if 'request' in self.context: - user = self.context['request'].user + if "request" in self.context: + user = self.context["request"].user projects = obj.project_set.exclude(models.Q(is_private=True) & ~models.Q(members=user)) return SimpleProjectSerializer(projects, context=self.context, many=True, read_only=True).data @@ -216,18 +206,15 @@ def get_users_with_add_permission(self, obj): """ return SimpleUserSerializer( User.objects.filter( - id__in=obj.analysisframeworkmembership_set.filter(role__can_add_user=True).values('member'), + id__in=obj.analysisframeworkmembership_set.filter(role__can_add_user=True).values("member"), ).all(), context=self.context, many=True, ).data def get_role(self, obj): - user = self.context['request'].user - membership = AnalysisFrameworkMembership.objects.filter( - framework=obj, - member=user - ).first() + user = self.context["request"].user + membership = AnalysisFrameworkMembership.objects.filter(framework=obj, member=user).first() role = None if not membership and not obj.is_private: @@ -243,28 +230,22 @@ def validate_project(self, project): try: project = Project.objects.get(id=project) except Project.DoesNotExist: - raise serializers.ValidationError( - 'Project matching query does not exist' - ) + raise serializers.ValidationError("Project matching query does not exist") - if not project.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if not project.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project.id def create(self, validated_data): - project = validated_data.pop('project', None) - private = validated_data.get('is_private', False) + project = validated_data.pop("project", None) + private = validated_data.get("is_private", False) # Check if user has access to private project feature - user = self.context['request'].user - private_access = user.profile.get_accessible_features().filter( - key=Feature.FeatureKey.PRIVATE_PROJECT - ).exists() + user = self.context["request"].user + private_access = user.profile.get_accessible_features().filter(key=Feature.FeatureKey.PRIVATE_PROJECT).exists() if private and not private_access: - raise exceptions.PermissionDenied({ - "message": "You don't have permission to create private framework" - }) + raise exceptions.PermissionDenied({"message": "You don't have permission to create private framework"}) af = super().create(validated_data) @@ -275,32 +256,30 @@ def create(self, validated_data): project.save() owner_role = af.get_or_create_owner_role() - af.add_member(self.context['request'].user, owner_role) + af.add_member(self.context["request"].user, owner_role) return af def update(self, instance, validated_data): - if 'is_private' not in validated_data: + if "is_private" not in validated_data: return super().update(instance, validated_data) - if instance.is_private != validated_data['is_private']: - raise exceptions.PermissionDenied({ - "message": "You don't have permission to change framework's privacy" - }) + if instance.is_private != validated_data["is_private"]: + raise exceptions.PermissionDenied({"message": "You don't have permission to change framework's privacy"}) return super().update(instance, validated_data) def get_is_admin(self, analysis_framework): - return analysis_framework.can_modify(self.context['request'].user) + return analysis_framework.can_modify(self.context["request"].user) # ------------------ Graphql seriazliers ----------------------------------- -def validate_items_limit(items, limit, error_message='Only %d items are allowed. Provided: %d'): +def validate_items_limit(items, limit, error_message="Only %d items are allowed. Provided: %d"): if items: count = len(items) if count > limit: raise serializers.ValidationError(error_message % (limit, count)) -class AfWidgetLimit(): +class AfWidgetLimit: MAX_SECTIONS_ALLOWED = 5 MAX_WIDGETS_ALLOWED_PER_SECTION = 10 MAX_WIDGETS_ALLOWED_IN_SECONDARY_TAGGING = 100 @@ -320,22 +299,29 @@ class WidgetGqlSerializer(TempClientIdMixin, serializers.ModelSerializer): class Meta: model = Widget fields = ( - 'id', 'key', 'widget_id', 'title', 'order', 'width', 'version', - 'properties', 'conditional', - 'client_id', + "id", + "key", + "widget_id", + "title", + "order", + "width", + "version", + "properties", + "conditional", + "client_id", ) @cached_property def framework(self): - framework = self.context['request'].active_af + framework = self.context["request"].active_af # This is a rare case, just to make sure this is validated if self.instance and self.instance.analysis_framework != framework: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return framework def validate_widget_id(self, widget_type): if widget_type in Widget.DEPRECATED_TYPES: - raise serializers.ValidationError(f'Widget Type {widget_type} is not supported anymore!!') + raise serializers.ValidationError(f"Widget Type {widget_type} is not supported anymore!!") return widget_type def validate_conditional(self, conditional): @@ -345,46 +331,46 @@ def validate_conditional(self, conditional): ) if self.framework is None: raise serializers.ValidationError("Conditional isn't supported in creation of AF.") - parent_widget = conditional['parent_widget'] - conditions = conditional['conditions'] + parent_widget = conditional["parent_widget"] + conditions = conditional["conditions"] if parent_widget.analysis_framework_id != self.framework.id: - raise serializers.ValidationError('Parent widget should be of same AF') + raise serializers.ValidationError("Parent widget should be of same AF") return dict( conditional_parent_widget=parent_widget, conditional_conditions=conditions, ) def validate(self, data): - if 'conditional' in data: - data.update(data.pop('conditional')) + if "conditional" in data: + data.update(data.pop("conditional")) return data # TODO: Using WritableNestedModelSerializer here, let's use this everywhere instead of using custom serializer. class SectionGqlSerializer(TempClientIdMixin, WritableNestedModelSerializer): id = IntegerIDField(required=False) - widgets = WidgetGqlSerializer(source='widget_set', many=True, required=False) + widgets = WidgetGqlSerializer(source="widget_set", many=True, required=False) class Meta: model = Section fields = ( - 'id', 'title', 'order', 'tooltip', - 'widgets', - 'client_id', + "id", + "title", + "order", + "tooltip", + "widgets", + "client_id", ) # NOTE: Overriding perform_nested_delete_or_update to have custom behaviour for section->widgets on delete def perform_nested_delete_or_update(self, pks_to_delete, model_class, instance, related_field, field_source): if model_class != Widget: - return super().perform_nested_delete_or_update( - pks_to_delete, model_class, instance, related_field, field_source - ) + return super().perform_nested_delete_or_update(pks_to_delete, model_class, instance, related_field, field_source) # Ignore on_delete, just delete the widgets if removed from Section instead of # just removing section from widget which is the default behaviour for WritableNestedModelSerializer # https://github.com/beda-software/drf-writable-nested/blob/master/drf_writable_nested/mixins.py#L302-L308 qs = Widget.objects.filter( - section=self.instance, # NOTE: Adding this additional filter just to make sure - pk__in=pks_to_delete + section=self.instance, pk__in=pks_to_delete # NOTE: Adding this additional filter just to make sure ) qs.delete() @@ -399,9 +385,7 @@ def _get_prefetch_related_instances_qs(self, qs): def validate_widgets(self, items): # Check max limit for widgets validate_items_limit( - items, - AfWidgetLimit.MAX_WIDGETS_ALLOWED_PER_SECTION, - error_message='Only %d widgets are allowed. Provided: %d' + items, AfWidgetLimit.MAX_WIDGETS_ALLOWED_PER_SECTION, error_message="Only %d widgets are allowed. Provided: %d" ) return items @@ -426,12 +410,12 @@ def _validate_widget_with_widget_type(data, widget_type, many=False): if many: return [] if many: - ids = [item['pk'] for item in data] + ids = [item["pk"] for item in data] widgets = list(Widget.objects.filter(pk__in=ids)) widgets_type = list(set([widget.widget_id for widget in widgets])) if widgets_type and widgets_type != [widget_type]: raise serializers.ValidationError( - f'Different widget type was provided. Required: {widget_type} Provided: {widgets_type}', + f"Different widget type was provided. Required: {widget_type} Provided: {widgets_type}", ) return [ # Only return available widgets. Make sure to follow AnalysisFrameworkPropertiesStatsConfigIdGqlSerializer @@ -439,7 +423,7 @@ def _validate_widget_with_widget_type(data, widget_type, many=False): for widget in widgets ] # For single widget - pk = data['pk'] + pk = data["pk"] try: widget = Widget.objects.get(pk=pk) except Widget.DoesNotExist: @@ -448,7 +432,7 @@ def _validate_widget_with_widget_type(data, widget_type, many=False): ) if widget.widget_id != widget_type: raise serializers.ValidationError( - f'Different widget type was provided. Required: {widget_type} Provided: {widget.widget_id}', + f"Different widget type was provided. Required: {widget_type} Provided: {widget.widget_id}", ) return data @@ -479,7 +463,7 @@ class AnalysisFrameworkPropertiesGqlSerializer(serializers.Serializer): class AnalysisFrameworkGqlSerializer(UserResourceSerializer): - primary_tagging = SectionGqlSerializer(source='section_set', many=True, required=False) + primary_tagging = SectionGqlSerializer(source="section_set", many=True, required=False) secondary_tagging = WidgetGqlSerializer(many=True, write_only=False, required=False) prediction_tags_mapping = PredictionTagAnalysisFrameworkMapSerializer(many=True, write_only=False, required=False) properties = AnalysisFrameworkPropertiesGqlSerializer(required=False, allow_null=True) @@ -488,10 +472,20 @@ class AnalysisFrameworkGqlSerializer(UserResourceSerializer): class Meta: model = AnalysisFramework fields = ( - 'title', 'description', 'is_private', 'properties', 'organization', 'preview_image', - 'created_at', 'created_by', 'modified_at', 'modified_by', - 'primary_tagging', 'secondary_tagging', - 'prediction_tags_mapping', 'assisted_tagging_enabled', + "title", + "description", + "is_private", + "properties", + "organization", + "preview_image", + "created_at", + "created_by", + "modified_at", + "modified_by", + "primary_tagging", + "secondary_tagging", + "prediction_tags_mapping", + "assisted_tagging_enabled", ) # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -506,23 +500,17 @@ def validate_is_private(self, value): # Changing AF Privacy is not allowed (Existing AF) if self.instance: if self.instance.is_private != value: - raise exceptions.PermissionDenied({ - "is_private": "You don't have permission to change framework's privacy" - }) + raise exceptions.PermissionDenied({"is_private": "You don't have permission to change framework's privacy"}) return value # Requires feature access for Private project (New AF) - if value and not self.context['request'].user.have_feature_access(Feature.FeatureKey.PRIVATE_PROJECT): - raise exceptions.PermissionDenied({ - "is_private": "You don't have permission to create/update private framework" - }) + if value and not self.context["request"].user.have_feature_access(Feature.FeatureKey.PRIVATE_PROJECT): + raise exceptions.PermissionDenied({"is_private": "You don't have permission to create/update private framework"}) return value def validate_primary_tagging(self, items): # Check max limit for sections validate_items_limit( - items, - AfWidgetLimit.MAX_SECTIONS_ALLOWED, - error_message='Only %d sections are allowed. Provided: %d' + items, AfWidgetLimit.MAX_SECTIONS_ALLOWED, error_message="Only %d sections are allowed. Provided: %d" ) return items @@ -531,7 +519,7 @@ def validate_secondary_tagging(self, items): validate_items_limit( items, AfWidgetLimit.MAX_WIDGETS_ALLOWED_IN_SECONDARY_TAGGING, - error_message='Only %d widgets are allowed. Provided: %d' + error_message="Only %d widgets are allowed. Provided: %d", ) return items @@ -541,32 +529,25 @@ def validate_prediction_tags_mapping(self, prediction_tags_mapping): raise serializers.ValidationError("Can't create prediction tag mapping for new framework. Save first!") if not prediction_tags_mapping: return prediction_tags_mapping - widget_qs = Widget.objects.filter( - id__in=[ - _map['widget'].pk - for _map in prediction_tags_mapping - ] - ) - if list(widget_qs.values_list('analysis_framework', flat=True).distinct()) != [framework.pk]: - raise serializers.ValidationError('Found widgets from another Analysis Framework') + widget_qs = Widget.objects.filter(id__in=[_map["widget"].pk for _map in prediction_tags_mapping]) + if list(widget_qs.values_list("analysis_framework", flat=True).distinct()) != [framework.pk]: + raise serializers.ValidationError("Found widgets from another Analysis Framework") return prediction_tags_mapping def _delete_old_secondary_taggings(self, af, secondary_tagging): - current_ids = [ - widget_data['id'] for widget_data in secondary_tagging - if 'id' in widget_data - ] - qs_to_delete = Widget.objects\ - .filter( - analysis_framework=af, - section__isnull=True, # NOTE: section are null for secondary taggings - ).exclude(pk__in=current_ids) # Exclude current provided widgets + current_ids = [widget_data["id"] for widget_data in secondary_tagging if "id" in widget_data] + qs_to_delete = Widget.objects.filter( + analysis_framework=af, + section__isnull=True, # NOTE: section are null for secondary taggings + ).exclude( + pk__in=current_ids + ) # Exclude current provided widgets qs_to_delete.delete() def _save_secondary_taggings(self, af, secondary_tagging): # Create secondary tagging widgets (Primary/Section widgets are created using WritableNestedModelSerializer) for widget_data in secondary_tagging: - id = widget_data.get('id') + id = widget_data.get("id") widget = None if id: widget = Widget.objects.filter(analysis_framework=af, pk=id).first() @@ -581,26 +562,23 @@ def _save_secondary_taggings(self, af, secondary_tagging): serializer.save(analysis_framework=af) def _delete_old_prediction_tags_mapping(self, af, prediction_tags_mapping): - current_ids = [ - mapping['id'] - for mapping in prediction_tags_mapping - if 'id' in mapping - ] - qs_to_delete = PredictionTagAnalysisFrameworkWidgetMapping.objects\ - .filter( - widget__analysis_framework=af, - ).exclude(pk__in=current_ids) # Exclude current provided widgets + current_ids = [mapping["id"] for mapping in prediction_tags_mapping if "id" in mapping] + qs_to_delete = PredictionTagAnalysisFrameworkWidgetMapping.objects.filter( + widget__analysis_framework=af, + ).exclude( + pk__in=current_ids + ) # Exclude current provided widgets qs_to_delete.delete() def _save_prediction_tags_mapping(self, af, prediction_tags_mapping): # Create secondary tagging widgets (Primary/Section widgets are created using WritableNestedModelSerializer) for prediction_tag_mapping in prediction_tags_mapping: - id = prediction_tag_mapping.get('id') + id = prediction_tag_mapping.get("id") mapping = None if id: mapping = PredictionTagAnalysisFrameworkWidgetMapping.objects.filter( widget__analysis_framework=af, - widget=prediction_tag_mapping['widget'], + widget=prediction_tag_mapping["widget"], pk=id, ).first() serializer = PredictionTagAnalysisFrameworkMapSerializer( @@ -614,15 +592,13 @@ def _save_prediction_tags_mapping(self, af, prediction_tags_mapping): serializer.save() def _post_save(self, instance): - transaction.on_commit( - lambda: export_af_to_csv_task.delay(instance.pk) - ) + transaction.on_commit(lambda: export_af_to_csv_task.delay(instance.pk)) def create(self, validated_data): - validated_data.pop('secondary_tagging', None) - validated_data.pop('prediction_tags_mapping', None) - secondary_tagging = self.initial_data.get('secondary_tagging', None) - prediction_tags_mapping = self.initial_data.get('prediction_tags_mapping', None) + validated_data.pop("secondary_tagging", None) + validated_data.pop("prediction_tags_mapping", None) + secondary_tagging = self.initial_data.get("secondary_tagging", None) + prediction_tags_mapping = self.initial_data.get("prediction_tags_mapping", None) # Create AF instance = super().create(validated_data) if prediction_tags_mapping: @@ -632,17 +608,17 @@ def create(self, validated_data): # TODO: Check if there are any recursive conditionals # Create a owner role owner_role = instance.get_or_create_owner_role() - instance.add_member(self.context['request'].user, owner_role) + instance.add_member(self.context["request"].user, owner_role) # NOTE: Set current_user_role value. (get_current_user_role) instance.current_user_role = owner_role.type self._post_save(instance) return instance def update(self, instance, validated_data): - validated_data.pop('secondary_tagging', None) - validated_data.pop('prediction_tags_mapping', None) - secondary_tagging = self.initial_data.get('secondary_tagging', None) - prediction_tags_mapping = self.initial_data.get('prediction_tags_mapping', None) + validated_data.pop("secondary_tagging", None) + validated_data.pop("prediction_tags_mapping", None) + secondary_tagging = self.initial_data.get("secondary_tagging", None) + prediction_tags_mapping = self.initial_data.get("prediction_tags_mapping", None) # Update AF instance = super().update(instance, validated_data) # Update secondary_tagging @@ -656,7 +632,7 @@ def update(self, instance, validated_data): if instance.created_by_id and not instance.members.filter(id=instance.created_by_id).exists(): owner_role = instance.get_or_create_owner_role() instance.add_member(instance.created_by, owner_role) - ProjectChangeManager.log_framework_update(instance.pk, self.context['request'].user) + ProjectChangeManager.log_framework_update(instance.pk, self.context["request"].user) self._post_save(instance) return instance @@ -667,17 +643,14 @@ class AnalysisFrameworkMembershipGqlSerializer(TempClientIdMixin, serializers.Mo class Meta: model = AnalysisFrameworkMembership - fields = ( - 'id', 'member', 'role', - 'client_id' - ) + fields = ("id", "member", "role", "client_id") @cached_property def framework(self): - framework = self.context['request'].active_af + framework = self.context["request"].active_af # This is a rare case, just to make sure this is validated if self.instance and self.instance.framework != framework: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return framework def _get_default_role(self): @@ -692,20 +665,20 @@ def _get_default_role(self): def validate_member(self, member): current_members = AnalysisFrameworkMembership.objects.filter(framework=self.framework, member=member) if current_members.exclude(pk=self.instance and self.instance.pk).exists(): - raise serializers.ValidationError('User is already a member!') + raise serializers.ValidationError("User is already a member!") return member def validate_role(self, role): if role.is_private_role and not self.framework.is_private: - raise serializers.ValidationError('Public framework cannot have private role') + raise serializers.ValidationError("Public framework cannot have private role") if not role.is_private_role and self.framework.is_private: - raise serializers.ValidationError('Private framework cannot have public role') + raise serializers.ValidationError("Private framework cannot have public role") return role def create(self, validated_data): # use default role if not provided on creation. - validated_data['role'] = validated_data.get('role', self._get_default_role()) + validated_data["role"] = validated_data.get("role", self._get_default_role()) # make request user to be added_by by default - validated_data['framework'] = self.framework - validated_data['added_by'] = self.context['request'].user + validated_data["framework"] = self.framework + validated_data["added_by"] = self.context["request"].user return super().create(validated_data) diff --git a/apps/analysis_framework/tasks.py b/apps/analysis_framework/tasks.py index 7d50645697..d70956c705 100644 --- a/apps/analysis_framework/tasks.py +++ b/apps/analysis_framework/tasks.py @@ -1,31 +1,31 @@ import logging +from analysis_framework.export import export_af_to_csv from celery import shared_task from django.utils import timezone -from utils.common import redis_lock, get_temp_file +from utils.common import get_temp_file, redis_lock from utils.files import generate_file_for_upload from .models import AnalysisFramework -from analysis_framework.export import export_af_to_csv logger = logging.getLogger(__name__) @shared_task -@redis_lock('af_export__{0}') +@redis_lock("af_export__{0}") def export_af_to_csv_task(af_id): try: af = AnalysisFramework.objects.get(id=af_id) - with get_temp_file(suffix='.csv', mode='w+') as file: + with get_temp_file(suffix=".csv", mode="w+") as file: export_af_to_csv(af, file) - time_str = timezone.now().strftime('%Y-%m-%d%z') + time_str = timezone.now().strftime("%Y-%m-%d%z") file.seek(0) af.export.save( - f'AF_Export_{af.id}_{time_str}.csv', + f"AF_Export_{af.id}_{time_str}.csv", generate_file_for_upload(file), ) except Exception: - logger.error(f'Failed to export AF: {af_id}', exc_info=True) + logger.error(f"Failed to export AF: {af_id}", exc_info=True) return False return True diff --git a/apps/analysis_framework/tests/test_apis.py b/apps/analysis_framework/tests/test_apis.py index 0d42d84bbf..4a1a5cf00c 100644 --- a/apps/analysis_framework/tests/test_apis.py +++ b/apps/analysis_framework/tests/test_apis.py @@ -1,16 +1,12 @@ import os +from analysis_framework.models import AnalysisFramework, AnalysisFrameworkMembership from django.conf import settings - -from deep.tests import TestCase -from analysis_framework.models import ( - AnalysisFramework, - AnalysisFrameworkMembership, -) - +from organization.models import Organization from project.models import Project from user.models import User -from organization.models import Organization + +from deep.tests import TestCase class AnalysisFrameworkTests(TestCase): @@ -19,22 +15,22 @@ def test_get_private_analysis_framework_not_member(self): private_framework = self.create(AnalysisFramework, is_private=True) public_framework = self.create(AnalysisFramework, is_private=False) - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 1) - self.assertEqual(response.data['results'][0]['id'], public_framework.id) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], public_framework.id) # Now get a particular private framework - url = f'/api/v1/analysis-frameworks/{private_framework.id}/' + url = f"/api/v1/analysis-frameworks/{private_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_404(response) # Now get a particular public framework, should be 200 - url = f'/api/v1/analysis-frameworks/{public_framework.id}/' + url = f"/api/v1/analysis-frameworks/{public_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) @@ -51,24 +47,24 @@ def test_get_private_analysis_framework_not_member_but_same_project(self): # Add self.user to the project, but not to framework project.add_member(self.user) - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) - framework_ids = [x['id'] for x in response.data['results']] + self.assertEqual(len(response.data["results"]), 2) + framework_ids = [x["id"] for x in response.data["results"]] assert private_framework.id in framework_ids assert public_framework.id in framework_ids # Now get a particular private framework - url = f'/api/v1/analysis-frameworks/{private_framework.id}/' + url = f"/api/v1/analysis-frameworks/{private_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) # Now get a particular public framework, should be 200 - url = f'/api/v1/analysis-frameworks/{public_framework.id}/' + url = f"/api/v1/analysis-frameworks/{public_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) @@ -84,14 +80,14 @@ def test_get_related_to_me_frameworks(self): public_af.add_member(self.user) public_af2 = self.create(AnalysisFramework, is_private=False) # noqa - url = '/api/v1/analysis-frameworks/?relatedToMe=true' + url = "/api/v1/analysis-frameworks/?relatedToMe=true" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - afs = resp.data['results'] + afs = resp.data["results"] assert len(afs) == 2, "Two frameworks are related to user" - af_ids = [x['id'] for x in afs] + af_ids = [x["id"] for x in afs] assert private_af2.id in af_ids assert public_af.id in af_ids @@ -100,65 +96,62 @@ def test_get_private_analysis_framework_by_member(self): private_framework = self.create(AnalysisFramework, is_private=True) public_framework = self.create(AnalysisFramework, is_private=False) - private_framework.add_member( - self.user, - private_framework.get_or_create_owner_role() - ) + private_framework.add_member(self.user, private_framework.get_or_create_owner_role()) public_framework.add_member(self.user) - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" self.authenticate() response = self.client.get(url) - self.assertEqual(len(response.data['results']), 2) - for framework in response.data['results']: - assert 'role' in framework - assert isinstance(framework['role'], dict) + self.assertEqual(len(response.data["results"]), 2) + for framework in response.data["results"]: + assert "role" in framework + assert isinstance(framework["role"], dict) # Now get a particular private framework - url = f'/api/v1/analysis-frameworks/{private_framework.id}/' + url = f"/api/v1/analysis-frameworks/{private_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) - assert 'role' in response.data - assert isinstance(response.data['role'], dict) - self.check_owner_roles_present(private_framework, response.data['role']) + assert "role" in response.data + assert isinstance(response.data["role"], dict) + self.check_owner_roles_present(private_framework, response.data["role"]) # Now get a particular public framework, should be 200 - url = f'/api/v1/analysis-frameworks/{public_framework.id}/' + url = f"/api/v1/analysis-frameworks/{public_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) - assert 'role' in response.data - assert isinstance(response.data['role'], dict) - self.check_default_roles_present(public_framework, response.data['role']) + assert "role" in response.data + assert isinstance(response.data["role"], dict) + self.check_default_roles_present(public_framework, response.data["role"]) def test_get_public_framework_with_roles(self): public_framework = self.create(AnalysisFramework, is_private=False) - url = f'/api/v1/analysis-frameworks/{public_framework.id}/' + url = f"/api/v1/analysis-frameworks/{public_framework.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) - assert 'role' in response.data - assert isinstance(response.data['role'], dict) - self.check_default_roles_present(public_framework, response.data['role']) + assert "role" in response.data + assert isinstance(response.data["role"], dict) + self.check_default_roles_present(public_framework, response.data["role"]) def test_get_memberships(self): framework = self.create(AnalysisFramework) framework.add_member(self.user) - url = f'/api/v1/analysis-frameworks/{framework.id}/memberships/' + url = f"/api/v1/analysis-frameworks/{framework.id}/memberships/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - data = resp.data['results'] + data = resp.data["results"] assert len(data) == 1 - assert isinstance(data[0]['member_details'], dict), "Check if member field is expanded" - assert data[0]['member'] == self.user.id - assert 'member_details' in data[0] - assert data[0]['framework'] == framework.id + assert isinstance(data[0]["member_details"], dict), "Check if member field is expanded" + assert data[0]["member"] == self.user.id + assert "member_details" in data[0] + assert data[0]["framework"] == framework.id def test_get_more_memberships_data(self): user1 = self.create_user() @@ -166,139 +159,127 @@ def test_get_more_memberships_data(self): user3 = self.create_user() user4 = self.create_user() framework = self.create(AnalysisFramework) - framework.add_member( - user=user1, - role=framework.get_or_create_owner_role(), - added_by=user2 - ) + framework.add_member(user=user1, role=framework.get_or_create_owner_role(), added_by=user2) - url = f'/api/v1/analysis-frameworks/{framework.id}/memberships/' + url = f"/api/v1/analysis-frameworks/{framework.id}/memberships/" self.authenticate() response = self.client.get(url) self.assert_200(response) - data = response.data['results'] - assert 'added_by_details' in data[0] - self.assertEqual(data[0]['added_by_details']['id'], user2.id) - assert 'role_details' in data[0] + data = response.data["results"] + assert "added_by_details" in data[0] + self.assertEqual(data[0]["added_by_details"]["id"], user2.id) + assert "role_details" in data[0] # test for the pagination support in memberships framework.add_member(user2) framework.add_member(user3) framework.add_member(user4) - url = f'/api/v1/analysis-frameworks/{framework.id}/memberships/?limit=2' + url = f"/api/v1/analysis-frameworks/{framework.id}/memberships/?limit=2" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) + self.assertEqual(len(response.data["results"]), 2) def test_create_analysis_framework(self): project = self.create(Project, role=self.admin_role) organization = self.create(Organization) - preview_image_sample = os.path.join(settings.BASE_DIR, 'apps/static/image/drop-icon.png') + preview_image_sample = os.path.join(settings.BASE_DIR, "apps/static/image/drop-icon.png") - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" data = { - 'title': 'Test AnalysisFramework Title', - 'project': project.id, - 'organization': organization.id, - 'preview_image': open(preview_image_sample, 'rb'), + "title": "Test AnalysisFramework Title", + "project": project.id, + "organization": organization.id, + "preview_image": open(preview_image_sample, "rb"), } self.authenticate() - response = self.client.post(url, data, format='multipart') + response = self.client.post(url, data, format="multipart") project.refresh_from_db() self.assert_201(response) - self.assertEqual(response.data['title'], data['title']) - self.assertEqual(response.data['organization'], data['organization']) - self.assertEqual(response.data['organization_details']['id'], organization.id) - self.assertIsNotNone(response.data['preview_image']) - self.assertEqual(project.analysis_framework_id, response.data['id']) + self.assertEqual(response.data["title"], data["title"]) + self.assertEqual(response.data["organization"], data["organization"]) + self.assertEqual(response.data["organization_details"]["id"], organization.id) + self.assertIsNotNone(response.data["preview_image"]) + self.assertEqual(project.analysis_framework_id, response.data["id"]) # test Group Membership created or not - assert AnalysisFrameworkMembership.objects.filter( - framework_id=response.data['id'], - member=self.user, - role=project.analysis_framework.get_or_create_owner_role(), - ).first() is not None, "Membership Should be created" + assert ( + AnalysisFrameworkMembership.objects.filter( + framework_id=response.data["id"], + member=self.user, + role=project.analysis_framework.get_or_create_owner_role(), + ).first() + is not None + ), "Membership Should be created" def test_clone_analysis_framework_without_name(self): analysis_framework = self.create(AnalysisFramework) - project = self.create( - Project, analysis_framework=analysis_framework, - role=self.admin_role - ) - - url = '/api/v1/clone-analysis-framework/{}/'.format( - analysis_framework.id - ) + project = self.create(Project, analysis_framework=analysis_framework, role=self.admin_role) + + url = "/api/v1/clone-analysis-framework/{}/".format(analysis_framework.id) data = { - 'project': project.id, + "project": project.id, } self.authenticate() response = self.client.post(url, data) self.assert_400(response) - assert 'title' in response.data['errors'] + assert "title" in response.data["errors"] def test_clone_analysis_framework(self): """This is relevant only to public frameworks""" analysis_framework = self.create(AnalysisFramework, is_private=False) - project = self.create( - Project, analysis_framework=analysis_framework, - role=self.admin_role - ) + project = self.create(Project, analysis_framework=analysis_framework, role=self.admin_role) # Add self.user as member to analysis framework, to check if owner membership created or not default_membership, _ = analysis_framework.add_member(self.user) # Add owner user, but this should not be in the cloned framework user = self.create(User) owner_membership, _ = analysis_framework.add_member(user, analysis_framework.get_or_create_owner_role()) - url = '/api/v1/clone-analysis-framework/{}/'.format( - analysis_framework.id - ) - cloned_title = 'Cloned AF' + url = "/api/v1/clone-analysis-framework/{}/".format(analysis_framework.id) + cloned_title = "Cloned AF" data = { - 'project': project.id, - 'title': cloned_title, - 'description': 'New Description', + "project": project.id, + "title": cloned_title, + "description": "New Description", } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertNotEqual(response.data['id'], analysis_framework.id) - self.assertEqual( - response.data['title'], - cloned_title) + self.assertNotEqual(response.data["id"], analysis_framework.id) + self.assertEqual(response.data["title"], cloned_title) project = Project.objects.get(id=project.id) new_af = project.analysis_framework self.assertNotEqual(new_af.id, analysis_framework.id) - self.assertEqual(project.analysis_framework.id, response.data['id']) + self.assertEqual(project.analysis_framework.id, response.data["id"]) # Check if description updated - assert new_af.description == data['description'], "Description should be updated" - assert new_af.title == data['title'], "Title should be updated" + assert new_af.description == data["description"], "Description should be updated" + assert new_af.title == data["title"], "Title should be updated" # Test permissions cloned # Only the requester should be the owner of the new framework assert new_af.members.all().count() == 1, "The cloned framework should have only one owner" assert AnalysisFrameworkMembership.objects.filter( - framework=new_af, role=owner_membership.role, + framework=new_af, + role=owner_membership.role, member=self.user, ).exists() def test_create_private_framework_unauthorized(self): project = self.create(Project, role=self.admin_role) - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" data = { - 'title': 'Test AnalysisFramework Title', - 'project': project.id, - 'is_private': True, + "title": "Test AnalysisFramework Title", + "project": project.id, + "is_private": True, } self.authenticate() @@ -309,14 +290,8 @@ def test_change_is_private_field(self): """Even the owner should be unable to change privacy""" private_framework = self.create(AnalysisFramework, is_private=True) public_framework = self.create(AnalysisFramework, is_private=False) - private_framework.add_member( - self.user, - private_framework.get_or_create_owner_role() - ) - public_framework.add_member( - self.user, - public_framework.get_or_create_owner_role() - ) + private_framework.add_member(self.user, private_framework.get_or_create_owner_role()) + public_framework.add_member(self.user, public_framework.get_or_create_owner_role()) self._change_framework_privacy(public_framework, 403) self._change_framework_privacy(private_framework, 403) @@ -325,22 +300,22 @@ def test_change_other_fields(self): framework = self.create(AnalysisFramework) framework.add_member(self.user, framework.get_or_create_owner_role()) - url = f'/api/v1/analysis-frameworks/{framework.id}/' + url = f"/api/v1/analysis-frameworks/{framework.id}/" put_data = { - 'title': framework.title[:-12] + '(Modified)', - 'is_private': framework.is_private, + "title": framework.title[:-12] + "(Modified)", + "is_private": framework.is_private, } self.authenticate() response = self.client.put(url, put_data) self.assert_200(response) def test_get_membersips(self): - url = '/api/v1/framework-memberships/' + url = "/api/v1/framework-memberships/" self.authenticate() response = self.client.get(url) self.assert_200(response) - for membership in response.data['results']: - self.assertEqual(membership['member'], self.user.id) + for membership in response.data["results"]: + self.assertEqual(membership["member"], self.user.id) def test_post_framework_memberships(self): user = self.create_user() @@ -348,69 +323,60 @@ def test_post_framework_memberships(self): framework = self.create(AnalysisFramework) framework.add_member(user, framework.get_or_create_owner_role()) - data = { - 'role': framework.get_or_create_owner_role().id, - 'member': user2.id, - 'framework': framework.id - } + data = {"role": framework.get_or_create_owner_role().id, "member": user2.id, "framework": framework.id} self.authenticate(user) - url = '/api/v1/framework-memberships/' + url = "/api/v1/framework-memberships/" response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['added_by'], user.id) # set request user to be added_by + self.assertEqual(response.data["added_by"], user.id) # set request user to be added_by def test_add_roles_to_public_framework_non_member(self): framework = self.create(AnalysisFramework, is_private=False) add_member_data = { - 'framework': framework.id, - 'member': self.user.id, - 'role': framework.get_or_create_editor_role().id, # Just an arbritrary role + "framework": framework.id, + "member": self.user.id, + "role": framework.get_or_create_editor_role().id, # Just an arbritrary role } self.authenticate() - url = '/api/v1/framework-memberships/' + url = "/api/v1/framework-memberships/" response = self.client.post(url, add_member_data) self.assert_403(response) def test_project_analysis_framework(self): analysis_framework = self.create(AnalysisFramework) - project = self.create( - Project, analysis_framework=analysis_framework, - role=self.admin_role - ) + project = self.create(Project, analysis_framework=analysis_framework, role=self.admin_role) - url = '/api/v1/projects/{}/analysis-framework/'.format( - project.id - ) + url = "/api/v1/projects/{}/analysis-framework/".format(project.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['id'], analysis_framework.id) - self.assertEqual(response.data['title'], analysis_framework.title) + self.assertEqual(response.data["id"], analysis_framework.id) + self.assertEqual(response.data["title"], analysis_framework.title) def test_filter_analysis_framework(self): - url = '/api/v1/analysis-frameworks/' + url = "/api/v1/analysis-frameworks/" self.authenticate() - response = self.client.get(f'{url}?activity=active&relatedToMe=True') + response = self.client.get(f"{url}?activity=active&relatedToMe=True") self.assert_200(response) def test_search_users_excluding_framework_members(self): - user1 = self.create(User, email='testuser1@tc.com') - user2 = self.create(User, email='testuser2@tc.com') - user3 = self.create(User, email='testuser3@tc.com') + user1 = self.create(User, email="testuser1@tc.com") + user2 = self.create(User, email="testuser2@tc.com") + user3 = self.create(User, email="testuser3@tc.com") framework = self.create(AnalysisFramework) framework.add_member(user1) - url = f'/api/v1/users/?members_exclude_framework={framework.id}&search=test' + url = f"/api/v1/users/?members_exclude_framework={framework.id}&search=test" self.authenticate() resp = self.client.get(url) self.assert_200(resp) data = resp.data - ids = [x['id'] for x in data['results']] + ids = [x["id"] for x in data["results"]] assert user1.id not in ids assert user2.id in ids assert user3.id in ids @@ -421,14 +387,14 @@ def test_af_project_api(self): self.create_project(is_private=False, analysis_framework=framework, role=None) private_project = self.create_project(is_private=True, analysis_framework=framework, role=None) - url = f'/api/v1/analysis-frameworks/{framework.id}/?fields=all_projects_count,visible_projects' + url = f"/api/v1/analysis-frameworks/{framework.id}/?fields=all_projects_count,visible_projects" self.authenticate(self.user) response = self.client.get(url) rjson = response.json() self.assert_200(response) - self.assertEqual(rjson['allProjectsCount'], 2) - self.assertEqual(len(rjson['visibleProjects']), 1) + self.assertEqual(rjson["allProjectsCount"], 2) + self.assertEqual(len(rjson["visibleProjects"]), 1) # Now add user to the private project private_project.add_member(self.user) @@ -436,26 +402,26 @@ def test_af_project_api(self): response = self.client.get(url) rjson = response.json() self.assert_200(response) - self.assertEqual(rjson['allProjectsCount'], 2) - self.assertEqual(len(rjson['visibleProjects']), 2) + self.assertEqual(rjson["allProjectsCount"], 2) + self.assertEqual(len(rjson["visibleProjects"]), 2) def check_owner_roles_present(self, framework, permissions): owner_permissions = framework.get_owner_permissions() for perm, val in owner_permissions.items(): - assert val == permissions[perm], f'Should match for {perm}' + assert val == permissions[perm], f"Should match for {perm}" def check_default_roles_present(self, framework, permissions): default_permissions = framework.get_default_permissions() for perm, val in default_permissions.items(): - assert val == permissions[perm], f'Should match for {perm}' + assert val == permissions[perm], f"Should match for {perm}" def _change_framework_privacy(self, framework, status=403, user=None): - url = f'/api/v1/analysis-frameworks/{framework.id}/' + url = f"/api/v1/analysis-frameworks/{framework.id}/" changed_privacy = not framework.is_private put_data = { - 'title': framework.title, - 'is_private': changed_privacy, + "title": framework.title, + "is_private": changed_privacy, # Other fields we don't care } self.authenticate(user) @@ -463,6 +429,6 @@ def _change_framework_privacy(self, framework, status=403, user=None): self.assertEqual(response.status_code, status) # Try patching, should give 403 as well - patch_data = {'is_private': changed_privacy} + patch_data = {"is_private": changed_privacy} response = self.client.patch(url, patch_data) self.assertEqual(response.status_code, status) diff --git a/apps/analysis_framework/tests/test_filters.py b/apps/analysis_framework/tests/test_filters.py index 1e34fc32f1..458ae4c669 100644 --- a/apps/analysis_framework/tests/test_filters.py +++ b/apps/analysis_framework/tests/test_filters.py @@ -1,12 +1,16 @@ from datetime import timedelta from unittest.mock import patch -from utils.graphene.tests import GraphQLTestCase +from analysis_framework.factories import ( + AnalysisFrameworkFactory, + AnalysisFrameworkTagFactory, +) from analysis_framework.filter_set import AnalysisFrameworkGqFilterSet -from analysis_framework.factories import AnalysisFrameworkFactory, AnalysisFrameworkTagFactory from entry.factories import EntryFactory from lead.factories import LeadFactory +from utils.graphene.tests import GraphQLTestCase + class TestAnalysisFrameworkFilter(GraphQLTestCase): def setUp(self) -> None: @@ -14,19 +18,14 @@ def setUp(self) -> None: self.filter_class = AnalysisFrameworkGqFilterSet def test_search_filter(self): - AnalysisFrameworkFactory.create(title='one') - af2 = AnalysisFrameworkFactory.create(title='two') - af3 = AnalysisFrameworkFactory.create(title='twoo') - obtained = self.filter_class(data=dict( - search='tw' - )).qs + AnalysisFrameworkFactory.create(title="one") + af2 = AnalysisFrameworkFactory.create(title="two") + af3 = AnalysisFrameworkFactory.create(title="twoo") + obtained = self.filter_class(data=dict(search="tw")).qs expected = [af2, af3] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) - @patch('django.utils.timezone.now') + @patch("django.utils.timezone.now") def test_filter_recently_used(self, now_patch): now = self.PATCHER_NOW_VALUE now_patch.side_effect = lambda: now - timedelta(days=90) @@ -49,26 +48,19 @@ def test_filter_recently_used(self, now_patch): ) # Make sure we only get af1, af2 now_patch.side_effect = lambda: now - obtained = set(list( - self - .filter_class(data={'recently_used': True}) - .qs - .values_list('id', flat=True) - )) + obtained = set(list(self.filter_class(data={"recently_used": True}).qs.values_list("id", flat=True))) expected = set([af1.pk, af2.pk]) self.assertEqual(obtained, expected) def test_tags_filter(self): tag1, tag2, _ = AnalysisFrameworkTagFactory.create_batch(3) - af1 = AnalysisFrameworkFactory.create(title='one', tags=[tag1]) - af2 = AnalysisFrameworkFactory.create(title='two', tags=[tag1, tag2]) - AnalysisFrameworkFactory.create(title='twoo') + af1 = AnalysisFrameworkFactory.create(title="one", tags=[tag1]) + af2 = AnalysisFrameworkFactory.create(title="two", tags=[tag1, tag2]) + AnalysisFrameworkFactory.create(title="twoo") for tags, expected in [ ([tag1, tag2], [af1, af2]), ([tag1], [af1, af2]), ([tag2], [af2]), ]: - obtained = self.filter_class(data=dict( - tags=[tag.id for tag in tags] - )).qs + obtained = self.filter_class(data=dict(tags=[tag.id for tag in tags])).qs self.assertQuerySetIdEqual(expected, obtained) diff --git a/apps/analysis_framework/tests/test_mutations.py b/apps/analysis_framework/tests/test_mutations.py index b630820fa6..e7575ddd8d 100644 --- a/apps/analysis_framework/tests/test_mutations.py +++ b/apps/analysis_framework/tests/test_mutations.py @@ -2,22 +2,20 @@ import json from unittest import mock -from django.core.files.temp import NamedTemporaryFile -from utils.graphene.tests import GraphQLTestCase, GraphQLSnapShotTestCase -from user.factories import UserFactory -from graphene_file_upload.django.testing import GraphQLFileUploadTestCase - -from analysis_framework.models import Widget - -from project.models import ProjectChangeLog -from project.factories import ProjectFactory -from organization.factories import OrganizationFactory -from analysis_framework.models import AnalysisFramework from analysis_framework.factories import ( AnalysisFrameworkFactory, SectionFactory, WidgetFactory, ) +from analysis_framework.models import AnalysisFramework, Widget +from django.core.files.temp import NamedTemporaryFile +from graphene_file_upload.django.testing import GraphQLFileUploadTestCase +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from project.models import ProjectChangeLog +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase, GraphQLTestCase class TestPreviewImage(GraphQLFileUploadTestCase, GraphQLTestCase): @@ -51,49 +49,42 @@ def setUp(self) -> None: } } """ - self.variables = { - "data": {"title": 'test', "previewImage": None} - } + self.variables = {"data": {"title": "test", "previewImage": None}} self.force_login(self.user) def test_upload_preview_image(self): - file_text = b'preview image text' - with NamedTemporaryFile(suffix='.png') as t_file: + file_text = b"preview image text" + with NamedTemporaryFile(suffix=".png") as t_file: t_file.write(file_text) t_file.seek(0) response = self._client.post( - '/graphql', + "/graphql", data={ - 'operations': json.dumps({ - 'query': self.upload_mutation, - 'variables': self.variables - }), - 't_file': t_file, - 'map': json.dumps({ - 't_file': ['variables.data.previewImage'] - }) - } + "operations": json.dumps({"query": self.upload_mutation, "variables": self.variables}), + "t_file": t_file, + "map": json.dumps({"t_file": ["variables.data.previewImage"]}), + }, ) content = response.json() self.assertResponseNoErrors(response) # Test can upload image - af_id = content['data']['analysisFrameworkCreate']['result']['id'] - self.assertTrue(content['data']['analysisFrameworkCreate']['ok'], content) - self.assertTrue(content['data']['analysisFrameworkCreate']['result']['previewImage']["name"]) - preview_image_name = content['data']['analysisFrameworkCreate']['result']['previewImage']["name"] - preview_image_url = content['data']['analysisFrameworkCreate']['result']['previewImage']["url"] - self.assertTrue(preview_image_name.endswith('.png')) + af_id = content["data"]["analysisFrameworkCreate"]["result"]["id"] + self.assertTrue(content["data"]["analysisFrameworkCreate"]["ok"], content) + self.assertTrue(content["data"]["analysisFrameworkCreate"]["result"]["previewImage"]["name"]) + preview_image_name = content["data"]["analysisFrameworkCreate"]["result"]["previewImage"]["name"] + preview_image_url = content["data"]["analysisFrameworkCreate"]["result"]["previewImage"]["url"] + self.assertTrue(preview_image_name.endswith(".png")) self.assertTrue(preview_image_url.endswith(preview_image_name)) # Test can retrive image response = self.query(self.retrieve_af_query % af_id) self.assertResponseNoErrors(response) content = response.json() - self.assertTrue(content['data']['analysisFramework']['previewImage']["name"]) - preview_image_name = content['data']['analysisFramework']['previewImage']["name"] - preview_image_url = content['data']['analysisFramework']['previewImage']["url"] - self.assertTrue(preview_image_name.endswith('.png')) + self.assertTrue(content["data"]["analysisFramework"]["previewImage"]["name"]) + preview_image_name = content["data"]["analysisFramework"]["previewImage"]["name"] + preview_image_url = content["data"]["analysisFramework"]["previewImage"]["url"] + self.assertTrue(preview_image_name.endswith(".png")) self.assertTrue(preview_image_url.endswith(preview_image_name)) @@ -102,7 +93,7 @@ class TestAnalysisFrameworkMutationSnapShotTestCase(GraphQLSnapShotTestCase): def setUp(self): super().setUp() - self.create_query = ''' + self.create_query = """ mutation MyMutation ($input: AnalysisFrameworkInputType!) { __typename analysisFrameworkCreate(data: $input) { @@ -144,63 +135,63 @@ def setUp(self): } } } - ''' + """ self.organization1 = OrganizationFactory.create() self.invalid_minput = dict( - title='', - description='Af description', + title="", + description="Af description", isPrivate=False, organization=str(self.organization1.id), # previewImage='', primaryTagging=[ dict( - title='', - clientId='section-101', + title="", + clientId="section-101", order=2, - tooltip='Tooltip for section 101', + tooltip="Tooltip for section 101", widgets=[ dict( - clientId='section-text-101-client-id', - title='', + clientId="section-text-101-client-id", + title="", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-text-101', + key="section-text-101", order=1, properties=dict(), ), dict( - clientId='section-text-102-client-id', - title='', + clientId="section-text-102-client-id", + title="", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-text-102', + key="section-text-102", order=2, properties=dict(), ), ], ), dict( - title='', - clientId='section-102', + title="", + clientId="section-102", order=1, - tooltip='Tooltip for section 102', + tooltip="Tooltip for section 102", widgets=[ dict( - clientId='section-2-text-101-client-id', - title='Section-2-Text-101', + clientId="section-2-text-101-client-id", + title="Section-2-Text-101", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-2-text-101', + key="section-2-text-101", order=1, properties=dict(), ), dict( - clientId='section-2-text-102-client-id', - title='Section-2-Text-102', + clientId="section-2-text-102-client-id", + title="Section-2-Text-102", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-2-text-102', + key="section-2-text-102", order=2, properties=dict(), ), @@ -209,20 +200,20 @@ def setUp(self): ], secondaryTagging=[ dict( - clientId='select-widget-101-client-id', - title='', + clientId="select-widget-101-client-id", + title="", widgetId=self.genum(Widget.WidgetType.SELECT), version=1, - key='select-widget-101-key', + key="select-widget-101-key", order=1, properties=dict(), ), dict( - clientId='multi-select-widget-102-client-id', - title='multi-select-Widget-2', + clientId="multi-select-widget-102-client-id", + title="multi-select-Widget-2", widgetId=self.genum(Widget.WidgetType.MULTISELECT), version=1, - key='multi-select-widget-102-key', + key="multi-select-widget-102-key", order=2, properties=dict(), ), @@ -230,97 +221,96 @@ def setUp(self): ) self.valid_minput = dict( - title='AF (TEST)', - description='Af description', + title="AF (TEST)", + description="Af description", isPrivate=False, organization=str(self.organization1.id), properties=dict(), # previewImage='', primaryTagging=[ dict( - title='Section 101', - clientId='section-101', + title="Section 101", + clientId="section-101", order=2, - tooltip='Tooltip for section 101', + tooltip="Tooltip for section 101", widgets=[ dict( - clientId='section-text-101-client-id', - title='Section-Text-101', + clientId="section-text-101-client-id", + title="Section-Text-101", widgetId=self.genum(Widget.WidgetType.MATRIX1D), version=1, - key='section-text-101', + key="section-text-101", order=1, properties=dict( rows=[ dict( - key='row-key-1', - label='Row Label 1', + key="row-key-1", + label="Row Label 1", cells=[ - dict(key='cell-key-1.1', label='Cell Label 1.1'), - dict(key='cell-key-1.2', label='Cell Label 1.2'), - dict(key='cell-key-1.3', label='Cell Label 1.3'), + dict(key="cell-key-1.1", label="Cell Label 1.1"), + dict(key="cell-key-1.2", label="Cell Label 1.2"), + dict(key="cell-key-1.3", label="Cell Label 1.3"), ], ), dict( - key='row-key-2', - label='Row Label 2', + key="row-key-2", + label="Row Label 2", cells=[ - dict(key='cell-key-2.1', label='Cell Label 2.1'), - dict(key='cell-key-2.2', label='Cell Label 2.2'), + dict(key="cell-key-2.1", label="Cell Label 2.1"), + dict(key="cell-key-2.2", label="Cell Label 2.2"), ], ), - ], ), ), dict( - clientId='section-text-102-client-id', - title='Section-Text-102', + clientId="section-text-102-client-id", + title="Section-Text-102", widgetId=self.genum(Widget.WidgetType.MATRIX2D), version=1, - key='section-text-102', + key="section-text-102", order=2, properties=dict( rows=[ dict( - key='row-key-1', - label='Row Label 1', + key="row-key-1", + label="Row Label 1", subRows=[ - dict(key='sub-row-key-1.1', label='SubRow Label 1.1'), - dict(key='sub-row-key-1.2', label='SubRow Label 1.2'), - dict(key='sub-row-key-1.3', label='SubRow Label 1.3'), + dict(key="sub-row-key-1.1", label="SubRow Label 1.1"), + dict(key="sub-row-key-1.2", label="SubRow Label 1.2"), + dict(key="sub-row-key-1.3", label="SubRow Label 1.3"), ], ), dict( - key='row-key-2', - label='Row Label 2', + key="row-key-2", + label="Row Label 2", subRows=[ - dict(key='sub-row-key-2.1', label='SubRow Label 2.1'), - dict(key='sub-row-key-2.2', label='SubRow Label 2.2'), + dict(key="sub-row-key-2.1", label="SubRow Label 2.1"), + dict(key="sub-row-key-2.2", label="SubRow Label 2.2"), ], ), ], columns=[ dict( - key='column-key-1', - label='Column Label 1', + key="column-key-1", + label="Column Label 1", subColumns=[ - dict(key='sub-column-key-1.1', label='SubColumn Label 1.1'), - dict(key='sub-column-key-1.2', label='SubColumn Label 1.2'), - dict(key='sub-column-key-1.3', label='SubColumn Label 1.3'), + dict(key="sub-column-key-1.1", label="SubColumn Label 1.1"), + dict(key="sub-column-key-1.2", label="SubColumn Label 1.2"), + dict(key="sub-column-key-1.3", label="SubColumn Label 1.3"), ], ), dict( - key='column-key-2', - label='Column Label 2', + key="column-key-2", + label="Column Label 2", subColumns=[ - dict(key='sub-column-key-2.1', label='SubColumn Label 2.1'), - dict(key='sub-column-key-2.2', label='SubColumn Label 2.2'), + dict(key="sub-column-key-2.1", label="SubColumn Label 2.1"), + dict(key="sub-column-key-2.2", label="SubColumn Label 2.2"), ], ), dict( - key='column-key-3', - label='Column Label 3', + key="column-key-3", + label="Column Label 3", subColumns=[], ), ], @@ -329,26 +319,26 @@ def setUp(self): ], ), dict( - title='Section 102', - clientId='section-102', + title="Section 102", + clientId="section-102", order=1, - tooltip='Tooltip for section 102', + tooltip="Tooltip for section 102", widgets=[ dict( - clientId='section-2-text-101-client-id', - title='Section-2-Text-101', + clientId="section-2-text-101-client-id", + title="Section-2-Text-101", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-2-text-101', + key="section-2-text-101", order=1, properties=dict(), ), dict( - clientId='section-2-text-102-client-id', - title='Section-2-Text-102', + clientId="section-2-text-102-client-id", + title="Section-2-Text-102", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key='section-2-text-102', + key="section-2-text-102", order=2, properties=dict(), ), @@ -357,20 +347,20 @@ def setUp(self): ], secondaryTagging=[ dict( - clientId='select-widget-101-client-id', - title='Select-Widget-1', + clientId="select-widget-101-client-id", + title="Select-Widget-1", widgetId=self.genum(Widget.WidgetType.SELECT), version=1, - key='select-widget-101-key', + key="select-widget-101-key", order=1, properties=dict(), ), dict( - clientId='multi-select-widget-102-client-id', - title='multi-select-Widget-2', + clientId="multi-select-widget-102-client-id", + title="multi-select-Widget-2", widgetId=self.genum(Widget.WidgetType.MULTISELECT), version=1, - key='multi-select-widget-102-key', + key="multi-select-widget-102-key", order=2, properties=dict(), ), @@ -390,17 +380,17 @@ def _query_check(minput, **kwargs): self.force_login(user) response = _query_check(self.invalid_minput, okay=False) - self.assertMatchSnapshot(response, 'errors') + self.assertMatchSnapshot(response, "errors") with self.captureOnCommitCallbacks(execute=True): response = _query_check(self.valid_minput, okay=True) - self.assertMatchSnapshot(response, 'success') + self.assertMatchSnapshot(response, "success") # Export test - new_af = AnalysisFramework.objects.get(pk=response['data']['analysisFrameworkCreate']['result']['id']) - self.assertMatchSnapshot(new_af.export.file.read().decode('utf-8'), 'success-af-export') + new_af = AnalysisFramework.objects.get(pk=response["data"]["analysisFrameworkCreate"]["result"]["id"]) + self.assertMatchSnapshot(new_af.export.file.read().decode("utf-8"), "success-af-export") def test_analysis_framework_update(self): - query = ''' + query = """ mutation MyMutation ($id: ID! $input: AnalysisFrameworkInputType!) { __typename analysisFramework (id: $id ) { @@ -479,7 +469,7 @@ def test_analysis_framework_update(self): } } } - ''' + """ user = UserFactory.create() project1, project2, project3 = ProjectFactory.create_batch(3) @@ -488,8 +478,8 @@ def _query_check(id, minput, **kwargs): return self.query_check( query, minput=minput, - mnested=['analysisFramework'], - variables={'id': id}, + mnested=["analysisFramework"], + variables={"id": id}, **kwargs, ) @@ -497,53 +487,52 @@ def _query_check(id, minput, **kwargs): valid_minput = copy.deepcopy(self.valid_minput) new_widgets = [ dict( - clientId='geo-widget-103-client-id', - title='Geo', + clientId="geo-widget-103-client-id", + title="Geo", widgetId=self.genum(Widget.WidgetType.GEO), version=1, - key='geo-widget-103-key', + key="geo-widget-103-key", order=3, properties=dict(), ), dict( - clientId='scale-widget-104-client-id', - title='Scale', + clientId="scale-widget-104-client-id", + title="Scale", widgetId=self.genum(Widget.WidgetType.SCALE), version=1, - key='scale-widget-104-key', + key="scale-widget-104-key", order=4, properties=dict(), ), dict( - clientId='organigram-widget-104-client-id', - title='Organigram', + clientId="organigram-widget-104-client-id", + title="Organigram", widgetId=self.genum(Widget.WidgetType.ORGANIGRAM), version=1, - key='organigram-widget-104-key', + key="organigram-widget-104-key", order=5, properties=dict(), ), ] - valid_minput['secondaryTagging'].extend(new_widgets) + valid_minput["secondaryTagging"].extend(new_widgets) _query_check(0, valid_minput, assert_for_error=True) # ---------- With login self.force_login(user) # ---------- Let's create a new AF (Using create test data) - new_af_response = self.query_check( - self.create_query, minput=valid_minput)['data']['analysisFrameworkCreate']['result'] - self.assertMatchSnapshot(copy.deepcopy(new_af_response), 'created') + new_af_response = self.query_check(self.create_query, minput=valid_minput)["data"]["analysisFrameworkCreate"]["result"] + self.assertMatchSnapshot(copy.deepcopy(new_af_response), "created") - new_af_id = new_af_response['id'] + new_af_id = new_af_response["id"] for project in [project1, project2]: project.analysis_framework_id = new_af_id - project.save(update_fields=('analysis_framework_id',)) + project.save(update_fields=("analysis_framework_id",)) # ---------------- Remove invalid attributes - new_af_response.pop('currentUserRole') - new_af_response.pop('id') + new_af_response.pop("currentUserRole") + new_af_response.pop("id") # ---------- Let's change some attributes (for validation errors) - new_af_response['title'] = '' - new_af_response['primaryTagging'][0]['title'] = '' + new_af_response["title"] = "" + new_af_response["primaryTagging"][0]["title"] = "" # ----------------- Let's try to update # ---- Add stats_config as well. @@ -552,24 +541,17 @@ def _query_check(id, minput, **kwargs): def _get_widget_ID(_type): widget = widget_qs.filter(widget_id=_type).first() if widget: - return dict( - pk=str(widget.id) - ) + return dict(pk=str(widget.id)) def _get_multiple_widget_ID(_type): - return [ - dict( - pk=str(widget.id) - ) - for widget in widget_qs.filter(widget_id=_type) - ] + return [dict(pk=str(widget.id)) for widget in widget_qs.filter(widget_id=_type)] - new_af_response['properties'] = dict( + new_af_response["properties"] = dict( statsConfig=dict( # Invalid IDS geoWidget=_get_widget_ID(Widget.WidgetType.MULTISELECT), severityWidget=_get_widget_ID(Widget.WidgetType.MULTISELECT), - reliabilityWidget=dict(pk='10000001'), + reliabilityWidget=dict(pk="10000001"), # widget1d=_get_multiple_widget_ID(Widget.WidgetType.MULTISELECT), widget1d=_get_multiple_widget_ID(Widget.WidgetType.MULTISELECT), widget2d=_get_multiple_widget_ID(Widget.WidgetType.MULTISELECT), @@ -578,18 +560,18 @@ def _get_multiple_widget_ID(_type): ), ) response = _query_check(new_af_id, new_af_response, okay=False) - self.assertMatchSnapshot(response, 'errors') + self.assertMatchSnapshot(response, "errors") # ---------- Let's change some attributes (for success change) - new_af_response['title'] = 'Updated AF (TEST)' - new_af_response['description'] = 'Updated Af description' - new_af_response['primaryTagging'][0]['title'] = 'Updated Section 102' - new_af_response['primaryTagging'][0]['widgets'][0].pop('id') # Remove/Create a widget - new_af_response['primaryTagging'][0]['widgets'][1]['title'] = 'Updated-Section-2-Text-101' # Remove a widget - new_af_response['primaryTagging'][1].pop('id') # Remove/Create second ordered section (but use current widgets) - new_af_response['secondaryTagging'].pop(0) # Remove another widget - new_af_response['secondaryTagging'][0].pop('id') # Remove/Create another widget + new_af_response["title"] = "Updated AF (TEST)" + new_af_response["description"] = "Updated Af description" + new_af_response["primaryTagging"][0]["title"] = "Updated Section 102" + new_af_response["primaryTagging"][0]["widgets"][0].pop("id") # Remove/Create a widget + new_af_response["primaryTagging"][0]["widgets"][1]["title"] = "Updated-Section-2-Text-101" # Remove a widget + new_af_response["primaryTagging"][1].pop("id") # Remove/Create second ordered section (but use current widgets) + new_af_response["secondaryTagging"].pop(0) # Remove another widget + new_af_response["secondaryTagging"][0].pop("id") # Remove/Create another widget # ----------------- Let's try to update - new_af_response['properties'] = dict( + new_af_response["properties"] = dict( statsConfig=dict( # Invalid IDS geoWidget=_get_widget_ID(Widget.WidgetType.GEO), @@ -603,16 +585,16 @@ def _get_multiple_widget_ID(_type): ) with self.captureOnCommitCallbacks(execute=True): response = _query_check(new_af_id, new_af_response, okay=True) - self.assertMatchSnapshot(response, 'success') + self.assertMatchSnapshot(response, "success") new_af = AnalysisFramework.objects.get(pk=new_af_id) - self.assertMatchSnapshot(new_af.export.file.read().decode('utf-8'), 'success-af-export') + self.assertMatchSnapshot(new_af.export.file.read().decode("utf-8"), "success-af-export") # Check with conditionals other_af_widget = WidgetFactory.create(analysis_framework=AnalysisFrameworkFactory.create()) af_widget = Widget.objects.filter(analysis_framework_id=new_af_id).first() af_widget_pk = af_widget and af_widget.pk # Some with conditionals - new_af_response['primaryTagging'][0]['widgets'][1]['conditional'] = dict( + new_af_response["primaryTagging"][0]["widgets"][1]["conditional"] = dict( parentWidget=other_af_widget.pk, conditions=[], ) @@ -620,22 +602,22 @@ def _get_multiple_widget_ID(_type): response = _query_check(new_af_id, new_af_response, okay=False) # Success Add - new_af_response['primaryTagging'][0]['widgets'][1]['conditional'] = dict( + new_af_response["primaryTagging"][0]["widgets"][1]["conditional"] = dict( parentWidget=af_widget_pk, conditions=[], ) - new_af_response['secondaryTagging'][0]['conditional'] = dict( + new_af_response["secondaryTagging"][0]["conditional"] = dict( parentWidget=af_widget_pk, conditions=[], ) response = _query_check(new_af_id, new_af_response, okay=True) - self.assertMatchSnapshot(response, 'with-conditionals-add') + self.assertMatchSnapshot(response, "with-conditionals-add") # Success Remove - new_af_response['primaryTagging'][0]['widgets'][1].pop('conditional') - new_af_response['secondaryTagging'][0]['conditional'] = None # Should remove this only + new_af_response["primaryTagging"][0]["widgets"][1].pop("conditional") + new_af_response["secondaryTagging"][0]["conditional"] = None # Should remove this only response = _query_check(new_af_id, new_af_response, okay=True) - self.assertMatchSnapshot(response, 'with-conditionals-remove') + self.assertMatchSnapshot(response, "with-conditionals-remove") # With another user (Access denied) another_user = UserFactory.create() @@ -644,18 +626,22 @@ def _get_multiple_widget_ID(_type): # Project Log Check def _get_project_logs_qs(project): - return ProjectChangeLog.objects.filter(project=project).order_by('id') + return ProjectChangeLog.objects.filter(project=project).order_by("id") assert _get_project_logs_qs(project3).count() == 0 for project in [project1, project2]: project_log_qs = _get_project_logs_qs(project) assert project_log_qs.count() == 3 - assert list(project_log_qs.values_list('diff', flat=True)) == [ - dict(framework=dict(updated=True)), - ] * 3 + assert ( + list(project_log_qs.values_list("diff", flat=True)) + == [ + dict(framework=dict(updated=True)), + ] + * 3 + ) def test_analysis_framework_membership_bulk(self): - query = ''' + query = """ mutation MyMutation( $id: ID!, $afMembership: [BulkAnalysisFrameworkMembershipInputType!]!, @@ -704,7 +690,7 @@ def test_analysis_framework_membership_bulk(self): } } } - ''' + """ creater_user = UserFactory.create() user = UserFactory.create() low_permission_user = UserFactory.create() @@ -762,10 +748,11 @@ def test_analysis_framework_membership_bulk(self): def _query_check(**kwargs): return self.query_check( query, - mnested=['analysisFramework'], - variables={'id': af.id, **minput}, + mnested=["analysisFramework"], + variables={"id": af.id, **minput}, **kwargs, ) + # ---------- Without login _query_check(assert_for_error=True) # ---------- With login (with non-member) @@ -777,16 +764,16 @@ def _query_check(**kwargs): # ---------- With login (with higher permission) self.force_login(user) # ----------------- Some Invalid input - response = _query_check()['data']['analysisFramework']['analysisFrameworkMembershipBulk'] - self.assertMatchSnapshot(response, 'try 1') + response = _query_check()["data"]["analysisFramework"]["analysisFrameworkMembershipBulk"] + self.assertMatchSnapshot(response, "try 1") # ----------------- All valid input - minput['afMembership'].pop(1) - response = _query_check()['data']['analysisFramework']['analysisFrameworkMembershipBulk'] - self.assertMatchSnapshot(response, 'try 2') + minput["afMembership"].pop(1) + response = _query_check()["data"]["analysisFramework"]["analysisFrameworkMembershipBulk"] + self.assertMatchSnapshot(response, "try 2") - @mock.patch('analysis_framework.serializers.AfWidgetLimit') + @mock.patch("analysis_framework.serializers.AfWidgetLimit") def test_widgets_limit(self, AfWidgetLimitMock): - query = ''' + query = """ mutation MyMutation ($input: AnalysisFrameworkInputType!) { __typename analysisFrameworkCreate(data: $input) { @@ -797,38 +784,39 @@ def test_widgets_limit(self, AfWidgetLimitMock): } } } - ''' + """ user = UserFactory.create() minput = dict( - title='AF (TEST)', + title="AF (TEST)", primaryTagging=[ dict( - title=f'Section {i}', - clientId=f'section-{i}', + title=f"Section {i}", + clientId=f"section-{i}", order=i, - tooltip=f'Tooltip for section {i}', + tooltip=f"Tooltip for section {i}", widgets=[ dict( - clientId=f'section-text-{j}-client-id', - title=f'Section-Text-{j}', + clientId=f"section-text-{j}-client-id", + title=f"Section-Text-{j}", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key=f'section-text-{j}', + key=f"section-text-{j}", order=j, ) for j in range(0, 4) ], - ) for i in range(0, 2) + ) + for i in range(0, 2) ], secondaryTagging=[ dict( - clientId=f'section-text-{j}-client-id', - title=f'Section-Text-{j}', + clientId=f"section-text-{j}-client-id", + title=f"Section-Text-{j}", widgetId=self.genum(Widget.WidgetType.TEXT), version=1, - key=f'section-text-{j}', + key=f"section-text-{j}", order=j, ) for j in range(0, 4) @@ -838,19 +826,19 @@ def test_widgets_limit(self, AfWidgetLimitMock): self.force_login(user) def _query_check(**kwargs): - return self.query_check(query, minput=minput, **kwargs)['data']['analysisFrameworkCreate'] + return self.query_check(query, minput=minput, **kwargs)["data"]["analysisFrameworkCreate"] # Let's change the limit to lower value for easy testing :P AfWidgetLimitMock.MAX_SECTIONS_ALLOWED = 1 AfWidgetLimitMock.MAX_WIDGETS_ALLOWED_PER_SECTION = 2 AfWidgetLimitMock.MAX_WIDGETS_ALLOWED_IN_SECONDARY_TAGGING = 2 response = _query_check(okay=False) - self.assertMatchSnapshot(response, 'failure-widget-level') + self.assertMatchSnapshot(response, "failure-widget-level") # Let's change the limit to lower value for easy testing :P AfWidgetLimitMock.MAX_WIDGETS_ALLOWED_IN_SECONDARY_TAGGING = 10 AfWidgetLimitMock.MAX_WIDGETS_ALLOWED_PER_SECTION = 10 response = _query_check(okay=False) - self.assertMatchSnapshot(response, 'failure-section-level') + self.assertMatchSnapshot(response, "failure-section-level") # Let's change the limit to higher value # Let's change the limit to higher value AfWidgetLimitMock.MAX_SECTIONS_ALLOWED = 5 @@ -861,7 +849,7 @@ class TestAnalysisFrameworkCreateUpdate(GraphQLTestCase): def setUp(self): super().setUp() self.user = UserFactory.create() - self.create_mutation = ''' + self.create_mutation = """ mutation Mutation($input: AnalysisFrameworkInputType!) { analysisFrameworkCreate(data: $input) { ok @@ -874,8 +862,8 @@ def setUp(self): } } } - ''' - self.update_mutation = ''' + """ + self.update_mutation = """ mutation UpdateMutation($input: AnalysisFrameworkInputType!, $id: ID!) { analysisFramework (id: $id ) { analysisFrameworkUpdate(data: $input) { @@ -890,26 +878,18 @@ def setUp(self): } } } - ''' + """ def test_create_analysis_framework(self): - self.input = dict( - title='new title' - ) + self.input = dict(title="new title") self.force_login(self.user) - response = self.query( - self.create_mutation, - input_data=self.input - ) + response = self.query(self.create_mutation, input_data=self.input) self.assertResponseNoErrors(response) content = response.json() - self.assertTrue(content['data']['analysisFrameworkCreate']['ok'], content) - self.assertEqual( - content['data']['analysisFrameworkCreate']['result']['title'], - self.input['title'] - ) + self.assertTrue(content["data"]["analysisFrameworkCreate"]["ok"], content) + self.assertEqual(content["data"]["analysisFrameworkCreate"]["result"]["title"], self.input["title"]) # TODO: MOVE THIS TO PROJECT TEST # def test_create_private_framework_unauthorized(self): @@ -956,93 +936,75 @@ def test_change_is_private_field(self): private_framework = AnalysisFrameworkFactory.create(is_private=True) public_framework = AnalysisFrameworkFactory.create(is_private=False) user = self.user - private_framework.add_member( - user, - private_framework.get_or_create_owner_role() - ) - public_framework.add_member( - user, - public_framework.get_or_create_owner_role() - ) + private_framework.add_member(user, private_framework.get_or_create_owner_role()) + public_framework.add_member(user, public_framework.get_or_create_owner_role()) content = self._change_framework_privacy(public_framework, user) - self.assertIsNotNone(content['errors'][0]['message']) - self.assertIn('permission', content['errors'][0]['message']) + self.assertIsNotNone(content["errors"][0]["message"]) + self.assertIn("permission", content["errors"][0]["message"]) content = self._change_framework_privacy(private_framework, user) - self.assertIsNotNone(content['errors'][0]['message']) - self.assertIn('permission', content['errors'][0]['message']) + self.assertIsNotNone(content["errors"][0]["message"]) + self.assertIn("permission", content["errors"][0]["message"]) def test_change_other_fields(self): private_framework = AnalysisFrameworkFactory.create(is_private=True) public_framework = AnalysisFrameworkFactory.create(is_private=False) user = self.user - private_framework.add_member( - user, - private_framework.get_or_create_owner_role() - ) - public_framework.add_member( - user, - public_framework.get_or_create_owner_role() - ) + private_framework.add_member(user, private_framework.get_or_create_owner_role()) + public_framework.add_member(user, public_framework.get_or_create_owner_role()) self.force_login(user) # private framework update self.input = dict( - title='new title updated', + title="new title updated", isPrivate=private_framework.is_private, ) response = self.query( self.update_mutation, input_data=self.input, - variables={'id': private_framework.id}, + variables={"id": private_framework.id}, ) private_framework.refresh_from_db() content = response.json() - self.assertNotEqual(content['data']['analysisFramework']['analysisFrameworkUpdate'], None, content) - self.assertTrue(content['data']['analysisFramework']['analysisFrameworkUpdate']['ok'], content) - self.assertEqual( - private_framework.title, - self.input['title'] - ) + self.assertNotEqual(content["data"]["analysisFramework"]["analysisFrameworkUpdate"], None, content) + self.assertTrue(content["data"]["analysisFramework"]["analysisFrameworkUpdate"]["ok"], content) + self.assertEqual(private_framework.title, self.input["title"]) # public framework update self.input = dict( - title='public title updated', + title="public title updated", isPrivate=public_framework.is_private, ) response = self.query( self.update_mutation, input_data=self.input, - variables={'id': public_framework.id}, + variables={"id": public_framework.id}, ) public_framework.refresh_from_db() content = response.json() - self.assertNotEqual(content['data']['analysisFramework']['analysisFrameworkUpdate'], None, content) - self.assertTrue(content['data']['analysisFramework']['analysisFrameworkUpdate']['ok'], content) - self.assertEqual( - public_framework.title, - self.input['title'] - ) + self.assertNotEqual(content["data"]["analysisFramework"]["analysisFrameworkUpdate"], None, content) + self.assertTrue(content["data"]["analysisFramework"]["analysisFrameworkUpdate"]["ok"], content) + self.assertEqual(public_framework.title, self.input["title"]) def _change_framework_privacy(self, framework, user): self.force_login(user) changed_privacy = not framework.is_private self.input = dict( - title='new title', + title="new title", isPrivate=changed_privacy, # other fields not cared for now ) response = self.query( self.update_mutation, input_data=self.input, - variables={'id': framework.id}, + variables={"id": framework.id}, ) content = response.json() return content def test_af_modified_at(self): - create_mutation = ''' + create_mutation = """ mutation Mutation($input: AnalysisFrameworkInputType!) { analysisFrameworkCreate(data: $input) { ok @@ -1056,8 +1018,8 @@ def test_af_modified_at(self): } } } - ''' - update_mutation = ''' + """ + update_mutation = """ mutation UpdateMutation($input: AnalysisFrameworkInputType!, $id: ID!) { analysisFramework (id: $id ) { analysisFrameworkUpdate(data: $input) { @@ -1073,20 +1035,20 @@ def test_af_modified_at(self): } } } - ''' + """ self.force_login(self.user) # Create - minput = dict(title='new title') - af_response = self.query_check(create_mutation, minput=minput)['data']['analysisFrameworkCreate']['result'] - af_id = af_response['id'] - af_modified_at = af_response['modifiedAt'] + minput = dict(title="new title") + af_response = self.query_check(create_mutation, minput=minput)["data"]["analysisFrameworkCreate"]["result"] + af_id = af_response["id"] + af_modified_at = af_response["modifiedAt"] # Update - minput = dict(title='new updated title') - updated_af_response = self.query_check( - update_mutation, minput=minput, variables={'id': af_id} - )['data']['analysisFramework']['analysisFrameworkUpdate']['result'] + minput = dict(title="new updated title") + updated_af_response = self.query_check(update_mutation, minput=minput, variables={"id": af_id})["data"][ + "analysisFramework" + ]["analysisFrameworkUpdate"]["result"] # Make sure modifiedAt is higher now - assert updated_af_response['modifiedAt'] > af_modified_at + assert updated_af_response["modifiedAt"] > af_modified_at diff --git a/apps/analysis_framework/tests/test_roles_api.py b/apps/analysis_framework/tests/test_roles_api.py index ff10591962..2997185882 100644 --- a/apps/analysis_framework/tests/test_roles_api.py +++ b/apps/analysis_framework/tests/test_roles_api.py @@ -1,17 +1,18 @@ -from deep.tests import TestCase - from analysis_framework.models import ( - AnalysisFramework, Widget, + AnalysisFramework, AnalysisFrameworkMembership, + Widget, ) from project.models import Project from user.models import User +from deep.tests import TestCase + class TestAnalysisFrameworkRoles(TestCase): """Test cases for analysis framework roles""" - fixtures = ['apps/analysis_framework/fixtures/af_roles.json'] + fixtures = ["apps/analysis_framework/fixtures/af_roles.json"] def setUp(self): super().setUp() @@ -19,72 +20,68 @@ def setUp(self): self.project = self.create(Project, role=self.admin_role) # Create private and public frameworks self.private_framework = AnalysisFramework.objects.create( - title='Private Framework', + title="Private Framework", project=self.project, is_private=True, ) self.public_framework = AnalysisFramework.objects.create( - title='Public Framework', + title="Public Framework", project=self.project, is_private=False, created_by=self.user, ) # Add widgets self.private_widget = self.create( - Widget, analysis_framework=self.private_framework, + Widget, + analysis_framework=self.private_framework, widget_id=Widget.WidgetType.TEXT, - key='text-widget-001', + key="text-widget-001", ) self.public_widget = self.create( - Widget, analysis_framework=self.public_framework, + Widget, + analysis_framework=self.public_framework, widget_id=Widget.WidgetType.TEXT, - key='text-widget-002', + key="text-widget-002", ) def test_get_private_roles(self): - url = '/api/v1/private-framework-roles/' + url = "/api/v1/private-framework-roles/" self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.data - for role in data['results']: - assert role['is_private_role'] is True, "Must be a private role" + for role in data["results"]: + assert role["is_private_role"] is True, "Must be a private role" def test_get_public_roles_all(self): - url = '/api/v1/public-framework-roles/' + url = "/api/v1/public-framework-roles/" self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.data - for role in data['results']: - assert role['is_private_role'] is not True, "Must be a public role" + for role in data["results"]: + assert role["is_private_role"] is not True, "Must be a public role" - assert any(x['is_default_role'] for x in data['results']), "A default role should be present" + assert any(x["is_default_role"] for x in data["results"]), "A default role should be present" def test_get_public_roles_no_default(self): - url = '/api/v1/public-framework-roles/?is_default_role=false' + url = "/api/v1/public-framework-roles/?is_default_role=false" self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.data - for role in data['results']: - assert role['is_private_role'] is not True, "Must be a public role" + for role in data["results"]: + assert role["is_private_role"] is not True, "Must be a public role" - print([x['is_default_role'] for x in data['results']]) - assert not any(x['is_default_role'] for x in data['results']), "No default role should be present" + print([x["is_default_role"] for x in data["results"]]) + assert not any(x["is_default_role"] for x in data["results"]), "No default role should be present" def test_owner_role(self): - self.private_framework.add_member( - self.user, - self.private_framework.get_or_create_owner_role() - ) - self.public_framework.add_member( - self.user, - self.public_framework.get_or_create_owner_role() - ) + self.private_framework.add_member(self.user, self.private_framework.get_or_create_owner_role()) + self.public_framework.add_member(self.user, self.public_framework.get_or_create_owner_role()) # CLONING THE FRAMEWORK response = self._clone_framework_test(self.private_framework) self.assert_403(response) @@ -110,10 +107,10 @@ def test_patch_membership(self): user = self.create(User) membership, _ = self.private_framework.add_member(user) - url = f'/api/v1/framework-memberships/{membership.id}/' + url = f"/api/v1/framework-memberships/{membership.id}/" patch_data = { - 'role': editor.id, + "role": editor.id, } self.authenticate() @@ -129,7 +126,7 @@ def test_get_membership(self): user = self.create(User) membership, _ = self.private_framework.add_member(user) - url = f'/api/v1/framework-memberships/{membership.id}/' + url = f"/api/v1/framework-memberships/{membership.id}/" self.authenticate() resp = self.client.get(url) @@ -138,14 +135,8 @@ def test_get_membership(self): def test_editor_role(self): editor_user = self.create(User) - self.private_framework.add_member( - editor_user, - self.private_framework.get_or_create_editor_role() - ) - self.public_framework.add_member( - editor_user, - self.public_framework.get_or_create_editor_role() - ) + self.private_framework.add_member(editor_user, self.private_framework.get_or_create_editor_role()) + self.public_framework.add_member(editor_user, self.public_framework.get_or_create_editor_role()) # CLONING FRAMEWORK response = self._clone_framework_test(self.private_framework, editor_user) @@ -185,12 +176,8 @@ def test_add_user_with_public_role_to_private_framework(self): public_role = public_framework.get_or_create_editor_role() private_framework.add_member(self.user, private_framework.get_or_create_owner_role()) - url = '/api/v1/framework-memberships/' - post_data = { - 'framework': private_framework.id, - 'member': user.id, - 'role': public_role.id - } + url = "/api/v1/framework-memberships/" + post_data = {"framework": private_framework.id, "member": user.id, "role": public_role.id} self.authenticate() resp = self.client.post(url, post_data) self.assert_403(resp) @@ -203,12 +190,8 @@ def test_add_user_with_private_role_to_public_framework(self): private_role = private_framework.get_or_create_editor_role() public_framework.add_member(self.user, public_framework.get_or_create_owner_role()) - url = '/api/v1/framework-memberships/' - post_data = { - 'framework': public_framework.id, - 'member': user.id, - 'role': private_role.id - } + url = "/api/v1/framework-memberships/" + post_data = {"framework": public_framework.id, "member": user.id, "role": private_role.id} self.authenticate() resp = self.client.post(url, post_data) self.assert_403(resp) @@ -219,10 +202,10 @@ def test_default_role_private_framework(self): user = self.create(User) private_framework.add_member(self.user, private_framework.get_or_create_owner_role()) - url = '/api/v1/framework-memberships/' + url = "/api/v1/framework-memberships/" post_data = { - 'framework': private_framework.id, - 'member': user.id, + "framework": private_framework.id, + "member": user.id, } self.authenticate() resp = self.client.post(url, post_data) @@ -230,13 +213,13 @@ def test_default_role_private_framework(self): # Now check if user has default_role memship = AnalysisFrameworkMembership.objects.filter( - member=user, framework=private_framework, + member=user, + framework=private_framework, ).first() assert memship is not None, "Membership should be created" permissions = memship.role.permissions - assert permissions == private_framework.get_default_permissions(), \ - "The permissions should be the default permissions" + assert permissions == private_framework.get_default_permissions(), "The permissions should be the default permissions" def test_default_role_public_framework(self): """When not sent role field, default role will be added""" @@ -244,10 +227,10 @@ def test_default_role_public_framework(self): user = self.create(User) public_framework.add_member(self.user, public_framework.get_or_create_owner_role()) - url = '/api/v1/framework-memberships/' + url = "/api/v1/framework-memberships/" post_data = { - 'framework': public_framework.id, - 'member': user.id, + "framework": public_framework.id, + "member": user.id, } self.authenticate() resp = self.client.post(url, post_data) @@ -255,20 +238,22 @@ def test_default_role_public_framework(self): # Now check if user has default_role memship = AnalysisFrameworkMembership.objects.filter( - member=user, framework=public_framework, + member=user, + framework=public_framework, ).first() assert memship is not None, "Membership should be created" permissions = memship.role.permissions - assert permissions == public_framework.get_editor_permissions(), \ - "The default member permissions should be the editor permissions" + assert ( + permissions == public_framework.get_editor_permissions() + ), "The default member permissions should be the editor permissions" def test_owner_cannot_delete_himself(self): framework = self.create(AnalysisFramework) owner_role = framework.get_or_create_owner_role() membership, _ = framework.add_member(self.user, owner_role) - url = f'/api/v1/framework-memberships/{membership.id}/' + url = f"/api/v1/framework-memberships/{membership.id}/" self.authenticate() resp = self.client.delete(url) @@ -276,30 +261,26 @@ def test_owner_cannot_delete_himself(self): def _edit_framework_test(self, framework, user=None, status=200): # Private framework - edit_data = { - 'title': framework.title + '-edited', - 'is_private': framework.is_private, - 'widgets': [] - } + edit_data = {"title": framework.title + "-edited", "is_private": framework.is_private, "widgets": []} self.authenticate(user) - url = f'/api/v1/analysis-frameworks/{framework.id}/' + url = f"/api/v1/analysis-frameworks/{framework.id}/" response = self.client.put(url, edit_data) self.assertEqual(response.status_code, status) def _clone_framework_test(self, framework, user=None): - clone_url = f'/api/v1/clone-analysis-framework/{framework.id}/' + clone_url = f"/api/v1/clone-analysis-framework/{framework.id}/" self.authenticate(user) - data = {'title': 'Cloned'} + data = {"title": "Cloned"} return self.client.post(clone_url, data=data) def _add_user_test(self, framework, user, status=201, role=None): - add_user_url = '/api/v1/framework-memberships/' - role = (role and role.id) or framework.get_or_create_editor_role().id, + add_user_url = "/api/v1/framework-memberships/" + role = ((role and role.id) or framework.get_or_create_editor_role().id,) new_user = self.create(User) add_member_data = { - 'framework': framework.id, - 'member': new_user.id, - 'role': framework.get_or_create_editor_role().id, # Just an arbritrary role + "framework": framework.id, + "member": new_user.id, + "role": framework.get_or_create_editor_role().id, # Just an arbritrary role } self.authenticate(user) response = self.client.post(add_user_url, add_member_data) diff --git a/apps/analysis_framework/tests/test_schemas.py b/apps/analysis_framework/tests/test_schemas.py index fd1d423e88..54c02b48cb 100644 --- a/apps/analysis_framework/tests/test_schemas.py +++ b/apps/analysis_framework/tests/test_schemas.py @@ -1,18 +1,16 @@ import factory - -from utils.graphene.tests import GraphQLSnapShotTestCase - -from analysis_framework.models import AnalysisFrameworkRole - -from user.factories import UserFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory from analysis_framework.factories import ( AnalysisFrameworkFactory, AnalysisFrameworkTagFactory, SectionFactory, WidgetFactory, ) +from analysis_framework.models import AnalysisFrameworkRole +from lead.factories import LeadFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase class TestAnalysisFrameworkQuery(GraphQLSnapShotTestCase): @@ -26,7 +24,7 @@ class TestAnalysisFrameworkQuery(GraphQLSnapShotTestCase): ] def test_analysis_framework_list(self): - query = ''' + query = """ query MyQuery { analysisFrameworks (ordering: "id") { page @@ -50,7 +48,7 @@ def test_analysis_framework_list(self): } } } - ''' + """ user = UserFactory.create() tag1, tag2, _ = AnalysisFrameworkTagFactory.create_batch(3) @@ -66,31 +64,31 @@ def test_analysis_framework_list(self): self.force_login(user) content = self.query_check(query) - results = content['data']['analysisFrameworks']['results'] - self.assertEqual(content['data']['analysisFrameworks']['totalCount'], 2) - self.assertIdEqual(results[0]['id'], normal_af.id) - self.assertIdEqual(results[1]['id'], member_af.id) - self.assertNotIn(str(private_af.id), [d['id'] for d in results]) # Can't see private project. - self.assertMatchSnapshot(results, 'response-01') + results = content["data"]["analysisFrameworks"]["results"] + self.assertEqual(content["data"]["analysisFrameworks"]["totalCount"], 2) + self.assertIdEqual(results[0]["id"], normal_af.id) + self.assertIdEqual(results[1]["id"], member_af.id) + self.assertNotIn(str(private_af.id), [d["id"] for d in results]) # Can't see private project. + self.assertMatchSnapshot(results, "response-01") project = ProjectFactory.create(analysis_framework=private_af) # It shouldn't list private AF after adding to a project. content = self.query_check(query) - results = content['data']['analysisFrameworks']['results'] - self.assertEqual(content['data']['analysisFrameworks']['totalCount'], 2) - self.assertNotIn(str(private_af.id), [d['id'] for d in results]) # Can't see private project. - self.assertMatchSnapshot(results, 'response-02') + results = content["data"]["analysisFrameworks"]["results"] + self.assertEqual(content["data"]["analysisFrameworks"]["totalCount"], 2) + self.assertNotIn(str(private_af.id), [d["id"] for d in results]) # Can't see private project. + self.assertMatchSnapshot(results, "response-02") project.add_member(user) # It should list private AF after user is member of the project. content = self.query_check(query) - results = content['data']['analysisFrameworks']['results'] - self.assertEqual(content['data']['analysisFrameworks']['totalCount'], 3) - self.assertIn(str(private_af.id), [d['id'] for d in results]) # Can see private project now. - self.assertMatchSnapshot(results, 'response-03') + results = content["data"]["analysisFrameworks"]["results"] + self.assertEqual(content["data"]["analysisFrameworks"]["totalCount"], 3) + self.assertIn(str(private_af.id), [d["id"] for d in results]) # Can see private project now. + self.assertMatchSnapshot(results, "response-03") def test_public_analysis_framework(self): - query = ''' + query = """ query MyQuery { publicAnalysisFrameworks (ordering: "id") { page @@ -102,14 +100,14 @@ def test_public_analysis_framework(self): } } } - ''' + """ AnalysisFrameworkFactory.create_batch(4, is_private=False) AnalysisFrameworkFactory.create_batch(5, is_private=True) content = self.query_check(query) - self.assertEqual(content['data']['publicAnalysisFrameworks']['totalCount'], 4, content) + self.assertEqual(content["data"]["publicAnalysisFrameworks"]["totalCount"], 4, content) def test_analysis_framework(self): - query = ''' + query = """ query MyQuery ($id: ID!) { analysisFramework(id: $id) { id @@ -120,7 +118,7 @@ def test_analysis_framework(self): clonedFrom } } - ''' + """ user = UserFactory.create() private_af = AnalysisFrameworkFactory.create(is_private=True) @@ -128,36 +126,36 @@ def test_analysis_framework(self): member_af = AnalysisFrameworkFactory.create(cloned_from=normal_af) member_af.add_member(user) # Without login - self.query_check(query, assert_for_error=True, variables={'id': normal_af.pk}) + self.query_check(query, assert_for_error=True, variables={"id": normal_af.pk}) # With login self.force_login(user) # Should work for normal AF - response = self.query_check(query, variables={'id': normal_af.pk})['data']['analysisFramework'] - self.assertIdEqual(response['id'], normal_af.id, response) - self.assertEqual(response['isPrivate'], False, response) + response = self.query_check(query, variables={"id": normal_af.pk})["data"]["analysisFramework"] + self.assertIdEqual(response["id"], normal_af.id, response) + self.assertEqual(response["isPrivate"], False, response) # Should work for member AF - response = self.query_check(query, variables={'id': member_af.pk})['data']['analysisFramework'] - self.assertIdEqual(response['id'], member_af.id, response) - self.assertEqual(response['isPrivate'], False, response) - self.assertEqual(response['clonedFrom'], str(normal_af.id), response) + response = self.query_check(query, variables={"id": member_af.pk})["data"]["analysisFramework"] + self.assertIdEqual(response["id"], member_af.id, response) + self.assertEqual(response["isPrivate"], False, response) + self.assertEqual(response["clonedFrom"], str(normal_af.id), response) # Shouldn't work for non-member private AF - response = self.query_check(query, variables={'id': private_af.pk})['data']['analysisFramework'] + response = self.query_check(query, variables={"id": private_af.pk})["data"]["analysisFramework"] self.assertEqual(response, None, response) # Shouldn't work for non-member private AF even if there is a project attached project = ProjectFactory.create(analysis_framework=private_af) - response = self.query_check(query, variables={'id': private_af.pk})['data']['analysisFramework'] + response = self.query_check(query, variables={"id": private_af.pk})["data"]["analysisFramework"] self.assertEqual(response, None, response) # Should work for member private AF project.add_member(user) - response = self.query_check(query, variables={'id': private_af.pk})['data']['analysisFramework'] - self.assertIdEqual(response['id'], private_af.id, response) - self.assertEqual(response['isPrivate'], True, response) + response = self.query_check(query, variables={"id": private_af.pk})["data"]["analysisFramework"] + self.assertIdEqual(response["id"], private_af.id, response) + self.assertEqual(response["isPrivate"], True, response) def test_analysis_framework_detail_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { analysisFramework(id: $id) { id @@ -208,7 +206,7 @@ def test_analysis_framework_detail_query(self): } } } - ''' + """ user = UserFactory.create() another_user = UserFactory.create() @@ -216,7 +214,7 @@ def test_analysis_framework_detail_query(self): af.add_member(another_user) def _query_check(**kwargs): - return self.query_check(query, variables={'id': af.pk}, **kwargs) + return self.query_check(query, variables={"id": af.pk}, **kwargs) # Without login _query_check(assert_for_error=True) @@ -225,18 +223,18 @@ def _query_check(**kwargs): self.force_login(user) # Should work for normal AF - response = _query_check()['data']['analysisFramework'] - self.assertEqual(len(response['secondaryTagging']), 0, response) - self.assertEqual(len(response['primaryTagging']), 0, response) + response = _query_check()["data"]["analysisFramework"] + self.assertEqual(len(response["secondaryTagging"]), 0, response) + self.assertEqual(len(response["primaryTagging"]), 0, response) # Let's add some widgets and sections sequence = factory.Sequence(lambda n: n) rsequence = factory.Sequence(lambda n: 20 - n) # Primary Tagging for order, widget_count, tooltip, _sequence in ( - (3, 2, 'Some tooltip info 101', sequence), - (1, 3, 'Some tooltip info 102', rsequence), - (2, 4, 'Some tooltip info 103', sequence), + (3, 2, "Some tooltip info 101", sequence), + (1, 3, "Some tooltip info 102", rsequence), + (2, 4, "Some tooltip info 103", sequence), ): section = SectionFactory.create(analysis_framework=af, order=order, tooltip=tooltip) WidgetFactory.create_batch(widget_count, analysis_framework=af, section=section, order=_sequence) @@ -246,16 +244,16 @@ def _query_check(**kwargs): # Let's save/compare snapshot (without membership) response = _query_check() - self.assertMatchSnapshot(response, 'without-membership') + self.assertMatchSnapshot(response, "without-membership") # Let's save/compare snapshot (with membership) af.add_member(user) response = _query_check() - self.assertMatchSnapshot(response, 'with-membership') + self.assertMatchSnapshot(response, "with-membership") def test_recent_analysis_framework(self): # NOTE: This test includes the recent_analysis_framework based on project and source - query = ''' + query = """ query MyQuery { projectExploreStats { topActiveFrameworks { @@ -266,7 +264,7 @@ def test_recent_analysis_framework(self): } } } - ''' + """ # lets create some analysis_framework ( @@ -301,33 +299,28 @@ def test_recent_analysis_framework(self): content = self.query_check(query) - self.assertEqual(len(content['data']['projectExploreStats']['topActiveFrameworks']), 5, content) + self.assertEqual(len(content["data"]["projectExploreStats"]["topActiveFrameworks"]), 5, content) self.assertEqual( - content['data']['projectExploreStats']['topActiveFrameworks'][0]['analysisFrameworkId'], - str(analysis_framework1.id) + content["data"]["projectExploreStats"]["topActiveFrameworks"][0]["analysisFrameworkId"], str(analysis_framework1.id) ) - self.assertEqual(content['data']['projectExploreStats']['topActiveFrameworks'][0]['projectCount'], 3) - self.assertEqual(content['data']['projectExploreStats']['topActiveFrameworks'][0]['sourceCount'], 65) + self.assertEqual(content["data"]["projectExploreStats"]["topActiveFrameworks"][0]["projectCount"], 3) + self.assertEqual(content["data"]["projectExploreStats"]["topActiveFrameworks"][0]["sourceCount"], 65) self.assertEqual( - content['data']['projectExploreStats']['topActiveFrameworks'][1]['analysisFrameworkId'], - str(analysis_framework3.id) + content["data"]["projectExploreStats"]["topActiveFrameworks"][1]["analysisFrameworkId"], str(analysis_framework3.id) ) - self.assertEqual(content['data']['projectExploreStats']['topActiveFrameworks'][1]['projectCount'], 2) + self.assertEqual(content["data"]["projectExploreStats"]["topActiveFrameworks"][1]["projectCount"], 2) self.assertEqual( - content['data']['projectExploreStats']['topActiveFrameworks'][2]['analysisFrameworkId'], - str(analysis_framework5.id) + content["data"]["projectExploreStats"]["topActiveFrameworks"][2]["analysisFrameworkId"], str(analysis_framework5.id) ) self.assertEqual( - content['data']['projectExploreStats']['topActiveFrameworks'][3]['analysisFrameworkId'], - str(analysis_framework6.id) + content["data"]["projectExploreStats"]["topActiveFrameworks"][3]["analysisFrameworkId"], str(analysis_framework6.id) ) self.assertEqual( - content['data']['projectExploreStats']['topActiveFrameworks'][4]['analysisFrameworkId'], - str(analysis_framework4.id) + content["data"]["projectExploreStats"]["topActiveFrameworks"][4]["analysisFrameworkId"], str(analysis_framework4.id) ) def test_analysis_framework_roles(self): - query = ''' + query = """ query MyQuery { analysisFrameworkRoles { title @@ -337,7 +330,7 @@ def test_analysis_framework_roles(self): isDefaultRole } } - ''' + """ user = UserFactory.create() # without login self.query_check(query, assert_for_error=True) @@ -346,4 +339,4 @@ def test_analysis_framework_roles(self): self.force_login(user) content = self.query_check(query) af_roles_count = AnalysisFrameworkRole.objects.all().count() - self.assertEqual(len(content['data']['analysisFrameworkRoles']), af_roles_count) + self.assertEqual(len(content["data"]["analysisFrameworkRoles"]), af_roles_count) diff --git a/apps/analysis_framework/utils.py b/apps/analysis_framework/utils.py index b436801eec..1aebbd16ed 100644 --- a/apps/analysis_framework/utils.py +++ b/apps/analysis_framework/utils.py @@ -1,4 +1,5 @@ -from analysis_framework.models import Widget, Filter, Exportable +from analysis_framework.models import Exportable, Filter, Widget + from .widgets.store import widget_store @@ -7,15 +8,15 @@ def update_widget(widget): widget_module = widget_store.get(widget.widget_id) if widget_module is None: - raise Exception(f'Unknown widget type: {widget.widget_id}') + raise Exception(f"Unknown widget type: {widget.widget_id}") new_filter_keys = [] - if hasattr(widget_module, 'get_filters'): + if hasattr(widget_module, "get_filters"): filters = widget_module.get_filters(widget, widget_properties) or [] for filter in filters: - filter_key = filter.get('key', widget.key) + filter_key = filter.get("key", widget.key) new_filter_keys.append(filter_key) - filter['title'] = filter.get('title', widget.title) + filter["title"] = filter.get("title", widget.title) Filter.objects.update_or_create( analysis_framework=widget.analysis_framework, widget_key=widget.key, @@ -24,7 +25,7 @@ def update_widget(widget): ) new_exportable_keys = [] - if hasattr(widget_module, 'get_exportable'): + if hasattr(widget_module, "get_exportable"): exportable = widget_module.get_exportable(widget, widget_properties) if exportable: new_exportable_keys.append(widget.key) @@ -32,7 +33,7 @@ def update_widget(widget): analysis_framework=widget.analysis_framework, widget_key=widget.key, defaults={ - 'data': exportable, + "data": exportable, }, ) diff --git a/apps/analysis_framework/views.py b/apps/analysis_framework/views.py index 3caf3f4418..5bbe96a0e8 100644 --- a/apps/analysis_framework/views.py +++ b/apps/analysis_framework/views.py @@ -1,36 +1,42 @@ -from django.utils import timezone from datetime import timedelta + import django_filters from django.db import models +from django.utils import timezone +from entry.models import Entry +from project.models import Project from rest_framework import ( exceptions, + filters, permissions, response, status, - filters, views, viewsets, ) from rest_framework.decorators import action -from deep.permissions import ModifyPermission + from deep.paginations import SmallSizeSetPagination +from deep.permissions import ModifyPermission -from project.models import Project -from entry.models import Entry +from .filter_set import AnalysisFrameworkFilterSet from .models import ( - AnalysisFramework, Widget, Filter, Exportable, + AnalysisFramework, AnalysisFrameworkMembership, AnalysisFrameworkRole, + Exportable, + Filter, + Widget, ) +from .permissions import FrameworkMembershipModifyPermission from .serializers import ( - AnalysisFrameworkSerializer, - WidgetSerializer, - FilterSerializer, ExportableSerializer, AnalysisFrameworkMembershipSerializer, AnalysisFrameworkRoleSerializer, + AnalysisFrameworkSerializer, + ExportableSerializer, + FilterSerializer, + WidgetSerializer, ) -from .filter_set import AnalysisFrameworkFilterSet -from .permissions import FrameworkMembershipModifyPermission class AnalysisFrameworkViewSet(viewsets.ModelViewSet): @@ -38,49 +44,49 @@ class AnalysisFrameworkViewSet(viewsets.ModelViewSet): permission_classes = [permissions.IsAuthenticated, ModifyPermission] filter_backends = ( django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter, + filters.SearchFilter, + filters.OrderingFilter, ) filterset_class = AnalysisFrameworkFilterSet - search_fields = ('title', 'description',) + search_fields = ( + "title", + "description", + ) def get_queryset(self): query_params = self.request.query_params - queryset = AnalysisFramework.get_for(self.request.user).select_related('organization') + queryset = AnalysisFramework.get_for(self.request.user).select_related("organization") month_ago = timezone.now() - timedelta(days=30) - activity_param = query_params.get('activity') + activity_param = query_params.get("activity") # Active/Inactive Filter - if activity_param in ['active', 'inactive']: + if activity_param in ["active", "inactive"]: queryset = queryset.annotate( recent_entry_exists=models.Exists( Entry.objects.filter( - analysis_framework_id=models.OuterRef('id'), + analysis_framework_id=models.OuterRef("id"), modified_at__date__gt=month_ago, ) ), ).filter( - recent_entry_exists=activity_param.lower() == 'active', + recent_entry_exists=activity_param.lower() == "active", ) # Owner Filter - if query_params.get('relatedToMe', 'false').lower() == 'true': + if query_params.get("relatedToMe", "false").lower() == "true": queryset = queryset.filter(members=self.request.user) return queryset @action( detail=True, - url_path='memberships', - methods=['get'], + url_path="memberships", + methods=["get"], ) def get_memberships(self, request, pk=None, version=None): framework = self.get_object() - memberships = AnalysisFrameworkMembership.objects.filter(framework=framework).select_related( - 'member', 'role', 'added_by' - ) + memberships = AnalysisFrameworkMembership.objects.filter(framework=framework).select_related("member", "role", "added_by") serializer = AnalysisFrameworkMembershipSerializer( - self.paginate_queryset(memberships), - context={'request': request}, - many=True + self.paginate_queryset(memberships), context={"request": request}, many=True ) return self.get_paginated_response(serializer.data) @@ -89,22 +95,20 @@ class AnalysisFrameworkCloneView(views.APIView): permission_classes = [permissions.IsAuthenticated] def post(self, request, af_id, version=None): - if not AnalysisFramework.objects.filter( - id=af_id - ).exists(): + if not AnalysisFramework.objects.filter(id=af_id).exists(): raise exceptions.NotFound() - analysis_framework = AnalysisFramework.objects.get( - id=af_id - ) + analysis_framework = AnalysisFramework.objects.get(id=af_id) if not analysis_framework.can_clone(request.user): raise exceptions.PermissionDenied() - cloned_title = request.data.get('title') + cloned_title = request.data.get("title") if not cloned_title: - raise exceptions.ValidationError({ - 'title': 'Title should be present', - }) + raise exceptions.ValidationError( + { + "title": "Title should be present", + } + ) new_af = analysis_framework.clone( request.user, @@ -115,16 +119,18 @@ def post(self, request, af_id, version=None): serializer = AnalysisFrameworkSerializer( new_af, - context={'request': request}, + context={"request": request}, ) - project = request.data.get('project') + project = request.data.get("project") if project: project = Project.objects.get(id=project) if not project.can_modify(request.user): - raise exceptions.ValidationError({ - 'project': 'Invalid project', - }) + raise exceptions.ValidationError( + { + "project": "Invalid project", + } + ) project.analysis_framework = new_af project.modified_by = request.user project.save() @@ -137,8 +143,7 @@ def post(self, request, af_id, version=None): class WidgetViewSet(viewsets.ModelViewSet): serializer_class = WidgetSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return Widget.get_for(self.request.user) @@ -146,8 +151,7 @@ def get_queryset(self): class FilterViewSet(viewsets.ModelViewSet): serializer_class = FilterSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return Filter.get_for(self.request.user) @@ -155,8 +159,7 @@ def get_queryset(self): class ExportableViewSet(viewsets.ModelViewSet): serializer_class = ExportableSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return Exportable.get_for(self.request.user) @@ -164,21 +167,18 @@ def get_queryset(self): class AnalysisFrameworkMembershipViewSet(viewsets.ModelViewSet): serializer_class = AnalysisFrameworkMembershipSerializer - permission_classes = [permissions.IsAuthenticated, - FrameworkMembershipModifyPermission] + permission_classes = [permissions.IsAuthenticated, FrameworkMembershipModifyPermission] pagination_class = SmallSizeSetPagination def get_queryset(self): - return AnalysisFrameworkMembership.get_for(self.request.user).select_related( - 'member', 'role', 'added_by' - ) + return AnalysisFrameworkMembership.get_for(self.request.user).select_related("member", "role", "added_by") def destroy(self, request, *args, **kwargs): instance = self.get_object() # Don't let user delete him/herself if request.user == instance.member: return response.Response( - {'message': 'You cannot remove yourself from framework'}, + {"message": "You cannot remove yourself from framework"}, status=status.HTTP_403_FORBIDDEN, ) @@ -199,8 +199,8 @@ class PublicAnalysisFrameworkRoleViewSet(viewsets.ReadOnlyModelViewSet): permission_classes = [permissions.IsAuthenticated] def get_queryset(self): - no_default_role = self.request.query_params.get('is_default_role', 'true') == 'false' - extra = {} if not no_default_role else {'is_default_role': False} + no_default_role = self.request.query_params.get("is_default_role", "true") == "false" + extra = {} if not no_default_role else {"is_default_role": False} return AnalysisFrameworkRole.objects.filter( is_private_role=False, diff --git a/apps/analysis_framework/widgets/conditional_widget.py b/apps/analysis_framework/widgets/conditional_widget.py index 6282da8c1e..ffbdca3411 100644 --- a/apps/analysis_framework/widgets/conditional_widget.py +++ b/apps/analysis_framework/widgets/conditional_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'conditionalWidget' +WIDGET_ID = "conditionalWidget" class DummyWidget: @@ -9,75 +9,70 @@ def __init__(self, kwargs): def get_nested_filters(original_widget, widget): from analysis_framework.widgets.store import widget_store - widget_data = widget.properties and widget.properties.get('data') + + widget_data = widget.properties and widget.properties.get("data") widget_module = widget_store.get(widget.widget_id) - if hasattr(widget_module, 'get_filters'): - filters = widget_module.get_filters( - widget, - widget_data or {}, - ) or [] - - return [{ - **filter, - 'title': '{} - {}'.format( - original_widget.title, - filter.get('title', widget.title) - ), - 'key': '{}-{}'.format( - original_widget.key, - filter.get('key', widget.key), - ), - } for filter in filters] + if hasattr(widget_module, "get_filters"): + filters = ( + widget_module.get_filters( + widget, + widget_data or {}, + ) + or [] + ) + + return [ + { + **filter, + "title": "{} - {}".format(original_widget.title, filter.get("title", widget.title)), + "key": "{}-{}".format( + original_widget.key, + filter.get("key", widget.key), + ), + } + for filter in filters + ] return [] def get_filters(original_widget, data): - widgets = data.get('widgets') or [] + widgets = data.get("widgets") or [] filters = [] for w in widgets: - widget = DummyWidget(w.get('widget')) + widget = DummyWidget(w.get("widget")) filters = filters + get_nested_filters(original_widget, widget) return filters def get_nested_exportable(widget): from analysis_framework.widgets.store import widget_store - widget_data = widget.properties and widget.properties.get('data') + + widget_data = widget.properties and widget.properties.get("data") widget_module = widget_store.get(widget.widget_id) - if hasattr(widget_module, 'get_exportable'): - return widget_module.get_exportable( - widget, - widget_data or {}, - ) or {} + if hasattr(widget_module, "get_exportable"): + return ( + widget_module.get_exportable( + widget, + widget_data or {}, + ) + or {} + ) return {} def get_exportable(widget, data): - widgets = data.get('widgets') or [] + widgets = data.get("widgets") or [] exportables = [] for w in widgets: - widget = DummyWidget(w.get('widget')) + widget = DummyWidget(w.get("widget")) exportables.append(get_nested_exportable(widget)) return { - 'excel': { - 'type': 'nested', - 'title': widget.title, - 'children': [ - e.get('excel') - for e in exportables - if e.get('excel') - ] + "excel": {"type": "nested", "title": widget.title, "children": [e.get("excel") for e in exportables if e.get("excel")]}, + "report": { + "levels": [level for e in exportables if e.get("report", {}).get("levels") for level in e["report"]["levels"]] }, - 'report': { - 'levels': [ - level - for e in exportables - if e.get('report', {}).get('levels') - for level in e['report']['levels'] - ] - } } diff --git a/apps/analysis_framework/widgets/date_range_widget.py b/apps/analysis_framework/widgets/date_range_widget.py index a865b63910..253ccbcc8b 100644 --- a/apps/analysis_framework/widgets/date_range_widget.py +++ b/apps/analysis_framework/widgets/date_range_widget.py @@ -1,28 +1,30 @@ -WIDGET_ID = 'dateRangeWidget' +WIDGET_ID = "dateRangeWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.INTERSECTS, - 'properties': { - 'type': 'date', - }, - }] + return [ + { + "filter_type": Filter.FilterType.INTERSECTS, + "properties": { + "type": "date", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'type': 'multiple', - 'titles': [ - '{} (From)'.format(widget.title), - '{} (To)'.format(widget.title), + "excel": { + "type": "multiple", + "titles": [ + "{} (From)".format(widget.title), + "{} (To)".format(widget.title), ], - 'col_type': [ - 'date', - 'date', + "col_type": [ + "date", + "date", ], }, } diff --git a/apps/analysis_framework/widgets/date_widget.py b/apps/analysis_framework/widgets/date_widget.py index 7ab18bc8a3..5fbff5d766 100644 --- a/apps/analysis_framework/widgets/date_widget.py +++ b/apps/analysis_framework/widgets/date_widget.py @@ -1,20 +1,23 @@ -WIDGET_ID = 'dateWidget' +WIDGET_ID = "dateWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.NUMBER, - 'properties': { - 'type': 'date', - }, - }] + + return [ + { + "filter_type": Filter.FilterType.NUMBER, + "properties": { + "type": "date", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, - 'col_type': 'date', + "excel": { + "title": widget.title, + "col_type": "date", }, } diff --git a/apps/analysis_framework/widgets/geo_widget.py b/apps/analysis_framework/widgets/geo_widget.py index c7ac8174b9..d26b762aee 100644 --- a/apps/analysis_framework/widgets/geo_widget.py +++ b/apps/analysis_framework/widgets/geo_widget.py @@ -1,20 +1,23 @@ -WIDGET_ID = 'geoWidget' +WIDGET_ID = "geoWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'geo', - }, - }] + + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": { + "type": "geo", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'type': 'geo', - 'title': widget.title, + "excel": { + "type": "geo", + "title": widget.title, }, } diff --git a/apps/analysis_framework/widgets/matrix1d_widget.py b/apps/analysis_framework/widgets/matrix1d_widget.py index afb367e53a..989a067d78 100644 --- a/apps/analysis_framework/widgets/matrix1d_widget.py +++ b/apps/analysis_framework/widgets/matrix1d_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'matrix1dWidget' +WIDGET_ID = "matrix1dWidget" """ @@ -23,59 +23,67 @@ def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - rows = properties.get('rows', []) + rows = properties.get("rows", []) filter_options = [] for row in rows: - filter_options.append({ - 'label': row.get('label'), - 'key': row.get('key'), - }) - cells = row.get('cells', []) + filter_options.append( + { + "label": row.get("label"), + "key": row.get("key"), + } + ) + cells = row.get("cells", []) for cell in cells: - filter_options.append({ - 'label': '{} / {}'.format( - row.get('label'), - cell.get('label'), - ), - 'key': cell.get('key'), - }) + filter_options.append( + { + "label": "{} / {}".format( + row.get("label"), + cell.get("label"), + ), + "key": cell.get("key"), + } + ) - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'multiselect', - 'options': filter_options, - }, - }] + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": { + "type": "multiselect", + "options": filter_options, + }, + } + ] def get_exportable(widget, properties): - rows = properties.get('rows', []) + rows = properties.get("rows", []) excel = { - 'type': 'multiple', - 'titles': [ - '{} - Dimension'.format(widget.title), - '{} - Subdimension'.format(widget.title), + "type": "multiple", + "titles": [ + "{} - Dimension".format(widget.title), + "{} - Subdimension".format(widget.title), ], } report = { - 'levels': [ + "levels": [ { - 'id': row.get('key'), - 'title': row.get('label'), - 'sublevels': [ + "id": row.get("key"), + "title": row.get("label"), + "sublevels": [ { - 'id': '{}-{}'.format(row.get('key'), cell.get('key')), - 'title': cell.get('label'), - } for cell in row.get('cells', []) + "id": "{}-{}".format(row.get("key"), cell.get("key")), + "title": cell.get("label"), + } + for cell in row.get("cells", []) ], - } for row in rows + } + for row in rows ], } return { - 'excel': excel, - 'report': report, + "excel": excel, + "report": report, } diff --git a/apps/analysis_framework/widgets/matrix2d_widget.py b/apps/analysis_framework/widgets/matrix2d_widget.py index 78e342c648..cbca77cfcb 100644 --- a/apps/analysis_framework/widgets/matrix2d_widget.py +++ b/apps/analysis_framework/widgets/matrix2d_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'matrix2dWidget' +WIDGET_ID = "matrix2dWidget" """ @@ -43,103 +43,116 @@ def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import row_options = [] - rows = properties.get('rows', []) + rows = properties.get("rows", []) for row in rows: - row_options.append({ - 'label': row.get('label'), - 'key': row.get('key'), - }) + row_options.append( + { + "label": row.get("label"), + "key": row.get("key"), + } + ) - sub_rows = row.get('subRows', []) + sub_rows = row.get("subRows", []) for sub_row in sub_rows: - row_options.append({ - 'label': '{} / {}'.format( - row.get('label'), - sub_row.get('label'), - ), - 'key': sub_row.get('key'), - }) + row_options.append( + { + "label": "{} / {}".format( + row.get("label"), + sub_row.get("label"), + ), + "key": sub_row.get("key"), + } + ) column_options = [] - columns = properties.get('columns', []) + columns = properties.get("columns", []) for column in columns: - column_options.append({ - 'label': column.get('label'), - 'key': column.get('key'), - }) + column_options.append( + { + "label": column.get("label"), + "key": column.get("key"), + } + ) - subcolumns = column.get('subColumns', []) + subcolumns = column.get("subColumns", []) for subcolumn in subcolumns: - column_options.append({ - 'label': '{} / {}'.format( - column.get('label'), - subcolumn.get('label'), - ), - 'key': subcolumn.get('key'), - }) + column_options.append( + { + "label": "{} / {}".format( + column.get("label"), + subcolumn.get("label"), + ), + "key": subcolumn.get("key"), + } + ) - return [{ - 'title': '{} Rows'.format(widget.title), - 'filter_type': Filter.FilterType.LIST, - 'key': '{}-rows'.format(widget.key), - 'properties': { - 'type': 'multiselect', - 'options': row_options, + return [ + { + "title": "{} Rows".format(widget.title), + "filter_type": Filter.FilterType.LIST, + "key": "{}-rows".format(widget.key), + "properties": { + "type": "multiselect", + "options": row_options, + }, }, - }, { - 'title': '{} Columns'.format(widget.title), - 'filter_type': Filter.FilterType.LIST, - 'key': '{}-columns'.format(widget.key), - 'properties': { - 'type': 'multiselect', - 'options': column_options, + { + "title": "{} Columns".format(widget.title), + "filter_type": Filter.FilterType.LIST, + "key": "{}-columns".format(widget.key), + "properties": { + "type": "multiselect", + "options": column_options, + }, }, - }] + ] def get_exportable(widget, properties): excel = { - 'type': 'multiple', - 'titles': [ - '{} - Row'.format(widget.title), - '{} - SubRow'.format(widget.title), - '{} - Column'.format(widget.title), - '{} - SubColumns'.format(widget.title), + "type": "multiple", + "titles": [ + "{} - Row".format(widget.title), + "{} - SubRow".format(widget.title), + "{} - Column".format(widget.title), + "{} - SubColumns".format(widget.title), ], } report = { - 'levels': [ + "levels": [ { - 'id': column.get('key'), - 'title': column.get('label'), - 'sublevels': [ + "id": column.get("key"), + "title": column.get("label"), + "sublevels": [ { - 'id': '{}-{}'.format( - column.get('key'), - row.get('key'), + "id": "{}-{}".format( + column.get("key"), + row.get("key"), ), - 'title': row.get('label'), - 'sublevels': [ + "title": row.get("label"), + "sublevels": [ { - 'id': '{}-{}-{}'.format( - column.get('key'), - row.get('key'), - sub_row.get('key'), + "id": "{}-{}-{}".format( + column.get("key"), + row.get("key"), + sub_row.get("key"), ), - 'title': sub_row.get('label'), - } for sub_row - in row.get('subRows', []) - ] - } for row in properties.get('rows', []) + "title": sub_row.get("label"), + } + for sub_row in row.get("subRows", []) + ], + } + for row in properties.get("rows", []) ], - } for column in properties.get('columns', []) + } + for column in properties.get("columns", []) ], } return { - 'excel': excel, - 'report': report, + "excel": excel, + "report": report, } diff --git a/apps/analysis_framework/widgets/multiselect_widget.py b/apps/analysis_framework/widgets/multiselect_widget.py index 1a346ff73f..0472539a39 100644 --- a/apps/analysis_framework/widgets/multiselect_widget.py +++ b/apps/analysis_framework/widgets/multiselect_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'multiselectWidget' +WIDGET_ID = "multiselectWidget" """ properties: @@ -16,23 +16,25 @@ def get_filters(widget, properties): filter_options = [ { - 'key': option['key'], - 'label': option['label'], + "key": option["key"], + "label": option["label"], + } + for option in properties.get("options", []) + ] + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": { + "type": "multiselect", + "options": filter_options, + }, } - for option in properties.get('options', []) ] - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'multiselect', - 'options': filter_options, - }, - }] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, + "excel": { + "title": widget.title, }, } diff --git a/apps/analysis_framework/widgets/number_matrix_widget.py b/apps/analysis_framework/widgets/number_matrix_widget.py index 1163849763..d9258fca12 100644 --- a/apps/analysis_framework/widgets/number_matrix_widget.py +++ b/apps/analysis_framework/widgets/number_matrix_widget.py @@ -1,37 +1,42 @@ -WIDGET_ID = 'numberMatrixWidget' +WIDGET_ID = "numberMatrixWidget" # NOTE: THIS IS REMOVED FROM NEW UI + def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.NUMBER, - 'properties': { - 'type': 'number-2d', - }, - }] + return [ + { + "filter_type": Filter.FilterType.NUMBER, + "properties": { + "type": "number-2d", + }, + } + ] def get_exportable(widget, properties): titles = [] - row_headers = properties.get('row_headers', []) + row_headers = properties.get("row_headers", []) for row_header in row_headers: - column_headers = properties.get('column_headers', []) + column_headers = properties.get("column_headers", []) for column_header in column_headers: - titles.append('{} - {}'.format( - row_header.get('title'), - column_header.get('title'), - )) + titles.append( + "{} - {}".format( + row_header.get("title"), + column_header.get("title"), + ) + ) - titles.append('{} - Matches'.format(row_header.get('title'))) + titles.append("{} - Matches".format(row_header.get("title"))) return { - 'excel': { - 'type': 'multiple', - 'titles': titles, + "excel": { + "type": "multiple", + "titles": titles, # TODO: col_type to list full of 'number' }, } diff --git a/apps/analysis_framework/widgets/number_widget.py b/apps/analysis_framework/widgets/number_widget.py index 2975468ac8..0fa4c434b1 100644 --- a/apps/analysis_framework/widgets/number_widget.py +++ b/apps/analysis_framework/widgets/number_widget.py @@ -1,21 +1,23 @@ -WIDGET_ID = 'numberWidget' +WIDGET_ID = "numberWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.NUMBER, - 'properties': { - 'type': 'number', - }, - }] + return [ + { + "filter_type": Filter.FilterType.NUMBER, + "properties": { + "type": "number", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, - 'col_type': 'number', + "excel": { + "title": widget.title, + "col_type": "number", }, } diff --git a/apps/analysis_framework/widgets/organigram_widget.py b/apps/analysis_framework/widgets/organigram_widget.py index 49fbf41bde..77177dcc73 100644 --- a/apps/analysis_framework/widgets/organigram_widget.py +++ b/apps/analysis_framework/widgets/organigram_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'organigramWidget' +WIDGET_ID = "organigramWidget" """ properties: @@ -12,19 +12,19 @@ def get_values_for_organ(organ, parent_label=None): - label = organ.get('label', '') + label = organ.get("label", "") if parent_label: - label = '{} / {}'.format(parent_label, label) + label = "{} / {}".format(parent_label, label) - values = [{ - 'key': organ.get('key'), - 'label': label, - }] + values = [ + { + "key": organ.get("key"), + "label": label, + } + ] - for organ in organ.get('children') or []: - values.extend( - get_values_for_organ(organ, label) - ) + for organ in organ.get("children") or []: + values.extend(get_values_for_organ(organ, label)) return values @@ -33,39 +33,34 @@ def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import options = [] - raw_options = properties and properties.get('options') + raw_options = properties and properties.get("options") if raw_options: options = get_values_for_organ(raw_options, None) - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'multiselect', - 'options': options, - }, - }] + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": { + "type": "multiselect", + "options": options, + }, + } + ] def get_exportable(widget, properties): def _get_depth(organ, level=1): - child_organs = organ.get('children') or [] + child_organs = organ.get("children") or [] if len(child_organs) == 0: return level depths = [] for c_organ in child_organs: - depths.append( - _get_depth(c_organ, level=level + 1) - ) + depths.append(_get_depth(c_organ, level=level + 1)) return max(depths) - options = (properties and properties.get('options')) or {} + options = (properties and properties.get("options")) or {} return { - 'excel': { - 'type': 'multiple', - 'titles': [ - f'{widget.title} - Level {level}' - for level in range( - _get_depth(options) - ) - ], + "excel": { + "type": "multiple", + "titles": [f"{widget.title} - Level {level}" for level in range(_get_depth(options))], }, } diff --git a/apps/analysis_framework/widgets/scale_widget.py b/apps/analysis_framework/widgets/scale_widget.py index a7f97fc5e9..132c63c262 100644 --- a/apps/analysis_framework/widgets/scale_widget.py +++ b/apps/analysis_framework/widgets/scale_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'scaleWidget' +WIDGET_ID = "scaleWidget" """ @@ -19,22 +19,22 @@ def get_filters(widget, properties): filter_options = [ { - 'key': option['key'], - 'label': option['label'], - } for option in properties.get('options', []) + "key": option["key"], + "label": option["label"], + } + for option in properties.get("options", []) + ] + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": {"type": "multiselect-range", "options": filter_options}, + } ] - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'multiselect-range', - 'options': filter_options - }, - }] def get_exportable(widget, data): return { - 'excel': { - 'title': widget.title, + "excel": { + "title": widget.title, }, } diff --git a/apps/analysis_framework/widgets/select_widget.py b/apps/analysis_framework/widgets/select_widget.py index 2e4d48e269..87764c943e 100644 --- a/apps/analysis_framework/widgets/select_widget.py +++ b/apps/analysis_framework/widgets/select_widget.py @@ -1,4 +1,4 @@ -WIDGET_ID = 'selectWidget' +WIDGET_ID = "selectWidget" """ properties: @@ -17,23 +17,25 @@ def get_filters(widget, properties): filter_options = [ { - 'key': option['key'], - 'label': option['label'], + "key": option["key"], + "label": option["label"], + } + for option in properties.get("options", []) + ] + return [ + { + "filter_type": Filter.FilterType.LIST, + "properties": { + "type": "multiselect", + "options": filter_options, + }, } - for option in properties.get('options', []) ] - return [{ - 'filter_type': Filter.FilterType.LIST, - 'properties': { - 'type': 'multiselect', - 'options': filter_options, - }, - }] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, + "excel": { + "title": widget.title, }, } diff --git a/apps/analysis_framework/widgets/store.py b/apps/analysis_framework/widgets/store.py index ceda527609..b9dec2301a 100644 --- a/apps/analysis_framework/widgets/store.py +++ b/apps/analysis_framework/widgets/store.py @@ -1,22 +1,21 @@ from . import ( - date_widget, + conditional_widget, date_range_widget, - time_widget, - time_range_widget, - number_widget, - scale_widget, - select_widget, - multiselect_widget, + date_widget, geo_widget, - organigram_widget, matrix1d_widget, matrix2d_widget, + multiselect_widget, number_matrix_widget, - conditional_widget, + number_widget, + organigram_widget, + scale_widget, + select_widget, text_widget, + time_range_widget, + time_widget, ) - widget_store = { widget.WIDGET_ID: widget for widget in ( diff --git a/apps/analysis_framework/widgets/text_widget.py b/apps/analysis_framework/widgets/text_widget.py index 871b8cec6c..55d7ce9246 100644 --- a/apps/analysis_framework/widgets/text_widget.py +++ b/apps/analysis_framework/widgets/text_widget.py @@ -1,20 +1,22 @@ -WIDGET_ID = 'textWidget' +WIDGET_ID = "textWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.TEXT, - 'properties': { - 'type': 'text', - }, - }] + return [ + { + "filter_type": Filter.FilterType.TEXT, + "properties": { + "type": "text", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, + "excel": { + "title": widget.title, }, } diff --git a/apps/analysis_framework/widgets/time_range_widget.py b/apps/analysis_framework/widgets/time_range_widget.py index 0420070202..9315b7e574 100644 --- a/apps/analysis_framework/widgets/time_range_widget.py +++ b/apps/analysis_framework/widgets/time_range_widget.py @@ -1,28 +1,30 @@ -WIDGET_ID = 'timeRangeWidget' +WIDGET_ID = "timeRangeWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.INTERSECTS, - 'properties': { - 'type': 'time', - }, - }] + return [ + { + "filter_type": Filter.FilterType.INTERSECTS, + "properties": { + "type": "time", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'type': 'multiple', - 'titles': [ - '{} (From)'.format(widget.title), - '{} (To)'.format(widget.title), + "excel": { + "type": "multiple", + "titles": [ + "{} (From)".format(widget.title), + "{} (To)".format(widget.title), ], - 'col_type': [ - 'time', - 'time', + "col_type": [ + "time", + "time", ], }, } diff --git a/apps/analysis_framework/widgets/time_widget.py b/apps/analysis_framework/widgets/time_widget.py index 167d6811cb..6932778b09 100644 --- a/apps/analysis_framework/widgets/time_widget.py +++ b/apps/analysis_framework/widgets/time_widget.py @@ -1,21 +1,23 @@ -WIDGET_ID = 'timeWidget' +WIDGET_ID = "timeWidget" def get_filters(widget, properties): from analysis_framework.models import Filter # To avoid circular import - return [{ - 'filter_type': Filter.FilterType.NUMBER, - 'properties': { - 'type': 'time', - }, - }] + return [ + { + "filter_type": Filter.FilterType.NUMBER, + "properties": { + "type": "time", + }, + } + ] def get_exportable(widget, properties): return { - 'excel': { - 'title': widget.title, - 'col_type': 'time', + "excel": { + "title": widget.title, + "col_type": "time", }, } diff --git a/apps/ary/admin.py b/apps/ary/admin.py index da51c4428a..93e0b09b14 100644 --- a/apps/ary/admin.py +++ b/apps/ary/admin.py @@ -1,45 +1,38 @@ from django.contrib import admin -from django.urls import path from django.http import HttpResponse +from django.urls import path -from deep.admin import linkify, ModelAdmin, VersionAdmin +from deep.admin import ModelAdmin, VersionAdmin, linkify from .management.commands.export_ary_template import export_ary_fixture from .models import ( + AffectedGroup, + AffectedLocation, + Assessment, AssessmentTemplate, - + Focus, MetadataField, MetadataGroup, MetadataOption, - MethodologyField, MethodologyGroup, MethodologyOption, - - Sector, - Focus, - AffectedGroup, - UnderlyingFactor, - - PrioritySector, PriorityIssue, - SpecificNeedGroup, - AffectedLocation, - + PrioritySector, ScoreBucket, - ScorePillar, - ScoreQuestion, - ScoreScale, + ScoreMatrixColumn, ScoreMatrixPillar, ScoreMatrixRow, - ScoreMatrixColumn, ScoreMatrixScale, - + ScorePillar, + ScoreQuestion, + ScoreQuestionnaire, ScoreQuestionnaireSector, ScoreQuestionnaireSubSector, - ScoreQuestionnaire, - - Assessment, + ScoreScale, + Sector, + SpecificNeedGroup, + UnderlyingFactor, ) @@ -50,23 +43,23 @@ class ScoreBucketInline(admin.TabularInline): @admin.register(AssessmentTemplate) class AnalysisFrameworkTemplateAdmin(VersionAdmin): - change_list_template = 'ary/ary_change_list.html' - search_fields = ('title',) + change_list_template = "ary/ary_change_list.html" + search_fields = ("title",) inlines = [ScoreBucketInline] - autocomplete_fields = ('created_by', 'modified_by',) + autocomplete_fields = ( + "created_by", + "modified_by", + ) def get_urls(self): info = self.model._meta.app_label, self.model._meta.model_name return [ - path( - 'export/', self.admin_site.admin_view(self.export_ary), - name='{}_{}_export'.format(*info) - ), + path("export/", self.admin_site.admin_view(self.export_ary), name="{}_{}_export".format(*info)), ] + super().get_urls() def export_ary(self, request): content = export_ary_fixture() - return HttpResponse(content, content_type='application/json') + return HttpResponse(content, content_type="application/json") class MetadataOptionInline(admin.TabularInline): @@ -102,13 +95,13 @@ class ScoreMatrixScaleInline(admin.TabularInline): @admin.register(ScorePillar) class ScorePillarAdmin(ModelAdmin): inlines = [ScoreQuestionInline] - list_display = ('title', linkify('template'), 'order', 'weight') + list_display = ("title", linkify("template"), "order", "weight") @admin.register(ScoreMatrixPillar) class ScoreMatrixPillarAdmin(ModelAdmin): inlines = [ScoreMatrixRowInline, ScoreMatrixColumnInline, ScoreMatrixScaleInline] - list_display = ('title', linkify('template'), 'order', 'weight') + list_display = ("title", linkify("template"), "order", "weight") class ScoreQuestionnaireSubSectorInline(admin.TabularInline): @@ -123,23 +116,32 @@ class ScoreQuestionnaireInline(admin.TabularInline): @admin.register(ScoreQuestionnaireSector) class ScoreQuestionnaireSectorAdmin(ModelAdmin): - list_display = ('title', 'order', 'method', 'sub_method', linkify('template')) + list_display = ("title", "order", "method", "sub_method", linkify("template")) inlines = [ScoreQuestionnaireSubSectorInline] @admin.register(ScoreQuestionnaireSubSector) class ScoreQuestionnaireSubSectorAdmin(ModelAdmin): - list_display = ('title', 'order', linkify('sector'), linkify('sector.template')) + list_display = ("title", "order", linkify("sector"), linkify("sector.template")) inlines = [ScoreQuestionnaireInline] @admin.register(AffectedGroup) class AffectedGroupAdmin(ModelAdmin): - list_display = ('title', 'order', linkify('template'),) + list_display = ( + "title", + "order", + linkify("template"), + ) -class FieldAdminMixin(): - list_display = ('title', 'id', 'order', linkify('group'),) +class FieldAdminMixin: + list_display = ( + "title", + "id", + "order", + linkify("group"), + ) @admin.register(MetadataField) @@ -152,10 +154,13 @@ class MethodologyFieldAdmin(FieldAdminMixin, ModelAdmin): inlines = [MethodologyOptionInline] -class TemplateGroupAdminMixin(): - search_fields = ('title', 'template__title') - list_display = ('title', linkify('template'),) - list_filter = ('template',) +class TemplateGroupAdminMixin: + search_fields = ("title", "template__title") + list_display = ( + "title", + linkify("template"), + ) + list_filter = ("template",) @admin.register(Focus) @@ -210,6 +215,9 @@ class ScoreScaleAdmin(TemplateGroupAdminMixin, ModelAdmin): @admin.register(Assessment) class AssessmentAdmin(VersionAdmin): - search_fields = ('lead__title',) - list_display = ('lead', linkify('project'),) - autocomplete_fields = ('lead', 'project', 'created_by', 'modified_by', 'lead_group') + search_fields = ("lead__title",) + list_display = ( + "lead", + linkify("project"), + ) + autocomplete_fields = ("lead", "project", "created_by", "modified_by", "lead_group") diff --git a/apps/ary/apps.py b/apps/ary/apps.py index 56fd16acd1..cbfca50978 100644 --- a/apps/ary/apps.py +++ b/apps/ary/apps.py @@ -2,4 +2,4 @@ class AryConfig(AppConfig): - name = 'ary' + name = "ary" diff --git a/apps/ary/enums.py b/apps/ary/enums.py index 5d745a8d58..3bab21e61d 100644 --- a/apps/ary/enums.py +++ b/apps/ary/enums.py @@ -2,10 +2,10 @@ from .models import MethodologyProtectionInfo - AssessmentMethodologyProtectionInfoEnum = convert_enum_to_graphene_enum( - MethodologyProtectionInfo, name='AssessmentMethodologyProtectionInfoEnum') + MethodologyProtectionInfo, name="AssessmentMethodologyProtectionInfoEnum" +) enum_map = { - 'UnusedAssessmentMethodologyProtectionInfo': AssessmentMethodologyProtectionInfoEnum, + "UnusedAssessmentMethodologyProtectionInfo": AssessmentMethodologyProtectionInfoEnum, } diff --git a/apps/ary/export/__init__.py b/apps/ary/export/__init__.py index 21e2b97ddc..98b7190fbd 100644 --- a/apps/ary/export/__init__.py +++ b/apps/ary/export/__init__.py @@ -1,26 +1,16 @@ from functools import reduce -from .common import ( - get_assessment_meta, - default_values as common_defaults, -) -from .stakeholders_info import ( - get_stakeholders_info, - default_values as stakeholders_defaults -) -from .locations_info import ( - get_locations_info, - default_values as locations_defaults -) -from .data_collection_techniques_info import ( - get_data_collection_techniques_info, - default_values as collection_defaults -) -from .affected_groups_info import ( - get_affected_groups_info, - default_values as affected_defaults -) +from .affected_groups_info import default_values as affected_defaults +from .affected_groups_info import get_affected_groups_info +from .common import default_values as common_defaults +from .common import get_assessment_meta +from .data_collection_techniques_info import default_values as collection_defaults +from .data_collection_techniques_info import get_data_collection_techniques_info +from .locations_info import default_values as locations_defaults +from .locations_info import get_locations_info from .questionaire import get_questionaire +from .stakeholders_info import default_values as stakeholders_defaults +from .stakeholders_info import get_stakeholders_info def get_export_data(assessment): @@ -28,26 +18,23 @@ def get_export_data(assessment): questionaire_dict = get_questionaire(assessment) return { - 'data_collection_technique': { + "data_collection_technique": { **meta_data, **get_data_collection_techniques_info(assessment), }, - 'stakeholders': { + "stakeholders": { **meta_data, **get_stakeholders_info(assessment), }, - 'locations': { + "locations": { **meta_data, **get_locations_info(assessment), }, - 'affected_groups': { + "affected_groups": { **meta_data, **get_affected_groups_info(assessment), }, - 'cna': { - **meta_data, - **(questionaire_dict or {}) - } + "cna": {**meta_data, **(questionaire_dict or {})}, } @@ -72,39 +59,39 @@ def normalize_assessment(assessment_export_data): # Summary need not be normalized # Normalize stakeholders - stakeholders_sheet = assessment_export_data['stakeholders'] - new_stakeholders_sheet = replicate_other_col_groups(stakeholders_sheet, 'stakeholders') + stakeholders_sheet = assessment_export_data["stakeholders"] + new_stakeholders_sheet = replicate_other_col_groups(stakeholders_sheet, "stakeholders") # Normalize Locations - locations_sheet = assessment_export_data['locations'] - new_locations_sheet = replicate_other_col_groups(locations_sheet, 'locations') + locations_sheet = assessment_export_data["locations"] + new_locations_sheet = replicate_other_col_groups(locations_sheet, "locations") # Normalize Affected groups - affected_sheet = assessment_export_data['affected_groups'] - new_affected_sheet = replicate_other_col_groups(affected_sheet, 'affected_groups_info') + affected_sheet = assessment_export_data["affected_groups"] + new_affected_sheet = replicate_other_col_groups(affected_sheet, "affected_groups_info") assessment_data = { - 'stakeholders': new_stakeholders_sheet, - 'affected_groups': new_affected_sheet, - 'locations': new_locations_sheet, + "stakeholders": new_stakeholders_sheet, + "affected_groups": new_affected_sheet, + "locations": new_locations_sheet, } # Normailze Data Collection Techniques - techniques_sheet = assessment_export_data['data_collection_technique'] - new_techniques_sheet = replicate_other_col_groups(techniques_sheet, 'data_collection_technique') + techniques_sheet = assessment_export_data["data_collection_technique"] + new_techniques_sheet = replicate_other_col_groups(techniques_sheet, "data_collection_technique") return { **assessment_data, - 'data_collection_technique': new_techniques_sheet, - 'cna': {k: [v] for k, v in assessment_export_data['cna'].items()}, + "data_collection_technique": new_techniques_sheet, + "cna": {k: [v] for k, v in assessment_export_data["cna"].items()}, } DEFAULTS = { - 'stakeholders': stakeholders_defaults, - 'data_collection_technique': collection_defaults, - 'locations': locations_defaults, - 'affected_groups': affected_defaults, + "stakeholders": stakeholders_defaults, + "data_collection_technique": collection_defaults, + "locations": locations_defaults, + "affected_groups": affected_defaults, } for k, v in DEFAULTS.items(): v.update(common_defaults) @@ -128,16 +115,14 @@ def add_assessment_to_rows(sheets, assessment, planned_assessment=False): NOTE: If assessment has new column name inside grouped cols, the column is added to all existing data with None value """ + def add_new_keys(keys, data, default=None): if not keys: return data if isinstance(data, dict): return {**data, **{x: default for x in keys}} elif isinstance(data, list): - return [ - {**(x or {}), **{k: default for k in keys}} - for x in data - ] + return [{**(x or {}), **{k: default for k in keys}} for x in data] return data normalized_assessment = normalize_assessment(get_export_data(assessment)) @@ -167,31 +152,27 @@ def add_new_keys(keys, data, default=None): columns_data = [columns_data] if not isinstance(columns_data, list) else columns_data - assessment_col_data = [assessment_col_data]\ - if not isinstance(assessment_col_data, list) else assessment_col_data + assessment_col_data = [assessment_col_data] if not isinstance(assessment_col_data, list) else assessment_col_data if isinstance(columns_data[0], dict): # if assessment data empty, add empty dict if not assessment_col_data: assessment_col_data = [{}] - assessment_row_keys = set((assessment_col_data[0] or {}).keys())\ - if assessment_col_data else set() + assessment_row_keys = set((assessment_col_data[0] or {}).keys()) if assessment_col_data else set() sheet_row_keys = set(columns_data[0].keys()) new_ass_keys = assessment_row_keys.difference(sheet_row_keys) new_sheet_keys = sheet_row_keys.difference(assessment_row_keys) default_sheet = DEFAULTS.get(sheet) - default = default_sheet and default_sheet.get(col, default_sheet.get('*')) + default = default_sheet and default_sheet.get(col, default_sheet.get("*")) if new_ass_keys: # Add the key to each row in column data columns_data = add_new_keys(new_ass_keys, columns_data, default) if new_sheet_keys: # Add new keys to assessment data - assessment_col_data = add_new_keys( - new_sheet_keys, assessment_col_data, default - ) + assessment_col_data = add_new_keys(new_sheet_keys, assessment_col_data, default) # Now all the data is normalized(have same keys) # Append assessment data to col data columns_data.extend(assessment_col_data) @@ -211,10 +192,7 @@ def add_new_keys(keys, data, default=None): if not isinstance(coldata[0], dict): newcols_data[newcol] = [*[None] * (sheet_data_len), *coldata] else: - empty_data = { - key: None - for key in coldata[0].keys() - } + empty_data = {key: None for key in coldata[0].keys()} newcols_data[newcol] = [dict(empty_data) for _ in range(sheet_data_len)] newcols_data[newcol].extend(coldata) diff --git a/apps/ary/export/affected_groups_info.py b/apps/ary/export/affected_groups_info.py index e9ddca71a0..adb3ee7f36 100644 --- a/apps/ary/export/affected_groups_info.py +++ b/apps/ary/export/affected_groups_info.py @@ -1,15 +1,13 @@ from assessment_registry.models import AssessmentRegistry - -default_values = { -} +default_values = {} def get_affected_groups_info(assessment): affected_group_type_dict = {choice.value: choice.label for choice in AssessmentRegistry.AffectedGroupType} affected_groups = [affected_group_type_dict.get(group) for group in assessment.affected_groups if group] - max_level = max([len(v.split('/')) for k, v in AssessmentRegistry.AffectedGroupType.choices]) - levels = [f'Level {i+1}' for i in range(max_level)] + max_level = max([len(v.split("/")) for k, v in AssessmentRegistry.AffectedGroupType.choices]) + levels = [f"Level {i+1}" for i in range(max_level)] affected_grp_list = [] for group in affected_groups: group = group.split("/") @@ -22,5 +20,5 @@ def get_affected_groups_info(assessment): affected_grp_list.append(group_dict) return { - 'affected_groups_info': affected_grp_list, + "affected_groups_info": affected_grp_list, } diff --git a/apps/ary/export/common.py b/apps/ary/export/common.py index fcdd596812..a3440a43d7 100644 --- a/apps/ary/export/common.py +++ b/apps/ary/export/common.py @@ -1,18 +1,15 @@ from datetime import datetime -from utils.common import combine_dicts as _combine_dicts, deep_date_format + from assessment_registry.models import AssessmentRegistry -ISO_FORMAT = '%Y-%m-%d' +from utils.common import combine_dicts as _combine_dicts +from utils.common import deep_date_format + +ISO_FORMAT = "%Y-%m-%d" def combine_dicts(dict_list): - return _combine_dicts( - [ - { - _dict['schema']['name']: _dict - } - for _dict in dict_list] - ) + return _combine_dicts([{_dict["schema"]["name"]: _dict} for _dict in dict_list]) def str_to_dmy_date(datestr): @@ -23,27 +20,20 @@ def str_to_dmy_date(datestr): def get_value(d, key, default=None): - return d.get(key, {}).get('value', default) + return d.get(key, {}).get("value", default) def get_name_values(data_dict, keys): if not isinstance(keys, list): keys = [keys] - return { - x['schema']['name']: x['value'] - for key in keys - for x in data_dict.get(key, []) - } + return {x["schema"]["name"]: x["value"] for key in keys for x in data_dict.get(key, [])} def get_name_values_options(data_dict, keys): if not isinstance(keys, list): keys = [keys] return { - x['schema']['name']: { - 'value': x['value'], - 'options': x['schema']['options'] - } + x["schema"]["name"]: {"value": x["value"], "options": x["schema"]["options"]} for key in keys for x in data_dict.get(key, []) } @@ -51,20 +41,12 @@ def get_name_values_options(data_dict, keys): def populate_with_all_values(d, key, default=None): """This gets options and returns dict containing {value: count}""" - options = { - v: 0 for k, v in d.get(key, {}).get('options', {}).items() - } - return { - **options, - **{ - x: 1 - for x in d.get(key, {}).get('value', default) - } - } + options = {v: 0 for k, v in d.get(key, {}).get("options", {}).items()} + return {**options, **{x: 1 for x in d.get(key, {}).get("value", default)}} default_values = { - 'language': 0, + "language": 0, } @@ -87,42 +69,48 @@ def get_assessment_meta(assessment): ) return { - 'lead': { - 'date_of_lead_publication': deep_date_format(lead.published_on), - 'unique_assessment_id': assessment.id, - 'imported_by': ', '.join([user.username for user in lead.assignee.all()]), - 'lead_title': lead.title, - 'url': lead.url, - 'source': lead.get_source_display(), + "lead": { + "date_of_lead_publication": deep_date_format(lead.published_on), + "unique_assessment_id": assessment.id, + "imported_by": ", ".join([user.username for user in lead.assignee.all()]), + "lead_title": lead.title, + "url": lead.url, + "source": lead.get_source_display(), }, - - 'background': { - 'country': ','.join(admin_levels), - 'crisis_type': assessment.get_bg_crisis_type_display(), - 'crisis_start_date': assessment.bg_crisis_start_date.strftime("%d-%m-%Y") if - assessment.bg_crisis_start_date else assessment.bg_crisis_start_date, - 'preparedness': assessment.get_bg_preparedness_display(), - 'external_support': assessment.get_external_support_display(), - 'coordination': assessment.get_coordinated_joint_display(), - 'cost_estimates_in_USD': assessment.cost_estimates_usd, + "background": { + "country": ",".join(admin_levels), + "crisis_type": assessment.get_bg_crisis_type_display(), + "crisis_start_date": ( + assessment.bg_crisis_start_date.strftime("%d-%m-%Y") + if assessment.bg_crisis_start_date + else assessment.bg_crisis_start_date + ), + "preparedness": assessment.get_bg_preparedness_display(), + "external_support": assessment.get_external_support_display(), + "coordination": assessment.get_coordinated_joint_display(), + "cost_estimates_in_USD": assessment.cost_estimates_usd, }, - - 'details': { - 'type': assessment.get_details_type_display(), - 'family': assessment.get_family_display(), - 'frequency': assessment.get_frequency_display(), - 'confidentiality': assessment.get_confidentiality_display(), - 'number_of_pages': assessment.no_of_pages, + "details": { + "type": assessment.get_details_type_display(), + "family": assessment.get_family_display(), + "frequency": assessment.get_frequency_display(), + "confidentiality": assessment.get_confidentiality_display(), + "number_of_pages": assessment.no_of_pages, }, - - 'language': get_languages(assessment), - - 'dates': { - 'data_collection_start_date': assessment.data_collection_start_date.strftime("%d-%m-%Y") if - assessment.data_collection_start_date else assessment.data_collection_start_date, - 'data_collection_end_date': assessment.data_collection_end_date.strftime("%d-%m-%Y") if - assessment.data_collection_end_date else assessment.data_collection_end_date, - 'publication_date': assessment.publication_date.strftime("%d-%m-%Y") if - assessment.publication_date else assessment.publication_date + "language": get_languages(assessment), + "dates": { + "data_collection_start_date": ( + assessment.data_collection_start_date.strftime("%d-%m-%Y") + if assessment.data_collection_start_date + else assessment.data_collection_start_date + ), + "data_collection_end_date": ( + assessment.data_collection_end_date.strftime("%d-%m-%Y") + if assessment.data_collection_end_date + else assessment.data_collection_end_date + ), + "publication_date": ( + assessment.publication_date.strftime("%d-%m-%Y") if assessment.publication_date else assessment.publication_date + ), }, } diff --git a/apps/ary/export/data_collection_techniques_info.py b/apps/ary/export/data_collection_techniques_info.py index 6597a17c4d..cd4e6b0300 100644 --- a/apps/ary/export/data_collection_techniques_info.py +++ b/apps/ary/export/data_collection_techniques_info.py @@ -1,21 +1,18 @@ from assessment_registry.models import MethodologyAttribute -default_values = { -} +default_values = {} def format_value(val): if isinstance(val, list): - return ','.join(val) + return ",".join(val) if val is None: - val = '' + val = "" return str(val) def get_data_collection_techniques_info(assessment): - attributes = MethodologyAttribute.objects.filter( - assessment_registry=assessment - ) + attributes = MethodologyAttribute.objects.filter(assessment_registry=assessment) data = [ { "Data Collection Technique": attr.get_data_collection_technique_display(), @@ -24,8 +21,9 @@ def get_data_collection_techniques_info(assessment): "Proximity": attr.get_proximity_display(), "Unit of Analysis": attr.get_unit_of_analysis_display(), "Unit of reporting": attr.get_unit_of_reporting_display(), - }for attr in attributes + } + for attr in attributes ] return { - 'data_collection_technique': data, + "data_collection_technique": data, } diff --git a/apps/ary/export/locations_info.py b/apps/ary/export/locations_info.py index 93e829b624..822a8ec5c5 100644 --- a/apps/ary/export/locations_info.py +++ b/apps/ary/export/locations_info.py @@ -1,21 +1,16 @@ -default_values = { -} +default_values = {} def is_point_data(x): - return isinstance(x, dict) and x['geo_json']['geometry']['type'] == 'Point' + return isinstance(x, dict) and x["geo_json"]["geometry"]["type"] == "Point" def is_polygon_data(x): - return isinstance(x, dict) and x['geo_json']['geometry']['type'] == 'Polygon' + return isinstance(x, dict) and x["geo_json"]["geometry"]["type"] == "Polygon" def get_title_from_geo_json_data(x): - return ( - x.get('geo_json') and - x['geo_json'].get('properties') and - x['geo_json']['properties'].get('title') - ) + return x.get("geo_json") and x["geo_json"].get("properties") and x["geo_json"]["properties"].get("title") def get_locations_info(assessment): @@ -23,25 +18,25 @@ def get_locations_info(assessment): data = [] if not geo_areas: - return {'locations': data} + return {"locations": data} # Region is the region of the first geo area region = geo_areas[0].admin_level.region - region_geos = {x['key']: x for x in region.geo_options} + region_geos = {x["key"]: x for x in region.geo_options} for area in geo_areas: geo_info = region_geos.get(str(area.id)) if geo_info is None: continue - level = geo_info['admin_level'] - key = f'Admin {level}' + level = geo_info["admin_level"] + key = f"Admin {level}" - admin_levels = {f'Admin {x}': None for x in range(7)} + admin_levels = {f"Admin {x}": None for x in range(7)} admin_levels[key] = area.title # Now add parents as well while level - 1: level -= 1 - parent_id = geo_info['parent'] + parent_id = geo_info["parent"] if parent_id is None: break @@ -50,11 +45,11 @@ def get_locations_info(assessment): if not geo_info: break - key = f'Admin {level}' - admin_levels[key] = geo_info['title'] + key = f"Admin {level}" + admin_levels[key] = geo_info["title"] data.append(admin_levels) return { - 'locations': data, + "locations": data, } diff --git a/apps/ary/export/questionaire.py b/apps/ary/export/questionaire.py index d97dddae5e..3ace4000c8 100644 --- a/apps/ary/export/questionaire.py +++ b/apps/ary/export/questionaire.py @@ -9,7 +9,8 @@ def get_questionaire(assessment): for sub_sector in sub_sector_list_set: questionaire_dict[sub_sector] = { answer.question.question: 1 if answer.answer else 0 - for answer in answers if answer.question.get_sub_sector_display() == sub_sector + for answer in answers + if answer.question.get_sub_sector_display() == sub_sector } return questionaire_dict diff --git a/apps/ary/export/scoring.py b/apps/ary/export/scoring.py index 09584b5c6f..42c48be496 100644 --- a/apps/ary/export/scoring.py +++ b/apps/ary/export/scoring.py @@ -1,41 +1,37 @@ default_values = { - 'Final Score': 0, - '*': 0, + "Final Score": 0, + "*": 0, } def get_scoring(assessment): scoring_data = assessment.get_score_json() pillars_final_scores = { - '{} Final Score'.format(title): score - for title, score in (scoring_data.get('final_pillars_score') or {}).items() + "{} Final Score".format(title): score for title, score in (scoring_data.get("final_pillars_score") or {}).items() } - matrix_pillars_final_scores = scoring_data.get('matrix_pillars_final_score') or {} + matrix_pillars_final_scores = scoring_data.get("matrix_pillars_final_score") or {} pillars = { - pillar: { - sub_pillar: sp_data['value'] - for sub_pillar, sp_data in pillar_data.items() - } - for pillar, pillar_data in (scoring_data.get('pillars') or {}).items() + pillar: {sub_pillar: sp_data["value"] for sub_pillar, sp_data in pillar_data.items()} + for pillar, pillar_data in (scoring_data.get("pillars") or {}).items() } matrix_pillars_scores = {} - for title, pillars_score in (scoring_data.get('matrix_pillars') or {}).items(): - col_key = '{} Score'.format(title) + for title, pillars_score in (scoring_data.get("matrix_pillars") or {}).items(): + col_key = "{} Score".format(title) matrix_pillars_scores[col_key] = {} for sector, data in pillars_score.items(): - matrix_pillars_scores[col_key][sector] = data['value'] + matrix_pillars_scores[col_key][sector] = data["value"] return { **pillars, **matrix_pillars_scores, - 'final_scores': { + "final_scores": { **pillars_final_scores, **matrix_pillars_final_scores, }, - '': { - 'Final Score': scoring_data.get('final_score'), - } + "": { + "Final Score": scoring_data.get("final_score"), + }, } diff --git a/apps/ary/export/stakeholders_info.py b/apps/ary/export/stakeholders_info.py index 2de1a3fb46..ce7e62c818 100644 --- a/apps/ary/export/stakeholders_info.py +++ b/apps/ary/export/stakeholders_info.py @@ -1,19 +1,16 @@ from assessment_registry.models import AssessmentRegistryOrganization - default_values = { - 'stakeholders': None, + "stakeholders": None, } def get_stakeholders_info(assessment): stakeholders_info = [ - { - 'name': org.organization.title, - 'type': org.get_organization_type_display() - }for org in AssessmentRegistryOrganization.objects.filter(assessment_registry=assessment) + {"name": org.organization.title, "type": org.get_organization_type_display()} + for org in AssessmentRegistryOrganization.objects.filter(assessment_registry=assessment) ] # TODO : Add Dataloaders return { - 'stakeholders': stakeholders_info, + "stakeholders": stakeholders_info, } diff --git a/apps/ary/export/summary.py b/apps/ary/export/summary.py index 79946b6843..615eeea63a 100644 --- a/apps/ary/export/summary.py +++ b/apps/ary/export/summary.py @@ -1,28 +1,27 @@ -from apps.entry.widgets.geo_widget import get_valid_geo_ids -from geo.models import GeoArea - from ary.models import ( + AffectedGroup, + Focus, + MetadataField, MethodologyGroup, MethodologyOption, MethodologyProtectionInfo, - MetadataField, - Focus, Sector, - AffectedGroup, ) +from geo.models import GeoArea -from .scoring import get_scoring +from apps.entry.widgets.geo_widget import get_valid_geo_ids +from .scoring import get_scoring # Default values for column groups # Add other default values as required default_values = { - 'location': 0, - 'additional_documents': 0, - 'focuses': 0, - 'sectors': 0, - 'affected_groups': 0, - 'methodology_content': 0 + "location": 0, + "additional_documents": 0, + "focuses": 0, + "sectors": 0, + "affected_groups": 0, + "methodology_content": 0, } @@ -31,18 +30,14 @@ def get_methodology_summary(assessment): groups = MethodologyGroup.objects.filter(template=assessment.project.assessment_template) attributes = {} - groups_options = { - group.title: MethodologyOption.objects.filter(field__in=group.fields.all()) - for group in groups - } + groups_options = {group.title: MethodologyOption.objects.filter(field__in=group.fields.all()) for group in groups} - for attr in methodology.get('Attributes') or []: + for attr in methodology.get("Attributes") or []: for group, options in groups_options.items(): data = attr.get(group) or [{}] attr_data = attributes.get(group) or {} for option in options: - attr_data[option.title] = attr_data.get(option.title, 0) +\ - (1 if data[0].get('value') == option.title else 0) + attr_data[option.title] = attr_data.get(option.title, 0) + (1 if data[0].get("value") == option.title else 0) attributes[group] = attr_data return attributes @@ -53,19 +48,19 @@ def get_assessment_export_summary(assessment, planned_assessment=False): """ template = assessment.project.assessment_template - additional_documents = (assessment.metadata or {}).get('additional_documents') or {} + additional_documents = (assessment.metadata or {}).get("additional_documents") or {} metadata = assessment.get_metadata_json() methodology = assessment.get_methodology_json() focuses = [x.title for x in Focus.objects.filter(template=template)] - selected_focuses = set(methodology.get('Focuses') or []) + selected_focuses = set(methodology.get("Focuses") or []) sectors = [x.title for x in Sector.objects.filter(template=template)] - selected_sectors = set(methodology.get('Sectors') or []) + selected_sectors = set(methodology.get("Sectors") or []) methodology_protection_informations = [label for _, label in MethodologyProtectionInfo.choices] - selected_methodology_protection_informations = set(methodology.get('Protection Info') or []) + selected_methodology_protection_informations = set(methodology.get("Protection Info") or []) root_affected_group = AffectedGroup.objects.filter(template=template, parent=None).first() all_affected_groups = root_affected_group.get_children_list() if root_affected_group else [] @@ -73,75 +68,61 @@ def get_assessment_export_summary(assessment, planned_assessment=False): methodology_summary = get_methodology_summary(assessment) # All affected groups does not have title, so generate title from parents - processed_affected_groups = [ - { - 'title': ' and '.join(x['parents'][::-1]), - 'id': x['id'] - } for x in all_affected_groups - ] + processed_affected_groups = [{"title": " and ".join(x["parents"][::-1]), "id": x["id"]} for x in all_affected_groups] - selected_affected_groups_ids = {x['key'] for x in (methodology.get('Affected Groups') or [])} + selected_affected_groups_ids = {x["key"] for x in (methodology.get("Affected Groups") or [])} - locations = get_valid_geo_ids((assessment.methodology or {}).get('locations') or {}) + locations = get_valid_geo_ids((assessment.methodology or {}).get("locations") or {}) - geo_areas = GeoArea.objects.filter(id__in=locations).prefetch_related('admin_level') - admin_levels = {f'Admin {x}': 0 for x in range(7)} + geo_areas = GeoArea.objects.filter(id__in=locations).prefetch_related("admin_level") + admin_levels = {f"Admin {x}": 0 for x in range(7)} for geo in geo_areas: level = geo.admin_level.level - key = f'Admin {level}' + key = f"Admin {level}" admin_levels[key] = admin_levels.get(key, 0) + 1 planned_assessment_info = { **methodology_summary, - 'location': admin_levels, - 'focuses': { - x: 1 if x in selected_focuses else 0 - for x in focuses - }, - 'sectors': { - x: 1 if x in selected_sectors else 0 - for x in sectors - }, - 'protection_information_management': { - x: 1 if x in selected_methodology_protection_informations else 0 - for x in methodology_protection_informations - }, - 'affected_groups': { - x['title']: 1 if x['id'] in selected_affected_groups_ids else 0 - for x in processed_affected_groups + "location": admin_levels, + "focuses": {x: 1 if x in selected_focuses else 0 for x in focuses}, + "sectors": {x: 1 if x in selected_sectors else 0 for x in sectors}, + "protection_information_management": { + x: 1 if x in selected_methodology_protection_informations else 0 for x in methodology_protection_informations }, + "affected_groups": {x["title"]: 1 if x["id"] in selected_affected_groups_ids else 0 for x in processed_affected_groups}, } if planned_assessment: - return {'title': assessment.title, **planned_assessment_info} + return {"title": assessment.title, **planned_assessment_info} - stakeholders = metadata.get('Stakeholders') or [] + stakeholders = metadata.get("Stakeholders") or [] lead_org = stakeholders and stakeholders[0] other_orgs = [x for x in stakeholders[1:]] data = { - 'methodology_content': { - 'objectives': 1 if methodology.get('Objectives') else 0, - 'data_collection_techniques': 1 if methodology.get('Data Collection Techniques') else 0, - 'sampling': 1 if methodology.get('Sampling') else 0, - 'limitations': 1 if methodology.get('Limitations') else 0, + "methodology_content": { + "objectives": 1 if methodology.get("Objectives") else 0, + "data_collection_techniques": 1 if methodology.get("Data Collection Techniques") else 0, + "sampling": 1 if methodology.get("Sampling") else 0, + "limitations": 1 if methodology.get("Limitations") else 0, }, - 'stakeholders': { - lead_org['schema']['name']: lead_org['value'][0]['name'] if lead_org['value'] else '', - **{ - x['schema']['name']: len(x['value']) - if x['schema']['type'] == MetadataField.MULTISELECT else 1 - for x in other_orgs + "stakeholders": ( + { + lead_org["schema"]["name"]: lead_org["value"][0]["name"] if lead_org["value"] else "", + **{ + x["schema"]["name"]: len(x["value"]) if x["schema"]["type"] == MetadataField.MULTISELECT else 1 + for x in other_orgs + }, } - } if stakeholders else {}, - - 'additional_documents': { - 'Executive Summary': 1 if additional_documents.get('executive_summary') else 0, - 'Assessment Database': 1 if additional_documents.get('assessment_data') else 0, - 'Questionnaire': 1 if additional_documents.get('questionnaire') else 0, - 'Miscellaneous': 1 if additional_documents.get('misc') else 0, + if stakeholders + else {} + ), + "additional_documents": { + "Executive Summary": 1 if additional_documents.get("executive_summary") else 0, + "Assessment Database": 1 if additional_documents.get("assessment_data") else 0, + "Questionnaire": 1 if additional_documents.get("questionnaire") else 0, + "Miscellaneous": 1 if additional_documents.get("misc") else 0, }, - **planned_assessment_info, } diff --git a/apps/ary/factories.py b/apps/ary/factories.py index b8bde5b330..ab4b226508 100644 --- a/apps/ary/factories.py +++ b/apps/ary/factories.py @@ -1,9 +1,6 @@ from factory.django import DjangoModelFactory -from .models import ( - AssessmentTemplate, - Assessment, -) +from .models import Assessment, AssessmentTemplate class AssessmentFactory(DjangoModelFactory): diff --git a/apps/ary/filters.py b/apps/ary/filters.py index d9a24da693..cb92ff630d 100644 --- a/apps/ary/filters.py +++ b/apps/ary/filters.py @@ -1,22 +1,17 @@ import django_filters from django.db import models - -from user_resource.filters import UserResourceFilterSet -from user.models import User -from project.models import Project from lead.models import Lead, LeadGroup -from user_resource.filters import UserResourceGqlFilterSet +from project.models import Project +from user.models import User +from user_resource.filters import UserResourceFilterSet, UserResourceGqlFilterSet -from .models import ( - Assessment, - PlannedAssessment, -) +from .models import Assessment, PlannedAssessment class AssessmentFilterSet(UserResourceFilterSet): project = django_filters.ModelMultipleChoiceFilter( queryset=Project.objects.all(), - field_name='lead__project', + field_name="lead__project", ) lead = django_filters.ModelMultipleChoiceFilter( queryset=Lead.objects.all(), @@ -31,13 +26,13 @@ class AssessmentFilterSet(UserResourceFilterSet): class Meta: model = Assessment - fields = ['id', 'lead__title', 'lead_group__title'] + fields = ["id", "lead__title", "lead_group__title"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -46,7 +41,7 @@ class Meta: class PlannedAssessmentFilterSet(UserResourceFilterSet): project = django_filters.ModelMultipleChoiceFilter( queryset=Project.objects.all(), - field_name='project', + field_name="project", ) created_by = django_filters.ModelMultipleChoiceFilter( queryset=User.objects.all(), @@ -55,13 +50,13 @@ class PlannedAssessmentFilterSet(UserResourceFilterSet): class Meta: model = PlannedAssessment - fields = ['id', 'title'] + fields = ["id", "title"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -69,7 +64,7 @@ class Meta: # -------------------- Graphql Filters ----------------------------------- class AssessmentGQFilterSet(UserResourceGqlFilterSet): - search = django_filters.CharFilter(method='filter_title') + search = django_filters.CharFilter(method="filter_title") class Meta: model = Assessment @@ -78,7 +73,4 @@ class Meta: def filter_title(self, qs, name, value): if not value: return qs - return qs.filter( - models.Q(lead__title__icontains=value) | - models.Q(lead_group__title__icontains=value) - ).distinct() + return qs.filter(models.Q(lead__title__icontains=value) | models.Q(lead_group__title__icontains=value)).distinct() diff --git a/apps/ary/management/commands/export_ary_template.py b/apps/ary/management/commands/export_ary_template.py index 306a254f6b..283fcc9c70 100644 --- a/apps/ary/management/commands/export_ary_template.py +++ b/apps/ary/management/commands/export_ary_template.py @@ -1,20 +1,19 @@ +from django.apps import apps from django.core.management import call_command -from django.core.serializers import serialize from django.core.management.base import BaseCommand -from django.apps import apps +from django.core.serializers import serialize from fixture_magic.utils import ( - add_to_serialize_list, serialize_me, serialize_fully, seen, + add_to_serialize_list, + seen, + serialize_fully, + serialize_me, ) - -IGNORE_MODELS = [ - f'ary.{model}' - for model in ['Assessment'] -] +IGNORE_MODELS = [f"ary.{model}" for model in ["Assessment"]] def export_ary_fixture(): - for model in apps.get_app_config('ary').get_models(): + for model in apps.get_app_config("ary").get_models(): label = model._meta.label if label in IGNORE_MODELS: continue @@ -22,14 +21,14 @@ def export_ary_fixture(): add_to_serialize_list(objs) serialize_fully() data = serialize( - 'json', + "json", sorted( [o for o in serialize_me if o is not None], - key=lambda x: (f'{x._meta.app_label}.{x._meta.model_name}', x.pk), + key=lambda x: (f"{x._meta.app_label}.{x._meta.model_name}", x.pk), ), indent=2, use_natural_foreign_keys=False, - use_natural_primary_keys=False + use_natural_primary_keys=False, ) del serialize_me[:] seen.clear() @@ -40,12 +39,12 @@ class Command(BaseCommand): def handle(self, *args, **options): export_files = [] - for model in apps.get_app_config('ary').get_models(): + for model in apps.get_app_config("ary").get_models(): label = model._meta.label if label in IGNORE_MODELS: continue - filename = f'/tmp/{label}.json' + filename = f"/tmp/{label}.json" export_files.append(filename) - with open(filename, 'w+') as f: - call_command('dump_object', label, '*', stdout=f) - call_command('merge_fixtures', *export_files) + with open(filename, "w+") as f: + call_command("dump_object", label, "*", stdout=f) + call_command("merge_fixtures", *export_files) diff --git a/apps/ary/models.py b/apps/ary/models.py index ec7dc28c1b..59eb3d490a 100644 --- a/apps/ary/models.py +++ b/apps/ary/models.py @@ -1,22 +1,22 @@ from collections import OrderedDict -from django.db import models -from django.core.exceptions import ValidationError -from user_resource.models import UserResource -from deep.models import Field, FieldOption +from django.core.exceptions import ValidationError +from django.db import models from lead.models import Lead, LeadGroup from project.mixins import ProjectEntityMixin +from user_resource.models import UserResource + +from deep.models import Field, FieldOption +from utils.common import identity, underscore_to_title from .utils import ( FIELDS_KEYS_VALUE_EXTRACTORS, - get_title_or_none, + get_integer_enum_title, get_location_title, get_model_attrs_or_empty_dict, - get_integer_enum_title, + get_title_or_none, ) -from utils.common import identity, underscore_to_title - class AssessmentTemplate(UserResource): title = models.CharField(max_length=255) @@ -53,11 +53,11 @@ class BasicEntity(models.Model): order = models.IntegerField(default=1) def __str__(self): - return '{}'.format(self.title) + return "{}".format(self.title) class Meta: abstract = True - ordering = ['order'] + ordering = ["order"] class BasicTemplateEntity(models.Model): @@ -66,11 +66,11 @@ class BasicTemplateEntity(models.Model): order = models.IntegerField(default=1) def __str__(self): - return '{} ({})'.format(self.title, self.template) + return "{} ({})".format(self.title, self.template) class Meta: abstract = True - ordering = ['order'] + ordering = ["order"] class MetadataGroup(BasicTemplateEntity): @@ -79,30 +79,34 @@ class MetadataGroup(BasicTemplateEntity): class MetadataField(Field): group = models.ForeignKey( - MetadataGroup, related_name='fields', on_delete=models.CASCADE, + MetadataGroup, + related_name="fields", + on_delete=models.CASCADE, ) tooltip = models.TextField(blank=True) order = models.IntegerField(default=1) show_in_planned_assessment = models.BooleanField(default=False) def __str__(self): - return '{} ({})'.format(self.title, self.group.template) + return "{} ({})".format(self.title, self.group.template) class Meta(Field.Meta): - ordering = ['order'] + ordering = ["order"] class MetadataOption(FieldOption): field = models.ForeignKey( - MetadataField, related_name='options', on_delete=models.CASCADE, + MetadataField, + related_name="options", + on_delete=models.CASCADE, ) order = models.IntegerField(default=1) def __str__(self): - return 'Option {} for {}'.format(self.title, self.field) + return "Option {} for {}".format(self.title, self.field) class Meta(FieldOption.Meta): - ordering = ['order'] + ordering = ["order"] class MethodologyGroup(BasicTemplateEntity): @@ -110,42 +114,46 @@ class MethodologyGroup(BasicTemplateEntity): class MethodologyProtectionInfo(models.IntegerChoices): - PROTECTION_MONITORING = 1, 'Protection Monitoring' - PROTECTION_NEEDS_ASSESSMENT = 2, 'Protection Needs Assessment' - CASE_MANAGEMENT = 3, 'Case Management' - POPULATION_DATA = 4, 'Population Data' - PROTECTION_RESPONSE = 5, 'Protection Response M&E' - COMMUNICATING_WITH_OR_IN_AFFECTED_COMMUNITIES = 6, 'Communicating with(in) Affected Communities' - SECURITY_AND_SITUATIONAL_AWARENESS = 7, 'Security & Situational Awareness' - SECTORAL_SYSTEMS_OR_OTHER = 8, 'Sectoral Systems/Other' + PROTECTION_MONITORING = 1, "Protection Monitoring" + PROTECTION_NEEDS_ASSESSMENT = 2, "Protection Needs Assessment" + CASE_MANAGEMENT = 3, "Case Management" + POPULATION_DATA = 4, "Population Data" + PROTECTION_RESPONSE = 5, "Protection Response M&E" + COMMUNICATING_WITH_OR_IN_AFFECTED_COMMUNITIES = 6, "Communicating with(in) Affected Communities" + SECURITY_AND_SITUATIONAL_AWARENESS = 7, "Security & Situational Awareness" + SECTORAL_SYSTEMS_OR_OTHER = 8, "Sectoral Systems/Other" class MethodologyField(Field): group = models.ForeignKey( - MethodologyGroup, related_name='fields', on_delete=models.CASCADE, + MethodologyGroup, + related_name="fields", + on_delete=models.CASCADE, ) tooltip = models.TextField(blank=True) order = models.IntegerField(default=1) show_in_planned_assessment = models.BooleanField(default=False) def __str__(self): - return '{} ({})'.format(self.title, self.group.template) + return "{} ({})".format(self.title, self.group.template) class Meta(Field.Meta): - ordering = ['order'] + ordering = ["order"] class MethodologyOption(FieldOption): field = models.ForeignKey( - MethodologyField, related_name='options', on_delete=models.CASCADE, + MethodologyField, + related_name="options", + on_delete=models.CASCADE, ) order = models.IntegerField(default=1) def __str__(self): - return 'Option {} for {}'.format(self.title, self.field) + return "Option {} for {}".format(self.title, self.field) class Meta(FieldOption.Meta): - ordering = ['order'] + ordering = ["order"] class Sector(BasicTemplateEntity): @@ -154,14 +162,17 @@ class Sector(BasicTemplateEntity): class Focus(BasicTemplateEntity): class Meta(BasicTemplateEntity.Meta): - verbose_name_plural = 'focuses' + verbose_name_plural = "focuses" class AffectedGroup(BasicTemplateEntity): parent = models.ForeignKey( - 'AffectedGroup', - related_name='children', on_delete=models.CASCADE, - default=None, null=True, blank=True, + "AffectedGroup", + related_name="children", + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, ) def get_children_list(self): @@ -175,73 +186,74 @@ def get_children_list(self): ] """ # TODO: cache, but very careful - nodes_list = [ - { - 'title': self.title, - 'parents': [self.title], # includes self as well - 'id': self.id - } - ] + nodes_list = [{"title": self.title, "parents": [self.title], "id": self.id}] # includes self as well children = self.children.all() if not children: return nodes_list for child in children: - nodes_list.extend([ - { - 'title': f'{self.title} - {x["title"]}', - 'parents': [*x['parents'], self.title], - 'id': x['id'] - } - for x in child.get_children_list() - ]) + nodes_list.extend( + [ + {"title": f'{self.title} - {x["title"]}', "parents": [*x["parents"], self.title], "id": x["id"]} + for x in child.get_children_list() + ] + ) return nodes_list class PrioritySector(BasicTemplateEntity): parent = models.ForeignKey( - 'PrioritySector', - related_name='children', on_delete=models.CASCADE, - default=None, null=True, blank=True, + "PrioritySector", + related_name="children", + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, ) class Meta(BasicTemplateEntity.Meta): - verbose_name = 'sector with most unmet need' - verbose_name_plural = 'sectors with most unmet need' + verbose_name = "sector with most unmet need" + verbose_name_plural = "sectors with most unmet need" class PriorityIssue(BasicTemplateEntity): parent = models.ForeignKey( - 'PriorityIssue', - related_name='children', on_delete=models.CASCADE, - default=None, null=True, blank=True, + "PriorityIssue", + related_name="children", + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, ) class Meta(BasicTemplateEntity.Meta): - verbose_name = 'priority humanitarian access issue' + verbose_name = "priority humanitarian access issue" class UnderlyingFactor(BasicTemplateEntity): parent = models.ForeignKey( - 'UnderlyingFactor', - related_name='children', on_delete=models.CASCADE, - default=None, null=True, blank=True, + "UnderlyingFactor", + related_name="children", + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, ) class Meta(BasicTemplateEntity.Meta): - verbose_name = 'main sectoral underlying factor' + verbose_name = "main sectoral underlying factor" class SpecificNeedGroup(BasicTemplateEntity): class Meta(BasicTemplateEntity.Meta): - verbose_name = 'priority group with specific need' - verbose_name_plural = 'priority groups with specific need' + verbose_name = "priority group with specific need" + verbose_name_plural = "priority groups with specific need" # TODO: Remove / This is text field now and is not required anymore class AffectedLocation(BasicTemplateEntity): class Meta(BasicTemplateEntity.Meta): - verbose_name = 'setting facing most humanitarian access issues' - verbose_name_plural = 'settings facing most humanitarian access issues' + verbose_name = "setting facing most humanitarian access issues" + verbose_name_plural = "settings facing most humanitarian access issues" class ScoreBucket(models.Model): @@ -251,7 +263,7 @@ class ScoreBucket(models.Model): score = models.FloatField(default=1) def __str__(self): - return '{} <= x < {} : {} ({})'.format( + return "{} <= x < {} : {} ({})".format( self.min_value, self.max_value, self.score, @@ -259,7 +271,7 @@ def __str__(self): ) class Meta: - ordering = ['min_value'] + ordering = ["min_value"] class ScorePillar(BasicTemplateEntity): @@ -268,7 +280,9 @@ class ScorePillar(BasicTemplateEntity): class ScoreQuestion(BasicEntity): pillar = models.ForeignKey( - ScorePillar, on_delete=models.CASCADE, related_name='questions', + ScorePillar, + on_delete=models.CASCADE, + related_name="questions", ) description = models.TextField(blank=True) @@ -281,7 +295,7 @@ class ScoreScale(models.Model): default = models.BooleanField(default=False) def __str__(self): - return '{} ({} : {}) - ({})'.format( + return "{} ({} : {}) - ({})".format( self.title, self.value, self.color, @@ -289,7 +303,7 @@ def __str__(self): ) class Meta: - ordering = ['value'] + ordering = ["value"] class ScoreMatrixPillar(BasicTemplateEntity): @@ -298,19 +312,25 @@ class ScoreMatrixPillar(BasicTemplateEntity): class ScoreMatrixRow(BasicEntity): pillar = models.ForeignKey( - ScoreMatrixPillar, on_delete=models.CASCADE, related_name='rows', + ScoreMatrixPillar, + on_delete=models.CASCADE, + related_name="rows", ) class ScoreMatrixColumn(BasicEntity): pillar = models.ForeignKey( - ScoreMatrixPillar, on_delete=models.CASCADE, related_name='columns', + ScoreMatrixPillar, + on_delete=models.CASCADE, + related_name="columns", ) class ScoreMatrixScale(models.Model): pillar = models.ForeignKey( - ScoreMatrixPillar, on_delete=models.CASCADE, related_name='scales', + ScoreMatrixPillar, + on_delete=models.CASCADE, + related_name="scales", ) row = models.ForeignKey(ScoreMatrixRow, on_delete=models.CASCADE) column = models.ForeignKey(ScoreMatrixColumn, on_delete=models.CASCADE) @@ -318,28 +338,27 @@ class ScoreMatrixScale(models.Model): default = models.BooleanField(default=False) def __str__(self): - return '{}-{} : {}'.format(str(self.row), str(self.column), - str(self.value)) + return "{}-{} : {}".format(str(self.row), str(self.column), str(self.value)) class Meta: - ordering = ['value'] + ordering = ["value"] class ScoreQuestionnaireSector(BasicTemplateEntity): - HNO = 'hno' - CNA = 'cna' + HNO = "hno" + CNA = "cna" - CRITERIA = 'criteria' - ETHOS = 'ethos' + CRITERIA = "criteria" + ETHOS = "ethos" METHOD_CHOICES = ( - (HNO, 'HNO'), - (CNA, 'CNA'), + (HNO, "HNO"), + (CNA, "CNA"), ) SUB_METHOD_CHOICES = ( - (CRITERIA, 'Criteria'), - (ETHOS, 'Ethos'), + (CRITERIA, "Criteria"), + (ETHOS, "Ethos"), ) method = models.CharField(max_length=10, choices=METHOD_CHOICES) sub_method = models.CharField(max_length=10, choices=SUB_METHOD_CHOICES) @@ -360,13 +379,21 @@ class Assessment(UserResource, ProjectEntityMixin): """ Assessment belonging to a lead """ + lead = models.OneToOneField( - Lead, default=None, blank=True, null=True, on_delete=models.CASCADE, + Lead, + default=None, + blank=True, + null=True, + on_delete=models.CASCADE, ) - project = models.ForeignKey('project.Project', on_delete=models.CASCADE) + project = models.ForeignKey("project.Project", on_delete=models.CASCADE) lead_group = models.OneToOneField( - LeadGroup, on_delete=models.CASCADE, - default=None, blank=True, null=True, + LeadGroup, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) metadata = models.JSONField(default=None, blank=True, null=True) methodology = models.JSONField(default=None, blank=True, null=True) @@ -379,13 +406,9 @@ def __str__(self): def clean(self): if not self.lead and not self.lead_group: - raise ValidationError( - 'Neither `lead` nor `lead_group` defined' - ) + raise ValidationError("Neither `lead` nor `lead_group` defined") if self.lead and self.lead_group: - raise ValidationError( - 'Assessment cannot have both `lead` and `lead_group` defined' - ) + raise ValidationError("Assessment cannot have both `lead` and `lead_group` defined") return super().clean() def save(self, *args, **kwargs): @@ -395,36 +418,32 @@ def save(self, *args, **kwargs): def create_schema_for_group(self, GroupClass): schema = {} assessment_template = self.lead.project.assessment_template - groups = GroupClass.objects.filter(template=assessment_template).prefetch_related('fields') + groups = GroupClass.objects.filter(template=assessment_template).prefetch_related("fields") schema = { group.title: [ { - 'id': field.id, - 'name': field.title, - 'type': field.field_type, - 'source_type': field.source_type, - 'options': { - x['key']: x['title'] for x in field.get_options() - } + "id": field.id, + "name": field.title, + "type": field.field_type, + "source_type": field.source_type, + "options": {x["key"]: x["title"] for x in field.get_options()}, } for field in group.fields.all() - ] for group in groups + ] + for group in groups } return schema @staticmethod def get_actual_value(schema, value): - value_function = FIELDS_KEYS_VALUE_EXTRACTORS.get(schema['name'], identity) - if schema['type'] == Field.SELECT: + value_function = FIELDS_KEYS_VALUE_EXTRACTORS.get(schema["name"], identity) + if schema["type"] == Field.SELECT: # value should not be list but just in case it is a list - value = value[0] if isinstance(value, list) and len(value) > 0 else value or '' - actual_value = schema['options'].get(value, value) - elif schema['type'] == Field.MULTISELECT: + value = value[0] if isinstance(value, list) and len(value) > 0 else value or "" + actual_value = schema["options"].get(value, value) + elif schema["type"] == Field.MULTISELECT: value = value or [] - actual_value = [ - value_function(schema['options'].get(x, x)) - for x in value - ] + actual_value = [value_function(schema["options"].get(x, x)) for x in value] else: actual_value = value return actual_value @@ -434,22 +453,18 @@ def get_data_from_schema(schema, raw_data): if not raw_data: return {} - if 'id' in schema: - key = str(schema['id']) - value = raw_data.get(key, '') + if "id" in schema: + key = str(schema["id"]) + value = raw_data.get(key, "") return { - 'schema': schema, - 'value': Assessment.get_actual_value(schema, value), - 'key': value, + "schema": schema, + "value": Assessment.get_actual_value(schema, value), + "key": value, } if isinstance(schema, dict): - data = { - k: Assessment.get_data_from_schema(v, raw_data) - for k, v in schema.items() - } + data = {k: Assessment.get_data_from_schema(v, raw_data) for k, v in schema.items()} elif isinstance(schema, list): - data = [Assessment.get_data_from_schema(x, raw_data) - for x in schema] + data = [Assessment.get_data_from_schema(x, raw_data) for x in schema] else: raise Exception("Something that could not be parsed from schema") return data @@ -457,7 +472,7 @@ def get_data_from_schema(schema, raw_data): def get_metadata_json(self): metadata_schema = self.create_schema_for_group(MetadataGroup) metadata_raw = self.metadata or {} - metadata_raw = metadata_raw.get('basic_information', {}) + metadata_raw = metadata_raw.get("basic_information", {}) metadata = self.get_data_from_schema(metadata_schema, metadata_raw) return metadata @@ -466,27 +481,20 @@ def get_methodology_json(self): methodology_raw = self.methodology or {} mapping = { - 'attributes': lambda x: self.get_data_from_schema( - methodology_sch, x - ), - 'sectors': get_title_or_none(Sector), - 'focuses': get_title_or_none(Focus), - 'affected_groups': lambda x: { - 'key': x, - **get_model_attrs_or_empty_dict(AffectedGroup, ['title', 'order'])(x) - }, - 'locations': get_location_title, - 'objectives': identity, - 'sampling': identity, - 'limitations': identity, - 'data_collection_techniques': identity, - 'protection_info': get_integer_enum_title(MethodologyProtectionInfo), + "attributes": lambda x: self.get_data_from_schema(methodology_sch, x), + "sectors": get_title_or_none(Sector), + "focuses": get_title_or_none(Focus), + "affected_groups": lambda x: {"key": x, **get_model_attrs_or_empty_dict(AffectedGroup, ["title", "order"])(x)}, + "locations": get_location_title, + "objectives": identity, + "sampling": identity, + "limitations": identity, + "data_collection_techniques": identity, + "protection_info": get_integer_enum_title(MethodologyProtectionInfo), } return { - underscore_to_title(k): - v if not isinstance(v, list) - else [mapping[k](y) for y in v] + underscore_to_title(k): v if not isinstance(v, list) else [mapping[k](y) for y in v] for k, v in methodology_raw.items() } @@ -494,11 +502,11 @@ def get_summary_json(self): # Formatting of underscored keywords, by default is upper case as given # by default_format() function below formatting = { - 'priority_sectors': lambda x: 'Most Unmet Needs Sectors', - 'affected_location': lambda x: 'Settings Facing Most Humanitarian Issues' # noqa + "priority_sectors": lambda x: "Most Unmet Needs Sectors", + "affected_location": lambda x: "Settings Facing Most Humanitarian Issues", # noqa } - default_format = underscore_to_title # function + default_format = underscore_to_title # function summary_raw = self.summary if not summary_raw: @@ -511,7 +519,7 @@ def get_summary_json(self): # Add sectors data first for sectorname, sector_data in summary_raw.items(): try: - _, sec_id = sectorname.split('-') + _, sec_id = sectorname.split("-") sector = Sector.objects.get(id=sec_id).title # Exception because, we have cross_sector and humanitarian_access # in addition to "sector-" keys @@ -529,11 +537,8 @@ def get_summary_json(self): for rank, data in group_data.items(): for colname, colval in data.items(): col_f = formatting.get(colname, default_format)(colname) - group_col_data = parsed_group_data.get( - col_f, - [None] * numrows - ) - rankvalue = int(rank.replace('rank', '')) # rank + group_col_data = parsed_group_data.get(col_f, [None] * numrows) + rankvalue = int(rank.replace("rank", "")) # rank group_col_data[rankvalue - 1] = colval parsed_group_data[col_f] = group_col_data @@ -549,7 +554,7 @@ def get_summary_json(self): for sector, data in summary_data.items(): for group, groupdata in data.items(): for col, coldata in groupdata.items(): - key = '{} - {} - {}'.format(sector, group, col) + key = "{} - {} - {}".format(sector, group, col) new_summary_data[key] = coldata return new_summary_data @@ -557,12 +562,9 @@ def get_score_json(self): if not self.score: return {} - pillars_raw = self.score['pillars'] or {} - matrix_pillars_raw = self.score['matrix_pillars'] or {} - matrix_pillars_final_raw = { - x: self.score[x] - for x in self.score.keys() if 'matrix-score' in x - } + pillars_raw = self.score["pillars"] or {} + matrix_pillars_raw = self.score["matrix_pillars"] or {} + matrix_pillars_final_raw = {x: self.score[x] for x in self.score.keys() if "matrix-score" in x} matrix_pillars_final_score = {} @@ -573,17 +575,17 @@ def get_score_json(self): data = {} for qid, sid in pdata.items(): q = get_title_or_none(ScoreQuestion)(qid) - data[q] = get_model_attrs_or_empty_dict(ScoreScale, ['title', 'value'])(sid) + data[q] = get_model_attrs_or_empty_dict(ScoreScale, ["title", "value"])(sid) pillars[pillar_title] = data - final_pillars_score[pillar_title] = self.score.get('{}-score'.format(pid)) + final_pillars_score[pillar_title] = self.score.get("{}-score".format(pid)) matrix_pillars = {} for mpid, mpdata in matrix_pillars_raw.items(): mpillar_title = get_title_or_none(ScoreMatrixPillar)(mpid) data = {} - matrix_final_data = matrix_pillars_final_raw.get(f'{mpid}-matrix-score') or '' - matrix_pillars_final_score[f'{mpillar_title}_final_score'] = matrix_final_data + matrix_final_data = matrix_pillars_final_raw.get(f"{mpid}-matrix-score") or "" + matrix_pillars_final_score[f"{mpillar_title}_final_score"] = matrix_final_data for sector in Sector.objects.filter(template=self.project.assessment_template): scale = None @@ -591,17 +593,17 @@ def get_score_json(self): if mpdata is not None and sector_id in mpdata: scale = ScoreMatrixScale.objects.filter(id=mpdata[sector_id]).first() data[sector.title] = { - 'value': scale.value if scale else '', - 'title': f'{scale.row.title} / {scale.column.title}' if scale else '' + "value": scale.value if scale else "", + "title": f"{scale.row.title} / {scale.column.title}" if scale else "", } matrix_pillars[mpillar_title] = data return { - 'final_score': self.score.get('final_score'), - 'final_pillars_score': final_pillars_score, - 'pillars': pillars, - 'matrix_pillars': matrix_pillars, - 'matrix_pillars_final_score': matrix_pillars_final_score, + "final_score": self.score.get("final_score"), + "final_pillars_score": final_pillars_score, + "pillars": pillars, + "matrix_pillars": matrix_pillars, + "matrix_pillars_final_score": matrix_pillars_final_score, } def get_questionnaire_json(self, questionnaire_subsectors=None): @@ -610,14 +612,12 @@ def get_questionnaire_json(self, questionnaire_subsectors=None): template = self.project.assessment_template raw_questionnaire = self.questionnaire or {} - questionnaire_subsectors = ScoreQuestionnaireSubSector.objects.filter( - sector__template=template - ).prefetch_related('sector', 'scorequestionnaire_set') - - questionnaire_sectors = ScoreQuestionnaireSector.objects.filter( - template=template + questionnaire_subsectors = ScoreQuestionnaireSubSector.objects.filter(sector__template=template).prefetch_related( + "sector", "scorequestionnaire_set" ) + questionnaire_sectors = ScoreQuestionnaireSector.objects.filter(template=template) + questionnaire_json = {} for subsector in questionnaire_subsectors: @@ -635,15 +635,14 @@ def get_questionnaire_json(self, questionnaire_subsectors=None): # Add Method summaries for method in methods: raw_data = raw_questionnaire.get(method) or {} - questionnaire_json[method][f'{method}_score'] = { - 'all_quality_criteria': raw_data.get('all-quality-criteria', {}).get('value'), - 'minimum_requirement': raw_data.get('minimum-requirements', {}).get('value'), - 'use': raw_data.get('use-criteria', {}).get('value'), + questionnaire_json[method][f"{method}_score"] = { + "all_quality_criteria": raw_data.get("all-quality-criteria", {}).get("value"), + "minimum_requirement": raw_data.get("minimum-requirements", {}).get("value"), + "use": raw_data.get("use-criteria", {}).get("value"), } - questionnaire_json[method]['breakdown_of_quality_criteria'] = { - x.title: raw_data.get(f'sector-{x.id}') - for x in questionnaire_sectors + questionnaire_json[method]["breakdown_of_quality_criteria"] = { + x.title: raw_data.get(f"sector-{x.id}") for x in questionnaire_sectors } return questionnaire_json @@ -658,19 +657,15 @@ def to_exportable_json(self): summary = self.get_summary_json() # score score = self.get_score_json() - return OrderedDict(( - ('metadata', metadata), - ('methodology', methodology), - ('summary', summary), - ('score', score) - )) + return OrderedDict((("metadata", metadata), ("methodology", methodology), ("summary", summary), ("score", score))) class PlannedAssessment(UserResource, ProjectEntityMixin): """ Planned Assessment belonging to a lead """ - project = models.ForeignKey('project.Project', on_delete=models.CASCADE) + + project = models.ForeignKey("project.Project", on_delete=models.CASCADE) title = models.CharField(max_length=255) metadata = models.JSONField(default=None, blank=True, null=True) methodology = models.JSONField(default=None, blank=True, null=True) @@ -681,39 +676,39 @@ def __str__(self): def create_schema_for_group(self, GroupClass): schema = {} assessment_template = self.project.assessment_template - groups = GroupClass.objects.filter( - template=assessment_template, - fields__show_in_planned_assessment=True, - ).prefetch_related('fields').distinct() + groups = ( + GroupClass.objects.filter( + template=assessment_template, + fields__show_in_planned_assessment=True, + ) + .prefetch_related("fields") + .distinct() + ) schema = { group.title: [ { - 'id': field.id, - 'name': field.title, - 'type': field.field_type, - 'source_type': field.source_type, - 'options': { - x['key']: x['title'] for x in field.get_options() - } + "id": field.id, + "name": field.title, + "type": field.field_type, + "source_type": field.source_type, + "options": {x["key"]: x["title"] for x in field.get_options()}, } for field in group.fields.all() - ] for group in groups + ] + for group in groups } return schema @staticmethod def get_actual_value(schema, value): - value_function = FIELDS_KEYS_VALUE_EXTRACTORS.get(schema['name'], identity) - if schema['type'] == Field.SELECT: + value_function = FIELDS_KEYS_VALUE_EXTRACTORS.get(schema["name"], identity) + if schema["type"] == Field.SELECT: # value should not be list but just in case it is a list - value = value[0] if isinstance(value, list) and len(value) > 0 else value or '' - actual_value = schema['options'].get(value, value) - elif schema['type'] == Field.MULTISELECT: + value = value[0] if isinstance(value, list) and len(value) > 0 else value or "" + actual_value = schema["options"].get(value, value) + elif schema["type"] == Field.MULTISELECT: value = value or [] - actual_value = [ - value_function(schema['options'].get(x, x)) - for x in value - ] + actual_value = [value_function(schema["options"].get(x, x)) for x in value] else: actual_value = value return actual_value @@ -723,22 +718,18 @@ def get_data_from_schema(schema, raw_data): if not raw_data: return {} - if 'id' in schema: - key = str(schema['id']) - value = raw_data.get(key, '') + if "id" in schema: + key = str(schema["id"]) + value = raw_data.get(key, "") return { - 'schema': schema, - 'value': Assessment.get_actual_value(schema, value), - 'key': value, + "schema": schema, + "value": Assessment.get_actual_value(schema, value), + "key": value, } if isinstance(schema, dict): - data = { - k: Assessment.get_data_from_schema(v, raw_data) - for k, v in schema.items() - } + data = {k: Assessment.get_data_from_schema(v, raw_data) for k, v in schema.items()} elif isinstance(schema, list): - data = [Assessment.get_data_from_schema(x, raw_data) - for x in schema] + data = [Assessment.get_data_from_schema(x, raw_data) for x in schema] else: raise Exception("Something that could not be parsed from schema") return data @@ -753,8 +744,10 @@ def to_exportable_json(self): metadata = self.get_metadata_json() # for methodology methodology = self.get_methodology_json() - return OrderedDict(( - ('title', self.title), - ('metadata', metadata), - ('methodology', methodology), - )) + return OrderedDict( + ( + ("title", self.title), + ("metadata", metadata), + ("methodology", methodology), + ) + ) diff --git a/apps/ary/mutation.py b/apps/ary/mutation.py index 7f260888e8..c9b703fccd 100644 --- a/apps/ary/mutation.py +++ b/apps/ary/mutation.py @@ -1,12 +1,12 @@ import graphene - -from utils.graphene.mutation import PsDeleteMutation -from deep.permissions import ProjectPermissions as PP from ary.models import Assessment from ary.schema import AssessmentType +from deep.permissions import ProjectPermissions as PP +from utils.graphene.mutation import PsDeleteMutation + -class AssessmentMutationMixin(): +class AssessmentMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(project=info.context.active_project) @@ -15,10 +15,11 @@ def filter_queryset(cls, qs, info): class DeleteAssessment(AssessmentMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = Assessment result = graphene.Field(AssessmentType) permissions = [PP.Permission.DELETE_LEAD] -class Mutation(): +class Mutation: assessment_delete = DeleteAssessment.Field() diff --git a/apps/ary/schema.py b/apps/ary/schema.py index af3ef740fc..0dd59f6628 100644 --- a/apps/ary/schema.py +++ b/apps/ary/schema.py @@ -1,15 +1,14 @@ +from ary.filters import AssessmentGQFilterSet +from ary.models import Assessment from django.db.models import QuerySet from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination - -from utils.graphene.types import CustomDjangoListObjectType -from utils.graphene.fields import DjangoPaginatedListObjectField -from deep.permissions import ProjectPermissions as PP +from lead.models import Lead from user_resource.schema import UserResourceMixin -from lead.models import Lead -from ary.models import Assessment -from ary.filters import AssessmentGQFilterSet +from deep.permissions import ProjectPermissions as PP +from utils.graphene.fields import DjangoPaginatedListObjectField +from utils.graphene.types import CustomDjangoListObjectType def get_assessment_qs(info): @@ -26,9 +25,15 @@ class AssessmentType(UserResourceMixin, DjangoObjectType): class Meta: model = Assessment only_fields = ( - 'id', 'lead', 'project', 'lead_group', - 'metadata', 'methodology', 'summary', - 'score', 'questionnaire', + "id", + "lead", + "project", + "lead_group", + "metadata", + "methodology", + "summary", + "score", + "questionnaire", ) @staticmethod @@ -45,10 +50,7 @@ class Meta: class Query: assessment = DjangoObjectField(AssessmentType) assessments = DjangoPaginatedListObjectField( - AssessmentListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AssessmentListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod diff --git a/apps/ary/serializers.py b/apps/ary/serializers.py index 0fcf8197cc..cbae3bcad7 100644 --- a/apps/ary/serializers.py +++ b/apps/ary/serializers.py @@ -1,88 +1,77 @@ -from django.shortcuts import get_object_or_404 from django.db.models import Q +from django.shortcuts import get_object_or_404 from drf_dynamic_fields import DynamicFieldsMixin -from rest_framework import serializers - -from deep.serializers import ( - RemoveNullFieldsMixin, - RecursiveSerializer, -) - -from project.models import Project from gallery.models import File -from user_resource.serializers import UserResourceSerializer -from lead.serializers import SimpleLeadSerializer -from lead.models import Lead, LeadGroup -from deep.models import Field +from gallery.serializers import SimpleFileSerializer from geo.models import Region +from lead.models import Lead, LeadGroup +from lead.serializers import SimpleLeadSerializer from organization.models import Organization, OrganizationType from organization.serializers import ( ArySourceOrganizationSerializer, OrganizationTypeSerializer, ) -from gallery.serializers import SimpleFileSerializer +from project.models import Project +from rest_framework import serializers +from user_resource.serializers import UserResourceSerializer + +from deep.models import Field +from deep.serializers import RecursiveSerializer, RemoveNullFieldsMixin from .models import ( - AssessmentTemplate, Assessment, + AssessmentTemplate, PlannedAssessment, + ScoreQuestionnaire, ScoreQuestionnaireSector, ScoreQuestionnaireSubSector, - ScoreQuestionnaire, ) -class AssessmentSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): - lead_title = serializers.CharField(source='lead.title', - read_only=True) - lead_group_title = serializers.CharField(source='lead_group.title', - read_only=True) - project = serializers.PrimaryKeyRelatedField( - required=False, - queryset=Project.objects.all() - ) +class AssessmentSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): + lead_title = serializers.CharField(source="lead.title", read_only=True) + lead_group_title = serializers.CharField(source="lead_group.title", read_only=True) + project = serializers.PrimaryKeyRelatedField(required=False, queryset=Project.objects.all()) class Meta: model = Assessment - fields = ('__all__') + fields = "__all__" def create(self, data): - if data.get('project') is None: - if data.get('lead') is None: - data['project'] = data['lead_group'].project + if data.get("project") is None: + if data.get("lead") is None: + data["project"] = data["lead_group"].project else: - data['project'] = data['lead'].project + data["project"] = data["lead"].project return super().create(data) -class PlannedAssessmentSerializer( - RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): +class PlannedAssessmentSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): class Meta: model = PlannedAssessment - fields = '__all__' + fields = "__all__" class LeadAssessmentSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): - lead_title = serializers.CharField(source='lead.title', read_only=True) + lead_title = serializers.CharField(source="lead.title", read_only=True) gallery_files_details = serializers.SerializerMethodField() class Meta: model = Assessment - fields = ('__all__') - read_only_fields = ('lead', 'lead_group', 'project') + fields = "__all__" + read_only_fields = ("lead", "lead_group", "project") def get_gallery_files_details(self, assessment): # Right now gallery files are only used in additional_documents - additional_documents = (assessment.metadata or {}).get('additional_documents') + additional_documents = (assessment.metadata or {}).get("additional_documents") if not additional_documents: return files_id = [] for items in additional_documents.values(): for item in items or []: - if item.get('id') and item.get('type') == 'file': - files_id.append(item['id']) + if item.get("id") and item.get("type") == "file": + files_id.append(item["id"]) # TODO: qs = File.objects.filter(id__in=files_id).all() return SimpleFileSerializer(qs, context=self.context, many=True).data @@ -90,40 +79,36 @@ def get_gallery_files_details(self, assessment): def create(self, validated_data): # If this assessment is being created for the first time, # we want to set lead to the one which has its id in the url - lead = get_object_or_404(Lead, pk=self.initial_data['lead']) - assessment = super().create({ - **validated_data, - 'lead': lead, - 'project': lead.project, - }) + lead = get_object_or_404(Lead, pk=self.initial_data["lead"]) + assessment = super().create( + { + **validated_data, + "lead": lead, + "project": lead.project, + } + ) assessment.save() return assessment -class LeadGroupAssessmentSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer): - lead_group_title = serializers.CharField(source='lead_group.title', - read_only=True) - leads = SimpleLeadSerializer(source='lead_group.lead_set', - many=True, - read_only=True) +class LeadGroupAssessmentSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): + lead_group_title = serializers.CharField(source="lead_group.title", read_only=True) + leads = SimpleLeadSerializer(source="lead_group.lead_set", many=True, read_only=True) class Meta: model = Assessment - fields = ('__all__') - read_only_fields = ('lead', 'lead_group') + fields = "__all__" + read_only_fields = ("lead", "lead_group") def create(self, validated_data): # If this assessment is being created for the first time, # we want to set lead group to the one which has its id in the url - assessment = super().create({ - **validated_data, - 'lead_group': get_object_or_404( - LeadGroup, - pk=self.initial_data['lead_group'] - ), - }) + assessment = super().create( + { + **validated_data, + "lead_group": get_object_or_404(LeadGroup, pk=self.initial_data["lead_group"]), + } + ) assessment.save() return assessment @@ -141,7 +126,7 @@ class TreeSerializer(serializers.Serializer): class OptionSerializer(serializers.Serializer): key = serializers.CharField() - label = serializers.CharField(source='title') + label = serializers.CharField(source="title") class FieldSerializer(serializers.Serializer): @@ -151,12 +136,11 @@ class FieldSerializer(serializers.Serializer): tooltip = serializers.CharField() field_type = serializers.CharField() source_type = serializers.CharField() - options = OptionSerializer(source='get_options', - many=True, read_only=True) + options = OptionSerializer(source="get_options", many=True, read_only=True) show_in_planned_assessment = serializers.BooleanField() class Meta: - ref_name = 'AryFieldSerializer' + ref_name = "AryFieldSerializer" class GroupSerializer(serializers.Serializer): @@ -227,59 +211,50 @@ def get_scales(self, pillar): class ScoreQuestionnaireSerializer(serializers.ModelSerializer): class Meta: model = ScoreQuestionnaire - fields = '__all__' + fields = "__all__" class ScoreQuestionnaireSubSectorSerializer(serializers.ModelSerializer): questions = ScoreQuestionnaireSerializer( - source='scorequestionnaire_set', many=True, read_only=True, + source="scorequestionnaire_set", + many=True, + read_only=True, ) class Meta: model = ScoreQuestionnaireSubSector - fields = '__all__' + fields = "__all__" class ScoreQuestionnaireSectorSerializer(serializers.ModelSerializer): sub_sectors = ScoreQuestionnaireSubSectorSerializer( - source='scorequestionnairesubsector_set', many=True, read_only=True, + source="scorequestionnairesubsector_set", + many=True, + read_only=True, ) class Meta: model = ScoreQuestionnaireSector - fields = '__all__' - - -class AssessmentTemplateSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): - metadata_groups = GroupSerializer(source='metadatagroup_set', - many=True, read_only=True) - methodology_groups = GroupSerializer(source='methodologygroup_set', - many=True, read_only=True) - sectors = ItemSerializer(source='sector_set', - many=True, read_only=True) - focuses = ItemSerializer(source='focus_set', - many=True, read_only=True) - underlying_factors = TreeSerializer(source='get_parent_underlying_factors', - many=True, read_only=True) - affected_groups = TreeSerializer(source='get_parent_affected_groups', - many=True, read_only=True) - - priority_sectors = TreeSerializer(source='get_parent_priority_sectors', - many=True, read_only=True) - priority_issues = TreeSerializer(source='get_parent_priority_issues', - many=True, read_only=True) - specific_need_groups = ItemSerializer(source='specificneedgroup_set', - many=True, read_only=True) - affected_locations = ItemSerializer(source='affectedlocation_set', - many=True, read_only=True) - - score_scales = ScoreScaleSerializer(source='scorescale_set', - many=True, read_only=True) - score_pillars = ScorePillarSerializer(source='scorepillar_set', - many=True, read_only=True) + fields = "__all__" + + +class AssessmentTemplateSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): + metadata_groups = GroupSerializer(source="metadatagroup_set", many=True, read_only=True) + methodology_groups = GroupSerializer(source="methodologygroup_set", many=True, read_only=True) + sectors = ItemSerializer(source="sector_set", many=True, read_only=True) + focuses = ItemSerializer(source="focus_set", many=True, read_only=True) + underlying_factors = TreeSerializer(source="get_parent_underlying_factors", many=True, read_only=True) + affected_groups = TreeSerializer(source="get_parent_affected_groups", many=True, read_only=True) + + priority_sectors = TreeSerializer(source="get_parent_priority_sectors", many=True, read_only=True) + priority_issues = TreeSerializer(source="get_parent_priority_issues", many=True, read_only=True) + specific_need_groups = ItemSerializer(source="specificneedgroup_set", many=True, read_only=True) + affected_locations = ItemSerializer(source="affectedlocation_set", many=True, read_only=True) + + score_scales = ScoreScaleSerializer(source="scorescale_set", many=True, read_only=True) + score_pillars = ScorePillarSerializer(source="scorepillar_set", many=True, read_only=True) score_matrix_pillars = ScoreMatrixPillarSerializer( - source='scorematrixpillar_set', + source="scorematrixpillar_set", many=True, read_only=True, ) @@ -287,44 +262,54 @@ class AssessmentTemplateSerializer(RemoveNullFieldsMixin, score_buckets = serializers.SerializerMethodField() sources = serializers.SerializerMethodField() questionnaire_sector = ScoreQuestionnaireSectorSerializer( - source='scorequestionnairesector_set', many=True, read_only=True, + source="scorequestionnairesector_set", + many=True, + read_only=True, ) class Meta: model = AssessmentTemplate - fields = ('__all__') + fields = "__all__" def get_score_buckets(self, template): buckets = template.scorebucket_set.all() - return [ - [b.min_value, b.max_value, b.score] - for b in buckets - ] + return [[b.min_value, b.max_value, b.score] for b in buckets] def get_sources(self, instance): def have_source(source_type): return AssessmentTemplate.objects.filter( - Q(metadatagroup__fields__source_type=source_type) | - Q(methodologygroup__fields__source_type=source_type), + Q(metadatagroup__fields__source_type=source_type) | Q(methodologygroup__fields__source_type=source_type), pk=instance.pk, ).exists() return { - 'countries': Region.objects.filter(public=True).extra( - select={ - 'key': 'id', - 'label': 'title', - } - ).values('key', 'label') if have_source(Field.COUNTRIES) else [], - 'organizations': ArySourceOrganizationSerializer( - Organization.objects.all(), - many=True, - context=self.context, - ).data - if have_source(Field.ORGANIZATIONS or Field.DONORS) else [], - 'organization_type': OrganizationTypeSerializer( - OrganizationType.objects.all(), - many=True, - ).data - if have_source(Field.ORGANIZATIONS or Field.DONORS) else [], + "countries": ( + Region.objects.filter(public=True) + .extra( + select={ + "key": "id", + "label": "title", + } + ) + .values("key", "label") + if have_source(Field.COUNTRIES) + else [] + ), + "organizations": ( + ArySourceOrganizationSerializer( + Organization.objects.all(), + many=True, + context=self.context, + ).data + if have_source(Field.ORGANIZATIONS or Field.DONORS) + else [] + ), + "organization_type": ( + OrganizationTypeSerializer( + OrganizationType.objects.all(), + many=True, + ).data + if have_source(Field.ORGANIZATIONS or Field.DONORS) + else [] + ), } diff --git a/apps/ary/stats.py b/apps/ary/stats.py index 2d4c9a3e0b..4661f1a527 100644 --- a/apps/ary/stats.py +++ b/apps/ary/stats.py @@ -1,29 +1,23 @@ -from django.utils import timezone -from django.db.models import F, Count -from entry.stats import ( - _get_project_geoareas, - _get_lead_data, - get_project_entries_stats -) -from apps.entry.widgets.geo_widget import get_valid_geo_ids - -from organization.models import OrganizationType, Organization -from lead.models import Lead from ary.models import ( + AffectedGroup, Assessment, + Focus, MetadataField, MethodologyField, - Focus, - Sector, - AffectedGroup, - - ScorePillar, - ScoreScale, - + MethodologyProtectionInfo, ScoreMatrixPillar, ScoreMatrixScale, - MethodologyProtectionInfo, + ScorePillar, + ScoreScale, + Sector, ) +from django.db.models import Count, F +from django.utils import timezone +from entry.stats import _get_lead_data, _get_project_geoareas, get_project_entries_stats +from lead.models import Lead +from organization.models import Organization, OrganizationType + +from apps.entry.widgets.geo_widget import get_valid_geo_ids def _get_integer_array(array): @@ -33,25 +27,19 @@ def _get_integer_array(array): def _get_ary_field_options(config): - pk = config['pk'] - field_type = config['type'] - if field_type == 'metadatafield': - return list( - MetadataField.objects.get(pk=pk).options.values('key', 'title').values(id=F('key'), name=F('title')) - ) - elif field_type == 'methodologyfield': - return list( - MethodologyField.objects.get(pk=pk).options.values('key', 'title').values(id=F('key'), name=F('title')) - ) - elif field_type == 'scorepillar': - return list( - ScorePillar.objects.get(pk=pk).questions.values('id', name=F('title')) - ) - elif field_type == 'scorematrixpillar': + pk = config["pk"] + field_type = config["type"] + if field_type == "metadatafield": + return list(MetadataField.objects.get(pk=pk).options.values("key", "title").values(id=F("key"), name=F("title"))) + elif field_type == "methodologyfield": + return list(MethodologyField.objects.get(pk=pk).options.values("key", "title").values(id=F("key"), name=F("title"))) + elif field_type == "scorepillar": + return list(ScorePillar.objects.get(pk=pk).questions.values("id", name=F("title"))) + elif field_type == "scorematrixpillar": return { - 'scale': list(ScoreMatrixScale.objects.filter(pillar=pk).values('id', 'row', 'column', 'value')), + "scale": list(ScoreMatrixScale.objects.filter(pillar=pk).values("id", "row", "column", "value")), } - raise Exception(f'Unknown field type provided {field_type}') + raise Exception(f"Unknown field type provided {field_type}") def get_project_ary_entry_stats(project): @@ -60,175 +48,161 @@ def get_project_ary_entry_stats(project): """ # Sample config [Should work what modification if fixture is used to load] dynamic_fields = { - 'assessment_type': { - 'pk': 20, - 'type': 'metadatafield', + "assessment_type": { + "pk": 20, + "type": "metadatafield", }, - 'language': { - 'pk': 18, - 'type': 'metadatafield', + "language": { + "pk": 18, + "type": "metadatafield", }, - 'coordination': { - 'pk': 6, - 'type': 'metadatafield', + "coordination": { + "pk": 6, + "type": "metadatafield", }, - 'frequency': { - 'pk': 16, - 'type': 'metadatafield', + "frequency": { + "pk": 16, + "type": "metadatafield", }, - - 'units': { - 'pk': 6, # Unit of reporting - 'type': 'methodologyfield', + "units": { + "pk": 6, # Unit of reporting + "type": "methodologyfield", }, - 'type_of_unit_of_analysis': { - 'pk': 5, # Unit of analysis - 'type': 'methodologyfield', + "type_of_unit_of_analysis": { + "pk": 5, # Unit of analysis + "type": "methodologyfield", }, - 'sampling_approach': { - 'pk': 2, # Sampling Approach - 'type': 'methodologyfield', + "sampling_approach": { + "pk": 2, # Sampling Approach + "type": "methodologyfield", }, - 'data_collection_technique': { - 'pk': 1, # data collection technique - 'type': 'methodologyfield', + "data_collection_technique": { + "pk": 1, # data collection technique + "type": "methodologyfield", }, - - 'fit_for_purpose_array': { - 'pk': 1, - 'type': 'scorepillar', + "fit_for_purpose_array": { + "pk": 1, + "type": "scorepillar", }, - 'trustworthiness_array': { - 'pk': 2, - 'type': 'scorepillar', + "trustworthiness_array": { + "pk": 2, + "type": "scorepillar", }, - 'analytical_rigor_array': { - 'pk': 3, - 'type': 'scorepillar', + "analytical_rigor_array": { + "pk": 3, + "type": "scorepillar", }, - 'analytical_writing_array': { - 'pk': 4, - 'type': 'scorepillar', + "analytical_writing_array": { + "pk": 4, + "type": "scorepillar", }, - 'analytical_density': { - 'pk': 1, - 'type': 'scorematrixpillar', + "analytical_density": { + "pk": 1, + "type": "scorematrixpillar", }, } # Stakeholder is MetaField Group [Used to identify organizations field in the stakeholder group] stakeholder_pk = 2 - stakeholder_fields_id = MetadataField.objects.filter(group=stakeholder_pk).values_list('id', flat=True) + stakeholder_fields_id = MetadataField.objects.filter(group=stakeholder_pk).values_list("id", flat=True) # Used to generate data for individuals and households sum methodology_attributes_fields = { - 'sampling_size_field_pk': 3, + "sampling_size_field_pk": 3, # Both are key of methodology field(unit_of_analysis)'s options - 'households': 4, - 'individuals': 5, + "households": 4, + "individuals": 5, } - dynamic_meta = { - key: _get_ary_field_options(value) - for key, value in dynamic_fields.items() if value.get('pk') - } + dynamic_meta = {key: _get_ary_field_options(value) for key, value in dynamic_fields.items() if value.get("pk")} - analytical_density_scale = { - scale['id']: scale['value'] - for scale in dynamic_meta['analytical_density']['scale'] - } + analytical_density_scale = {scale["id"]: scale["value"] for scale in dynamic_meta["analytical_density"]["scale"]} static_meta = { - 'focus_array': list(Focus.objects.values('id', name=F('title'))), - 'protection_info_management_array': [ + "focus_array": list(Focus.objects.values("id", name=F("title"))), + "protection_info_management_array": [ { - 'id': _id, - 'title': title, + "id": _id, + "title": title, } for _id, title in MethodologyProtectionInfo.choices ], - 'sector_array': list(Sector.objects.values('id', name=F('title'))), - 'affected_groups_array': list(AffectedGroup.objects.values('id', name=F('title'))), - 'organization_type': list( + "sector_array": list(Sector.objects.values("id", name=F("title"))), + "affected_groups_array": list(AffectedGroup.objects.values("id", name=F("title"))), + "organization_type": list( OrganizationType.objects.annotate( - organization_count=Count('organization', distinct=True), - ).values( - 'id', - 'organization_count', - name=F('title') - ) + organization_count=Count("organization", distinct=True), + ).values("id", "organization_count", name=F("title")) ), - 'organization': [ + "organization": [ { - 'id': org.id, - 'name': org.title, - 'short_name': org.short_name, - 'long_name': org.long_name, - 'organization_type_id': org.organization_type_id, - 'parent': org.parent_id, + "id": org.id, + "name": org.title, + "short_name": org.short_name, + "long_name": org.long_name, + "organization_type_id": org.organization_type_id, + "parent": org.parent_id, } for org in Organization.objects.all() ], # scale used by score_pillar - 'scorepillar_scale': list(ScoreScale.objects.values('id', 'color', 'value', name=F('title'))), - 'final_scores_array': { - 'score_pillar': list(ScorePillar.objects.values('id', name=F('title'))), - 'score_matrix_pillar': list(ScoreMatrixPillar.objects.values('id', name=F('title'))), + "scorepillar_scale": list(ScoreScale.objects.values("id", "color", "value", name=F("title"))), + "final_scores_array": { + "score_pillar": list(ScorePillar.objects.values("id", name=F("title"))), + "score_matrix_pillar": list(ScoreMatrixPillar.objects.values("id", name=F("title"))), }, - # NOTE: Is defined in client - 'additional_documentation_array': [ - {'id': 1, 'name': 'Executive Summary'}, - {'id': 2, 'name': 'Assessment Database'}, - {'id': 3, 'name': 'Questionnaire'}, - {'id': 4, 'name': 'Miscellaneous'}, + "additional_documentation_array": [ + {"id": 1, "name": "Executive Summary"}, + {"id": 2, "name": "Assessment Database"}, + {"id": 3, "name": "Questionnaire"}, + {"id": 4, "name": "Miscellaneous"}, ], - 'methodology_content': [ - {'id': 1, 'name': 'Objectives'}, - {'id': 2, 'name': 'Data Collection Techniques'}, - {'id': 3, 'name': 'Sampling'}, - {'id': 4, 'name': 'Limitations'}, + "methodology_content": [ + {"id": 1, "name": "Objectives"}, + {"id": 2, "name": "Data Collection Techniques"}, + {"id": 3, "name": "Sampling"}, + {"id": 4, "name": "Limitations"}, ], } # Used to retrive organization type ID using Organiztion ID organization_type_map = { # Organization ID -> Organization Type ID - org['id']: org['organization_type_id'] - for org in static_meta['organization'] + org["id"]: org["organization_type_id"] + for org in static_meta["organization"] } meta = { - 'data_calculated': timezone.now(), + "data_calculated": timezone.now(), **static_meta, **dynamic_meta, } public_data = [] confidential_data = [] - for ary in Assessment.objects.prefetch_related('lead', 'lead__attachment').filter(project=project).all(): + for ary in Assessment.objects.prefetch_related("lead", "lead__attachment").filter(project=project).all(): metadata_raw = ary.metadata or {} - basic_information = metadata_raw.get('basic_information') or {} - additional_documents = metadata_raw.get('additional_documents') or {} + basic_information = metadata_raw.get("basic_information") or {} + additional_documents = metadata_raw.get("additional_documents") or {} methodology_raw = ary.methodology or {} - methodology_attributes = methodology_raw.get('attributes') or [] + methodology_attributes = methodology_raw.get("attributes") or [] score_raw = ary.score or {} - pillars = score_raw.get('pillars') or {} - matrix_pillars = score_raw.get('matrix_pillars') or {} + pillars = score_raw.get("pillars") or {} + matrix_pillars = score_raw.get("matrix_pillars") or {} scores = { - 'final_scores': { - 'score_pillar': { - score_pillar['id']: score_raw.get(f"{score_pillar['id']}-score") - for score_pillar in meta['final_scores_array']['score_pillar'] + "final_scores": { + "score_pillar": { + score_pillar["id"]: score_raw.get(f"{score_pillar['id']}-score") + for score_pillar in meta["final_scores_array"]["score_pillar"] + }, + "score_matrix_pillar": { + sm_pillar["id"]: score_raw.get(f"{sm_pillar['id']}-matrix-score") + for sm_pillar in meta["final_scores_array"]["score_matrix_pillar"] }, - 'score_matrix_pillar': { - sm_pillar['id']: score_raw.get(f"{sm_pillar['id']}-matrix-score") - for sm_pillar in meta['final_scores_array']['score_matrix_pillar'] - } }, - # Analytical Density (Matrix Score Pillar) **{ # key: Sector id (Food, Livelihood, Education (Selected Sectors) @@ -236,28 +210,24 @@ def get_project_ary_entry_stats(project): score_matrix_pillar_key: { sector_id: analytical_density_scale.get(scale_id) for sector_id, scale_id in ( - matrix_pillars.get( - str(dynamic_fields[score_matrix_pillar_key]['pk']) - ) or {} + matrix_pillars.get(str(dynamic_fields[score_matrix_pillar_key]["pk"])) or {} ).items() } - for score_matrix_pillar_key in ['analytical_density'] + for score_matrix_pillar_key in ["analytical_density"] }, - **{ # NOTE: Make sure the keys don't conflit with outer keys scorepillar_key: { - option['id']: ( - pillars.get(str(dynamic_fields[scorepillar_key]['pk'])) or {} - ).get(str(option['id'])) + option["id"]: (pillars.get(str(dynamic_fields[scorepillar_key]["pk"])) or {}).get(str(option["id"])) for option in meta[scorepillar_key] - } for scorepillar_key in [ - 'fit_for_purpose_array', - 'trustworthiness_array', - 'analytical_rigor_array', - 'analytical_writing_array', + } + for scorepillar_key in [ + "fit_for_purpose_array", + "trustworthiness_array", + "analytical_rigor_array", + "analytical_writing_array", ] - } + }, } lead = ary.lead @@ -267,103 +237,95 @@ def get_project_ary_entry_stats(project): # confidential data (if lead is confidential) lead_source_data = {} if ( - lead.source_type in [Lead.SourceType.DISK, Lead.SourceType.DROPBOX, Lead.SourceType.GOOGLE_DRIVE] and - lead.attachment and lead.attachment.file + lead.source_type in [Lead.SourceType.DISK, Lead.SourceType.DROPBOX, Lead.SourceType.GOOGLE_DRIVE] + and lead.attachment + and lead.attachment.file ): - lead_source_data['attachment'] = lead.attachment.file.url + lead_source_data["attachment"] = lead.attachment.file.url elif lead.source_type == Lead.SourceType.WEBSITE: - lead_source_data['url'] = lead.url + lead_source_data["url"] = lead.url elif lead.source_type == Lead.SourceType.TEXT: - lead_source_data['text'] = lead.text + lead_source_data["text"] = lead.text ary_data = { - 'pk': ary.pk, - 'created_at': ary.created_at, - 'date': ary.lead.published_on, - 'lead': { + "pk": ary.pk, + "created_at": ary.created_at, + "date": ary.lead.published_on, + "lead": { **lead_data, **lead_source_data, }, - - 'focus': _get_integer_array(methodology_raw.get('focuses') or []), - 'protection_info_management': _get_integer_array(methodology_raw.get('protection_info') or []), - 'sector': _get_integer_array(methodology_raw.get('sectors') or []), - 'scores': scores or [], - 'geo': get_valid_geo_ids(methodology_raw.get('locations') or []), - 'affected_groups': _get_integer_array(methodology_raw.get('affected_groups') or []), - - 'organization_and_stakeholder_type': [ + "focus": _get_integer_array(methodology_raw.get("focuses") or []), + "protection_info_management": _get_integer_array(methodology_raw.get("protection_info") or []), + "sector": _get_integer_array(methodology_raw.get("sectors") or []), + "scores": scores or [], + "geo": get_valid_geo_ids(methodology_raw.get("locations") or []), + "affected_groups": _get_integer_array(methodology_raw.get("affected_groups") or []), + "organization_and_stakeholder_type": [ # Organization Type ID, Organization ID [organization_type_map.get(organization_id), organization_id] for field_id in stakeholder_fields_id for organization_id in basic_information.get(str(field_id)) or [] ], - # Metadata Fields Data **{ - key: basic_information.get(str(dynamic_fields[selector]['pk'])) + key: basic_information.get(str(dynamic_fields[selector]["pk"])) for key, selector in ( # NOTE: Make sure the keys are not conflicting with outer keys - ('assessment_type', 'assessment_type'), - ('language', 'language'), - ('coordination', 'coordination'), - ('frequency', 'frequency'), + ("assessment_type", "assessment_type"), + ("language", "language"), + ("coordination", "coordination"), + ("frequency", "frequency"), ) }, - # Housholds and Individuals **{ unit_of_analysis_type: sum( - attribute.get(str(methodology_attributes_fields['sampling_size_field_pk'])) or 0 + attribute.get(str(methodology_attributes_fields["sampling_size_field_pk"])) or 0 for attribute in methodology_attributes - if int( - attribute.get(str(dynamic_fields['data_collection_technique']['pk'])) or -1 - ) == int(methodology_attributes_fields[unit_of_analysis_type]) - ) for unit_of_analysis_type in ['households', 'individuals'] + if int(attribute.get(str(dynamic_fields["data_collection_technique"]["pk"])) or -1) + == int(methodology_attributes_fields[unit_of_analysis_type]) + ) + for unit_of_analysis_type in ["households", "individuals"] }, - - 'data_collection_technique_sample_size': { - technique['id']: sum( - attribute.get(str(methodology_attributes_fields['sampling_size_field_pk'])) or 0 + "data_collection_technique_sample_size": { + technique["id"]: sum( + attribute.get(str(methodology_attributes_fields["sampling_size_field_pk"])) or 0 for attribute in methodology_attributes - if int( - attribute.get(str(dynamic_fields['data_collection_technique']['pk'])) or -1 - ) == int(technique['id']) - ) for technique in dynamic_meta['data_collection_technique'] + if int(attribute.get(str(dynamic_fields["data_collection_technique"]["pk"])) or -1) == int(technique["id"]) + ) + for technique in dynamic_meta["data_collection_technique"] }, - # Methodology Fields Data **{ - key: [ - attribute.get(str(dynamic_fields[selector]['pk'])) - for attribute in methodology_attributes - ] for key, selector in ( + key: [attribute.get(str(dynamic_fields[selector]["pk"])) for attribute in methodology_attributes] + for key, selector in ( # NOTE: Make sure the keys are not conflicting with outer keys - ('data_collection_technique', 'data_collection_technique'), - ('unit_of_analysis', 'type_of_unit_of_analysis'), - ('unit_of_reporting', 'units'), - ('sampling_approach', 'sampling_approach'), + ("data_collection_technique", "data_collection_technique"), + ("unit_of_analysis", "type_of_unit_of_analysis"), + ("unit_of_reporting", "units"), + ("sampling_approach", "sampling_approach"), ) }, - - 'methodology_content': [ + "methodology_content": [ 1 if methodology_raw.get(content_type) else 0 - for content_type in ['objectives', 'data_collection_techniques', 'sampling', 'limitations'] + for content_type in ["objectives", "data_collection_techniques", "sampling", "limitations"] ], - - 'additional_documentation': [ + "additional_documentation": [ len(additional_documents.get(doc_type) or []) - for doc_type in ['executive_summary', 'assessment_data', 'questionnaire', 'misc'] + for doc_type in ["executive_summary", "assessment_data", "questionnaire", "misc"] ], } confidential_data.append(ary_data) # Hide source data from confidential leads for unrestricted users if ary.lead.confidentiality == Lead.Confidentiality.CONFIDENTIAL: - public_data.append({ - **ary_data, - 'lead': lead_data, # No source data - }) + public_data.append( + { + **ary_data, + "lead": lead_data, # No source data + } + ) else: public_data.append(ary_data) @@ -371,17 +333,17 @@ def get_project_ary_entry_stats(project): geo_array = _get_project_geoareas(project) return { - 'geo_data': geo_array, - 'entry_data': entry_stats, - 'ary_data': { - 'meta': meta, - 'data': public_data, + "geo_data": geo_array, + "entry_data": entry_stats, + "ary_data": { + "meta": meta, + "data": public_data, }, }, { - 'geo_data': geo_array, - 'entry_data': entry_stats, - 'ary_data': { - 'meta': meta, - 'data': confidential_data, + "geo_data": geo_array, + "entry_data": entry_stats, + "ary_data": { + "meta": meta, + "data": confidential_data, }, } diff --git a/apps/ary/tests/test_apis.py b/apps/ary/tests/test_apis.py index 787dbdc6af..e4635c55cb 100644 --- a/apps/ary/tests/test_apis.py +++ b/apps/ary/tests/test_apis.py @@ -1,20 +1,20 @@ -from dateutil.relativedelta import relativedelta - -from django.utils import timezone - -from deep.tests import TestCase - -from project.models import Project -from user.models import User -from lead.models import Lead from ary.models import ( + AffectedGroup, Assessment, AssessmentTemplate, - MetadataGroup, MetadataField, MetadataOption, + MetadataField, + MetadataGroup, + MetadataOption, MethodologyGroup, Sector, - AffectedGroup, ) +from dateutil.relativedelta import relativedelta +from django.utils import timezone +from lead.models import Lead +from project.models import Project +from user.models import User + +from deep.tests import TestCase class AssessmentTests(TestCase): @@ -26,12 +26,12 @@ def test_create_assessment(self): assessment_count = Assessment.objects.count() lead = self.create_lead() - url = '/api/v1/assessments/' + url = "/api/v1/assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate() @@ -39,21 +39,20 @@ def test_create_assessment(self): self.assert_201(response) self.assertEqual(Assessment.objects.count(), assessment_count + 1) - self.assertEqual(response.data['version_id'], 1) - self.assertEqual(response.data['metadata'], data['metadata']) - self.assertEqual(response.data['methodology'], - data['methodology']) + self.assertEqual(response.data["version_id"], 1) + self.assertEqual(response.data["metadata"], data["metadata"]) + self.assertEqual(response.data["methodology"], data["methodology"]) def test_create_assessment_no_project_yes_lead(self): assessment_count = Assessment.objects.count() lead = self.create_lead() - url = '/api/v1/assessments/' + url = "/api/v1/assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate() @@ -61,10 +60,9 @@ def test_create_assessment_no_project_yes_lead(self): self.assert_201(response) self.assertEqual(Assessment.objects.count(), assessment_count + 1) - self.assertEqual(response.data['version_id'], 1) - self.assertEqual(response.data['metadata'], data['metadata']) - self.assertEqual(response.data['methodology'], - data['methodology']) + self.assertEqual(response.data["version_id"], 1) + self.assertEqual(response.data["metadata"], data["metadata"]) + self.assertEqual(response.data["methodology"], data["methodology"]) def test_create_assessment_no_perm(self): assessment_count = Assessment.objects.count() @@ -74,12 +72,12 @@ def test_create_assessment_no_perm(self): lead.project.add_member(user, self.view_only_role) - url = '/api/v1/assessments/' + url = "/api/v1/assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate(user) @@ -93,10 +91,10 @@ def test_lead_assessment(self): assessment_count = Assessment.objects.count() lead = self.create_lead() - url = '/api/v1/lead-assessments/{}/'.format(lead.pk) + url = "/api/v1/lead-assessments/{}/".format(lead.pk) data = { - 'metadata': {'test_meta': 'Test 1'}, - 'methodology': {'test_methodology': 'Test 2'}, + "metadata": {"test_meta": "Test 1"}, + "methodology": {"test_methodology": "Test 2"}, } self.authenticate() @@ -104,20 +102,19 @@ def test_lead_assessment(self): self.assert_200(response) self.assertEqual(Assessment.objects.count(), assessment_count + 1) - self.assertEqual(response.data['version_id'], 1) - self.assertEqual(response.data['metadata'], data['metadata']) - self.assertEqual(response.data['methodology'], - data['methodology']) + self.assertEqual(response.data["version_id"], 1) + self.assertEqual(response.data["metadata"], data["metadata"]) + self.assertEqual(response.data["methodology"], data["methodology"]) # Next test editing the assessment - data['metadata'] = {'test_meta': 'Test 1 new'} + data["metadata"] = {"test_meta": "Test 1 new"} response = self.client.put(url, data) self.assert_200(response) - self.assertEqual(response.data['version_id'], 2) - self.assertEqual(response.data['metadata'], data['metadata']) + self.assertEqual(response.data["version_id"], 2) + self.assertEqual(response.data["metadata"], data["metadata"]) def test_get_template(self): template = self.create(AssessmentTemplate) @@ -135,7 +132,7 @@ def test_get_template(self): self.create(AffectedGroup, parent=ag_parent, template=template) self.create(AffectedGroup, parent=ag_parent, template=template) - url = '/api/v1/assessment-templates/{}/'.format(template.id) + url = "/api/v1/assessment-templates/{}/".format(template.id) self.authenticate() response = self.client.get(url) @@ -149,26 +146,24 @@ def test_project_assessment_template(self): project.assessment_template = template project.save() - url = '/api/v1/projects/{}/assessment-template/'.format( - project.id - ) + url = "/api/v1/projects/{}/assessment-template/".format(project.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['id'], template.id) - self.assertEqual(response.data['title'], template.title) + self.assertEqual(response.data["id"], template.id) + self.assertEqual(response.data["title"], template.title) def test_options(self): - url = '/api/v1/assessment-options/' + url = "/api/v1/assessment-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) def test_ary_copy_from_project_with_only_view(self): - url = '/api/v1/assessment-copy/' + url = "/api/v1/assessment-copy/" source_project = self.create(Project, role=self.view_only_role) dest_project = self.create(Project, role=self.admin_role) @@ -177,8 +172,8 @@ def test_ary_copy_from_project_with_only_view(self): arys_count = Assessment.objects.all().count() data = { - 'projects': [dest_project.pk], - 'assessments': [ary.pk], + "projects": [dest_project.pk], + "assessments": [ary.pk], } self.authenticate() @@ -188,26 +183,28 @@ def test_ary_copy_from_project_with_only_view(self): assert arys_count == Assessment.objects.all().count(), "No new assessment should have been created" def test_ary_copy(self): - url = '/api/v1/assessment-copy/' + url = "/api/v1/assessment-copy/" # Projects [Source] # NOTE: make sure the source projects have create/edit permissions - project1s = self.create(Project, title='project1s', role=self.admin_role) - project2s = self.create(Project, title='project2s', role=self.admin_role) - project3s = self.create(Project, title='project3s') - project4s = self.create(Project, title='project4s', role=self.normal_role) + project1s = self.create(Project, title="project1s", role=self.admin_role) + project2s = self.create(Project, title="project2s", role=self.admin_role) + project3s = self.create(Project, title="project3s") + project4s = self.create(Project, title="project4s", role=self.normal_role) # Projects [Destination] - project1d = self.create(Project, title='project1d') - project2d = self.create(Project, title='project2d', role=self.admin_role) - project4d = self.create(Project, title='project4d', role=self.view_only_role) + project1d = self.create(Project, title="project1d") + project2d = self.create(Project, title="project2d", role=self.admin_role) + project4d = self.create(Project, title="project4d", role=self.view_only_role) lead1s = self.create( - Lead, title='lead1s', source_type=Lead.SourceType.WEBSITE, url='https://random-source-11010', project=project1s) + Lead, title="lead1s", source_type=Lead.SourceType.WEBSITE, url="https://random-source-11010", project=project1s + ) lead2s = self.create( - Lead, title='lead2s', source_type=Lead.SourceType.WEBSITE, url='https://random-source-11011', project=project2s) - lead3s = self.create(Lead, title='lead3s', project=project3s) - lead4s = self.create(Lead, title='lead4s', project=project4s) + Lead, title="lead2s", source_type=Lead.SourceType.WEBSITE, url="https://random-source-11011", project=project2s + ) + lead3s = self.create(Lead, title="lead3s", project=project3s) + lead4s = self.create(Lead, title="lead4s", project=project4s) # ary1 Info (Will be used later for testing) @@ -219,18 +216,18 @@ def test_ary_copy(self): # For duplicate url validation check # Lead + Assessment - lead1d = self.create(Lead, title='lead1d', source_type=lead1s.source_type, url=lead1s.url, project=project1d) + lead1d = self.create(Lead, title="lead1d", source_type=lead1s.source_type, url=lead1s.url, project=project1d) self.create(Assessment, project=lead1d.project, lead=lead1d) # Request body data [also contains unauthorized projects and Assessments] data = { - 'projects': sorted([project4d.pk, project2d.pk, project1d.pk, project1s.pk]), - 'assessments': sorted([ary3.pk, ary2.pk, ary1.pk, ary4.pk]), + "projects": sorted([project4d.pk, project2d.pk, project1d.pk, project1s.pk]), + "assessments": sorted([ary3.pk, ary2.pk, ary1.pk, ary4.pk]), } # data [only contains authorized projects and assessments] validate_data = { - 'projects': sorted([project2d.pk, project1s.pk]), - 'assessments': sorted([ary4.pk, ary2.pk, ary1.pk]), + "projects": sorted([project2d.pk, project1s.pk]), + "assessments": sorted([ary4.pk, ary2.pk, ary1.pk]), } ary_stats = [ @@ -239,7 +236,6 @@ def test_ary_copy(self): (project2s, 1, 1), (project3s, 1, 1), (project4s, 1, 1), - (project1d, 1, 1), (project2d, 0, 3), (project4d, 0, 0), @@ -248,7 +244,7 @@ def test_ary_copy(self): # Current ARY Count for project, old_ary_count, _ in ary_stats: current_ary_count = Assessment.objects.filter(project_id=project.pk).count() - assert old_ary_count == current_ary_count, f'Project: {project.title} Assessment current count is different' + assert old_ary_count == current_ary_count, f"Project: {project.title} Assessment current count is different" self.authenticate() response = self.client.post(url, data) @@ -257,8 +253,8 @@ def test_ary_copy(self): rdata = response.json() # Sort the data since we are comparing lists sorted_rdata = { - 'projects': sorted(rdata['projects']), - 'assessments': sorted(rdata['assessments']), + "projects": sorted(rdata["projects"]), + "assessments": sorted(rdata["assessments"]), } self.assert_201(response) self.assertNotEqual(sorted_rdata, data) @@ -267,8 +263,7 @@ def test_ary_copy(self): # New ARY Count (after assessment-copy) for project, _, new_ary_count in ary_stats: current_ary_count = Assessment.objects.filter(project_id=project.pk).count() - assert new_ary_count == current_ary_count, \ - f'Project: {project.title} {project.pk} Assessment new count is different' + assert new_ary_count == current_ary_count, f"Project: {project.title} {project.pk} Assessment new count is different" def test_filter_assessment(self): now = timezone.now() @@ -284,20 +279,20 @@ def test_filter_assessment(self): self.update_obj(self.create(Assessment, lead=lead3, project=project), created_at=now + relativedelta(days=-2)) self.update_obj(self.create(Assessment, lead=lead4, project=project), created_at=now) - params = {'created_at__gte': now.strftime('%Y-%m-%d%z')} - url = '/api/v1/assessments/' + params = {"created_at__gte": now.strftime("%Y-%m-%d%z")} + url = "/api/v1/assessments/" self.authenticate() respose = self.client.get(url, params) self.assert_200(respose) - self.assertEqual(len(respose.data['results']), 4) + self.assertEqual(len(respose.data["results"]), 4) def test_assessment_options(self): - url = '/api/v1/assessment-options/' + url = "/api/v1/assessment-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertIn('created_by', response.data) - self.assertIn('project', response.data) + self.assertIn("created_by", response.data) + self.assertIn("project", response.data) def test_assessment_options_in_project(self): user1 = self.create_user() @@ -315,29 +310,26 @@ def test_assessment_options_in_project(self): self.create(Assessment, lead=lead3, project=project1, created_by=user2) # filter by project2 - url = f'/api/v1/assessment-options/?project={project2.id}' + url = f"/api/v1/assessment-options/?project={project2.id}" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - projects = response.data['project'] + projects = response.data["project"] self.assertEqual(len(projects), 1) - self.assertEqual(projects[0]['key'], project2.id, projects) - self.assertEqual(projects[0]['value'], project2.title, projects) + self.assertEqual(projects[0]["key"], project2.id, projects) + self.assertEqual(projects[0]["value"], project2.title, projects) # gives all the assessment that the user has created for the project - self.assertEqual( - set([item['key'] for item in response.data['created_by']]), - set([user1.id, user2.id]) - ) + self.assertEqual(set([item["key"] for item in response.data["created_by"]]), set([user1.id, user2.id])) # filter by project1 - url = f'/api/v1/assessment-options/?project={project1.id}' + url = f"/api/v1/assessment-options/?project={project1.id}" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - projects = response.data['project'] + projects = response.data["project"] self.assertEqual(len(projects), 1) - self.assertEqual(projects[0]['key'], project1.id, projects) - self.assertEqual(projects[0]['value'], project1.title, projects) + self.assertEqual(projects[0]["key"], project1.id, projects) + self.assertEqual(projects[0]["value"], project1.title, projects) # gives all the assessment that the user has created for the project - self.assertEqual(user2.id, response.data['created_by'][0]['key']) - self.assertEqual(len(response.data['created_by']), 1) + self.assertEqual(user2.id, response.data["created_by"][0]["key"]) + self.assertEqual(len(response.data["created_by"]), 1) diff --git a/apps/ary/tests/test_mutations.py b/apps/ary/tests/test_mutations.py index 934a3b2d6d..b2670c9130 100644 --- a/apps/ary/tests/test_mutations.py +++ b/apps/ary/tests/test_mutations.py @@ -1,14 +1,14 @@ -from utils.graphene.tests import GraphQLTestCase - from ary.factories import AssessmentFactory -from project.factories import ProjectFactory from lead.factories import LeadFactory +from project.factories import ProjectFactory from user.factories import UserFactory +from utils.graphene.tests import GraphQLTestCase + class TestAssessmentMutation(GraphQLTestCase): def test_assessment_delete_mutation(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID! $assessmentId: ID!) { project(id: $projectId) { assessmentDelete(id: $assessmentId) { @@ -20,7 +20,7 @@ def test_assessment_delete_mutation(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() non_member_user = UserFactory.create() @@ -30,11 +30,7 @@ def test_assessment_delete_mutation(self): ary = AssessmentFactory.create(project=project, lead=lead1) def _query_check(**kwargs): - return self.query_check( - query, - variables={'projectId': project.id, 'assessmentId': ary.id}, - **kwargs - ) + return self.query_check(query, variables={"projectId": project.id, "assessmentId": ary.id}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -42,8 +38,8 @@ def _query_check(**kwargs): # --- member user self.force_login(member_user) content = _query_check(assert_for_error=False) - self.assertEqual(content['data']['project']['assessmentDelete']['ok'], True) - self.assertIdEqual(content['data']['project']['assessmentDelete']['result']['id'], ary.id) + self.assertEqual(content["data"]["project"]["assessmentDelete"]["ok"], True) + self.assertIdEqual(content["data"]["project"]["assessmentDelete"]["result"]["id"], ary.id) # --- non_member user self.force_login(non_member_user) diff --git a/apps/ary/tests/test_permissions.py b/apps/ary/tests/test_permissions.py index dc6b8d7d49..10ac1066dd 100644 --- a/apps/ary/tests/test_permissions.py +++ b/apps/ary/tests/test_permissions.py @@ -1,44 +1,41 @@ -from deep.tests import TestCase - +from ary.models import PlannedAssessment from lead.models import Lead from project.models import Project, ProjectRole -from ary.models import PlannedAssessment from project.permissions import get_project_permissions_value +from deep.tests import TestCase + class TestAssessmentPermissions(TestCase): def setUp(self): super().setUp() self.no_assmt_creation_role = ProjectRole.objects.create( - title='No Assessment Creation Role', - entry_permissions=get_project_permissions_value('entry', '__all__'), - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), + title="No Assessment Creation Role", + entry_permissions=get_project_permissions_value("entry", "__all__"), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), assessment_permissions=0, ) self.assmt_creation_role = ProjectRole.objects.create( - title='Assessment Creation Role', - entry_permissions=get_project_permissions_value('entry', '__all__'), - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), - assessment_permissions=get_project_permissions_value('assessment', ['create']), + title="Assessment Creation Role", + entry_permissions=get_project_permissions_value("entry", "__all__"), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", ["create"]), ) def test_create_assessment_no_permission(self): - project = self.create( - Project, - role=self.no_assmt_creation_role - ) + project = self.create(Project, role=self.no_assmt_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/assessments/' + url = "/api/v1/assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate() @@ -46,17 +43,14 @@ def test_create_assessment_no_permission(self): self.assert_403(response) def test_create_assessment_with_permission(self): - project = self.create( - Project, - role=self.assmt_creation_role - ) + project = self.create(Project, role=self.assmt_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/assessments/' + url = "/api/v1/assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate() @@ -68,36 +62,33 @@ class TestPlannedAssessmentPermissions(TestCase): def setUp(self): super().setUp() self.no_assmt_creation_role = ProjectRole.objects.create( - title='No Assessment Creation Role', - entry_permissions=get_project_permissions_value('entry', '__all__'), - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), + title="No Assessment Creation Role", + entry_permissions=get_project_permissions_value("entry", "__all__"), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), assessment_permissions=0, ) self.assmt_creation_role = ProjectRole.objects.create( - title='Assessment Creation Role', - entry_permissions=get_project_permissions_value('entry', '__all__'), - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), - assessment_permissions=get_project_permissions_value('assessment', ['create']), + title="Assessment Creation Role", + entry_permissions=get_project_permissions_value("entry", "__all__"), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", ["create"]), ) def test_create_panned_assessment_no_permission(self): - project = self.create( - Project, - role=self.no_assmt_creation_role - ) + project = self.create(Project, role=self.no_assmt_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/planned-assessments/' + url = "/api/v1/planned-assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'metadata': {'test_meta': 'Test'}, - 'title': 'This is title', - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "metadata": {"test_meta": "Test"}, + "title": "This is title", + "methodology": {"test_methodology": "Test"}, } self.authenticate() @@ -106,19 +97,16 @@ def test_create_panned_assessment_no_permission(self): def test_create_panned_assessment_with_permission(self): initial_count = PlannedAssessment.objects.count() - project = self.create( - Project, - role=self.assmt_creation_role - ) + project = self.create(Project, role=self.assmt_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/planned-assessments/' + url = "/api/v1/planned-assessments/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'title': 'This is title', - 'metadata': {'test_meta': 'Test'}, - 'methodology': {'test_methodology': 'Test'}, + "lead": lead.pk, + "project": lead.project.pk, + "title": "This is title", + "metadata": {"test_meta": "Test"}, + "methodology": {"test_methodology": "Test"}, } self.authenticate() diff --git a/apps/ary/tests/test_schemas.py b/apps/ary/tests/test_schemas.py index bb09a8bde2..83c2c61a3d 100644 --- a/apps/ary/tests/test_schemas.py +++ b/apps/ary/tests/test_schemas.py @@ -1,16 +1,15 @@ -from utils.graphene.tests import GraphQLTestCase - -from lead.models import Lead - from ary.factories import AssessmentFactory -from project.factories import ProjectFactory from lead.factories import LeadFactory +from lead.models import Lead +from project.factories import ProjectFactory from user.factories import UserFactory +from utils.graphene.tests import GraphQLTestCase + class TestAssessmentQuery(GraphQLTestCase): def test_assessment_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { assessments(ordering: "id") { @@ -34,7 +33,7 @@ def test_assessment_query(self): } } } - ''' + """ project1 = ProjectFactory.create() project2 = ProjectFactory.create() member_user = UserFactory.create() @@ -54,23 +53,23 @@ def test_assessment_query(self): # -- non member user (Project 1) self.force_login(non_member_user) - content = self.query_check(query, variables={'id': project1.id}) - self.assertEqual(content['data']['project']['assessments']['totalCount'], 0) - self.assertListIds(content['data']['project']['assessments']['results'], [], content) + content = self.query_check(query, variables={"id": project1.id}) + self.assertEqual(content["data"]["project"]["assessments"]["totalCount"], 0) + self.assertListIds(content["data"]["project"]["assessments"]["results"], [], content) # -- non confidential member user (Project 1) self.force_login(non_confidential_member_user) - content = self.query_check(query, variables={'id': project1.id}) - self.assertEqual(content['data']['project']['assessments']['totalCount'], 1) - self.assertListIds(content['data']['project']['assessments']['results'], [ary2], content) + content = self.query_check(query, variables={"id": project1.id}) + self.assertEqual(content["data"]["project"]["assessments"]["totalCount"], 1) + self.assertListIds(content["data"]["project"]["assessments"]["results"], [ary2], content) # -- member user (Project 1) self.force_login(member_user) - content = self.query_check(query, variables={'id': project1.id}) - self.assertEqual(content['data']['project']['assessments']['totalCount'], 2) - self.assertListIds(content['data']['project']['assessments']['results'], [ary1, ary2], content) + content = self.query_check(query, variables={"id": project1.id}) + self.assertEqual(content["data"]["project"]["assessments"]["totalCount"], 2) + self.assertListIds(content["data"]["project"]["assessments"]["results"], [ary1, ary2], content) # -- member user (Project 2) - content = self.query_check(query, variables={'id': project2.id}) - self.assertEqual(content['data']['project']['assessments']['totalCount'], 1) - self.assertEqual(content['data']['project']['assessments']['results'][0]['id'], str(ary3.id)) + content = self.query_check(query, variables={"id": project2.id}) + self.assertEqual(content["data"]["project"]["assessments"]["totalCount"], 1) + self.assertEqual(content["data"]["project"]["assessments"]["results"][0]["id"], str(ary3.id)) diff --git a/apps/ary/utils.py b/apps/ary/utils.py index 97c8f14127..93c0605630 100644 --- a/apps/ary/utils.py +++ b/apps/ary/utils.py @@ -1,23 +1,21 @@ -from geo.models import Region, GeoArea +from assessment_registry.models import AdditionalDocument +from geo.models import GeoArea, Region from organization.models import Organization from utils.common import parse_number -from assessment_registry.models import AdditionalDocument - def get_title_or_none(Model): def _get_title(val): instance = Model.objects.filter(id=val).first() return instance and instance.title + return _get_title def get_location_title(val): if isinstance(val, dict): - return val.get('geo_json') and \ - val['geo_json'].get('properties') and \ - val['geo_json']['properties'].get('title') + return val.get("geo_json") and val["geo_json"].get("properties") and val["geo_json"]["properties"].get("title") instance = GeoArea.objects.filter(id=val).first() return instance and instance.title @@ -27,6 +25,7 @@ def _get_title(val): _val = int(val) if _val in IntegerEnum: return IntegerEnum(_val).label + return _get_title @@ -34,6 +33,7 @@ def get_model_attr_or_none(Model, attr): def _get_attr(val): instance = Model.objects.filter(id=val).first() return instance and instance.__dict__.get(attr) + return _get_attr @@ -43,6 +43,7 @@ def _get_attrs(val): if not instance: return {attr: None for attr in attrs} return {attr: instance.__dict__.get(attr) for attr in attrs} + return _get_attrs @@ -56,14 +57,14 @@ def get_organization_name(did): if org: m_org = org.data return { - 'name': m_org.title, - 'type': m_org.organization_type and m_org.organization_type.title, - 'key': did, + "name": m_org.title, + "type": m_org.organization_type and m_org.organization_type.title, + "key": did, } return { - 'name': '', - 'type': '', - 'key': did, + "name": "", + "type": "", + "key": did, } @@ -73,10 +74,7 @@ def get_additional_documents(assessment): for document_type in all_document_types: doc_list = [] - docs = AdditionalDocument.objects.filter( - assessment_registry=assessment, - document_type=document_type - ) + docs = AdditionalDocument.objects.filter(assessment_registry=assessment, document_type=document_type) for doc in docs: doc = { "id": doc.id, @@ -89,12 +87,12 @@ def get_additional_documents(assessment): FIELDS_KEYS_VALUE_EXTRACTORS = { - 'Country': get_country_name, - 'Donor': get_organization_name, - 'Partner': get_organization_name, - 'Partners': get_organization_name, - 'Lead Organization': get_organization_name, - 'International Partners': get_organization_name, - 'Government': get_organization_name, - 'National Partners': get_organization_name, + "Country": get_country_name, + "Donor": get_organization_name, + "Partner": get_organization_name, + "Partners": get_organization_name, + "Lead Organization": get_organization_name, + "International Partners": get_organization_name, + "Government": get_organization_name, + "National Partners": get_organization_name, } diff --git a/apps/ary/views.py b/apps/ary/views.py index 9b77dd9524..ac5c4ed38c 100644 --- a/apps/ary/views.py +++ b/apps/ary/views.py @@ -1,21 +1,14 @@ import copy -from django.contrib.auth.models import User -from django.http import Http404 -from rest_framework import ( - filters, - mixins, - permissions, - response, - views, - viewsets, -) import django_filters - -from deep.permissions import ModifyPermission, CreateAssessmentPermission +from django.contrib.auth.models import User +from django.http import Http404 +from lead.views import BaseCopyView, LeadCopyView from project.models import Project from project.permissions import PROJECT_PERMISSIONS as PROJ_PERMS -from lead.views import BaseCopyView, LeadCopyView +from rest_framework import filters, mixins, permissions, response, views, viewsets + +from deep.permissions import CreateAssessmentPermission, ModifyPermission from .filters import AssessmentFilterSet, PlannedAssessmentFilterSet from .models import ( @@ -26,22 +19,20 @@ ) from .serializers import ( AssessmentSerializer, - PlannedAssessmentSerializer, AssessmentTemplateSerializer, LeadAssessmentSerializer, LeadGroupAssessmentSerializer, + PlannedAssessmentSerializer, ) class AssessmentViewSet(viewsets.ModelViewSet): serializer_class = AssessmentSerializer - permission_classes = [permissions.IsAuthenticated, CreateAssessmentPermission, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.OrderingFilter, filters.SearchFilter) + permission_classes = [permissions.IsAuthenticated, CreateAssessmentPermission, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.OrderingFilter, filters.SearchFilter) filterset_class = AssessmentFilterSet - ordering_fields = ('lead__title', 'created_by', 'created_at') - search_fields = ('lead__title',) + ordering_fields = ("lead__title", "created_by", "created_at") + search_fields = ("lead__title",) def get_queryset(self): return Assessment.get_for(self.request.user) @@ -49,22 +40,19 @@ def get_queryset(self): class PlannedAssessmentViewSet(viewsets.ModelViewSet): serializer_class = PlannedAssessmentSerializer - permission_classes = [permissions.IsAuthenticated, CreateAssessmentPermission, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.OrderingFilter, filters.SearchFilter) + permission_classes = [permissions.IsAuthenticated, CreateAssessmentPermission, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.OrderingFilter, filters.SearchFilter) filterset_class = PlannedAssessmentFilterSet - ordering_fields = ('title', 'created_by', 'created_at') - search_fields = ('title',) + ordering_fields = ("title", "created_by", "created_at") + search_fields = ("title",) def get_queryset(self): return PlannedAssessment.get_for(self.request.user) -class LeadAssessmentViewSet(mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - viewsets.GenericViewSet): +class LeadAssessmentViewSet( + mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, viewsets.GenericViewSet +): """ Assessments accessed using associated lead id. @@ -73,11 +61,11 @@ class LeadAssessmentViewSet(mixins.RetrieveModelMixin, In put requests, if there is no existing assessment, one is automatically created. """ + serializer_class = LeadAssessmentSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - lookup_field = 'lead' - lookup_url_kwarg = 'pk' + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + lookup_field = "lead" + lookup_url_kwarg = "pk" def get_queryset(self): return Assessment.get_for(self.request.user) @@ -85,7 +73,7 @@ def get_queryset(self): def update(self, request, *args, **kwargs): # For put/patch request, we want to set `lead` data # from url - partial = kwargs.pop('partial', False) + partial = kwargs.pop("partial", False) try: instance = self.get_object() except Http404: @@ -93,23 +81,21 @@ def update(self, request, *args, **kwargs): data = { **request.data, - 'lead': kwargs['pk'], + "lead": kwargs["pk"], } - serializer = self.get_serializer(instance, data=data, - partial=partial) + serializer = self.get_serializer(instance, data=data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) - if getattr(instance, '_prefetched_objects_cache', None): + if getattr(instance, "_prefetched_objects_cache", None): instance._prefetched_objects_cache = {} return response.Response(serializer.data) -class LeadGroupAssessmentViewSet(mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - viewsets.GenericViewSet): +class LeadGroupAssessmentViewSet( + mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, viewsets.GenericViewSet +): """ Assessments accessed using associated lead group id. @@ -118,11 +104,11 @@ class LeadGroupAssessmentViewSet(mixins.RetrieveModelMixin, In put requests, if there is no existing assessment, one is automatically created. """ + serializer_class = LeadGroupAssessmentSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - lookup_field = 'lead_group' - lookup_url_kwarg = 'pk' + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + lookup_field = "lead_group" + lookup_url_kwarg = "pk" def get_queryset(self): return Assessment.get_for(self.request.user) @@ -130,7 +116,7 @@ def get_queryset(self): def update(self, request, *args, **kwargs): # For put/patch request, we want to set `lead_group` data # from url - partial = kwargs.pop('partial', False) + partial = kwargs.pop("partial", False) try: instance = self.get_object() except Http404: @@ -138,14 +124,13 @@ def update(self, request, *args, **kwargs): data = { **request.data, - 'lead_group': kwargs['pk'], + "lead_group": kwargs["pk"], } - serializer = self.get_serializer(instance, data=data, - partial=partial) + serializer = self.get_serializer(instance, data=data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) - if getattr(instance, '_prefetched_objects_cache', None): + if getattr(instance, "_prefetched_objects_cache", None): instance._prefetched_objects_cache = {} return response.Response(serializer.data) @@ -155,44 +140,43 @@ class AssessmentOptionsView(views.APIView): permission_classes = [permissions.IsAuthenticated] def get(self, request, version=None): - project_query = request.GET.get('project') - fields_query = request.GET.get('fields') + project_query = request.GET.get("project") + fields_query = request.GET.get("fields") projects = Project.get_for_member(request.user) if project_query: - projects = projects.filter(id__in=project_query.split(',')) + projects = projects.filter(id__in=project_query.split(",")) fields = None if fields_query: - fields = fields_query.split(',') + fields = fields_query.split(",") options = {} - if (fields is None or 'created_by' in fields): + if fields is None or "created_by" in fields: assessment_qs = Assessment.objects.filter(project__in=projects) - options['created_by'] = [ + options["created_by"] = [ { - 'key': user.id, - 'value': user.profile.get_display_name(), + "key": user.id, + "value": user.profile.get_display_name(), } - for user in User.objects.filter( - pk__in=assessment_qs.distinct().values('created_by') - ).select_related('profile') + for user in User.objects.filter(pk__in=assessment_qs.distinct().values("created_by")).select_related("profile") ] - if (fields is None or 'project' in fields): - options['project'] = [ + if fields is None or "project" in fields: + options["project"] = [ { - 'key': project.id, - 'value': project.title, - } for project in projects.distinct() + "key": project.id, + "value": project.title, + } + for project in projects.distinct() ] - if (fields is None or 'methodology_protection_info' in fields): - options['methodology_protection_info'] = [ + if fields is None or "methodology_protection_info" in fields: + options["methodology_protection_info"] = [ { - 'key': value, - 'value': label, + "key": value, + "value": label, } for value, label in MethodologyProtectionInfo.choices ] @@ -202,8 +186,7 @@ def get(self, request, version=None): class AssessmentTemplateViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = AssessmentTemplateSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return AssessmentTemplate.get_for(self.request.user) @@ -215,23 +198,24 @@ class AssessmentCopyView(BaseCopyView): """ CLONE_PERMISSION = PROJ_PERMS.assessment - CLONE_ROLE = 'role__assessment_permissions' - CLONE_ENTITY_NAME = 'assessment' + CLONE_ROLE = "role__assessment_permissions" + CLONE_ENTITY_NAME = "assessment" CLONE_ENTITY = Assessment def get_clone_context(self, request): - return { - 'lead_create_access_project_ids': set(LeadCopyView.get_project_ids_with_create_access(request)) - } + return {"lead_create_access_project_ids": set(LeadCopyView.get_project_ids_with_create_access(request))} @classmethod def clone_entity(cls, original_ary, project_id, user, context): lead, is_new = LeadCopyView.clone_or_get_lead( - original_ary.lead, project_id, user, context, - context['lead_create_access_project_ids'], + original_ary.lead, + project_id, + user, + context, + context["lead_create_access_project_ids"], ) # Skip assessment creation if lead already has a assessment (or use lead.refresh_from_db()) - if lead is None or (not is_new and getattr(lead, 'assessment', None)): + if lead is None or (not is_new and getattr(lead, "assessment", None)): return ary = copy.deepcopy(original_ary) diff --git a/apps/assessment_registry/admin.py b/apps/assessment_registry/admin.py index f656b301fd..9fdd09e66e 100644 --- a/apps/assessment_registry/admin.py +++ b/apps/assessment_registry/admin.py @@ -1,34 +1,34 @@ -from django.contrib import admin from admin_auto_filters.filters import AutocompleteFilterFactory +from django.contrib import admin from .models import ( + Answer, AssessmentRegistry, AssessmentRegistryOrganization, MethodologyAttribute, Question, - Answer, - ScoreRating, ScoreAnalyticalDensity, + ScoreRating, Summary, - SummarySubPillarIssue, SummaryFocus, - SummarySubDimensionIssue, SummaryIssue, + SummarySubDimensionIssue, + SummarySubPillarIssue, ) @admin.register(Question) class QuestionAdmin(admin.ModelAdmin): - list_display = ('id', 'sector', 'question') + list_display = ("id", "sector", "question") readonly_fields = ( - 'created_by', - 'modified_by', - 'client_id', + "created_by", + "modified_by", + "client_id", ) exclude = ( - 'created_by', - 'modified_by', - 'client_id', + "created_by", + "modified_by", + "client_id", ) def save_model(self, request, obj, form, change): @@ -42,81 +42,81 @@ def save_model(self, request, obj, form, change): class MethodologyAttributeInline(admin.TabularInline): model = MethodologyAttribute extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class AnswerInline(admin.TabularInline): model = Answer extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class ScoreInline(admin.TabularInline): model = ScoreRating extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class AnalyticalDensityInline(admin.TabularInline): model = ScoreAnalyticalDensity extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class SummaryInline(admin.TabularInline): model = Summary extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class SummarySubPillarIssueInline(admin.TabularInline): model = SummarySubPillarIssue extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class SummaryFocusInline(admin.TabularInline): model = SummaryFocus extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class SummarySubDimensionIssueInline(admin.TabularInline): model = SummarySubDimensionIssue extra = 0 - exclude = ('created_by', 'modified_by', 'client_id') + exclude = ("created_by", "modified_by", "client_id") class StakeHolderInline(admin.TabularInline): model = AssessmentRegistryOrganization extra = 0 - exclude = ('created_by', 'modified_by') + exclude = ("created_by", "modified_by") # TODO: Readonly mode @admin.register(SummaryIssue) class SummaryIssueAdmin(admin.ModelAdmin): - search_fields = ('sub_dimension',) - autocomplete_fields = ('parent',) + search_fields = ("sub_dimension",) + autocomplete_fields = ("parent",) @admin.register(AssessmentRegistry) class AssessmentRegistryAdmin(admin.ModelAdmin): - list_display = ('id', 'project', 'lead', 'created_at', 'publication_date') - readonly_fields = ('created_at', 'modified_at') + list_display = ("id", "project", "lead", "created_at", "publication_date") + readonly_fields = ("created_at", "modified_at") autocomplete_fields = ( - 'created_by', - 'modified_by', - 'project', - 'bg_countries', - 'locations', - 'lead', - 'project', + "created_by", + "modified_by", + "project", + "bg_countries", + "locations", + "lead", + "project", ) list_filter = ( - AutocompleteFilterFactory('Project', 'project'), - AutocompleteFilterFactory('Created By', 'created_by'), - 'created_at', + AutocompleteFilterFactory("Project", "project"), + AutocompleteFilterFactory("Created By", "created_by"), + "created_at", ) inlines = [ MethodologyAttributeInline, @@ -130,4 +130,4 @@ class AssessmentRegistryAdmin(admin.ModelAdmin): ] def get_queryset(self, request): - return super().get_queryset(request).prefetch_related('project', 'lead') + return super().get_queryset(request).prefetch_related("project", "lead") diff --git a/apps/assessment_registry/apps.py b/apps/assessment_registry/apps.py index 5aa7411010..e462794d88 100644 --- a/apps/assessment_registry/apps.py +++ b/apps/assessment_registry/apps.py @@ -2,4 +2,4 @@ class AssessmentRegistryConfig(AppConfig): - name = 'assessment_registry' + name = "assessment_registry" diff --git a/apps/assessment_registry/dashboard_schema.py b/apps/assessment_registry/dashboard_schema.py index 46dcb49cd1..90126ee0f6 100644 --- a/apps/assessment_registry/dashboard_schema.py +++ b/apps/assessment_registry/dashboard_schema.py @@ -1,30 +1,31 @@ -import graphene from dataclasses import dataclass -from django.db.models import Count, Sum, Avg, Case, Value, When +import graphene +from deep_explore.schema import count_by_date_queryset_generator from django.contrib.postgres.aggregates.general import ArrayAgg +from django.db import connection as django_db_connection from django.db import models +from django.db.models import Avg, Case, Count, Sum, Value, When from django.db.models.functions import TruncDay, TruncMonth -from django.db import connection as django_db_connection from geo.schema import ProjectGeoAreaType - +from organization.schema import OrganizationType as OrganizationObjectType from deep.caches import CacheHelper, CacheKey from utils.graphene.enums import EnumDescription + from .enums import ( AssessmentRegistryAffectedGroupTypeEnum, AssessmentRegistryCoordinationTypeEnum, AssessmentRegistryDataCollectionTechniqueTypeEnum, AssessmentRegistryFocusTypeEnum, AssessmentRegistryProtectionInfoTypeEnum, + AssessmentRegistryProximityTypeEnum, + AssessmentRegistrySamplingApproachTypeEnum, + AssessmentRegistryScoreCriteriaTypeEnum, AssessmentRegistrySectorTypeEnum, AssessmentRegistryUnitOfAnalysisTypeEnum, AssessmentRegistryUnitOfReportingTypeEnum, - AssessmentRegistrySamplingApproachTypeEnum, - AssessmentRegistryProximityTypeEnum, - AssessmentRegistryScoreCriteriaTypeEnum, ) -from deep_explore.schema import count_by_date_queryset_generator from .filter_set import ( AssessmentDashboardFilterDataInputType, AssessmentDashboardFilterSet, @@ -34,7 +35,6 @@ AssessmentRegistryOrganization, MethodologyAttribute, ) -from organization.schema import OrganizationType as OrganizationObjectType # TODO? NODE_CACHE_TIMEOUT = 60 * 60 * 1 @@ -43,6 +43,7 @@ def node_cache(cache_key): def cache_key_gen(root: AssessmentDashboardStat, *_): return root.cache_key + return CacheHelper.gql_cache( cache_key, timeout=NODE_CACHE_TIMEOUT, @@ -362,9 +363,7 @@ class AssessmentDashboardStatisticsType(graphene.ObjectType): assessment_per_affected_group = graphene.List(graphene.NonNull(AssessmentAffectedGroupCountByDateType)) assessment_per_humanitarian_sector = graphene.List(graphene.NonNull(AssessmentHumanitrainSectorCountByDateType)) assessment_per_protection_management = graphene.List(graphene.NonNull(AssessmentProtectionInformationCountByDateType)) - assessment_per_affected_group_and_sector = graphene.List( - graphene.NonNull(AssessmentPerAffectedGroupAndSectorCountByDateType) - ) + assessment_per_affected_group_and_sector = graphene.List(graphene.NonNull(AssessmentPerAffectedGroupAndSectorCountByDateType)) assessment_per_affected_group_and_geoarea = graphene.List( graphene.NonNull(AssessmentPerAffectedGroupAndGeoAreaCountByDateType) ) @@ -384,9 +383,7 @@ class AssessmentDashboardStatisticsType(graphene.ObjectType): assessment_by_sampling_approach_and_geolocation = graphene.List( graphene.NonNull(AssessmentByGeographicalAndSamplingApproachCountByDateType) ) - assessment_by_proximity_and_geolocation = graphene.List( - graphene.NonNull(AssessmentByGeographicalAndProximityCountByDateType) - ) + assessment_by_proximity_and_geolocation = graphene.List(graphene.NonNull(AssessmentByGeographicalAndProximityCountByDateType)) assessment_by_unit_of_analysis_and_geolocation = graphene.List( graphene.NonNull(AssessmentByGeographicalAndUnit_Of_AnalysisCountByDateType) ) @@ -398,15 +395,16 @@ class AssessmentDashboardStatisticsType(graphene.ObjectType): median_quality_score_over_time_by_month = graphene.List(graphene.NonNull(MedianQualityScoreOverTimeDateType)) median_quality_score_of_each_dimension = graphene.List(graphene.NonNull(MedianScoreOfEachDimensionType)) median_quality_score_of_each_dimension_by_date = graphene.List(graphene.NonNull(MedianScoreOfEachDimensionDateType)) - median_quality_score_of_each_dimension_by_date_month = graphene.List( - graphene.NonNull(MedianScoreOfEachDimensionDateType)) + median_quality_score_of_each_dimension_by_date_month = graphene.List(graphene.NonNull(MedianScoreOfEachDimensionDateType)) median_quality_score_of_analytical_density = graphene.List(graphene.NonNull(MedianScoreOfAnalyticalDensityType)) median_quality_score_by_analytical_density_date = graphene.List(graphene.NonNull(MedianScoreOfAnalyticalDensityDateType)) median_quality_score_by_analytical_density_date_month = graphene.List( - graphene.NonNull(MedianScoreOfAnalyticalDensityDateType)) + graphene.NonNull(MedianScoreOfAnalyticalDensityDateType) + ) median_quality_score_by_geoarea_and_sector = graphene.List(graphene.NonNull(MedianScoreOfGeographicalAndSectorDateType)) median_quality_score_by_geoarea_and_sector_by_month = graphene.List( - graphene.NonNull(MedianScoreOfGeographicalAndSectorDateType)) + graphene.NonNull(MedianScoreOfGeographicalAndSectorDateType) + ) median_quality_score_by_geoarea_and_affected_group = graphene.List( graphene.NonNull(MedianScoreOfGeoAreaAndAffectedGroupDateType) ) @@ -415,20 +413,20 @@ class AssessmentDashboardStatisticsType(graphene.ObjectType): @staticmethod def custom_resolver(root, info, _filter): - assessment_qs = ( - AssessmentRegistry.objects.filter( - project=info.context.active_project, - **get_global_filters(_filter), - ) + assessment_qs = AssessmentRegistry.objects.filter( + project=info.context.active_project, + **get_global_filters(_filter), ) assessment_qs_filter = AssessmentDashboardFilterSet(queryset=assessment_qs, data=_filter.get("assessment")).qs methodology_attribute_qs = MethodologyAttribute.objects.select_related("assessment_registry").filter( assessment_registry__in=assessment_qs_filter ) - cache_key = CacheHelper.generate_hash({ - 'project': info.context.active_project.id, - 'filter': _filter.__dict__, - }) + cache_key = CacheHelper.generate_hash( + { + "project": info.context.active_project.id, + "filter": _filter.__dict__, + } + ) return AssessmentDashboardStat( cache_key=cache_key, assessment_registry_qs=assessment_qs_filter, @@ -449,9 +447,12 @@ def resolve_total_stakeholder(root: AssessmentDashboardStat, info) -> int: @staticmethod @node_cache(CacheKey.AssessmentDashboard.TOTAL_COLLECTION_TECHNIQUE_COUNT) def resolve_total_collection_technique(root: AssessmentDashboardStat, info) -> int: - return root.methodology_attribute_qs\ - .filter(data_collection_technique__isnull=False)\ - .values("data_collection_technique").distinct().count() + return ( + root.methodology_attribute_qs.filter(data_collection_technique__isnull=False) + .values("data_collection_technique") + .distinct() + .count() + ) @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_COUNT) @@ -467,10 +468,10 @@ def resolve_assessment_count(root: AssessmentDashboardStat, info): def resolve_stakeholder_count(root: AssessmentDashboardStat, info): return ( root.assessment_registry_qs.filter(stakeholders__organization_type__title__isnull=False) - .values(stakeholder=models.F('stakeholders__organization_type__title')) - .annotate(count=Count('id')) - .order_by('stakeholder') - .values('count', 'stakeholder') + .values(stakeholder=models.F("stakeholders__organization_type__title")) + .annotate(count=Count("id")) + .order_by("stakeholder") + .values("count", "stakeholder") ) @staticmethod @@ -481,7 +482,7 @@ def resolve_collection_technique_count(root: AssessmentDashboardStat, info): .values("data_collection_technique") .annotate(count=Count("data_collection_technique")) .order_by("data_collection_technique") - .values('data_collection_technique', 'count') + .values("data_collection_technique", "count") ) @staticmethod @@ -498,7 +499,8 @@ def resolve_total_singlesector_assessment(root: AssessmentDashboardStat, info) - @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_BY_GEOAREA) def resolve_assessment_geographic_areas(root: AssessmentDashboardStat, info): return ( - root.assessment_registry_qs.filter(locations__isnull=False).values("locations") + root.assessment_registry_qs.filter(locations__isnull=False) + .values("locations") .annotate( region=models.F("locations__admin_level__region"), count=Count("locations__id"), @@ -508,13 +510,13 @@ def resolve_assessment_geographic_areas(root: AssessmentDashboardStat, info): code=models.F("locations__code"), ) .values( - 'locations', - 'count', - 'assessment_ids', - 'geo_area', - 'admin_level_id', - 'code', - 'region', + "locations", + "count", + "assessment_ids", + "geo_area", + "admin_level_id", + "code", + "region", ) .order_by("locations") ) @@ -527,53 +529,69 @@ def resolve_assessment_by_over_time(root: AssessmentDashboardStat, info): @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_PER_FRAMEWORK_PILLAR) def resolve_assessment_per_framework_pillar(root: AssessmentDashboardStat, info): - return root.assessment_registry_qs.annotate( - focus=models.Func(models.F("focuses"), function="unnest"), - ).values('focus').order_by('focus').annotate( - count=Count('id') - ).values('focus', 'count').annotate( - date=TruncDay('created_at') - ).values('focus', 'count', 'date') + return ( + root.assessment_registry_qs.annotate( + focus=models.Func(models.F("focuses"), function="unnest"), + ) + .values("focus") + .order_by("focus") + .annotate(count=Count("id")) + .values("focus", "count") + .annotate(date=TruncDay("created_at")) + .values("focus", "count", "date") + ) @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_PER_AFFECTED_GROUP) def resolve_assessment_per_affected_group(root: AssessmentDashboardStat, info): - return root.assessment_registry_qs.annotate( - affected_group=models.Func(models.F('affected_groups'), function='unnest'), - ).values('affected_group').order_by('affected_group').annotate( - count=Count('id') - ).values('affected_group', 'count').annotate( - date=TruncDay('created_at') - ).values('affected_group', 'count', 'date') + return ( + root.assessment_registry_qs.annotate( + affected_group=models.Func(models.F("affected_groups"), function="unnest"), + ) + .values("affected_group") + .order_by("affected_group") + .annotate(count=Count("id")) + .values("affected_group", "count") + .annotate(date=TruncDay("created_at")) + .values("affected_group", "count", "date") + ) @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_PER_HUMANITRATION_SECTOR) def resolve_assessment_per_humanitarian_sector(root: AssessmentDashboardStat, info): - return root.assessment_registry_qs.annotate( - sector=models.Func(models.F('sectors'), function='unnest'), - ).values('sector').order_by('sector').annotate( - count=Count('id') - ).values('sector', 'count').annotate( - date=TruncDay('created_at') - ).values('sector', 'count', 'date') + return ( + root.assessment_registry_qs.annotate( + sector=models.Func(models.F("sectors"), function="unnest"), + ) + .values("sector") + .order_by("sector") + .annotate(count=Count("id")) + .values("sector", "count") + .annotate(date=TruncDay("created_at")) + .values("sector", "count", "date") + ) @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_PER_PROTECTION_MANAGEMENT) def resolve_assessment_per_protection_management(root: AssessmentDashboardStat, info): - return root.assessment_registry_qs.annotate( - protection_management=models.Func(models.F('protection_info_mgmts'), function='unnest'), - ).values('protection_management').order_by('protection_management').annotate( - count=Count('id') - ).values('protection_management', 'count').annotate( - date=TruncDay('created_at') - ).values('protection_management', 'count', 'date') + return ( + root.assessment_registry_qs.annotate( + protection_management=models.Func(models.F("protection_info_mgmts"), function="unnest"), + ) + .values("protection_management") + .order_by("protection_management") + .annotate(count=Count("id")) + .values("protection_management", "count") + .annotate(date=TruncDay("created_at")) + .values("protection_management", "count", "date") + ) @staticmethod @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_AFFECTED_GROUP_AND_SECTOR) def resolve_assessment_per_affected_group_and_sector(root: AssessmentDashboardStat, info): # TODO : Global filter and assessment filter need to implement with django_db_connection.cursor() as cursor: - query = f''' + query = f""" SELECT sector, affected_group, @@ -593,7 +611,7 @@ def resolve_assessment_per_affected_group_and_sector(root: AssessmentDashboardSt WHERE project_id = {info.context.active_project.id} GROUP BY sector, affected_group ORDER BY sector, affected_group DESC; - ''' + """ cursor.execute(query, {}) return [ AssessmentPerAffectedGroupAndSectorCountByDateType(sector=data[0], affected_group=data[1], count=data[2]) @@ -634,9 +652,7 @@ def resolve_assessment_per_sector_and_geoarea(root: AssessmentDashboardStat, inf @node_cache(CacheKey.AssessmentDashboard.ASSESSMENT_BY_LEAD_ORGANIZATION) def resolve_assessment_by_lead_organization(root: AssessmentDashboardStat, info): return ( - AssessmentRegistryOrganization.objects.filter( - organization_type=AssessmentRegistryOrganization.Type.LEAD_ORGANIZATION - ) + AssessmentRegistryOrganization.objects.filter(organization_type=AssessmentRegistryOrganization.Type.LEAD_ORGANIZATION) .values(date=TruncDay("assessment_registry__created_at")) .filter(assessment_registry__in=root.assessment_registry_qs) .annotate(count=Count("organization")) @@ -716,7 +732,7 @@ def resolve_assessment_by_data_collection_technique_and_geolocation(root: Assess admin_level_id=models.F("assessment_registry__locations__admin_level_id"), ) .annotate(count=Count("assessment_registry__locations")) - .values('data_collection_technique', 'geo_area', 'region', 'admin_level_id', 'count') + .values("data_collection_technique", "geo_area", "region", "admin_level_id", "count") .order_by("assessment_registry__locations") ) @@ -732,7 +748,7 @@ def resolve_assessment_by_sampling_approach_and_geolocation(root: AssessmentDash admin_level_id=models.F("assessment_registry__locations__admin_level_id"), ) .annotate(count=Count("assessment_registry__locations")) - .values('sampling_approach', 'geo_area', 'region', 'admin_level_id', 'count') + .values("sampling_approach", "geo_area", "region", "admin_level_id", "count") .order_by("assessment_registry__locations") ) @@ -748,7 +764,7 @@ def resolve_assessment_by_proximity_and_geolocation(root: AssessmentDashboardSta admin_level_id=models.F("assessment_registry__locations__admin_level_id"), ) .annotate(count=Count("assessment_registry__locations")) - .values('proximity', 'geo_area', 'region', 'admin_level_id', 'count') + .values("proximity", "geo_area", "region", "admin_level_id", "count") .order_by("assessment_registry__locations") ) @@ -756,14 +772,15 @@ def resolve_assessment_by_proximity_and_geolocation(root: AssessmentDashboardSta @node_cache(CacheKey.AssessmentDashboard.UNIT_OF_ANALYSIS_AND_GEOLOCATION) def resolve_assessment_by_unit_of_analysis_and_geolocation(root: AssessmentDashboardStat, info): return ( - root.methodology_attribute_qs.filter(assessment_registry__locations__isnull=False).values( + root.methodology_attribute_qs.filter(assessment_registry__locations__isnull=False) + .values( "unit_of_analysis", geo_area=models.F("assessment_registry__locations"), region=models.F("assessment_registry__locations__admin_level__region"), admin_level_id=models.F("assessment_registry__locations__admin_level_id"), ) .annotate(count=Count("assessment_registry__locations")) - .values('geo_area', 'region', 'admin_level_id', 'unit_of_analysis', 'count') + .values("geo_area", "region", "admin_level_id", "unit_of_analysis", "count") .order_by("assessment_registry__locations") ) @@ -771,14 +788,15 @@ def resolve_assessment_by_unit_of_analysis_and_geolocation(root: AssessmentDashb @node_cache(CacheKey.AssessmentDashboard.UNIT_REPORTING_AND_GEOLOCATION) def resolve_assessment_by_unit_of_reporting_and_geolocation(root: AssessmentDashboardStat, info): return ( - root.methodology_attribute_qs.filter(assessment_registry__locations__isnull=False).values( + root.methodology_attribute_qs.filter(assessment_registry__locations__isnull=False) + .values( "unit_of_reporting", geo_area=models.F("assessment_registry__locations"), region=models.F("assessment_registry__locations__admin_level__region"), admin_level_id=models.F("assessment_registry__locations__admin_level_id"), ) .annotate(count=Count("assessment_registry__locations")) - .values('unit_of_reporting', 'count', 'geo_area', 'region', 'admin_level_id') + .values("unit_of_reporting", "count", "geo_area", "region", "admin_level_id") .order_by("assessment_registry__locations") ) @@ -808,11 +826,15 @@ def resolve_median_quality_score_by_geo_area(root: AssessmentDashboardStat, info final_score=( Avg( ( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") - ) / models.Value(10) - ) + (models.F("score_rating_matrix")) - ) / Count("id") * 5 + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") + ) + / models.Value(10) + ) + + (models.F("score_rating_matrix")) + ) + / Count("id") + * 5 ) .order_by() .values( @@ -849,11 +871,15 @@ def resolve_median_quality_score_over_time(root: AssessmentDashboardStat, info): final_score=( Avg( ( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") - ) / models.Value(10) - ) + Sum(models.F("score_rating_matrix")) - ) / Count("id") * 5 + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") + ) + / models.Value(10) + ) + + Sum(models.F("score_rating_matrix")) + ) + / Count("id") + * 5 ) .values("final_score", "date") ).exclude(final_score__isnull=True) @@ -883,11 +909,15 @@ def resolve_median_quality_score_over_time_by_month(root: AssessmentDashboardSta final_score=( Avg( ( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") - ) / models.Value(10) - ) + Sum(models.F("score_rating_matrix")) - ) / Count("id") * 5 + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") + ) + / models.Value(10) + ) + + Sum(models.F("score_rating_matrix")) + ) + / Count("id") + * 5 ) .values("final_score", "date") ).exclude(final_score__isnull=True) @@ -971,10 +1001,11 @@ def resolve_median_quality_score_of_analytical_density(root: AssessmentDashboard .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .order_by() .values("final_score", sector=models.F("analytical_density__sector")) @@ -989,10 +1020,11 @@ def resolve_median_quality_score_by_analytical_density_date(root: AssessmentDash .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .values("final_score", "date", sector=models.F("analytical_density__sector")) .exclude(analytical_density__sector__isnull=True) @@ -1006,10 +1038,11 @@ def resolve_median_quality_score_by_analytical_density_date_month(root: Assessme .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .values("final_score", "date", sector=models.F("analytical_density__sector")) .exclude(analytical_density__sector__isnull=True) @@ -1027,10 +1060,11 @@ def resolve_median_quality_score_by_geoarea_and_sector(root: AssessmentDashboard .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .annotate(geo_area=models.F("locations"), sector=models.F("analytical_density__sector")) .values("geo_area", "final_score", "sector", "date") @@ -1048,10 +1082,11 @@ def resolve_median_quality_score_by_geoarea_and_sector_by_month(root: Assessment .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .annotate(geo_area=models.F("locations"), sector=models.F("analytical_density__sector")) .values("geo_area", "final_score", "sector", "date") @@ -1080,11 +1115,15 @@ def resolve_median_quality_score_by_geoarea_and_affected_group(root: AssessmentD final_score=( Avg( ( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") - ) / models.Value(10) - ) + Sum(models.F("score_rating_matrix")) - ) / Count("id") * 5, + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") + ) + / models.Value(10) + ) + + Sum(models.F("score_rating_matrix")) + ) + / Count("id") + * 5, ) .annotate(affected_group=models.Func(models.F("affected_groups"), function="unnest")) .values("final_score", "date", "affected_group", geo_area=models.F("locations")) @@ -1115,11 +1154,15 @@ def resolve_median_quality_score_by_geoarea_and_affected_group_by_month(root: As final_score=( Avg( ( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") - ) / models.Value(10) - ) + Sum(models.F("score_rating_matrix")) - ) / Count("id") * 5, + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") + ) + / models.Value(10) + ) + + Sum(models.F("score_rating_matrix")) + ) + / Count("id") + * 5, ) .annotate(affected_group=models.Func(models.F("affected_groups"), function="unnest")) .values("final_score", "date", "affected_group", geo_area=models.F("locations")) @@ -1137,10 +1180,11 @@ def resolve_median_score_by_sector_and_affected_group(root: AssessmentDashboardS .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .annotate( affected_group=models.Func(models.F("affected_groups"), function="unnest"), @@ -1159,10 +1203,11 @@ def resolve_median_score_by_sector_and_affected_group_by_month(root: AssessmentD .annotate( final_score=( Avg( - models.F("analytical_density__figure_provided__len") * - models.F("analytical_density__analysis_level_covered__len") + models.F("analytical_density__figure_provided__len") + * models.F("analytical_density__analysis_level_covered__len") ) - ) / models.Value(10) + ) + / models.Value(10) ) .annotate( affected_group=models.Func(models.F("affected_groups"), function="unnest"), diff --git a/apps/assessment_registry/dataloaders.py b/apps/assessment_registry/dataloaders.py index 08fbfdfed1..2fa47bb5c7 100644 --- a/apps/assessment_registry/dataloaders.py +++ b/apps/assessment_registry/dataloaders.py @@ -1,16 +1,14 @@ -from django.db.models import Count from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property from django.db import connection as django_db_connection +from django.db.models import Count +from django.utils.functional import cached_property from geo.schema import get_geo_area_queryset_for_project_geo_area_type +from promise import Promise + from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from .models import ( - AssessmentRegistryOrganization, - SummaryIssue, -) +from .models import AssessmentRegistryOrganization, SummaryIssue class AssessmentRegistryOrganizationsLoader(DataLoaderWithContext): @@ -25,32 +23,20 @@ def batch_load_fn(self, keys): class AssessmentRegistryIssueLoader(DataLoaderWithContext): def batch_load_fn(self, keys): qs = SummaryIssue.objects.filter(id__in=keys) - _map = { - issue.pk: issue - for issue in qs - } + _map = {issue.pk: issue for issue in qs} return Promise.resolve([_map.get(key) for key in keys]) class AssessmentRegistryIssueChildLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = SummaryIssue.objects.filter( - parent__in=keys - ).values( - 'parent' - ).annotate( - child_count=Count( - 'id' - ) - ).values( - 'parent', - 'child_count' + qs = ( + SummaryIssue.objects.filter(parent__in=keys) + .values("parent") + .annotate(child_count=Count("id")) + .values("parent", "child_count") ) - counts_map = { - obj['parent']: obj['child_count'] - for obj in qs - } + counts_map = {obj["parent"]: obj["child_count"] for obj in qs} return Promise.resolve([counts_map.get(key, 0) for key in keys]) @@ -58,7 +44,7 @@ def batch_load_fn(self, keys): class SummaryIssueLevelLoader(DataLoaderWithContext): def batch_load_fn(self, keys): with django_db_connection.cursor() as cursor: - select_sql = f''' + select_sql = f""" WITH RECURSIVE parents AS ( SELECT sub_g.id, @@ -80,12 +66,9 @@ def batch_load_fn(self, keys): count(*) FROM parents GROUP BY main_entity_id - ''' + """ cursor.execute(select_sql, (tuple(keys),)) - _map = { - _id: level - for _id, level in cursor.fetchall() - } + _map = {_id: level for _id, level in cursor.fetchall()} return Promise.resolve([_map.get(key, 0) for key in keys]) diff --git a/apps/assessment_registry/enums.py b/apps/assessment_registry/enums.py index 183e6401a2..d4f76905a1 100644 --- a/apps/assessment_registry/enums.py +++ b/apps/assessment_registry/enums.py @@ -4,118 +4,113 @@ ) from .models import ( + AdditionalDocument, AssessmentRegistry, + AssessmentRegistryOrganization, MethodologyAttribute, - AdditionalDocument, + Question, + ScoreAnalyticalDensity, + ScoreRating, Summary, - SummaryIssue, SummaryFocus, - ScoreRating, - ScoreAnalyticalDensity, - Question, + SummaryIssue, SummarySubDimensionIssue, - AssessmentRegistryOrganization, ) - AssessmentRegistryCrisisTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.CrisisType, name='AssessmentRegistryCrisisTypeEnum' + AssessmentRegistry.CrisisType, name="AssessmentRegistryCrisisTypeEnum" ) AssessmentRegistryPreparednessTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.PreparednessType, name='AssessmentRegistryPreparednessTypeEnum' + AssessmentRegistry.PreparednessType, name="AssessmentRegistryPreparednessTypeEnum" ) AssessmentRegistryExternalSupportTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.ExternalSupportType, name='AssessmentRegistryExternalTypeEnum' + AssessmentRegistry.ExternalSupportType, name="AssessmentRegistryExternalTypeEnum" ) AssessmentRegistryCoordinationTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.CoordinationType, name='AssessmentRegistryCoordinationTypeEnum' -) -AssessmentRegistryDetailTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.Type, name='AssessmentRegistryDetailTypeEnum' + AssessmentRegistry.CoordinationType, name="AssessmentRegistryCoordinationTypeEnum" ) +AssessmentRegistryDetailTypeEnum = convert_enum_to_graphene_enum(AssessmentRegistry.Type, name="AssessmentRegistryDetailTypeEnum") AssessmentRegistryFamilyTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.FamilyType, name='AssessmentRegistryFamilyTypeEnum' + AssessmentRegistry.FamilyType, name="AssessmentRegistryFamilyTypeEnum" ) AssessmentRegistryFrequencyTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.FrequencyType, name='AssessmentRegistryFrequencyTypeEnum' + AssessmentRegistry.FrequencyType, name="AssessmentRegistryFrequencyTypeEnum" ) AssessmentRegistryConfidentialityTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.ConfidentialityType, name='AssessmentRegistryConfidentialityTypeEnum' + AssessmentRegistry.ConfidentialityType, name="AssessmentRegistryConfidentialityTypeEnum" ) AssessmentRegistryLanguageTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.Language, name='AssessmentRegistryLanguageTypeEnum' + AssessmentRegistry.Language, name="AssessmentRegistryLanguageTypeEnum" ) AssessmentRegistryFocusTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.FocusType, name='AssessmentRegistryFocusTypeEnum' + AssessmentRegistry.FocusType, name="AssessmentRegistryFocusTypeEnum" ) AssessmentRegistrySectorTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.SectorType, name='AssessmentRegistrySectorTypeEnum' + AssessmentRegistry.SectorType, name="AssessmentRegistrySectorTypeEnum" ) AssessmentRegistryProtectionInfoTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.ProtectionInfoType, name='AssessmentRegistryProtectionInfoTypeEnum' + AssessmentRegistry.ProtectionInfoType, name="AssessmentRegistryProtectionInfoTypeEnum" ) AssessmentRegistryProtectionRiskTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.ProtectionRiskType, name='AssessmentRegistryProtectionRiskTypeEnum' + AssessmentRegistry.ProtectionRiskType, name="AssessmentRegistryProtectionRiskTypeEnum" ) AssessmentRegistryStatusTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.StatusType, name='AssessmentRegistryStatusTypeEnum' + AssessmentRegistry.StatusType, name="AssessmentRegistryStatusTypeEnum" ) AssessmentRegistryAffectedGroupTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistry.AffectedGroupType, name='AssessmentRegistryAffectedGroupTypeEnum' + AssessmentRegistry.AffectedGroupType, name="AssessmentRegistryAffectedGroupTypeEnum" ) AssessmentRegistryDataCollectionTechniqueTypeEnum = convert_enum_to_graphene_enum( - MethodologyAttribute.CollectionTechniqueType, name='AssessmentRegistryDataCollectionTechniqueTypeEnum' + MethodologyAttribute.CollectionTechniqueType, name="AssessmentRegistryDataCollectionTechniqueTypeEnum" ) AssessmentRegistrySamplingApproachTypeEnum = convert_enum_to_graphene_enum( - MethodologyAttribute.SamplingApproachType, name='AssessmentRegistrySamplingApproachTypeEnum' + MethodologyAttribute.SamplingApproachType, name="AssessmentRegistrySamplingApproachTypeEnum" ) AssessmentRegistryProximityTypeEnum = convert_enum_to_graphene_enum( - MethodologyAttribute.ProximityType, name='AssessmentRegistryProximityTypeEnum' + MethodologyAttribute.ProximityType, name="AssessmentRegistryProximityTypeEnum" ) AssessmentRegistryUnitOfAnalysisTypeEnum = convert_enum_to_graphene_enum( - MethodologyAttribute.UnitOfAnalysisType, name='AssessmentRegistryUnitOfAnalysisTypeEnum' + MethodologyAttribute.UnitOfAnalysisType, name="AssessmentRegistryUnitOfAnalysisTypeEnum" ) AssessmentRegistryUnitOfReportingTypeEnum = convert_enum_to_graphene_enum( - MethodologyAttribute.UnitOfReportingType, name='AssessmentRegistryUnitOfReportingTypeEnum' + MethodologyAttribute.UnitOfReportingType, name="AssessmentRegistryUnitOfReportingTypeEnum" ) AssessmentRegistryDocumentTypeEnum = convert_enum_to_graphene_enum( - AdditionalDocument.DocumentType, name='AssessmentRegistryDocumentTypeEnum' + AdditionalDocument.DocumentType, name="AssessmentRegistryDocumentTypeEnum" ) AssessmentRegistryScoreCriteriaTypeEnum = convert_enum_to_graphene_enum( - ScoreRating.ScoreCriteria, name='AssessmentRegistryScoreCriteriaTypeEnum' + ScoreRating.ScoreCriteria, name="AssessmentRegistryScoreCriteriaTypeEnum" ) AssessmentRegistryScoreAnalyticalStatementTypeEnum = convert_enum_to_graphene_enum( - ScoreRating.AnalyticalStatement, name='AssessmentRegistryScoreAnalyticalStatementTypeEnum' + ScoreRating.AnalyticalStatement, name="AssessmentRegistryScoreAnalyticalStatementTypeEnum" ) AssessmentRegistryAnalysisLevelTypeEnum = convert_enum_to_graphene_enum( - ScoreAnalyticalDensity.AnalysisLevelCovered, name='AssessmentRegistryAnalysisLevelTypeEnum' + ScoreAnalyticalDensity.AnalysisLevelCovered, name="AssessmentRegistryAnalysisLevelTypeEnum" ) AssessmentRegistryAnalysisFigureTypeEnum = convert_enum_to_graphene_enum( - ScoreAnalyticalDensity.FigureProvidedByAssessment, name='AssessmentRegistryAnalysisFigureTypeEnum' -) -AssessmentRegistryRatingTypeEnum = convert_enum_to_graphene_enum( - ScoreRating.RatingType, name='AssessmentRegistryRatingType' + ScoreAnalyticalDensity.FigureProvidedByAssessment, name="AssessmentRegistryAnalysisFigureTypeEnum" ) +AssessmentRegistryRatingTypeEnum = convert_enum_to_graphene_enum(ScoreRating.RatingType, name="AssessmentRegistryRatingType") AssessmentRegistryCNAQuestionSectorTypeEnum = convert_enum_to_graphene_enum( - Question.QuestionSector, name='AssessmentRegistryCNAQuestionSectorTypeEnum' + Question.QuestionSector, name="AssessmentRegistryCNAQuestionSectorTypeEnum" ) AssessmentRegistryCNAQuestionSubSectorTypeEnum = convert_enum_to_graphene_enum( - Question.QuestionSubSector, name='AssessmentRegistryCNAQuestionSubSectorTypeEnum' + Question.QuestionSubSector, name="AssessmentRegistryCNAQuestionSubSectorTypeEnum" ) AssessmentRegistrySummaryPillarTypeEnum = convert_enum_to_graphene_enum( - Summary.Pillar, name='AssessmentRegistrySummaryPillarTypeEnum' + Summary.Pillar, name="AssessmentRegistrySummaryPillarTypeEnum" ) AssessmentRegistrySummaryFocusDimensionTypeEnum = convert_enum_to_graphene_enum( - SummaryFocus.Dimension, name='AssessmentRegistrySummaryFocusDimensionTypeEnum' + SummaryFocus.Dimension, name="AssessmentRegistrySummaryFocusDimensionTypeEnum" ) AssessmentRegistrySummarySubDimensionTypeEnum = convert_enum_to_graphene_enum( - SummaryIssue.SubDimension, name='AssessmentRegistrySummarySubDimensionTypeEnum' + SummaryIssue.SubDimension, name="AssessmentRegistrySummarySubDimensionTypeEnum" ) AssessmentRegistrySummarySubPillarTypeEnum = convert_enum_to_graphene_enum( - SummaryIssue.SubPillar, name='AssessmentRegistrySummarySubPillarTypeEnum' + SummaryIssue.SubPillar, name="AssessmentRegistrySummarySubPillarTypeEnum" ) AssessmentRegistryOrganizationTypeEnum = convert_enum_to_graphene_enum( - AssessmentRegistryOrganization.Type, name='AssessmentRegistryOrganizationTypeEnum' + AssessmentRegistryOrganization.Type, name="AssessmentRegistryOrganizationTypeEnum" ) enum_map = { get_enum_name_from_django_field(field): enum diff --git a/apps/assessment_registry/factories.py b/apps/assessment_registry/factories.py index d0d00c6107..0c5aaa4e9c 100644 --- a/apps/assessment_registry/factories.py +++ b/apps/assessment_registry/factories.py @@ -1,25 +1,25 @@ -import typing -import random import datetime -import factory -from factory import fuzzy -from factory.django import DjangoModelFactory -from django.db import models +import random +import typing +import factory from assessment_registry.models import ( - Question, + AdditionalDocument, Answer, AssessmentRegistry, MethodologyAttribute, - AdditionalDocument, - ScoreRating, + Question, ScoreAnalyticalDensity, - SummaryIssue, + ScoreRating, Summary, SummaryFocus, - SummarySubPillarIssue, + SummaryIssue, SummarySubDimensionIssue, + SummarySubPillarIssue, ) +from django.db import models +from factory import fuzzy +from factory.django import DjangoModelFactory DEFAULT_START_DATE = datetime.date(year=2017, month=1, day=1) @@ -28,9 +28,7 @@ def _choices(enum: typing.Type[models.IntegerChoices]): """ Get key from Django Choices """ - return [ - key for key, _ in enum.choices - ] + return [key for key, _ in enum.choices] class FuzzyChoiceList(fuzzy.FuzzyChoice): @@ -43,10 +41,7 @@ def fuzz(self): self.choices = list(self.choices_generator) if self.choices_len is None: self.choices_len = len(self.choices) - value = random.sample( - self.choices, - random.randint(0, self.choices_len) - ) + value = random.sample(self.choices, random.randint(0, self.choices_len)) if self.getter is None: return value return self.getter(value) @@ -71,7 +66,7 @@ class Meta: class SummarySubPillarIssueFactory(DjangoModelFactory): - text = factory.Faker('text') + text = factory.Faker("text") order = factory.Sequence(lambda n: n) class Meta: @@ -97,7 +92,7 @@ class Meta: class SummarySubDimensionIssueFactory(DjangoModelFactory): sector = fuzzy.FuzzyChoice(_choices(AssessmentRegistry.SectorType)) - text = factory.Faker('text') + text = factory.Faker("text") order = factory.Sequence(lambda n: n) class Meta: @@ -110,7 +105,7 @@ class Meta: class AnswerFactory(DjangoModelFactory): - answer = factory.Faker('boolean') + answer = factory.Faker("boolean") class Meta: model = Answer @@ -139,7 +134,7 @@ class Meta: class AdditionalDocumentFactory(DjangoModelFactory): document_type = fuzzy.FuzzyChoice(_choices(AdditionalDocument.DocumentType)) - external_link = 'https://example.com/invalid-file-link' + external_link = "https://example.com/invalid-file-link" class Meta: model = AdditionalDocument @@ -148,7 +143,7 @@ class Meta: class ScoreRatingFactory(DjangoModelFactory): score_type = fuzzy.FuzzyChoice(_choices(ScoreRating.ScoreCriteria)) rating = fuzzy.FuzzyChoice(_choices(ScoreRating.RatingType)) - reason = factory.Faker('text') + reason = factory.Faker("text") class Meta: model = ScoreRating @@ -179,13 +174,13 @@ class AssessmentRegistryFactory(DjangoModelFactory): publication_date = fuzzy.FuzzyDate(DEFAULT_START_DATE) # Additional Documents - executive_summary = factory.Faker('text') + executive_summary = factory.Faker("text") # Methodology - objectives = factory.Faker('text') - data_collection_techniques = factory.Faker('text') - sampling = factory.Faker('text') - limitations = factory.Faker('text') + objectives = factory.Faker("text") + data_collection_techniques = factory.Faker("text") + sampling = factory.Faker("text") + limitations = factory.Faker("text") # Focus # -- Focus Sectors @@ -194,13 +189,13 @@ class AssessmentRegistryFactory(DjangoModelFactory): protection_info_mgmts = FuzzyChoiceList(_choices(AssessmentRegistry.ProtectionInfoType)) affected_groups = FuzzyChoiceList(_choices(AssessmentRegistry.AffectedGroupType)) - metadata_complete = factory.Faker('boolean') - additional_document_complete = factory.Faker('boolean') - focus_complete = factory.Faker('boolean') - methodology_complete = factory.Faker('boolean') - summary_complete = factory.Faker('boolean') - cna_complete = factory.Faker('boolean') - score_complete = factory.Faker('boolean') + metadata_complete = factory.Faker("boolean") + additional_document_complete = factory.Faker("boolean") + focus_complete = factory.Faker("boolean") + methodology_complete = factory.Faker("boolean") + summary_complete = factory.Faker("boolean") + cna_complete = factory.Faker("boolean") + score_complete = factory.Faker("boolean") class Meta: model = AssessmentRegistry @@ -212,6 +207,4 @@ def bg_countries(self, create, extracted, **_): if extracted: for country in extracted: - self.bg_countries.add( # pyright: ignore [reportGeneralTypeIssues] - country - ) + self.bg_countries.add(country) # pyright: ignore [reportGeneralTypeIssues] diff --git a/apps/assessment_registry/filter_set.py b/apps/assessment_registry/filter_set.py index ae86c1a582..9d26a38973 100644 --- a/apps/assessment_registry/filter_set.py +++ b/apps/assessment_registry/filter_set.py @@ -1,29 +1,31 @@ -from deep.filter_set import generate_type_for_filter_set, OrderEnumMixin from user_resource.filters import UserResourceGqlFilterSet -from .models import AssessmentRegistry + +from deep.filter_set import OrderEnumMixin, generate_type_for_filter_set from utils.graphene.filters import IDListFilter, MultipleInputFilter + from .enums import ( AssessmentRegistryAffectedGroupTypeEnum, AssessmentRegistryCoordinationTypeEnum, AssessmentRegistryDetailTypeEnum, AssessmentRegistryFamilyTypeEnum, - AssessmentRegistryFrequencyTypeEnum, AssessmentRegistryFocusTypeEnum, + AssessmentRegistryFrequencyTypeEnum, AssessmentRegistrySectorTypeEnum, ) +from .models import AssessmentRegistry class AssessmentDashboardFilterSet(OrderEnumMixin, UserResourceGqlFilterSet): - stakeholder = IDListFilter(field_name='stakeholders') - lead_organization = IDListFilter(field_name='stakeholders') - location = IDListFilter(field_name='locations') - affected_group = MultipleInputFilter(AssessmentRegistryAffectedGroupTypeEnum, method='filter_affected_group') + stakeholder = IDListFilter(field_name="stakeholders") + lead_organization = IDListFilter(field_name="stakeholders") + location = IDListFilter(field_name="locations") + affected_group = MultipleInputFilter(AssessmentRegistryAffectedGroupTypeEnum, method="filter_affected_group") family = MultipleInputFilter(AssessmentRegistryFamilyTypeEnum) frequency = MultipleInputFilter(AssessmentRegistryFrequencyTypeEnum) - coordination_type = MultipleInputFilter(AssessmentRegistryCoordinationTypeEnum, field_name='coordinated_joint') - assessment_type = MultipleInputFilter(AssessmentRegistryDetailTypeEnum, field_name='details_type') - focuses = MultipleInputFilter(AssessmentRegistryFocusTypeEnum, method='filter_focuses') - sectors = MultipleInputFilter(AssessmentRegistrySectorTypeEnum, method='filter_sectors') + coordination_type = MultipleInputFilter(AssessmentRegistryCoordinationTypeEnum, field_name="coordinated_joint") + assessment_type = MultipleInputFilter(AssessmentRegistryDetailTypeEnum, field_name="details_type") + focuses = MultipleInputFilter(AssessmentRegistryFocusTypeEnum, method="filter_focuses") + sectors = MultipleInputFilter(AssessmentRegistrySectorTypeEnum, method="filter_sectors") class Meta: model = AssessmentRegistry diff --git a/apps/assessment_registry/filters.py b/apps/assessment_registry/filters.py index f6ad3772a1..1720cf3d49 100644 --- a/apps/assessment_registry/filters.py +++ b/apps/assessment_registry/filters.py @@ -1,16 +1,18 @@ import django_filters from django.db import models from django.db.models import Q - +from lead.models import Lead +from project.models import Project +from user.models import User from user_resource.filters import UserResourceGqlFilterSet -from utils.graphene.filters import SimpleInputFilter -from user.models import User -from project.models import Project -from lead.models import Lead +from utils.graphene.filters import SimpleInputFilter +from .enums import ( + AssessmentRegistrySummarySubDimensionTypeEnum, + AssessmentRegistrySummarySubPillarTypeEnum, +) from .models import AssessmentRegistry, SummaryIssue -from .enums import AssessmentRegistrySummarySubPillarTypeEnum, AssessmentRegistrySummarySubDimensionTypeEnum class AssessmentRegistryGQFilterSet(UserResourceGqlFilterSet): @@ -18,7 +20,7 @@ class AssessmentRegistryGQFilterSet(UserResourceGqlFilterSet): date_to = django_filters.DateFilter(required=False) project = django_filters.ModelMultipleChoiceFilter( queryset=Project.objects.all(), - field_name='lead__project', + field_name="lead__project", ) lead = django_filters.ModelMultipleChoiceFilter( queryset=Lead.objects.all(), @@ -27,16 +29,12 @@ class AssessmentRegistryGQFilterSet(UserResourceGqlFilterSet): queryset=User.objects.all(), ) publication_date_lte = django_filters.DateFilter( - field_name='publication_date', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'] + field_name="publication_date", lookup_expr="lte", input_formats=["%Y-%m-%d%z"] ) publication_date_gte = django_filters.DateFilter( - field_name='publication_date', - lookup_expr='gte', - input_formats=['%Y-%m-%d%z'] + field_name="publication_date", lookup_expr="gte", input_formats=["%Y-%m-%d%z"] ) - search = django_filters.CharFilter(method='filter_assessment_registry') + search = django_filters.CharFilter(method="filter_assessment_registry") class Meta: model = AssessmentRegistry @@ -44,9 +42,9 @@ class Meta: filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -54,27 +52,26 @@ class Meta: def filter_assessment_registry(self, qs, name, value): if not value: return qs - return qs.filter( - Q(lead__title__icontains=value) - ).distinct() + return qs.filter(Q(lead__title__icontains=value)).distinct() class AssessmentRegistryIssueGQFilterSet(django_filters.FilterSet): sub_pillar = SimpleInputFilter(AssessmentRegistrySummarySubPillarTypeEnum) sub_dimension = SimpleInputFilter(AssessmentRegistrySummarySubDimensionTypeEnum) - search = django_filters.CharFilter(method='filter_assessment_registry_issues') - is_parent = django_filters.BooleanFilter(method='filter_is_parent') + search = django_filters.CharFilter(method="filter_assessment_registry_issues") + is_parent = django_filters.BooleanFilter(method="filter_is_parent") class Meta: model = SummaryIssue - fields = ('label', 'parent',) + fields = ( + "label", + "parent", + ) def filter_assessment_registry_issues(self, qs, name, value): if not value: return qs - return qs.filter( - label__icontains=value - ) + return qs.filter(label__icontains=value) def filter_is_parent(self, qs, name, value): if value is None: diff --git a/apps/assessment_registry/management/commands/generate_dummy_assessments.py b/apps/assessment_registry/management/commands/generate_dummy_assessments.py index da5710df2b..671a626c6e 100644 --- a/apps/assessment_registry/management/commands/generate_dummy_assessments.py +++ b/apps/assessment_registry/management/commands/generate_dummy_assessments.py @@ -1,66 +1,64 @@ -import random import datetime +import random import typing -from factory import fuzzy -from django.core.management.base import BaseCommand -from django.db import transaction, models -from django.conf import settings - -from user.models import User from ary.models import AssessmentTemplate -from lead.models import Lead -from project.models import Project, ProjectRole -from geo.models import Region, GeoArea -from organization.models import Organization -from assessment_registry.models import ( - AssessmentRegistryOrganization, - Question, - SummaryIssue, -) -from project.factories import ProjectFactory -from lead.factories import LeadFactory, LeadPreviewFactory from assessment_registry.factories import ( + AdditionalDocumentFactory, + AnswerFactory, AssessmentRegistryFactory, MethodologyAttributeFactory, - AdditionalDocumentFactory, - ScoreRatingFactory, ScoreAnalyticalDensityFactory, - AnswerFactory, - SummaryMetaFactory, - SummarySubPillarIssueFactory, + ScoreRatingFactory, SummaryFocusFactory, + SummaryMetaFactory, SummarySubDimensionIssueFactory, + SummarySubPillarIssueFactory, ) +from assessment_registry.models import ( + AssessmentRegistryOrganization, + Question, + SummaryIssue, +) +from django.conf import settings +from django.core.management.base import BaseCommand +from django.db import models, transaction +from factory import fuzzy +from geo.models import GeoArea, Region +from lead.factories import LeadFactory, LeadPreviewFactory +from lead.models import Lead +from organization.models import Organization +from project.factories import ProjectFactory +from project.models import Project, ProjectRole +from user.models import User - -DUMMY_PROJECT_PREFIX = 'Dummy Project (Assessment)' +DUMMY_PROJECT_PREFIX = "Dummy Project (Assessment)" DEFAULT_START_DATETIME = datetime.datetime(year=2017, month=1, day=1, tzinfo=datetime.timezone.utc) created_at_fuzzy = fuzzy.FuzzyDateTime(DEFAULT_START_DATETIME) class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument('--delete-existing', dest='delete_existing', action='store_true') - parser.add_argument('--regions-from-project', dest='project_for_regions') - parser.add_argument('--user-email', dest='user_email', required=True) - parser.add_argument('--leads-counts', dest='leads_count', type=int, default=50) + parser.add_argument("--delete-existing", dest="delete_existing", action="store_true") + parser.add_argument("--regions-from-project", dest="project_for_regions") + parser.add_argument("--user-email", dest="user_email", required=True) + parser.add_argument("--leads-counts", dest="leads_count", type=int, default=50) def handle(self, **kwargs): if not settings.ALLOW_DUMMY_DATA_GENERATION: self.stderr.write( - 'Dummy data generation is not allowed for this instance.' - ' Use environment variable ALLOW_DUMMY_DATA_GENERATION to enable this' + "Dummy data generation is not allowed for this instance." + " Use environment variable ALLOW_DUMMY_DATA_GENERATION to enable this" ) return - user_email = kwargs['user_email'] - leads_count = kwargs['leads_count'] - delete_existing = kwargs['delete_existing'] - project_for_regions = kwargs['project_for_regions'] + user_email = kwargs["user_email"] + leads_count = kwargs["leads_count"] + delete_existing = kwargs["delete_existing"] + project_for_regions = kwargs["project_for_regions"] user = User.objects.get(email=user_email) self.ur_args = { - 'created_by': user, - 'modified_by': user, + "created_by": user, + "modified_by": user, } self.run(user, leads_count, delete_existing, project_for_regions) @@ -75,7 +73,7 @@ def _fuzzy_created_at(model, objects): for obj in objects: obj.created_at = created_at_fuzzy.fuzz() update_objs.append(obj) - model.objects.bulk_update(update_objs, ('created_at',)) + model.objects.bulk_update(update_objs, ("created_at",)) def generate_leads(self, project: Project, count: int): # Leads @@ -87,10 +85,7 @@ def generate_leads(self, project: Project, count: int): ) # Previews # NOTE: Bulk create is throwing onetoone key already exists error - [ - LeadPreviewFactory.create(lead=lead) - for lead in leads - ] + [LeadPreviewFactory.create(lead=lead) for lead in leads] # Fuzzy out the created at self._fuzzy_created_at(Lead, leads) return leads @@ -98,47 +93,44 @@ def generate_leads(self, project: Project, count: int): def generate_assessments(self, project: Project, leads: typing.List[Lead]): # Organization data assessment_organization_types = [c[0] for c in AssessmentRegistryOrganization.Type.choices] - organizations = list( - Organization.objects.only('id')[:300] - ) + organizations = list(Organization.objects.only("id")[:300]) organizations_len = len(organizations) # Geo data regions = list( Region.objects.filter( project=project, - ).only('id') + ).only("id") ) geo_areas = list( GeoArea.objects.filter( admin_level__region__project=project, admin_level__level__in=[1, 2], - ).annotate( - region_id=models.F('admin_level__region'), - ).only('id') + ) + .annotate( + region_id=models.F("admin_level__region"), + ) + .only("id") ) regions_len = len(regions) # Assessment Questions - ary_questions = list(Question.objects.only('id').all()[:200]) + ary_questions = list(Question.objects.only("id").all()[:200]) ary_questions_len = len(ary_questions) # Issues - summary_issues = list(SummaryIssue.objects.only('id').all()[:100]) + summary_issues = list(SummaryIssue.objects.only("id").all()[:100]) # Assessments total_leads = len(leads) for index, lead in enumerate(leads, start=1): - self.stdout.write(f'Processing for lead ({index}/{total_leads}): {lead}') + self.stdout.write(f"Processing for lead ({index}/{total_leads}): {lead}") assessment_registry = AssessmentRegistryFactory.create( project=project, lead=lead, **self.ur_args, ) assessment_registry.created_at = fuzzy.FuzzyDateTime(lead.created_at).fuzz() - assessment_registry.save(update_fields=('created_at',)) + assessment_registry.save(update_fields=("created_at",)) if organizations: - _organizations = random.sample( - organizations, - random.randint(0, organizations_len) - ) + _organizations = random.sample(organizations, random.randint(0, organizations_len)) stakeholders = [] for assessment_organization_type, organization in zip( assessment_organization_types, @@ -169,10 +161,8 @@ def generate_assessments(self, project: Project, leads: typing.List[Lead]): geo_area for geo_area in geo_areas # Annotated field region_id - if geo_area.region_id in [ # pyright: ignore [reportGeneralTypeIssues] - region.id - for region in selected_regions - ] + if geo_area.region_id + in [region.id for region in selected_regions] # pyright: ignore [reportGeneralTypeIssues] ] assessment_registry.locations.add( *random.sample( @@ -186,7 +176,7 @@ def generate_assessments(self, project: Project, leads: typing.List[Lead]): del geo_areas_filtered del selected_regions - ary_params = {'assessment_registry': assessment_registry} + ary_params = {"assessment_registry": assessment_registry} AdditionalDocumentFactory.create_batch(random.randint(0, 10), **ary_params) MethodologyAttributeFactory.create_batch(random.randint(0, 10), **ary_params) @@ -196,10 +186,7 @@ def generate_assessments(self, project: Project, leads: typing.List[Lead]): # Questions for question_ in random.sample(ary_questions, random.randint(0, ary_questions_len)): # NOTE: With BulkCreate unique error is thrown - AnswerFactory.create( - question=question_, - **ary_params - ) + AnswerFactory.create(question=question_, **ary_params) # Summary SummaryMetaFactory.create(**ary_params) if summary_issues: @@ -220,32 +207,33 @@ def run(self, user: User, leads_count: int, delete_existing: bool, project_for_r existing_dummy_projects_count = existing_dummy_projects.count() if delete_existing: if existing_dummy_projects_count: - self.stdout.write(f'There are {existing_dummy_projects_count} existing dummy projects.') - for _id, title, creator in existing_dummy_projects.values_list('id', 'title', 'created_by__email'): - self.stdout.write(f'{_id}: {title} - {creator}') + self.stdout.write(f"There are {existing_dummy_projects_count} existing dummy projects.") + for _id, title, creator in existing_dummy_projects.values_list("id", "title", "created_by__email"): + self.stdout.write(f"{_id}: {title} - {creator}") result = input("%s " % "This will delete above projects. Are you sure? type YES to delete: ") - if result == 'YES': + if result == "YES": with transaction.atomic(): - Project.objects.filter( - pk__in=existing_dummy_projects.values('id') - ).delete() + Project.objects.filter(pk__in=existing_dummy_projects.values("id")).delete() existing_dummy_projects_count = 0 project = ProjectFactory.create( - title=f'{DUMMY_PROJECT_PREFIX} {existing_dummy_projects_count}', + title=f"{DUMMY_PROJECT_PREFIX} {existing_dummy_projects_count}", assessment_template=AssessmentTemplate.objects.first(), **self.ur_args, ) project.created_at = DEFAULT_START_DATETIME - project.save(update_fields=('created_at',)) + project.save(update_fields=("created_at",)) if project_for_regions is None: # Using top used regions project_regions = list( Region.objects.filter( is_published=True, - ).annotate( - project_count=models.Count('project'), - ).order_by('-project_count').only('id')[:5] + ) + .annotate( + project_count=models.Count("project"), + ) + .order_by("-project_count") + .only("id")[:5] ) else: # Using regions from provided project @@ -253,16 +241,18 @@ def run(self, user: User, leads_count: int, delete_existing: bool, project_for_r Region.objects.filter( is_published=True, project=project_for_regions, - ).distinct().only('id') + ) + .distinct() + .only("id") ) - assert len(project_regions) > 0, 'There are no regions in selected project' + assert len(project_regions) > 0, "There are no regions in selected project" project.regions.add(*project_regions) project.add_member(user, role=ProjectRole.objects.get(type=ProjectRole.Type.ADMIN)) - self.stdout.write(f'Generating assessments for new project: {project.title}') + self.stdout.write(f"Generating assessments for new project: {project.title}") # Leads - self.stdout.write(f'Generating {leads_count} leads') + self.stdout.write(f"Generating {leads_count} leads") leads = self.generate_leads(project, leads_count) # Assessments - self.stdout.write(f'Generating assessments for {leads_count} leads') + self.stdout.write(f"Generating assessments for {leads_count} leads") self.generate_assessments(project, leads) diff --git a/apps/assessment_registry/management/commands/migrate_old_assessments.py b/apps/assessment_registry/management/commands/migrate_old_assessments.py index ea06417913..60f36ae704 100644 --- a/apps/assessment_registry/management/commands/migrate_old_assessments.py +++ b/apps/assessment_registry/management/commands/migrate_old_assessments.py @@ -1,27 +1,25 @@ -from django.db import transaction -from django.core.management.base import BaseCommand -from django.db.models import Subquery, OuterRef - from ary.models import ( Assessment, ScoreQuestionnaire, ScoreQuestionnaireSector, ScoreQuestionnaireSubSector, ) -from gallery.models import File -from geo.models import Region, GeoArea -from organization.models import Organization - from assessment_registry.models import ( + AdditionalDocument, + Answer, AssessmentRegistry, AssessmentRegistryOrganization, MethodologyAttribute, - AdditionalDocument, Question, - Answer, ScoreAnalyticalDensity, ScoreRating, ) +from django.core.management.base import BaseCommand +from django.db import transaction +from django.db.models import OuterRef, Subquery +from gallery.models import File +from geo.models import GeoArea, Region +from organization.models import Organization def empty_str_to_none(value): @@ -43,26 +41,23 @@ def get_key(choice_model, label): def get_choice_field_key(metadata, value, choice_model): for schema in metadata: if isinstance(value, int): - value = schema['schema']['options'][value] + value = schema["schema"]["options"][value] return get_key(choice_model, value) - elif schema['value'] == value: + elif schema["value"] == value: return get_key(choice_model, value) else: get_key(choice_model, value) def save_countries(assessment_registry, metadata): - countries = Region.objects.filter(id__in=metadata['Country']['key']) + countries = Region.objects.filter(id__in=metadata["Country"]["key"]) if countries: for country in countries: assessment_registry.bg_countries.add(country) def get_affected_groups_key(choice_model, label): - choices = { - k: v.split('/')[-1] - for k, v in choice_model.choices - } + choices = {k: v.split("/")[-1] for k, v in choice_model.choices} if not label: return for key, value in choices.items(): @@ -74,66 +69,63 @@ def get_affected_groups_key(choice_model, label): def create_stakeholders(organizations, assessment_reg, org_type): for org in organizations: AssessmentRegistryOrganization.objects.create( - organization_type=org_type, - assessment_registry=assessment_reg, - organization=org + organization_type=org_type, assessment_registry=assessment_reg, organization=org ) def save_stakeholders(metadata_dict, assessment_registry): - stakeholders_dict = { - k: v for k, v in AssessmentRegistryOrganization.Type.choices - } + stakeholders_dict = {k: v for k, v in AssessmentRegistryOrganization.Type.choices} for org_type_key, org_type_value in stakeholders_dict.items(): - stakeholder_keys = [] if not metadata_dict[org_type_value]['key'] else metadata_dict[org_type_value]['key'] + stakeholder_keys = [] if not metadata_dict[org_type_value]["key"] else metadata_dict[org_type_value]["key"] organizations = Organization.objects.filter(id__in=stakeholder_keys) if organizations: create_stakeholders(organizations, assessment_registry, org_type_key) def save_locations(methodology_json, assessment_registry): - if methodology_json.get('Locations'): - locations = GeoArea.objects.filter(title__in=methodology_json.get('Locations')) + if methodology_json.get("Locations"): + locations = GeoArea.objects.filter(title__in=methodology_json.get("Locations")) if locations: for loc in locations: assessment_registry.locations.add(loc) def save_methodology_attributes(methodology_json, assessment_registry): - methodology_attributes = methodology_json.get('Attributes', None) + methodology_attributes = methodology_json.get("Attributes", None) if methodology_attributes: for attribute in methodology_attributes: MethodologyAttribute.objects.create( assessment_registry=assessment_registry, - data_collection_technique=empty_str_to_none(attribute['Collection Technique'][0]['key']), - sampling_approach=empty_str_to_none(attribute['Sampling'][1]['key']), - sampling_size=empty_str_to_none(attribute['Sampling'][0]['key']), - proximity=empty_str_to_none(attribute['Proximity'][0]['key']), - unit_of_analysis=empty_str_to_none(attribute['Unit of Analysis'][0]['key']), - unit_of_reporting=empty_str_to_none(attribute['Unit of Reporting'][0]['key']) + data_collection_technique=empty_str_to_none(attribute["Collection Technique"][0]["key"]), + sampling_approach=empty_str_to_none(attribute["Sampling"][1]["key"]), + sampling_size=empty_str_to_none(attribute["Sampling"][0]["key"]), + proximity=empty_str_to_none(attribute["Proximity"][0]["key"]), + unit_of_analysis=empty_str_to_none(attribute["Unit of Analysis"][0]["key"]), + unit_of_reporting=empty_str_to_none(attribute["Unit of Reporting"][0]["key"]), ) def get_focus_data(methodology_json): def _get_focus_key(model_choice, label): - if label == 'Impact (scope & Scale)': + if label == "Impact (scope & Scale)": return get_key(model_choice, AssessmentRegistry.FocusType.IMPACT.label) - if label == 'Information and communication': + if label == "Information and communication": return get_key(model_choice, AssessmentRegistry.FocusType.INFORMATION_AND_COMMUNICATION.label) return get_key(model_choice, label) - focus_data = [_get_focus_key(AssessmentRegistry.FocusType, value) for value in methodology_json.get('Focuses') or []] + + focus_data = [_get_focus_key(AssessmentRegistry.FocusType, value) for value in methodology_json.get("Focuses") or []] return list(filter(lambda x: x is not None, focus_data)) def get_sector_data(methodology_json): def _get_sector_key(model_choice, label): - if label == 'Food': + if label == "Food": return get_key(model_choice, AssessmentRegistry.SectorType.FOOD_SECURITY.label) - if label == 'WASH': + if label == "WASH": return get_key(model_choice, AssessmentRegistry.SectorType.WASH.label) return get_key(model_choice, label) - sector_data = [_get_sector_key(AssessmentRegistry.SectorType, value) for value in methodology_json.get('Sectors') or []] + sector_data = [_get_sector_key(AssessmentRegistry.SectorType, value) for value in methodology_json.get("Sectors") or []] return list(filter(lambda x: x is not None, sector_data)) @@ -142,58 +134,47 @@ def create_additional_document(assessment_reg, old_file_type, old_file_id=None, def _save_additional_doc(doc_type): AdditionalDocument.objects.create( - document_type=doc_type, - assessment_registry=assessment_reg, - file=file, - external_link=external_link or "" + document_type=doc_type, assessment_registry=assessment_reg, file=file, external_link=external_link or "" ) - if old_file_type == 'assessment_data': + if old_file_type == "assessment_data": _save_additional_doc(AdditionalDocument.DocumentType.ASSESSMENT_DATABASE) - if old_file_type == 'misc': + if old_file_type == "misc": _save_additional_doc(AdditionalDocument.DocumentType.MISCELLANEOUS) - if old_file_type == 'questionnaire': + if old_file_type == "questionnaire": _save_additional_doc(AdditionalDocument.DocumentType.QUESTIONNAIRE) def save_additional_documents(old_ary, assessment_registry): - old_ary_additional_docs = (old_ary.metadata or {}).get('additional_documents') + old_ary_additional_docs = (old_ary.metadata or {}).get("additional_documents") for k, v in old_ary_additional_docs.items(): if v: for file in v: - file_id = file.get('id', None) + file_id = file.get("id", None) if file_id: create_additional_document( assessment_reg=assessment_registry, old_file_type=k, - old_file_id=file['id'], + old_file_id=file["id"], ) else: - create_additional_document( - assessment_reg=assessment_registry, - old_file_type=k, - external_link=file['url'] - ) + create_additional_document(assessment_reg=assessment_registry, old_file_type=k, external_link=file["url"]) def migrate_score_data(old_ary, assessment_reg): score_json = old_ary.get_score_json() - analytical_density_data = (score_json.get('matrix_pillars'))['Analytical Density'] - sector_value_dict = [(k, v['value']) for k, v in analytical_density_data.items() if not v['value'] == ''] + analytical_density_data = (score_json.get("matrix_pillars"))["Analytical Density"] + sector_value_dict = [(k, v["value"]) for k, v in analytical_density_data.items() if not v["value"] == ""] for sector, value in sector_value_dict: sector_key = get_key(AssessmentRegistry.SectorType, sector) if sector_key: - ScoreAnalyticalDensity.objects.get_or_create( - assessment_registry=assessment_reg, - sector=sector_key, - score=value * 2 - ) + ScoreAnalyticalDensity.objects.get_or_create(assessment_registry=assessment_reg, sector=sector_key, score=value * 2) - score_rating_data = score_json.get('pillars') + score_rating_data = score_json.get("pillars") score_criteria_list = [] for analytical_statement, score_criterias in score_rating_data.items(): - score_criteria_score = [(criteria, v['value']) for criteria, v in score_criterias.items()] + score_criteria_score = [(criteria, v["value"]) for criteria, v in score_criterias.items()] score_criteria_list.extend(score_criteria_score) for score_criteria, score_value in score_criteria_list: @@ -202,12 +183,12 @@ def migrate_score_data(old_ary, assessment_reg): ScoreRating.objects.get_or_create( assessment_registry=assessment_reg, score_type=get_key(ScoreRating.ScoreCriteria, score_criteria), - rating=score_value + rating=score_value, ) def migrate_cna_questions(): - cna_sectors = ScoreQuestionnaireSector.objects.filter(method='cna') + cna_sectors = ScoreQuestionnaireSector.objects.filter(method="cna") cna_subsectors = ScoreQuestionnaireSubSector.objects.filter(sector__in=cna_sectors) cna_questions = ScoreQuestionnaire.objects.filter(sub_sector__in=cna_subsectors) @@ -215,38 +196,36 @@ def migrate_cna_questions(): Question.objects.get_or_create( sub_sector=get_key(Question.QuestionSubSector, question.sub_sector.title), sector=get_key(Question.QuestionSector, question.sub_sector.sector.title), - question=question.text + question=question.text, ) def migrate_cna_data(old_ary, new_assessment_registry): questionnaire = old_ary.questionnaire if questionnaire: - questions = questionnaire.get('cna', None) + questions = questionnaire.get("cna", None) if questions: - cna = questions['questions'] + cna = questions["questions"] for k, v in cna.items(): if v: old_q = ScoreQuestionnaire.objects.get(id=int(k)) ass_question = Question.objects.get( question=old_q.text, sector=get_key(Question.QuestionSector, old_q.sub_sector.sector.title), - sub_sector=get_key(Question.QuestionSubSector, old_q.sub_sector.title) + sub_sector=get_key(Question.QuestionSubSector, old_q.sub_sector.title), ) if ass_question: Answer.objects.create( - question=ass_question, - assessment_registry=new_assessment_registry, - answer=v['value'] + question=ass_question, assessment_registry=new_assessment_registry, answer=v["value"] ) def update_new_ary_created_update_date(): AssessmentRegistry.objects.update( - created_at=Subquery(Assessment.objects.filter(lead=OuterRef('lead')).values('created_at')), - created_by=Subquery(Assessment.objects.filter(lead=OuterRef('lead')).values('created_by')), - modified_at=Subquery(Assessment.objects.filter(lead=OuterRef('lead')).values('modified_at')), - modified_by=Subquery(Assessment.objects.filter(lead=OuterRef('lead')).values('modified_by')) + created_at=Subquery(Assessment.objects.filter(lead=OuterRef("lead")).values("created_at")), + created_by=Subquery(Assessment.objects.filter(lead=OuterRef("lead")).values("created_by")), + modified_at=Subquery(Assessment.objects.filter(lead=OuterRef("lead")).values("modified_at")), + modified_by=Subquery(Assessment.objects.filter(lead=OuterRef("lead")).values("modified_by")), ) @@ -260,7 +239,7 @@ def handle(self, *args, **kwargs): except Exception: failed_ids.append(ary.id) if not failed_ids == []: - self.stdout.write(f'Failed to migrate data IDs: {failed_ids}') + self.stdout.write(f"Failed to migrate data IDs: {failed_ids}") update_new_ary_created_update_date() @transaction.atomic @@ -268,78 +247,82 @@ def map_old_to_new_data(self, assessment_id): old_ary = Assessment.objects.get(id=assessment_id) if not old_ary: return - self.stdout.write(f'Migrating data for assessment id {old_ary.id}') + self.stdout.write(f"Migrating data for assessment id {old_ary.id}") meta_data_json = old_ary.get_metadata_json() - meta_data = meta_data_json.get('Background') + meta_data_json.get('Details') + \ - meta_data_json.get('Dates') + meta_data_json.get('Stakeholders') + meta_data = ( + meta_data_json.get("Background") + + meta_data_json.get("Details") + + meta_data_json.get("Dates") + + meta_data_json.get("Stakeholders") + ) metadata_dict = {} for d in meta_data: - k = d['schema']['name'] + k = d["schema"]["name"] v = dict() - v['value'] = d['value'] - v['key'] = d['key'] + v["value"] = d["value"] + v["key"] = d["key"] metadata_dict[k] = v methodology_json = old_ary.get_methodology_json() # protection management - old_protection_mgmt = methodology_json.get('Protection Info', None) + old_protection_mgmt = methodology_json.get("Protection Info", None) new_protection_mgmt = [] if old_protection_mgmt: for value in old_protection_mgmt: new_protection_mgmt.append(get_key(AssessmentRegistry.ProtectionInfoType, value)) # Affected Groups - old_affected_groups = methodology_json.get('Affected Groups', None) + old_affected_groups = methodology_json.get("Affected Groups", None) new_affected_groups = [] if old_affected_groups: - old_affected_groups_list = [aff_grp['title'] for aff_grp in old_affected_groups] + old_affected_groups_list = [aff_grp["title"] for aff_grp in old_affected_groups] for aff_grp in old_affected_groups_list: new_affected_groups.append(get_affected_groups_key(AssessmentRegistry.AffectedGroupType, aff_grp)) def _get_bg_crisis_type(): - crisis_type = get_choice_field_key( - meta_data_json.get('Background'), - metadata_dict['Crisis Type']['value'], - AssessmentRegistry.CrisisType - ) if metadata_dict['Crisis Type']['key'] == 14 or 11 else metadata_dict['Crisis Type']['key'] + crisis_type = ( + get_choice_field_key( + meta_data_json.get("Background"), metadata_dict["Crisis Type"]["value"], AssessmentRegistry.CrisisType + ) + if metadata_dict["Crisis Type"]["key"] == 14 or 11 + else metadata_dict["Crisis Type"]["key"] + ) return crisis_type input_data = { - 'project': old_ary.project, - 'lead': old_ary.lead, - 'bg_crisis_type': _get_bg_crisis_type(), - 'bg_crisis_start_date': metadata_dict['Crisis Start Date']['value'], - 'bg_preparedness': metadata_dict['Preparedness']['key'], - 'external_support': get_key(AssessmentRegistry.ExternalSupportType, metadata_dict['External Support']['value']), - 'coordinated_joint': metadata_dict['Coordination']['key'], - 'cost_estimates_usd': metadata_dict['Cost estimates in USD']['key'], - 'details_type': metadata_dict['Type']['key'], - 'family': metadata_dict['Family']['key'], - 'frequency': metadata_dict['Frequency']['key'], - 'confidentiality': get_key(AssessmentRegistry.ConfidentialityType, metadata_dict['Confidentiality']['value']), - 'language': metadata_dict['Language']['key'], - 'no_of_pages': metadata_dict['Number of Pages']['key'], - 'data_collection_start_date': metadata_dict['Data Collection Start Date']['value'], - 'data_collection_end_date': metadata_dict['Data Collection End Date']['value'], - 'publication_date': metadata_dict['Publication Date']['value'], - 'executive_summary': '', - 'focuses': get_focus_data(methodology_json) or [], - 'sectors': get_sector_data(methodology_json) or [], - 'protection_info_mgmts': new_protection_mgmt, - 'affected_groups': new_affected_groups, - 'sampling': methodology_json.get('Sampling', None), - 'objectives': methodology_json.get('Objectives', None), - 'limitations': methodology_json.get('Limitations', None), - 'data_collection_techniques': methodology_json.get('Data Collection Technique', None), - - 'created_at': old_ary.created_at, - 'modified_at': old_ary.modified_at, - 'created_by': old_ary.created_by, - 'modified_by': old_ary.modified_by - + "project": old_ary.project, + "lead": old_ary.lead, + "bg_crisis_type": _get_bg_crisis_type(), + "bg_crisis_start_date": metadata_dict["Crisis Start Date"]["value"], + "bg_preparedness": metadata_dict["Preparedness"]["key"], + "external_support": get_key(AssessmentRegistry.ExternalSupportType, metadata_dict["External Support"]["value"]), + "coordinated_joint": metadata_dict["Coordination"]["key"], + "cost_estimates_usd": metadata_dict["Cost estimates in USD"]["key"], + "details_type": metadata_dict["Type"]["key"], + "family": metadata_dict["Family"]["key"], + "frequency": metadata_dict["Frequency"]["key"], + "confidentiality": get_key(AssessmentRegistry.ConfidentialityType, metadata_dict["Confidentiality"]["value"]), + "language": metadata_dict["Language"]["key"], + "no_of_pages": metadata_dict["Number of Pages"]["key"], + "data_collection_start_date": metadata_dict["Data Collection Start Date"]["value"], + "data_collection_end_date": metadata_dict["Data Collection End Date"]["value"], + "publication_date": metadata_dict["Publication Date"]["value"], + "executive_summary": "", + "focuses": get_focus_data(methodology_json) or [], + "sectors": get_sector_data(methodology_json) or [], + "protection_info_mgmts": new_protection_mgmt, + "affected_groups": new_affected_groups, + "sampling": methodology_json.get("Sampling", None), + "objectives": methodology_json.get("Objectives", None), + "limitations": methodology_json.get("Limitations", None), + "data_collection_techniques": methodology_json.get("Data Collection Technique", None), + "created_at": old_ary.created_at, + "modified_at": old_ary.modified_at, + "created_by": old_ary.created_by, + "modified_by": old_ary.modified_by, } assessment_reg, created = AssessmentRegistry.objects.get_or_create(**input_data) @@ -359,4 +342,4 @@ def _get_bg_crisis_type(): migrate_score_data(old_ary, assessment_reg) - self.stdout.write(f'Migrating data for assessment id {old_ary.id} Done') + self.stdout.write(f"Migrating data for assessment id {old_ary.id} Done") diff --git a/apps/assessment_registry/models.py b/apps/assessment_registry/models.py index c3f7779bfe..2b68ea01fe 100644 --- a/apps/assessment_registry/models.py +++ b/apps/assessment_registry/models.py @@ -1,183 +1,184 @@ -from django.db import models from django.contrib.postgres.fields import ArrayField from django.core.exceptions import ValidationError +from django.db import models from django.utils.translation import gettext_lazy as _ - -from user_resource.models import UserResource -from geo.models import Region -from organization.models import Organization from gallery.models import File +from geo.models import GeoArea, Region from lead.models import Lead -from geo.models import GeoArea +from organization.models import Organization +from user_resource.models import UserResource class AssessmentRegistry(UserResource): class CrisisType(models.IntegerChoices): - EARTH_QUAKE = 1, 'Earthquake' - GROUND_SHAKING = 2, 'Ground Shaking' - TSUNAMI = 3, 'Tsunami' - VOLCANO = 4, 'Volcano' - VOLCANIC_ERUPTION = 5, 'Volcanic Eruption' - MASS_MOMENT_DRY = 6, 'Mass Movement (Dry)' - ROCK_FALL = 7, 'Rockfall' - AVALANCE = 8, 'Avalance' - LANDSLIDE = 9, 'Landslide' - SUBSIDENCE = 10, 'Subsidence' - EXTRA_TROPICAL_CYCLONE = 11, 'Extra Tropical Cyclone' - TROPICAL_CYCLONE = 12, 'Tropical Cyclone' - LOCAL_STROM = 13, 'Local/Convective Strom' - FLOOD_RAIN = 14, 'Flood/Rain' - GENERAL_RIVER_FLOOD = 15, 'General River Flood' - FLASH_FLOOD = 16, 'Flash Flood' - STROM_SURGE_FLOOD = 17, 'Strom surge/Coastal Flood' - MASS_MOVEMENT_WET = 18, 'Mass Movement (wet)' - EXTREME_TEMPERATURE = 19, 'Extreme Temperature' - HEAT_WAVE = 20, 'Heat Wave' - COLD_WAVE = 21, 'Cold Wave' - EXTREME_WEATHER_CONDITION = 22, 'Extreme Weather Conditions' - DROUGHT = 23, 'Drought' - WILDFIRE = 24, 'Wildfire' - POPULATION_DISPLACEMENT = 25, 'Population displacement' - CONFLICT = 26, 'Conflict' - ECONOMIC = 27, 'Economic' - EPIDEMIC = 28, 'Epidemic' + EARTH_QUAKE = 1, "Earthquake" + GROUND_SHAKING = 2, "Ground Shaking" + TSUNAMI = 3, "Tsunami" + VOLCANO = 4, "Volcano" + VOLCANIC_ERUPTION = 5, "Volcanic Eruption" + MASS_MOMENT_DRY = 6, "Mass Movement (Dry)" + ROCK_FALL = 7, "Rockfall" + AVALANCE = 8, "Avalance" + LANDSLIDE = 9, "Landslide" + SUBSIDENCE = 10, "Subsidence" + EXTRA_TROPICAL_CYCLONE = 11, "Extra Tropical Cyclone" + TROPICAL_CYCLONE = 12, "Tropical Cyclone" + LOCAL_STROM = 13, "Local/Convective Strom" + FLOOD_RAIN = 14, "Flood/Rain" + GENERAL_RIVER_FLOOD = 15, "General River Flood" + FLASH_FLOOD = 16, "Flash Flood" + STROM_SURGE_FLOOD = 17, "Strom surge/Coastal Flood" + MASS_MOVEMENT_WET = 18, "Mass Movement (wet)" + EXTREME_TEMPERATURE = 19, "Extreme Temperature" + HEAT_WAVE = 20, "Heat Wave" + COLD_WAVE = 21, "Cold Wave" + EXTREME_WEATHER_CONDITION = 22, "Extreme Weather Conditions" + DROUGHT = 23, "Drought" + WILDFIRE = 24, "Wildfire" + POPULATION_DISPLACEMENT = 25, "Population displacement" + CONFLICT = 26, "Conflict" + ECONOMIC = 27, "Economic" + EPIDEMIC = 28, "Epidemic" class PreparednessType(models.IntegerChoices): - WITH_PREPAREDNESS = 1, 'With preparedness' - WITHOUT_PREPAREDNESS = 2, 'Without preparedness' + WITH_PREPAREDNESS = 1, "With preparedness" + WITHOUT_PREPAREDNESS = 2, "Without preparedness" class ExternalSupportType(models.IntegerChoices): - EXTERNAL_SUPPORT_RECIEVED = 1, 'External support received' - NO_EXTERNAL_SUPPORT_RECEIVED = 2, 'No external support received' + EXTERNAL_SUPPORT_RECIEVED = 1, "External support received" + NO_EXTERNAL_SUPPORT_RECEIVED = 2, "No external support received" class CoordinationType(models.IntegerChoices): - COORDINATED = 1, 'Coordinated - Joint' - HARMONIZED = 2, 'Coordinated - Harmonized' - UNCOORDINATED = 3, 'Uncoordinated' + COORDINATED = 1, "Coordinated - Joint" + HARMONIZED = 2, "Coordinated - Harmonized" + UNCOORDINATED = 3, "Uncoordinated" class Type(models.IntegerChoices): - INITIAL = 3, 'Initial' - RAPID = 2, 'Rapid' - IN_DEPTH = 1, 'In-depth' - MONITORING = 4, 'Monitoring' - REGISTRATION = 6, 'Registration' - OTHER = 5, 'Other' + INITIAL = 3, "Initial" + RAPID = 2, "Rapid" + IN_DEPTH = 1, "In-depth" + MONITORING = 4, "Monitoring" + REGISTRATION = 6, "Registration" + OTHER = 5, "Other" class FamilyType(models.IntegerChoices): - DISPLACEMENT_TRAKING_MATRIX = 1, 'Displacement Tracking Matrix' - MULTI_CLUSTER_INITIAL_AND_RAPID_ASSESSMENT = 2, 'Multi Cluster Initial and Rapid Assessment (MIRA)' - MULTI_SECTORIAL_NEEDS_ASSESSMENT = 3, 'Multi Sectoral Needs Assessment (MSNA)' - EMERGENCY_FOOD_SECURITY_ASSESSMENT = 4, 'Emergency Food Security Assessment (EFSA)' - COMPREHENSIVE_FOOD_SECURITY_AND_VULNERABILITY_ANALYSIS = \ - 5, 'Comprehensive Food Security and Vulnerability Analysis (CFSVA)' - PROTECTION_MONITORING = 6, 'Protection Monitoring' - HUMANITARIAN_NEEDS_OVERVIEW = 7, 'Humanitarian Needs Overview (HNO)' - BRIEFING_NOTE = 8, 'Briefing Note' - REGISTRATION = 9, 'Registration' - IDP_PROFILING_EXERCISE = 10, 'IDPs profiling exercise' - CENSUS = 11, 'Census' - REFUGEE_AND_MIGRANT_RESPONSE_PLAN = 12, 'Refugee and Migrant Response Plan (RMRP)' - RUFUGEE_RESPONSE_PLAN = 13, 'Refugee Response Plan (RRP)' - SMART_NUTRITION_SURVEY = 14, 'Smart Nutrition Survey' - OTHER = 15, 'Other' + DISPLACEMENT_TRAKING_MATRIX = 1, "Displacement Tracking Matrix" + MULTI_CLUSTER_INITIAL_AND_RAPID_ASSESSMENT = 2, "Multi Cluster Initial and Rapid Assessment (MIRA)" + MULTI_SECTORIAL_NEEDS_ASSESSMENT = 3, "Multi Sectoral Needs Assessment (MSNA)" + EMERGENCY_FOOD_SECURITY_ASSESSMENT = 4, "Emergency Food Security Assessment (EFSA)" + COMPREHENSIVE_FOOD_SECURITY_AND_VULNERABILITY_ANALYSIS = ( + 5, + "Comprehensive Food Security and Vulnerability Analysis (CFSVA)", + ) + PROTECTION_MONITORING = 6, "Protection Monitoring" + HUMANITARIAN_NEEDS_OVERVIEW = 7, "Humanitarian Needs Overview (HNO)" + BRIEFING_NOTE = 8, "Briefing Note" + REGISTRATION = 9, "Registration" + IDP_PROFILING_EXERCISE = 10, "IDPs profiling exercise" + CENSUS = 11, "Census" + REFUGEE_AND_MIGRANT_RESPONSE_PLAN = 12, "Refugee and Migrant Response Plan (RMRP)" + RUFUGEE_RESPONSE_PLAN = 13, "Refugee Response Plan (RRP)" + SMART_NUTRITION_SURVEY = 14, "Smart Nutrition Survey" + OTHER = 15, "Other" class FrequencyType(models.IntegerChoices): - ONE_OFF = 1, 'One off' - REGULAR = 2, 'Regular' + ONE_OFF = 1, "One off" + REGULAR = 2, "Regular" class ConfidentialityType(models.IntegerChoices): - UNPROTECTED = 1, 'Unprotected' - CONFIDENTIAL = 2, 'Confidential' + UNPROTECTED = 1, "Unprotected" + CONFIDENTIAL = 2, "Confidential" class Language(models.IntegerChoices): - ENGLISH = 1, 'English' - FRENCH = 2, 'French' - SPANISH = 3, 'Spanish' - ARABIC = 4, 'Arabic' - PORTUGESE = 5, 'Portugese' + ENGLISH = 1, "English" + FRENCH = 2, "French" + SPANISH = 3, "Spanish" + ARABIC = 4, "Arabic" + PORTUGESE = 5, "Portugese" class FocusType(models.IntegerChoices): - CONTEXT = 1, 'Context' - SHOCK_EVENT = 2, 'Shock/Event' - DISPLACEMENT = 3, 'Displacement' - CASUALTIES = 4, 'Casualties' - INFORMATION_AND_COMMUNICATION = 5, 'Information & Communication' - HUMANITERIAN_ACCESS = 6, 'Humanitarian Access' - IMPACT = 7, 'Impact (scope & Scale)' - HUMANITARIAN_CONDITIONS = 8, 'Humanitarian Conditions' - PEOPLE_AT_RISK = 9, 'People at risk' - PRIORITIES_AND_PREFERENCES = 10, 'Priorities & Preferences' - RESPONSE_AND_CAPACITIES = 11, 'Response & Capacities' + CONTEXT = 1, "Context" + SHOCK_EVENT = 2, "Shock/Event" + DISPLACEMENT = 3, "Displacement" + CASUALTIES = 4, "Casualties" + INFORMATION_AND_COMMUNICATION = 5, "Information & Communication" + HUMANITERIAN_ACCESS = 6, "Humanitarian Access" + IMPACT = 7, "Impact (scope & Scale)" + HUMANITARIAN_CONDITIONS = 8, "Humanitarian Conditions" + PEOPLE_AT_RISK = 9, "People at risk" + PRIORITIES_AND_PREFERENCES = 10, "Priorities & Preferences" + RESPONSE_AND_CAPACITIES = 11, "Response & Capacities" class SectorType(models.IntegerChoices): - FOOD_SECURITY = 1, 'Food Security' - HEALTH = 2, 'Health' - SHELTER = 3, 'Shelter' - WASH = 4, 'Wash' - PROTECTION = 5, 'Protection' - NUTRITION = 6, 'Nutrition' - LIVELIHOOD = 7, 'Livelihood' - EDUCATION = 8, 'Education' - LOGISTICS = 9, 'Logistics' - INTER_CROSS_SECTOR = 10, 'Inter/Cross Sector' + FOOD_SECURITY = 1, "Food Security" + HEALTH = 2, "Health" + SHELTER = 3, "Shelter" + WASH = 4, "Wash" + PROTECTION = 5, "Protection" + NUTRITION = 6, "Nutrition" + LIVELIHOOD = 7, "Livelihood" + EDUCATION = 8, "Education" + LOGISTICS = 9, "Logistics" + INTER_CROSS_SECTOR = 10, "Inter/Cross Sector" class ProtectionInfoType(models.IntegerChoices): - PROTECTION_MONITORING = 1, 'Protection Monitoring' - PROTECTION_NEEDS_ASSESSMENT = 2, 'Protection Needs Assessment' - CASE_MANAGEMENT = 3, 'Case Management' - POPULATION_DATA = 4, 'Population Data' - PROTECTION_RESPONSE_M_E = 5, 'Protection Response M&E' - COMMUNICATING_WITH_IN_AFFECTED_COMMUNITIES = 6, 'Communicating with(in) Affected Communities' - SECURITY_AND_SITUATIONAL_AWARENESS = 7, 'Security & Situational Awareness' - SECTORAL_SYSTEM_OTHER = 8, 'Sectoral Systems/Other' + PROTECTION_MONITORING = 1, "Protection Monitoring" + PROTECTION_NEEDS_ASSESSMENT = 2, "Protection Needs Assessment" + CASE_MANAGEMENT = 3, "Case Management" + POPULATION_DATA = 4, "Population Data" + PROTECTION_RESPONSE_M_E = 5, "Protection Response M&E" + COMMUNICATING_WITH_IN_AFFECTED_COMMUNITIES = 6, "Communicating with(in) Affected Communities" + SECURITY_AND_SITUATIONAL_AWARENESS = 7, "Security & Situational Awareness" + SECTORAL_SYSTEM_OTHER = 8, "Sectoral Systems/Other" class StatusType(models.IntegerChoices): - PLANNED = 1, 'Planned' - ONGOING = 2, 'Ongoing' - FINALIZED = 3, 'Finalized' + PLANNED = 1, "Planned" + ONGOING = 2, "Ongoing" + FINALIZED = 3, "Finalized" class AffectedGroupType(models.IntegerChoices): - ALL = 1, 'All' - ALL_AFFECTED = 2, 'All/Affected' - ALL_NOT_AFFECTED = 3, 'All/Not Affected' - ALL_AFFECTED_NOT_DISPLACED = 4, 'All/Affected/Not Displaced' - ALL_AFFECTED_DISPLACED = 5, 'All/Affected/Displaced' - ALL_AFFECTED_DISPLACED_IN_TRANSIT = 6, 'All/Affected/Displaced/In Transit' - ALL_AFFECTED_DISPLACED_MIGRANTS = 7, 'All/Affected/Displaced/Migrants' - ALL_AFFECTED_DISPLACED_IDPS = 8, 'All/Affected/Displaced/IDPs' - ALL_AFFECTED_DISPLACED_ASYLUM_SEEKER = 9, 'All/Affected/Displaced/Asylum Seeker' - ALL_AFFECTED_DISPLACED_OTHER_OF_CONCERN = 10, 'All/Affected/Displaced/Other of concerns' - ALL_AFFECTED_DISPLACED_RETURNEES = 11, 'All/Affected/Displaced/Returnees' - ALL_AFFECTED_DISPLACED_REFUGEES = 12, 'All/Affected/Displaced/Refugees' - ALL_AFFECTED_DISPLACED_MIGRANTS_IN_TRANSIT = 13, 'All/Affected/Displaced/Migrants/In transit' - ALL_AFFECTED_DISPLACED_MIGRANTS_PERMANENTS = 14, 'All/Affected/Displaced/Migrants/Permanents' - ALL_AFFECTED_DISPLACED_MIGRANTS_PENDULAR = 15, 'All/Affected/Displaced/Migrants/Pendular' - ALL_AFFECTED_NOT_DISPLACED_NO_HOST = 16, 'All/Affected/Not Displaced/Not Host' - ALL_AFFECTED_NOT_DISPLACED_HOST = 17, 'All/Affected/Not Displaced/Host' + ALL = 1, "All" + ALL_AFFECTED = 2, "All/Affected" + ALL_NOT_AFFECTED = 3, "All/Not Affected" + ALL_AFFECTED_NOT_DISPLACED = 4, "All/Affected/Not Displaced" + ALL_AFFECTED_DISPLACED = 5, "All/Affected/Displaced" + ALL_AFFECTED_DISPLACED_IN_TRANSIT = 6, "All/Affected/Displaced/In Transit" + ALL_AFFECTED_DISPLACED_MIGRANTS = 7, "All/Affected/Displaced/Migrants" + ALL_AFFECTED_DISPLACED_IDPS = 8, "All/Affected/Displaced/IDPs" + ALL_AFFECTED_DISPLACED_ASYLUM_SEEKER = 9, "All/Affected/Displaced/Asylum Seeker" + ALL_AFFECTED_DISPLACED_OTHER_OF_CONCERN = 10, "All/Affected/Displaced/Other of concerns" + ALL_AFFECTED_DISPLACED_RETURNEES = 11, "All/Affected/Displaced/Returnees" + ALL_AFFECTED_DISPLACED_REFUGEES = 12, "All/Affected/Displaced/Refugees" + ALL_AFFECTED_DISPLACED_MIGRANTS_IN_TRANSIT = 13, "All/Affected/Displaced/Migrants/In transit" + ALL_AFFECTED_DISPLACED_MIGRANTS_PERMANENTS = 14, "All/Affected/Displaced/Migrants/Permanents" + ALL_AFFECTED_DISPLACED_MIGRANTS_PENDULAR = 15, "All/Affected/Displaced/Migrants/Pendular" + ALL_AFFECTED_NOT_DISPLACED_NO_HOST = 16, "All/Affected/Not Displaced/Not Host" + ALL_AFFECTED_NOT_DISPLACED_HOST = 17, "All/Affected/Not Displaced/Host" class ProtectionRiskType(models.IntegerChoices): - ABDUCATION_KIDNAPPING = 1, \ - 'Abduction, kidnapping, enforced disappearance, arbitrary or unlawful arrest and/or detention' - ATTACKS_ON_CIVILIANS = 2, 'Attacks on civilians and other unlawful killings, and attacks on civilian objects' - CHILD_AND_FORCED = 3, 'Child and forced family separation' - EARLY_AND_FORCED_MARRIAGE = 4, 'Child, early or forced marriage' - DISCRIMINATION_AND_STIGMATIZATION = 5, \ - 'Discrimination and stigmatization, denial of resources, opportunities, services and/or humanitarian access' - DISINFORMATION_AND_DENIAL = 6, 'Disinformation and denial of access to information' - FORCED_RECRUITMENT = 7, 'Forced recruitment and association of children in armed forces and groups' - GENDER_BASED_VIOLENCE = 8, 'Gender-based violence' - IMPEDIMENTS_AND_RESTRICTIONS = 9, 'Impediments and/or restrictions to access to legal identity, remedies and justice' - PRESENCE_OF_MINE = 10, 'Presence of Mine and other explosive ordnance' - PSYCHOLOGICAL_INFLICATED_DISTRESS = 11, 'Psychological/emotional abuse or inflicted distress' - DESTRUCTION_OF_PERSONAL_PROPERTY = 12, 'Theft, extortion, forced eviction or destruction of personal property' - DEGRADING_TREATMENT = 13, 'Torture or cruel, inhuman, degrading treatment or punishment' - TRAFFICKING_IN_PERSONS = 14, 'Trafficking in persons, forced labour or slavery-like practices' - UNLAWFUL_IMPEDIMENTS = 15, \ - 'Unlawful impediments or restrictions to freedom of movement, siege and forced displacement' - - project = models.ForeignKey('project.Project', on_delete=models.CASCADE) + ABDUCATION_KIDNAPPING = 1, "Abduction, kidnapping, enforced disappearance, arbitrary or unlawful arrest and/or detention" + ATTACKS_ON_CIVILIANS = 2, "Attacks on civilians and other unlawful killings, and attacks on civilian objects" + CHILD_AND_FORCED = 3, "Child and forced family separation" + EARLY_AND_FORCED_MARRIAGE = 4, "Child, early or forced marriage" + DISCRIMINATION_AND_STIGMATIZATION = ( + 5, + "Discrimination and stigmatization, denial of resources, opportunities, services and/or humanitarian access", + ) + DISINFORMATION_AND_DENIAL = 6, "Disinformation and denial of access to information" + FORCED_RECRUITMENT = 7, "Forced recruitment and association of children in armed forces and groups" + GENDER_BASED_VIOLENCE = 8, "Gender-based violence" + IMPEDIMENTS_AND_RESTRICTIONS = 9, "Impediments and/or restrictions to access to legal identity, remedies and justice" + PRESENCE_OF_MINE = 10, "Presence of Mine and other explosive ordnance" + PSYCHOLOGICAL_INFLICATED_DISTRESS = 11, "Psychological/emotional abuse or inflicted distress" + DESTRUCTION_OF_PERSONAL_PROPERTY = 12, "Theft, extortion, forced eviction or destruction of personal property" + DEGRADING_TREATMENT = 13, "Torture or cruel, inhuman, degrading treatment or punishment" + TRAFFICKING_IN_PERSONS = 14, "Trafficking in persons, forced labour or slavery-like practices" + UNLAWFUL_IMPEDIMENTS = 15, "Unlawful impediments or restrictions to freedom of movement, siege and forced displacement" + + project = models.ForeignKey("project.Project", on_delete=models.CASCADE) lead = models.OneToOneField( - Lead, on_delete=models.CASCADE, + Lead, + on_delete=models.CASCADE, ) # Metadata Group # -- Background Fields @@ -206,8 +207,8 @@ class ProtectionRiskType(models.IntegerChoices): # -- Stakeholders stakeholders = models.ManyToManyField( Organization, - through='AssessmentRegistryOrganization', - through_fields=('assessment_registry', 'organization'), + through="AssessmentRegistryOrganization", + through_fields=("assessment_registry", "organization"), blank=True, ) @@ -224,20 +225,11 @@ class ProtectionRiskType(models.IntegerChoices): # -- Focus Sectors focuses = ArrayField(models.IntegerField(choices=FocusType.choices), default=list, blank=True) sectors = ArrayField(models.IntegerField(choices=SectorType.choices), default=list, blank=True) - protection_info_mgmts = ArrayField( - models.IntegerField(choices=ProtectionInfoType.choices), - default=list, blank=True - ) - protection_risks = ArrayField( - models.IntegerField(choices=ProtectionRiskType.choices), - default=list, blank=True - ) - affected_groups = ArrayField( - models.IntegerField(choices=AffectedGroupType.choices), - default=list, blank=True - ) + protection_info_mgmts = ArrayField(models.IntegerField(choices=ProtectionInfoType.choices), default=list, blank=True) + protection_risks = ArrayField(models.IntegerField(choices=ProtectionRiskType.choices), default=list, blank=True) + affected_groups = ArrayField(models.IntegerField(choices=AffectedGroupType.choices), default=list, blank=True) - locations = models.ManyToManyField(GeoArea, related_name='focus_location_assessment_reg', blank=True) + locations = models.ManyToManyField(GeoArea, related_name="focus_location_assessment_reg", blank=True) metadata_complete = models.BooleanField(default=False) additional_document_complete = models.BooleanField(default=False) focus_complete = models.BooleanField(default=False) @@ -258,11 +250,11 @@ def can_delete(self, user): class AssessmentRegistryOrganization(models.Model): class Type(models.IntegerChoices): - LEAD_ORGANIZATION = 1, 'Lead Organization' # Project Owner - INTERNATIONAL_PARTNER = 2, 'International Partners' - NATIONAL_PARTNER = 3, 'National Partners' - DONOR = 4, 'Donor' - GOVERNMENT = 5, 'Government' + LEAD_ORGANIZATION = 1, "Lead Organization" # Project Owner + INTERNATIONAL_PARTNER = 2, "International Partners" + NATIONAL_PARTNER = 3, "National Partners" + DONOR = 4, "Donor" + GOVERNMENT = 5, "Government" organization_type = models.IntegerField(choices=Type.choices) organization = models.ForeignKey(Organization, on_delete=models.CASCADE) @@ -271,57 +263,57 @@ class Type(models.IntegerChoices): class MethodologyAttribute(UserResource): class CollectionTechniqueType(models.IntegerChoices): - SECONDARY_DATA_REVIEW = 1, 'Secondary Data Review' - KEY_INFORMAT_INTERVIEW = 2, 'Key Informant Interview' - DIRECT_OBSERVATION = 3, 'Direct Observation' - COMMUNITY_GROUP_DISCUSSION = 4, 'Community Group Discussion' - FOCUS_GROUP_DISCUSSION = 5, 'Focus Group Discussion' - HOUSEHOLD_INTERVIEW = 6, 'Household Interview' - INDIVIDUAL_INTERVIEW = 7, 'Individual Interview' - SATELLITE_IMAGERY = 8, 'Satellite Imagery' + SECONDARY_DATA_REVIEW = 1, "Secondary Data Review" + KEY_INFORMAT_INTERVIEW = 2, "Key Informant Interview" + DIRECT_OBSERVATION = 3, "Direct Observation" + COMMUNITY_GROUP_DISCUSSION = 4, "Community Group Discussion" + FOCUS_GROUP_DISCUSSION = 5, "Focus Group Discussion" + HOUSEHOLD_INTERVIEW = 6, "Household Interview" + INDIVIDUAL_INTERVIEW = 7, "Individual Interview" + SATELLITE_IMAGERY = 8, "Satellite Imagery" class SamplingApproachType(models.IntegerChoices): - NON_RANDOM_SELECTION = 1, 'Non-Random Selection' - RANDOM_SELECTION = 2, 'Random Selection' - FULL_ENUMERATION = 3, 'Full Enumeration' + NON_RANDOM_SELECTION = 1, "Non-Random Selection" + RANDOM_SELECTION = 2, "Random Selection" + FULL_ENUMERATION = 3, "Full Enumeration" class ProximityType(models.IntegerChoices): - FACE_TO_FACE = 1, 'Face-to-Face' - REMOTE = 2, 'Remote' - MIXED = 3, 'Mixed' + FACE_TO_FACE = 1, "Face-to-Face" + REMOTE = 2, "Remote" + MIXED = 3, "Mixed" class UnitOfAnalysisType(models.IntegerChoices): - CRISIS = 1, 'Crisis' - COUNTRY = 2, 'Country' - REGION = 3, 'Region' - PROVINCE_GOV_PREFECTURE = 4, 'Province/governorate/prefecture' - DEPARTMENT_DISTRICT = 5, 'Department/District' - SUB_DISTRICT_COUNTRY = 6, 'Sub-District/Country' - MUNICIPALITY = 7, 'Municipality' - NEIGHBORHOOD_QUARTIER = 8, 'Neighborhood/Quartier' - COMMUNITY_SITE = 9, 'Community/Site' - AFFECTED_GROUP = 10, 'Affected group' - HOUSEHOLD = 11, 'Household' - INDIVIDUAL = 12, 'Individual' + CRISIS = 1, "Crisis" + COUNTRY = 2, "Country" + REGION = 3, "Region" + PROVINCE_GOV_PREFECTURE = 4, "Province/governorate/prefecture" + DEPARTMENT_DISTRICT = 5, "Department/District" + SUB_DISTRICT_COUNTRY = 6, "Sub-District/Country" + MUNICIPALITY = 7, "Municipality" + NEIGHBORHOOD_QUARTIER = 8, "Neighborhood/Quartier" + COMMUNITY_SITE = 9, "Community/Site" + AFFECTED_GROUP = 10, "Affected group" + HOUSEHOLD = 11, "Household" + INDIVIDUAL = 12, "Individual" class UnitOfReportingType(models.IntegerChoices): - CRISIS = 1, 'Crisis' - COUNTRY = 2, 'Country' - REGION = 3, 'Region' - PROVINCE_GOV_PREFECTURE = 4, 'Province/governorate/prefecture' - DEPARTMENT_DISTRICT = 5, 'Department/District' - SUB_DISTRICT_COUNTRY = 6, 'Sub-District/Country' - MUNICIPALITY = 7, 'Municipality' - NEIGHBORHOOD_QUARTIER = 8, 'Neighborhood/Quartier' - COMMUNITY_SITE = 9, 'Community/Site' - AFFECTED_GROUP = 10, 'Affected group' - HOUSEHOLD = 11, 'Household' - INDIVIDUAL = 12, 'Individual' + CRISIS = 1, "Crisis" + COUNTRY = 2, "Country" + REGION = 3, "Region" + PROVINCE_GOV_PREFECTURE = 4, "Province/governorate/prefecture" + DEPARTMENT_DISTRICT = 5, "Department/District" + SUB_DISTRICT_COUNTRY = 6, "Sub-District/Country" + MUNICIPALITY = 7, "Municipality" + NEIGHBORHOOD_QUARTIER = 8, "Neighborhood/Quartier" + COMMUNITY_SITE = 9, "Community/Site" + AFFECTED_GROUP = 10, "Affected group" + HOUSEHOLD = 11, "Household" + INDIVIDUAL = 12, "Individual" assessment_registry = models.ForeignKey( AssessmentRegistry, on_delete=models.CASCADE, - related_name='methodology_attributes', + related_name="methodology_attributes", ) data_collection_technique = models.IntegerField(choices=CollectionTechniqueType.choices, null=True, blank=True) sampling_approach = models.IntegerField(choices=SamplingApproachType.choices, null=True, blank=True) @@ -333,34 +325,29 @@ class UnitOfReportingType(models.IntegerChoices): class AdditionalDocument(UserResource): class DocumentType(models.IntegerChoices): - ASSESSMENT_DATABASE = 1, 'Assessment database' - QUESTIONNAIRE = 2, 'Questionnaire' - MISCELLANEOUS = 3, 'Miscellaneous' + ASSESSMENT_DATABASE = 1, "Assessment database" + QUESTIONNAIRE = 2, "Questionnaire" + MISCELLANEOUS = 3, "Miscellaneous" assessment_registry = models.ForeignKey( AssessmentRegistry, on_delete=models.CASCADE, - related_name='additional_documents', + related_name="additional_documents", ) document_type = models.IntegerField(choices=DocumentType.choices) - file = models.ForeignKey( - File, - on_delete=models.SET_NULL, - related_name='assessment_reg_file', - null=True, blank=True - ) + file = models.ForeignKey(File, on_delete=models.SET_NULL, related_name="assessment_reg_file", null=True, blank=True) external_link = models.URLField(max_length=500, blank=True) def __str__(self): - return f'FileID: {self.file_id}' + return f"FileID: {self.file_id}" class ScoreRating(UserResource): class AnalyticalStatement(models.IntegerChoices): - FIT_FOR_PURPOSE = 1, 'Fit for Purpose' - TRUSTWORTHINESS = 2, 'Trustworthiness' - ANALYTICAL_RIGOR = 3, 'Analytical Rigor' - ANALYTICAL_WRITING = 4, 'Analytical Writing' + FIT_FOR_PURPOSE = 1, "Fit for Purpose" + TRUSTWORTHINESS = 2, "Trustworthiness" + ANALYTICAL_RIGOR = 3, "Analytical Rigor" + ANALYTICAL_WRITING = 4, "Analytical Writing" class ScoreCriteria(models.IntegerChoices): RELEVANCE = 1, "Relevance" @@ -425,7 +412,7 @@ class RatingType(models.IntegerChoices): assessment_registry = models.ForeignKey( AssessmentRegistry, on_delete=models.CASCADE, - related_name='score_ratings', + related_name="score_ratings", ) score_type = models.IntegerField(choices=ScoreCriteria.choices) rating = models.IntegerField(choices=RatingType.choices, default=RatingType.FAIR) @@ -434,40 +421,39 @@ class RatingType(models.IntegerChoices): class ScoreAnalyticalDensity(UserResource): class AnalysisLevelCovered(models.IntegerChoices): - ISSUE_UNMET_NEEDS_ARE_DETAILED = 1, 'Issues/unmet needs are detailed' - ISSUE_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = 2, 'Issues/unmet needs are prioritized/ranked' - CAUSES_OR_UNDERLYING_MECHANISMS_BEHIND_ISSUES_UNMET_NEEDS_ARE_DETAILED = 3, \ - 'Causes or underlying mechanisms behind issues/unmet needs are detailed' - CAUSES_OR_UNDERLYING_MECHANISMS_BEHIND_ISSUES_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = 4, \ - 'Causes or underlying mechanisms behind issues/unmet needs are prioritized/ranked' - SEVERITY_OF_SOME_ALL_ISSUE_UNMET_NEEDS_IS_DETAILED = 5, \ - 'Severity of some/all issues/unmet_needs_is_detailed' - FUTURE_ISSUES_UNMET_NEEDS_ARE_DETAILED = 6, 'Future issues/unmet needs are detailed' - FUTURE_ISSUES_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = 7, 'Future issues/unmet needs are prioritized/ranked' - SEVERITY_OF_SOME_ALL_FUTURE_ISSUE_UNMET_NEEDS_IS_DETAILED = 8, \ - 'Severity of some/all future issues/unmet_needs_is_detailed' - RECOMMENDATIONS_INTERVENTIONS_ARE_DETAILED = 9, 'Recommendations/interventions are detailed' - RECOMMENDATIONS_INTERVENTIONS_ARE_PRIORITIZED_RANKED = 10, \ - 'Recommendations/interventions are prioritized/ranked' + ISSUE_UNMET_NEEDS_ARE_DETAILED = 1, "Issues/unmet needs are detailed" + ISSUE_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = 2, "Issues/unmet needs are prioritized/ranked" + CAUSES_OR_UNDERLYING_MECHANISMS_BEHIND_ISSUES_UNMET_NEEDS_ARE_DETAILED = ( + 3, + "Causes or underlying mechanisms behind issues/unmet needs are detailed", + ) + CAUSES_OR_UNDERLYING_MECHANISMS_BEHIND_ISSUES_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = ( + 4, + "Causes or underlying mechanisms behind issues/unmet needs are prioritized/ranked", + ) + SEVERITY_OF_SOME_ALL_ISSUE_UNMET_NEEDS_IS_DETAILED = 5, "Severity of some/all issues/unmet_needs_is_detailed" + FUTURE_ISSUES_UNMET_NEEDS_ARE_DETAILED = 6, "Future issues/unmet needs are detailed" + FUTURE_ISSUES_UNMET_NEEDS_ARE_PRIORITIZED_RANKED = 7, "Future issues/unmet needs are prioritized/ranked" + SEVERITY_OF_SOME_ALL_FUTURE_ISSUE_UNMET_NEEDS_IS_DETAILED = ( + 8, + "Severity of some/all future issues/unmet_needs_is_detailed", + ) + RECOMMENDATIONS_INTERVENTIONS_ARE_DETAILED = 9, "Recommendations/interventions are detailed" + RECOMMENDATIONS_INTERVENTIONS_ARE_PRIORITIZED_RANKED = 10, "Recommendations/interventions are prioritized/ranked" class FigureProvidedByAssessment(models.IntegerChoices): - TOTAL_POP_IN_THE_ASSESSED_AREAS = 1, 'Total population in the assessed areas' - TOTAL_POP_EXPOSED_TO_THE_SHOCK_EVENT = 2, 'Total population exposed to the shock/event' - TOTAL_POP_AFFECTED_LIVING_IN_THE_AFFECTED_AREAS = 3, \ - 'Total population affected/living in the affected area' - TOTAL_POP_FACING_HUMANITARIAN_ACCESS_CONSTRAINTS = 4, 'Total population facing humanitarian access constraints' - TOTAL_POP_IN_NEED = 5, 'Total population in need' - TOTAL_POP_IN_CRITICAL_NEED = 6, 'Total population in critical need' - TOTAL_POP_IN_SEVERE_NEED = 7, 'Total population in severe need' - TOTAL_POP_IN_MODERATE_NEED = 8, 'Total population in moderate need' - TOTAL_POP_AT_RISK_VULNERABLE = 9, 'Total population at risk/vulnerable' - TOTAL_POP_REACHED_BY_ASSISTANCE = 10, 'Total population reached by assistance' - - assessment_registry = models.ForeignKey( - AssessmentRegistry, - on_delete=models.CASCADE, - related_name='analytical_density' - ) + TOTAL_POP_IN_THE_ASSESSED_AREAS = 1, "Total population in the assessed areas" + TOTAL_POP_EXPOSED_TO_THE_SHOCK_EVENT = 2, "Total population exposed to the shock/event" + TOTAL_POP_AFFECTED_LIVING_IN_THE_AFFECTED_AREAS = 3, "Total population affected/living in the affected area" + TOTAL_POP_FACING_HUMANITARIAN_ACCESS_CONSTRAINTS = 4, "Total population facing humanitarian access constraints" + TOTAL_POP_IN_NEED = 5, "Total population in need" + TOTAL_POP_IN_CRITICAL_NEED = 6, "Total population in critical need" + TOTAL_POP_IN_SEVERE_NEED = 7, "Total population in severe need" + TOTAL_POP_IN_MODERATE_NEED = 8, "Total population in moderate need" + TOTAL_POP_AT_RISK_VULNERABLE = 9, "Total population at risk/vulnerable" + TOTAL_POP_REACHED_BY_ASSISTANCE = 10, "Total population reached by assistance" + + assessment_registry = models.ForeignKey(AssessmentRegistry, on_delete=models.CASCADE, related_name="analytical_density") sector = models.IntegerField(choices=AssessmentRegistry.SectorType.choices) analysis_level_covered = ArrayField(models.IntegerField(choices=AnalysisLevelCovered.choices), default=list, blank=True) figure_provided = ArrayField(models.IntegerField(choices=FigureProvidedByAssessment.choices), default=list, blank=True) @@ -476,59 +462,63 @@ class FigureProvidedByAssessment(models.IntegerChoices): class Question(UserResource): class QuestionSector(models.IntegerChoices): - RELEVANCE = 1, 'Relevance' - COMPREHENSIVENESS = 2, 'Comprehensiveness' - ETHICS = 3, 'Ethics' - METHODOLOGICAL_RIGOR = 4, 'Methodological rigor' - ANALYTICAL_VALUE = 5, 'Analytical value' - TIMELINESS = 6, 'Timeliness' - EFFECTIVE_COMMUNICATION = 7, 'Effective communication' - USE = 8, 'Use', - PEOPLE_CENTERED_AND_INCLUSIVE = 9, 'People-centred and inclusive' - ACCOUNTABILITY_TO_AFFECTED_POPULATIONS = 10, 'Accountability to Affected Populations' - DO_NOT_HARM = 11, 'Do no harm' - DESIGNED_WITH_PURPOSE = 12, 'Designed with a purpose' - COMPETENCY_AND_CAPACITY = 13, 'Competency and capacity' - IMPARTIALITY = 14, 'Impartiality' - COORDINATION_AND_DATA_MINIMIZATION = 15, 'Coordination and data minimization' - JOINT_ANALYSIS = 16, 'Joint analysis' - ACKNOWLEDGE_DISSENTING_VOICES_IN_JOINT_NEEDS_ANALYSIS = 17, 'Acknowledge dissenting voices in joint needs analysis' - IFORMED_CONSENT_CONFIDENTIALITY_AND_DATA_SECURITY = 18, 'Informed consent, confidentiality and data security' - SHARING_RESULTS = 19, 'Sharing results (data and analysis)' - TRANSPARENCY_BETWEEN_ACTORS = 20, 'Transparency between actors' - MINIMUM_TECHNICAL_STANDARDS = 21, 'Minimum technical standards' + RELEVANCE = 1, "Relevance" + COMPREHENSIVENESS = 2, "Comprehensiveness" + ETHICS = 3, "Ethics" + METHODOLOGICAL_RIGOR = 4, "Methodological rigor" + ANALYTICAL_VALUE = 5, "Analytical value" + TIMELINESS = 6, "Timeliness" + EFFECTIVE_COMMUNICATION = 7, "Effective communication" + USE = ( + 8, + "Use", + ) + PEOPLE_CENTERED_AND_INCLUSIVE = 9, "People-centred and inclusive" + ACCOUNTABILITY_TO_AFFECTED_POPULATIONS = 10, "Accountability to Affected Populations" + DO_NOT_HARM = 11, "Do no harm" + DESIGNED_WITH_PURPOSE = 12, "Designed with a purpose" + COMPETENCY_AND_CAPACITY = 13, "Competency and capacity" + IMPARTIALITY = 14, "Impartiality" + COORDINATION_AND_DATA_MINIMIZATION = 15, "Coordination and data minimization" + JOINT_ANALYSIS = 16, "Joint analysis" + ACKNOWLEDGE_DISSENTING_VOICES_IN_JOINT_NEEDS_ANALYSIS = 17, "Acknowledge dissenting voices in joint needs analysis" + IFORMED_CONSENT_CONFIDENTIALITY_AND_DATA_SECURITY = 18, "Informed consent, confidentiality and data security" + SHARING_RESULTS = 19, "Sharing results (data and analysis)" + TRANSPARENCY_BETWEEN_ACTORS = 20, "Transparency between actors" + MINIMUM_TECHNICAL_STANDARDS = 21, "Minimum technical standards" class QuestionSubSector(models.IntegerChoices): - RELEVANCE = 1, 'Relevance' - GEOGRAPHIC_COMPREHENSIVENESS = 2, 'Geographic comprehensiveness' - SECTORAL_COMPREHENSIVENESS = 3, 'Sectoral comprehensiveness' - AFFECTED_AND_VULNERABLE_GROUPS_COMPREHENSIVENESS = 4, 'Affected and vulnerable groups comprehensiveness' - SAFETY_AND_PROTECTION = 5, 'Safety and protection' - HUMANITARIAN_PRINCIPLES = 6, 'Humanitarian principles' - CONTRIBUTION = 7, 'Contribution' - TRANSPARENCY = 8, 'Transparency' - MITIGATING_BIAS = 9, 'Mitigating bias' - PARTICIPATION = 10, 'Participation' - CONTEXT_SPECIFICITY = 11, 'Context specificity' - ANALYTICAL_STANDARDS = 12, 'Analytical standards' - DESCRIPTIONS = 13, 'Description' - EXPLANATION = 14, 'Explanation' - INTERPRETATION = 15, 'Interpretation' - ANTICIPATION = 16, 'Anticipation' - TIMELINESS = 17, 'Timeliness' - USER_FRIENDLY_PRESENTATION = 18, 'User-friendly presentation' - ACTIVE_DISSEMINATION = 19, 'Active dissemination' - USE_FOR_COLLECTIVE_PLANNING = 20, 'Use for collective planning' - BUY_IN_AND_USE_BY_HUMANITARIAN_CLUSTERS_SECTORS = 21, 'Buy-in and use by humanitarian clusters / sectors' - BUY_IN_AND_USE_BY_UN_AGENCIES = 22, 'Buy-in and use by UN agencies' - BUY_IN_AND_USE_BY_INTERNATIONAL_NGO = 23, 'Buy-in and use by international non-governmental organizations (INGOs)' - BUY_IN_AND_USE_BY_LOCAL_NGO = 24, 'Buy-in and use by local non-governmental organizations (local NGOs)' - BUY_IN_AND_USE_BY_MEMBER_OF_RED_CROSS_RED_CRESENT_MOVEMENT = 25, \ - 'Buy-in and use by members of the Red Cross / Red Crescent Movement' - BUY_IN_AND_USE_BY_DONORS = 26, 'Buy-in and use by donors' - BUY_IN_AND_USE_BY_NATIONAL_AND_LOCAL_GOVERNMENT_AGENCIES = 27, \ - 'Buy-in and use by national and local government agencies' - BUY_IN_AND_USE_BY_DEVELOPMENT_AND_STABILIZATION_ACTORS = 28, 'Buy-in and use by development and stabilization actors' + RELEVANCE = 1, "Relevance" + GEOGRAPHIC_COMPREHENSIVENESS = 2, "Geographic comprehensiveness" + SECTORAL_COMPREHENSIVENESS = 3, "Sectoral comprehensiveness" + AFFECTED_AND_VULNERABLE_GROUPS_COMPREHENSIVENESS = 4, "Affected and vulnerable groups comprehensiveness" + SAFETY_AND_PROTECTION = 5, "Safety and protection" + HUMANITARIAN_PRINCIPLES = 6, "Humanitarian principles" + CONTRIBUTION = 7, "Contribution" + TRANSPARENCY = 8, "Transparency" + MITIGATING_BIAS = 9, "Mitigating bias" + PARTICIPATION = 10, "Participation" + CONTEXT_SPECIFICITY = 11, "Context specificity" + ANALYTICAL_STANDARDS = 12, "Analytical standards" + DESCRIPTIONS = 13, "Description" + EXPLANATION = 14, "Explanation" + INTERPRETATION = 15, "Interpretation" + ANTICIPATION = 16, "Anticipation" + TIMELINESS = 17, "Timeliness" + USER_FRIENDLY_PRESENTATION = 18, "User-friendly presentation" + ACTIVE_DISSEMINATION = 19, "Active dissemination" + USE_FOR_COLLECTIVE_PLANNING = 20, "Use for collective planning" + BUY_IN_AND_USE_BY_HUMANITARIAN_CLUSTERS_SECTORS = 21, "Buy-in and use by humanitarian clusters / sectors" + BUY_IN_AND_USE_BY_UN_AGENCIES = 22, "Buy-in and use by UN agencies" + BUY_IN_AND_USE_BY_INTERNATIONAL_NGO = 23, "Buy-in and use by international non-governmental organizations (INGOs)" + BUY_IN_AND_USE_BY_LOCAL_NGO = 24, "Buy-in and use by local non-governmental organizations (local NGOs)" + BUY_IN_AND_USE_BY_MEMBER_OF_RED_CROSS_RED_CRESENT_MOVEMENT = ( + 25, + "Buy-in and use by members of the Red Cross / Red Crescent Movement", + ) + BUY_IN_AND_USE_BY_DONORS = 26, "Buy-in and use by donors" + BUY_IN_AND_USE_BY_NATIONAL_AND_LOCAL_GOVERNMENT_AGENCIES = 27, "Buy-in and use by national and local government agencies" + BUY_IN_AND_USE_BY_DEVELOPMENT_AND_STABILIZATION_ACTORS = 28, "Buy-in and use by development and stabilization actors" QUESTION_SECTOR_SUB_SECTOR_MAP = { QuestionSector.RELEVANCE: [ @@ -636,12 +626,12 @@ def clean(self): sub_sector = self.sub_sector # NOTE We are adding this validation here because Question are added from Admin Panel. - if hasattr(self, 'sector'): + if hasattr(self, "sector"): print("Sector Str", str(sector)) - if hasattr(self, 'sub_sector'): + if hasattr(self, "sub_sector"): print("Sub Sector Str", str(sub_sector)) if sub_sector not in Question.QUESTION_SECTOR_SUB_SECTOR_MAP[sector]: - raise ValidationError('Invalid sub-sector selected for given sector provided') + raise ValidationError("Invalid sub-sector selected for given sector provided") def save(self, *args, **kwargs): self.full_clean() @@ -655,15 +645,8 @@ def __str__(self): class Answer(UserResource): - assessment_registry = models.ForeignKey( - AssessmentRegistry, - on_delete=models.CASCADE, - related_name='answer' - ) - question = models.ForeignKey( - Question, - on_delete=models.CASCADE - ) + assessment_registry = models.ForeignKey(AssessmentRegistry, on_delete=models.CASCADE, related_name="answer") + question = models.ForeignKey(Question, on_delete=models.CASCADE) answer = models.BooleanField() class Meta: @@ -675,13 +658,13 @@ def __str__(self): class Summary(UserResource): class Pillar(models.IntegerChoices): - CONTEXT = 1, 'Context' - EVENT_SHOCK = 2, 'Event/Shock' - DISPLACEMENT = 3, 'Displacement' - INFORMATION_AND_COMMUNICATION = 4, 'Information & Communication' - HUMANITARIAN_ACCESS = 5, 'Humanitarian Access' + CONTEXT = 1, "Context" + EVENT_SHOCK = 2, "Event/Shock" + DISPLACEMENT = 3, "Displacement" + INFORMATION_AND_COMMUNICATION = 4, "Information & Communication" + HUMANITARIAN_ACCESS = 5, "Humanitarian Access" - assessment_registry = models.OneToOneField(AssessmentRegistry, related_name='summary', on_delete=models.CASCADE) + assessment_registry = models.OneToOneField(AssessmentRegistry, related_name="summary", on_delete=models.CASCADE) total_people_assessed = models.IntegerField(null=True, blank=True) total_dead = models.IntegerField(null=True, blank=True) total_injured = models.IntegerField(null=True, blank=True) @@ -692,15 +675,9 @@ class Pillar(models.IntegerChoices): class SummarySubPillarIssue(UserResource): assessment_registry = models.ForeignKey( - AssessmentRegistry, - on_delete=models.CASCADE, - related_name='summary_sub_sector_issue_ary' - ) - summary_issue = models.ForeignKey( - 'SummaryIssue', - on_delete=models.CASCADE, - related_name='summary_subsector_issue' + AssessmentRegistry, on_delete=models.CASCADE, related_name="summary_sub_sector_issue_ary" ) + summary_issue = models.ForeignKey("SummaryIssue", on_delete=models.CASCADE, related_name="summary_subsector_issue") text = models.TextField(blank=True) order = models.IntegerField() lead_preview_text_ref = models.JSONField(default=None, blank=True, null=True) @@ -708,17 +685,13 @@ class SummarySubPillarIssue(UserResource): class SummaryFocus(UserResource): class Dimension(models.IntegerChoices): - IMPACT = 1, 'Impact' - HUMANITARIAN_CONDITIONS = 2, 'Humanitarian Conditions' - PRIORITIES_AND_PREFERENCES = 3, 'Priorities & Preferences' - CONCLUSIONS = 4, 'Conclusions' - HUMANITARIAN_POPULATION_FIGURES = 5, 'Humanitarian Population Figures' + IMPACT = 1, "Impact" + HUMANITARIAN_CONDITIONS = 2, "Humanitarian Conditions" + PRIORITIES_AND_PREFERENCES = 3, "Priorities & Preferences" + CONCLUSIONS = 4, "Conclusions" + HUMANITARIAN_POPULATION_FIGURES = 5, "Humanitarian Population Figures" - assessment_registry = models.ForeignKey( - AssessmentRegistry, - on_delete=models.CASCADE, - related_name='summary_focus' - ) + assessment_registry = models.ForeignKey(AssessmentRegistry, on_delete=models.CASCADE, related_name="summary_focus") sector = models.IntegerField(choices=AssessmentRegistry.SectorType.choices) percentage_of_people_affected = models.IntegerField(null=True, blank=True) @@ -738,44 +711,44 @@ class Meta: class SummaryIssue(models.Model): class SubPillar(models.IntegerChoices): - POLITICS = 1, 'Politics' - DEMOGRAPHY = 2, 'Demography' - SOCIO_CULTURAL = 3, 'Socio-cultural' - ENVIRONMENT = 4, 'Environment' - SECURITY_AND_STABILITY = 5, 'Security & Stability' - ECONOMICS = 6, 'Economics' - CHARACTERISTICS = 7, 'Characteristics' - DRIVERS_AND_AGGRAVATING_FACTORS = 8, 'Drivers and Aggravating Factors' - MITIGATING_FACTORS = 9, 'Mitigating factors' - HAZARDS_AND_THREATS = 10, 'Hazards & Threats' - DISPLACEMENT_CHARACTERISTICS = 11, 'Characteristics' - PUSH_FACTORS = 12, 'Push factors' - PULL_FACTORS = 13, 'Pull factors' - INTENTIONS = 14, 'Intentions' - LOCAL_INTREGATIONS = 15, 'Local Integrations' - SOURCE_AND_MEANS = 16, 'Source & Means' - CHALLANGES_AND_BARRIERS = 17, 'Challenges & Barriers' - KNOWLEDGE_AND_INFO_GAPS_HUMAN = 18, 'Knowledge & Info Gaps (by humanitarians)' - KNOWLEDGE_AND_INFO_GAPS_POP = 19, 'Knowledge & Info Gaps (by population)' - POPULATION_TO_RELIEF = 20, 'Population to Relief' - RELIEF_TO_POPULATION = 21, 'Relief to Population' + POLITICS = 1, "Politics" + DEMOGRAPHY = 2, "Demography" + SOCIO_CULTURAL = 3, "Socio-cultural" + ENVIRONMENT = 4, "Environment" + SECURITY_AND_STABILITY = 5, "Security & Stability" + ECONOMICS = 6, "Economics" + CHARACTERISTICS = 7, "Characteristics" + DRIVERS_AND_AGGRAVATING_FACTORS = 8, "Drivers and Aggravating Factors" + MITIGATING_FACTORS = 9, "Mitigating factors" + HAZARDS_AND_THREATS = 10, "Hazards & Threats" + DISPLACEMENT_CHARACTERISTICS = 11, "Characteristics" + PUSH_FACTORS = 12, "Push factors" + PULL_FACTORS = 13, "Pull factors" + INTENTIONS = 14, "Intentions" + LOCAL_INTREGATIONS = 15, "Local Integrations" + SOURCE_AND_MEANS = 16, "Source & Means" + CHALLANGES_AND_BARRIERS = 17, "Challenges & Barriers" + KNOWLEDGE_AND_INFO_GAPS_HUMAN = 18, "Knowledge & Info Gaps (by humanitarians)" + KNOWLEDGE_AND_INFO_GAPS_POP = 19, "Knowledge & Info Gaps (by population)" + POPULATION_TO_RELIEF = 20, "Population to Relief" + RELIEF_TO_POPULATION = 21, "Relief to Population" class SubDimension(models.IntegerChoices): - DRIVERS = 1, 'Drivers' - IMPACT_ON_PEOPLE = 2, 'Impact on people' - IMPACT_ON_SYSTEM = 3, 'Impact on System, Network and services' - LIVING_STANDARDS = 4, 'Living standards' - COPING_MECHANISMS = 5, 'Coping mechanisms' - PHYSICAL_AND_MENTAL_WELL_BEING = 6, 'Physical and Mental well being' - NEEDS_POP = 7, 'Needs (by population)' - NEEDS_HUMAN = 8, 'Needs (by humanitarians)' - INTERVENTIONS_POP = 9, 'Interventions (by population)' - INTERVENTIONS_HUMAN = 10, 'Interventions (by humanitarians)' - DEMOGRAPHIC_GROUPS = 11, 'Demographic Groups' - GROUPS_WITH_SPECIFIC_NEEDS = 12, 'Groups with Specific Needs' - GEOGRAPHICAL_AREAS = 13, 'Geographical areas' - PEOPLE_AT_RISKS = 14, 'People at risk' - FOCAL_ISSUES = 15, 'Focal issues' + DRIVERS = 1, "Drivers" + IMPACT_ON_PEOPLE = 2, "Impact on people" + IMPACT_ON_SYSTEM = 3, "Impact on System, Network and services" + LIVING_STANDARDS = 4, "Living standards" + COPING_MECHANISMS = 5, "Coping mechanisms" + PHYSICAL_AND_MENTAL_WELL_BEING = 6, "Physical and Mental well being" + NEEDS_POP = 7, "Needs (by population)" + NEEDS_HUMAN = 8, "Needs (by humanitarians)" + INTERVENTIONS_POP = 9, "Interventions (by population)" + INTERVENTIONS_HUMAN = 10, "Interventions (by humanitarians)" + DEMOGRAPHIC_GROUPS = 11, "Demographic Groups" + GROUPS_WITH_SPECIFIC_NEEDS = 12, "Groups with Specific Needs" + GEOGRAPHICAL_AREAS = 13, "Geographical areas" + PEOPLE_AT_RISKS = 14, "People at risk" + FOCAL_ISSUES = 15, "Focal issues" PILLAR_SUB_PILLAR_MAP = { Summary.Pillar.CONTEXT: [ @@ -808,7 +781,7 @@ class SubDimension(models.IntegerChoices): Summary.Pillar.HUMANITARIAN_ACCESS: [ SubPillar.POPULATION_TO_RELIEF, SubPillar.RELIEF_TO_POPULATION, - ] + ], } DIMMENSION_SUB_DIMMENSION_MAP = { @@ -840,7 +813,7 @@ class SubDimension(models.IntegerChoices): sub_pillar = models.IntegerField(choices=SubPillar.choices, blank=True, null=True) sub_dimension = models.IntegerField(choices=SubDimension.choices, blank=True, null=True) parent = models.ForeignKey( - 'SummaryIssue', + "SummaryIssue", on_delete=models.CASCADE, blank=True, null=True, @@ -854,16 +827,10 @@ def __str__(self): class SummarySubDimensionIssue(UserResource): assessment_registry = models.ForeignKey( - AssessmentRegistry, - on_delete=models.CASCADE, - related_name='summary_focus_subsector_issue_ary' + AssessmentRegistry, on_delete=models.CASCADE, related_name="summary_focus_subsector_issue_ary" ) sector = models.IntegerField(choices=AssessmentRegistry.SectorType.choices) - summary_issue = models.ForeignKey( - SummaryIssue, - on_delete=models.CASCADE, - related_name='summary_focus_subsector_issue' - ) + summary_issue = models.ForeignKey(SummaryIssue, on_delete=models.CASCADE, related_name="summary_focus_subsector_issue") text = models.TextField(blank=True) order = models.IntegerField() lead_preview_text_ref = models.JSONField(default=None, blank=True, null=True) diff --git a/apps/assessment_registry/mutation.py b/apps/assessment_registry/mutation.py index 27077ad5e9..0dacdff82a 100644 --- a/apps/assessment_registry/mutation.py +++ b/apps/assessment_registry/mutation.py @@ -1,31 +1,26 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, GrapheneMutation, + PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP -from utils.graphene.mutation import PsDeleteMutation from .models import AssessmentRegistry, SummaryIssue -from .schema import AssessmentRegistryType, AssessmentRegistrySummaryIssueType -from .serializers import ( - AssessmentRegistrySerializer, - IssueSerializer, -) +from .schema import AssessmentRegistrySummaryIssueType, AssessmentRegistryType +from .serializers import AssessmentRegistrySerializer, IssueSerializer AssessmentRegistryCreateInputType = generate_input_type_for_serializer( - 'AssessmentRegistryCreateInputType', - serializer_class=AssessmentRegistrySerializer + "AssessmentRegistryCreateInputType", serializer_class=AssessmentRegistrySerializer ) AssessmentRegistrySummaryIssueCreateInputType = generate_input_type_for_serializer( - 'AssessmentRegistrySummaryIssueCreateInputType', - serializer_class=IssueSerializer + "AssessmentRegistrySummaryIssueCreateInputType", serializer_class=IssueSerializer ) -class AssessmentRegsitryMutationMixin(): +class AssessmentRegsitryMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(project=info.context.active_project) @@ -68,16 +63,17 @@ class Arguments: class DeleteAssessmentRegistry(AssessmentRegsitryMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = AssessmentRegistry result = graphene.Field(AssessmentRegistryType) permissions = [PP.Permission.DELETE_LEAD] -class ProjectMutation(): +class ProjectMutation: create_assessment_registry = CreateAssessmentRegistry.Field() update_assessment_registry = UpdateAssessmentRegistry.Field() delete_assessment_registry = DeleteAssessmentRegistry.Field() -class Mutation(): +class Mutation: create_assessment_reg_summary_issue = AssessmentRegistryCreateIssue.Field() diff --git a/apps/assessment_registry/schema.py b/apps/assessment_registry/schema.py index a42180b1f0..23acde4f2f 100644 --- a/apps/assessment_registry/schema.py +++ b/apps/assessment_registry/schema.py @@ -1,68 +1,67 @@ import graphene +from geo.schema import ProjectGeoAreaType from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from lead.schema import LeadDetailType +from user_resource.schema import UserResourceMixin +from deep.permissions import ProjectPermissions as PP from utils.common import render_string_for_graphql -from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType +from utils.graphene.enums import EnumDescription from utils.graphene.fields import DjangoPaginatedListObjectField from utils.graphene.pagination import NoOrderingPageGraphqlPagination -from utils.graphene.enums import EnumDescription -from deep.permissions import ProjectPermissions as PP -from user_resource.schema import UserResourceMixin - -from lead.schema import LeadDetailType -from geo.schema import ProjectGeoAreaType +from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType -from .models import ( - AssessmentRegistry, - MethodologyAttribute, - AdditionalDocument, - Summary, - SummarySubPillarIssue, - SummaryIssue, - SummaryFocus, - SummarySubDimensionIssue, - ScoreRating, - ScoreAnalyticalDensity, - Question, - Answer, - AssessmentRegistryOrganization, -) -from .filters import AssessmentRegistryGQFilterSet, AssessmentRegistryIssueGQFilterSet from .enums import ( - AssessmentRegistryCrisisTypeEnum, - AssessmentRegistryPreparednessTypeEnum, - AssessmentRegistryExternalSupportTypeEnum, + AssessmentRegistryAffectedGroupTypeEnum, + AssessmentRegistryAnalysisFigureTypeEnum, + AssessmentRegistryAnalysisLevelTypeEnum, + AssessmentRegistryCNAQuestionSectorTypeEnum, + AssessmentRegistryCNAQuestionSubSectorTypeEnum, + AssessmentRegistryConfidentialityTypeEnum, AssessmentRegistryCoordinationTypeEnum, + AssessmentRegistryCrisisTypeEnum, + AssessmentRegistryDataCollectionTechniqueTypeEnum, AssessmentRegistryDetailTypeEnum, + AssessmentRegistryDocumentTypeEnum, + AssessmentRegistryExternalSupportTypeEnum, AssessmentRegistryFamilyTypeEnum, + AssessmentRegistryFocusTypeEnum, AssessmentRegistryFrequencyTypeEnum, - AssessmentRegistryConfidentialityTypeEnum, AssessmentRegistryLanguageTypeEnum, - AssessmentRegistryFocusTypeEnum, - AssessmentRegistrySectorTypeEnum, + AssessmentRegistryOrganizationTypeEnum, + AssessmentRegistryPreparednessTypeEnum, AssessmentRegistryProtectionInfoTypeEnum, AssessmentRegistryProtectionRiskTypeEnum, - AssessmentRegistryStatusTypeEnum, - AssessmentRegistryAffectedGroupTypeEnum, - AssessmentRegistryDataCollectionTechniqueTypeEnum, - AssessmentRegistrySamplingApproachTypeEnum, AssessmentRegistryProximityTypeEnum, - AssessmentRegistryUnitOfAnalysisTypeEnum, - AssessmentRegistryUnitOfReportingTypeEnum, - AssessmentRegistryDocumentTypeEnum, + AssessmentRegistryRatingTypeEnum, + AssessmentRegistrySamplingApproachTypeEnum, AssessmentRegistryScoreAnalyticalStatementTypeEnum, AssessmentRegistryScoreCriteriaTypeEnum, - AssessmentRegistryRatingTypeEnum, - AssessmentRegistryAnalysisLevelTypeEnum, - AssessmentRegistryAnalysisFigureTypeEnum, - AssessmentRegistryCNAQuestionSectorTypeEnum, - AssessmentRegistryCNAQuestionSubSectorTypeEnum, - AssessmentRegistrySummaryPillarTypeEnum, - AssessmentRegistrySummarySubPillarTypeEnum, + AssessmentRegistrySectorTypeEnum, + AssessmentRegistryStatusTypeEnum, AssessmentRegistrySummaryFocusDimensionTypeEnum, + AssessmentRegistrySummaryPillarTypeEnum, AssessmentRegistrySummarySubDimensionTypeEnum, - AssessmentRegistryOrganizationTypeEnum, + AssessmentRegistrySummarySubPillarTypeEnum, + AssessmentRegistryUnitOfAnalysisTypeEnum, + AssessmentRegistryUnitOfReportingTypeEnum, +) +from .filters import AssessmentRegistryGQFilterSet, AssessmentRegistryIssueGQFilterSet +from .models import ( + AdditionalDocument, + Answer, + AssessmentRegistry, + AssessmentRegistryOrganization, + MethodologyAttribute, + Question, + ScoreAnalyticalDensity, + ScoreRating, + Summary, + SummaryFocus, + SummaryIssue, + SummarySubDimensionIssue, + SummarySubPillarIssue, ) @@ -70,12 +69,12 @@ class AssessmentRegistryOrganizationType(DjangoObjectType, UserResourceMixin, Cl class Meta: model = AssessmentRegistryOrganization only_fields = ( - 'id', - 'organization', + "id", + "organization", ) organization_type = graphene.Field(AssessmentRegistryOrganizationTypeEnum, required=True) - organization_type_display = EnumDescription(source='get_organization_type_display', required=True) + organization_type_display = EnumDescription(source="get_organization_type_display", required=True) @staticmethod def resolve_organization(root, info): @@ -86,15 +85,15 @@ class QuestionType(DjangoObjectType, UserResourceMixin): class Meta: model = Question only_fields = ( - 'id', - 'question', + "id", + "question", ) sector = graphene.Field(AssessmentRegistryCNAQuestionSectorTypeEnum, required=False) - sector_display = EnumDescription(source='get_sector_display', required=False) + sector_display = EnumDescription(source="get_sector_display", required=False) sub_sector = graphene.Field(AssessmentRegistryCNAQuestionSubSectorTypeEnum, required=False) - sub_sector_display = EnumDescription(source='get_sub_sector_display', required=False) + sub_sector_display = EnumDescription(source="get_sub_sector_display", required=False) class SummaryOptionType(graphene.ObjectType): @@ -172,25 +171,26 @@ class ScoreRatingType(DjangoObjectType, UserResourceMixin, ClientIdMixin): class Meta: model = ScoreRating only_fields = ( - 'id', - 'reason', + "id", + "reason", ) score_type = graphene.Field(AssessmentRegistryScoreCriteriaTypeEnum, required=True) - score_type_display = EnumDescription(source='get_score_type_display', required=True) + score_type_display = EnumDescription(source="get_score_type_display", required=True) rating = graphene.Field(AssessmentRegistryRatingTypeEnum, required=True) - rating_display = EnumDescription(source='get_rating_display', required=True) + rating_display = EnumDescription(source="get_rating_display", required=True) class ScoreAnalyticalDensityType(DjangoObjectType, UserResourceMixin, ClientIdMixin): class Meta: model = ScoreAnalyticalDensity only_fields = ( - 'id', 'score', + "id", + "score", ) sector = graphene.Field(AssessmentRegistrySectorTypeEnum, required=True) - sector_display = EnumDescription(source='get_sector_display', required=True) + sector_display = EnumDescription(source="get_sector_display", required=True) analysis_level_covered = graphene.List(graphene.NonNull(AssessmentRegistryAnalysisLevelTypeEnum), required=True) figure_provided = graphene.List(graphene.NonNull(AssessmentRegistryAnalysisFigureTypeEnum), required=True) @@ -209,32 +209,32 @@ class MethodologyAttributeType(DjangoObjectType, UserResourceMixin, ClientIdMixi class Meta: model = MethodologyAttribute only_fields = ( - 'id', - 'sampling_size', + "id", + "sampling_size", ) data_collection_technique = graphene.Field(AssessmentRegistryDataCollectionTechniqueTypeEnum, required=False) - data_collection_technique_display = EnumDescription(source='get_data_collection_technique_display', required=False) + data_collection_technique_display = EnumDescription(source="get_data_collection_technique_display", required=False) sampling_approach = graphene.Field(AssessmentRegistrySamplingApproachTypeEnum, required=False) - sampling_appraoch_display = EnumDescription(source='get_sampling_approach_display', required=False) + sampling_appraoch_display = EnumDescription(source="get_sampling_approach_display", required=False) proximity = graphene.Field(AssessmentRegistryProximityTypeEnum, required=False) - proximity_display = EnumDescription(source='get_proximity_display', required=False) + proximity_display = EnumDescription(source="get_proximity_display", required=False) unit_of_analysis = graphene.Field(AssessmentRegistryUnitOfAnalysisTypeEnum, required=False) - unit_of_analysis_display = EnumDescription(source='get_unit_of_analysis_display', required=False) + unit_of_analysis_display = EnumDescription(source="get_unit_of_analysis_display", required=False) unit_of_reporting = graphene.Field(AssessmentRegistryUnitOfReportingTypeEnum, required=False) - unit_of_reporting_display = EnumDescription(source='get_unit_of_reporting_display', required=False) + unit_of_reporting_display = EnumDescription(source="get_unit_of_reporting_display", required=False) class AdditionalDocumentType(DjangoObjectType, UserResourceMixin, ClientIdMixin): class Meta: model = AdditionalDocument only_fields = ( - 'id', - 'file', + "id", + "file", ) document_type = graphene.Field(AssessmentRegistryDocumentTypeEnum, required=True) - document_type_display = EnumDescription(source='get_document_type_display', required=True) + document_type_display = EnumDescription(source="get_document_type_display", required=True) external_link = graphene.String(required=False) def resolve_external_link(root, info, **kwargs): @@ -246,32 +246,32 @@ def resolve_file(root, info, **kwargs): class CNAType(DjangoObjectType, UserResourceMixin, ClientIdMixin): - question = graphene.Field(QuestionType, required=True) # TODO: Dataloader + question = graphene.Field(QuestionType, required=True) # TODO: Dataloader class Meta: model = Answer only_fields = ( - 'id', - 'question', - 'answer', + "id", + "question", + "answer", ) class AssessmentRegistrySummaryIssueType(DjangoObjectType, UserResourceMixin): sub_pillar = graphene.Field(AssessmentRegistrySummarySubPillarTypeEnum, required=False) - sub_pillar_display = EnumDescription(source='get_sub_pillar_display', required=False) + sub_pillar_display = EnumDescription(source="get_sub_pillar_display", required=False) sub_dimension = graphene.Field(AssessmentRegistrySummarySubDimensionTypeEnum, required=False) - sub_dimension_display = EnumDescription(source='get_sub_dimension_display', required=False) + sub_dimension_display = EnumDescription(source="get_sub_dimension_display", required=False) child_count = graphene.Int(required=True) level = graphene.Int(required=False) class Meta: model = SummaryIssue only_fields = [ - 'id', - 'parent', # TODO: Dataloader - 'label', - 'full_label', + "id", + "parent", # TODO: Dataloader + "label", + "full_label", ] @staticmethod @@ -293,13 +293,13 @@ class SummaryMetaType(DjangoObjectType, UserResourceMixin): class Meta: model = Summary only_fields = [ - 'id', - 'total_people_assessed', - 'total_dead', - 'total_injured', - 'total_missing', - 'total_people_facing_hum_access_cons', - 'percentage_of_people_facing_hum_access_cons', + "id", + "total_people_assessed", + "total_dead", + "total_injured", + "total_missing", + "total_people_facing_hum_access_cons", + "percentage_of_people_facing_hum_access_cons", ] @@ -307,11 +307,11 @@ class SummarySubPillarIssueType(DjangoObjectType, UserResourceMixin, ClientIdMix class Meta: model = SummarySubPillarIssue only_fields = [ - 'id', - 'text', - 'order', - 'summary_issue', - 'lead_preview_text_ref', + "id", + "text", + "order", + "summary_issue", + "lead_preview_text_ref", ] @staticmethod @@ -321,44 +321,44 @@ def resolve_summary_issue(root, info, **kwargs): class SummaryFocusMetaType(DjangoObjectType, UserResourceMixin, ClientIdMixin): sector = graphene.Field(AssessmentRegistrySectorTypeEnum, required=False) - sector_display = EnumDescription(source='get_sector_display', required=False) + sector_display = EnumDescription(source="get_sector_display", required=False) class Meta: model = SummaryFocus only_fields = [ - 'id', - 'percentage_of_people_affected', - 'total_people_affected', - 'percentage_of_moderate', - 'percentage_of_severe', - 'percentage_of_critical', - 'percentage_in_need', - 'total_moderate', - 'total_severe', - 'total_critical', - 'total_in_need', - 'total_pop_assessed', - 'total_not_affected', - 'total_affected', - 'total_people_in_need', - 'total_people_moderately_in_need', - 'total_people_severly_in_need', - 'total_people_critically_in_need', + "id", + "percentage_of_people_affected", + "total_people_affected", + "percentage_of_moderate", + "percentage_of_severe", + "percentage_of_critical", + "percentage_in_need", + "total_moderate", + "total_severe", + "total_critical", + "total_in_need", + "total_pop_assessed", + "total_not_affected", + "total_affected", + "total_people_in_need", + "total_people_moderately_in_need", + "total_people_severly_in_need", + "total_people_critically_in_need", ] class SummaryFocusSubDimensionIssueType(DjangoObjectType, UserResourceMixin, ClientIdMixin): sector = graphene.Field(AssessmentRegistrySectorTypeEnum, required=True) - sector_display = EnumDescription(source='get_sector_display', required=True) + sector_display = EnumDescription(source="get_sector_display", required=True) class Meta: model = SummarySubDimensionIssue only_fields = [ - 'id', - 'summary_issue', - 'text', - 'order', - 'lead_preview_text_ref', + "id", + "summary_issue", + "text", + "order", + "lead_preview_text_ref", ] @staticmethod @@ -367,57 +367,57 @@ def resolve_summary_issue(root, info, **kwargs): class AssessmentRegistryType( - DjangoObjectType, - UserResourceMixin, - ClientIdMixin, + DjangoObjectType, + UserResourceMixin, + ClientIdMixin, ): class Meta: model = AssessmentRegistry only_fields = ( - 'id', - 'bg_countries', - 'bg_crisis_start_date', - 'cost_estimates_usd', - 'no_of_pages', - 'data_collection_start_date', - 'data_collection_end_date', - 'publication_date', - 'executive_summary', - 'objectives', - 'data_collection_techniques', - 'sampling', - 'limitations', + "id", + "bg_countries", + "bg_crisis_start_date", + "cost_estimates_usd", + "no_of_pages", + "data_collection_start_date", + "data_collection_end_date", + "publication_date", + "executive_summary", + "objectives", + "data_collection_techniques", + "sampling", + "limitations", "metadata_complete", "additional_document_complete", "focus_complete", "methodology_complete", "summary_complete", "cna_complete", - "score_complete" + "score_complete", ) # TODO: We might need to define dataloaders here for fields which are used for listing in client side - project = graphene.ID(source='project_id', required=True) + project = graphene.ID(source="project_id", required=True) lead = graphene.NonNull(LeadDetailType) bg_crisis_type = graphene.Field(AssessmentRegistryCrisisTypeEnum, required=True) - bg_crisis_type_display = EnumDescription(source='get_bg_crisis_type_display', required=True) + bg_crisis_type_display = EnumDescription(source="get_bg_crisis_type_display", required=True) bg_preparedness = graphene.Field(AssessmentRegistryPreparednessTypeEnum, required=True) - bg_preparedness_display = EnumDescription(source='get_bg_preparedness_display', required=True) + bg_preparedness_display = EnumDescription(source="get_bg_preparedness_display", required=True) external_support = graphene.Field(AssessmentRegistryExternalSupportTypeEnum, required=True) - external_support_display = EnumDescription(source='get_external_support_display', required=True) + external_support_display = EnumDescription(source="get_external_support_display", required=True) coordinated_joint = graphene.Field(AssessmentRegistryCoordinationTypeEnum, required=True) - coordinated_joint_display = EnumDescription(source='get_coordinated_joint_display', required=True) + coordinated_joint_display = EnumDescription(source="get_coordinated_joint_display", required=True) details_type = graphene.Field(AssessmentRegistryDetailTypeEnum, required=True) - details_type_display = EnumDescription(source='get_details_type_display', required=True) + details_type_display = EnumDescription(source="get_details_type_display", required=True) family = graphene.Field(AssessmentRegistryFamilyTypeEnum, required=True) - family_display = EnumDescription(source='get_family_display', required=True) + family_display = EnumDescription(source="get_family_display", required=True) frequency = graphene.Field(AssessmentRegistryFrequencyTypeEnum, required=True) - frequency_display = EnumDescription(source='get_frequency_display', required=True) + frequency_display = EnumDescription(source="get_frequency_display", required=True) confidentiality = graphene.Field(AssessmentRegistryConfidentialityTypeEnum, required=True) - confidentiality_display = EnumDescription(source='get_confidentiality_display', required=True) + confidentiality_display = EnumDescription(source="get_confidentiality_display", required=True) status = graphene.Field(AssessmentRegistryStatusTypeEnum, required=True) - status_display = EnumDescription(source='get_status_display', required=True) + status_display = EnumDescription(source="get_status_display", required=True) language = graphene.List(graphene.NonNull(AssessmentRegistryLanguageTypeEnum), required=True) focuses = graphene.List(graphene.NonNull(AssessmentRegistryFocusTypeEnum), required=True) sectors = graphene.List(graphene.NonNull(AssessmentRegistrySectorTypeEnum), required=True) @@ -497,8 +497,8 @@ class ProjectQuery: assessment_registries = DjangoPaginatedListObjectField( AssessmentRegistryListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', - ) + page_size_query_param="pageSize", + ), ) assessment_registry_options = graphene.Field(AssessmentRegistryOptionsType) @@ -511,11 +511,8 @@ def resolve_assessment_registry_options(root, info, **kwargs): return AssessmentRegistryOptionsType -class Query(): +class Query: assessment_reg_summary_issue = DjangoObjectField(AssessmentRegistrySummaryIssueType) assessment_reg_summary_issues = DjangoPaginatedListObjectField( - AssessmentRegistrySummaryIssueListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AssessmentRegistrySummaryIssueListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) diff --git a/apps/assessment_registry/serializers.py b/apps/assessment_registry/serializers.py index f8cfeb17a0..154f91af5f 100644 --- a/apps/assessment_registry/serializers.py +++ b/apps/assessment_registry/serializers.py @@ -1,38 +1,35 @@ from rest_framework import serializers - -from .utils import get_hierarchy_level from user_resource.serializers import UserResourceSerializer + from deep.serializers import ( + IntegerIDField, ProjectPropertySerializerMixin, TempClientIdMixin, - IntegerIDField, ) + from .models import ( + AdditionalDocument, + Answer, AssessmentRegistry, + AssessmentRegistryOrganization, MethodologyAttribute, - AdditionalDocument, - SummaryIssue, + ScoreAnalyticalDensity, + ScoreRating, Summary, - SummarySubPillarIssue, SummaryFocus, + SummaryIssue, SummarySubDimensionIssue, - ScoreRating, - ScoreAnalyticalDensity, - Answer, - AssessmentRegistryOrganization, + SummarySubPillarIssue, ) +from .utils import get_hierarchy_level -class AssessmentRegistryOrganizationSerializer( - TempClientIdMixin, - UserResourceSerializer, - serializers.ModelSerializer -): +class AssessmentRegistryOrganizationSerializer(TempClientIdMixin, UserResourceSerializer, serializers.ModelSerializer): id = IntegerIDField(required=False) class Meta: model = AssessmentRegistryOrganization - fields = ('id', 'client_id', 'organization', 'organization_type') + fields = ("id", "client_id", "organization", "organization_type") class MethodologyAttributeSerializer(TempClientIdMixin, serializers.ModelSerializer): @@ -41,8 +38,14 @@ class MethodologyAttributeSerializer(TempClientIdMixin, serializers.ModelSeriali class Meta: model = MethodologyAttribute fields = ( - "id", "client_id", "data_collection_technique", "sampling_approach", "sampling_size", - "proximity", "unit_of_analysis", "unit_of_reporting", + "id", + "client_id", + "data_collection_technique", + "sampling_approach", + "sampling_size", + "proximity", + "unit_of_analysis", + "unit_of_reporting", ) @@ -51,20 +54,24 @@ class AdditionalDocumentSerializer(TempClientIdMixin, UserResourceSerializer): class Meta: model = AdditionalDocument - fields = ("id", "client_id", "document_type", "file", "external_link",) + fields = ( + "id", + "client_id", + "document_type", + "file", + "external_link", + ) class IssueSerializer(UserResourceSerializer): class Meta: model = SummaryIssue - fields = ( - 'sub_pillar', 'sub_dimension', 'parent', 'label' - ) + fields = ("sub_pillar", "sub_dimension", "parent", "label") def validate(self, data): - sub_pillar = data.get('sub_pillar') - sub_dimension = data.get('sub_dimension') - parent = data.get('parent') + sub_pillar = data.get("sub_pillar") + sub_dimension = data.get("sub_dimension") + parent = data.get("parent") if all([sub_pillar, sub_dimension]): raise serializers.ValidationError("Cannot select both sub_pillar and sub_dimension field.") @@ -97,8 +104,13 @@ class SummaryMetaSerializer(UserResourceSerializer): class Meta: model = Summary fields = ( - "id", "total_people_assessed", "total_dead", "total_injured", "total_missing", - "total_people_facing_hum_access_cons", "percentage_of_people_facing_hum_access_cons", + "id", + "total_people_assessed", + "total_dead", + "total_injured", + "total_missing", + "total_people_facing_hum_access_cons", + "percentage_of_people_facing_hum_access_cons", ) @@ -108,9 +120,19 @@ class SummaryFocusMetaSerializer(UserResourceSerializer, TempClientIdMixin): class Meta: model = SummaryFocus fields = ( - "id", "client_id", "sector", "percentage_of_people_affected", "total_people_affected", "percentage_of_moderate", - "percentage_of_severe", "percentage_of_critical", "percentage_in_need", "total_moderate", - "total_severe", "total_critical", "total_in_need", + "id", + "client_id", + "sector", + "percentage_of_people_affected", + "total_people_affected", + "percentage_of_moderate", + "percentage_of_severe", + "percentage_of_critical", + "percentage_in_need", + "total_moderate", + "total_severe", + "total_critical", + "total_in_need", ) @@ -135,7 +157,13 @@ class ScoreRatingSerializer(UserResourceSerializer, TempClientIdMixin): class Meta: model = ScoreRating - fields = ("id", "client_id", "score_type", "rating", "reason",) + fields = ( + "id", + "client_id", + "score_type", + "rating", + "reason", + ) class ScoreAnalyticalDensitySerializer(UserResourceSerializer, TempClientIdMixin): @@ -143,7 +171,14 @@ class ScoreAnalyticalDensitySerializer(UserResourceSerializer, TempClientIdMixin class Meta: model = ScoreAnalyticalDensity - fields = ("id", "client_id", "sector", "analysis_level_covered", "figure_provided", "score",) + fields = ( + "id", + "client_id", + "sector", + "analysis_level_covered", + "figure_provided", + "score", + ) class CNAAnswerSerializer(TempClientIdMixin, UserResourceSerializer): @@ -151,46 +186,28 @@ class CNAAnswerSerializer(TempClientIdMixin, UserResourceSerializer): class Meta: model = Answer - fields = ("id", 'client_id', 'question', 'answer') + fields = ("id", "client_id", "question", "answer") class AssessmentRegistrySerializer(UserResourceSerializer, ProjectPropertySerializerMixin): stakeholders = AssessmentRegistryOrganizationSerializer( - source='assessmentregistryorganization_set', + source="assessmentregistryorganization_set", many=True, required=False, ) - methodology_attributes = MethodologyAttributeSerializer( - many=True, required=False - ) - additional_documents = AdditionalDocumentSerializer( - many=True, required=False - ) - score_ratings = ScoreRatingSerializer( - many=True, required=False - ) - score_analytical_density = ScoreAnalyticalDensitySerializer( - source="analytical_density", many=True, required=False - ) - cna = CNAAnswerSerializer( - source='answer', - many=True, - required=False - ) - summary_pillar_meta = SummaryMetaSerializer(source='summary', required=False) - - summary_sub_pillar_issue = SummarySubPillarIssueSerializer( - source="summary_sub_sector_issue_ary", many=True, required=False - ) - summary_dimension_meta = SummaryFocusMetaSerializer( - source='summary_focus', many=True, required=False - ) + methodology_attributes = MethodologyAttributeSerializer(many=True, required=False) + additional_documents = AdditionalDocumentSerializer(many=True, required=False) + score_ratings = ScoreRatingSerializer(many=True, required=False) + score_analytical_density = ScoreAnalyticalDensitySerializer(source="analytical_density", many=True, required=False) + cna = CNAAnswerSerializer(source="answer", many=True, required=False) + summary_pillar_meta = SummaryMetaSerializer(source="summary", required=False) + + summary_sub_pillar_issue = SummarySubPillarIssueSerializer(source="summary_sub_sector_issue_ary", many=True, required=False) + summary_dimension_meta = SummaryFocusMetaSerializer(source="summary_focus", many=True, required=False) summary_sub_dimension_issue = SummarySubDimensionSerializer( source="summary_focus_subsector_issue_ary", many=True, required=False ) - additional_documents = AdditionalDocumentSerializer( - many=True, required=False - ) + additional_documents = AdditionalDocumentSerializer(many=True, required=False) class Meta: model = AssessmentRegistry @@ -241,7 +258,7 @@ class Meta: "methodology_complete", "summary_complete", "cna_complete", - "score_complete" + "score_complete", ) def validate_score_ratings(self, data): @@ -265,19 +282,19 @@ def validate_score_analytical_density(self, data): def validate_stakeholders(self, data): stakeholders_list = [] for org in data: - org.pop('client_id', None) + org.pop("client_id", None) if org in stakeholders_list: - raise serializers.ValidationError('Dublicate organization selected') + raise serializers.ValidationError("Dublicate organization selected") stakeholders_list.append(org) def validate_cna(self, data): question_list = [] for question in data: - question.pop('client_id', None) + question.pop("client_id", None) if question in question_list: - raise serializers.ValidationError('Dublicate question selected') + raise serializers.ValidationError("Dublicate question selected") question_list.append(question) def validate(self, data): - data['project'] = self.project + data["project"] = self.project return data diff --git a/apps/assessment_registry/tests/test_dashboard_schema.py b/apps/assessment_registry/tests/test_dashboard_schema.py index ec8680c151..c676a0417b 100644 --- a/apps/assessment_registry/tests/test_dashboard_schema.py +++ b/apps/assessment_registry/tests/test_dashboard_schema.py @@ -1,24 +1,21 @@ from datetime import date, timedelta -from utils.graphene.tests import GraphQLTestCase - -from organization.factories import OrganizationFactory -from geo.factories import RegionFactory, AdminLevelFactory, GeoAreaFactory -from gallery.factories import FileFactory -from project.factories import ProjectFactory -from user.factories import UserFactory -from lead.factories import LeadFactory -from assessment_registry.factories import ( - QuestionFactory, - SummaryIssueFactory, -) +from assessment_registry.factories import QuestionFactory, SummaryIssueFactory from assessment_registry.models import ( + AdditionalDocument, AssessmentRegistry, MethodologyAttribute, - AdditionalDocument, - ScoreRating, Question, + ScoreRating, ) +from gallery.factories import FileFactory +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory +from lead.factories import LeadFactory +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class AssessmentDashboardQuerySchema(GraphQLTestCase): @@ -355,21 +352,21 @@ def _query_check(filter=None, **kwargs): self.assertEqual(content["assessmentPerFrameworkPillar"][0]["date"], str(date.today())) # assessment dashboard tab 2 self.assertEqual( - content['assessmentByDataCollectionTechniqueAndGeolocation'][0]['dataCollectionTechnique'], - "SECONDARY_DATA_REVIEW") + content["assessmentByDataCollectionTechniqueAndGeolocation"][0]["dataCollectionTechnique"], "SECONDARY_DATA_REVIEW" + ) self.assertEqual( - content['assessmentByDataCollectionTechniqueAndGeolocation'][1]['dataCollectionTechnique'], - "KEY_INFORMAT_INTERVIEW") - self.assertEqual(content['assessmentByDataCollectionTechniqueAndGeolocation'][0]['geoArea'], str(self.geo_area1.id)) - self.assertEqual(content['assessmentByDataCollectionTechniqueAndGeolocation'][1]['geoArea'], str(self.geo_area1.id)) - self.assertEqual(content['assessmentByDataCollectionTechniqueAndGeolocation'][0]['count'], 1) - self.assertEqual(content['assessmentByDataCollectionTechniqueAndGeolocation'][1]['count'], 1) - self.assertEqual(content['assessmentByProximityAndGeolocation'][0]['count'], 2) - self.assertEqual(content['assessmentByProximityAndGeolocation'][0]['proximity'], "FACE_TO_FACE") - self.assertEqual(content['assessmentByProximityAndGeolocation'][0]['geoArea'], str(self.geo_area1.id)) + content["assessmentByDataCollectionTechniqueAndGeolocation"][1]["dataCollectionTechnique"], "KEY_INFORMAT_INTERVIEW" + ) + self.assertEqual(content["assessmentByDataCollectionTechniqueAndGeolocation"][0]["geoArea"], str(self.geo_area1.id)) + self.assertEqual(content["assessmentByDataCollectionTechniqueAndGeolocation"][1]["geoArea"], str(self.geo_area1.id)) + self.assertEqual(content["assessmentByDataCollectionTechniqueAndGeolocation"][0]["count"], 1) + self.assertEqual(content["assessmentByDataCollectionTechniqueAndGeolocation"][1]["count"], 1) + self.assertEqual(content["assessmentByProximityAndGeolocation"][0]["count"], 2) + self.assertEqual(content["assessmentByProximityAndGeolocation"][0]["proximity"], "FACE_TO_FACE") + self.assertEqual(content["assessmentByProximityAndGeolocation"][0]["geoArea"], str(self.geo_area1.id)) # assessment Dashboard tab 3 - self.assertEqual(content['medianQualityScoreByAnalyticalDensityDate'][0]['sector'], "FOOD_SECURITY") - self.assertEqual(content['medianQualityScoreByAnalyticalDensityDate'][0]['sectorDisplay'], "Food Security") - self.assertEqual(content['medianQualityScoreByAnalyticalDensityDate'][0]['date'], str(date.today())) - self.assertEqual(content['medianQualityScoreByAnalyticalDensityDate'][1]['finalScore'], 0.0) - self.assertEqual(content['medianQualityScoreByGeoArea'][0]['finalScore'], 8.75) + self.assertEqual(content["medianQualityScoreByAnalyticalDensityDate"][0]["sector"], "FOOD_SECURITY") + self.assertEqual(content["medianQualityScoreByAnalyticalDensityDate"][0]["sectorDisplay"], "Food Security") + self.assertEqual(content["medianQualityScoreByAnalyticalDensityDate"][0]["date"], str(date.today())) + self.assertEqual(content["medianQualityScoreByAnalyticalDensityDate"][1]["finalScore"], 0.0) + self.assertEqual(content["medianQualityScoreByGeoArea"][0]["finalScore"], 8.75) diff --git a/apps/assessment_registry/tests/test_mutation.py b/apps/assessment_registry/tests/test_mutation.py index 64eff90bde..94663b7e7d 100644 --- a/apps/assessment_registry/tests/test_mutation.py +++ b/apps/assessment_registry/tests/test_mutation.py @@ -1,27 +1,24 @@ -from utils.graphene.tests import GraphQLTestCase - -from organization.factories import OrganizationFactory -from geo.factories import GeoAreaFactory, AdminLevelFactory, RegionFactory -from gallery.factories import FileFactory -from project.factories import ProjectFactory -from user.factories import UserFactory -from lead.factories import LeadFactory -from assessment_registry.factories import ( - QuestionFactory, - SummaryIssueFactory, -) +from assessment_registry.factories import QuestionFactory, SummaryIssueFactory from assessment_registry.models import ( + AdditionalDocument, AssessmentRegistry, MethodologyAttribute, - AdditionalDocument, - ScoreRating, Question, ScoreAnalyticalDensity, + ScoreRating, ) +from gallery.factories import FileFactory +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory +from lead.factories import LeadFactory +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestAssessmentRegistryMutation(GraphQLTestCase): - CREATE_ASSESSMENT_REGISTRY_QUERY = ''' + CREATE_ASSESSMENT_REGISTRY_QUERY = """ mutation MyMutation ($projectId: ID!, $input: AssessmentRegistryCreateInputType!) { project(id:$projectId) { createAssessmentRegistry(data: $input) { @@ -122,7 +119,7 @@ class TestAssessmentRegistryMutation(GraphQLTestCase): } } } -''' +""" def setUp(self): super().setUp() @@ -138,7 +135,7 @@ def setUp(self): self.question1 = QuestionFactory.create( sector=Question.QuestionSector.RELEVANCE.value, sub_sector=Question.QuestionSubSector.RELEVANCE.value, - question="test question" + question="test question", ) self.file = FileFactory.create() self.project1.add_member(self.member_user, role=self.project_role_member) @@ -147,10 +144,7 @@ def setUp(self): def test_create_assessment_registry(self): def _query_check(minput, **kwargs): return self.query_check( - self.CREATE_ASSESSMENT_REGISTRY_QUERY, - minput=minput, - variables={'projectId': self.project1.id}, - **kwargs + self.CREATE_ASSESSMENT_REGISTRY_QUERY, minput=minput, variables={"projectId": self.project1.id}, **kwargs ) minput = dict( @@ -167,21 +161,21 @@ def _query_check(minput, **kwargs): focuses=[ self.genum(AssessmentRegistry.FocusType.CONTEXT), self.genum(AssessmentRegistry.FocusType.HUMANITERIAN_ACCESS), - self.genum(AssessmentRegistry.FocusType.DISPLACEMENT) + self.genum(AssessmentRegistry.FocusType.DISPLACEMENT), ], frequency=self.genum(AssessmentRegistry.FrequencyType.ONE_OFF), protectionInfoMgmts=[ self.genum(AssessmentRegistry.ProtectionInfoType.PROTECTION_MONITORING), - self.genum(AssessmentRegistry.ProtectionInfoType.PROTECTION_NEEDS_ASSESSMENT) + self.genum(AssessmentRegistry.ProtectionInfoType.PROTECTION_NEEDS_ASSESSMENT), ], protectionRisks=[ self.genum(AssessmentRegistry.ProtectionRiskType.ABDUCATION_KIDNAPPING), - self.genum(AssessmentRegistry.ProtectionRiskType.ATTACKS_ON_CIVILIANS) + self.genum(AssessmentRegistry.ProtectionRiskType.ATTACKS_ON_CIVILIANS), ], sectors=[ self.genum(AssessmentRegistry.SectorType.HEALTH), self.genum(AssessmentRegistry.SectorType.SHELTER), - self.genum(AssessmentRegistry.SectorType.WASH) + self.genum(AssessmentRegistry.SectorType.WASH), ], lead=self.lead1.id, locations=[self.geo_area1.id, self.geo_area2.id], @@ -193,10 +187,7 @@ def _query_check(minput, **kwargs): noOfPages=10, publicationDate="2023-01-01", sampling="test", - language=[ - self.genum(AssessmentRegistry.Language.ENGLISH), - self.genum(AssessmentRegistry.Language.SPANISH) - ], + language=[self.genum(AssessmentRegistry.Language.ENGLISH), self.genum(AssessmentRegistry.Language.SPANISH)], bgCountries=[self.region.id], affectedGroups=[self.genum(AssessmentRegistry.AffectedGroupType.ALL_AFFECTED)], metadataComplete=True, @@ -210,27 +201,27 @@ def _query_check(minput, **kwargs): samplingApproach=self.genum(MethodologyAttribute.SamplingApproachType.NON_RANDOM_SELECTION), samplingSize=10, unitOfAnalysis=self.genum(MethodologyAttribute.UnitOfAnalysisType.CRISIS), - unitOfReporting=self.genum(MethodologyAttribute.UnitOfReportingType.CRISIS) + unitOfReporting=self.genum(MethodologyAttribute.UnitOfReportingType.CRISIS), ), ], additionalDocuments=[ dict( documentType=self.genum(AdditionalDocument.DocumentType.ASSESSMENT_DATABASE), externalLink="", - file=str(self.file.id) + file=str(self.file.id), ), ], scoreRatings=[ dict( scoreType=self.genum(ScoreRating.ScoreCriteria.ASSUMPTIONS), rating=self.genum(ScoreRating.RatingType.VERY_POOR), - reason="test" + reason="test", ), dict( scoreType=self.genum(ScoreRating.ScoreCriteria.RELEVANCE), rating=self.genum(ScoreRating.RatingType.VERY_POOR), - reason="test" - ) + reason="test", + ), ], scoreAnalyticalDensity=[ dict( @@ -244,11 +235,7 @@ def _query_check(minput, **kwargs): ], score=1, ), - dict( - sector=self.genum(AssessmentRegistry.SectorType.SHELTER), - analysisLevelCovered=[], - score=2 - ) + dict(sector=self.genum(AssessmentRegistry.SectorType.SHELTER), analysisLevelCovered=[], score=2), ], cna=[ dict( @@ -256,9 +243,7 @@ def _query_check(minput, **kwargs): question=self.question1.id, ) ], - summaryPillarMeta=dict( - totalPeopleAssessed=1000 - ), + summaryPillarMeta=dict(totalPeopleAssessed=1000), summarySubPillarIssue=[ dict( summaryIssue=self.summary_issue1.id, @@ -281,18 +266,18 @@ def _query_check(minput, **kwargs): sector=self.genum(AssessmentRegistry.SectorType.FOOD_SECURITY), order=1, ) - ] + ], ) self.force_login(self.member_user) content = _query_check(minput, okay=False) - data = content['data']['project']['createAssessmentRegistry']['result'] - self.assertEqual(data['costEstimatesUsd'], minput['costEstimatesUsd'], data) - self.assertIsNotNone(data['methodologyAttributes']) - self.assertIsNotNone(data['additionalDocuments']) - self.assertIsNotNone(data['cna']) - self.assertIsNotNone(data['summaryPillarMeta']) - self.assertIsNotNone(data['summaryDimensionMeta']) - self.assertIsNotNone(data['summarySubPillarIssue']) - self.assertIsNotNone(data['summarySubDimensionIssue']) - self.assertEqual(data['metadataComplete'], True) - self.assertIsNotNone(data['protectionRisks']) + data = content["data"]["project"]["createAssessmentRegistry"]["result"] + self.assertEqual(data["costEstimatesUsd"], minput["costEstimatesUsd"], data) + self.assertIsNotNone(data["methodologyAttributes"]) + self.assertIsNotNone(data["additionalDocuments"]) + self.assertIsNotNone(data["cna"]) + self.assertIsNotNone(data["summaryPillarMeta"]) + self.assertIsNotNone(data["summaryDimensionMeta"]) + self.assertIsNotNone(data["summarySubPillarIssue"]) + self.assertIsNotNone(data["summarySubDimensionIssue"]) + self.assertEqual(data["metadataComplete"], True) + self.assertIsNotNone(data["protectionRisks"]) diff --git a/apps/assessment_registry/tests/test_schemas.py b/apps/assessment_registry/tests/test_schemas.py index 8b2328c227..19434e51cd 100644 --- a/apps/assessment_registry/tests/test_schemas.py +++ b/apps/assessment_registry/tests/test_schemas.py @@ -1,55 +1,53 @@ -from utils.graphene.tests import GraphQLTestCase - -from assessment_registry.factories import AssessmentRegistryFactory -from organization.factories import OrganizationFactory -from geo.factories import RegionFactory -from gallery.factories import FileFactory -from project.factories import ProjectFactory -from user.factories import UserFactory -from lead.factories import LeadFactory from assessment_registry.factories import ( - QuestionFactory, - MethodologyAttributeFactory, AdditionalDocumentFactory, - ScoreRatingFactory, - ScoreAnalyticalDensityFactory, AnswerFactory, - SummaryMetaFactory, - SummarySubPillarIssueFactory, - SummaryIssueFactory, + AssessmentRegistryFactory, + MethodologyAttributeFactory, + QuestionFactory, + ScoreAnalyticalDensityFactory, + ScoreRatingFactory, SummaryFocusFactory, + SummaryIssueFactory, + SummaryMetaFactory, SummarySubDimensionIssueFactory, + SummarySubPillarIssueFactory, ) -from lead.models import Lead -from project.models import Project from assessment_registry.models import ( - AssessmentRegistry, AdditionalDocument, - ScoreRating, + AssessmentRegistry, Question, + ScoreRating, SummaryIssue, ) +from gallery.factories import FileFactory +from geo.factories import RegionFactory +from lead.factories import LeadFactory +from lead.models import Lead +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from project.models import Project +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestAssessmentRegistryQuerySchema(GraphQLTestCase): def setUp(self): super().setUp() self.question1 = QuestionFactory.create( - sector=Question.QuestionSector.RELEVANCE, - sub_sector=Question.QuestionSubSector.RELEVANCE, - question='test question' + sector=Question.QuestionSector.RELEVANCE, sub_sector=Question.QuestionSubSector.RELEVANCE, question="test question" ) self.question2 = QuestionFactory.create( sector=Question.QuestionSector.COMPREHENSIVENESS, sub_sector=Question.QuestionSubSector.GEOGRAPHIC_COMPREHENSIVENESS, - question='test question', + question="test question", ) self.country1, self.country2 = RegionFactory.create_batch(2) self.organization1, self.organization2 = OrganizationFactory.create_batch(2) self.org_list = [self.organization1.id, self.organization2.id] def test_assessment_registry_query(self): - query = ''' + query = """ query MyQuery ($projectId: ID! $assessmentRegistryId: ID!) { project(id: $projectId) { assessmentRegistry (id: $assessmentRegistryId) { @@ -110,7 +108,7 @@ def test_assessment_registry_query(self): } } } - ''' + """ project1 = ProjectFactory.create(status=Project.Status.ACTIVE) @@ -130,7 +128,7 @@ def test_assessment_registry_query(self): cna_complete=False, protection_risks=[ AssessmentRegistry.ProtectionRiskType.ABDUCATION_KIDNAPPING, - AssessmentRegistry.ProtectionRiskType.ATTACKS_ON_CIVILIANS + AssessmentRegistry.ProtectionRiskType.ATTACKS_ON_CIVILIANS, ], ) @@ -142,23 +140,23 @@ def test_assessment_registry_query(self): AdditionalDocumentFactory.create( assessment_registry=assessment_registry, document_type=AdditionalDocument.DocumentType.ASSESSMENT_DATABASE, - file=FileFactory() + file=FileFactory(), ) # Add Score Ratings ScoreRatingFactory.create( assessment_registry=assessment_registry, score_type=ScoreRating.ScoreCriteria.RELEVANCE, - rating=ScoreRating.RatingType.GOOD + rating=ScoreRating.RatingType.GOOD, ) ScoreRatingFactory.create( assessment_registry=assessment_registry, score_type=ScoreRating.ScoreCriteria.TIMELINESS, - rating=ScoreRating.RatingType.GOOD + rating=ScoreRating.RatingType.GOOD, ) ScoreRatingFactory.create( assessment_registry=assessment_registry, score_type=ScoreRating.ScoreCriteria.GRANULARITY, - rating=ScoreRating.RatingType.GOOD + rating=ScoreRating.RatingType.GOOD, ) # Add Score Analytical Density ScoreAnalyticalDensityFactory.create( @@ -170,16 +168,8 @@ def test_assessment_registry_query(self): sector=AssessmentRegistry.SectorType.SHELTER, ) # Add Answer to the question - AnswerFactory.create( - assessment_registry=assessment_registry, - question=self.question1, - answer=True - ) - AnswerFactory.create( - assessment_registry=assessment_registry, - question=self.question2, - answer=False - ) + AnswerFactory.create(assessment_registry=assessment_registry, question=self.question1, answer=True) + AnswerFactory.create(assessment_registry=assessment_registry, question=self.question2, answer=False) SummaryMetaFactory.create( assessment_registry=assessment_registry, ) @@ -201,42 +191,40 @@ def test_assessment_registry_query(self): def _query_check(assessment_registry, **kwargs): return self.query_check( - query, - variables={ - 'projectId': project1.id, - 'assessmentRegistryId': assessment_registry.id - }, **kwargs) + query, variables={"projectId": project1.id, "assessmentRegistryId": assessment_registry.id}, **kwargs + ) # -- non member user self.force_login(non_member_user) content1 = _query_check(assessment_registry) - self.assertIsNone(content1['data']['project']['assessmentRegistry']) + self.assertIsNone(content1["data"]["project"]["assessmentRegistry"]) # --- member user self.force_login(member_user) content = _query_check(assessment_registry) - self.assertIsNotNone(content['data']['project']['assessmentRegistry']['id']) - self.assertEqual(content['data']['project']['assessmentRegistry']['lead']['id'], str(lead_1.id), ) - self.assertIsNotNone(content['data']['project']['assessmentRegistry']['bgCountries']) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['bgCountries']), 2) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['methodologyAttributes']), 2) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['additionalDocuments']), 1) - self.assertIsNotNone( - content['data']['project']['assessmentRegistry']['additionalDocuments'][0]['file']['file']['url'] + self.assertIsNotNone(content["data"]["project"]["assessmentRegistry"]["id"]) + self.assertEqual( + content["data"]["project"]["assessmentRegistry"]["lead"]["id"], + str(lead_1.id), ) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['scoreRatings']), 3) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['scoreAnalyticalDensity']), 2) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['cna']), 2) - - self.assertEqual(len(content['data']['project']['assessmentRegistry']['summaryPillarMeta']), 1) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['summarySubPillarIssue']), 1) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['summaryDimensionMeta']), 1) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['summarySubDimensionIssue']), 1) - self.assertEqual(content['data']['project']['assessmentRegistry']['cnaComplete'], False) - self.assertEqual(len(content['data']['project']['assessmentRegistry']['protectionRisks']), 2) + self.assertIsNotNone(content["data"]["project"]["assessmentRegistry"]["bgCountries"]) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["bgCountries"]), 2) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["methodologyAttributes"]), 2) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["additionalDocuments"]), 1) + self.assertIsNotNone(content["data"]["project"]["assessmentRegistry"]["additionalDocuments"][0]["file"]["file"]["url"]) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["scoreRatings"]), 3) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["scoreAnalyticalDensity"]), 2) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["cna"]), 2) + + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["summaryPillarMeta"]), 1) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["summarySubPillarIssue"]), 1) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["summaryDimensionMeta"]), 1) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["summarySubDimensionIssue"]), 1) + self.assertEqual(content["data"]["project"]["assessmentRegistry"]["cnaComplete"], False) + self.assertEqual(len(content["data"]["project"]["assessmentRegistry"]["protectionRisks"]), 2) def test_list_assessment_registry_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { assessmentRegistries { @@ -249,7 +237,7 @@ def test_list_assessment_registry_query(self): } } } - ''' + """ project1 = ProjectFactory.create() project2 = ProjectFactory.create() @@ -287,8 +275,10 @@ def _query_check(**kwargs): return self.query_check( query, variables={ - 'id': project1.id, - }, **kwargs) + "id": project1.id, + }, + **kwargs, + ) # -- Without login _query_check(assert_for_error=True) @@ -296,21 +286,21 @@ def _query_check(**kwargs): # -- non member user self.force_login(non_member_user) content = _query_check(okay=False) - self.assertEqual(content['data']['project']['assessmentRegistries']['totalCount'], 0) + self.assertEqual(content["data"]["project"]["assessmentRegistries"]["totalCount"], 0) # -- With login self.force_login(member_user) content = _query_check(okay=False) - self.assertEqual(content['data']['project']['assessmentRegistries']['totalCount'], 4, content) + self.assertEqual(content["data"]["project"]["assessmentRegistries"]["totalCount"], 4, content) # -- non confidential member user self.force_login(non_confidential_member_user) content = _query_check(okay=False) - self.assertEqual(content['data']['project']['assessmentRegistries']['totalCount'], 3) + self.assertEqual(content["data"]["project"]["assessmentRegistries"]["totalCount"], 3) def test_issue_list_query_filter(self): - query = ''' + query = """ query MyQuery ( $isParent: Boolean $label: String @@ -326,7 +316,7 @@ def test_issue_list_query_filter(self): } } } - ''' + """ member_user = UserFactory.create() self.force_login(member_user) @@ -337,22 +327,23 @@ def test_issue_list_query_filter(self): for filter_data, expected_issues in [ ({}, [child_issue1, child_issue2, child_issue3, parent_issue1, parent_issue2, parent_issue3]), - ({'isParent': True}, [parent_issue1, parent_issue2, parent_issue3]), - ({'isParent': False}, [child_issue1, child_issue2, child_issue3]), + ({"isParent": True}, [parent_issue1, parent_issue2, parent_issue3]), + ({"isParent": False}, [child_issue1, child_issue2, child_issue3]), ]: content = self.query_check(query, variables={**filter_data}) self.assertListIds( - content['data']['assessmentRegSummaryIssues']['results'], expected_issues, - {'response': content, 'filter': filter_data} + content["data"]["assessmentRegSummaryIssues"]["results"], + expected_issues, + {"response": content, "filter": filter_data}, ) # check for child count - content = self.query_check(query)['data']['assessmentRegSummaryIssues']['results'] + content = self.query_check(query)["data"]["assessmentRegSummaryIssues"]["results"] parents = [str(parent.id) for parent in SummaryIssue.objects.filter(parent=None)] - child_count_list = [item['childCount'] for item in content if item['id'] in parents] + child_count_list = [item["childCount"] for item in content if item["id"] in parents] self.assertEqual(child_count_list, [1, 2, 0]) # check for level - self.assertEqual(set([item['level'] for item in content if item['id'] in parents]), {1}) - self.assertEqual(set([item['level'] for item in content if item['id'] not in parents]), {2}) + self.assertEqual(set([item["level"] for item in content if item["id"] in parents]), {1}) + self.assertEqual(set([item["level"] for item in content if item["id"] not in parents]), {2}) diff --git a/apps/assisted_tagging/admin.py b/apps/assisted_tagging/admin.py index 8bcc2de9f4..1d064bbe77 100644 --- a/apps/assisted_tagging/admin.py +++ b/apps/assisted_tagging/admin.py @@ -1,30 +1,34 @@ # Register your models here. from admin_auto_filters.filters import AutocompleteFilterFactory +from assisted_tagging.models import ( + AssistedTaggingModelPredictionTag, + AssistedTaggingPrediction, + DraftEntry, +) from django.contrib import admin -from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry from deep.admin import VersionAdmin @admin.register(DraftEntry) class DraftEntryAdmin(VersionAdmin): - search_fields = ['lead'] + search_fields = ["lead"] list_display = [ - 'lead', - 'prediction_status', + "lead", + "prediction_status", ] - list_filter = ( - AutocompleteFilterFactory('Lead', 'lead'), - AutocompleteFilterFactory('Project', 'project'), - 'type' - ) + list_filter = (AutocompleteFilterFactory("Lead", "lead"), AutocompleteFilterFactory("Project", "project"), "type") - autocomplete_fields = ('project', 'lead', 'related_geoareas',) + autocomplete_fields = ( + "project", + "lead", + "related_geoareas", + ) @admin.register(AssistedTaggingPrediction) class AssistedTaggingPredictionAdmin(VersionAdmin): - search_fields = ['draft_entry'] + search_fields = ["draft_entry"] list_display = [ "data_type", "draft_entry", @@ -32,21 +36,18 @@ class AssistedTaggingPredictionAdmin(VersionAdmin): "is_selected", "tag", ] - list_filter = ( - AutocompleteFilterFactory('DraftEntry', 'draft_entry'), - - ) + list_filter = (AutocompleteFilterFactory("DraftEntry", "draft_entry"),) # NOTE: Skipping model_version. Only few of them exists - autocomplete_fields = ('draft_entry', 'category', 'tag') + autocomplete_fields = ("draft_entry", "category", "tag") @admin.register(AssistedTaggingModelPredictionTag) class AssistedTaggingModelPredictionTagAdmin(VersionAdmin): - search_fields = ['parent_tag'] + search_fields = ["parent_tag"] list_display = [ - 'name', - 'is_category', - 'tag_id', - 'parent_tag', + "name", + "is_category", + "tag_id", + "parent_tag", ] - autocomplete_fields = ('parent_tag',) + autocomplete_fields = ("parent_tag",) diff --git a/apps/assisted_tagging/apps.py b/apps/assisted_tagging/apps.py index aedcc18c9f..7c4594a832 100644 --- a/apps/assisted_tagging/apps.py +++ b/apps/assisted_tagging/apps.py @@ -2,4 +2,4 @@ class AssistedTaggingConfig(AppConfig): - name = 'assisted_tagging' + name = "assisted_tagging" diff --git a/apps/assisted_tagging/dataloaders.py b/apps/assisted_tagging/dataloaders.py index 959403e55f..2d7d721e31 100644 --- a/apps/assisted_tagging/dataloaders.py +++ b/apps/assisted_tagging/dataloaders.py @@ -1,17 +1,15 @@ from collections import defaultdict -from promise import Promise - -from django.utils.functional import cached_property from assisted_tagging.models import AssistedTaggingPrediction +from django.utils.functional import cached_property +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin class DraftEntryPredicationsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - assisted_tagging_qs = AssistedTaggingPrediction.objects\ - .filter(draft_entry_id__in=keys, is_selected=True) + assisted_tagging_qs = AssistedTaggingPrediction.objects.filter(draft_entry_id__in=keys, is_selected=True) _map = defaultdict(list) for assisted_tagging in assisted_tagging_qs: _map[assisted_tagging.draft_entry_id].append(assisted_tagging) diff --git a/apps/assisted_tagging/enums.py b/apps/assisted_tagging/enums.py index 3301924fb3..23c85ff26b 100644 --- a/apps/assisted_tagging/enums.py +++ b/apps/assisted_tagging/enums.py @@ -5,18 +5,13 @@ get_enum_name_from_django_field, ) -from .models import ( - DraftEntry, - AssistedTaggingPrediction, -) +from .models import AssistedTaggingPrediction, DraftEntry -DraftEntryPredictionStatusEnum = convert_enum_to_graphene_enum( - DraftEntry.PredictionStatus, name='DraftEntryPredictionStatusEnum') +DraftEntryPredictionStatusEnum = convert_enum_to_graphene_enum(DraftEntry.PredictionStatus, name="DraftEntryPredictionStatusEnum") AssistedTaggingPredictionDataTypeEnum = convert_enum_to_graphene_enum( - AssistedTaggingPrediction.DataType, name='AssistedTaggingPredictionDataTypeEnum') -DraftEntryTypeEnum = convert_enum_to_graphene_enum( - DraftEntry.Type, name="DraftEntryTypeEnum" + AssistedTaggingPrediction.DataType, name="AssistedTaggingPredictionDataTypeEnum" ) +DraftEntryTypeEnum = convert_enum_to_graphene_enum(DraftEntry.Type, name="DraftEntryTypeEnum") enum_map = { get_enum_name_from_django_field(field): enum @@ -30,13 +25,13 @@ class AssistedTaggingModelOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' + ASC_ID = "id" # DESC - DESC_ID = f'-{ASC_ID}' + DESC_ID = f"-{ASC_ID}" class AssistedTaggingModelPredictionTagOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' + ASC_ID = "id" # DESC - DESC_ID = f'-{ASC_ID}' + DESC_ID = f"-{ASC_ID}" diff --git a/apps/assisted_tagging/factories.py b/apps/assisted_tagging/factories.py index cb448b5029..95518f2956 100644 --- a/apps/assisted_tagging/factories.py +++ b/apps/assisted_tagging/factories.py @@ -3,32 +3,32 @@ from .models import ( AssistedTaggingModel, - AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, - DraftEntry, + AssistedTaggingModelVersion, AssistedTaggingPrediction, - WrongPredictionReview, + DraftEntry, MissingPredictionReview, + WrongPredictionReview, ) class AssistedTaggingModelFactory(DjangoModelFactory): - name = factory.Sequence(lambda n: f'Model-{n}') + name = factory.Sequence(lambda n: f"Model-{n}") class Meta: model = AssistedTaggingModel class AssistedTaggingModelVersionFactory(DjangoModelFactory): - version = factory.Sequence(lambda n: f'version-{n}') + version = factory.Sequence(lambda n: f"version-{n}") class Meta: model = AssistedTaggingModelVersion class AssistedTaggingModelPredictionTagFactory(DjangoModelFactory): - name = factory.Sequence(lambda n: f'name-{n}') - tag_id = factory.Sequence(lambda n: f'tag-{n}') + name = factory.Sequence(lambda n: f"name-{n}") + tag_id = factory.Sequence(lambda n: f"tag-{n}") class Meta: model = AssistedTaggingModelPredictionTag diff --git a/apps/assisted_tagging/filters.py b/apps/assisted_tagging/filters.py index eacfbdb935..8fb9492964 100644 --- a/apps/assisted_tagging/filters.py +++ b/apps/assisted_tagging/filters.py @@ -1,20 +1,15 @@ import django_filters +from utils.graphene.filters import IDFilter, IDListFilter, MultipleInputFilter + +from .enums import DraftEntryTypeEnum from .models import DraftEntry -from utils.graphene.filters import ( - IDFilter, - MultipleInputFilter, - IDListFilter, -) -from .enums import ( - DraftEntryTypeEnum -) class DraftEntryFilterSet(django_filters.FilterSet): - lead = IDFilter(field_name='lead') - draft_entry_types = MultipleInputFilter(DraftEntryTypeEnum, field_name='type') - ignore_ids = IDListFilter(method='filter_ignore_draft_ids', help_text='Ids are filtered out.') + lead = IDFilter(field_name="lead") + draft_entry_types = MultipleInputFilter(DraftEntryTypeEnum, field_name="type") + ignore_ids = IDListFilter(method="filter_ignore_draft_ids", help_text="Ids are filtered out.") is_discarded = django_filters.BooleanFilter() class Meta: diff --git a/apps/assisted_tagging/models.py b/apps/assisted_tagging/models.py index 85875daf3b..8e4485bd36 100644 --- a/apps/assisted_tagging/models.py +++ b/apps/assisted_tagging/models.py @@ -1,22 +1,23 @@ # from django.contrib.postgres.fields import ArrayField from __future__ import annotations + from typing import Union -from django.db import models -from django.db.models.functions import Concat from analysis_framework.models import Widget -from project.models import Project +from django.db import models +from django.db.models.functions import Concat +from geo.models import GeoArea from lead.models import Lead +from project.models import Project from user_resource.models import UserResource, UserResourceCreated -from geo.models import GeoArea class AssistedTaggingModel(models.Model): # This is for refering model id within deep. This can change. Source is the deepl. class ModelID(models.TextChoices): - MAIN = 'all_tags_model', 'All tags model' - GEO = 'geolocation', 'Geo Location' - RELIABILITY = 'reliability', 'Reliability' + MAIN = "all_tags_model", "All tags model" + GEO = "geolocation", "Geo Location" + RELIABILITY = "reliability", "Reliability" model_id = models.CharField(max_length=256) name = models.CharField(max_length=256) @@ -25,15 +26,15 @@ def __int__(self): self.versions: models.QuerySet[AssistedTaggingModelVersion] def __str__(self): - return f'<{self.name}> {self.model_id}' + return f"<{self.name}> {self.model_id}" @property def latest_version(self): - return self.versions.order_by('-version').first() + return self.versions.order_by("-version").first() class AssistedTaggingModelVersion(models.Model): - model = models.ForeignKey(AssistedTaggingModel, on_delete=models.CASCADE, related_name='versions') + model = models.ForeignKey(AssistedTaggingModel, on_delete=models.CASCADE, related_name="versions") version = models.CharField(max_length=256) # 'MAJOR.MINOR.PATCH' # Extra attributes (TODO: Later) # endpoint = models.CharField(max_length=256) @@ -46,21 +47,31 @@ def __str__(self): @classmethod def get_latest_models_version(cls) -> models.QuerySet: - return AssistedTaggingModelVersion.objects.annotate( - model_with_version=Concat( - models.F('model_id'), models.F('version'), - output_field=models.CharField(), - ) - ).filter( - model_with_version__in=AssistedTaggingModelVersion.objects.order_by().values('model').annotate( - max_version=models.Max('version'), - ).annotate( + return ( + AssistedTaggingModelVersion.objects.annotate( model_with_version=Concat( - models.F('model_id'), models.F('max_version'), + models.F("model_id"), + models.F("version"), output_field=models.CharField(), ) - ).values('model_with_version') - ).order_by('model_with_version') + ) + .filter( + model_with_version__in=AssistedTaggingModelVersion.objects.order_by() + .values("model") + .annotate( + max_version=models.Max("version"), + ) + .annotate( + model_with_version=Concat( + models.F("model_id"), + models.F("max_version"), + output_field=models.CharField(), + ) + ) + .values("model_with_version") + ) + .order_by("model_with_version") + ) class AssistedTaggingModelPredictionTag(models.Model): @@ -73,7 +84,7 @@ class AssistedTaggingModelPredictionTag(models.Model): is_category = models.BooleanField(default=False) is_deprecated = models.BooleanField(default=False) parent_tag = models.ForeignKey( - 'assisted_tagging.AssistedTaggingModelPredictionTag', + "assisted_tagging.AssistedTaggingModelPredictionTag", on_delete=models.PROTECT, null=True, blank=True, @@ -85,18 +96,19 @@ def __str__(self): class DraftEntry(UserResourceCreated): class PredictionStatus(models.IntegerChoices): - PENDING = 0, 'Pending' - STARTED = 1, 'Started' - DONE = 2, 'Done' - SEND_FAILED = 3, 'Send Failed' + PENDING = 0, "Pending" + STARTED = 1, "Started" + DONE = 2, "Done" + SEND_FAILED = 3, "Send Failed" class Type(models.IntegerChoices): - AUTO = 0, 'Auto Extraction' # NLP defiend extraction text - MANUAL = 1, 'Manual Extraction' # manual defined extraction text + AUTO = 0, "Auto Extraction" # NLP defiend extraction text + MANUAL = 1, "Manual Extraction" # manual defined extraction text + page = models.IntegerField(default=0) text_order = models.IntegerField(default=0) - project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='+') - lead = models.ForeignKey(Lead, on_delete=models.CASCADE, related_name='+') + project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name="+") + lead = models.ForeignKey(Lead, on_delete=models.CASCADE, related_name="+") excerpt = models.TextField() prediction_status = models.SmallIntegerField(choices=PredictionStatus.choices, default=PredictionStatus.PENDING) # After successfull prediction @@ -107,7 +119,7 @@ class Type(models.IntegerChoices): is_discarded = models.BooleanField(default=False) def __str__(self): - return f'{self.id}' + return f"{self.id}" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -122,8 +134,8 @@ def get_existing_draft_entry(cls, project: Project, lead: Lead, excerpt: str) -> excerpt=excerpt, ).first() if ( - already_existing_draft_entry and - not already_existing_draft_entry.predictions.filter( + already_existing_draft_entry + and not already_existing_draft_entry.predictions.filter( ~models.Q(model_version__in=AssistedTaggingModelVersion.get_latest_models_version()), ).exists() ): @@ -135,16 +147,15 @@ def clear_data(self): def save_geo_data(self): from geo.filter_set import GeoAreaGqlFilterSet + geo_values = list( - AssistedTaggingPrediction.objects - .filter( - draft_entry=self, - model_version__model__model_id=AssistedTaggingModel.ModelID.GEO.value - ).values_list('value', flat=True) + AssistedTaggingPrediction.objects.filter( + draft_entry=self, model_version__model__model_id=AssistedTaggingModel.ModelID.GEO.value + ).values_list("value", flat=True) ) if geo_values: geo_areas_qs = GeoAreaGqlFilterSet( - data={'titles': geo_values}, + data={"titles": geo_values}, queryset=GeoArea.get_for_project(self.project), ).qs self.related_geoareas.set(geo_areas_qs) @@ -152,13 +163,13 @@ def save_geo_data(self): class AssistedTaggingPrediction(models.Model): class DataType(models.IntegerChoices): - RAW = 0, 'Raw' # data is stored in value - TAG = 1, 'Tag' # data is stored in category + tag + RAW = 0, "Raw" # data is stored in value + TAG = 1, "Tag" # data is stored in category + tag data_type = models.SmallIntegerField(choices=DataType.choices) - model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name='+') - draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='predictions') + model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name="+") + draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name="predictions") # For RAW DataType value = models.CharField(max_length=255, blank=True) # For Tag DataType @@ -193,7 +204,7 @@ class WrongPredictionReview(UserResource): prediction = models.ForeignKey( AssistedTaggingPrediction, on_delete=models.CASCADE, - related_name='wrong_prediction_reviews', + related_name="wrong_prediction_reviews", ) client_id = None # Removing field from UserResource @@ -204,7 +215,7 @@ def __str__(self): class MissingPredictionReview(UserResource): - draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='missing_prediction_reviews') + draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name="missing_prediction_reviews") category = models.ForeignKey(AssistedTaggingModelPredictionTag, on_delete=models.CASCADE, related_name="+") tag = models.ForeignKey(AssistedTaggingModelPredictionTag, on_delete=models.CASCADE, related_name="+") client_id = None # Removing field from UserResource diff --git a/apps/assisted_tagging/mutation.py b/apps/assisted_tagging/mutation.py index abfd134c37..a9f4d9830e 100644 --- a/apps/assisted_tagging/mutation.py +++ b/apps/assisted_tagging/mutation.py @@ -1,18 +1,14 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsDeleteMutation, - mutation_is_not_valid + PsGrapheneMutation, + generate_input_type_for_serializer, + mutation_is_not_valid, ) -from deep.permissions import ProjectPermissions as PP -from .models import ( - DraftEntry, - MissingPredictionReview, - WrongPredictionReview, -) +from .models import DraftEntry, MissingPredictionReview, WrongPredictionReview from .schema import ( DraftEntryType, MissingPredictionReviewType, @@ -20,42 +16,40 @@ ) from .serializers import ( DraftEntryGqlSerializer, - WrongPredictionReviewGqlSerializer, MissingPredictionReviewGqlSerializer, TriggerDraftEntryGqlSerializer, - UpdateDraftEntrySerializer + UpdateDraftEntrySerializer, + WrongPredictionReviewGqlSerializer, ) - DraftEntryInputType = generate_input_type_for_serializer( - 'DraftEntryInputType', + "DraftEntryInputType", serializer_class=DraftEntryGqlSerializer, ) WrongPredictionReviewInputType = generate_input_type_for_serializer( - 'WrongPredictionReviewInputType', + "WrongPredictionReviewInputType", serializer_class=WrongPredictionReviewGqlSerializer, ) MissingPredictionReviewInputType = generate_input_type_for_serializer( - 'MissingPredictionReviewInputType', + "MissingPredictionReviewInputType", serializer_class=MissingPredictionReviewGqlSerializer, ) TriggerAutoDraftEntryInputType = generate_input_type_for_serializer( - "TriggerAutoDraftEntryInputType", - serializer_class=TriggerDraftEntryGqlSerializer + "TriggerAutoDraftEntryInputType", serializer_class=TriggerDraftEntryGqlSerializer ) UpdateDraftEntryInputType = generate_input_type_for_serializer( - "UpdateDraftEntryInputType", - serializer_class=UpdateDraftEntrySerializer + "UpdateDraftEntryInputType", serializer_class=UpdateDraftEntrySerializer ) class CreateDraftEntry(PsGrapheneMutation): class Arguments: data = DraftEntryInputType(required=True) + model = DraftEntry serializer_class = DraftEntryGqlSerializer result = graphene.Field(DraftEntryType) @@ -65,6 +59,7 @@ class Arguments: class CreateMissingPredictionReview(PsGrapheneMutation): class Arguments: data = MissingPredictionReviewInputType(required=True) + model = MissingPredictionReview serializer_class = MissingPredictionReviewGqlSerializer result = graphene.Field(MissingPredictionReviewType) @@ -74,6 +69,7 @@ class Arguments: class CreateWrongPredictionReview(PsGrapheneMutation): class Arguments: data = WrongPredictionReviewInputType(required=True) + model = WrongPredictionReview serializer_class = WrongPredictionReviewGqlSerializer result = graphene.Field(MissingPredictionReviewType) @@ -83,6 +79,7 @@ class Arguments: class DeleteMissingPredictionReview(PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = MissingPredictionReview result = graphene.Field(MissingPredictionReviewType) permissions = [PP.Permission.CREATE_ENTRY] @@ -98,6 +95,7 @@ def filter_queryset(cls, qs, info): class DeleteWrongPredictionReview(PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = WrongPredictionReview result = graphene.Field(WrongPredictionReviewType) permissions = [PP.Permission.CREATE_ENTRY] @@ -109,20 +107,22 @@ def filter_queryset(cls, qs, info): created_by=info.context.user, ) + # auto draft_entry_create class TriggerAutoDraftEntry(PsGrapheneMutation): class Arguments: data = TriggerAutoDraftEntryInputType(required=True) + model = DraftEntry serializer_class = TriggerDraftEntryGqlSerializer permissions = [PP.Permission.CREATE_ENTRY] @classmethod def perform_mutate(cls, root, info, **kwargs): - data = kwargs['data'] - serializer = cls.serializer_class(data=data, context={'request': info.context.request}) + data = kwargs["data"] + serializer = cls.serializer_class(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return cls(errors=errors, ok=False) serializer.save() @@ -133,6 +133,7 @@ class UpdateDraftEntry(PsGrapheneMutation): class Arguments: data = UpdateDraftEntryInputType(required=True) id = graphene.ID(required=True) + model = DraftEntry serializer_class = UpdateDraftEntrySerializer result = graphene.Field(DraftEntryType) diff --git a/apps/assisted_tagging/schema.py b/apps/assisted_tagging/schema.py index 8bc4265035..5c62672291 100644 --- a/apps/assisted_tagging/schema.py +++ b/apps/assisted_tagging/schema.py @@ -1,32 +1,27 @@ import graphene +from assisted_tagging.filters import DraftEntryFilterSet +from django.db.models import Prefetch +from geo.schema import ProjectGeoAreaType from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField -from django.db.models import Prefetch -from assisted_tagging.filters import DraftEntryFilterSet - -from utils.graphene.enums import EnumDescription from user_resource.schema import UserResourceMixin -from deep.permissions import ProjectPermissions as PP -from geo.schema import ( - ProjectGeoAreaType, -) +from deep.permissions import ProjectPermissions as PP +from utils.graphene.enums import EnumDescription from utils.graphene.fields import DjangoPaginatedListObjectField from utils.graphene.pagination import NoOrderingPageGraphqlPagination from utils.graphene.types import CustomDjangoListObjectType + +from .enums import AssistedTaggingPredictionDataTypeEnum, DraftEntryPredictionStatusEnum from .models import ( - DraftEntry, AssistedTaggingModel, - AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, + AssistedTaggingModelVersion, AssistedTaggingPrediction, + DraftEntry, MissingPredictionReview, WrongPredictionReview, ) -from .enums import ( - DraftEntryPredictionStatusEnum, - AssistedTaggingPredictionDataTypeEnum, -) # -- Global Level @@ -34,42 +29,40 @@ class AssistedTaggingModelVersionType(DjangoObjectType): class Meta: model = AssistedTaggingModelVersion only_fields = ( - 'id', - 'version', + "id", + "version", ) class AssistedTaggingModelType(DjangoObjectType): - versions = graphene.List( - graphene.NonNull(AssistedTaggingModelVersionType) - ) + versions = graphene.List(graphene.NonNull(AssistedTaggingModelVersionType)) class Meta: model = AssistedTaggingModel only_fields = ( - 'id', - 'name', - 'model_id', + "id", + "name", + "model_id", ) @staticmethod def resolve_versions(root, info, **kwargs): - return root.versions.all() # NOTE: Prefetched + return root.versions.all() # NOTE: Prefetched class AssistedTaggingModelPredictionTagType(DjangoObjectType): - parent_tag = graphene.ID(source='parent_tag_id') + parent_tag = graphene.ID(source="parent_tag_id") class Meta: model = AssistedTaggingModelPredictionTag only_fields = ( - 'id', - 'name', - 'group', - 'tag_id', - 'is_category', - 'is_deprecated', - 'hide_in_analysis_framework_mapping', + "id", + "name", + "group", + "tag_id", + "is_category", + "is_deprecated", + "hide_in_analysis_framework_mapping", ) @@ -80,16 +73,14 @@ class AssistedTaggingRootQueryType(graphene.ObjectType): ) prediction_tag = DjangoObjectField(AssistedTaggingModelPredictionTagType) - prediction_tags = graphene.List( - graphene.NonNull(AssistedTaggingModelPredictionTagType) - ) + prediction_tags = graphene.List(graphene.NonNull(AssistedTaggingModelPredictionTagType)) @staticmethod def resolve_tagging_models(root, info, **kwargs): return AssistedTaggingModel.objects.prefetch_related( Prefetch( - 'versions', - queryset=AssistedTaggingModelVersion.objects.order_by('-version'), + "versions", + queryset=AssistedTaggingModelVersion.objects.order_by("-version"), ), ).all() @@ -100,7 +91,7 @@ def resolve_prediction_tags(root, info, **kwargs): # -- Project Level def get_draft_entry_qs(info): # TODO use dataloader - qs = DraftEntry.objects.filter(project=info.context.active_project).order_by('page', 'text_order') + qs = DraftEntry.objects.filter(project=info.context.active_project).order_by("page", "text_order") if PP.check_permission(info, PP.Permission.VIEW_ENTRY): return qs return qs.none() @@ -114,65 +105,58 @@ def get_draft_entry_with_filter_qs(info, filters): class WrongPredictionReviewType(UserResourceMixin, DjangoObjectType): - prediction = graphene.ID(source='prediction_id', required=True) + prediction = graphene.ID(source="prediction_id", required=True) class Meta: model = WrongPredictionReview - only_fields = ( - 'id', - ) + only_fields = ("id",) class AssistedTaggingPredictionType(DjangoObjectType): - model_version = graphene.ID(source='model_version_id', required=True) - draft_entry = graphene.ID(source='draft_entry_id', required=True) + model_version = graphene.ID(source="model_version_id", required=True) + draft_entry = graphene.ID(source="draft_entry_id", required=True) data_type = graphene.Field(AssistedTaggingPredictionDataTypeEnum, required=True) - data_type_display = EnumDescription(source='get_data_type_display', required=True) - category = graphene.ID(source='category_id') - tag = graphene.ID(source='tag_id') + data_type_display = EnumDescription(source="get_data_type_display", required=True) + category = graphene.ID(source="category_id") + tag = graphene.ID(source="tag_id") class Meta: model = AssistedTaggingPrediction only_fields = ( - 'id', - 'value', - 'prediction', - 'threshold', - 'is_selected', + "id", + "value", + "prediction", + "threshold", + "is_selected", ) - ''' + + """ NOTE: model_version_deepl_model_id and wrong_prediction_review are not included here because they are not used in client - ''' + """ class MissingPredictionReviewType(UserResourceMixin, DjangoObjectType): - category = graphene.ID(source='category_id', required=True) - tag = graphene.ID(source='tag_id', required=True) - draft_entry = graphene.ID(source='draft_entry_id', required=True) + category = graphene.ID(source="category_id", required=True) + tag = graphene.ID(source="tag_id", required=True) + draft_entry = graphene.ID(source="draft_entry_id", required=True) class Meta: model = MissingPredictionReview - only_fields = ( - 'id', - ) + only_fields = ("id",) class DraftEntryType(DjangoObjectType): prediction_status = graphene.Field(DraftEntryPredictionStatusEnum, required=True) - prediction_status_display = EnumDescription(source='get_prediction_status_display', required=True) - prediction_tags = graphene.List( - graphene.NonNull(AssistedTaggingPredictionType) - ) - geo_areas = graphene.List( - graphene.NonNull(ProjectGeoAreaType) - ) + prediction_status_display = EnumDescription(source="get_prediction_status_display", required=True) + prediction_tags = graphene.List(graphene.NonNull(AssistedTaggingPredictionType)) + geo_areas = graphene.List(graphene.NonNull(ProjectGeoAreaType)) class Meta: model = DraftEntry only_fields = ( - 'id', - 'excerpt', - 'prediction_received_at', + "id", + "excerpt", + "prediction_received_at", ) @staticmethod @@ -202,7 +186,7 @@ class AssistedTaggingQueryType(graphene.ObjectType): draft_entries = DjangoPaginatedListObjectField( DraftEntryListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', + page_size_query_param="pageSize", ), ) diff --git a/apps/assisted_tagging/serializers.py b/apps/assisted_tagging/serializers.py index a0baecaa2b..589fe27c33 100644 --- a/apps/assisted_tagging/serializers.py +++ b/apps/assisted_tagging/serializers.py @@ -1,8 +1,10 @@ +from analysis_framework.models import Widget from django.db import transaction +from lead.models import Lead from rest_framework import serializers -from user_resource.serializers import UserResourceSerializer, UserResourceCreatedMixin +from user_resource.serializers import UserResourceCreatedMixin, UserResourceSerializer + from deep.serializers import ProjectPropertySerializerMixin, TempClientIdMixin -from analysis_framework.models import Widget from .models import ( DraftEntry, @@ -10,10 +12,9 @@ PredictionTagAnalysisFrameworkWidgetMapping, WrongPredictionReview, ) -from lead.models import Lead from .tasks import ( + trigger_request_for_auto_draft_entry_task, trigger_request_for_draft_entry_task, - trigger_request_for_auto_draft_entry_task ) @@ -21,10 +22,10 @@ class DraftEntryBaseSerializer(serializers.Serializer): def validate_lead(self, lead): if lead.project != self.project: - raise serializers.ValidationError('Only lead from current project are allowed.') + raise serializers.ValidationError("Only lead from current project are allowed.") af = lead.project.analysis_framework if af is None or not af.assisted_tagging_enabled: - raise serializers.ValidationError('Assisted tagging is disabled for the Framework used by this project.') + raise serializers.ValidationError("Assisted tagging is disabled for the Framework used by this project.") return lead @@ -37,26 +38,24 @@ class DraftEntryGqlSerializer( class Meta: model = DraftEntry fields = ( - 'lead', - 'excerpt', + "lead", + "excerpt", ) def create(self, data): # Use already existing draft entry if found - project = data['lead'].project + project = data["lead"].project already_existing_draft_entry = DraftEntry.get_existing_draft_entry( project, - data['lead'], - data['excerpt'], + data["lead"], + data["excerpt"], ) if already_existing_draft_entry: return already_existing_draft_entry # Create new one and send trigger to deepl. - data['project'] = project + data["project"] = project instance = super().create(data) - transaction.on_commit( - lambda: trigger_request_for_draft_entry_task.delay(instance.pk) - ) + transaction.on_commit(lambda: trigger_request_for_draft_entry_task.delay(instance.pk)) return instance def update(self, *_): @@ -66,18 +65,16 @@ def update(self, *_): class WrongPredictionReviewGqlSerializer(UserResourceSerializer, serializers.ModelSerializer): class Meta: model = WrongPredictionReview - fields = ( - 'prediction', - ) + fields = ("prediction",) def validate_prediction(self, prediction): - if prediction.draft_entry.project != self.context['request'].active_project: - raise serializers.ValidationError('Prediction not part of the active project.') + if prediction.draft_entry.project != self.context["request"].active_project: + raise serializers.ValidationError("Prediction not part of the active project.") return prediction def validate(self, data): - if self.instance and self.instance.created_by != self.context['request'].user: - raise serializers.ValidationError('Only reviewer can edit this review') + if self.instance and self.instance.created_by != self.context["request"].user: + raise serializers.ValidationError("Only reviewer can edit this review") return data @@ -85,19 +82,19 @@ class MissingPredictionReviewGqlSerializer(UserResourceSerializer): class Meta: model = MissingPredictionReview fields = ( - 'draft_entry', - 'tag', - 'category', + "draft_entry", + "tag", + "category", ) def validate_draft_entry(self, draft_entry): - if draft_entry.project != self.context['request'].active_project: - raise serializers.ValidationError('Draft Entry not part of the active project.') + if draft_entry.project != self.context["request"].active_project: + raise serializers.ValidationError("Draft Entry not part of the active project.") return draft_entry def validate(self, data): - if self.instance and self.instance.created_by != self.context['request'].user: - raise serializers.ValidationError('Only reviewer can edit this review') + if self.instance and self.instance.created_by != self.context["request"].user: + raise serializers.ValidationError("Only reviewer can edit this review") return data @@ -110,26 +107,22 @@ class PredictionTagAnalysisFrameworkMapSerializer(TempClientIdMixin, serializers class Meta: model = PredictionTagAnalysisFrameworkWidgetMapping fields = ( - 'id', - 'widget', - 'tag', - 'association', - 'client_id', # From TempClientIdMixin + "id", + "widget", + "tag", + "association", + "client_id", # From TempClientIdMixin ) def validate(self, data): - tag = data.get('tag', self.instance and self.instance.tag) - association = data.get('association', self.instance and self.instance.association) - widget = data.get('widget', self.instance and self.instance.widget) + tag = data.get("tag", self.instance and self.instance.tag) + association = data.get("association", self.instance and self.instance.association) + widget = data.get("widget", self.instance and self.instance.widget) skip_tag = widget.widget_id in self.TAG_NOT_REQUIRED_FOR_WIDGET_TYPE if tag is None and not skip_tag: - raise serializers.ValidationError(dict( - tag='Tag is required for this widget.' - )) + raise serializers.ValidationError(dict(tag="Tag is required for this widget.")) if association is None and not skip_tag: - raise serializers.ValidationError(dict( - association='Association is required for this widget.' - )) + raise serializers.ValidationError(dict(association="Association is required for this widget.")) return data @@ -141,28 +134,24 @@ class TriggerDraftEntryGqlSerializer( ): class Meta: model = DraftEntry - fields = ( - 'lead', - ) + fields = ("lead",) def create(self, data): - lead = data['lead'] + lead = data["lead"] if lead.leadpreview.text_extraction_id is None: raise serializers.DjangoValidationError("Assisted tagging is not available in old lead") if lead.auto_entry_extraction_status == Lead.AutoExtractionStatus.SUCCESS: raise serializers.DjangoValidationError("Already Triggered") if not lead.leadpreview.text_extract: - raise serializers.DjangoValidationError('Simplifed Text is empty') + raise serializers.DjangoValidationError("Simplifed Text is empty") draft_entry_qs = DraftEntry.objects.filter(lead=lead, type=DraftEntry.Type.AUTO) if draft_entry_qs.exists(): - raise serializers.DjangoValidationError('Draft entry already exists') + raise serializers.DjangoValidationError("Draft entry already exists") # Use already existing draft entry if found # Create new one and send trigger to deepl lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.PENDING - lead.save(update_fields=['auto_entry_extraction_status']) - transaction.on_commit( - lambda: trigger_request_for_auto_draft_entry_task.delay(lead.id) - ) + lead.save(update_fields=["auto_entry_extraction_status"]) + transaction.on_commit(lambda: trigger_request_for_auto_draft_entry_task.delay(lead.id)) return True def update(self, instance, validate_data): @@ -170,13 +159,13 @@ def update(self, instance, validate_data): class UpdateDraftEntrySerializer( - DraftEntryBaseSerializer, ProjectPropertySerializerMixin, UserResourceSerializer, serializers.ModelSerializer + DraftEntryBaseSerializer, ProjectPropertySerializerMixin, UserResourceSerializer, serializers.ModelSerializer ): class Meta: model = DraftEntry fields = ( - 'lead', - 'is_discarded', + "lead", + "is_discarded", ) def create(self, _): diff --git a/apps/assisted_tagging/tasks.py b/apps/assisted_tagging/tasks.py index c809c30bd4..36f1bd6707 100644 --- a/apps/assisted_tagging/tasks.py +++ b/apps/assisted_tagging/tasks.py @@ -1,34 +1,30 @@ import logging -import requests +import requests from celery import shared_task -from lead.models import Lead - -from utils.common import redis_lock -from deep.deepl import DeeplServiceEndpoint from deepl_integration.handlers import ( AssistedTaggingDraftEntryHandler, AutoAssistedTaggingDraftEntryHandler, - BaseHandler as DeepHandler ) +from deepl_integration.handlers import BaseHandler as DeepHandler +from lead.models import Lead + +from deep.deepl import DeeplServiceEndpoint +from utils.common import redis_lock from .models import ( - DraftEntry, AssistedTaggingModel, - AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, + AssistedTaggingModelVersion, + DraftEntry, ) - logger = logging.getLogger(__name__) def sync_tags_with_deepl(): def _get_existing_tags_by_tagid(): - return { - tag.tag_id: tag # tag_id is from deepl - for tag in AssistedTaggingModelPredictionTag.objects.all() - } + return {tag.tag_id: tag for tag in AssistedTaggingModelPredictionTag.objects.all()} # tag_id is from deepl response = requests.get(DeeplServiceEndpoint.ASSISTED_TAGGING_TAGS_ENDPOINT, headers=DeepHandler.REQUEST_HEADERS).json() existing_tags_by_tagid = _get_existing_tags_by_tagid() @@ -37,10 +33,10 @@ def _get_existing_tags_by_tagid(): updated_tags = [] for tag_id, tag_meta in response.items(): assisted_tag = existing_tags_by_tagid.get(tag_id, AssistedTaggingModelPredictionTag()) - assisted_tag.name = tag_meta['label'] - assisted_tag.group = tag_meta.get('group') - assisted_tag.is_category = tag_meta['is_category'] - assisted_tag.hide_in_analysis_framework_mapping = tag_meta['hide_in_analysis_framework_mapping'] + assisted_tag.name = tag_meta["label"] + assisted_tag.group = tag_meta.get("group") + assisted_tag.is_category = tag_meta["is_category"] + assisted_tag.hide_in_analysis_framework_mapping = tag_meta["hide_in_analysis_framework_mapping"] if assisted_tag.pk: updated_tags.append(assisted_tag) else: @@ -53,60 +49,57 @@ def _get_existing_tags_by_tagid(): AssistedTaggingModelPredictionTag.objects.bulk_update( updated_tags, fields=( - 'name', - 'group', - 'is_category', - 'hide_in_analysis_framework_mapping', - ) + "name", + "group", + "is_category", + "hide_in_analysis_framework_mapping", + ), ) # For parent relation updated_tags = [] existing_tags_by_tagid = _get_existing_tags_by_tagid() for tag_id, tag_meta in response.items(): - if tag_meta.get('parent_id') is None: + if tag_meta.get("parent_id") is None: continue assisted_tag = existing_tags_by_tagid[tag_id] - parent_tag = existing_tags_by_tagid[tag_meta['parent_id']] + parent_tag = existing_tags_by_tagid[tag_meta["parent_id"]] if parent_tag.pk == assisted_tag.parent_tag_id: continue assisted_tag.parent_tag = parent_tag updated_tags.append(assisted_tag) if updated_tags: - AssistedTaggingModelPredictionTag.objects.bulk_update( - updated_tags, - fields=('parent_tag',) - ) + AssistedTaggingModelPredictionTag.objects.bulk_update(updated_tags, fields=("parent_tag",)) def sync_models_with_deepl(): models_data = requests.get(DeeplServiceEndpoint.ASSISTED_TAGGING_MODELS_ENDPOINT).json() for model_meta in models_data.values(): assisted_model, _ = AssistedTaggingModel.objects.get_or_create( - model_id=model_meta['id'], + model_id=model_meta["id"], ) AssistedTaggingModelVersion.objects.get_or_create( model=assisted_model, - version=model_meta['version'], + version=model_meta["version"], ) @shared_task -@redis_lock('trigger_request_for_draft_entry_task_{0}', 60 * 60 * 0.5) +@redis_lock("trigger_request_for_draft_entry_task_{0}", 60 * 60 * 0.5) def trigger_request_for_draft_entry_task(draft_entry_id): draft_entry = DraftEntry.objects.get(pk=draft_entry_id) return AssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry) @shared_task -@redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5) +@redis_lock("trigger_request_for_auto_draft_entry_task_{0}", 60 * 60 * 0.5) def trigger_request_for_auto_draft_entry_task(lead_id): lead = Lead.objects.get(id=lead_id) return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead) @shared_task -@redis_lock('sync_tags_with_deepl_task', 60 * 60 * 0.5) +@redis_lock("sync_tags_with_deepl_task", 60 * 60 * 0.5) def sync_tags_with_deepl_task(): return ( sync_tags_with_deepl(), diff --git a/apps/assisted_tagging/tests/test_query.py b/apps/assisted_tagging/tests/test_query.py index 19601d9fa2..492af0e2fb 100644 --- a/apps/assisted_tagging/tests/test_query.py +++ b/apps/assisted_tagging/tests/test_query.py @@ -1,41 +1,36 @@ from unittest.mock import patch -from snapshottest.django import TestCase as SnapShotTextCase - -from utils.graphene.tests import GraphQLTestCase -from deep.tests import TestCase -from assisted_tagging.models import ( - AssistedTaggingPrediction, +from assisted_tagging.factories import ( + AssistedTaggingModelFactory, + AssistedTaggingModelPredictionTagFactory, + AssistedTaggingModelVersionFactory, + AssistedTaggingPredictionFactory, + DraftEntryFactory, + MissingPredictionReviewFactory, ) - -from deepl_integration.handlers import AssistedTaggingDraftEntryHandler -from assisted_tagging.tasks import sync_tags_with_deepl from assisted_tagging.models import ( AssistedTaggingModel, - AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, + AssistedTaggingModelVersion, + AssistedTaggingPrediction, DraftEntry, ) - +from assisted_tagging.tasks import sync_tags_with_deepl +from deepl_integration.handlers import AssistedTaggingDraftEntryHandler +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory from lead.factories import LeadFactory -from user.factories import UserFactory from project.factories import ProjectFactory -from geo.factories import RegionFactory, AdminLevelFactory, GeoAreaFactory +from snapshottest.django import TestCase as SnapShotTextCase +from user.factories import UserFactory -from assisted_tagging.factories import ( - AssistedTaggingModelFactory, - AssistedTaggingModelPredictionTagFactory, - AssistedTaggingModelVersionFactory, - DraftEntryFactory, - AssistedTaggingPredictionFactory, - MissingPredictionReviewFactory, -) +from deep.tests import TestCase +from utils.graphene.tests import GraphQLTestCase class TestAssistedTaggingQuery(GraphQLTestCase): ENABLE_NOW_PATCHER = True - ASSISTED_TAGGING_NLP_DATA = ''' + ASSISTED_TAGGING_NLP_DATA = """ query MyQuery ($taggingModelId: ID!, $predictionTag: ID!) { assistedTagging { predictionTags { @@ -76,9 +71,9 @@ class TestAssistedTaggingQuery(GraphQLTestCase): } } } - ''' + """ - ASSISTED_TAGGING_DRAFT_ENTRY = ''' + ASSISTED_TAGGING_DRAFT_ENTRY = """ query MyQuery ($projectId: ID!, $draftEntryId: ID!) { project(id: $projectId) { assistedTagging { @@ -104,7 +99,7 @@ class TestAssistedTaggingQuery(GraphQLTestCase): } } } - ''' + """ def test_unified_connector_nlp_data(self): user = UserFactory.create() @@ -130,57 +125,69 @@ def test_unified_connector_nlp_data(self): variables=dict( taggingModelId=model1.id, predictionTag=tag1.id, - ) - )['data']['assistedTagging'] - self.assertEqual(content['predictionTags'], [ + ), + )["data"]["assistedTagging"] + self.assertEqual( + content["predictionTags"], + [ + dict( + id=str(tag.id), + tagId=tag.tag_id, + isDeprecated=tag.is_deprecated, + isCategory=tag.is_category, + group=tag.group, + hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping, + parentTag=tag.parent_tag_id and str(tag.parent_tag_id), + ) + for tag in [tag1, *other_tags] + ], + ) + self.assertEqual( + content["predictionTag"], dict( - id=str(tag.id), - tagId=tag.tag_id, - isDeprecated=tag.is_deprecated, - isCategory=tag.is_category, - group=tag.group, - hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping, - parentTag=tag.parent_tag_id and str(tag.parent_tag_id), - ) - for tag in [tag1, *other_tags] - ]) - self.assertEqual(content['predictionTag'], dict( - id=str(tag1.id), - tagId=tag1.tag_id, - isDeprecated=tag1.is_deprecated, - isCategory=tag1.is_category, - group=tag1.group, - hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping, - parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id), - )) + id=str(tag1.id), + tagId=tag1.tag_id, + isDeprecated=tag1.is_deprecated, + isCategory=tag1.is_category, + group=tag1.group, + hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping, + parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id), + ), + ) - self.assertEqual(content['taggingModels'], [ + self.assertEqual( + content["taggingModels"], + [ + dict( + id=str(_model.id), + modelId=_model.model_id, + name=_model.name, + versions=[ + dict( + id=str(model_version.id), + version=str(model_version.version), + ) + for model_version in _model.versions.order_by("-version").all() + ], + ) + for _model in [model1, *other_models] + ], + ) + self.assertEqual( + content["taggingModel"], dict( - id=str(_model.id), - modelId=_model.model_id, - name=_model.name, + id=str(model1.id), + modelId=model1.model_id, + name=model1.name, versions=[ dict( id=str(model_version.id), version=str(model_version.version), ) - for model_version in _model.versions.order_by('-version').all() + for model_version in model1.versions.all() ], - ) - for _model in [model1, *other_models] - ]) - self.assertEqual(content['taggingModel'], dict( - id=str(model1.id), - modelId=model1.model_id, - name=model1.name, - versions=[ - dict( - id=str(model_version.id), - version=str(model_version.version), - ) - for model_version in model1.versions.all() - ], - )) + ), + ) def test_unified_connector_draft_entry(self): project = ProjectFactory.create() @@ -193,16 +200,16 @@ def test_unified_connector_draft_entry(self): project.regions.add(region) self.maxDiff = None - GeoAreaFactory.create(admin_level=admin_level, title='Nepal') - GeoAreaFactory.create(admin_level=admin_level, title='Bagmati') - GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu') + GeoAreaFactory.create(admin_level=admin_level, title="Nepal") + GeoAreaFactory.create(admin_level=admin_level, title="Bagmati") + GeoAreaFactory.create(admin_level=admin_level, title="Kathmandu") model1 = AssistedTaggingModelFactory.create() geo_model = AssistedTaggingModelFactory.create(model_id=AssistedTaggingModel.ModelID.GEO) latest_model1_version = AssistedTaggingModelVersionFactory.create_batch(2, model=model1)[0] latest_geo_model_version = AssistedTaggingModelVersionFactory.create(model=geo_model) category1, tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5) - draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt='sample excerpt') + draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt="sample excerpt") prediction1 = AssistedTaggingPredictionFactory.create( data_type=AssistedTaggingPrediction.DataType.TAG, @@ -218,14 +225,14 @@ def test_unified_connector_draft_entry(self): data_type=AssistedTaggingPrediction.DataType.RAW, model_version=latest_geo_model_version, draft_entry=draft_entry1, - value='Nepal', + value="Nepal", is_selected=True, ) prediction3 = AssistedTaggingPredictionFactory.create( data_type=AssistedTaggingPrediction.DataType.RAW, model_version=latest_geo_model_version, draft_entry=draft_entry1, - value='Kathmandu', + value="Kathmandu", is_selected=True, ) draft_entry1.save_geo_data() @@ -246,56 +253,58 @@ def _query_check(**kwargs): # -- with login (non-member) self.force_login(another_user) content = _query_check() - self.assertIsNone(content['data']['project']['assistedTagging']) + self.assertIsNone(content["data"]["project"]["assistedTagging"]) # -- with login (member) self.force_login(user) - content = _query_check()['data']['project']['assistedTagging']['draftEntry'] - self.assertEqual(content, dict( - id=str(draft_entry1.pk), - excerpt=draft_entry1.excerpt, - predictionReceivedAt=None, - predictionStatus=self.genum(draft_entry1.prediction_status), - predictionStatusDisplay=draft_entry1.get_prediction_status_display(), - predictionTags=[ - dict( - id=str(prediction1.pk), - modelVersion=str(prediction1.model_version_id), - dataType=self.genum(prediction1.data_type), - dataTypeDisplay=prediction1.get_data_type_display(), - value='', - category=str(prediction1.category_id), - tag=str(prediction1.tag_id), - ), - dict( - id=str(prediction2.id), - modelVersion=str(prediction2.model_version.id), - dataType=self.genum(prediction2.data_type), - dataTypeDisplay=prediction2.get_data_type_display(), - value=prediction2.value, - category=None, - tag=None, - ), - dict( - id=str(prediction3.id), - modelVersion=str(prediction3.model_version.id), - dataType=self.genum(prediction3.data_type), - dataTypeDisplay=prediction3.get_data_type_display(), - value=prediction3.value, - category=None, - tag=None, - ) - ], - geoAreas=[ - dict( - title='Nepal', - ), - dict( - title='Kathmandu', - ) - - ], - )) + content = _query_check()["data"]["project"]["assistedTagging"]["draftEntry"] + self.assertEqual( + content, + dict( + id=str(draft_entry1.pk), + excerpt=draft_entry1.excerpt, + predictionReceivedAt=None, + predictionStatus=self.genum(draft_entry1.prediction_status), + predictionStatusDisplay=draft_entry1.get_prediction_status_display(), + predictionTags=[ + dict( + id=str(prediction1.pk), + modelVersion=str(prediction1.model_version_id), + dataType=self.genum(prediction1.data_type), + dataTypeDisplay=prediction1.get_data_type_display(), + value="", + category=str(prediction1.category_id), + tag=str(prediction1.tag_id), + ), + dict( + id=str(prediction2.id), + modelVersion=str(prediction2.model_version.id), + dataType=self.genum(prediction2.data_type), + dataTypeDisplay=prediction2.get_data_type_display(), + value=prediction2.value, + category=None, + tag=None, + ), + dict( + id=str(prediction3.id), + modelVersion=str(prediction3.model_version.id), + dataType=self.genum(prediction3.data_type), + dataTypeDisplay=prediction3.get_data_type_display(), + value=prediction3.value, + category=None, + tag=None, + ), + ], + geoAreas=[ + dict( + title="Nepal", + ), + dict( + title="Kathmandu", + ), + ], + ), + ) class AssistedTaggingCallbackApiTest(TestCase, SnapShotTextCase): @@ -315,934 +324,743 @@ class AssistedTaggingCallbackApiTest(TestCase, SnapShotTextCase): "client_id": "random-client-id", "model_tags": { "1": { - "101": { - "prediction": 0.002, - "threshold": 0.14, - "is_selected": False - }, - "102": { - "prediction": 0.648, - "threshold": 0.17, - "is_selected": True - }, - "103": { - "prediction": 0.027, - "threshold": 0.1, - "is_selected": False - }, - "104": { - "prediction": 0.062, - "threshold": 0.14, - "is_selected": False - } + "101": {"prediction": 0.002, "threshold": 0.14, "is_selected": False}, + "102": {"prediction": 0.648, "threshold": 0.17, "is_selected": True}, + "103": {"prediction": 0.027, "threshold": 0.1, "is_selected": False}, + "104": {"prediction": 0.062, "threshold": 0.14, "is_selected": False}, }, "3": { - "301": { - "prediction": 0.001, - "threshold": 0.01, - "is_selected": False - }, - "302": { - "prediction": 0.001, - "threshold": 0.11, - "is_selected": False - }, - "303": { - "prediction": 0.083, - "threshold": 0.38, - "is_selected": False - }, - "304": { - "prediction": 0.086, - "threshold": 0.01, - "is_selected": True - }, - "315": { - "prediction": 0.003, - "threshold": 0.45, - "is_selected": False - }, - "316": { - "prediction": 0.001, - "threshold": 0.06, - "is_selected": False - }, - "317": { - "prediction": 0.004, - "threshold": 0.28, - "is_selected": False - }, - "318": { - "prediction": 0.0, - "threshold": 0.13, - "is_selected": False - } + "301": {"prediction": 0.001, "threshold": 0.01, "is_selected": False}, + "302": {"prediction": 0.001, "threshold": 0.11, "is_selected": False}, + "303": {"prediction": 0.083, "threshold": 0.38, "is_selected": False}, + "304": {"prediction": 0.086, "threshold": 0.01, "is_selected": True}, + "315": {"prediction": 0.003, "threshold": 0.45, "is_selected": False}, + "316": {"prediction": 0.001, "threshold": 0.06, "is_selected": False}, + "317": {"prediction": 0.004, "threshold": 0.28, "is_selected": False}, + "318": {"prediction": 0.0, "threshold": 0.13, "is_selected": False}, }, "2": { - "219": { - "prediction": 0.003, - "threshold": 0.13, - "is_selected": False - }, - "217": { - "prediction": 0.001, - "threshold": 0.04, - "is_selected": False - }, - "218": { - "prediction": 0.004, - "threshold": 0.09, - "is_selected": False - }, - "204": { - "prediction": 0.007, - "threshold": 0.14, - "is_selected": False - }, - "216": { - "prediction": 0.003, - "threshold": 0.13, - "is_selected": False - }, - "214": { - "prediction": 0.001, - "threshold": 0.09, - "is_selected": False - }, - "209": { - "prediction": 0.458, - "threshold": 0.05, - "is_selected": True - } + "219": {"prediction": 0.003, "threshold": 0.13, "is_selected": False}, + "217": {"prediction": 0.001, "threshold": 0.04, "is_selected": False}, + "218": {"prediction": 0.004, "threshold": 0.09, "is_selected": False}, + "204": {"prediction": 0.007, "threshold": 0.14, "is_selected": False}, + "216": {"prediction": 0.003, "threshold": 0.13, "is_selected": False}, + "214": {"prediction": 0.001, "threshold": 0.09, "is_selected": False}, + "209": {"prediction": 0.458, "threshold": 0.05, "is_selected": True}, }, "6": { - "601": { - "prediction": 0.0, - "threshold": 0.06, - "is_selected": False - }, - "602": { - "prediction": 0.001, - "threshold": 0.48, - "is_selected": False - }, - "603": { - "prediction": 0.022, - "threshold": 0.34, - "is_selected": False - }, - "604": { - "prediction": 0.0, - "threshold": 0.16, - "is_selected": False - } + "601": {"prediction": 0.0, "threshold": 0.06, "is_selected": False}, + "602": {"prediction": 0.001, "threshold": 0.48, "is_selected": False}, + "603": {"prediction": 0.022, "threshold": 0.34, "is_selected": False}, + "604": {"prediction": 0.0, "threshold": 0.16, "is_selected": False}, }, "5": { - "501": { - "prediction": 0.0, - "threshold": 0.45, - "is_selected": False - }, - "502": { - "prediction": 0.0, - "threshold": 0.48, - "is_selected": False - } + "501": {"prediction": 0.0, "threshold": 0.45, "is_selected": False}, + "502": {"prediction": 0.0, "threshold": 0.48, "is_selected": False}, }, "8": { - "801": { - "prediction": 0.0, - "threshold": 0.66, - "is_selected": False - }, - "802": { - "prediction": 0.0, - "threshold": 0.3, - "is_selected": False - } + "801": {"prediction": 0.0, "threshold": 0.66, "is_selected": False}, + "802": {"prediction": 0.0, "threshold": 0.3, "is_selected": False}, }, "4": { - "401": { - "prediction": 0.001, - "threshold": 0.29, - "is_selected": False - }, - "402": { - "prediction": 0.001, - "threshold": 0.45, - "is_selected": False - }, - "407": { - "prediction": 0.0, - "threshold": 0.07, - "is_selected": False - }, - "408": { - "prediction": 0.001, - "threshold": 0.11, - "is_selected": False - }, - "412": { - "prediction": 0.0, - "threshold": 0.36, - "is_selected": False - } - }, - "7": { - "701": { - "prediction": 0.008, - "threshold": 0.27, - "is_selected": False - } + "401": {"prediction": 0.001, "threshold": 0.29, "is_selected": False}, + "402": {"prediction": 0.001, "threshold": 0.45, "is_selected": False}, + "407": {"prediction": 0.0, "threshold": 0.07, "is_selected": False}, + "408": {"prediction": 0.001, "threshold": 0.11, "is_selected": False}, + "412": {"prediction": 0.0, "threshold": 0.36, "is_selected": False}, }, + "7": {"701": {"prediction": 0.008, "threshold": 0.27, "is_selected": False}}, "9": { - "904": { - "prediction": -1, - "threshold": -1, - "is_selected": False - }, - "905": { - "prediction": -1, - "threshold": -1, - "is_selected": False - }, - "907": { - "prediction": -1, - "threshold": -1, - "is_selected": False - } - } - }, - "geolocations": [ - "New York" - ], - "model_info": { - "id": "all_tags_model", - "version": "1.0.0" + "904": {"prediction": -1, "threshold": -1, "is_selected": False}, + "905": {"prediction": -1, "threshold": -1, "is_selected": False}, + "907": {"prediction": -1, "threshold": -1, "is_selected": False}, + }, }, - "prediction_status": True + "geolocations": ["New York"], + "model_info": {"id": "all_tags_model", "version": "1.0.0"}, + "prediction_status": True, } DEEPL_TAGS_MOCK_RESPONSE = { - '101': { - 'label': 'Agriculture', - 'group': 'Sectors', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '1', - }, - '102': { - 'label': 'Cross', - 'group': 'Sectors', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '1', - }, - '103': { - 'label': 'Education', - 'group': 'Sectors', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '1', - }, - '104': { - 'label': 'Food Security', - 'group': 'Sectors', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '1', - }, - '201': { - 'label': 'Environment', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '202': { - 'label': 'Socio Cultural', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '203': { - 'label': 'Economy', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '204': { - 'label': 'Demography', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '205': { - 'label': 'Legal & Policy', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '206': { - 'label': 'Security & Stability', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '207': { - 'label': 'Politics', - 'group': 'Context', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '208': { - 'label': 'Type And Characteristics', - 'group': 'Shock/Event', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '209': { - 'label': 'Underlying/Aggravating Factors', - 'group': 'Shock/Event', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '210': { - 'label': 'Hazard & Threats', - 'group': 'Shock/Event', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '212': { - 'label': 'Type/Numbers/Movements', - 'group': 'Displacement', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '213': { - 'label': 'Push Factors', - 'group': 'Displacement', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '214': { - 'label': 'Pull Factors', - 'group': 'Displacement', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '215': { - 'label': 'Intentions', - 'group': 'Displacement', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '216': { - 'label': 'Local Integration', - 'group': 'Displacement', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '217': { - 'label': 'Injured', - 'group': 'Casualties', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '218': { - 'label': 'Missing', - 'group': 'Casualties', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '219': { - 'label': 'Dead', - 'group': 'Casualties', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '220': { - 'label': 'Relief To Population', - 'group': 'Humanitarian Access', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '221': { - 'label': 'Population To Relief', - 'group': 'Humanitarian Access', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '222': { - 'label': 'Physical Constraints', - 'group': 'Humanitarian Access', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '223': { - 'label': 'Number Of People Facing Humanitarian Access Constraints/Humanitarian Access Gaps', - 'group': 'Humanitarian Access', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '224': { - 'label': 'Communication Means And Preferences', - 'group': 'Information And Communication', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '225': { - 'label': 'Information Challenges And Barriers', - 'group': 'Information And Communication', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '226': { - 'label': 'Knowledge And Info Gaps (Pop)', - 'group': 'Information And Communication', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '227': { - 'label': 'Knowledge And Info Gaps (Hum)', - 'group': 'Information And Communication', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '228': { - 'label': 'Cases', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '229': { - 'label': 'Contact Tracing', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '230': { - 'label': 'Deaths', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '231': { - 'label': 'Hospitalization & Care', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '232': { - 'label': 'Restriction Measures', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '233': { - 'label': 'Testing', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '234': { - 'label': 'Vaccination', - 'group': 'Covid-19', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '2', - }, - '301': { - 'label': 'Number Of People At Risk', - 'group': 'At Risk', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '302': { - 'label': 'Risk And Vulnerabilities', - 'group': 'At Risk', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '303': { - 'label': 'International Response', - 'group': 'Capacities & Response', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '304': { - 'label': 'Local Response', - 'group': 'Capacities & Response', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '305': { - 'label': 'National Response', - 'group': 'Capacities & Response', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '306': { - 'label': 'Number Of People Reached/Response Gaps', - 'group': 'Capacities & Response', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '307': { - 'label': 'Coping Mechanisms', - 'group': 'Humanitarian Conditions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '308': { - 'label': 'Living Standards', - 'group': 'Humanitarian Conditions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '309': { - 'label': 'Number Of People In Need', - 'group': 'Humanitarian Conditions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '310': { - 'label': 'Physical And Mental Well Being', - 'group': 'Humanitarian Conditions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '311': { - 'label': 'Driver/Aggravating Factors', - 'group': 'Impact', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '312': { - 'label': 'Impact On People', - 'group': 'Impact', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '313': { - 'label': 'Impact On Systems, Services And Networks', - 'group': 'Impact', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '314': { - 'label': 'Number Of People Affected', - 'group': 'Impact', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '315': { - 'label': 'Expressed By Humanitarian Staff', - 'group': 'Priority Interventions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '316': { - 'label': 'Expressed By Population', - 'group': 'Priority Interventions', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '317': { - 'label': 'Expressed By Humanitarian Staff', - 'group': 'Priority Needs', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '318': { - 'label': 'Expressed By Population', - 'group': 'Priority Needs', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '3', - }, - '401': { - 'label': 'Child Head of Household', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '402': { - 'label': 'Chronically Ill', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '403': { - 'label': 'Elderly Head of Household', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '404': { - 'label': 'Female Head of Household', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '405': { - 'label': 'GBV survivors', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '406': { - 'label': 'Indigenous people', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '407': { - 'label': 'LGBTQI+', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '408': { - 'label': 'Minorities', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '409': { - 'label': 'Persons with Disability', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '410': { - 'label': 'Pregnant or Lactating Women', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '411': { - 'label': 'Single Women (including Widows)', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '412': { - 'label': 'Unaccompanied or Separated Children', - 'group': 'Specific Needs Group', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '4', - }, - '901': { - 'label': 'Infants/Toddlers (<5 years old) ', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '902': { - 'label': 'Female Children/Youth (5 to 17 years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '903': { - 'label': 'Male Children/Youth (5 to 17 years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '904': { - 'label': 'Female Adult (18 to 59 years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '905': { - 'label': 'Male Adult (18 to 59 years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '906': { - 'label': 'Female Older Persons (60+ years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '907': { - 'label': 'Male Older Persons (60+ years old)', - 'group': 'Demographic Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '9', - }, - '701': { - 'label': 'Critical', - 'group': 'Severity', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '7', - }, - '702': { - 'label': 'Major', - 'group': 'Severity', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '7', - }, - '703': { - 'label': 'Minor Problem', - 'group': 'Severity', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '7', - }, - '704': { - 'label': 'No problem', - 'group': 'Severity', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '7', - }, - '705': { - 'label': 'Of Concern', - 'group': 'Severity', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '7', - }, - '801': { - 'label': 'Asylum Seekers', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '802': { - 'label': 'Host', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '803': { - 'label': 'IDP', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '804': { - 'label': 'Migrants', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '805': { - 'label': 'Refugees', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '806': { - 'label': 'Returnees', - 'group': 'Affected Groups', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '8', - }, - '1001': { - 'label': 'Completely reliable', - 'group': 'Reliability', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '10', - }, - '1002': { - 'label': 'Usually reliable', - 'group': 'Reliability', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '10', - }, - '1003': { - 'label': 'Fairly Reliable', - 'group': 'Reliability', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '10', - }, - '1004': { - 'label': 'Unreliable', - 'group': 'Reliability', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '10', - }, - '501': { - 'label': 'Female', - 'group': 'Gender', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '5', - }, - '502': { - 'label': 'Male', - 'group': 'Gender', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '5', - }, - '601': { - 'label': 'Adult (18 to 59 years old)', - 'group': 'Age', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '6', - }, - '602': { - 'label': 'Children/Youth (5 to 17 years old)', - 'group': 'Age', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '6', - }, - '603': { - 'label': 'Infants/Toddlers (<5 years old)', - 'group': 'Age', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '6', - }, - '604': { - 'label': 'Older Persons (60+ years old)', - 'group': 'Age', - 'hide_in_analysis_framework_mapping': False, - 'is_category': False, - 'parent_id': '6', - }, - '1': { - 'label': 'sectors', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '2': { - 'label': 'subpillars_1d', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '3': { - 'label': 'subpillars_2d', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '6': { - 'label': 'age', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '5': { - 'label': 'gender', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '9': { - 'label': 'demographic_group', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '8': { - 'label': 'affected_groups', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '4': { - 'label': 'specific_needs_groups', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '7': { - 'label': 'severity', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, - '10': { - 'label': 'reliability', - 'is_category': True, - 'hide_in_analysis_framework_mapping': True - }, + "101": { + "label": "Agriculture", + "group": "Sectors", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "1", + }, + "102": { + "label": "Cross", + "group": "Sectors", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "1", + }, + "103": { + "label": "Education", + "group": "Sectors", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "1", + }, + "104": { + "label": "Food Security", + "group": "Sectors", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "1", + }, + "201": { + "label": "Environment", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "202": { + "label": "Socio Cultural", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "203": { + "label": "Economy", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "204": { + "label": "Demography", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "205": { + "label": "Legal & Policy", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "206": { + "label": "Security & Stability", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "207": { + "label": "Politics", + "group": "Context", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "208": { + "label": "Type And Characteristics", + "group": "Shock/Event", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "209": { + "label": "Underlying/Aggravating Factors", + "group": "Shock/Event", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "210": { + "label": "Hazard & Threats", + "group": "Shock/Event", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "212": { + "label": "Type/Numbers/Movements", + "group": "Displacement", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "213": { + "label": "Push Factors", + "group": "Displacement", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "214": { + "label": "Pull Factors", + "group": "Displacement", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "215": { + "label": "Intentions", + "group": "Displacement", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "216": { + "label": "Local Integration", + "group": "Displacement", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "217": { + "label": "Injured", + "group": "Casualties", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "218": { + "label": "Missing", + "group": "Casualties", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "219": { + "label": "Dead", + "group": "Casualties", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "220": { + "label": "Relief To Population", + "group": "Humanitarian Access", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "221": { + "label": "Population To Relief", + "group": "Humanitarian Access", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "222": { + "label": "Physical Constraints", + "group": "Humanitarian Access", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "223": { + "label": "Number Of People Facing Humanitarian Access Constraints/Humanitarian Access Gaps", + "group": "Humanitarian Access", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "224": { + "label": "Communication Means And Preferences", + "group": "Information And Communication", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "225": { + "label": "Information Challenges And Barriers", + "group": "Information And Communication", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "226": { + "label": "Knowledge And Info Gaps (Pop)", + "group": "Information And Communication", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "227": { + "label": "Knowledge And Info Gaps (Hum)", + "group": "Information And Communication", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "228": { + "label": "Cases", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "229": { + "label": "Contact Tracing", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "230": { + "label": "Deaths", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "231": { + "label": "Hospitalization & Care", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "232": { + "label": "Restriction Measures", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "233": { + "label": "Testing", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "234": { + "label": "Vaccination", + "group": "Covid-19", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "2", + }, + "301": { + "label": "Number Of People At Risk", + "group": "At Risk", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "302": { + "label": "Risk And Vulnerabilities", + "group": "At Risk", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "303": { + "label": "International Response", + "group": "Capacities & Response", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "304": { + "label": "Local Response", + "group": "Capacities & Response", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "305": { + "label": "National Response", + "group": "Capacities & Response", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "306": { + "label": "Number Of People Reached/Response Gaps", + "group": "Capacities & Response", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "307": { + "label": "Coping Mechanisms", + "group": "Humanitarian Conditions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "308": { + "label": "Living Standards", + "group": "Humanitarian Conditions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "309": { + "label": "Number Of People In Need", + "group": "Humanitarian Conditions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "310": { + "label": "Physical And Mental Well Being", + "group": "Humanitarian Conditions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "311": { + "label": "Driver/Aggravating Factors", + "group": "Impact", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "312": { + "label": "Impact On People", + "group": "Impact", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "313": { + "label": "Impact On Systems, Services And Networks", + "group": "Impact", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "314": { + "label": "Number Of People Affected", + "group": "Impact", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "315": { + "label": "Expressed By Humanitarian Staff", + "group": "Priority Interventions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "316": { + "label": "Expressed By Population", + "group": "Priority Interventions", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "317": { + "label": "Expressed By Humanitarian Staff", + "group": "Priority Needs", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "318": { + "label": "Expressed By Population", + "group": "Priority Needs", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "3", + }, + "401": { + "label": "Child Head of Household", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "402": { + "label": "Chronically Ill", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "403": { + "label": "Elderly Head of Household", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "404": { + "label": "Female Head of Household", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "405": { + "label": "GBV survivors", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "406": { + "label": "Indigenous people", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "407": { + "label": "LGBTQI+", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "408": { + "label": "Minorities", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "409": { + "label": "Persons with Disability", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "410": { + "label": "Pregnant or Lactating Women", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "411": { + "label": "Single Women (including Widows)", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "412": { + "label": "Unaccompanied or Separated Children", + "group": "Specific Needs Group", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "4", + }, + "901": { + "label": "Infants/Toddlers (<5 years old) ", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "902": { + "label": "Female Children/Youth (5 to 17 years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "903": { + "label": "Male Children/Youth (5 to 17 years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "904": { + "label": "Female Adult (18 to 59 years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "905": { + "label": "Male Adult (18 to 59 years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "906": { + "label": "Female Older Persons (60+ years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "907": { + "label": "Male Older Persons (60+ years old)", + "group": "Demographic Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "9", + }, + "701": { + "label": "Critical", + "group": "Severity", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "7", + }, + "702": { + "label": "Major", + "group": "Severity", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "7", + }, + "703": { + "label": "Minor Problem", + "group": "Severity", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "7", + }, + "704": { + "label": "No problem", + "group": "Severity", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "7", + }, + "705": { + "label": "Of Concern", + "group": "Severity", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "7", + }, + "801": { + "label": "Asylum Seekers", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "802": { + "label": "Host", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "803": { + "label": "IDP", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "804": { + "label": "Migrants", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "805": { + "label": "Refugees", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "806": { + "label": "Returnees", + "group": "Affected Groups", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "8", + }, + "1001": { + "label": "Completely reliable", + "group": "Reliability", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "10", + }, + "1002": { + "label": "Usually reliable", + "group": "Reliability", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "10", + }, + "1003": { + "label": "Fairly Reliable", + "group": "Reliability", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "10", + }, + "1004": { + "label": "Unreliable", + "group": "Reliability", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "10", + }, + "501": { + "label": "Female", + "group": "Gender", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "5", + }, + "502": { + "label": "Male", + "group": "Gender", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "5", + }, + "601": { + "label": "Adult (18 to 59 years old)", + "group": "Age", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "6", + }, + "602": { + "label": "Children/Youth (5 to 17 years old)", + "group": "Age", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "6", + }, + "603": { + "label": "Infants/Toddlers (<5 years old)", + "group": "Age", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "6", + }, + "604": { + "label": "Older Persons (60+ years old)", + "group": "Age", + "hide_in_analysis_framework_mapping": False, + "is_category": False, + "parent_id": "6", + }, + "1": {"label": "sectors", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "2": {"label": "subpillars_1d", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "3": {"label": "subpillars_2d", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "6": {"label": "age", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "5": {"label": "gender", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "9": {"label": "demographic_group", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "8": {"label": "affected_groups", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "4": {"label": "specific_needs_groups", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "7": {"label": "severity", "is_category": True, "hide_in_analysis_framework_mapping": True}, + "10": {"label": "reliability", "is_category": True, "hide_in_analysis_framework_mapping": True}, } def setUp(self): super().setUp() - self.sync_request_mock = patch('assisted_tagging.tasks.requests') + self.sync_request_mock = patch("assisted_tagging.tasks.requests") mock = self.sync_request_mock.start() mock.get.return_value.status_code = 200 mock.get.return_value.json.return_value = self.DEEPL_TAGS_MOCK_RESPONSE @@ -1261,15 +1079,11 @@ def _get_current_model_stats(): model_count=AssistedTaggingModel.objects.count(), model_version_count=AssistedTaggingModelVersion.objects.count(), tag_count=AssistedTaggingModelPredictionTag.objects.count(), - models=list( - AssistedTaggingModel.objects.values('model_id', 'name').order_by('model_id') - ), + models=list(AssistedTaggingModel.objects.values("model_id", "name").order_by("model_id")), model_versions=list( - AssistedTaggingModelVersion.objects.values('model__model_id', 'version').order_by('model__model_id') - ), - tags=list( - AssistedTaggingModelPredictionTag.objects.values('name', 'tag_id', 'is_deprecated').order_by('tag_id') + AssistedTaggingModelVersion.objects.values("model__model_id", "version").order_by("model__model_id") ), + tags=list(AssistedTaggingModelPredictionTag.objects.values("name", "tag_id", "is_deprecated").order_by("tag_id")), ) def _get_current_prediction_stats(): @@ -1277,30 +1091,30 @@ def _get_current_prediction_stats(): prediction_count=AssistedTaggingPrediction.objects.count(), predictions=list( AssistedTaggingPrediction.objects.values( - 'data_type', - 'model_version__model__model_id', - 'draft_entry__excerpt', - 'value', - 'category__tag_id', - 'tag__tag_id', - 'prediction', - 'threshold', - 'is_selected', + "data_type", + "model_version__model__model_id", + "draft_entry__excerpt", + "value", + "category__tag_id", + "tag__tag_id", + "prediction", + "threshold", + "is_selected", ).order_by( - 'data_type', - 'model_version__model__model_id', - 'draft_entry__excerpt', - 'value', - 'category__tag_id', - 'tag__tag_id', - 'prediction', - 'threshold', - 'is_selected', + "data_type", + "model_version__model__model_id", + "draft_entry__excerpt", + "value", + "category__tag_id", + "tag__tag_id", + "prediction", + "threshold", + "is_selected", ) - ) + ), ) - url = '/api/v1/callback/assisted-tagging-draft-entry-prediction/' + url = "/api/v1/callback/assisted-tagging-draft-entry-prediction/" project = ProjectFactory.create() lead = LeadFactory.create(project=project) draft_args = dict( @@ -1310,17 +1124,17 @@ def _get_current_prediction_stats(): ) draft_entry1 = DraftEntryFactory.create( **draft_args, - excerpt='sample excerpt 101', + excerpt="sample excerpt 101", ) draft_entry2 = DraftEntryFactory.create( **draft_args, - excerpt='sample excerpt 102', + excerpt="sample excerpt 102", ) # ------ Invalid entry_id data = { **self.DEEPL_CALLBACK_MOCK_DATA, - 'client_id': 'invalid-id', + "client_id": "invalid-id", } response = self.client.post(url, data) @@ -1330,7 +1144,7 @@ def _get_current_prediction_stats(): # ----- Valid entry_id data = { **self.DEEPL_CALLBACK_MOCK_DATA, - 'client_id': AssistedTaggingDraftEntryHandler.get_client_id(draft_entry1), + "client_id": AssistedTaggingDraftEntryHandler.get_client_id(draft_entry1), } self.maxDiff = None @@ -1352,7 +1166,7 @@ def _get_current_prediction_stats(): # ----- Valid entry_id send with same type of data data = { **self.DEEPL_CALLBACK_MOCK_DATA, - 'client_id': AssistedTaggingDraftEntryHandler.get_client_id(draft_entry2), + "client_id": AssistedTaggingDraftEntryHandler.get_client_id(draft_entry2), } current_model_stats = _get_current_model_stats() @@ -1363,59 +1177,61 @@ def _get_current_prediction_stats(): current_model_stats = _get_current_model_stats() current_prediction_stats = _get_current_prediction_stats() - self.assertMatchSnapshot(current_model_stats, 'final-current-model-stats') - self.assertMatchSnapshot(current_prediction_stats, 'final-current-prediction-stats') + self.assertMatchSnapshot(current_model_stats, "final-current-model-stats") + self.assertMatchSnapshot(current_prediction_stats, "final-current-prediction-stats") def test_tags_sync(self): def _get_current_tags(): return list( AssistedTaggingModelPredictionTag.objects.values( - 'name', - 'group', - 'tag_id', - 'is_deprecated', - 'is_category', - 'hide_in_analysis_framework_mapping', - 'parent_tag__tag_id', - ).order_by('tag_id') + "name", + "group", + "tag_id", + "is_deprecated", + "is_category", + "hide_in_analysis_framework_mapping", + "parent_tag__tag_id", + ).order_by("tag_id") ) self.maxDiff = None self.assertEqual(len(_get_current_tags()), 0) sync_tags_with_deepl() self.assertNotEqual(len(_get_current_tags()), 0) - self.assertMatchSnapshot(_get_current_tags(), 'sync-tags') + self.assertMatchSnapshot(_get_current_tags(), "sync-tags") class TestAssistedTaggingModules(GraphQLTestCase): def test_assisted_tagging_model_version_latest_model_fetch(self): model1, model2, model3 = AssistedTaggingModelFactory.create_batch(3) - model1_v1 = AssistedTaggingModelVersionFactory.create(model=model1, version='v1.0.0') - model1_v1_1 = AssistedTaggingModelVersionFactory.create(model=model1, version='v1.0.1') - model2_v1 = AssistedTaggingModelVersionFactory.create(model=model2, version='v1.0.0') - model3_v0_1 = AssistedTaggingModelVersionFactory.create(model=model2, version='v0.0.1') - model3_v1 = AssistedTaggingModelVersionFactory.create(model=model3, version='v1.0.0') + model1_v1 = AssistedTaggingModelVersionFactory.create(model=model1, version="v1.0.0") + model1_v1_1 = AssistedTaggingModelVersionFactory.create(model=model1, version="v1.0.1") + model2_v1 = AssistedTaggingModelVersionFactory.create(model=model2, version="v1.0.0") + model3_v0_1 = AssistedTaggingModelVersionFactory.create(model=model2, version="v0.0.1") + model3_v1 = AssistedTaggingModelVersionFactory.create(model=model3, version="v1.0.0") latest_models = list(AssistedTaggingModelVersion.get_latest_models_version()) assert model1_v1 not in latest_models assert model3_v0_1 not in latest_models - assert set(latest_models) == set([ - model1_v1_1, - model2_v1, - model3_v1, - ]) + assert set(latest_models) == set( + [ + model1_v1_1, + model2_v1, + model3_v1, + ] + ) def test_get_existing_draft_entry(self): # Model model1, model2, model3 = AssistedTaggingModelFactory.create_batch(3) # Model Versions - model1_v1 = AssistedTaggingModelVersionFactory.create(model=model1, version='v1.0.0') - model1_v1_1 = AssistedTaggingModelVersionFactory.create(model=model1, version='v1.0.1') - model2_v1 = AssistedTaggingModelVersionFactory.create(model=model2, version='v1.0.0') + model1_v1 = AssistedTaggingModelVersionFactory.create(model=model1, version="v1.0.0") + model1_v1_1 = AssistedTaggingModelVersionFactory.create(model=model1, version="v1.0.1") + model2_v1 = AssistedTaggingModelVersionFactory.create(model=model2, version="v1.0.0") project = ProjectFactory.create() lead = LeadFactory.create(project=project) - excerpt = 'test-101' + excerpt = "test-101" draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt=excerpt) category1, tag1 = AssistedTaggingModelPredictionTagFactory.create_batch(2) @@ -1440,11 +1256,14 @@ def test_get_existing_draft_entry(self): **prediction_common_params, ) - assert DraftEntry.get_existing_draft_entry( - project, - lead, - excerpt=excerpt, - ) is None + assert ( + DraftEntry.get_existing_draft_entry( + project, + lead, + excerpt=excerpt, + ) + is None + ) # Clear out predictions draft_entry1.predictions.all().delete() @@ -1459,8 +1278,11 @@ def test_get_existing_draft_entry(self): model_version=model2_v1, **prediction_common_params, ) - assert DraftEntry.get_existing_draft_entry( - project, - lead, - excerpt=excerpt, - ) == draft_entry1 + assert ( + DraftEntry.get_existing_draft_entry( + project, + lead, + excerpt=excerpt, + ) + == draft_entry1 + ) diff --git a/apps/bulk_data_migration/apps.py b/apps/bulk_data_migration/apps.py index 59174491d9..495f8698e0 100644 --- a/apps/bulk_data_migration/apps.py +++ b/apps/bulk_data_migration/apps.py @@ -2,4 +2,4 @@ class BulkDataMigrationConfig(AppConfig): - name = 'bulk_data_migration' + name = "bulk_data_migration" diff --git a/apps/bulk_data_migration/entry_images/migrate.py b/apps/bulk_data_migration/entry_images/migrate.py index a949d4563f..f04c02a91d 100644 --- a/apps/bulk_data_migration/entry_images/migrate.py +++ b/apps/bulk_data_migration/entry_images/migrate.py @@ -1,4 +1,5 @@ from urllib.parse import urljoin + import reversion from entry.models import Entry from entry.utils import base64_to_deep_image @@ -27,15 +28,11 @@ def migrate_entry(entry, root_url): if new_image == image: return - Entry.objects.filter( - id=entry.id - ).update( - image=new_image - ) + Entry.objects.filter(id=entry.id).update(image=new_image) def migrate(*args): - print('This should be already migrated') + print("This should be already migrated") return root_url = args[0] with reversion.create_revision(): diff --git a/apps/bulk_data_migration/entry_images_v2/migrate.py b/apps/bulk_data_migration/entry_images_v2/migrate.py index 8a456902a5..e2f0556931 100644 --- a/apps/bulk_data_migration/entry_images_v2/migrate.py +++ b/apps/bulk_data_migration/entry_images_v2/migrate.py @@ -1,27 +1,26 @@ -from django.db.models import Q from django.conf import settings - -from utils.common import parse_number - -from lead.models import LeadPreviewImage +from django.db.models import Q from entry.models import Entry from gallery.models import File +from lead.models import LeadPreviewImage + +from utils.common import parse_number """ python3 manage.py bulk_migrate entry_images_v2 """ -FILE_API_PREFIX = '{protocol}://{domain}{url}'.format( +FILE_API_PREFIX = "{protocol}://{domain}{url}".format( protocol=settings.HTTP_PROTOCOL, domain=settings.DJANGO_API_HOST, - url='/file/', + url="/file/", ) -S3_URL_PREFIX = f'https://{settings.AWS_STORAGE_BUCKET_NAME_MEDIA}.s3.amazonaws.com/{settings.MEDIAFILES_LOCATION}/' +S3_URL_PREFIX = f"https://{settings.AWS_STORAGE_BUCKET_NAME_MEDIA}.s3.amazonaws.com/{settings.MEDIAFILES_LOCATION}/" def _get_file_from_file_url(entry, string): try: - fileid = parse_number(string.rstrip('/').split('/')[-1]) + fileid = parse_number(string.rstrip("/").split("/")[-1]) except IndexError: return return fileid and File.objects.filter(id=fileid).first() @@ -29,11 +28,11 @@ def _get_file_from_file_url(entry, string): def _get_file_from_s3_url(entry, string): try: - file_path = '/'.join(string.split('?')[0].split('/')[4:]) + file_path = "/".join(string.split("?")[0].split("/")[4:]) except IndexError: return # NOTE: For lead-preview generate gallery files - if file_path.startswith('lead-preview/'): + if file_path.startswith("lead-preview/"): lead_preview = LeadPreviewImage.objects.filter(file=file_path).first() if lead_preview and lead_preview.file and lead_preview.file.storage.exists(lead_preview.file.name): return lead_preview.clone_as_deep_file(entry.created_by) @@ -62,21 +61,21 @@ def migrate_entry(entry): def migrate(*args, **kwargs): qs = Entry.objects.filter( Q(image_raw__isnull=False), - ~Q(image_raw=''), + ~Q(image_raw=""), image__isnull=True, ) total = qs.count() success = 0 index = 1 - print('File string saved:', qs.filter(image_raw__startswith=FILE_API_PREFIX).count()) - print('S3 string saved (lead images):', qs.filter(image_raw__startswith=S3_URL_PREFIX).count()) + print("File string saved:", qs.filter(image_raw__startswith=FILE_API_PREFIX).count()) + print("S3 string saved (lead images):", qs.filter(image_raw__startswith=S3_URL_PREFIX).count()) for entry in qs.iterator(): - print(f'Processing {index} of {total}', end='\r') + print(f"Processing {index} of {total}", end="\r") if migrate_entry(entry): success += 1 index += 1 - print('Summary:') - print(f'\t-Total: {total}') - success and print(f'\t-Success: {success}') - (total - success) and print(f'\t-Failed: {total - success}') + print("Summary:") + print(f"\t-Total: {total}") + success and print(f"\t-Success: {success}") + (total - success) and print(f"\t-Failed: {total - success}") diff --git a/apps/bulk_data_migration/management/commands/bulk_migrate.py b/apps/bulk_data_migration/management/commands/bulk_migrate.py index f06c75a130..84aa8f3863 100644 --- a/apps/bulk_data_migration/management/commands/bulk_migrate.py +++ b/apps/bulk_data_migration/management/commands/bulk_migrate.py @@ -1,23 +1,18 @@ -from django.core.management.base import BaseCommand import importlib +from django.core.management.base import BaseCommand + class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'arg_list', - nargs='+', - ) - parser.add_argument( - '--filters_file', - type=str, - default=None + "arg_list", + nargs="+", ) + parser.add_argument("--filters_file", type=str, default=None) def handle(self, *args, **kwargs): - arg_list = kwargs.pop('arg_list', []) + arg_list = kwargs.pop("arg_list", []) migration_type = arg_list[0] - migrate = importlib.import_module( - 'bulk_data_migration.{}.migrate'.format(migration_type) - ).migrate + migrate = importlib.import_module("bulk_data_migration.{}.migrate".format(migration_type)).migrate migrate(*(arg_list[1:]), **kwargs) diff --git a/apps/bulk_data_migration/management/commands/classify_leads.py b/apps/bulk_data_migration/management/commands/classify_leads.py index 5ca696b93f..4b237353f7 100644 --- a/apps/bulk_data_migration/management/commands/classify_leads.py +++ b/apps/bulk_data_migration/management/commands/classify_leads.py @@ -1,29 +1,29 @@ from django.core.management.base import BaseCommand from django.db.models import Q from django.db.models.functions import Length - from lead.models import Lead from lead.tasks import classify_lead class Command(BaseCommand): - help = 'Classify leads whose preview have been generated but do not have classified_doc_id' + help = "Classify leads whose preview have been generated but do not have classified_doc_id" def handle(self, *args, **options): - leads = Lead.objects.filter( - ~Q(leadpreview=None), - ~Q(leadpreview__text_extract=None), - ~Q(leadpreview__text_extract__regex=r'^\W*$'), - leadpreview__classified_doc_id=None, - ).annotate( - text_len=Length('leadpreview__text_extract') - ).filter( - text_len__lte=5000 # Texts of length 5000 do not pose huge computation in DEEPL - ).prefetch_related('leadpreview')[:50] + leads = ( + Lead.objects.filter( + ~Q(leadpreview=None), + ~Q(leadpreview__text_extract=None), + ~Q(leadpreview__text_extract__regex=r"^\W*$"), + leadpreview__classified_doc_id=None, + ) + .annotate(text_len=Length("leadpreview__text_extract")) + .filter(text_len__lte=5000) # Texts of length 5000 do not pose huge computation in DEEPL + .prefetch_related("leadpreview")[:50] + ) - print('\nNOTE: that only 50 leads will be classified at a time.\n') + print("\nNOTE: that only 50 leads will be classified at a time.\n") for i, lead in enumerate(leads): - print('Classifying lead', lead.id, 'Lead Count:', i + 1) + print("Classifying lead", lead.id, "Lead Count:", i + 1) classify_lead(lead) - print('Complete!!\n') + print("Complete!!\n") diff --git a/apps/bulk_data_migration/management/commands/entries_highlight_migrate.py b/apps/bulk_data_migration/management/commands/entries_highlight_migrate.py index d2be1509e7..e8f3194256 100644 --- a/apps/bulk_data_migration/management/commands/entries_highlight_migrate.py +++ b/apps/bulk_data_migration/management/commands/entries_highlight_migrate.py @@ -1,30 +1,25 @@ -from django.core.management.base import BaseCommand -from django.db.models.functions import StrIndex -from django.db.models import F - import math -from lead.models import Lead +from django.core.management.base import BaseCommand +from django.db.models import F +from django.db.models.functions import StrIndex from entry.models import Entry +from lead.models import Lead class Command(BaseCommand): - help = 'Check if entry text is in lead text and populate dropped_text accordingly.' + help = "Check if entry text is in lead text and populate dropped_text accordingly." def handle(self, *args, **options): chunk_size = 200 - leads = Lead.objects.filter( - leadpreview__text_extract__isnull=False - ) + leads = Lead.objects.filter(leadpreview__text_extract__isnull=False) leads_count = leads.count() total_chunks = math.ceil(leads_count / chunk_size) n = 1 for lead in leads.iterator(chunk_size=chunk_size): - print(f'Updating entries from lead chunk {n} of {total_chunks}') + print(f"Updating entries from lead chunk {n} of {total_chunks}") lead.entry_set.filter(entry_type=Entry.TagType.EXCERPT).annotate( - index=StrIndex('lead__leadpreview__text_extract', F('excerpt')) - ).filter(index__gt=0).update( - dropped_excerpt=F('excerpt') - ) + index=StrIndex("lead__leadpreview__text_extract", F("excerpt")) + ).filter(index__gt=0).update(dropped_excerpt=F("excerpt")) n += 1 - print('Done.') + print("Done.") diff --git a/apps/bulk_data_migration/management/commands/generate_preview.py b/apps/bulk_data_migration/management/commands/generate_preview.py index e5fcc0fc88..f8eeeaf620 100644 --- a/apps/bulk_data_migration/management/commands/generate_preview.py +++ b/apps/bulk_data_migration/management/commands/generate_preview.py @@ -1,21 +1,15 @@ from django.core.management.base import BaseCommand - from lead.tasks import generate_previews class Command(BaseCommand): - help = 'Extract preview/images from leads' + help = "Extract preview/images from leads" def add_arguments(self, parser): - parser.add_argument( - '--lead_id', - nargs='+', - type=int, - help='List of lead ids' - ) + parser.add_argument("--lead_id", nargs="+", type=int, help="List of lead ids") def handle(self, *args, **options): - if options['lead_id']: - generate_previews.delay(options['lead_id']) + if options["lead_id"]: + generate_previews.delay(options["lead_id"]) else: generate_previews.delay() diff --git a/apps/bulk_data_migration/management/commands/update_attribute_for_scale_export.py b/apps/bulk_data_migration/management/commands/update_attribute_for_scale_export.py index 8b6a1c640f..a955825840 100644 --- a/apps/bulk_data_migration/management/commands/update_attribute_for_scale_export.py +++ b/apps/bulk_data_migration/management/commands/update_attribute_for_scale_export.py @@ -1,12 +1,11 @@ from django.core.management.base import BaseCommand - from entry.models import Attribute from entry.utils import update_entry_attribute class Command(BaseCommand): - help = 'Update attributes to export scales' + help = "Update attributes to export scales" def handle(self, *args, **options): - for each in Attribute.objects.filter(widget__widget_id__in=['scaleWidget', 'conditionalWidget']): + for each in Attribute.objects.filter(widget__widget_id__in=["scaleWidget", "conditionalWidget"]): update_entry_attribute(each) diff --git a/apps/bulk_data_migration/management/commands/update_attribute_for_widget.py b/apps/bulk_data_migration/management/commands/update_attribute_for_widget.py index 09d330806e..e1d63e3227 100644 --- a/apps/bulk_data_migration/management/commands/update_attribute_for_widget.py +++ b/apps/bulk_data_migration/management/commands/update_attribute_for_widget.py @@ -1,20 +1,12 @@ -from datetime import datetime, timedelta import time +from datetime import datetime, timedelta from django.core.management.base import BaseCommand -from django.db.models import ( - Q, - Max, - OuterRef, - Subquery, - DateTimeField, - Exists -) - +from django.db.models import DateTimeField, Exists, Max, OuterRef, Q, Subquery from entry.models import Attribute, ExportData from entry.utils import update_entry_attribute -from entry.widgets.store import widget_store from entry.widgets import conditional_widget +from entry.widgets.store import widget_store from lead.models import Lead HIGH = 3 @@ -23,25 +15,23 @@ class Command(BaseCommand): - help = 'Update attributes to export widget' + help = "Update attributes to export widget" def add_arguments(self, parser): parser.add_argument( - '--priority', + "--priority", type=int, - help='Priority based on last activity of leads (high: >{}, medium: {}, low: <{})'.format( - HIGH, MEDIUM, LOW - ), + help="Priority based on last activity of leads (high: >{}, medium: {}, low: <{})".format(HIGH, MEDIUM, LOW), ) parser.add_argument( - '--project', + "--project", type=int, - help='Specific project export data migration', + help="Specific project export data migration", ) parser.add_argument( - '--widget', + "--widget", type=str, - help='Specific widget export data migration', + help="Specific widget export data migration", ) def update_attributes(self, widget, qs): @@ -50,70 +40,75 @@ def update_attributes(self, widget, qs): if widget == conditional_widget.WIDGET_ID: # conditional widget is handled within each overview widget return - current_widget_data_version = getattr(widget_store[widget], 'DATA_VERSION', None) + current_widget_data_version = getattr(widget_store[widget], "DATA_VERSION", None) to_be_changed_export_data_exists = Exists( ExportData.objects.filter( - exportable__analysis_framework=OuterRef('widget__analysis_framework'), - exportable__widget_key=OuterRef('widget__key'), - entry_id=OuterRef('entry') + exportable__analysis_framework=OuterRef("widget__analysis_framework"), + exportable__widget_key=OuterRef("widget__key"), + entry_id=OuterRef("entry"), ).filter( - ~Q(data__has_key='common') | - ~Q(data__common__has_key='version') | - ( - Q(data__has_key='common') & - Q(data__common__has_key='version') & - ~Q(data__common__version=current_widget_data_version) + ~Q(data__has_key="common") + | ~Q(data__common__has_key="version") + | ( + Q(data__has_key="common") + & Q(data__common__has_key="version") + & ~Q(data__common__version=current_widget_data_version) ) ) ) - attribute_qs = qs.filter( - widget__widget_id=widget, - ).annotate( - export_data_exists=to_be_changed_export_data_exists - ).filter(export_data_exists=True) + attribute_qs = ( + qs.filter( + widget__widget_id=widget, + ) + .annotate(export_data_exists=to_be_changed_export_data_exists) + .filter(export_data_exists=True) + ) total_to_process = attribute_qs.count() - print(f'Processing for {widget}. Total attributes to process: {total_to_process}') + print(f"Processing for {widget}. Total attributes to process: {total_to_process}") if total_to_process == 0: # Nothing to do here return for index, attr in enumerate(attribute_qs.iterator(), start=1): - print(f' - {index}/{total_to_process}', end='\r') + print(f" - {index}/{total_to_process}", end="\r") update_entry_attribute(attr) - print(f' - Updated {total_to_process}') + print(f" - Updated {total_to_process}") def handle(self, *args, **options): old = time.time() qs = Attribute.objects.all() - if options.get('project'): - qs = qs.filter(entry__project=options['project']) - elif options.get('priority'): + if options.get("project"): + qs = qs.filter(entry__project=options["project"]) + elif options.get("priority"): today = datetime.today() - priority = options['priority'] + priority = options["priority"] last_30_days_ago = today - timedelta(days=30) last_60_days_ago = today - timedelta(days=60) qs = qs.annotate( - last_lead_added=Subquery(Lead.objects.filter( - project=OuterRef('entry__project_id') - ).order_by().values('project').annotate(max=Max('created_at')).values('max')[:1], - output_field=DateTimeField()) + last_lead_added=Subquery( + Lead.objects.filter(project=OuterRef("entry__project_id")) + .order_by() + .values("project") + .annotate(max=Max("created_at")) + .values("max")[:1], + output_field=DateTimeField(), + ) ) if priority >= HIGH: qs = qs.filter(last_lead_added__gte=last_30_days_ago) elif priority == MEDIUM: - qs = qs.filter(last_lead_added__lt=last_30_days_ago, - last_lead_added__gte=last_60_days_ago) + qs = qs.filter(last_lead_added__lt=last_30_days_ago, last_lead_added__gte=last_60_days_ago) else: qs = qs.filter(last_lead_added__lt=last_60_days_ago) - if options.get('widget') in widget_store.keys(): - widget = options['widget'] + if options.get("widget") in widget_store.keys(): + widget = options["widget"] qs = qs.filter(Q(widget__widget_id=widget) | Q(widget__widget_id=conditional_widget.WIDGET_ID)) self.update_attributes(widget, qs) else: for widget in widget_store.keys(): - self.update_attributes(widget, qs.filter( - Q(widget__widget_id=widget) | Q(widget__widget_id=conditional_widget.WIDGET_ID) - )) - print(f'Checked on {qs.count()} attributes.') - print(f'It took {time.time() - old} seconds.') + self.update_attributes( + widget, qs.filter(Q(widget__widget_id=widget) | Q(widget__widget_id=conditional_widget.WIDGET_ID)) + ) + print(f"Checked on {qs.count()} attributes.") + print(f"It took {time.time() - old} seconds.") diff --git a/apps/bulk_data_migration/v1_2/ary.py b/apps/bulk_data_migration/v1_2/ary.py index d34bea735f..6eeed89cc1 100644 --- a/apps/bulk_data_migration/v1_2/ary.py +++ b/apps/bulk_data_migration/v1_2/ary.py @@ -1,19 +1,19 @@ from ary.models import Assessment + from utils.common import random_key def migrate_assessment(obj): methodology = obj.methodology - attributes = methodology.get('attributes') + attributes = methodology.get("attributes") if not attributes: return for attribute in attributes: - if not attribute.get('key'): - attribute['key'] = random_key() + if not attribute.get("key"): + attribute["key"] = random_key() - Assessment.objects.filter(id=obj.id)\ - .update(methodology=methodology) + Assessment.objects.filter(id=obj.id).update(methodology=methodology) def migrate_ary(**filters): diff --git a/apps/bulk_data_migration/v1_2/geo.py b/apps/bulk_data_migration/v1_2/geo.py index aa2c399b73..570eba0986 100644 --- a/apps/bulk_data_migration/v1_2/geo.py +++ b/apps/bulk_data_migration/v1_2/geo.py @@ -4,12 +4,12 @@ def migrate_widget(widget_data): def migrate_val(v): if isinstance(v, dict): - return v['key'] + return v["key"] return v def migrate_attribute(data): - value = data.get('values') or [] + value = data.get("values") or [] return { - 'value': [migrate_val(v) for v in value], + "value": [migrate_val(v) for v in value], } diff --git a/apps/bulk_data_migration/v1_2/matrix1d.py b/apps/bulk_data_migration/v1_2/matrix1d.py index d637427bb2..bbd6a3b071 100644 --- a/apps/bulk_data_migration/v1_2/matrix1d.py +++ b/apps/bulk_data_migration/v1_2/matrix1d.py @@ -3,9 +3,7 @@ def migrate_widget(widget_data): def migrate_attribute(data): - if data.get('value'): + if data.get("value"): return data - return { - 'value': data - } + return {"value": data} diff --git a/apps/bulk_data_migration/v1_2/matrix2d.py b/apps/bulk_data_migration/v1_2/matrix2d.py index d637427bb2..bbd6a3b071 100644 --- a/apps/bulk_data_migration/v1_2/matrix2d.py +++ b/apps/bulk_data_migration/v1_2/matrix2d.py @@ -3,9 +3,7 @@ def migrate_widget(widget_data): def migrate_attribute(data): - if data.get('value'): + if data.get("value"): return data - return { - 'value': data - } + return {"value": data} diff --git a/apps/bulk_data_migration/v1_2/migrate.py b/apps/bulk_data_migration/v1_2/migrate.py index d632288d3d..1748fdfde7 100644 --- a/apps/bulk_data_migration/v1_2/migrate.py +++ b/apps/bulk_data_migration/v1_2/migrate.py @@ -1,37 +1,28 @@ import json -import reversion -from analysis_framework.utils import update_widgets, Widget -from entry.utils import update_attributes, Attribute +import reversion +from analysis_framework.utils import Widget, update_widgets +from entry.utils import Attribute, update_attributes -from .projects import migrate_projects +from . import excerpt, geo, matrix1d, matrix2d, number_matrix, organigram, scale from .ary import migrate_ary -from . import ( - matrix1d, - matrix2d, - scale, - excerpt, - organigram, - geo, - number_matrix, -) - +from .projects import migrate_projects widgets = { - 'matrix1dWidget': matrix1d, - 'matrix2dWidget': matrix2d, - 'scaleWidget': scale, - 'excerptWidget': excerpt, - 'organigramWidget': organigram, - 'geoWidget': geo, - 'numberMatrixWidget': number_matrix, + "matrix1dWidget": matrix1d, + "matrix2dWidget": matrix2d, + "scaleWidget": scale, + "excerptWidget": excerpt, + "organigramWidget": organigram, + "geoWidget": geo, + "numberMatrixWidget": number_matrix, } default_added_from = { - 'matrix1dWidget': 'overview', - 'matrix2dWidget': 'overview', - 'numberMatrixWidget': 'overview', - 'excerptWidget': 'overview', + "matrix1dWidget": "overview", + "matrix2dWidget": "overview", + "numberMatrixWidget": "overview", + "excerptWidget": "overview", # if not specified here, default is assumed to be list } @@ -41,15 +32,14 @@ def migrate_widgets(**kwargs): if not widget.properties: widget.properties = {} - if not widget.properties.get('added_from'): - widget.properties['added_from'] = \ - default_added_from.get(widget.widget_id, 'list') + if not widget.properties.get("added_from"): + widget.properties["added_from"] = default_added_from.get(widget.widget_id, "list") - widget_data = widget.properties.get('data') + widget_data = widget.properties.get("data") w = widgets.get(widget.widget_id) if widget_data and w: - widget.properties['data'] = w.migrate_widget(widget_data) + widget.properties["data"] = w.migrate_widget(widget_data) widget.save() @@ -71,18 +61,18 @@ def migrate_attributes(**kwargs): def migrate(*args, **kwargs): - if not kwargs.get('filters_file'): + if not kwargs.get("filters_file"): project_filters = {} widget_filters = {} attributes_filters = {} ary_filters = {} else: - with open(kwargs['filters_file']) as f: + with open(kwargs["filters_file"]) as f: filter_data = json.load(f) - project_filters = filter_data.get('project_filters', {}) - widget_filters = filter_data.get('widget_filters', {}) - attributes_filters = filter_data.get('attributes_filters', {}) - ary_filters = filter_data.get('ary_filters', {}) + project_filters = filter_data.get("project_filters", {}) + widget_filters = filter_data.get("widget_filters", {}) + attributes_filters = filter_data.get("attributes_filters", {}) + ary_filters = filter_data.get("ary_filters", {}) with reversion.create_revision(): migrate_projects(**project_filters) diff --git a/apps/bulk_data_migration/v1_2/number_matrix.py b/apps/bulk_data_migration/v1_2/number_matrix.py index d637427bb2..bbd6a3b071 100644 --- a/apps/bulk_data_migration/v1_2/number_matrix.py +++ b/apps/bulk_data_migration/v1_2/number_matrix.py @@ -3,9 +3,7 @@ def migrate_widget(widget_data): def migrate_attribute(data): - if data.get('value'): + if data.get("value"): return data - return { - 'value': data - } + return {"value": data} diff --git a/apps/bulk_data_migration/v1_2/organigram.py b/apps/bulk_data_migration/v1_2/organigram.py index 3350d2859e..4a2a1ad30b 100644 --- a/apps/bulk_data_migration/v1_2/organigram.py +++ b/apps/bulk_data_migration/v1_2/organigram.py @@ -4,12 +4,12 @@ def migrate_widget(widget_data): def migrate_val(v): if isinstance(v, dict): - return v['id'] + return v["id"] return v def migrate_attribute(data): - value = data.get('values') or [] + value = data.get("values") or [] return { - 'value': [migrate_val(v) for v in value], + "value": [migrate_val(v) for v in value], } diff --git a/apps/bulk_data_migration/v1_2/scale.py b/apps/bulk_data_migration/v1_2/scale.py index 6bcc939ce2..17cb972adb 100644 --- a/apps/bulk_data_migration/v1_2/scale.py +++ b/apps/bulk_data_migration/v1_2/scale.py @@ -1,16 +1,14 @@ def migrate_widget(widget_data): - units = widget_data.get('scale_units') or [] + units = widget_data.get("scale_units") or [] for unit in units: - if unit.get('title'): - unit['label'] = unit['title'] - unit.pop('title') + if unit.get("title"): + unit["label"] = unit["title"] + unit.pop("title") return widget_data def migrate_attribute(data): - if data.get('value'): + if data.get("value"): return data - return { - 'value': data.get('selected_scale') - } + return {"value": data.get("selected_scale")} diff --git a/apps/category_editor/admin.py b/apps/category_editor/admin.py index 2a6bc7edd9..e862822118 100644 --- a/apps/category_editor/admin.py +++ b/apps/category_editor/admin.py @@ -1,9 +1,12 @@ +from category_editor.models import CategoryEditor from django.contrib import admin from reversion.admin import VersionAdmin -from category_editor.models import CategoryEditor @admin.register(CategoryEditor) class CategoryEditorAdmin(VersionAdmin): - search_fields = ('title',) - autocomplete_fields = ('created_by', 'modified_by',) + search_fields = ("title",) + autocomplete_fields = ( + "created_by", + "modified_by", + ) diff --git a/apps/category_editor/apps.py b/apps/category_editor/apps.py index cc00169b22..19db499f0b 100644 --- a/apps/category_editor/apps.py +++ b/apps/category_editor/apps.py @@ -2,4 +2,4 @@ class CategoryEditorConfig(AppConfig): - name = 'category_editor' + name = "category_editor" diff --git a/apps/category_editor/models.py b/apps/category_editor/models.py index 7f376c656c..62bb8d6bfb 100644 --- a/apps/category_editor/models.py +++ b/apps/category_editor/models.py @@ -13,11 +13,14 @@ def clone(self, user, overrides={}): """ Clone category editor """ - title = overrides.get('title', '{} (cloned)'.format( - # Allowing addition of ' (cloned)' to charfield with maxlen 255 - # by stripping off extra chars - self.title[:230] - )) + title = overrides.get( + "title", + "{} (cloned)".format( + # Allowing addition of ' (cloned)' to charfield with maxlen 255 + # by stripping off extra chars + self.title[:230] + ), + ) category_editor = CategoryEditor( title=title, data=self.data, @@ -35,9 +38,7 @@ def get_for(user): it's project """ return CategoryEditor.objects.filter( - models.Q(project=None) | - models.Q(project__members=user) | - models.Q(project__user_groups__members=user) + models.Q(project=None) | models.Q(project__members=user) | models.Q(project__user_groups__members=user) ).distinct() def can_get(self, user): @@ -51,10 +52,11 @@ def can_modify(self, user): * it belongs to a project where the user is admin """ import project + return ( - self.created_by == user or - user.is_superuser or - project.models.ProjectMembership.objects.filter( + self.created_by == user + or user.is_superuser + or project.models.ProjectMembership.objects.filter( project__in=self.project_set.all(), member=user, role__in=project.models.ProjectRole.get_admin_roles(), diff --git a/apps/category_editor/serializers.py b/apps/category_editor/serializers.py index bdd7af874d..0b01965e24 100644 --- a/apps/category_editor/serializers.py +++ b/apps/category_editor/serializers.py @@ -1,53 +1,49 @@ +from category_editor.models import CategoryEditor from drf_dynamic_fields import DynamicFieldsMixin +from project.models import Project from rest_framework import serializers - -from deep.serializers import RemoveNullFieldsMixin from user_resource.serializers import UserResourceSerializer -from project.models import Project -from category_editor.models import CategoryEditor +from deep.serializers import RemoveNullFieldsMixin -class CategoryEditorSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): +class CategoryEditorSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): is_admin = serializers.SerializerMethodField() project = serializers.IntegerField( write_only=True, required=False, ) projects = serializers.PrimaryKeyRelatedField( - source='project_set.all', + source="project_set.all", read_only=True, many=True, ) class Meta: model = CategoryEditor - fields = ('__all__') + fields = "__all__" def validate_project(self, project): try: project = Project.objects.get(id=project) except Project.DoesNotExist: - raise serializers.ValidationError( - 'Project matching query does not exist' - ) + raise serializers.ValidationError("Project matching query does not exist") - if not project.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if not project.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project.id def create(self, validated_data): - project = validated_data.pop('project', None) + project = validated_data.pop("project", None) ce = super().create(validated_data) if project: project = Project.objects.get(id=project) project.category_editor = ce - project.modified_by = self.context['request'].user + project.modified_by = self.context["request"].user project.save() return ce def get_is_admin(self, category_editor): - return category_editor.can_modify(self.context['request'].user) + return category_editor.can_modify(self.context["request"].user) diff --git a/apps/category_editor/tests/test_apis.py b/apps/category_editor/tests/test_apis.py index 94a03e2343..efedbb02a2 100644 --- a/apps/category_editor/tests/test_apis.py +++ b/apps/category_editor/tests/test_apis.py @@ -1,17 +1,18 @@ -from deep.tests import TestCase from category_editor.models import CategoryEditor from project.models import Project +from deep.tests import TestCase + class CategoryEditorTests(TestCase): def test_create_category_editor(self): project = self.create(Project, role=self.admin_role) ce_count = CategoryEditor.objects.count() - url = '/api/v1/category-editors/' + url = "/api/v1/category-editors/" data = { - 'title': 'New Category Editor', - 'project': project.id, + "title": "New Category Editor", + "project": project.id, } self.authenticate() @@ -20,45 +21,40 @@ def test_create_category_editor(self): self.assertEqual(CategoryEditor.objects.count(), ce_count + 1) project = Project.objects.get(id=project.id) - self.assertEqual(project.category_editor.id, response.data['id']) + self.assertEqual(project.category_editor.id, response.data["id"]) def test_clone_category_editor(self): category_editor = self.create(CategoryEditor) - project = self.create( - Project, category_editor=category_editor, - role=self.admin_role - ) + project = self.create(Project, category_editor=category_editor, role=self.admin_role) - url = '/api/v1/clone-category-editor/{}/'.format(category_editor.id) + url = "/api/v1/clone-category-editor/{}/".format(category_editor.id) data = { - 'project': project.id, + "project": project.id, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertNotEqual(response.data['id'], category_editor.id) - self.assertEqual(response.data['title'], - category_editor.title[:230] + ' (cloned)') + self.assertNotEqual(response.data["id"], category_editor.id) + self.assertEqual(response.data["title"], category_editor.title[:230] + " (cloned)") project = Project.objects.get(id=project.id) - self.assertNotEqual(project.category_editor.id, - category_editor.id) + self.assertNotEqual(project.category_editor.id, category_editor.id) - self.assertEqual(project.category_editor.id, response.data['id']) + self.assertEqual(project.category_editor.id, response.data["id"]) def test_classify(self): ce_data = { - 'categories': [ + "categories": [ { - 'title': 'Sector', - 'subcategories': [ + "title": "Sector", + "subcategories": [ { - 'title': 'WASH', - 'ngrams': { - 1: ['affected', 'water'], - 2: ['affected not', 'water not'], + "title": "WASH", + "ngrams": { + 1: ["affected", "water"], + 2: ["affected not", "water not"], }, }, ], @@ -67,17 +63,13 @@ def test_classify(self): } category_editor = self.create(CategoryEditor, data=ce_data) - project = self.create( - Project, category_editor=category_editor, - role=self.admin_role) - - text = 'My water aaloooo' - url = '/api/v1/projects/{}/category-editor/classify/'.format( - project.id - ) + project = self.create(Project, category_editor=category_editor, role=self.admin_role) + + text = "My water aaloooo" + url = "/api/v1/projects/{}/category-editor/classify/".format(project.id) data = { - 'text': text, - 'category': 'sector', + "text": text, + "category": "sector", } self.authenticate() @@ -86,15 +78,13 @@ def test_classify(self): expected = [ { - 'title': 'WASH', - 'keywords': [ - {'start': 3, 'length': 5, 'subcategory': 'WASH'}, + "title": "WASH", + "keywords": [ + {"start": 3, "length": 5, "subcategory": "WASH"}, ], }, ] - got = [dict(c) for c in response.data.get('classifications')] + got = [dict(c) for c in response.data.get("classifications")] for g in got: - g['keywords'] = [ - dict(k) for k in g['keywords'] - ] + g["keywords"] = [dict(k) for k in g["keywords"]] self.assertEqual(got, expected) diff --git a/apps/category_editor/views.py b/apps/category_editor/views.py index f3029de351..138467e563 100644 --- a/apps/category_editor/views.py +++ b/apps/category_editor/views.py @@ -1,25 +1,18 @@ -from rest_framework import ( - exceptions, - permissions, - response, - status, - views, - viewsets, -) -from deep.permissions import ModifyPermission +import re -from project.models import Project from lead.models import LeadPreview +from project.models import Project +from rest_framework import exceptions, permissions, response, status, views, viewsets + +from deep.permissions import ModifyPermission + from .models import CategoryEditor from .serializers import CategoryEditorSerializer -import re - class CategoryEditorViewSet(viewsets.ModelViewSet): serializer_class = CategoryEditorSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return CategoryEditor.get_for(self.request.user) @@ -29,14 +22,10 @@ class CategoryEditorCloneView(views.APIView): permission_classes = [permissions.IsAuthenticated] def post(self, request, ce_id, version=None): - if not CategoryEditor.objects.filter( - id=ce_id - ).exists(): + if not CategoryEditor.objects.filter(id=ce_id).exists(): raise exceptions.NotFound() - category_editor = CategoryEditor.objects.get( - id=ce_id - ) + category_editor = CategoryEditor.objects.get(id=ce_id) if not category_editor.can_get(request.user): raise exceptions.PermissionDenied() @@ -46,16 +35,18 @@ def post(self, request, ce_id, version=None): ) serializer = CategoryEditorSerializer( new_ce, - context={'request': request}, + context={"request": request}, ) - project = request.data.get('project') + project = request.data.get("project") if project: project = Project.objects.get(id=project) if not project.can_modify(request.user): - raise exceptions.ValidationError({ - 'project': 'Invalid project', - }) + raise exceptions.ValidationError( + { + "project": "Invalid project", + } + ) project.category_editor = new_ce project.modified_by = request.user project.save() @@ -84,17 +75,17 @@ def post(self, request, project_id, version=None): if not ce_data: return response.Response( { - 'classifications': [], + "classifications": [], }, status=status.HTTP_200_OK, ) - text = request.data.get('text') - preview_id = request.data.get('preview_id') + text = request.data.get("text") + preview_id = request.data.get("preview_id") errors = {} if not text and not preview_id: - errors['text'] = 'Value not provided' - errors['preview_id'] = 'Value not provided' + errors["text"] = "Value not provided" + errors["preview_id"] = "Value not provided" if not text: text = LeadPreview.objects.get(id=preview_id).text_extract @@ -104,18 +95,18 @@ def post(self, request, project_id, version=None): raise exceptions.ValidationError(errors) classifications = [] - for category in ce_data.get('categories'): + for category in ce_data.get("categories"): self._classify(ce_data, category, text, classifications) return response.Response( { - 'classifications': classifications, + "classifications": classifications, }, status=status.HTTP_200_OK, ) def _classify(self, ce_data, category, text, results): - subcategories = category.get('subcategories', []) + subcategories = category.get("subcategories", []) for subcategory in subcategories: self._process_subcategory(subcategory, text.lower(), results) @@ -123,31 +114,30 @@ def _classify(self, ce_data, category, text, results): return results def _process_subcategory(self, category, text, results): - title = category.get('title') - ngrams = category.get('ngrams', {}) + title = category.get("title") + ngrams = category.get("ngrams", {}) category_results = [] - results.append({ - 'title': title, - 'keywords': category_results, - }) + results.append( + { + "title": title, + "keywords": category_results, + } + ) for _, ngram in ngrams.items(): - [ - category_results.extend(self._search_word(title, word, text)) - for word in ngram - if word.lower() in text - ] + [category_results.extend(self._search_word(title, word, text)) for word in ngram if word.lower() in text] - subcategories = category.get('subcategories', []) + subcategories = category.get("subcategories", []) for subcategory in subcategories: self._process_subcategory(subcategory, text, results) def _search_word(self, title, word, text): return [ { - 'start': a.start(), - 'length': len(word), - 'subcategory': title, - } for a in list(re.finditer(word, text)) + "start": a.start(), + "length": len(word), + "subcategory": title, + } + for a in list(re.finditer(word, text)) ] diff --git a/apps/client_page_meta/admin.py b/apps/client_page_meta/admin.py index aa7eb75552..b0bde239fa 100644 --- a/apps/client_page_meta/admin.py +++ b/apps/client_page_meta/admin.py @@ -6,5 +6,5 @@ @admin.register(Page) class PageAdmin(VersionAdmin): - search_fields = ('title', 'page_id') - list_display = ('title', 'page_id', 'help_url') + search_fields = ("title", "page_id") + list_display = ("title", "page_id", "help_url") diff --git a/apps/client_page_meta/apps.py b/apps/client_page_meta/apps.py index bbb4e4f53e..51436e479f 100644 --- a/apps/client_page_meta/apps.py +++ b/apps/client_page_meta/apps.py @@ -2,4 +2,4 @@ class ClientPageMetaConfig(AppConfig): - name = 'client_page_meta' + name = "client_page_meta" diff --git a/apps/client_page_meta/models.py b/apps/client_page_meta/models.py index 89d7aa492e..0e452f8e51 100644 --- a/apps/client_page_meta/models.py +++ b/apps/client_page_meta/models.py @@ -7,4 +7,4 @@ class Page(models.Model): help_url = models.TextField() def __str__(self): - return '{} {}'.format(self.title, self.page_id) + return "{} {}".format(self.title, self.page_id) diff --git a/apps/client_page_meta/serializers.py b/apps/client_page_meta/serializers.py index 79c30698ab..e57a4ef472 100644 --- a/apps/client_page_meta/serializers.py +++ b/apps/client_page_meta/serializers.py @@ -6,4 +6,4 @@ class PageSerializer(serializers.ModelSerializer): class Meta: model = Page - fields = ('__all__') + fields = "__all__" diff --git a/apps/client_page_meta/views.py b/apps/client_page_meta/views.py index 9c6788321c..508e673beb 100644 --- a/apps/client_page_meta/views.py +++ b/apps/client_page_meta/views.py @@ -1,10 +1,10 @@ from rest_framework import viewsets -from .serializers import PageSerializer from .models import Page +from .serializers import PageSerializer class PageViewSet(viewsets.ReadOnlyModelViewSet): queryset = Page.objects.all() serializer_class = PageSerializer - lookup_field = 'page_id' + lookup_field = "page_id" diff --git a/apps/commons/apps.py b/apps/commons/apps.py index 90c6dcdf93..f0fa26579e 100644 --- a/apps/commons/apps.py +++ b/apps/commons/apps.py @@ -2,7 +2,7 @@ class CommonsConfig(AppConfig): - name = 'commons' + name = "commons" def ready(self): import commons.receivers # noqa: F401 diff --git a/apps/commons/management/commands/run_celery_dev.py b/apps/commons/management/commands/run_celery_dev.py index cf15dc2687..56b94a6f21 100644 --- a/apps/commons/management/commands/run_celery_dev.py +++ b/apps/commons/management/commands/run_celery_dev.py @@ -1,23 +1,23 @@ -import shlex import os +import shlex import subprocess from django.core.management.base import BaseCommand from django.utils import autoreload -from deep.celery import CeleryQueue +from deep.celery import CeleryQueue -WORKER_STATE_DIR = '/var/run/celery' +WORKER_STATE_DIR = "/var/run/celery" CMD = ( f"celery -A deep worker -Q {','.join(CeleryQueue.ALL_QUEUES)} -B --concurrency=2 -l info " - '--scheduler django_celery_beat.schedulers:DatabaseScheduler ' - f'--statedb={WORKER_STATE_DIR}/worker.state' + "--scheduler django_celery_beat.schedulers:DatabaseScheduler " + f"--statedb={WORKER_STATE_DIR}/worker.state" ) def restart_celery(*args, **kwargs): - kill_worker_cmd = 'pkill -9 celery' + kill_worker_cmd = "pkill -9 celery" subprocess.call(shlex.split(kill_worker_cmd)) subprocess.call(shlex.split(CMD)) @@ -25,7 +25,7 @@ def restart_celery(*args, **kwargs): class Command(BaseCommand): def handle(self, *args, **options): - self.stdout.write('Starting celery worker with autoreload...') + self.stdout.write("Starting celery worker with autoreload...") if not os.path.exists(WORKER_STATE_DIR): os.makedirs(WORKER_STATE_DIR) autoreload.run_with_reloader(restart_celery, args=None, kwargs=None) diff --git a/apps/commons/receivers.py b/apps/commons/receivers.py index 0ee99c8411..aaf8388fe7 100644 --- a/apps/commons/receivers.py +++ b/apps/commons/receivers.py @@ -2,11 +2,7 @@ from django.db import models from django.db.transaction import on_commit from django.dispatch import receiver - -from lead.models import ( - LeadPreview, - LeadPreviewImage, -) +from lead.models import LeadPreview, LeadPreviewImage from unified_connector.models import ConnectorLeadPreviewImage diff --git a/apps/commons/schema_snapshots.py b/apps/commons/schema_snapshots.py index 376018d168..c59dcb6a65 100644 --- a/apps/commons/schema_snapshots.py +++ b/apps/commons/schema_snapshots.py @@ -1,7 +1,7 @@ import typing -from django.test import override_settings from django.core.files.base import ContentFile +from django.test import override_settings from utils.files import generate_json_file_for_upload @@ -182,7 +182,7 @@ class DeepExplore: """ class AnalysisReport: - SnapshotFragment = ''' + SnapshotFragment = """ fragment OrganizationGeneralResponse on OrganizationType { id title @@ -645,10 +645,10 @@ class AnalysisReport: } } } - ''' + """ Snapshot = ( - SnapshotFragment + - '''\n + SnapshotFragment + + """\n query MyQuery($projectID: ID!, $reportID: ID!) { project(id: $projectID) { analysisReport(id: $reportID) { @@ -656,7 +656,7 @@ class AnalysisReport: } } } - ''' + """ ) @@ -666,9 +666,9 @@ class DummyContext: @override_settings( CACHES={ - 'default': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - 'LOCATION': 'unique-snowflake', + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "unique-snowflake", } }, ) @@ -677,17 +677,13 @@ def generate_query_snapshot( variables: dict, data_callback: typing.Callable = lambda x: x, context: typing.Optional[object] = None, -) -> \ - typing.Tuple[typing.Optional[ContentFile], typing.Optional[dict]]: +) -> typing.Tuple[typing.Optional[ContentFile], typing.Optional[dict]]: # To avoid circular dependency from deep.schema import schema as gql_schema + if context is None: context = DummyContext() - result = gql_schema.execute( - query, - context=context, - variables=variables - ) + result = gql_schema.execute(query, context=context, variables=variables) if result.errors: return None, result.errors return generate_json_file_for_upload(data_callback(result.data)), None diff --git a/apps/commons/tests/test_common.py b/apps/commons/tests/test_common.py index 6fb17c411b..6aa11072d2 100644 --- a/apps/commons/tests/test_common.py +++ b/apps/commons/tests/test_common.py @@ -1,4 +1,5 @@ from django.conf import settings + from deep.tests import TestCase diff --git a/apps/commons/views.py b/apps/commons/views.py index c44cef0965..0664008026 100644 --- a/apps/commons/views.py +++ b/apps/commons/views.py @@ -4,31 +4,25 @@ import random import string -from django.views.generic import View from django.http import FileResponse, HttpResponse - - +from django.views.generic import View from geo.models import GeoArea -from tabular.viz import ( - barchart, - histograms, - map as _map, -) - +from tabular.viz import barchart, histograms +from tabular.viz import map as _map logger = logging.getLogger(__name__) try: import pandas as pd except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") -STRINGS = string.ascii_uppercase + string.digits + 'चैनपुर नगरपालिका à€' +STRINGS = string.ascii_uppercase + string.digits + "चैनपुर नगरपालिका à€" def _get_random_string(N): - return ''.join(random.choice(STRINGS) for _ in range(N)) + return "".join(random.choice(STRINGS) for _ in range(N)) def _get_random_number(min, max): @@ -36,18 +30,20 @@ def _get_random_number(min, max): def _get_image_response(fp, image_format): - if image_format == 'svg': + if image_format == "svg": return HttpResponse( - ''' + """ {} - '''.format(fp.read().decode('utf-8')) + """.format( + fp.read().decode("utf-8") + ) ) - return FileResponse(fp, content_type='image/png') + return FileResponse(fp, content_type="image/png") class RenderChart(View): @@ -55,6 +51,7 @@ class RenderChart(View): Debug chart rendering NOTE: Use Only For Debug """ + MAX_VALUE_LEN = 1000 MAX_VALUE_INTEGER = 100 MAX_ROW = 10 @@ -63,55 +60,58 @@ class RenderChart(View): def get_geo_data(self): return [ { - 'count': _get_random_number(1, self.MAX_COUNT), - 'value': geoarea.id, - } for geoarea in GeoArea.objects.filter(admin_level_id=2) + "count": _get_random_number(1, self.MAX_COUNT), + "value": geoarea.id, + } + for geoarea in GeoArea.objects.filter(admin_level_id=2) ] def get_data(self, number=False): return [ { - 'count': _get_random_number(10, self.MAX_COUNT), - 'value': _get_random_number(10, self.MAX_VALUE_INTEGER) - if number else _get_random_string(self.MAX_VALUE_LEN), - } for row in range(self.MAX_ROW) + "count": _get_random_number(10, self.MAX_COUNT), + "value": _get_random_number(10, self.MAX_VALUE_INTEGER) if number else _get_random_string(self.MAX_VALUE_LEN), + } + for row in range(self.MAX_ROW) ] def get(self, request): - image_format = request.GET.get('format', 'png') - chart_type = request.GET.get('chart_type', 'barchart') - if chart_type in ['histograms']: + image_format = request.GET.get("format", "png") + chart_type = request.GET.get("chart_type", "barchart") + if chart_type in ["histograms"]: df = pd.DataFrame(self.get_data(number=True)) - elif chart_type in ['map']: + elif chart_type in ["map"]: df = pd.DataFrame(self.get_geo_data()) else: df = pd.DataFrame(self.get_data()) params = { - 'x_label': 'Test Label', - 'y_label': 'count', - 'chart_size': (8, 4), - 'data': df, - 'format': image_format, + "x_label": "Test Label", + "y_label": "count", + "chart_size": (8, 4), + "data": df, + "format": image_format, } - if chart_type == 'barchart': - params['data']['value'] = params['data']['value'].str.slice(0, 20) + '...' + if chart_type == "barchart": + params["data"]["value"] = params["data"]["value"].str.slice(0, 20) + "..." fp = barchart.plotly(**params) - elif chart_type == 'histograms': + elif chart_type == "histograms": new_data = [] - values = df['value'].tolist() - counts = df['count'].tolist() + values = df["value"].tolist() + counts = df["count"].tolist() for index, value in enumerate(values): new_data.extend([value for i in range(counts[index])]) - params['data'] = pd.to_numeric(new_data) + params["data"] = pd.to_numeric(new_data) fp = histograms.plotly(**params) - elif chart_type == 'map': - adjust_df = pd.DataFrame([ - {'value': 0, 'count': 0}, # Count 0 is min's max value - {'value': 0, 'count': 5}, # Count 5 is max's min value - ]) - params['data'] = params['data'].append(adjust_df, ignore_index=True) + elif chart_type == "map": + adjust_df = pd.DataFrame( + [ + {"value": 0, "count": 0}, # Count 0 is min's max value + {"value": 0, "count": 5}, # Count 5 is max's min value + ] + ) + params["data"] = params["data"].append(adjust_df, ignore_index=True) fp = _map.plot(**params) - return _get_image_response(fp[0]['image'], fp[0]['format']) + return _get_image_response(fp[0]["image"], fp[0]["format"]) diff --git a/apps/connector/admin.py b/apps/connector/admin.py index 8306601498..9f09557e75 100644 --- a/apps/connector/admin.py +++ b/apps/connector/admin.py @@ -1,17 +1,24 @@ +from connector.models import Connector, ConnectorSource, EMMConfig from django.contrib import admin from deep.admin import VersionAdmin -from connector.models import Connector, EMMConfig, ConnectorSource @admin.register(Connector) class ConnectorAdmin(VersionAdmin): - autocomplete_fields = ('created_by', 'modified_by',) + autocomplete_fields = ( + "created_by", + "modified_by", + ) @admin.register(EMMConfig) class EMMConfigAdmin(VersionAdmin): - list_display = ('entity_tag', 'trigger_tag', 'trigger_attribute',) + list_display = ( + "entity_tag", + "trigger_tag", + "trigger_attribute", + ) admin.site.register(ConnectorSource) diff --git a/apps/connector/apps.py b/apps/connector/apps.py index 49b7b05338..c3fb3440b7 100644 --- a/apps/connector/apps.py +++ b/apps/connector/apps.py @@ -3,18 +3,20 @@ class ConnectorConfig(AppConfig): - name = 'connector' + name = "connector" def ready(self): from connector.models import ConnectorSource - from .sources.store import source_store + from utils.common import kebabcase_to_titlecase + from .sources.store import source_store + try: for key in source_store.keys(): ConnectorSource.objects.get_or_create( key=key, - defaults={'title': kebabcase_to_titlecase(key)}, + defaults={"title": kebabcase_to_titlecase(key)}, ) except ProgrammingError: # Because, ready() is called before the migration to create ConnectorSource table is run diff --git a/apps/connector/management/commands/create_connector_sources.py b/apps/connector/management/commands/create_connector_sources.py index 6da143c499..c8a86d1bf9 100644 --- a/apps/connector/management/commands/create_connector_sources.py +++ b/apps/connector/management/commands/create_connector_sources.py @@ -1,7 +1,7 @@ -from django.core.management.base import BaseCommand - from connector.models import ConnectorSource from connector.sources.store import source_store +from django.core.management.base import BaseCommand + from utils.common import kebabcase_to_titlecase @@ -9,15 +9,16 @@ class Command(BaseCommand): """ This is a command to add connector sources if not already created. """ + def handle(self, *args, **kwargs): - print('Creating connector sources that are not created') + print("Creating connector sources that are not created") for key in source_store.keys(): obj, created = ConnectorSource.objects.get_or_create( key=key, - defaults={'title': kebabcase_to_titlecase(key)}, + defaults={"title": kebabcase_to_titlecase(key)}, ) if created: - print(f'Created source for {key}') + print(f"Created source for {key}") else: - print(f'Source for {key} already exists.') - print('Done') + print(f"Source for {key} already exists.") + print("Done") diff --git a/apps/connector/models.py b/apps/connector/models.py index e56e0b9a04..26b70a6a06 100644 --- a/apps/connector/models.py +++ b/apps/connector/models.py @@ -1,20 +1,18 @@ from django.db import models - +from project.models import Project +from user.models import User from user_resource.models import UserResource from utils.common import is_valid_regex -from project.models import Project -from user.models import User - class ConnectorSource(models.Model): - STATUS_BROKEN = 'broken' - STATUS_WORKING = 'working' + STATUS_BROKEN = "broken" + STATUS_WORKING = "working" STATUS_CHOICES = ( - (STATUS_BROKEN, 'Broken'), - (STATUS_WORKING, 'Working'), + (STATUS_BROKEN, "Broken"), + (STATUS_WORKING, "Working"), ) key = models.CharField(max_length=100, primary_key=True) @@ -39,10 +37,8 @@ class Connector(UserResource): ) params = models.JSONField(default=None, blank=True, null=True) - users = models.ManyToManyField(User, blank=True, - through='ConnectorUser') - projects = models.ManyToManyField(Project, blank=True, - through='ConnectorProject') + users = models.ManyToManyField(User, blank=True, through="ConnectorUser") + projects = models.ManyToManyField(Project, blank=True, through="ConnectorProject") def __str__(self): return self.title @@ -50,9 +46,7 @@ def __str__(self): @staticmethod def get_for(user): return Connector.objects.filter( - models.Q(users=user) | - models.Q(projects__members=user) | - models.Q(projects__user_groups__members=user) + models.Q(users=user) | models.Q(projects__members=user) | models.Q(projects__user_groups__members=user) ).distinct() def can_get(self, user): @@ -62,10 +56,10 @@ def can_modify(self, user): return ConnectorUser.objects.filter( connector=self, user=user, - role='admin', + role="admin", ).exists() - def add_member(self, user, role='normal'): + def add_member(self, user, role="normal"): return ConnectorUser.objects.create( user=user, role=role, @@ -79,21 +73,20 @@ class ConnectorUser(models.Model): """ ROLES = ( - ('normal', 'Normal'), - ('admin', 'Admin'), + ("normal", "Normal"), + ("admin", "Admin"), ) user = models.ForeignKey(User, on_delete=models.CASCADE) connector = models.ForeignKey(Connector, on_delete=models.CASCADE) - role = models.CharField(max_length=96, choices=ROLES, - default='normal') + role = models.CharField(max_length=96, choices=ROLES, default="normal") added_at = models.DateTimeField(auto_now_add=True) def __str__(self): - return '{} @ {}'.format(str(self.user), self.connector.title) + return "{} @ {}".format(str(self.user), self.connector.title) class Meta: - unique_together = ('user', 'connector') + unique_together = ("user", "connector") @staticmethod def get_for(user): @@ -112,21 +105,20 @@ class ConnectorProject(models.Model): """ ROLES = ( - ('self', 'For self only'), - ('global', 'For all members of project'), + ("self", "For self only"), + ("global", "For all members of project"), ) project = models.ForeignKey(Project, on_delete=models.CASCADE) connector = models.ForeignKey(Connector, on_delete=models.CASCADE) - role = models.CharField(max_length=96, choices=ROLES, - default='self') + role = models.CharField(max_length=96, choices=ROLES, default="self") added_at = models.DateTimeField(auto_now_add=True) def __str__(self): - return '{} @ {}'.format(str(self.project), self.connector.title) + return "{} @ {}".format(str(self.project), self.connector.title) class Meta: - unique_together = ('project', 'connector') + unique_together = ("project", "connector") @staticmethod def get_for(user): @@ -139,11 +131,11 @@ def can_modify(self, user): return self.connector.can_modify(user) -EMM_SEPARATOR_DEFAULT = ';' -EMM_TRIGGER_REGEX_DEFAULT = r'(\((?P[a-zA-Z ]+)\)){0,1}(?P[a-zA-Z ]+)\[(?P\d+)]' -EMM_ENTITY_TAG_DEFAULT = 'emm:entity' -EMM_TRIGGER_TAG_DEFAULT = 'category' -EMM_TRIGGER_ATTRIBUTE_DEFAULT = 'emm:trigger' +EMM_SEPARATOR_DEFAULT = ";" +EMM_TRIGGER_REGEX_DEFAULT = r"(\((?P[a-zA-Z ]+)\)){0,1}(?P[a-zA-Z ]+)\[(?P\d+)]" +EMM_ENTITY_TAG_DEFAULT = "emm:entity" +EMM_TRIGGER_TAG_DEFAULT = "category" +EMM_TRIGGER_ATTRIBUTE_DEFAULT = "emm:trigger" class EMMConfig(models.Model): @@ -154,12 +146,12 @@ class EMMConfig(models.Model): trigger_attribute = models.CharField(max_length=50, default=EMM_TRIGGER_ATTRIBUTE_DEFAULT) def __str__(self): - return f'{self.entity_tag}:{self.trigger_tag}:{self.trigger_attribute}' + return f"{self.entity_tag}:{self.trigger_tag}:{self.trigger_attribute}" # Just Allow to have a single config def save(self, *args, **kwargs): self.pk = 1 # Check if valid regex if not is_valid_regex(self.trigger_regex): - raise Exception(f'{self.trigger_regex} is not a valid Regular Expression') + raise Exception(f"{self.trigger_regex} is not a valid Regular Expression") super().save(*args, **kwargs) diff --git a/apps/connector/serializers.py b/apps/connector/serializers.py index 86f4909ce8..875fc17b1b 100644 --- a/apps/connector/serializers.py +++ b/apps/connector/serializers.py @@ -1,35 +1,24 @@ from drf_dynamic_fields import DynamicFieldsMixin +from lead.models import Lead +from lead.views import check_if_url_exists +from organization.serializers import SimpleOrganizationSerializer from rest_framework import serializers +from user_resource.serializers import UserResourceSerializer from deep.serializers import RemoveNullFieldsMixin -from organization.serializers import SimpleOrganizationSerializer -from user_resource.serializers import UserResourceSerializer -from lead.models import Lead -from lead.views import check_if_url_exists +from .models import Connector, ConnectorProject, ConnectorSource, ConnectorUser from .sources.store import source_store -from .models import ( - Connector, - ConnectorSource, - ConnectorUser, - ConnectorProject, -) -class SourceOptionSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.Serializer): +class SourceOptionSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.Serializer): key = serializers.CharField() field_type = serializers.CharField() title = serializers.CharField() - options = serializers.ListField( - serializers.DictField(serializers.CharField) - ) + options = serializers.ListField(serializers.DictField(serializers.CharField)) -class SourceSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.Serializer): +class SourceSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.Serializer): title = serializers.CharField() key = serializers.CharField() options = SourceOptionSerializer(many=True) @@ -54,10 +43,9 @@ class SourceEMMTriggerSerializer(serializers.Serializer): count = serializers.IntegerField() -class SourceDataSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SourceDataSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): existing = serializers.SerializerMethodField() - key = serializers.CharField(source='id') + key = serializers.CharField(source="id") emm_entities = serializers.SerializerMethodField() emm_triggers = serializers.SerializerMethodField() @@ -66,134 +54,131 @@ class SourceDataSerializer(RemoveNullFieldsMixin, authors_detail = serializers.SerializerMethodField() # TODO: Remove (Legacy) - author_detail = SimpleOrganizationSerializer(source='author', read_only=True) + author_detail = SimpleOrganizationSerializer(source="author", read_only=True) - source_detail = SimpleOrganizationSerializer(source='source', read_only=True) + source_detail = SimpleOrganizationSerializer(source="source", read_only=True) published_on = serializers.DateField(read_only=True) class Meta: model = Lead fields = ( - 'key', 'title', 'source', 'source_type', 'url', - 'published_on', 'existing', - 'emm_entities', 'emm_triggers', 'source_detail', - 'author_detail', 'authors', 'authors_detail', - 'source_raw', 'author_raw', + "key", + "title", + "source", + "source_type", + "url", + "published_on", + "existing", + "emm_entities", + "emm_triggers", + "source_detail", + "author_detail", + "authors", + "authors_detail", + "source_raw", + "author_raw", ) def get_authors(self, lead): - if hasattr(lead, '_authors'): + if hasattr(lead, "_authors"): return [author.pk for author in lead._authors] return [] def get_authors_detail(self, lead): - if hasattr(lead, '_authors'): + if hasattr(lead, "_authors"): return SimpleOrganizationSerializer(lead._authors, many=True).data return [] def get_emm_entities(self, lead): - if hasattr(lead, '_emm_entities'): + if hasattr(lead, "_emm_entities"): return SourceEMMEntitiesSerializer(lead._emm_entities, many=True).data return [] def get_emm_triggers(self, lead): - if hasattr(lead, '_emm_triggers'): + if hasattr(lead, "_emm_triggers"): return SourceEMMTriggerSerializer(lead._emm_triggers, many=True).data return [] def get_existing(self, lead): - if not self.context.get('request'): + if not self.context.get("request"): return False - return check_if_url_exists(lead.url, - self.context['request'].user, - self.context.get('project')) + return check_if_url_exists(lead.url, self.context["request"].user, self.context.get("project")) -class ConnectorUserSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): - email = serializers.CharField(source='user.email', read_only=True) +class ConnectorUserSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): + email = serializers.CharField(source="user.email", read_only=True) display_name = serializers.CharField( - source='user.profile.get_display_name', + source="user.profile.get_display_name", read_only=True, ) class Meta: model = ConnectorUser - fields = ('id', 'user', 'display_name', 'email', - 'connector', 'role', 'added_at') + fields = ("id", "user", "display_name", "email", "connector", "role", "added_at") def get_unique_together_validators(self): return [] # Validations def validate_connector(self, connector): - if not connector.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid connector') + if not connector.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid connector") return connector -class ConnectorProjectSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): - title = serializers.CharField(source='project.title', - read_only=True) +class ConnectorProjectSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): + title = serializers.CharField(source="project.title", read_only=True) class Meta: model = ConnectorProject - fields = ('id', 'project', 'title', - 'connector', 'role', 'added_at') + fields = ("id", "project", "title", "connector", "role", "added_at") def get_unique_together_validators(self): return [] # Validations def validate_connector(self, connector): - if not connector.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid connector') + if not connector.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid connector") return connector -class ConnectorSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): +class ConnectorSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): users = ConnectorUserSerializer( - source='connectoruser_set', + source="connectoruser_set", many=True, required=False, ) projects = ConnectorProjectSerializer( - source='connectorproject_set', + source="connectorproject_set", many=True, required=False, ) source = serializers.PrimaryKeyRelatedField(queryset=ConnectorSource.objects.all()) - source_title = serializers.CharField(source='source.title', read_only=True) + source_title = serializers.CharField(source="source.title", read_only=True) role = serializers.SerializerMethodField() filters = serializers.SerializerMethodField() - status = serializers.CharField(source='source.status', read_only=True) + status = serializers.CharField(source="source.status", read_only=True) class Meta: model = Connector - fields = ('__all__') + fields = "__all__" def create(self, validated_data): connector = super().create(validated_data) ConnectorUser.objects.create( connector=connector, - user=self.context['request'].user, - role='admin', + user=self.context["request"].user, + role="admin", ) return connector def get_role(self, connector): - request = self.context['request'] - user = request.GET.get('user', request.user) + request = self.context["request"] + user = request.GET.get("user", request.user) - usership = ConnectorUser.objects.filter( - connector=connector, - user=user - ).first() + usership = ConnectorUser.objects.filter(connector=connector, user=user).first() if usership: return usership.role @@ -201,6 +186,6 @@ def get_role(self, connector): def get_filters(self, connector): source = source_store[connector.source.key]() - if not hasattr(source, 'filters'): + if not hasattr(source, "filters"): return [] return source.filters diff --git a/apps/connector/sources/store.py b/apps/connector/sources/store.py index a56dfaa976..c669a339a6 100644 --- a/apps/connector/sources/store.py +++ b/apps/connector/sources/store.py @@ -1,31 +1,33 @@ +import random from collections import OrderedDict + from unified_connector.sources import ( - atom_feed, - rss_feed, acaps_briefing_notes, - unhcr_portal, - relief_web, + atom_feed, + emm, + humanitarian_response, pdna, + relief_web, research_center, + rss_feed, + unhcr_portal, wpf, - humanitarian_response, - emm, ) -import random - -source_store = OrderedDict([ - (atom_feed.AtomFeed.key, atom_feed.AtomFeed), - ('rss-feed', rss_feed.RssFeed), - ('emm', emm.EMM), - ('acaps-briefing-notes', acaps_briefing_notes.AcapsBriefingNotes), - ('unhcr-portal', unhcr_portal.UNHCRPortal), - ('relief-web', relief_web.ReliefWeb), - ('post-disaster-needs-assessment', pdna.PDNA), - ('research-resource-center', research_center.ResearchResourceCenter), - ('world-food-programme', wpf.WorldFoodProgramme), - ('humanitarian-response', humanitarian_response.HumanitarianResponse), -]) +source_store = OrderedDict( + [ + (atom_feed.AtomFeed.key, atom_feed.AtomFeed), + ("rss-feed", rss_feed.RssFeed), + ("emm", emm.EMM), + ("acaps-briefing-notes", acaps_briefing_notes.AcapsBriefingNotes), + ("unhcr-portal", unhcr_portal.UNHCRPortal), + ("relief-web", relief_web.ReliefWeb), + ("post-disaster-needs-assessment", pdna.PDNA), + ("research-resource-center", research_center.ResearchResourceCenter), + ("world-food-programme", wpf.WorldFoodProgramme), + ("humanitarian-response", humanitarian_response.HumanitarianResponse), + ] +) sources = None diff --git a/apps/connector/tests/connector_content_mock_data.py b/apps/connector/tests/connector_content_mock_data.py index 290cbd66db..d6cada5fdc 100644 --- a/apps/connector/tests/connector_content_mock_data.py +++ b/apps/connector/tests/connector_content_mock_data.py @@ -1,4 +1,4 @@ -RSS_FEED_MOCK_DATA = ''' +RSS_FEED_MOCK_DATA = """ @@ -459,4 +459,6 @@ -'''.encode('utf-8') +""".encode( + "utf-8" +) diff --git a/apps/connector/tests/test_apis.py b/apps/connector/tests/test_apis.py index 04820e2c05..c8937527cc 100644 --- a/apps/connector/tests/test_apis.py +++ b/apps/connector/tests/test_apis.py @@ -1,18 +1,14 @@ from unittest.mock import patch -from deep.tests import TestCase -from user.models import User -from project.models import Project +from connector.models import Connector, ConnectorSource, ConnectorUser # EMMConfig, +from connector.sources import store +from connector.sources.store import acaps_briefing_notes, get_random_source from organization.models import Organization -from connector.sources.store import get_random_source, acaps_briefing_notes +from project.models import Project from unified_connector.sources.base import OrganizationSearch -from connector.models import ( - Connector, - ConnectorSource, - ConnectorUser, - # EMMConfig, -) -from connector.sources import store +from user.models import User + +from deep.tests import TestCase from .connector_content_mock_data import RSS_FEED_MOCK_DATA @@ -22,40 +18,39 @@ def get_source_object(key): SAMPLE_RSS_PARAMS = { - 'feed-url': 'https://reliefweb.int/country/afg/rss.xml?primary_country=16', - 'title-field': 'title', - 'source-field': 'source', - 'author-field': 'author', - 'date-field': 'pubDate', - 'url-field': 'link', + "feed-url": "https://reliefweb.int/country/afg/rss.xml?primary_country=16", + "title-field": "title", + "source-field": "source", + "author-field": "author", + "date-field": "pubDate", + "url-field": "link", } SAMPLE_ATOM_PARAMS = { - 'feed-url': 'https://feedly.com/f/Lmh0gtsFqdkr3hzoDFuOeass.atom?count=10', - 'title-field': 'title', - 'source-field': 'author', - 'author-field': 'author', - 'date-field': 'published', - 'url-field': 'link', + "feed-url": "https://feedly.com/f/Lmh0gtsFqdkr3hzoDFuOeass.atom?count=10", + "title-field": "title", + "source-field": "author", + "author-field": "author", + "date-field": "published", + "url-field": "link", } SAMPLE_EMM_PARAMS = { - 'feed-url': 'https://emm.newsbrief.eu/rss/rss?type=category&' - 'id=filter-FocusedMyanmarEW-Q&language=en&duplicates=false', - 'url-field': 'link', - 'date-field': 'pubDate', - 'source-field': 'source', - 'author-field': 'source', - 'title-field': 'title', + "feed-url": "https://emm.newsbrief.eu/rss/rss?type=category&" "id=filter-FocusedMyanmarEW-Q&language=en&duplicates=false", + "url-field": "link", + "date-field": "pubDate", + "source-field": "source", + "author-field": "source", + "title-field": "title", } class ConnectorApiTest(TestCase): def test_create_connector(self): - url = '/api/v1/connectors/' + url = "/api/v1/connectors/" data = { - 'title': 'Test connector', - 'source': get_random_source(), + "title": "Test connector", + "source": get_random_source(), } connector_count = Connector.objects.count() @@ -65,89 +60,88 @@ def test_create_connector(self): self.assert_201(response) self.assertEqual(Connector.objects.count(), connector_count + 1) - self.assertEqual(response.data['title'], data['title']) + self.assertEqual(response.data["title"], data["title"]) # Test that the user has been made admin - self.assertEqual(len(response.data['users']), 1) - self.assertEqual(response.data['users'][0]['user'], self.user.pk) + self.assertEqual(len(response.data["users"]), 1) + self.assertEqual(response.data["users"][0]["user"], self.user.pk) - user = ConnectorUser.objects.get(pk=response.data['users'][0]['id']) + user = ConnectorUser.objects.get(pk=response.data["users"][0]["id"]) self.assertEqual(user.user.pk, self.user.pk) - self.assertEqual(user.role, 'admin') + self.assertEqual(user.role, "admin") def test_add_user(self): - connector = self.create(Connector, role='admin') + connector = self.create(Connector, role="admin") test_user = self.create(User) - url = '/api/v1/connector-users/' + url = "/api/v1/connector-users/" data = { - 'user': test_user.pk, - 'connector': connector.pk, - 'role': 'normal', + "user": test_user.pk, + "connector": connector.pk, + "role": "normal", } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['role'], data['role']) - self.assertEqual(response.data['user'], data['user']) - self.assertEqual(response.data['connector'], data['connector']) + self.assertEqual(response.data["role"], data["role"]) + self.assertEqual(response.data["user"], data["user"]) + self.assertEqual(response.data["connector"], data["connector"]) def test_add_project(self): - connector = self.create(Connector, role='admin') + connector = self.create(Connector, role="admin") test_project = self.create(Project) - url = '/api/v1/connector-projects/' + url = "/api/v1/connector-projects/" data = { - 'project': test_project.pk, - 'connector': connector.pk, - 'role': 'self', + "project": test_project.pk, + "connector": connector.pk, + "role": "self", } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['role'], data['role']) - self.assertEqual(response.data['project'], data['project']) - self.assertEqual(response.data['connector'], data['connector']) + self.assertEqual(response.data["role"], data["role"]) + self.assertEqual(response.data["project"], data["project"]) + self.assertEqual(response.data["connector"], data["connector"]) def test_list_sources(self): - url = '/api/v1/connector-sources/' + url = "/api/v1/connector-sources/" self.authenticate() response = self.client.get(url) self.assert_200(response) - @patch('unified_connector.sources.rss_feed.requests') + @patch("unified_connector.sources.rss_feed.requests") def test_connector_leads(self, mock_requests): mock_requests.get.return_value.content = RSS_FEED_MOCK_DATA - connector = self.create( - Connector, - source=get_source_object('rss-feed'), - params=SAMPLE_RSS_PARAMS, - role='self' - ) - url = '/api/v1/connectors/{}/leads/'.format(connector.id) + connector = self.create(Connector, source=get_source_object("rss-feed"), params=SAMPLE_RSS_PARAMS, role="self") + url = "/api/v1/connectors/{}/leads/".format(connector.id) self.authenticate() response = self.client.post(url) self.assert_200(response) - self.assertIsNotNone(response.data.get('results')) - self.assertTrue(response.data['count'] == 20) - self.assertIsInstance(response.data['results'], list) + self.assertIsNotNone(response.data.get("results")) + self.assertTrue(response.data["count"] == 20) + self.assertIsInstance(response.data["results"], list) - first_lead = response.data['results'][0] + first_lead = response.data["results"][0] for key in [ - 'source_raw', 'source', 'source_detail', - 'author_raw', 'author_detail', - 'authors', 'authors_detail', + "source_raw", + "source", + "source_detail", + "author_raw", + "author_detail", + "authors", + "authors_detail", ]: self.assertTrue(first_lead[key] not in [None, []]) - self.assertIsNotNone(first_lead['authors'][0] == first_lead['authors_detail'][0]['id']) + self.assertIsNotNone(first_lead["authors"][0] == first_lead["authors_detail"][0]["id"]) # FIXME: Fix the broken tests by mocking # def test_get_leads_from_connector(self): @@ -278,68 +272,66 @@ def test_get_connector_fields(self): Connector, source=get_source_object(store.atom_feed.AtomFeed.key), params=SAMPLE_ATOM_PARAMS, - role='self', + role="self", ) - url = '/api/v1/connectors/' + url = "/api/v1/connectors/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - data = resp.data['results'] + data = resp.data["results"] assert len(data) == 1 - assert data[0]['id'] == connector.id - assert 'source' in data[0] - assert 'source_title' in data[0] + assert data[0]["id"] == connector.id + assert "source" in data[0] + assert "source_title" in data[0] class ConnectorSourcesApiTest(TestCase): """ NOTE: The basic connector sources are added from the migration. """ + statuses = [ConnectorSource.STATUS_BROKEN, ConnectorSource.STATUS_WORKING] def setUp(self): super().setUp() # Set acaps status working, since might be set broken by other test functions - acaps_source = ConnectorSource.objects.get(key='acaps-briefing-notes') + acaps_source = ConnectorSource.objects.get(key="acaps-briefing-notes") acaps_source.status = ConnectorSource.STATUS_WORKING acaps_source.save() def test_get_connector_sources_has_status_key(self): - url = '/api/v1/connector-sources/' + url = "/api/v1/connector-sources/" self.authenticate() response = self.client.get(url) self.assert_200(response) - data = response.data['results'] + data = response.data["results"] for each in data: - assert 'status' in each - assert each['status'] in self.statuses + assert "status" in each + assert each["status"] in self.statuses def test_get_connector_acaps_status_broken(self): - acaps_source = ConnectorSource.objects.get(key='acaps-briefing-notes') + acaps_source = ConnectorSource.objects.get(key="acaps-briefing-notes") acaps_source.status = ConnectorSource.STATUS_BROKEN acaps_source.save() - url = '/api/v1/connector-sources/' + url = "/api/v1/connector-sources/" self.authenticate() response = self.client.get(url) self.assert_200(response) - data = response.data['results'] + data = response.data["results"] for each in data: - assert 'status' in each - if each['key'] == 'acaps-briefing-notes': - assert each['status'] == ConnectorSource.STATUS_BROKEN + assert "status" in each + if each["key"] == "acaps-briefing-notes": + assert each["status"] == ConnectorSource.STATUS_BROKEN else: - assert each['status'] == ConnectorSource.STATUS_WORKING + assert each["status"] == ConnectorSource.STATUS_WORKING def test_get_connectors_have_status_key(self): - url = '/api/v1/connectors/' - data = { - 'title': 'Test Acaps connector', - 'source': acaps_briefing_notes.AcapsBriefingNotes.key - } + url = "/api/v1/connectors/" + data = {"title": "Test Acaps connector", "source": acaps_briefing_notes.AcapsBriefingNotes.key} self.authenticate() response = self.client.post(url, data) @@ -347,22 +339,19 @@ def test_get_connectors_have_status_key(self): response = self.client.get(url) self.assert_200(response) - data = response.data['results'] + data = response.data["results"] for each in data: - assert 'status' in each - assert each['status'] in self.statuses + assert "status" in each + assert each["status"] in self.statuses def test_get_acaps_connector_broken(self): - acaps_source = ConnectorSource.objects.get(key='acaps-briefing-notes') + acaps_source = ConnectorSource.objects.get(key="acaps-briefing-notes") acaps_source.status = ConnectorSource.STATUS_BROKEN acaps_source.save() - url = '/api/v1/connectors/' - data = { - 'title': 'Test Acaps connector', - 'source': acaps_briefing_notes.AcapsBriefingNotes.key - } + url = "/api/v1/connectors/" + data = {"title": "Test Acaps connector", "source": acaps_briefing_notes.AcapsBriefingNotes.key} self.authenticate() response = self.client.post(url, data) @@ -370,20 +359,20 @@ def test_get_acaps_connector_broken(self): response = self.client.get(url) self.assert_200(response) - data = response.data['results'] + data = response.data["results"] for each in data: - assert 'status' in each - if each['source'] == 'acaps-briefing-notes': - assert each['status'] == ConnectorSource.STATUS_BROKEN + assert "status" in each + if each["source"] == "acaps-briefing-notes": + assert each["status"] == ConnectorSource.STATUS_BROKEN else: - assert each['status'] == ConnectorSource.STATUS_BROKEN + assert each["status"] == ConnectorSource.STATUS_BROKEN def test_organization_search_util(self): organization_titles = [ - 'Deep', - 'New Deep', - 'Old Deep', + "Deep", + "New Deep", + "Old Deep", ] Organization.objects.filter(title__in=organization_titles).all().delete() assert Organization.objects.filter(title__in=organization_titles).count() == 0 diff --git a/apps/connector/utils.py b/apps/connector/utils.py index 247bfa58c0..df9973ec7b 100644 --- a/apps/connector/utils.py +++ b/apps/connector/utils.py @@ -1,12 +1,11 @@ import logging -from django.core.cache import cache from django.conf import settings +from django.core.cache import cache from deep.caches import CacheKey from utils.common import replace_ns - logger = logging.getLogger(__name__) @@ -23,11 +22,12 @@ class WrappedClass(ConnectorClass): This wraps the basic connector class and provides functionalities like profiling on fetch and caching on get_content """ + def get_leads(self, *args, **kwargs): try: ret = super().get_leads(*args, **kwargs) except Exception as e: - logger.error('Connector: Get lead failed', exc_info=True) + logger.error("Connector: Get lead failed", exc_info=True) raise ConnectorGetLeadException( f"Parsing Connector Source data for {self.title} failed. " "Maybe the source HTML structure has changed " @@ -41,7 +41,7 @@ def get_content(self, url, params): This will get the cached content if present else fetch from respective source """ - url_params = f'{url}:{str(params)}' + url_params = f"{url}:{str(params)}" cache_key = CacheKey.CONNECTOR_KEY_FORMAT.format(hash(url_params)) data = cache.get(cache_key) @@ -57,7 +57,7 @@ def get_content(self, url, params): def get_rss_fields(item, nsmap, parent_tag=None): - tag = '{}/{}'.format(parent_tag, item.tag) if parent_tag else item.tag + tag = "{}/{}".format(parent_tag, item.tag) if parent_tag else item.tag childs = item.getchildren() fields = [] if len(childs) > 0: @@ -66,8 +66,10 @@ def get_rss_fields(item, nsmap, parent_tag=None): children_fields.extend(get_rss_fields(child, nsmap, tag)) fields.extend(children_fields) else: - fields.append({ - 'key': tag, - 'label': replace_ns(nsmap, tag), - }) + fields.append( + { + "key": tag, + "label": replace_ns(nsmap, tag), + } + ) return fields diff --git a/apps/connector/views.py b/apps/connector/views.py index 1af018babc..5c1056e296 100644 --- a/apps/connector/views.py +++ b/apps/connector/views.py @@ -1,32 +1,21 @@ from django.db import models -from rest_framework import ( - exceptions, - permissions, - response, - views, - viewsets, -) - +from project.models import Project +from rest_framework import exceptions, permissions, response, views, viewsets from rest_framework.decorators import action +from unified_connector.sources.base import Source + from deep.permissions import ModifyPermission -from project.models import Project from utils.common import parse_number +from .models import Connector, ConnectorProject, ConnectorUser from .serializers import ( - SourceSerializer, - SourceDataSerializer, - + ConnectorProjectSerializer, ConnectorSerializer, ConnectorUserSerializer, - ConnectorProjectSerializer, -) -from .models import ( - Connector, - ConnectorUser, - ConnectorProject, + SourceDataSerializer, + SourceSerializer, ) from .sources.store import source_store -from unified_connector.sources.base import Source class SourceViewSet(viewsets.ViewSet): @@ -36,10 +25,12 @@ def list(self, request, version=None): sources = [s() for s in source_store.values()] serializer = SourceSerializer(sources, many=True) results = serializer.data - return response.Response({ - 'count': len(results), - 'results': results, - }) + return response.Response( + { + "count": len(results), + "results": results, + } + ) class SourceQueryView(views.APIView): @@ -47,26 +38,28 @@ class SourceQueryView(views.APIView): def query(self, source_type, query, params): source = source_store[source_type]() - method = getattr(source, 'query_{}'.format(query)) + method = getattr(source, "query_{}".format(query)) query_params = self.request.query_params - offset = parse_number(query_params.get('offset')) or 0 - limit = parse_number(query_params.get('limit')) or Source.DEFAULT_PER_PAGE + offset = parse_number(query_params.get("offset")) or 0 + limit = parse_number(query_params.get("limit")) or Source.DEFAULT_PER_PAGE args = () - if query == 'leads': + if query == "leads": args = (offset, limit) results = method(params, *args) if isinstance(results, list): - return response.Response({ - 'count': len(results), - 'results': results, - 'has_emm_triggers': getattr(source, 'has_emm_triggers', False), - 'has_emm_entities': getattr(source, 'has_emm_entities', False), - }) + return response.Response( + { + "count": len(results), + "results": results, + "has_emm_triggers": getattr(source, "has_emm_triggers", False), + "has_emm_entities": getattr(source, "has_emm_entities", False), + } + ) return response.Response(results) @@ -79,52 +72,48 @@ def post(self, request, source_type, query, version=None): class ConnectorViewSet(viewsets.ModelViewSet): serializer_class = ConnectorSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): - user = self.request.GET.get('user', self.request.user) - project_ids = self.request.GET.get('projects') + user = self.request.GET.get("user", self.request.user) + project_ids = self.request.GET.get("projects") connectors = Connector.get_for(user) - role = self.request.GET.get('role') + role = self.request.GET.get("role") if role: users = ConnectorUser.objects.filter( role=role, user=user, ) - connectors = connectors.filter( - connectoruser__in=users - ) + connectors = connectors.filter(connectoruser__in=users) if not project_ids: return connectors - project_ids = project_ids.split(',') + project_ids = project_ids.split(",") projects = ConnectorProject.objects.filter( project__id__in=project_ids, ) - self_projects = projects.filter(role='self') - global_projects = projects.filter(role='global') + self_projects = projects.filter(role="self") + global_projects = projects.filter(role="global") return connectors.filter( - models.Q(connectorproject__in=self_projects, users=user) | - models.Q(connectorproject__in=global_projects), + models.Q(connectorproject__in=self_projects, users=user) | models.Q(connectorproject__in=global_projects), ) @action( detail=True, permission_classes=[permissions.IsAuthenticated], - methods=['post'], - url_path='leads', - serializer_class=SourceDataSerializer + methods=["post"], + url_path="leads", + serializer_class=SourceDataSerializer, ) def get_leads(self, request, pk=None, version=None): connector = self.get_object() if not connector.can_get(request.user): raise exceptions.PermissionDenied() - project_id = request.data.pop('project', None) + project_id = request.data.pop("project", None) project = project_id and Project.objects.get(id=project_id) params = { @@ -142,29 +131,30 @@ def get_leads(self, request, pk=None, version=None): serializer = SourceDataSerializer( data, many=True, - context={'request': request, 'project': project}, + context={"request": request, "project": project}, ) results = serializer.data - return response.Response({ - 'count': count, - 'has_emm_triggers': getattr(source, 'has_emm_triggers', False), - 'has_emm_entities': getattr(source, 'has_emm_entities', False), - 'results': results, - }) + return response.Response( + { + "count": count, + "has_emm_triggers": getattr(source, "has_emm_triggers", False), + "has_emm_entities": getattr(source, "has_emm_entities", False), + "results": results, + } + ) class ConnectorUserViewSet(viewsets.ModelViewSet): serializer_class = ConnectorUserSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_serializer(self, *args, **kwargs): - data = kwargs.get('data') - list = data and data.get('list') + data = kwargs.get("data") + list = data and data.get("list") if list: - kwargs.pop('data') - kwargs.pop('many', None) + kwargs.pop("data") + kwargs.pop("many", None) return super().get_serializer( data=list, many=True, @@ -182,15 +172,14 @@ def get_queryset(self): class ConnectorProjectViewSet(viewsets.ModelViewSet): serializer_class = ConnectorProjectSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_serializer(self, *args, **kwargs): - data = kwargs.get('data') - list = data and data.get('list') + data = kwargs.get("data") + list = data and data.get("list") if list: - kwargs.pop('data') - kwargs.pop('many', None) + kwargs.pop("data") + kwargs.pop("many", None) return super().get_serializer( data=list, many=True, diff --git a/apps/deduplication/models.py b/apps/deduplication/models.py index 92b94c73bb..85e10ab08a 100644 --- a/apps/deduplication/models.py +++ b/apps/deduplication/models.py @@ -1,11 +1,10 @@ -import pickle -from django.db import models - import logging +import pickle -from apps.user_resource.models import UserResourceCreated +from django.db import models from project.models import Project +from apps.user_resource.models import UserResourceCreated logger = logging.getLogger(__name__) @@ -69,16 +68,10 @@ class Meta: def load_index(self): """This sets the attribute index if pickle is present""" - if ( - hasattr(self, "index_pickle") and - self.index_pickle is not None and - self.pickle_version is not None - ): + if hasattr(self, "index_pickle") and self.index_pickle is not None and self.pickle_version is not None: supported_formats = pickle.compatible_formats if self.pickle_version not in supported_formats: - logger.warn( - "Pickle versions not compatible, setting index to None" - ) + logger.warn("Pickle versions not compatible, setting index to None") self._index = None else: self._index = pickle.loads(self.index_pickle) @@ -88,9 +81,7 @@ def load_index(self): @property def index(self): - if not self._index_loaded or ( - self._index is None and self.index_pickle is not None - ): + if not self._index_loaded or (self._index is None and self.index_pickle is not None): self.load_index() return self._index diff --git a/apps/deduplication/receivers.py b/apps/deduplication/receivers.py index d5c380af99..dfa1abf625 100644 --- a/apps/deduplication/receivers.py +++ b/apps/deduplication/receivers.py @@ -1,16 +1,13 @@ -from django.db import transaction, models -from django.dispatch import receiver - from deduplication.models import LSHIndex +from django.db import models, transaction +from django.dispatch import receiver from lead.models import Lead, LeadDuplicates @receiver(models.signals.post_delete, sender=LSHIndex) def set_leads_as_unindexed(sender, instance, **kwargs): # set leads is_indexed False - transaction.on_commit( - lambda: clear_duplicates(instance) - ) + transaction.on_commit(lambda: clear_duplicates(instance)) @transaction.atomic @@ -24,7 +21,4 @@ def clear_duplicates(index_obj: LSHIndex): duplicate_leads_count=0, ) - LeadDuplicates.objects.filter( - models.Q(source_lead_id__in=lead_ids) | - models.Q(target_lead_id__in=lead_ids) - ).delete() + LeadDuplicates.objects.filter(models.Q(source_lead_id__in=lead_ids) | models.Q(target_lead_id__in=lead_ids)).delete() diff --git a/apps/deduplication/tasks/indexing.py b/apps/deduplication/tasks/indexing.py index d6ea038a78..9643065b0a 100644 --- a/apps/deduplication/tasks/indexing.py +++ b/apps/deduplication/tasks/indexing.py @@ -1,16 +1,17 @@ import pickle -from django.db import transaction -from django.db.models import F -from django.utils import timezone + from celery import shared_task from celery.utils.log import get_task_logger from datasketch import LeanMinHash, MinHashLSH - -from utils.common import batched -from lead.models import Lead -from project.models import Project from deduplication.models import LSHIndex from deduplication.utils import get_minhash, insert_to_index +from django.db import transaction +from django.db.models import F +from django.utils import timezone +from lead.models import Lead +from project.models import Project + +from utils.common import batched logger = get_task_logger(__name__) @@ -21,14 +22,13 @@ def find_and_set_duplicate_leads(index: MinHashLSH, lead: Lead, minhash: LeanMin duplicate_leads_count = duplicate_leads_qs.count() if duplicate_leads_count > 0: lead.duplicate_leads_count += duplicate_leads_count - duplicate_leads_qs\ - .update(duplicate_leads_count=F('duplicate_leads_count') + 1) + duplicate_leads_qs.update(duplicate_leads_count=F("duplicate_leads_count") + 1) lead.duplicate_leads.set(duplicate_leads_qs) - lead.save(update_fields=['duplicate_leads_count']) + lead.save(update_fields=["duplicate_leads_count"]) def process_and_index_lead(lead: Lead, index: MinHashLSH): - text = lead.leadpreview.text_extract if hasattr(lead, 'leadpreview') else lead.text + text = lead.leadpreview.text_extract if hasattr(lead, "leadpreview") else lead.text if not text: return index minhash = get_minhash(text) @@ -37,7 +37,7 @@ def process_and_index_lead(lead: Lead, index: MinHashLSH): insert_to_index(index, lead.id, minhash) lead.is_indexed = True lead.indexed_at = timezone.now() - lead.save(update_fields=['is_indexed', 'indexed_at']) + lead.save(update_fields=["is_indexed", "indexed_at"]) return index @@ -61,16 +61,13 @@ def process_and_index_leads( index_obj.index = index index_obj.save() except Exception: - logger.error( - f"Error creating index for project {project.title}({project.id})", - exc_info=True - ) + logger.error(f"Error creating index for project {project.title}({project.id})", exc_info=True) index_obj.has_errored = True - index_obj.save(update_fields=['has_errored']) + index_obj.save(update_fields=["has_errored"]) else: index_obj.status = LSHIndex.IndexStatus.CREATED - index_obj.save(update_fields=['status']) + index_obj.save(update_fields=["status"]) def create_project_index(project: Project): @@ -119,7 +116,7 @@ def index_lead_and_calculate_duplicates(lead_id: int): logger.error(f"Cannot index inexistent lead(id={lead_id})") return - text = lead.leadpreview.text_extract if hasattr(lead, 'leadpreview') else lead.text + text = lead.leadpreview.text_extract if hasattr(lead, "leadpreview") else lead.text if not text: return @@ -132,7 +129,7 @@ def index_lead_and_calculate_duplicates(lead_id: int): index = process_and_index_lead(lead, index_obj.index) index_obj.index = index - index_obj.save(update_fields=['index_pickle']) + index_obj.save(update_fields=["index_pickle"]) @shared_task @@ -152,4 +149,4 @@ def remove_lead_from_index(lead_id: int): return index.remove(lead.id) index_obj.index = index - index_obj.save(update_fields=['index_pickle']) + index_obj.save(update_fields=["index_pickle"]) diff --git a/apps/deduplication/tests/test_tasks.py b/apps/deduplication/tests/test_tasks.py index f8fdebe8ed..032a4bae04 100644 --- a/apps/deduplication/tests/test_tasks.py +++ b/apps/deduplication/tests/test_tasks.py @@ -1,21 +1,22 @@ -import pytest from unittest.mock import patch -from django.db.models import Q -from deep.tests import TestCase -from project.factories import ProjectFactory -from lead.factories import LeadPreviewFactory, LeadFactory -from lead.receivers import update_index_and_duplicates -from lead.models import Lead, LeadDuplicates -from deduplication.models import LSHIndex +import pytest from deduplication.factories import LSHIndexFactory +from deduplication.models import LSHIndex from deduplication.tasks.indexing import ( - process_and_index_lead, + create_project_index, get_index_object_for_project, index_lead_and_calculate_duplicates, + process_and_index_lead, remove_lead_from_index, - create_project_index, ) +from django.db.models import Q +from lead.factories import LeadFactory, LeadPreviewFactory +from lead.models import Lead, LeadDuplicates +from lead.receivers import update_index_and_duplicates +from project.factories import ProjectFactory + +from deep.tests import TestCase @pytest.mark.django_db @@ -45,7 +46,7 @@ def test_get_index_object_for_project_existing(self): final_count = LSHIndex.objects.count() assert final_count == original_count - @patch('deduplication.tasks.indexing.get_index_object_for_project') + @patch("deduplication.tasks.indexing.get_index_object_for_project") def test_index_lead_and_calculate_duplicates_no_text(self, get_index_func): """When lead has no text, the function should return early without calling the function get_index_object_for_project @@ -55,7 +56,7 @@ def test_index_lead_and_calculate_duplicates_no_text(self, get_index_func): index_lead_and_calculate_duplicates(lead.id) get_index_func.assert_not_called() - @patch('deduplication.tasks.indexing.process_and_index_lead') + @patch("deduplication.tasks.indexing.process_and_index_lead") def test_index_lead_and_calculate_duplicates_errored_index(self, process_lead_func): project = ProjectFactory.create() lead = LeadFactory.create(project=project) @@ -180,10 +181,9 @@ def test_update_index_and_duplicates(self): project_leads = Lead.objects.filter(project=project) assert project_leads.filter(duplicate_leads_count=0).count() == 0, "Leads should have duplicates" - assert LeadDuplicates.objects.filter( - Q(source_lead_id=first_lead.id) | - Q(target_lead_id=first_lead.id) - ).count() > 0, "There should be duplicates entries for the first lead" + assert ( + LeadDuplicates.objects.filter(Q(source_lead_id=first_lead.id) | Q(target_lead_id=first_lead.id)).count() > 0 + ), "There should be duplicates entries for the first lead" # NOTE: this should have been called by signal update_index_and_duplicates(first_lead) diff --git a/apps/deduplication/utils.py b/apps/deduplication/utils.py index 9190b03aaf..3d933511f6 100644 --- a/apps/deduplication/utils.py +++ b/apps/deduplication/utils.py @@ -1,6 +1,6 @@ import re -from datasketch import MinHash, LeanMinHash +from datasketch import LeanMinHash, MinHash from deduplication.models import LSHIndex diff --git a/apps/deep_explore/apps.py b/apps/deep_explore/apps.py index e625b7fd07..ef092c6d26 100644 --- a/apps/deep_explore/apps.py +++ b/apps/deep_explore/apps.py @@ -2,4 +2,4 @@ class DeepExploreConfig(AppConfig): - name = 'deep_explore' + name = "deep_explore" diff --git a/apps/deep_explore/enums.py b/apps/deep_explore/enums.py index 93795f59df..8405fd2b60 100644 --- a/apps/deep_explore/enums.py +++ b/apps/deep_explore/enums.py @@ -5,10 +5,10 @@ from .models import PublicExploreSnapshot -PublicExploreSnapshotTypeEnum = convert_enum_to_graphene_enum( - PublicExploreSnapshot.Type, name='PublicExploreSnapshotTypeEnum') +PublicExploreSnapshotTypeEnum = convert_enum_to_graphene_enum(PublicExploreSnapshot.Type, name="PublicExploreSnapshotTypeEnum") PublicExploreSnapshotGlobalTypeEnum = convert_enum_to_graphene_enum( - PublicExploreSnapshot.GlobalType, name='PublicExploreSnapshotGlobalTypeEnum') + PublicExploreSnapshot.GlobalType, name="PublicExploreSnapshotGlobalTypeEnum" +) enum_map = { diff --git a/apps/deep_explore/filter_set.py b/apps/deep_explore/filter_set.py index 28c8a2dc62..8ae0a0c890 100644 --- a/apps/deep_explore/filter_set.py +++ b/apps/deep_explore/filter_set.py @@ -1,19 +1,19 @@ -from django.db import models import django_filters +from django.db import models +from entry.models import Entry +from user_resource.filters import UserResourceGqlFilterSet from deep.filter_set import OrderEnumMixin, generate_type_for_filter_set from utils.graphene.filters import IDListFilter -from user_resource.filters import UserResourceGqlFilterSet from .models import Project -from entry.models import Entry class ExploreProjectFilterSet(OrderEnumMixin, UserResourceGqlFilterSet): organizations = IDListFilter(distinct=True) - is_test = django_filters.BooleanFilter(method='filter_is_test') - search = django_filters.CharFilter(method='filter_title') - exclude_entry_less_than = django_filters.BooleanFilter(method='filter_exclude_entry_less_than') + is_test = django_filters.BooleanFilter(method="filter_is_test") + search = django_filters.CharFilter(method="filter_title") + exclude_entry_less_than = django_filters.BooleanFilter(method="filter_exclude_entry_less_than") regions = IDListFilter(distinct=True) class Meta: @@ -33,14 +33,17 @@ def filter_title(self, qs, _, value): def filter_exclude_entry_less_than(self, qs, _, value): if value is True: return qs.annotate( - entry_count=models.functions.Coalesce(models.Subquery( - Entry.objects.filter( - project=models.OuterRef('id') - ).order_by().values('project').annotate( - count=models.Count('id', distinct=True) - ).values('count')[:1], - output_field=models.IntegerField() - ), 0) + entry_count=models.functions.Coalesce( + models.Subquery( + Entry.objects.filter(project=models.OuterRef("id")) + .order_by() + .values("project") + .annotate(count=models.Count("id", distinct=True)) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ) ).filter(entry_count__gt=100) # False and None has same result return qs @@ -52,7 +55,7 @@ def qs(self): ExploreProjectFilterDataType, ExploreProjectFilterDataInputType = generate_type_for_filter_set( ExploreProjectFilterSet, - 'project.schema.ProjectListType', - 'ExploreProjectFilterDataType', - 'ExploreProjectFilterDataInputType', + "project.schema.ProjectListType", + "ExploreProjectFilterDataType", + "ExploreProjectFilterDataInputType", ) diff --git a/apps/deep_explore/management/commands/update_deep_explore_data.py b/apps/deep_explore/management/commands/update_deep_explore_data.py index 8e2134b96d..d547eabb6c 100644 --- a/apps/deep_explore/management/commands/update_deep_explore_data.py +++ b/apps/deep_explore/management/commands/update_deep_explore_data.py @@ -1,17 +1,17 @@ import time -from django.db import transaction -from django.core.management.base import BaseCommand - -from deep.caches import CacheKey from deep_explore.tasks import ( - update_deep_explore_entries_count_by_geo_aggreagate, generate_public_deep_explore_snapshot, + update_deep_explore_entries_count_by_geo_aggreagate, ) +from django.core.management.base import BaseCommand +from django.db import transaction from geo.models import GeoArea +from deep.caches import CacheKey -class ShowRunTime(): + +class ShowRunTime: def __init__(self, command: BaseCommand, func_name): self.func_name = func_name self.command = command @@ -22,9 +22,7 @@ def __enter__(self): def __exit__(self, *_): self.command.stdout.write( - self.command.style.SUCCESS( - f"{self.func_name} Runtime: {time.time() - self.start_time} seconds" - ) + self.command.style.SUCCESS(f"{self.func_name} Runtime: {time.time() - self.start_time} seconds") ) @@ -33,24 +31,20 @@ def handle(self, **_): start_time = time.time() # Try to clear cache - with ShowRunTime(self, 'Clear existing memory caches'): - print(f'Clear status: {CacheKey.ExploreDeep.clear_cache()}') + with ShowRunTime(self, "Clear existing memory caches"): + print(f"Clear status: {CacheKey.ExploreDeep.clear_cache()}") # Calculate centroid for geo_areas if not already. - with ShowRunTime(self, 'GeoCentroid Update'): + with ShowRunTime(self, "GeoCentroid Update"): with transaction.atomic(): GeoArea.sync_centroid() # Update explore data - with ShowRunTime(self, 'Geo Entries Aggregate Update'): + with ShowRunTime(self, "Geo Entries Aggregate Update"): update_deep_explore_entries_count_by_geo_aggreagate() # Update public snapshots - with ShowRunTime(self, 'DeepExplore Public Snapshot Update'): + with ShowRunTime(self, "DeepExplore Public Snapshot Update"): generate_public_deep_explore_snapshot() - self.stdout.write( - self.style.SUCCESS( - f"Total Runtime: {time.time() - start_time} seconds" - ) - ) + self.stdout.write(self.style.SUCCESS(f"Total Runtime: {time.time() - start_time} seconds")) diff --git a/apps/deep_explore/models.py b/apps/deep_explore/models.py index ac81dc80fc..1d1ba1d30a 100644 --- a/apps/deep_explore/models.py +++ b/apps/deep_explore/models.py @@ -1,14 +1,14 @@ -from django.db import models from django.core.exceptions import ValidationError - -from project.models import Project +from django.db import models from geo.models import GeoArea +from project.models import Project class AggregateTracker(models.Model): """ Used to track aggregated data last updated status """ + class Type(models.IntegerChoices): ENTRIES_COUNT_BY_GEO_AREA = 1 @@ -25,54 +25,59 @@ class EntriesCountByGeoAreaAggregate(models.Model): """ Used as cache to calculate entry - geo_area stats """ + project = models.ForeignKey(Project, on_delete=models.CASCADE) geo_area = models.ForeignKey(GeoArea, on_delete=models.CASCADE) date = models.DateField() entries_count = models.IntegerField() class Meta: - ordering = ('date',) - unique_together = ('project', 'geo_area', 'date') + ordering = ("date",) + unique_together = ("project", "geo_area", "date") class PublicExploreSnapshot(models.Model): """ Used to store snapshot used by public dashboard """ + class Type(models.IntegerChoices): - GLOBAL = 1, 'Global Snapshot' - YEARLY_SNAPSHOT = 2, 'Yearly Snapshot' + GLOBAL = 1, "Global Snapshot" + YEARLY_SNAPSHOT = 2, "Yearly Snapshot" class GlobalType(models.IntegerChoices): - FULL = 1, 'Full Dataset' - TIME_SERIES = 2, 'Time Series Dataset' + FULL = 1, "Full Dataset" + TIME_SERIES = 2, "Time Series Dataset" type = models.SmallIntegerField(choices=Type.choices) global_type = models.SmallIntegerField(choices=GlobalType.choices, null=True) start_date = models.DateField() end_date = models.DateField() year = models.SmallIntegerField(unique=True, null=True) - file = models.FileField(upload_to='deep-explore/public-snapshot/', max_length=255) + file = models.FileField(upload_to="deep-explore/public-snapshot/", max_length=255) # Empty for global - download_file = models.FileField(upload_to='deep-explore/public-excel-export/', max_length=255, blank=True) + download_file = models.FileField(upload_to="deep-explore/public-excel-export/", max_length=255, blank=True) class Meta: - ordering = ('type', 'year',) + ordering = ( + "type", + "year", + ) def clean(self): validation_set = [ - (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.TIME_SERIES, 'file'), - (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.FULL, 'file'), - (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.FULL, 'download_file'), - (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, 'year'), - (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, 'file'), - (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, 'download_file'), + (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.TIME_SERIES, "file"), + (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.FULL, "file"), + (PublicExploreSnapshot.Type.GLOBAL, PublicExploreSnapshot.GlobalType.FULL, "download_file"), + (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, "year"), + (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, "file"), + (PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, None, "download_file"), ] - fields = ['year', 'file', 'download_file'] + fields = ["year", "file", "download_file"] check_set = (self.type, self.global_type) for field in fields: if (*check_set, field) in validation_set: if getattr(self, field) is None: raise ValidationError( - f'+ needs to be defined.' + f"+ needs to be defined." ) diff --git a/apps/deep_explore/schema.py b/apps/deep_explore/schema.py index 05baf0e3ea..5906b16d48 100644 --- a/apps/deep_explore/schema.py +++ b/apps/deep_explore/schema.py @@ -1,33 +1,29 @@ import copy -import graphene -from typing import List, Callable -from datetime import timedelta from dataclasses import dataclass +from datetime import timedelta +from typing import Callable, List +import graphene +from analysis_framework.models import AnalysisFramework +from deep_explore.models import EntriesCountByGeoAreaAggregate, PublicExploreSnapshot +from django.contrib.postgres.aggregates.general import ArrayAgg from django.db import models +from django.db.models.functions import TruncDay, TruncMonth from django.utils import timezone -from django.db.models.functions import ( - TruncMonth, - TruncDay, -) -from django.contrib.postgres.aggregates.general import ArrayAgg -from graphene_django import DjangoObjectType, DjangoListField +from entry.models import Entry +from geo.models import Region +from graphene_django import DjangoListField, DjangoObjectType +from lead.models import Lead +from organization.models import Organization +from project.models import Project, ProjectMembership +from user.models import User -from deep.caches import CacheKey, CacheHelper +from deep.caches import CacheHelper, CacheKey from utils.graphene.geo_scalars import PointScalar from utils.graphene.types import FileFieldType -from organization.models import Organization -from geo.models import Region -from user.models import User -from project.models import Project, ProjectMembership -from lead.models import Lead -from entry.models import Entry -from analysis_framework.models import AnalysisFramework -from deep_explore.models import EntriesCountByGeoAreaAggregate, PublicExploreSnapshot +from .enums import PublicExploreSnapshotGlobalTypeEnum, PublicExploreSnapshotTypeEnum from .filter_set import ExploreProjectFilterDataInputType, ExploreProjectFilterSet -from .enums import PublicExploreSnapshotTypeEnum, PublicExploreSnapshotGlobalTypeEnum - # TODO? NODE_CACHE_TIMEOUT = 60 * 60 * 1 # 1 Hour @@ -44,10 +40,10 @@ def cache_key_gen(root: ExploreDashboardStatRoot, *_): ) -def get_global_filters(_filter: dict, date_field='created_at'): +def get_global_filters(_filter: dict, date_field="created_at"): return { - f'{date_field}__gte': _filter['date_from'], - f'{date_field}__lte': _filter['date_to'], + f"{date_field}__gte": _filter["date_from"], + f"{date_field}__lte": _filter["date_to"], } @@ -65,11 +61,7 @@ class ExploreCountByDateType(graphene.ObjectType): def count_by_date_queryset_generator(qs: models.QuerySet, trunc_func: Callable): # Used by ExploreCountByDateListType - return qs.values( - date=trunc_func('created_at') - ).annotate( - count=models.Count('id') - ).order_by('date') + return qs.values(date=trunc_func("created_at")).annotate(count=models.Count("id")).order_by("date") def get_top_ten_organizations_list( @@ -81,30 +73,31 @@ def get_top_ten_organizations_list( return [ { **data, - 'id': data.pop('org_id'), - 'title': data.pop('org_title'), + "id": data.pop("org_id"), + "title": data.pop("org_title"), } for data in leads_qs.filter( - **{f'{lead_field}__in': organization_queryset}, + **{f"{lead_field}__in": organization_queryset}, project__in=project_qs, - ).annotate( - org_id=models.functions.Coalesce( - models.F(f'{lead_field}__parent'), - models.F(f'{lead_field}__id') - ), - org_title=models.functions.Coalesce( - models.F(f'{lead_field}__parent__title'), - models.F(f'{lead_field}__title') - ), - ).order_by().values('org_id', 'org_title').annotate( - leads_count=models.Count('id', distinct=True), - projects_count=models.Count('project', distinct=True), - ).order_by('-leads_count', '-projects_count').values( - 'org_id', - 'org_title', - 'leads_count', - 'projects_count', - ).distinct()[:10] + ) + .annotate( + org_id=models.functions.Coalesce(models.F(f"{lead_field}__parent"), models.F(f"{lead_field}__id")), + org_title=models.functions.Coalesce(models.F(f"{lead_field}__parent__title"), models.F(f"{lead_field}__title")), + ) + .order_by() + .values("org_id", "org_title") + .annotate( + leads_count=models.Count("id", distinct=True), + projects_count=models.Count("project", distinct=True), + ) + .order_by("-leads_count", "-projects_count") + .values( + "org_id", + "org_title", + "leads_count", + "projects_count", + ) + .distinct()[:10] ] @@ -116,37 +109,41 @@ def get_top_ten_frameworks_list( # Calcuate projects/entries count projects_count_by_af = { af: count - for af, count in projects_qs.filter( - analysis_framework__in=analysis_framework_qs - ).order_by().values('analysis_framework').annotate( - count=models.Count('id'), - ).values_list('analysis_framework', 'count') + for af, count in projects_qs.filter(analysis_framework__in=analysis_framework_qs) + .order_by() + .values("analysis_framework") + .annotate( + count=models.Count("id"), + ) + .values_list("analysis_framework", "count") } entries_count_by_af = { af: count - for af, count in entries_qs.filter( - analysis_framework__in=analysis_framework_qs - ).order_by().values('analysis_framework').annotate( - count=models.Count('id'), - ).values_list('analysis_framework', 'count') + for af, count in entries_qs.filter(analysis_framework__in=analysis_framework_qs) + .order_by() + .values("analysis_framework") + .annotate( + count=models.Count("id"), + ) + .values_list("analysis_framework", "count") } # Sort AF id using projects/entries count - af_count_data = sorted([ - (af_id, entries_count_by_af.get(af_id, 0), projects_count_by_af.get(af_id, 0)) - for af_id in set([*projects_count_by_af.keys(), *entries_count_by_af.keys()]) - ], key=lambda x: x[1:], reverse=True)[:10] + af_count_data = sorted( + [ + (af_id, entries_count_by_af.get(af_id, 0), projects_count_by_af.get(af_id, 0)) + for af_id in set([*projects_count_by_af.keys(), *entries_count_by_af.keys()]) + ], + key=lambda x: x[1:], + reverse=True, + )[:10] # Fetch Top ten AF - af_data = { - af['id']: af - for af in analysis_framework_qs.distinct().filter( - ).values('id', 'title') - } + af_data = {af["id"]: af for af in analysis_framework_qs.distinct().filter().values("id", "title")} # Return AF data with projects/entries count return [ { **af_data[af_id], - 'entries_count': entries_count, - 'projects_count': projects_count, + "entries_count": entries_count, + "projects_count": projects_count, } for af_id, entries_count, projects_count in af_count_data if af_id in af_data @@ -172,17 +169,25 @@ def _order_by_lead(x): af: count for af, count in leads_qs.filter( project__in=projects_qs, - ).order_by().values('project').annotate( - count=models.Count('id'), - ).values_list('project', 'count') + ) + .order_by() + .values("project") + .annotate( + count=models.Count("id"), + ) + .values_list("project", "count") } entries_count_by_project = { af: count for af, count in entries_qs.filter( project__in=projects_qs, - ).order_by().values('project').annotate( - count=models.Count('id'), - ).values_list('project', 'count') + ) + .order_by() + .values("project") + .annotate( + count=models.Count("id"), + ) + .values_list("project", "count") } # Sort Project id using projects/entries count project_count_data = sorted( @@ -190,24 +195,17 @@ def _order_by_lead(x): (project_id, entries_count_by_project.get(project_id, 0), leads_count_by_project.get(project_id, 0)) for project_id in set([*leads_count_by_project.keys(), *entries_count_by_project.keys()]) ], - key=( - _order_by_entry if order_by_entry - else _order_by_lead - ), + key=(_order_by_entry if order_by_entry else _order_by_lead), reverse=True, )[:10] # Fetch Top ten Project - project_data = { - af['id']: af - for af in projects_qs.distinct().filter( - ).values('id', 'title') - } + project_data = {af["id"]: af for af in projects_qs.distinct().filter().values("id", "title")} # Return Project data with projects/entries count return [ { **project_data[project_id], - 'entries_count': entries_count, - 'leads_count': leads_count, + "entries_count": entries_count, + "leads_count": leads_count, } for project_id, entries_count, leads_count in project_count_data if project_id in project_data @@ -259,7 +257,7 @@ class ExploreDeepStatEntriesCountByCentroidType(graphene.ObjectType): @dataclass -class ExploreDashboardStatRoot(): +class ExploreDashboardStatRoot: cache_key: str analysis_framework_qs: models.QuerySet entries_count_by_geo_area_aggregate_qs: models.QuerySet @@ -324,41 +322,39 @@ def resolve_total_entries(root: ExploreDashboardStatRoot, *_) -> int: @staticmethod @CacheHelper.gql_cache(CacheKey.ExploreDeep.TOTAL_ENTRIES_ADDED_LAST_WEEK_COUNT, timeout=NODE_CACHE_TIMEOUT) def resolve_total_entries_added_last_week(*_) -> int: - return Entry.objects.filter( - created_at__gte=timezone.now().date() - timedelta(days=7) - ).count() + return Entry.objects.filter(created_at__gte=timezone.now().date() - timedelta(days=7)).count() @staticmethod @node_cache(CacheKey.ExploreDeep.TOTAL_ACTIVE_USERS_COUNT) def resolve_total_active_users(root: ExploreDashboardStatRoot, *_) -> int: - created_by_qs = root.leads_qs.values('created_by').union( - root.entries_qs.values('created_by'), + created_by_qs = root.leads_qs.values("created_by").union( + root.entries_qs.values("created_by"), # Modified By - root.leads_qs.values('modified_by'), - root.entries_qs.values('modified_by'), + root.leads_qs.values("modified_by"), + root.entries_qs.values("modified_by"), ) - return User.objects.filter(id__in=created_by_qs).values('id').distinct().count() + return User.objects.filter(id__in=created_by_qs).values("id").distinct().count() @staticmethod @node_cache(CacheKey.ExploreDeep.TOTAL_AUTHORS_COUNT) def resolve_total_authors(root: ExploreDashboardStatRoot, *_) -> int: - return root.leads_qs.values('authors').distinct().count() + return root.leads_qs.values("authors").distinct().count() @staticmethod @node_cache(CacheKey.ExploreDeep.TOTAL_PUBLISHERS_COUNT) def resolve_total_publishers(root: ExploreDashboardStatRoot, *_) -> int: - return root.leads_qs.values('source').distinct().count() + return root.leads_qs.values("source").distinct().count() # --- Array data ---- @staticmethod @node_cache(CacheKey.ExploreDeep.TOP_TEN_AUTHORS_LIST) def resolve_top_ten_authors(root: ExploreDashboardStatRoot, *_): - return get_top_ten_organizations_list(root.organization_qs, root.leads_qs, root.projects_qs, 'authors') + return get_top_ten_organizations_list(root.organization_qs, root.leads_qs, root.projects_qs, "authors") @staticmethod @node_cache(CacheKey.ExploreDeep.TOP_TEN_PUBLISHERS_LIST) def resolve_top_ten_publishers(root: ExploreDashboardStatRoot, *_): - return get_top_ten_organizations_list(root.organization_qs, root.leads_qs, root.projects_qs, 'source') + return get_top_ten_organizations_list(root.organization_qs, root.leads_qs, root.projects_qs, "source") @staticmethod @node_cache(CacheKey.ExploreDeep.TOP_TEN_FRAMEWORKS_LIST) @@ -369,20 +365,27 @@ def resolve_top_ten_frameworks(root: ExploreDashboardStatRoot, *_): @node_cache(CacheKey.ExploreDeep.TOP_TEN_PROJECTS_BY_USERS_LIST) def resolve_top_ten_projects_by_users(root: ExploreDashboardStatRoot, *_): return list( - root.projects_qs.distinct().annotate( + root.projects_qs.distinct() + .annotate( users_count=models.functions.Coalesce( models.Subquery( - ProjectMembership.objects.filter( - project=models.OuterRef('pk') - ).order_by().values('project').annotate( - count=models.Count('member', distinct=True), - ).values('count')[:1], - output_field=models.IntegerField() - ), 0), - ).order_by('-users_count').values( - 'id', - 'title', - 'users_count', + ProjectMembership.objects.filter(project=models.OuterRef("pk")) + .order_by() + .values("project") + .annotate( + count=models.Count("member", distinct=True), + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + ) + .order_by("-users_count") + .values( + "id", + "title", + "users_count", )[:10] ) @@ -425,27 +428,35 @@ def resolve_entries_count_by_day(root: ExploreDashboardStatRoot, *_): @staticmethod def resolve_entries_count_by_region(root: ExploreDashboardStatRoot, *_): - return root.entries_count_by_geo_area_aggregate_qs\ - .order_by().values('geo_area').annotate( - count=models.Sum('entries_count'), - ).values( - 'count', - centroid=models.F('geo_area__centroid'), + return ( + root.entries_count_by_geo_area_aggregate_qs.order_by() + .values("geo_area") + .annotate( + count=models.Sum("entries_count"), + ) + .values( + "count", + centroid=models.F("geo_area__centroid"), ) + ) @staticmethod def resolve_projects_by_region(root: ExploreDashboardStatRoot, *_): - return Region.objects.annotate( - project_ids=ArrayAgg( - 'project', - distinct=True, - ordering='project', - filter=models.Q(project__in=root.projects_qs), - ), - ).filter(project_ids__isnull=False).values( - 'id', - 'centroid', - 'project_ids', + return ( + Region.objects.annotate( + project_ids=ArrayAgg( + "project", + distinct=True, + ordering="project", + filter=models.Q(project__in=root.projects_qs), + ), + ) + .filter(project_ids__isnull=False) + .values( + "id", + "centroid", + "project_ids", + ) ) @staticmethod @@ -465,23 +476,22 @@ def custom_resolver(request, _filter): ref_projects_qs = ExploreProjectFilterSet( request=request, queryset=project_queryset(), - data=_filter.get('project'), + data=_filter.get("project"), ).qs projects_qs = copy.deepcopy(ref_projects_qs).filter(**get_global_filters(_filter)) organization_qs = Organization.objects.filter(**get_global_filters(_filter)) analysis_framework_qs = AnalysisFramework.objects.filter(**get_global_filters(_filter)) - registered_users = User.objects.filter(**get_global_filters(_filter, date_field='date_joined')) + registered_users = User.objects.filter(**get_global_filters(_filter, date_field="date_joined")) # With ref_projects_qs as filter entries_qs = Entry.objects.filter(**get_global_filters(_filter), project__in=ref_projects_qs) leads_qs = Lead.objects.filter(**get_global_filters(_filter), project__in=ref_projects_qs) - entries_count_by_geo_area_aggregate_qs = EntriesCountByGeoAreaAggregate.objects\ - .filter( - **get_global_filters(_filter, date_field='date'), - project__in=ref_projects_qs, - geo_area__centroid__isempty=False, - ) + entries_count_by_geo_area_aggregate_qs = EntriesCountByGeoAreaAggregate.objects.filter( + **get_global_filters(_filter, date_field="date"), + project__in=ref_projects_qs, + geo_area__centroid__isempty=False, + ) cache_key = CacheHelper.generate_hash(_filter.__dict__) return ExploreDashboardStatRoot( @@ -501,11 +511,12 @@ class PublicExploreSnapshotType(DjangoObjectType): class Meta: model = PublicExploreSnapshot only_fields = ( - 'id', - 'start_date', - 'end_date', - 'year', + "id", + "start_date", + "end_date", + "year", ) + type = graphene.Field(PublicExploreSnapshotTypeEnum, required=True) global_type = graphene.Field(PublicExploreSnapshotGlobalTypeEnum) file = graphene.Field(FileFieldType) @@ -513,10 +524,7 @@ class Meta: class Query: - deep_explore_stats = graphene.Field( - ExploreDashboardStatType, - filter=ExploreDeepFilterInputType(required=True) - ) + deep_explore_stats = graphene.Field(ExploreDashboardStatType, filter=ExploreDeepFilterInputType(required=True)) public_deep_explore_yearly_snapshots = DjangoListField(PublicExploreSnapshotType) public_deep_explore_global_snapshots = DjangoListField(PublicExploreSnapshotType) diff --git a/apps/deep_explore/tasks.py b/apps/deep_explore/tasks.py index 9349e9c26a..1244ffe814 100644 --- a/apps/deep_explore/tasks.py +++ b/apps/deep_explore/tasks.py @@ -1,23 +1,23 @@ -import logging import datetime +import logging import time +from typing import Tuple, Union + import pytz -from typing import Union, Tuple +from analysis_framework.models import Widget from celery import shared_task +from commons.schema_snapshots import SnapshotQuery, generate_query_snapshot from dateutil.relativedelta import relativedelta - from django.db import connection, models, transaction -from django.utils import timezone from django.test import override_settings +from django.utils import timezone from djangorestframework_camel_case.util import underscoreize +from entry.models import Attribute, Entry +from export.tasks.tasks_projects import generate_projects_stats +from geo.models import GeoArea +from project.models import Project -from commons.schema_snapshots import generate_query_snapshot, SnapshotQuery from utils.common import redis_lock -from entry.models import Entry, Attribute -from project.models import Project -from geo.models import GeoArea -from analysis_framework.models import Widget -from export.tasks.tasks_projects import generate_projects_stats from .models import ( AggregateTracker, @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class DateHelper(): +class DateHelper: @staticmethod def py_date(string: Union[str, None]) -> Union[datetime.date, None]: if string: @@ -131,7 +131,7 @@ def update_deep_explore_entries_count_by_geo_aggreagate(start_over=False): from_date = DateHelper.py_date(tracker.value) else: # Look at entry data - from_date = Entry.objects.aggregate(date=models.Min('created_at__date'))['date'] + from_date = Entry.objects.aggregate(date=models.Min("created_at__date"))["date"] until_date = timezone.now().date() # NOTE: Stats will not include this date params = dict( @@ -139,14 +139,14 @@ def update_deep_explore_entries_count_by_geo_aggreagate(start_over=False): until_date=DateHelper.str(until_date), ) if from_date is None or from_date >= until_date: - logger.info(f'Nothing to do here...{params}') + logger.info(f"Nothing to do here...{params}") return with transaction.atomic(): start_time = time.time() with connection.cursor() as cursor: cursor.execute(get_update_entries_count_by_geo_area_aggregate_sql(), params) - logger.info(f'Rows affected: {cursor.rowcount}') + logger.info(f"Rows affected: {cursor.rowcount}") logger.info(f"Successfull. Runtime: {time.time() - start_time} seconds") tracker.value = until_date logger.info(f"Saving date {tracker.value} as last tracker") @@ -170,8 +170,8 @@ def get_or_create(_type: PublicExploreSnapshot.Type, start_date: datetime.date, def _get_date_filter(min_date: datetime.date, max_date: datetime.date) -> dict: return { - 'dateFrom': min_date.isoformat(), - 'dateTo': max_date.isoformat(), + "dateFrom": min_date.isoformat(), + "dateTo": max_date.isoformat(), } def _get_date_meta(min_year, max_year) -> Tuple[Tuple[datetime.date, datetime.date], dict]: @@ -184,9 +184,9 @@ def _get_date_meta(min_year, max_year) -> Tuple[Tuple[datetime.date, datetime.da @override_settings( CACHES={ - 'default': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - 'LOCATION': 'unique-snowflake', + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "unique-snowflake", } }, ) @@ -197,14 +197,14 @@ def _save_snapshot( snapshot, generate_download_file=True, ): - file_content, errors = generate_query_snapshot(gql_query, {'filter': filters}) + file_content, errors = generate_query_snapshot(gql_query, {"filter": filters}) if file_content is None: - logger.error(f'Failed to generate: {errors}', exc_info=True) + logger.error(f"Failed to generate: {errors}", exc_info=True) return # Delete current file snapshot.file.delete() # Save new file - snapshot.file.save(f'{snapshot_filename}.json', file_content) + snapshot.file.save(f"{snapshot_filename}.json", file_content) if generate_download_file: # Skip for Global snapshots # Generate download_file = generate_projects_stats( @@ -215,18 +215,18 @@ def _save_snapshot( # Delete current file snapshot.download_file.delete() # Save new file - snapshot.download_file.save(f'{snapshot_filename}.csv', download_file) + snapshot.download_file.save(f"{snapshot_filename}.csv", download_file) snapshot.save() # Global year range - data_min_date = Project.objects.aggregate(min_created_at=models.Min('created_at'))['min_created_at'] + data_min_date = Project.objects.aggregate(min_created_at=models.Min("created_at"))["min_created_at"] data_max_date = timezone.now() - relativedelta(days=1) date_range, date_filter = (data_min_date, data_max_date), _get_date_filter(data_min_date, data_max_date) # Global - Time series _save_snapshot( SnapshotQuery.DeepExplore.GLOBAL_TIME_SERIES, date_filter, - 'Global-time-series-snapshot', + "Global-time-series-snapshot", get_or_create( PublicExploreSnapshot.Type.GLOBAL, *date_range, @@ -238,7 +238,7 @@ def _save_snapshot( _save_snapshot( SnapshotQuery.DeepExplore.GLOBAL_FULL, date_filter, - 'Global-full-snapshot', + "Global-full-snapshot", get_or_create( PublicExploreSnapshot.Type.GLOBAL, *date_range, @@ -251,7 +251,7 @@ def _save_snapshot( _save_snapshot( SnapshotQuery.DeepExplore.YEARLY, date_filter, - f'{year}-snapshot', + f"{year}-snapshot", get_or_create( PublicExploreSnapshot.Type.YEARLY_SNAPSHOT, *date_range, @@ -261,7 +261,7 @@ def _save_snapshot( @shared_task -@redis_lock('update_deep_explore_entries_count_by_geo_aggreagate') +@redis_lock("update_deep_explore_entries_count_by_geo_aggreagate") def update_deep_explore_entries_count_by_geo_aggreagate_task(): # Weekly clean-up old data and calculate from start. # https://docs.python.org/3/library/datetime.html#datetime.datetime.weekday @@ -272,6 +272,6 @@ def update_deep_explore_entries_count_by_geo_aggreagate_task(): @shared_task -@redis_lock('update_public_deep_explore_snapshot') +@redis_lock("update_public_deep_explore_snapshot") def update_public_deep_explore_snapshot(): return generate_public_deep_explore_snapshot() diff --git a/apps/deep_explore/tests.py b/apps/deep_explore/tests.py index 1b9fc11984..5b158d572e 100644 --- a/apps/deep_explore/tests.py +++ b/apps/deep_explore/tests.py @@ -1,11 +1,11 @@ -from utils.graphene.tests import GraphQLSnapShotTestCase - -from organization.factories import OrganizationFactory -from user.factories import UserFactory from analysis_framework.factories import AnalysisFrameworkFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory from entry.factories import EntryFactory +from lead.factories import LeadFactory +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase class TestDeepExploreStats(GraphQLSnapShotTestCase): @@ -28,19 +28,21 @@ def setUpClass(cls): project_8 = cls.update_obj(ProjectFactory.create(analysis_framework=analysis_framework2), created_at="2020-09-11") # Some leads - lead_1 = cls.update_obj(LeadFactory.create( - project=project_1, source=organization1, - created_by=user, - ), created_at="2020-10-11") + lead_1 = cls.update_obj( + LeadFactory.create( + project=project_1, + source=organization1, + created_by=user, + ), + created_at="2020-10-11", + ) lead_1.authors.add(organization1) lead_2 = cls.update_obj( LeadFactory.create(project=project_2, source=organization2, created_by=user), created_at="2020-10-11" ) lead_2.authors.add(organization2) cls.update_obj(LeadFactory.create(project=project_3), created_at="2021-10-11") - cls.update_obj( - LeadFactory.create(project=project_4, source=organization2, created_by=user2), created_at="2024-10-11" - ) + cls.update_obj(LeadFactory.create(project=project_4, source=organization2, created_by=user2), created_at="2024-10-11") lead_5 = cls.update_obj(LeadFactory.create(project=project_5), created_at="2021-10-11") lead_5.authors.add(organization1) cls.update_obj(LeadFactory.create(project=project_6, created_by=user), created_at="2023-10-11") @@ -48,9 +50,7 @@ def setUpClass(cls): LeadFactory.create(project=project_7, source=organization1, created_by=user2), created_at="2020-11-11" ) lead_7.authors.add(organization2) - cls.update_obj( - LeadFactory.create(project=project_8, source=organization1, created_by=user2), created_at="2020-09-11" - ) + cls.update_obj(LeadFactory.create(project=project_8, source=organization1, created_by=user2), created_at="2020-09-11") # Some entry cls.update_obj(EntryFactory.create(project=project_1, created_by=user, lead=lead_1), created_at="2020-10-11") @@ -119,19 +119,21 @@ def test_explore_deep_dashboard(self): project_8 = self.update_obj(ProjectFactory.create(analysis_framework=analysis_framework2), created_at="2020-09-11") # Some leads - lead_1 = self.update_obj(LeadFactory.create( - project=project_1, source=organization1, - created_by=user, - ), created_at="2020-10-11") + lead_1 = self.update_obj( + LeadFactory.create( + project=project_1, + source=organization1, + created_by=user, + ), + created_at="2020-10-11", + ) lead_1.authors.add(organization1) lead_2 = self.update_obj( LeadFactory.create(project=project_2, source=organization2, created_by=user), created_at="2020-10-11" ) lead_2.authors.add(organization2) self.update_obj(LeadFactory.create(project=project_3), created_at="2021-10-11") - self.update_obj( - LeadFactory.create(project=project_4, source=organization2, created_by=user2), created_at="2024-10-11" - ) + self.update_obj(LeadFactory.create(project=project_4, source=organization2, created_by=user2), created_at="2024-10-11") lead_5 = self.update_obj(LeadFactory.create(project=project_5), created_at="2021-10-11") lead_5.authors.add(organization1) self.update_obj(LeadFactory.create(project=project_6, created_by=user), created_at="2023-10-11") @@ -139,9 +141,7 @@ def test_explore_deep_dashboard(self): LeadFactory.create(project=project_7, source=organization1, created_by=user2), created_at="2020-11-11" ) lead_7.authors.add(organization2) - self.update_obj( - LeadFactory.create(project=project_8, source=organization1, created_by=user2), created_at="2020-09-11" - ) + self.update_obj(LeadFactory.create(project=project_8, source=organization1, created_by=user2), created_at="2020-09-11") # Some entry self.update_obj(EntryFactory.create(project=project_1, created_by=user, lead=lead_1), created_at="2020-10-11") @@ -157,22 +157,19 @@ def _query_check(filter=None, **kwargs): return self.query_check( query, variables={ - 'filter': filter, + "filter": filter, }, - **kwargs + **kwargs, ) - filter = { - "dateFrom": "2020-10-01", - "dateTo": "2021-11-11" - } + filter = {"dateFrom": "2020-10-01", "dateTo": "2021-11-11"} self.force_login(user) - content = _query_check(filter)['data']['deepExploreStats'] + content = _query_check(filter)["data"]["deepExploreStats"] self.assertIsNotNone(content, content) - self.assertEqual(content['totalActiveUsers'], 3) - self.assertEqual(content['totalAuthors'], 3) - self.assertEqual(content['totalEntries'], 4) - self.assertEqual(content['totalLeads'], 4) - self.assertEqual(content['totalProjects'], 4) - self.assertEqual(content['totalPublishers'], 2) - self.assertEqual(content['totalRegisteredUsers'], 3) + self.assertEqual(content["totalActiveUsers"], 3) + self.assertEqual(content["totalAuthors"], 3) + self.assertEqual(content["totalEntries"], 4) + self.assertEqual(content["totalLeads"], 4) + self.assertEqual(content["totalProjects"], 4) + self.assertEqual(content["totalPublishers"], 2) + self.assertEqual(content["totalRegisteredUsers"], 3) diff --git a/apps/deep_migration/apps.py b/apps/deep_migration/apps.py index f34680d167..1ea49ff3ba 100644 --- a/apps/deep_migration/apps.py +++ b/apps/deep_migration/apps.py @@ -2,4 +2,4 @@ class DeepMigrationConfig(AppConfig): - name = 'deep_migration' + name = "deep_migration" diff --git a/apps/deep_migration/management/commands/migrate_af_changes_v2_v3.py b/apps/deep_migration/management/commands/migrate_af_changes_v2_v3.py index bee2ecead5..39cf9fe413 100644 --- a/apps/deep_migration/management/commands/migrate_af_changes_v2_v3.py +++ b/apps/deep_migration/management/commands/migrate_af_changes_v2_v3.py @@ -1,25 +1,22 @@ -import json import copy +import json +from analysis_framework.models import Section, Widget from django.core.management.base import BaseCommand -from analysis_framework.models import Widget, Section from entry.models import Attribute def clone_data(src_data, mapping): - return { - key: src_data.get(src_key) - for key, src_key, _ in mapping - } + return {key: src_data.get(src_key) for key, src_key, _ in mapping} def verifiy_data(data, src_data, mapping): for key, _, is_required in mapping: - if is_required and data.get(key) in ['', None]: - print(f'-- {key}') + if is_required and data.get(key) in ["", None]: + print(f"-- {key}") print(json.dumps(src_data, indent=2)) print(json.dumps(data, indent=2)) - raise Exception('Data is required here') + raise Exception("Data is required here") # -- Widget convertors @@ -57,51 +54,49 @@ def matrix1d_property_convertor(properties): """ if properties in [None, {}]: return { - 'rows': [], + "rows": [], } ROW_MAP = [ # dest, src keys, required - ('key', 'key', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), - ('color', 'color', True), + ("key", "key", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), + ("color", "color", True), ] CELL_MAP = [ # dest, src keys - ('key', 'key', True), - ('label', 'value', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "key", True), + ("label", "value", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] new_rows = [] row_order = 0 - for row in properties['rows']: + for row in properties["rows"]: new_row = clone_data(row, ROW_MAP) - new_row['label'] = new_row['label'] or 'Untitled Row' - new_row['color'] = new_row['color'] or '#808080' - new_row['order'] = row_order + new_row["label"] = new_row["label"] or "Untitled Row" + new_row["color"] = new_row["color"] or "#808080" + new_row["order"] = row_order new_cells = [] cell_order = 0 - for cell in row['cells']: + for cell in row["cells"]: new_cell = clone_data(cell, CELL_MAP) - new_cell['label'] = new_cell['label'] or 'Untitled Cell' - new_cell['order'] = cell_order + new_cell["label"] = new_cell["label"] or "Untitled Cell" + new_cell["order"] = cell_order verifiy_data(new_cell, cell, CELL_MAP) new_cells.append(new_cell) cell_order += 1 - new_row['cells'] = new_cells + new_row["cells"] = new_cells verifiy_data(new_row, row, ROW_MAP) new_rows.append(new_row) row_order += 1 # New property - return { - 'rows': new_rows - } + return {"rows": new_rows} def matrix2d_property_convertor(properties): @@ -161,58 +156,58 @@ def matrix2d_property_convertor(properties): ROW_MAP = [ # dest, src keys - ('key', 'id', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), - ('color', 'color', True), + ("key", "id", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), + ("color", "color", True), ] SUB_ROW_MAP = [ # dest, src keys - ('key', 'id', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "id", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] COLUMN_MAP = [ # dest, src keys - ('key', 'id', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "id", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] SUB_COLUMN_MAP = [ # dest, src keys - ('key', 'id', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "id", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] if properties in [None, {}]: return { - 'rows': [], - 'columns': [], + "rows": [], + "columns": [], } # rows/dimensions new_rows = [] row_order = 0 - for dimension in properties['dimensions']: + for dimension in properties["dimensions"]: new_row = clone_data(dimension, ROW_MAP) - new_row['order'] = row_order - new_row['label'] = new_row['label'] or 'Untitled Row' - new_row['color'] = new_row['color'] or '#808080' + new_row["order"] = row_order + new_row["label"] = new_row["label"] or "Untitled Row" + new_row["color"] = new_row["color"] or "#808080" new_sub_rows = [] sub_row_order = 0 - for subdimension in dimension['subdimensions']: + for subdimension in dimension["subdimensions"]: new_sub_row = clone_data(subdimension, SUB_ROW_MAP) - new_sub_row['label'] = new_sub_row['label'] or 'Untitled SubRow' - new_sub_row['order'] = sub_row_order + new_sub_row["label"] = new_sub_row["label"] or "Untitled SubRow" + new_sub_row["order"] = sub_row_order verifiy_data(new_sub_row, subdimension, SUB_ROW_MAP) new_sub_rows.append(new_sub_row) sub_row_order += 1 - new_row['subRows'] = new_sub_rows + new_row["subRows"] = new_sub_rows verifiy_data(new_row, dimension, ROW_MAP) new_rows.append(new_row) @@ -221,29 +216,29 @@ def matrix2d_property_convertor(properties): # columns/sectors new_columns = [] column_order = 0 - for sector in properties['sectors']: + for sector in properties["sectors"]: new_column = clone_data(sector, COLUMN_MAP) - new_column['order'] = column_order - new_column['label'] = new_column['label'] or 'Untitled Column' + new_column["order"] = column_order + new_column["label"] = new_column["label"] or "Untitled Column" new_sub_columns = [] sub_column_order = 0 - for subsector in sector['subsectors']: + for subsector in sector["subsectors"]: new_sub_column = clone_data(subsector, SUB_COLUMN_MAP) - new_sub_column['label'] = new_sub_column['label'] or 'Untitled SubRow' - new_sub_column['order'] = sub_column_order + new_sub_column["label"] = new_sub_column["label"] or "Untitled SubRow" + new_sub_column["order"] = sub_column_order verifiy_data(new_sub_column, subsector, SUB_COLUMN_MAP) new_sub_columns.append(new_sub_column) sub_column_order += 1 - new_column['subColumns'] = new_sub_columns + new_column["subColumns"] = new_sub_columns verifiy_data(new_column, sector, COLUMN_MAP) new_columns.append(new_column) column_order += 1 # New property return { - 'rows': new_rows, - 'columns': new_columns, + "rows": new_rows, + "columns": new_columns, } @@ -264,31 +259,28 @@ def multiselect_property_convertor(properties): """ if properties in [None, {}]: return { - 'options': [], + "options": [], } OPTION_MAP = [ # dest, src keys - ('key', 'key', True), - ('label', 'label', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "key", True), + ("label", "label", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] - options = ( - properties if isinstance(properties, list) - else properties['options'] - ) + options = properties if isinstance(properties, list) else properties["options"] new_options = [] option_order = 0 for option in options: new_option = clone_data(option, OPTION_MAP) - new_option['label'] = new_option['label'] or 'Untitled' - new_option['order'] = option_order + new_option["label"] = new_option["label"] or "Untitled" + new_option["order"] = option_order verifiy_data(new_option, option, OPTION_MAP) new_options.append(new_option) option_order += 1 return { - 'options': new_options, + "options": new_options, } @@ -309,36 +301,34 @@ def organigram_property_convertor(properties): """ OPTION_MAP = [ # dest, src keys - ('key', 'key', True), - ('label', 'title', True), - ('tooltip', 'tooltip', False), - ('order', 'order', True), + ("key", "key", True), + ("label", "title", True), + ("tooltip", "tooltip", False), + ("order", "order", True), ] def _get_all_new_options(option, order=0): if option == {}: return new_option = clone_data(option, OPTION_MAP) - new_option['label'] = new_option['label'] or 'Untitled' - new_option['order'] = order + new_option["label"] = new_option["label"] or "Untitled" + new_option["order"] = order verifiy_data(new_option, option, OPTION_MAP) order += 1 new_childerns = [] - for organ in option.pop('organs', []): + for organ in option.pop("organs", []): child_organ = _get_all_new_options(organ, order=order) if child_organ: new_childerns.append(child_organ) - new_option['children'] = new_childerns + new_option["children"] = new_childerns return new_option if properties in [None, []]: return { - 'options': [], + "options": [], } - return { - 'options': _get_all_new_options(properties) - } + return {"options": _get_all_new_options(properties)} def scale_property_convertor(properties): @@ -362,38 +352,38 @@ def scale_property_convertor(properties): """ OPTION_MAP = [ # dest, src keys - ('key', 'key', True), - ('label', 'label', True), - ('tooltip', 'tooltip', False), - ('color', 'color', True), - ('order', 'order', True), + ("key", "key", True), + ("label", "label", True), + ("tooltip", "tooltip", False), + ("color", "color", True), + ("order", "order", True), ] if properties in [None, {}]: return { - 'options': [], + "options": [], } new_options = [] default_option = None scale_order = 0 - for scale_unit in properties['scale_units'] or []: + for scale_unit in properties["scale_units"] or []: new_option = clone_data(scale_unit, OPTION_MAP) - new_option['label'] = new_option['label'] or scale_unit.get('title') or 'Untitled' - new_option['color'] = new_option['color'] or '#808080' - new_option['order'] = scale_order + new_option["label"] = new_option["label"] or scale_unit.get("title") or "Untitled" + new_option["color"] = new_option["color"] or "#808080" + new_option["order"] = scale_order # For default case - if scale_unit.get('default'): + if scale_unit.get("default"): if default_option is not None: - print(f'- Multiple defaults found: {default_option}') - default_option = scale_unit['key'] + print(f"- Multiple defaults found: {default_option}") + default_option = scale_unit["key"] verifiy_data(new_option, scale_unit, OPTION_MAP) new_options.append(new_option) scale_order += 1 return { - 'options': new_options, - 'defaultValue': default_option, + "options": new_options, + "defaultValue": default_option, } @@ -402,7 +392,7 @@ def scale_property_convertor(properties): def matrix1d_attribute_data_convertor(data): - value = data.get('value') or {} + value = data.get("value") or {} new_value = {} for row_key, row_data in value.items(): new_row_data = {} @@ -411,13 +401,11 @@ def matrix1d_attribute_data_convertor(data): new_row_data[cell_key] = True if new_row_data not in MAXTIX_DATA_EMPTY_VALUES: new_value[row_key] = new_row_data - return { - 'value': new_value - } + return {"value": new_value} def matrix2d_attribute_data_convertor(data): - value = data.get('value') or {} + value = data.get("value") or {} new_value = {} for row_key, row_data in value.items(): new_row_data = {} @@ -430,49 +418,47 @@ def matrix2d_attribute_data_convertor(data): new_row_data[sub_row_key] = new_sub_row_data if new_row_data not in MAXTIX_DATA_EMPTY_VALUES: new_value[row_key] = new_row_data - return { - 'value': new_value - } + return {"value": new_value} def date_range_attribute_data_convertor(data): - value = data.get('value') or {} + value = data.get("value") or {} return { - 'value': { - 'startDate': value.get('from'), - 'endDate': value.get('to'), + "value": { + "startDate": value.get("from"), + "endDate": value.get("to"), } } def time_range_attribute_data_convertor(data): - value = data.get('value') or {} + value = data.get("value") or {} return { - 'value': { - 'startTime': value.get('from'), - 'endTime': value.get('to'), + "value": { + "startTime": value.get("from"), + "endTime": value.get("to"), } } def geo_attribute_data_convertor(data): - values = data.get('value') or [] + values = data.get("value") or [] geo_ids = [] polygons = [] points = [] for value in values: if isinstance(value, dict): - value_type = value.get('type') - if value_type == 'Point': + value_type = value.get("type") + if value_type == "Point": points.append(value) else: polygons.append(value) else: geo_ids.append(value) return { - 'value': geo_ids, - 'polygons': polygons, - 'points': points, + "value": geo_ids, + "polygons": polygons, + "points": points, } @@ -494,72 +480,72 @@ def geo_attribute_data_convertor(data): } CONDITIONAL_OPERATOR_MAP = { - ('matrix1dWidget', 'containsPillar'): 'matrix1d-rows-selected', - ('matrix1dWidget', 'containsSubpillar'): 'matrix1d-cells-selected', - ('matrix2dWidget', 'containsDimension'): 'matrix2d-rows-selected', - ('matrix2dWidget', 'containsSubdimension'): 'matrix2d-sub-rows-selected', - ('multiselectWidget', 'isSelected'): 'multi-selection-selected', - ('scaleWidget', 'isEqualTo'): 'scale-selected', - ('selectWidget', 'isSelected'): 'single-selection-selected', + ("matrix1dWidget", "containsPillar"): "matrix1d-rows-selected", + ("matrix1dWidget", "containsSubpillar"): "matrix1d-cells-selected", + ("matrix2dWidget", "containsDimension"): "matrix2d-rows-selected", + ("matrix2dWidget", "containsSubdimension"): "matrix2d-sub-rows-selected", + ("multiselectWidget", "isSelected"): "multi-selection-selected", + ("scaleWidget", "isEqualTo"): "scale-selected", + ("selectWidget", "isSelected"): "single-selection-selected", } def get_widgets_from_conditional(conditional_widget): af_widget_qs = Widget.objects.filter(analysis_framework_id=conditional_widget.analysis_framework_id) - widgets = conditional_widget.properties.get('data', {}).get('widgets', []) + widgets = conditional_widget.properties.get("data", {}).get("widgets", []) for conditional_widget_data in widgets: - widget_data = conditional_widget_data['widget'] + widget_data = conditional_widget_data["widget"] - legacy_conditions = conditional_widget_data['conditions'] - if len(legacy_conditions['list']) == 0: + legacy_conditions = conditional_widget_data["conditions"] + if len(legacy_conditions["list"]) == 0: continue - if len(legacy_conditions['list']) > 1: - raise Exception('Found multiple list. Not supported') - legacy_condition = legacy_conditions['list'][0] + if len(legacy_conditions["list"]) > 1: + raise Exception("Found multiple list. Not supported") + legacy_condition = legacy_conditions["list"][0] - parent_widget_key = legacy_condition['widget_key'] + parent_widget_key = legacy_condition["widget_key"] parent_widget = af_widget_qs.get(key=parent_widget_key) - operator = CONDITIONAL_OPERATOR_MAP[(legacy_condition['widget_id'], legacy_condition['condition_type'])] - condition_attributes = legacy_condition.get('attributes') or {} - invert = legacy_condition.get('invert_logic') - - if operator == 'matrix1d-rows-selected': - condition_collection = condition_attributes.get('pillars') or {} - elif operator == 'matrix1d-cells-selected': - condition_collection = condition_attributes.get('subpillars') or {} - elif operator == 'matrix2d-rows-selected': - condition_collection = condition_attributes.get('dimensions') or {} - elif operator == 'matrix2d-sub-rows-selected': - condition_collection = condition_attributes.get('subdimensions') or {} - elif operator == 'multi-selection-selected': - condition_collection = condition_attributes.get('selections') or {} - elif operator == 'scale-selected': - condition_collection = condition_attributes.get('scales') or {} - elif operator == 'single-selection-selected': - condition_collection = condition_attributes.get('selections') or {} + operator = CONDITIONAL_OPERATOR_MAP[(legacy_condition["widget_id"], legacy_condition["condition_type"])] + condition_attributes = legacy_condition.get("attributes") or {} + invert = legacy_condition.get("invert_logic") + + if operator == "matrix1d-rows-selected": + condition_collection = condition_attributes.get("pillars") or {} + elif operator == "matrix1d-cells-selected": + condition_collection = condition_attributes.get("subpillars") or {} + elif operator == "matrix2d-rows-selected": + condition_collection = condition_attributes.get("dimensions") or {} + elif operator == "matrix2d-sub-rows-selected": + condition_collection = condition_attributes.get("subdimensions") or {} + elif operator == "multi-selection-selected": + condition_collection = condition_attributes.get("selections") or {} + elif operator == "scale-selected": + condition_collection = condition_attributes.get("scales") or {} + elif operator == "single-selection-selected": + condition_collection = condition_attributes.get("selections") or {} else: - raise Exception('Found unhandled attribute data') - condition_value = condition_collection.get('values') or [] - operatorModifier = 'every' if condition_collection.get('test_every') else 'some' + raise Exception("Found unhandled attribute data") + condition_value = condition_collection.get("values") or [] + operatorModifier = "every" if condition_collection.get("test_every") else "some" conditions = [ dict( - key=legacy_condition['key'], - conjunctionOperator=legacy_conditions['operator'], + key=legacy_condition["key"], + conjunctionOperator=legacy_conditions["operator"], order=1, invert=invert, operatorModifier=operatorModifier, - operator=CONDITIONAL_OPERATOR_MAP[(legacy_condition['widget_id'], legacy_condition['condition_type'])], + operator=CONDITIONAL_OPERATOR_MAP[(legacy_condition["widget_id"], legacy_condition["condition_type"])], value=condition_value, ) ] new_widget = Widget( analysis_framework_id=conditional_widget.analysis_framework_id, - key=widget_data['key'], - widget_id=widget_data['widget_id'], - title=widget_data['title'], - properties=widget_data['properties'], + key=widget_data["key"], + widget_id=widget_data["widget_id"], + title=widget_data["title"], + properties=widget_data["properties"], conditional_parent_widget=parent_widget, conditional_conditions=conditions, ) @@ -568,13 +554,13 @@ def get_widgets_from_conditional(conditional_widget): def get_attribute_from_conditional_data(widget_qs, attribute): conditional_data = copy.deepcopy(attribute.data or {}) - conditional_value = (conditional_data or {}).get('value') - if conditional_value in [None, {}] or 'selected_widget_key' not in conditional_value: + conditional_value = (conditional_data or {}).get("value") + if conditional_value in [None, {}] or "selected_widget_key" not in conditional_value: return - selected_widget_key = conditional_value['selected_widget_key'] + selected_widget_key = conditional_value["selected_widget_key"] selected_widget = widget_qs.get(key=selected_widget_key) value = conditional_value.get(selected_widget_key) - data = (value or {}).get('data') or {} + data = (value or {}).get("data") or {} return Attribute( entry=attribute.entry, widget=selected_widget, @@ -586,13 +572,13 @@ def get_number_matrix_widget_data(widget_data): rows = {} columns = {} if widget_data: - for row in widget_data.get('row_headers'): - rows[row['key']] = row['title'] - for column in widget_data.get('column_headers'): - columns[column['key']] = column['title'] + for row in widget_data.get("row_headers"): + rows[row["key"]] = row["title"] + for column in widget_data.get("column_headers"): + columns[column["key"]] = column["title"] return { - 'rows': rows, - 'columns': columns, + "rows": rows, + "columns": columns, } @@ -600,13 +586,13 @@ def get_number_matrix_attribute_data(widget_data, attribute_value): extracted_data = [] if attribute_value: for row_key, row_data in attribute_value.items(): - row_label = widget_data['rows'].get(row_key, 'N/A') + row_label = widget_data["rows"].get(row_key, "N/A") if not row_data: continue for column_key, value in row_data.items(): - column_label = widget_data['columns'].get(column_key, 'N/A') - extracted_data.append(f'({row_label}, {column_label}, {value})') - return ','.join(extracted_data) + column_label = widget_data["columns"].get(column_key, "N/A") + extracted_data.append(f"({row_label}, {column_label}, {value})") + return ",".join(extracted_data) class Command(BaseCommand): @@ -622,8 +608,10 @@ def get_section_for_af_id(self, af_id): if af_id not in self.af_default_sections: self.af_default_sections[af_id] = Section.objects.get_or_create( analysis_framework_id=af_id, - title='Overview', - )[0] # NOTE: Check if multiple exists if required + title="Overview", + )[ + 0 + ] # NOTE: Check if multiple exists if required return self.af_default_sections[af_id] def handle(self, *args, **kwargs): @@ -631,68 +619,73 @@ def handle(self, *args, **kwargs): attribute_qs = Attribute.objects.exclude(widget_version=self.CURRENT_VERSION) # Migrate Widget Data - print(f'Widgets (Total: {widget_qs.count()})') + print(f"Widgets (Total: {widget_qs.count()})") for widget_type, widget_property_convertor in WIDGET_MIGRATION_MAP.items(): - print(f'\n- {widget_type}') + print(f"\n- {widget_type}") required_widgets_qs = widget_qs.filter(widget_id=widget_type) total = required_widgets_qs.count() for index, widget in enumerate(required_widgets_qs.all(), 1): # Update properties. - if widget.properties.get('added_from') == 'overview': # Requires section + if widget.properties.get("added_from") == "overview": # Requires section widget.section = self.get_section_for_af_id(widget.analysis_framework_id) - old_properties = copy.deepcopy(widget.properties.get('data') or {}) + old_properties = copy.deepcopy(widget.properties.get("data") or {}) widget.properties = widget_property_convertor(old_properties) # print_property(widget) - widget.properties['old_properties'] = old_properties # Clean-up this later. + widget.properties["old_properties"] = old_properties # Clean-up this later. widget.version = self.CURRENT_VERSION # Save - widget.save(update_fields=('properties', 'version', 'section')) - print(f'-- Saved ({index})/({total})', end='\r') + widget.save(update_fields=("properties", "version", "section")) + print(f"-- Saved ({index})/({total})", end="\r") # Migrate Entry Attribute Data - print(f'Entry Attributes (Total: {attribute_qs.count()})') + print(f"Entry Attributes (Total: {attribute_qs.count()})") for widget_type, attribute_data_convertor in ATTRIBUTE_MIGRATION_MAP.items(): - print(f'\n- {widget_type}') + print(f"\n- {widget_type}") required_attribute_qs = attribute_qs.filter(widget__widget_id=widget_type) total = required_attribute_qs.count() for index, attribute in enumerate(required_attribute_qs.iterator(chunk_size=1000)): # Update properties. old_data = copy.deepcopy(attribute.data or {}) attribute.data = attribute_data_convertor(old_data) - attribute.data['old_data'] = old_data # Clean-up this later. + attribute.data["old_data"] = old_data # Clean-up this later. attribute.widget_version = self.CURRENT_VERSION # Save - attribute.save(update_fields=('data', 'widget_version',)) - print(f'-- Saved ({index})/({total})', end='\r') - print('') + attribute.save( + update_fields=( + "data", + "widget_version", + ) + ) + print(f"-- Saved ({index})/({total})", end="\r") + print("") # Conditional Widgets conditional_widget_qs = widget_qs.filter(widget_id=Widget.WidgetType.CONDITIONAL) total = conditional_widget_qs.count() - print(f'Conditional Widgets (Total: {total})') + print(f"Conditional Widgets (Total: {total})") for index, conditional_widget in enumerate(conditional_widget_qs.all(), 1): for widget in get_widgets_from_conditional(conditional_widget): # Update properties. - if conditional_widget.properties.get('added_from') == 'overview': # Requires section + if conditional_widget.properties.get("added_from") == "overview": # Requires section widget.section = self.get_section_for_af_id(widget.analysis_framework_id) - old_properties = copy.deepcopy(widget.properties.get('data') or {}) + old_properties = copy.deepcopy(widget.properties.get("data") or {}) widget_property_convertor = WIDGET_MIGRATION_MAP.get(widget.widget_id) if widget_property_convertor is not None: widget.properties = widget_property_convertor(old_properties) - widget.properties['from_conditional_widget'] = True + widget.properties["from_conditional_widget"] = True widget.version = self.CURRENT_VERSION # Save widget.save() - conditional_widget.properties = {'old_data': conditional_widget.properties} + conditional_widget.properties = {"old_data": conditional_widget.properties} conditional_widget.version = self.CURRENT_VERSION - conditional_widget.save(update_fields=('version',)) - print(f'-- Saved ({index})/({total})', end='\r') - print('') + conditional_widget.save(update_fields=("version",)) + print(f"-- Saved ({index})/({total})", end="\r") + print("") # Migrate Conditional Entry Attribute Data conditional_attribute_qs = attribute_qs.filter(widget__widget_id=Widget.WidgetType.CONDITIONAL) total = conditional_attribute_qs.count() - print(f'Conditional Entry Attributes (Total: {total})') + print(f"Conditional Entry Attributes (Total: {total})") for index, attribute in enumerate(conditional_attribute_qs.iterator(chunk_size=1000)): # Update properties. try: @@ -708,29 +701,34 @@ def handle(self, *args, **kwargs): new_attribute.data = attribute_data_convertor(copy.deepcopy(new_attribute.data)) new_attribute.widget_version = self.CURRENT_VERSION new_attribute.save() - attribute.data = {'old_data': attribute.data} + attribute.data = {"old_data": attribute.data} attribute.widget_version = self.CURRENT_VERSION - attribute.save(update_fields=('data', 'widget_version',)) - print(f'-- Saved ({index})/({total})', end='\r') - print('') + attribute.save( + update_fields=( + "data", + "widget_version", + ) + ) + print(f"-- Saved ({index})/({total})", end="\r") + print("") # Now migrate all number_matrix to Text number_matrix_widget_qs = widget_qs.filter(widget_id=Widget.WidgetType.NUMBER_MATRIX) total = number_matrix_widget_qs.count() - print(f'Number Matrix Widget (Total: {total})') + print(f"Number Matrix Widget (Total: {total})") for index, widget in enumerate(number_matrix_widget_qs.iterator(chunk_size=1000)): - widget.title = f'{widget.title} (Previously Number Matrix)' + widget.title = f"{widget.title} (Previously Number Matrix)" widget.widget_id = Widget.WidgetType.TEXT - if widget.properties.get('added_from') == 'overview': # Requires section + if widget.properties.get("added_from") == "overview": # Requires section widget.section = self.get_section_for_af_id(widget.analysis_framework_id) widget.properties = { - 'migrated_from_number_matrix': True, - 'old_data': copy.deepcopy(widget.properties), + "migrated_from_number_matrix": True, + "old_data": copy.deepcopy(widget.properties), } widget.version = self.CURRENT_VERSION - widget.save(update_fields=('title', 'widget_id', 'properties', 'version', 'section')) - print(f'-- Saved ({index})/({total})', end='\r') - print('') + widget.save(update_fields=("title", "widget_id", "properties", "version", "section")) + print(f"-- Saved ({index})/({total})", end="\r") + print("") # Migrate Number Matrix Entry Attribute Data number_matrix_attribute_qs = attribute_qs.filter(widget__properties__migrated_from_number_matrix=True) @@ -738,25 +736,30 @@ def handle(self, *args, **kwargs): if total: number_matrix_widget_label_map = {} for widget in Widget.objects.filter(properties__migrated_from_number_matrix=True): - widget_data = (widget.properties and widget.properties.get('old_data', {}).get('data')) + widget_data = widget.properties and widget.properties.get("old_data", {}).get("data") number_matrix_widget_label_map[widget.pk] = get_number_matrix_widget_data(widget_data) - print(f'Number Matrix Entry Attributes (Total: {total})') + print(f"Number Matrix Entry Attributes (Total: {total})") for index, attribute in enumerate(number_matrix_attribute_qs.iterator(chunk_size=1000)): attribute.data = { - 'value': get_number_matrix_attribute_data( + "value": get_number_matrix_attribute_data( number_matrix_widget_label_map[attribute.widget_id], - (attribute.data or {}).get('value'), + (attribute.data or {}).get("value"), ), - 'old_data': attribute.data, + "old_data": attribute.data, } attribute.widget_version = self.CURRENT_VERSION - attribute.save(update_fields=('data', 'widget_version',)) - print(f'-- Saved ({index})/({total})', end='\r') - print('') + attribute.save( + update_fields=( + "data", + "widget_version", + ) + ) + print(f"-- Saved ({index})/({total})", end="\r") + print("") # Finally just update for this widgets (Not changes are required for this widgets) - print('Update normal widgets:') + print("Update normal widgets:") print( widget_qs.filter( widget_id__in=[ @@ -774,7 +777,7 @@ def handle(self, *args, **kwargs): ) # Just update for this widget's attributes (Not changes are required for this widgets) - print('Update normal attributes:') + print("Update normal attributes:") print( attribute_qs.filter( widget__widget_id__in=[ @@ -797,11 +800,13 @@ def handle(self, *args, **kwargs): def print_property(widget): import json - print('-' * 22) + + print("-" * 22) print(json.dumps(widget.properties, indent=2)) def print_attribute_data(widget): import json - print('-' * 22) + + print("-" * 22) print(json.dumps(widget.data, indent=2)) diff --git a/apps/deep_migration/management/commands/migrate_analysis_framework.py b/apps/deep_migration/management/commands/migrate_analysis_framework.py index 0b528a0321..fd16aa5795 100644 --- a/apps/deep_migration/management/commands/migrate_analysis_framework.py +++ b/apps/deep_migration/management/commands/migrate_analysis_framework.py @@ -1,26 +1,14 @@ import json +import re -from deep_migration.utils import ( - MigrationCommand, - get_source_url, - request_with_auth, -) - +import reversion +from analysis_framework.models import AnalysisFramework, Exportable, Filter, Widget from deep_migration.models import ( AnalysisFrameworkMigration, ProjectMigration, UserMigration, ) - -from analysis_framework.models import ( - AnalysisFramework, - Widget, - Exportable, - Filter, -) - -import reversion -import re +from deep_migration.utils import MigrationCommand, get_source_url, request_with_auth def get_user(old_user_id): @@ -35,38 +23,36 @@ def get_project(project_id): def snap(x, default=16): if isinstance(x, str): - x = int(re.sub(r'[^\d-]+', '', x)) + x = int(re.sub(r"[^\d-]+", "", x)) return round(x / default) * default class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: frameworks = json.load(f) else: - query = self.kwargs.get('query_str', '') - frameworks = request_with_auth( - get_source_url('entry-templates', 'v1', query) - ) + query = self.kwargs.get("query_str", "") + frameworks = request_with_auth(get_source_url("entry-templates", "v1", query)) if not frameworks: - print('Couldn\'t find AF data at') + print("Couldn't find AF data at") with reversion.create_revision(): - new_frameworks_file = open('new_afs.txt', 'a') + new_frameworks_file = open("new_afs.txt", "a") for framework in frameworks: self.import_framework(framework, new_frameworks_file) new_frameworks_file.close() def import_framework(self, data, file): - print('------------') - print('Migrating analysis framework') + print("------------") + print("Migrating analysis framework") - old_id = data['id'] - title = data['name'] + old_id = data["id"] + title = data["name"] - print('{} - {}'.format(old_id, title)) + print("{} - {}".format(old_id, title)) migration, _ = AnalysisFrameworkMigration.objects.get_or_create( old_id=old_id, @@ -77,21 +63,20 @@ def import_framework(self, data, file): title=title, ) migration.analysis_framework = framework - file.write('{}\n'.format(framework.id)) + file.write("{}\n".format(framework.id)) migration.save() else: return migration.analysis_framework framework = migration.analysis_framework - framework.created_by = get_user(data['created_by']) + framework.created_by = get_user(data["created_by"]) framework.modified_by = framework.created_by framework.save() - AnalysisFramework.objects.filter(id=framework.id)\ - .update(created_at=data['created_at']) + AnalysisFramework.objects.filter(id=framework.id).update(created_at=data["created_at"]) - projects = data['projects'] + projects = data["projects"] for project_id in projects: project = get_project(project_id) if project: @@ -100,397 +85,375 @@ def import_framework(self, data, file): # Let's start migrating widgets - elements = data['elements'] + elements = data["elements"] for element in elements: self.migrate_widget(framework, element) return framework def migrate_widget(self, framework, element): - print('Migrating widget {}'.format(element['id'])) + print("Migrating widget {}".format(element["id"])) type_method_map = { - 'pageOneExcerptBox': self.migrate_excerpt, - 'pageTwoExcerptBox': self.migrate_excerpt, - - 'matrix1d': self.migrate_matrix1d, - 'matrix2d': self.migrate_matrix2d, - - 'number-input': self.migrate_number, - 'date-input': self.migrate_date, - 'scale': self.migrate_scale, - - 'organigram': self.migrate_organigram, - 'multiselect': self.migrate_multiselect, - 'geolocations': self.migrate_geo, - - 'number2d': self.migrate_number_matrix, + "pageOneExcerptBox": self.migrate_excerpt, + "pageTwoExcerptBox": self.migrate_excerpt, + "matrix1d": self.migrate_matrix1d, + "matrix2d": self.migrate_matrix2d, + "number-input": self.migrate_number, + "date-input": self.migrate_date, + "scale": self.migrate_scale, + "organigram": self.migrate_organigram, + "multiselect": self.migrate_multiselect, + "geolocations": self.migrate_geo, + "number2d": self.migrate_number_matrix, } - method = type_method_map.get(element['type']) + method = type_method_map.get(element["type"]) if method: method(framework, element) def migrate_excerpt(self, framework, element): widget, _ = Widget.objects.get_or_create( - widget_id='excerptWidget', + widget_id="excerptWidget", analysis_framework=framework, defaults={ - 'title': ( - element.get('label') or - element.get('excerptLabel') or 'Excerpt' - ), - 'key': element['id'], - 'properties': {}, + "title": (element.get("label") or element.get("excerptLabel") or "Excerpt"), + "key": element["id"], + "properties": {}, }, ) - if element['id'] == 'page-one-excerpt': - widget.properties.update({ - 'overview_grid_layout': self.get_layout(element), - }) - elif element['id'] == 'page-two-excerpt': - widget.properties.update({ - 'list_grid_layout': self.get_layout(element), - }) + if element["id"] == "page-one-excerpt": + widget.properties.update( + { + "overview_grid_layout": self.get_layout(element), + } + ) + elif element["id"] == "page-two-excerpt": + widget.properties.update( + { + "list_grid_layout": self.get_layout(element), + } + ) widget.save() def migrate_number(self, framework, element): - title = element['label'] or 'Number' + title = element["label"] or "Number" widget, _ = Widget.objects.get_or_create( - widget_id='numberWidget', + widget_id="numberWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'filter_type': 'number', - 'properties': { - 'type': 'number', + "title": title, + "filter_type": "number", + "properties": { + "type": "number", }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'title': title} - }, + "data": {"excel": {"title": title}}, }, ) def migrate_date(self, framework, element): - title = element['label'] or 'Date' + title = element["label"] or "Date" widget, _ = Widget.objects.get_or_create( - widget_id='dateWidget', + widget_id="dateWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'filter_type': 'number', - 'properties': { - 'type': 'date', + "title": title, + "filter_type": "number", + "properties": { + "type": "date", }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'title': title} - }, + "data": {"excel": {"title": title}}, }, ) def migrate_scale(self, framework, element): - title = element['label'] or 'Scale' + title = element["label"] or "Scale" widget, _ = Widget.objects.get_or_create( - widget_id='scaleWidget', + widget_id="scaleWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), - 'data': { - 'scale_units': self.convert_scale_values( - element['scaleValues'] - ), - 'value': self.get_default_scale_value( - element['scaleValues'] - ), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), + "data": { + "scale_units": self.convert_scale_values(element["scaleValues"]), + "value": self.get_default_scale_value(element["scaleValues"]), }, }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'properties': { - 'type': 'multiselect-range', - 'options': self.convert_scale_filter_values( - element['scaleValues'] - ), + "title": title, + "properties": { + "type": "multiselect-range", + "options": self.convert_scale_filter_values(element["scaleValues"]), }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'title': title} - }, + "data": {"excel": {"title": title}}, }, ) def convert_scale_values(self, values): return [ { - 'title': v['name'], - 'color': v['color'], - 'key': v['id'], - 'default': v['default'], - } for v in values + "title": v["name"], + "color": v["color"], + "key": v["id"], + "default": v["default"], + } + for v in values ] def get_default_scale_value(self, values): - return next(( - v['id'] for v in values if v['default'] - ), None) + return next((v["id"] for v in values if v["default"]), None) def convert_scale_filter_values(self, values): return [ { - 'label': v['name'], - 'key': v['id'], - } for v in values + "label": v["name"], + "key": v["id"], + } + for v in values ] def migrate_organigram(self, framework, element): - title = element['label'] or 'Organigram' + title = element["label"] or "Organigram" widget, _ = Widget.objects.get_or_create( - widget_id='organigramWidget', + widget_id="organigramWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), - 'data': self.convert_organigram(element['nodes']), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), + "data": self.convert_organigram(element["nodes"]), }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'properties': { - 'type': 'multiselect', - 'options': self.get_organigram_filter_nodes( - element['nodes'] - ), + "title": title, + "properties": { + "type": "multiselect", + "options": self.get_organigram_filter_nodes(element["nodes"]), }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'title': title} - }, + "data": {"excel": {"title": title}}, }, ) def convert_organigram(self, nodes): - parent_nodes = self.get_organigram_nodes(nodes, '') + parent_nodes = self.get_organigram_nodes(nodes, "") return parent_nodes and parent_nodes[0] def get_organigram_nodes(self, nodes, parent): return [ { - 'key': node['id'], - 'title': node['name'], - 'organs': self.get_organigram_nodes(nodes, node['id']), - } for node in nodes if node['parent'] == parent + "key": node["id"], + "title": node["name"], + "organs": self.get_organigram_nodes(nodes, node["id"]), + } + for node in nodes + if node["parent"] == parent ] - def get_organigram_filter_nodes(self, nodes, parent='', prefix=''): + def get_organigram_filter_nodes(self, nodes, parent="", prefix=""): values = [] for node in nodes: - title = '{}{}'.format(prefix, node['name']) - if node['parent'] == parent: - values.append({ - 'key': node['id'], - 'label': title, - }) - - values.extend(self.get_organigram_filter_nodes( - nodes, - node['id'], - '{} / '.format(title), - )) + title = "{}{}".format(prefix, node["name"]) + if node["parent"] == parent: + values.append( + { + "key": node["id"], + "label": title, + } + ) + + values.extend( + self.get_organigram_filter_nodes( + nodes, + node["id"], + "{} / ".format(title), + ) + ) return values def migrate_multiselect(self, framework, element): - title = element['label'] or 'Groups' + title = element["label"] or "Groups" widget, _ = Widget.objects.get_or_create( - widget_id='multiselectWidget', + widget_id="multiselectWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), - 'data': { - 'options': self.convert_multiselect( - element['options'] - ), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), + "data": { + "options": self.convert_multiselect(element["options"]), }, }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'properties': { - 'type': 'multiselect', - 'options': self.convert_multiselect( - element['options'] - ), + "title": title, + "properties": { + "type": "multiselect", + "options": self.convert_multiselect(element["options"]), }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'title': title} - }, + "data": {"excel": {"title": title}}, }, ) def convert_multiselect(self, options): return [ { - 'key': option['id'], - 'label': option['text'], - } for option in options + "key": option["id"], + "label": option["text"], + } + for option in options ] def migrate_geo(self, framework, element): - title = element['label'] or 'Geo' + title = element["label"] or "Geo" widget, _ = Widget.objects.get_or_create( - widget_id='geoWidget', + widget_id="geoWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': title, - 'properties': { - 'list_grid_layout': self.get_layout(element), + "title": title, + "properties": { + "list_grid_layout": self.get_layout(element), }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': title, - 'properties': { - 'type': 'geo', + "title": title, + "properties": { + "type": "geo", }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': { - 'excel': {'type': 'geo'} - }, + "data": {"excel": {"type": "geo"}}, }, ) def migrate_number_matrix(self, framework, element): widget, _ = Widget.objects.get_or_create( - widget_id='numberMatrixWidget', + widget_id="numberMatrixWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': element['title'], - 'properties': { - 'overview_grid_layout': self.get_layout(element), - 'list_grid_layout': self.get_layout(element.get('list')), - 'data': self.convert_number_matrix(element['rows'], - element['columns']), + "title": element["title"], + "properties": { + "overview_grid_layout": self.get_layout(element), + "list_grid_layout": self.get_layout(element.get("list")), + "data": self.convert_number_matrix(element["rows"], element["columns"]), }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': element['title'], - 'properties': { - 'type': 'number-2d', + "title": element["title"], + "properties": { + "type": "number-2d", }, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': self.convert_number_matrix_export(element['rows'], - element['columns']), + "data": self.convert_number_matrix_export(element["rows"], element["columns"]), }, ) @@ -498,339 +461,337 @@ def convert_number_matrix(self, rows, columns): row_headers = [] column_headers = [] for row in rows: - row_headers.append({'key': row['id'], - 'title': row['title']}) + row_headers.append({"key": row["id"], "title": row["title"]}) for column in columns: - column_headers.append({'key': column['id'], - 'title': column['title']}) + column_headers.append({"key": column["id"], "title": column["title"]}) return { - 'row_headers': row_headers, - 'column_headers': column_headers, + "row_headers": row_headers, + "column_headers": column_headers, } def convert_number_matrix_export(self, rows, columns): titles = [] for row in rows: for column in columns: - titles.append('{}-{}'.format(row['title'], column['title'])) + titles.append("{}-{}".format(row["title"], column["title"])) return { - 'excel': { - 'type': 'multiple', - 'titles': titles, + "excel": { + "type": "multiple", + "titles": titles, }, } def migrate_matrix1d(self, framework, element): widget, _ = Widget.objects.get_or_create( - widget_id='matrix1dWidget', + widget_id="matrix1dWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': element['title'], - 'properties': { - 'overview_grid_layout': self.get_layout(element), - 'list_grid_layout': self.get_layout(element.get('list')), - 'data': { - 'rows': self.convert_matrix1d_rows(element['pillars']), + "title": element["title"], + "properties": { + "overview_grid_layout": self.get_layout(element), + "list_grid_layout": self.get_layout(element.get("list")), + "data": { + "rows": self.convert_matrix1d_rows(element["pillars"]), }, }, }, ) filter, _ = Filter.objects.get_or_create( - key=element['id'], + key=element["id"], analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': element['title'], - 'properties': { - 'type': 'multiselect', - 'options': self.convert_matrix1d_filter(element['pillars']) - }, + "title": element["title"], + "properties": {"type": "multiselect", "options": self.convert_matrix1d_filter(element["pillars"])}, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': self.convert_matrix1d_export(element['pillars']), + "data": self.convert_matrix1d_export(element["pillars"]), }, ) def convert_matrix1d_rows(self, pillars): return [ { - 'key': pillar['id'], - 'title': pillar['name'], - 'color': pillar['color'], - 'tooltip': pillar['tooltip'], - 'cells': self.convert_matrix1d_cells( - pillar['subpillars'] - ) - } for pillar in pillars + "key": pillar["id"], + "title": pillar["name"], + "color": pillar["color"], + "tooltip": pillar["tooltip"], + "cells": self.convert_matrix1d_cells(pillar["subpillars"]), + } + for pillar in pillars ] def convert_matrix1d_cells(self, subpillars): return [ { - 'key': subpillar['id'], - 'value': subpillar['name'], - } for subpillar in subpillars + "key": subpillar["id"], + "value": subpillar["name"], + } + for subpillar in subpillars ] def convert_matrix1d_filter(self, pillars): options = [] for pillar in pillars: - options.append({ - 'key': pillar['id'], - 'label': pillar['name'], - }) - - for subpillar in pillar['subpillars']: - options.append({ - 'key': subpillar['id'], - 'label': '{} / {}'.format( - pillar['name'], - subpillar['name'], - ), - }) + options.append( + { + "key": pillar["id"], + "label": pillar["name"], + } + ) + + for subpillar in pillar["subpillars"]: + options.append( + { + "key": subpillar["id"], + "label": "{} / {}".format( + pillar["name"], + subpillar["name"], + ), + } + ) return options def convert_matrix1d_export(self, pillars): excel = { - 'titles': ['Dimension', 'Subdimension'], - 'type': 'multiple', + "titles": ["Dimension", "Subdimension"], + "type": "multiple", } levels = [] for pillar in pillars: sublevels = [ { - 'id': subpillar['id'], - 'title': subpillar['name'], - } for subpillar in pillar['subpillars'] + "id": subpillar["id"], + "title": subpillar["name"], + } + for subpillar in pillar["subpillars"] ] - levels.append({ - 'id': pillar['id'], - 'title': pillar['name'], - 'sublevels': sublevels, - }) + levels.append( + { + "id": pillar["id"], + "title": pillar["name"], + "sublevels": sublevels, + } + ) - report = {'levels': levels} + report = {"levels": levels} return { - 'excel': excel, - 'report': report, + "excel": excel, + "report": report, } def migrate_matrix2d(self, framework, element): widget, _ = Widget.objects.get_or_create( - widget_id='matrix2dWidget', + widget_id="matrix2dWidget", analysis_framework=framework, - key=element['id'], + key=element["id"], defaults={ - 'title': element['title'], - 'properties': { - 'overview_grid_layout': self.get_layout(element), - 'list_grid_layout': self.get_layout(element.get('list')), - 'data': { - 'dimensions': self.convert_matrix2d_dimensions( - element['pillars'] - ), - 'sectors': self.convert_matrix2d_sectors( - element['sectors'] - ), + "title": element["title"], + "properties": { + "overview_grid_layout": self.get_layout(element), + "list_grid_layout": self.get_layout(element.get("list")), + "data": { + "dimensions": self.convert_matrix2d_dimensions(element["pillars"]), + "sectors": self.convert_matrix2d_sectors(element["sectors"]), }, }, }, ) filter1, _ = Filter.objects.get_or_create( - key='{}-dimensions'.format(element['id']), + key="{}-dimensions".format(element["id"]), analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': '{} Dimensions'.format(element['title']), - 'properties': { - 'type': 'multiselect', - 'options': self.convert_matrix2d_filter1( - element['pillars'] - ) - }, + "title": "{} Dimensions".format(element["title"]), + "properties": {"type": "multiselect", "options": self.convert_matrix2d_filter1(element["pillars"])}, }, ) filter2, _ = Filter.objects.get_or_create( - key='{}-sectors'.format(element['id']), + key="{}-sectors".format(element["id"]), analysis_framework=framework, - widget_key=element['id'], + widget_key=element["id"], defaults={ - 'title': '{} Sectors'.format(element['title']), - 'properties': { - 'type': 'multiselect', - 'options': self.convert_matrix2d_filter2( - element['sectors'] - ) - }, + "title": "{} Sectors".format(element["title"]), + "properties": {"type": "multiselect", "options": self.convert_matrix2d_filter2(element["sectors"])}, }, ) exportable, _ = Exportable.objects.get_or_create( - widget_key=element['id'], + widget_key=element["id"], analysis_framework=framework, defaults={ - 'data': self.convert_matrix2d_export(element), + "data": self.convert_matrix2d_export(element), }, ) def convert_matrix2d_dimensions(self, pillars): return [ { - 'id': pillar['id'], - 'title': pillar['title'], - 'color': pillar['color'], - 'tooltip': pillar['tooltip'], - 'subdimensions': self.convert_matrix2d_subdimensions( - pillar['subpillars'] - ) - } for pillar in pillars if pillar.get('id') + "id": pillar["id"], + "title": pillar["title"], + "color": pillar["color"], + "tooltip": pillar["tooltip"], + "subdimensions": self.convert_matrix2d_subdimensions(pillar["subpillars"]), + } + for pillar in pillars + if pillar.get("id") ] def convert_matrix2d_subdimensions(self, subpillars): return [ { - 'id': subpillar['id'], - 'title': subpillar['title'], - 'tooltip': subpillar['tooltip'], - } for subpillar in subpillars if subpillar.get('id') + "id": subpillar["id"], + "title": subpillar["title"], + "tooltip": subpillar["tooltip"], + } + for subpillar in subpillars + if subpillar.get("id") ] def convert_matrix2d_sectors(self, sectors): return [ - { - 'id': sector['id'], - 'title': sector['title'], - 'subsectors': self.convert_matrix2d_subsectors( - sector['subsectors'] - ) - } for sector in sectors if sector.get('id') + {"id": sector["id"], "title": sector["title"], "subsectors": self.convert_matrix2d_subsectors(sector["subsectors"])} + for sector in sectors + if sector.get("id") ] def convert_matrix2d_subsectors(self, subsectors): return [ { - 'id': subsector['id'], - 'title': subsector['title'], - } for subsector in subsectors if subsector.get('id') + "id": subsector["id"], + "title": subsector["title"], + } + for subsector in subsectors + if subsector.get("id") ] def convert_matrix2d_filter1(self, pillars): options = [] for pillar in pillars: - if not pillar.get('id'): + if not pillar.get("id"): continue - options.append({ - 'key': pillar['id'], - 'label': pillar['title'], - }) - - for subpillar in pillar['subpillars']: - options.append({ - 'key': subpillar['id'], - 'label': '{} / {}'.format( - pillar['title'], - subpillar['title'], - ), - }) + options.append( + { + "key": pillar["id"], + "label": pillar["title"], + } + ) + + for subpillar in pillar["subpillars"]: + options.append( + { + "key": subpillar["id"], + "label": "{} / {}".format( + pillar["title"], + subpillar["title"], + ), + } + ) return options def convert_matrix2d_filter2(self, sectors): options = [] for sector in sectors: - if not sector.get('id'): + if not sector.get("id"): continue - options.append({ - 'key': sector['id'], - 'label': sector['title'], - }) - - for subsector in sector['subsectors']: - options.append({ - 'key': subsector['id'], - 'label': '{} / {}'.format( - sector['title'], - subsector['title'], - ), - }) + options.append( + { + "key": sector["id"], + "label": sector["title"], + } + ) + + for subsector in sector["subsectors"]: + options.append( + { + "key": subsector["id"], + "label": "{} / {}".format( + sector["title"], + subsector["title"], + ), + } + ) return options def convert_matrix2d_export(self, element): excel = { - 'type': 'multiple', - 'titles': ['Dimension', 'Subdimension', 'Sector', 'Subsectors'], + "type": "multiple", + "titles": ["Dimension", "Subdimension", "Sector", "Subsectors"], } levels = [] - for sector in element['sectors']: - if not sector.get('id'): + for sector in element["sectors"]: + if not sector.get("id"): continue sublevels = [] - for pillar in element['pillars']: - if not pillar.get('id'): + for pillar in element["pillars"]: + if not pillar.get("id"): continue subsublevels = [] - for subpillar in pillar['subpillars']: - if not subpillar.get('id'): + for subpillar in pillar["subpillars"]: + if not subpillar.get("id"): continue - subsublevels.append({ - 'id': subpillar['id'], - 'title': subpillar['title'], - }) - - sublevels.append({ - 'id': pillar['id'], - 'title': pillar['title'], - 'sublevels': subsublevels, - }) - - levels.append({ - 'id': sector['id'], - 'title': sector['title'], - 'sublevels': sublevels, - }) + subsublevels.append( + { + "id": subpillar["id"], + "title": subpillar["title"], + } + ) + + sublevels.append( + { + "id": pillar["id"], + "title": pillar["title"], + "sublevels": subsublevels, + } + ) + + levels.append( + { + "id": sector["id"], + "title": sector["title"], + "sublevels": sublevels, + } + ) report = { - 'levels': levels, + "levels": levels, } return { - 'excel': excel, - 'report': report, + "excel": excel, + "report": report, } def get_layout(self, element): if not element: element = {} default_size = { - 'width': element.get('width') or 240, - 'height': element.get('height') or 240, + "width": element.get("width") or 240, + "height": element.get("height") or 240, } default_position = { - 'left': element.get('left') or 0, - 'top': element.get('top') or 0, + "left": element.get("left") or 0, + "top": element.get("top") or 0, } layout = { - **(element.get('size') or default_size), - **(element.get('position') or default_position), + **(element.get("size") or default_size), + **(element.get("position") or default_position), } - return { - key: snap(value) - for key, value - in layout.items() - } + return {key: snap(value) for key, value in layout.items()} diff --git a/apps/deep_migration/management/commands/migrate_entry.py b/apps/deep_migration/management/commands/migrate_entry.py index 00e0e0c0e9..1b729b8a57 100644 --- a/apps/deep_migration/management/commands/migrate_entry.py +++ b/apps/deep_migration/management/commands/migrate_entry.py @@ -1,34 +1,18 @@ import json +from datetime import datetime -from deep_migration.utils import ( - MigrationCommand, - get_source_url, - request_with_auth, -) - +import reversion +from analysis_framework.models import Exportable, Filter from deep_migration.models import ( AnalysisFrameworkMigration, - ProjectMigration, + CountryMigration, LeadMigration, + ProjectMigration, UserMigration, - CountryMigration, ) -from analysis_framework.models import ( - Filter, - Exportable, -) -from entry.models import ( - Entry, - Attribute, - FilterData, - ExportData, -) -from geo.models import Region, GeoArea - -from datetime import datetime - -import reversion - +from deep_migration.utils import MigrationCommand, get_source_url, request_with_auth +from entry.models import Attribute, Entry, ExportData, FilterData +from geo.models import GeoArea, Region ONE_DAY = 24 * 60 * 60 * 1000 @@ -49,52 +33,46 @@ def get_lead(lead_id): def get_analysis_framework(lead_id): - migration = AnalysisFrameworkMigration.objects.filter( - old_id=lead_id - ).first() + migration = AnalysisFrameworkMigration.objects.filter(old_id=lead_id).first() return migration and migration.analysis_framework def get_region(code): - migration = CountryMigration.objects.filter( - code=code - ).first() + migration = CountryMigration.objects.filter(code=code).first() return migration and migration.region class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: entries = json.load(f) else: - entries = request_with_auth( - get_source_url('entries', query='template=1', version='v1') - ) + entries = request_with_auth(get_source_url("entries", query="template=1", version="v1")) if not entries: - print('Couldn\'t find entries data') + print("Couldn't find entries data") with reversion.create_revision(): for entry in entries: self.import_entry(entry) def import_entry(self, data): - print('------------') - print('Migrating entries') + print("------------") + print("Migrating entries") - lead_id = data['lead'] - print('For lead - {}'.format(lead_id)) + lead_id = data["lead"] + print("For lead - {}".format(lead_id)) lead = get_lead(lead_id) if not lead: - print('Lead not migrated yet') + print("Lead not migrated yet") return lead.entry_set.all().delete() framework = None - project_id = data['event'] + project_id = data["event"] if project_id: project = get_project(project_id) framework = project.analysis_framework @@ -105,48 +83,46 @@ def import_entry(self, data): self.regions = regions if not framework: - template_id = data['template'] + template_id = data["template"] if not template_id: - print('Not an entry with analysis framework') + print("Not an entry with analysis framework") return framework = get_analysis_framework(template_id) if not framework: - print('Analysis framework not migrated yet') + print("Analysis framework not migrated yet") return - print('Lead title: {}'.format(lead.title)) - informations = data['informations'] + print("Lead title: {}".format(lead.title)) + informations = data["informations"] for information in informations: self.import_information(data, lead, framework, information) def import_information(self, entry_data, lead, framework, data): - old_id = data['id'] - print('Entry info - {}'.format(old_id)) + old_id = data["id"] + print("Entry info - {}".format(old_id)) entry = Entry( lead=lead, analysis_framework=framework, ) - if data.get('excerpt'): - entry.excerpt = data['excerpt'] + if data.get("excerpt"): + entry.excerpt = data["excerpt"] entry.entry_type = Entry.TagType.EXCERPT - elif data.get('image'): - entry.image_raw = data['image'] + elif data.get("image"): + entry.image_raw = data["image"] entry.entry_type = Entry.TagType.IMAGE - entry.created_by = get_user(entry_data['created_by']) + entry.created_by = get_user(entry_data["created_by"]) entry.modified_by = entry.created_by entry.project = entry.lead.project entry.save() - Entry.objects.filter(id=entry.id).update( - created_at=entry_data['created_at'] - ) + Entry.objects.filter(id=entry.id).update(created_at=entry_data["created_at"]) # Start migrating the attributes - elements = data['elements'] + elements = data["elements"] # TODO migrate excerpt and image widget for element in elements: self.migrate_attribute(entry, framework, element) @@ -154,23 +130,23 @@ def import_information(self, entry_data, lead, framework, data): return entry def migrate_attribute(self, entry, framework, element): - print('Migrating element {}'.format(element['id'])) + print("Migrating element {}".format(element["id"])) - widget = framework.widget_set.filter(key=element['id']).first() + widget = framework.widget_set.filter(key=element["id"]).first() if not widget: - print('Widget not migrated yet') + print("Widget not migrated yet") return widget_method_map = { - 'numberWidget': self.migrate_number, - 'dateWidget': self.migrate_date, - 'scaleWidget': self.migrate_scale, - 'multiselectWidget': self.migrate_multiselect, - 'organigramWidget': self.migrate_organigram, - 'geoWidget': self.migrate_geo, - 'matrix1dWidget': self.migrate_matrix1d, - 'matrix2dWidget': self.migrate_matrix2d, - 'numberMatrixWidget': self.migrate_number_matrix, + "numberWidget": self.migrate_number, + "dateWidget": self.migrate_date, + "scaleWidget": self.migrate_scale, + "multiselectWidget": self.migrate_multiselect, + "organigramWidget": self.migrate_organigram, + "geoWidget": self.migrate_geo, + "matrix1dWidget": self.migrate_matrix1d, + "matrix2dWidget": self.migrate_matrix2d, + "numberMatrixWidget": self.migrate_number_matrix, } method = widget_method_map.get(widget.widget_id) @@ -182,12 +158,11 @@ def migrate_attribute_data(self, entry, widget, data): entry=entry, widget=widget, defaults={ - 'data': data, + "data": data, }, ) - def migrate_filter_data(self, entry, widget, index=0, - number=None, values=None): + def migrate_filter_data(self, entry, widget, index=0, number=None, values=None): filter = Filter.objects.filter( widget_key=widget.key, analysis_framework=widget.analysis_framework, @@ -196,8 +171,8 @@ def migrate_filter_data(self, entry, widget, index=0, entry=entry, filter=filter, defaults={ - 'number': number, - 'values': values, + "number": number, + "values": values, }, ) @@ -210,115 +185,152 @@ def migrate_export_data(self, entry, widget, data): entry=entry, exportable=exportable, defaults={ - 'data': data, + "data": data, }, ) def migrate_number(self, entry, widget, element): - value = element['value'] and int(element['value']) - self.migrate_attribute_data(entry, widget, { - 'value': value, - }) + value = element["value"] and int(element["value"]) + self.migrate_attribute_data( + entry, + widget, + { + "value": value, + }, + ) self.migrate_filter_data(entry, widget, number=value) - self.migrate_export_data(entry, widget, { - 'excel': { - 'value': str(value), - } - }) + self.migrate_export_data( + entry, + widget, + { + "excel": { + "value": str(value), + } + }, + ) def migrate_date(self, entry, widget, element): - value = element['value'] + value = element["value"] try: - date = datetime.strptime(value, '%Y-%m-%d') + date = datetime.strptime(value, "%Y-%m-%d") except Exception: - date = datetime.strptime(value, '%d-%m-%Y') + date = datetime.strptime(value, "%d-%m-%Y") - self.migrate_attribute_data(entry, widget, { - 'value': date and date.strftime('%Y-%m-%d'), - }) + self.migrate_attribute_data( + entry, + widget, + { + "value": date and date.strftime("%Y-%m-%d"), + }, + ) number = date and int(date.timestamp() / ONE_DAY) self.migrate_filter_data(entry, widget, number=number) - self.migrate_export_data(entry, widget, { - 'excel': { - 'value': date and date.strftime('%d-%m-%Y'), - } - }) + self.migrate_export_data( + entry, + widget, + { + "excel": { + "value": date and date.strftime("%d-%m-%Y"), + } + }, + ) def migrate_scale(self, entry, widget, element): - value = element.get('value') - self.migrate_attribute_data(entry, widget, { - 'selectedScale': value, - }) + value = element.get("value") + self.migrate_attribute_data( + entry, + widget, + { + "selectedScale": value, + }, + ) self.migrate_filter_data(entry, widget, values=[value]) - widget_data = widget.properties['data'] - scale_units = widget_data['scale_units'] - scale = next(( - s for s in scale_units - if s['key'] == value - ), None) - self.migrate_export_data(entry, widget, { - 'excel': { - 'value': scale['title'] if scale else '', - } - }) + widget_data = widget.properties["data"] + scale_units = widget_data["scale_units"] + scale = next((s for s in scale_units if s["key"] == value), None) + self.migrate_export_data( + entry, + widget, + { + "excel": { + "value": scale["title"] if scale else "", + } + }, + ) def migrate_multiselect(self, entry, widget, element): - value = element.get('value') or [] - self.migrate_attribute_data(entry, widget, { - 'value': value, - }) + value = element.get("value") or [] + self.migrate_attribute_data( + entry, + widget, + { + "value": value, + }, + ) self.migrate_filter_data(entry, widget, values=value) - widget_data = widget.properties['data'] - options = widget_data['options'] + widget_data = widget.properties["data"] + options = widget_data["options"] label_list = [] for item in value: - option = next(( - o for o in options - if o['key'] == item - ), None) - label_list.append(option['label']) - - self.migrate_export_data(entry, widget, { - 'excel': { - 'type': 'list', - 'value': label_list, - } - }) + option = next((o for o in options if o["key"] == item), None) + label_list.append(option["label"]) + + self.migrate_export_data( + entry, + widget, + { + "excel": { + "type": "list", + "value": label_list, + } + }, + ) def migrate_organigram(self, entry, widget, element): - value = element.get('value') or [] - widget_data = widget.properties['data'] + value = element.get("value") or [] + widget_data = widget.properties["data"] nodes = self.get_organigram_nodes([widget_data], value) - self.migrate_attribute_data(entry, widget, { - 'values': nodes, - }) + self.migrate_attribute_data( + entry, + widget, + { + "values": nodes, + }, + ) self.migrate_filter_data( - entry, widget, + entry, + widget, values=self.get_organigram_filter_data([widget_data], value), ) - self.migrate_export_data(entry, widget, { - 'excel': { - 'type': 'list', - 'value': [n['name'] for n in nodes], - } - }) + self.migrate_export_data( + entry, + widget, + { + "excel": { + "type": "list", + "value": [n["name"] for n in nodes], + } + }, + ) def get_organigram_nodes(self, organs, keys): nodes = [] for organ in organs: - if organ['key'] in keys: - nodes.append({ - 'id': organ['key'], - 'name': organ['title'], - }) - - children = self.get_organigram_nodes(organ['organs'], keys) + if organ["key"] in keys: + nodes.append( + { + "id": organ["key"], + "name": organ["title"], + } + ) + + children = self.get_organigram_nodes(organ["organs"], keys) nodes = nodes + children return nodes @@ -326,41 +338,51 @@ def get_organigram_nodes(self, organs, keys): def get_organigram_filter_data(self, organs, keys): filter_data = [] for organ in organs: - children = self.get_organigram_filter_data(organ['organs'], keys) - if children or organ['key'] in keys: - filter_data.append(organ['key']) + children = self.get_organigram_filter_data(organ["organs"], keys) + if children or organ["key"] in keys: + filter_data.append(organ["key"]) filter_data = filter_data + children return filter_data def migrate_geo(self, entry, widget, element): - areas = [self.get_geo_area(v) for v in element.get('value', [])] + areas = [self.get_geo_area(v) for v in element.get("value", [])] values = [ { - 'key': str(area.id), - 'short_label': area.get_label(), - 'label': area.get_label(), - } for area in areas if area + "key": str(area.id), + "short_label": area.get_label(), + "label": area.get_label(), + } + for area in areas + if area ] - keys = [str(v['key']) for v in values] + keys = [str(v["key"]) for v in values] - self.migrate_attribute_data(entry, widget, { - 'values': values, - }) + self.migrate_attribute_data( + entry, + widget, + { + "values": values, + }, + ) self.migrate_filter_data( entry, widget, values=keys, ) - self.migrate_export_data(entry, widget, { - 'excel': { - 'values': keys, - } - }) + self.migrate_export_data( + entry, + widget, + { + "excel": { + "values": keys, + } + }, + ) def get_geo_area(self, value): - splits = value.split(':') + splits = value.split(":") region_code = splits[0] admin_level = splits[1] area_title = splits[2] @@ -384,27 +406,28 @@ def get_geo_area(self, value): return areas.first() def migrate_number_matrix(self, entry, widget, element): - numbers = element.get('numbers') or [] + numbers = element.get("numbers") or [] attribute = self.get_number_matrix(numbers) self.migrate_attribute_data(entry, widget, attribute) - widget_data = widget.properties['data'] - rows = widget_data['row_headers'] - columns = widget_data['column_headers'] + widget_data = widget.properties["data"] + rows = widget_data["row_headers"] + columns = widget_data["column_headers"] self.migrate_export_data( - entry, widget, + entry, + widget, self.get_number_matrix_export_data(attribute, rows, columns), ) def get_number_matrix(self, numbers): attribute = {} for number in numbers: - row = number['row'] - column = number['column'] + row = number["row"] + column = number["column"] if row not in attribute: attribute[row] = {} - attribute[row][column] = int(number['value']) + attribute[row][column] = int(number["value"]) return attribute @@ -413,43 +436,41 @@ def get_number_matrix_export_data(self, attribute, rows, columns): for row in rows: row_values = [] for column in columns: - value = attribute.get(row['key'], {}).get(column['key'], None) + value = attribute.get(row["key"], {}).get(column["key"], None) if value is not None: excel_values.append(str(value)) row_values.append(int(value)) else: - excel_values.append('') + excel_values.append("") is_same = len(row_values) == 0 or len(set(row_values)) == 1 - excel_values.append('True' if is_same else 'False') + excel_values.append("True" if is_same else "False") - return { - 'excel': { - 'values': excel_values - } - } + return {"excel": {"values": excel_values}} def migrate_matrix1d(self, entry, widget, element): - selections = element.get('selections') or [] + selections = element.get("selections") or [] attribute = self.get_matrix1d_attribute(selections) self.migrate_attribute_data(entry, widget, attribute) - widget_data = widget.properties['data'] - rows = widget_data['rows'] + widget_data = widget.properties["data"] + rows = widget_data["rows"] self.migrate_filter_data( - entry, widget, + entry, + widget, values=self.get_matrix1d_filter_values(selections), ) self.migrate_export_data( - entry, widget, + entry, + widget, self.get_matrix1d_export_data(selections, rows), ) def get_matrix1d_attribute(self, selections): attribute = {} for selection in selections: - pillar = selection['pillar'] - subpillar = selection['subpillar'] + pillar = selection["pillar"] + subpillar = selection["subpillar"] if pillar not in attribute: attribute[pillar] = {} attribute[pillar][subpillar] = True @@ -458,8 +479,8 @@ def get_matrix1d_attribute(self, selections): def get_matrix1d_filter_values(self, selections): filter_values = [] for selection in selections: - pillar = selection['pillar'] - subpillar = selection['subpillar'] + pillar = selection["pillar"] + subpillar = selection["subpillar"] if pillar not in filter_values: filter_values.append(pillar) filter_values.append(subpillar) @@ -470,66 +491,66 @@ def get_matrix1d_export_data(self, selections, rows): report_values = [] for selection in selections: - row = next((r for r in rows if r['key'] == selection['pillar']), - None) + row = next((r for r in rows if r["key"] == selection["pillar"]), None) if not row: continue - cell = next((c for c in row['cells'] - if c['key'] == selection['subpillar']), - None) + cell = next((c for c in row["cells"] if c["key"] == selection["subpillar"]), None) if not cell: continue - excel_values.append([row['title'], cell['value']]) - report_values.append('{}-{}'.format( - row['key'], - cell['key'], - )) + excel_values.append([row["title"], cell["value"]]) + report_values.append( + "{}-{}".format( + row["key"], + cell["key"], + ) + ) return { - 'excel': { - 'type': 'lists', - 'values': excel_values, + "excel": { + "type": "lists", + "values": excel_values, }, - 'report': { - 'keys': report_values, + "report": { + "keys": report_values, }, } def migrate_matrix2d(self, entry, widget, element): - selections = element.get('selections') or [] + selections = element.get("selections") or [] attribute = self.get_matrix2d_attribute(selections) self.migrate_attribute_data(entry, widget, attribute) - widget_data = widget.properties['data'] + widget_data = widget.properties["data"] - filter_values1, filter_values2 = self.get_matrix2d_filter_values( - selections - ) + filter_values1, filter_values2 = self.get_matrix2d_filter_values(selections) self.migrate_filter_data( - entry, widget, + entry, + widget, index=0, values=filter_values1, ) self.migrate_filter_data( - entry, widget, + entry, + widget, index=1, values=filter_values2, ) self.migrate_export_data( - entry, widget, + entry, + widget, self.get_matrix2d_export_data(selections, widget_data), ) def get_matrix2d_attribute(self, selections): attribute = {} for selection in selections: - pillar = selection.get('pillar') - subpillar = selection.get('subpillar') - sector = selection.get('sector') + pillar = selection.get("pillar") + subpillar = selection.get("subpillar") + sector = selection.get("sector") if not pillar or not subpillar or not sector: continue - subsectors = selection.get('subsectors') or [] + subsectors = selection.get("subsectors") or [] if pillar not in attribute: attribute[pillar] = {} @@ -543,12 +564,12 @@ def get_matrix2d_filter_values(self, selections): filter_values2 = [] for selection in selections: - pillar = selection.get('pillar') - subpillar = selection.get('subpillar') - sector = selection.get('sector') + pillar = selection.get("pillar") + subpillar = selection.get("subpillar") + sector = selection.get("sector") if not pillar or not subpillar or not sector: continue - subsectors = selection.get('subsectors') or [] + subsectors = selection.get("subsectors") or [] if pillar not in filter_values1: filter_values1.append(pillar) @@ -562,55 +583,54 @@ def get_matrix2d_filter_values(self, selections): def get_matrix2d_export_data(self, selections, data): excel_values = [] report_values = [] - dimensions = data['dimensions'] - sectors = data['sectors'] + dimensions = data["dimensions"] + sectors = data["sectors"] for selection in selections: - pillar = selection.get('pillar') - subpillar = selection.get('subpillar') - sector = selection.get('sector') + pillar = selection.get("pillar") + subpillar = selection.get("subpillar") + sector = selection.get("sector") if not pillar or not subpillar or not sector: continue - dim = next((d for d in dimensions - if d['id'] == pillar), - None) + dim = next((d for d in dimensions if d["id"] == pillar), None) if not dim: continue - sub = next((s for s in dim['subdimensions'] - if s['id'] == subpillar), - None) + sub = next((s for s in dim["subdimensions"] if s["id"] == subpillar), None) if not sub: continue - sector = next((s for s in sectors - if s['id'] == sector), - None) + sector = next((s for s in sectors if s["id"] == sector), None) if not sector: continue subsector_names = [] - for subsector in selection.get('subsectors') or []: - ss = next((ss for ss in sector['subsectors'] - if ss['id'] == subsector), None) + for subsector in selection.get("subsectors") or []: + ss = next((ss for ss in sector["subsectors"] if ss["id"] == subsector), None) if ss: - subsector_names.append(ss['title']) - - excel_values.append([ - dim['title'], - sub['title'], - sector['title'], - ','.join(subsector_names), - ]) - report_values.append('{}-{}-{}'.format( - dim['id'], sub['id'], sector['id'], - )) + subsector_names.append(ss["title"]) + + excel_values.append( + [ + dim["title"], + sub["title"], + sector["title"], + ",".join(subsector_names), + ] + ) + report_values.append( + "{}-{}-{}".format( + dim["id"], + sub["id"], + sector["id"], + ) + ) return { - 'excel': { - 'type': 'lists', - 'values': excel_values, + "excel": { + "type": "lists", + "values": excel_values, }, - 'report': { - 'keys': report_values, + "report": { + "keys": report_values, }, } diff --git a/apps/deep_migration/management/commands/migrate_geo.py b/apps/deep_migration/management/commands/migrate_geo.py index c8a9d27296..7a289c0405 100644 --- a/apps/deep_migration/management/commands/migrate_geo.py +++ b/apps/deep_migration/management/commands/migrate_geo.py @@ -1,45 +1,36 @@ import requests - +from deep_migration.models import AdminLevelMigration, CountryMigration from deep_migration.utils import ( MigrationCommand, - get_source_url, get_migrated_gallery_file, + get_source_url, ) - -from geo.models import ( - Region, - AdminLevel, -) -from deep_migration.models import ( - CountryMigration, - AdminLevelMigration, -) - +from geo.models import AdminLevel, Region from geo.tasks import load_geo_areas class Command(MigrationCommand): def run(self): - data = requests.get(get_source_url('countries')).json() + data = requests.get(get_source_url("countries")).json() - if not data or not data.get('data'): - print('Couldn\'t find countries data') + if not data or not data.get("data"): + print("Couldn't find countries data") return - countries = data['data'] + countries = data["data"] for country in countries: self.import_country(country) def import_country(self, country): - print('------------') - print('Migrating country') + print("------------") + print("Migrating country") - code = country['reference_code'] - modified_code = country['code'] - title = country['name'] - print('{} - {}'.format(code, title)) + code = country["reference_code"] + modified_code = country["code"] + title = country["name"] + print("{} - {}".format(code, title)) - public = (code == modified_code) + public = code == modified_code migration, _ = CountryMigration.objects.get_or_create( code=modified_code, @@ -59,14 +50,14 @@ def import_country(self, country): region.title = title region.public = public - region.regional_groups = country['regions'] - region.key_figures = country['key_figures'] - region.media_sources = country['media_sources'] + region.regional_groups = country["regions"] + region.key_figures = country["key_figures"] + region.media_sources = country["media_sources"] region.save() - admin_levels = country['admin_levels'] - admin_levels.sort(key=lambda a: a['level']) + admin_levels = country["admin_levels"] + admin_levels.sort(key=lambda a: a["level"]) parent = None for admin_level in admin_levels: parent = self.import_admin_level(region, parent, admin_level) @@ -75,18 +66,15 @@ def import_country(self, country): return region def import_admin_level(self, region, parent, data): - print('Migrating admin level') + print("Migrating admin level") - old_id = data['id'] - title = data['name'] - level = data['level'] + old_id = data["id"] + title = data["name"] + level = data["level"] - print('{} - {}'.format(data['id'], - data['name'])) + print("{} - {}".format(data["id"], data["name"])) - migration, _ = AdminLevelMigration.objects.get_or_create( - old_id=old_id - ) + migration, _ = AdminLevelMigration.objects.get_or_create(old_id=old_id) if not migration.admin_level: admin_level = AdminLevel.objects.create( @@ -102,11 +90,11 @@ def import_admin_level(self, region, parent, data): admin_level = migration.admin_level admin_level.parent = parent admin_level.level = level - admin_level.name_prop = data['property_name'] - admin_level.code_prop = data['property_pcode'] + admin_level.name_prop = data["property_name"] + admin_level.code_prop = data["property_pcode"] if level > 0: - admin_level.parent_name_prop = 'NAME_{}'.format(level - 1) - admin_level.geo_shape_file = get_migrated_gallery_file(data['geojson']) + admin_level.parent_name_prop = "NAME_{}".format(level - 1) + admin_level.geo_shape_file = get_migrated_gallery_file(data["geojson"]) admin_level.stale_geo_areas = True admin_level.save() diff --git a/apps/deep_migration/management/commands/migrate_lead.py b/apps/deep_migration/management/commands/migrate_lead.py index 4eec5c8891..b2062f38d8 100644 --- a/apps/deep_migration/management/commands/migrate_lead.py +++ b/apps/deep_migration/management/commands/migrate_lead.py @@ -1,21 +1,15 @@ import json +import reversion +from deep_migration.models import LeadMigration, ProjectMigration, UserMigration from deep_migration.utils import ( MigrationCommand, + get_migrated_gallery_file, get_source_url, request_with_auth, - get_migrated_gallery_file, -) - -from deep_migration.models import ( - LeadMigration, - ProjectMigration, - UserMigration, ) -from lead.models import Lead - from django.utils.dateparse import parse_date -import reversion +from lead.models import Lead def get_user(old_user_id): @@ -29,50 +23,50 @@ def get_project(project_id): CONFIDENTIALITY_MAP = { - 'UNP': Lead.Confidentiality.UNPROTECTED, - 'PRO': Lead.Confidentiality.PROTECTED, - 'RES': Lead.Confidentiality.RESTRICTED, - 'CON': Lead.Confidentiality.CONFIDENTIAL, + "UNP": Lead.Confidentiality.UNPROTECTED, + "PRO": Lead.Confidentiality.PROTECTED, + "RES": Lead.Confidentiality.RESTRICTED, + "CON": Lead.Confidentiality.CONFIDENTIAL, } STATUS_MAP = { - 'PEN': Lead.Status.NOT_TAGGED, - 'PRO': Lead.Status.PROTECTED, + "PEN": Lead.Status.NOT_TAGGED, + "PRO": Lead.Status.PROTECTED, } class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: leads = json.load(f) else: - data = request_with_auth(get_source_url('leads')) + data = request_with_auth(get_source_url("leads")) - if not data or not data.get('data'): - print('Couldn\'t find leads data') + if not data or not data.get("data"): + print("Couldn't find leads data") - leads = data['data'] + leads = data["data"] with reversion.create_revision(): for lead in leads: self.import_lead(lead) def import_lead(self, data): - print('------------') - print('Migrating lead') + print("------------") + print("Migrating lead") - old_id = data['id'] - title = data['name'] - project_id = data['event'] + old_id = data["id"] + title = data["name"] + project_id = data["event"] project = get_project(project_id) if not project: - print('Project with old id: {} doesn\'t exist'.format(project_id)) + print("Project with old id: {} doesn't exist".format(project_id)) return None - print('{} - {}'.format(old_id, title)) + print("{} - {}".format(old_id, title)) migration, _ = LeadMigration.objects.get_or_create( old_id=old_id, @@ -89,34 +83,31 @@ def import_lead(self, data): lead = migration.lead lead.title = title - lead.source = data['source'] or '' - lead.confidentiality = CONFIDENTIALITY_MAP[data['confidentiality']] - lead.status = STATUS_MAP[data['status']] + lead.source = data["source"] or "" + lead.confidentiality = CONFIDENTIALITY_MAP[data["confidentiality"]] + lead.status = STATUS_MAP[data["status"]] - lead.published_on = data['published_at'] and \ - parse_date(data['published_at']) - lead.created_by = get_user(data['created_by']) + lead.published_on = data["published_at"] and parse_date(data["published_at"]) + lead.created_by = get_user(data["created_by"]) lead.modified_by = lead.created_by - if data.get('description'): + if data.get("description"): lead.source_type = Lead.SourceType.TEXT - lead.text = data['description'] + lead.text = data["description"] - elif data.get('url'): + elif data.get("url"): lead.source_type = Lead.SourceType.WEBSITE - lead.url = data['url'] + lead.url = data["url"] - elif data.get('attachment'): + elif data.get("attachment"): lead.source_type = Lead.SourceType.DISK - lead.attachment = get_migrated_gallery_file( - data['attachment']['url'] - ) + lead.attachment = get_migrated_gallery_file(data["attachment"]["url"]) lead.save() - if data.get('assigned_to'): - lead.assignee.add(get_user(data.get('assigned_to'))) + if data.get("assigned_to"): + lead.assignee.add(get_user(data.get("assigned_to"))) - Lead.objects.filter(id=lead.id).update(created_at=data['created_at']) + Lead.objects.filter(id=lead.id).update(created_at=data["created_at"]) return lead diff --git a/apps/deep_migration/management/commands/migrate_project.py b/apps/deep_migration/management/commands/migrate_project.py index 6bb840a92a..282c336850 100644 --- a/apps/deep_migration/management/commands/migrate_project.py +++ b/apps/deep_migration/management/commands/migrate_project.py @@ -1,19 +1,10 @@ import json -from deep_migration.utils import ( - MigrationCommand, - get_source_url, - request_with_auth, -) -from deep_migration.models import ( - CountryMigration, - ProjectMigration, - UserMigration, -) -from project.models import Project, ProjectMembership, ProjectRole - -from django.utils.dateparse import parse_date import reversion +from deep_migration.models import CountryMigration, ProjectMigration, UserMigration +from deep_migration.utils import MigrationCommand, get_source_url, request_with_auth +from django.utils.dateparse import parse_date +from project.models import Project, ProjectMembership, ProjectRole def get_user(old_user_id): @@ -28,27 +19,27 @@ def get_region(reference_code): class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: projects = json.load(f) else: - projects = request_with_auth(get_source_url('events2', 'v1')) + projects = request_with_auth(get_source_url("events2", "v1")) if not projects: - print('Couldn\'t find projects data') + print("Couldn't find projects data") with reversion.create_revision(): for project in projects: self.import_project(project) def import_project(self, data): - print('------------') - print('Migrating project') + print("------------") + print("Migrating project") - old_id = data['id'] - title = data['name'] + old_id = data["id"] + title = data["name"] - print('{} - {}'.format(old_id, title)) + print("{} - {}".format(old_id, title)) migration, _ = ProjectMigration.objects.get_or_create( old_id=old_id, @@ -63,31 +54,29 @@ def import_project(self, data): return migration.project project = migration.project - project.start_date = data['start_date'] and \ - parse_date(data['start_date']) - project.end_date = data['end_date'] and \ - parse_date(data['end_date']) + project.start_date = data["start_date"] and parse_date(data["start_date"]) + project.end_date = data["end_date"] and parse_date(data["end_date"]) project.save() - for user_id in data['admins']: + for user_id in data["admins"]: user = get_user(user_id) if user: ProjectMembership.objects.get_or_create( project=project, member=user, - defaults={'role': ProjectRole.get_admin_roles().first()}, + defaults={"role": ProjectRole.get_admin_roles().first()}, ) - for user_id in data['members']: + for user_id in data["members"]: user = get_user(user_id) if user: ProjectMembership.objects.get_or_create( project=project, member=user, - defaults={'role': ProjectRole.get_default_role()}, + defaults={"role": ProjectRole.get_default_role()}, ) - for region_code in data['countries']: + for region_code in data["countries"]: region = get_region(region_code) if region and region not in project.regions.all(): project.regions.add(region) diff --git a/apps/deep_migration/management/commands/migrate_user.py b/apps/deep_migration/management/commands/migrate_user.py index 8de9471654..e2860f3d99 100644 --- a/apps/deep_migration/management/commands/migrate_user.py +++ b/apps/deep_migration/management/commands/migrate_user.py @@ -1,44 +1,42 @@ import json -from django.contrib.auth.models import User import requests - +from deep_migration.models import UserMigration from deep_migration.utils import ( MigrationCommand, - get_source_url, get_migrated_gallery_file, + get_source_url, ) - -from deep_migration.models import UserMigration +from django.contrib.auth.models import User class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: data = json.load(f) else: - data = requests.get(get_source_url('users2', 'v1')).json() + data = requests.get(get_source_url("users2", "v1")).json() if not data: - print('Couldn\'t find users data') + print("Couldn't find users data") return for user in data: self.import_user(user) def import_user(self, data): - print('------------') - print('Migrating user') + print("------------") + print("Migrating user") - old_id = data['id'] - username = data['username'] - email = data['email'] + old_id = data["id"] + username = data["username"] + email = data["email"] - first_name = data['first_name'] - last_name = data['last_name'] + first_name = data["first_name"] + last_name = data["last_name"] - print('{} - {} {}'.format(old_id, first_name, last_name)) + print("{} - {} {}".format(old_id, first_name, last_name)) migration, _ = UserMigration.objects.get_or_create( old_id=old_id, @@ -60,11 +58,9 @@ def import_user(self, data): user.first_name = first_name user.last_name = last_name - user.profile.organization = data['organization'] - user.profile.display_picture = get_migrated_gallery_file( - data['photo'] - ) - user.profile.hid = data['hid'] + user.profile.organization = data["organization"] + user.profile.display_picture = get_migrated_gallery_file(data["photo"]) + user.profile.hid = data["hid"] user.save() return user diff --git a/apps/deep_migration/management/commands/migrate_user_group.py b/apps/deep_migration/management/commands/migrate_user_group.py index 14372e1959..bee4496aef 100644 --- a/apps/deep_migration/management/commands/migrate_user_group.py +++ b/apps/deep_migration/management/commands/migrate_user_group.py @@ -1,23 +1,14 @@ import json +import reversion +from deep_migration.models import ProjectMigration, UserGroupMigration, UserMigration from deep_migration.utils import ( MigrationCommand, + get_migrated_gallery_file, get_source_url, request_with_auth, - get_migrated_gallery_file, -) - -from deep_migration.models import ( - UserGroupMigration, - UserMigration, - ProjectMigration, ) -from user_group.models import ( - UserGroup, - GroupMembership, -) - -import reversion +from user_group.models import GroupMembership, UserGroup def get_user(old_user_id): @@ -32,29 +23,27 @@ def get_project(project_id): class Command(MigrationCommand): def run(self): - if self.kwargs.get('data_file'): - with open(self.kwargs['data_file']) as f: + if self.kwargs.get("data_file"): + with open(self.kwargs["data_file"]) as f: user_groups = json.load(f) else: - user_groups = request_with_auth( - get_source_url('user-groups', 'v1') - ) + user_groups = request_with_auth(get_source_url("user-groups", "v1")) if not user_groups: - print('Couldn\'t find user groups data') + print("Couldn't find user groups data") with reversion.create_revision(): for user_group in user_groups: self.import_user_group(user_group) def import_user_group(self, data): - print('------------') - print('Migrating user group') + print("------------") + print("Migrating user group") - old_id = data['id'] - title = data['name'] + old_id = data["id"] + title = data["name"] - print('{} - {}'.format(old_id, title)) + print("{} - {}".format(old_id, title)) migration, _ = UserGroupMigration.objects.get_or_create( old_id=old_id, @@ -69,32 +58,30 @@ def import_user_group(self, data): return migration.user_group user_group = migration.user_group - user_group.description = data['description'] - user_group.display_picture = get_migrated_gallery_file( - data['photo'] - ) - user_group.global_crisis_monitoring = data['acaps'] + user_group.description = data["description"] + user_group.display_picture = get_migrated_gallery_file(data["photo"]) + user_group.global_crisis_monitoring = data["acaps"] user_group.save() - for user_id in data['admins']: + for user_id in data["admins"]: user = get_user(user_id) if user: GroupMembership.objects.get_or_create( group=user_group, member=user, - defaults={'role': 'admin'}, + defaults={"role": "admin"}, ) - for user_id in data['members']: + for user_id in data["members"]: user = get_user(user_id) if user: GroupMembership.objects.get_or_create( group=user_group, member=user, - defaults={'role': 'normal'}, + defaults={"role": "normal"}, ) - for project_id in data['projects']: + for project_id in data["projects"]: project = get_project(project_id) if project: project.user_groups.add(user_group) diff --git a/apps/deep_migration/models.py b/apps/deep_migration/models.py index f767782931..0a4488e5f9 100644 --- a/apps/deep_migration/models.py +++ b/apps/deep_migration/models.py @@ -1,13 +1,10 @@ +from analysis_framework.models import AnalysisFramework from django.contrib.auth.models import User from django.db import models -from geo.models import ( - Region, - AdminLevel, -) -from user_group.models import UserGroup -from project.models import Project +from geo.models import AdminLevel, Region from lead.models import Lead -from analysis_framework.models import AnalysisFramework +from project.models import Project +from user_group.models import UserGroup class BaseMigration(models.Model): @@ -16,60 +13,81 @@ class BaseMigration(models.Model): class Meta: abstract = True - ordering = ['-first_migrated_at'] + ordering = ["-first_migrated_at"] class UserMigration(models.Model): old_id = models.IntegerField(unique=True) user = models.ForeignKey( - User, on_delete=models.CASCADE, - default=None, blank=True, null=True, + User, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class CountryMigration(BaseMigration): code = models.CharField(max_length=50, unique=True) region = models.ForeignKey( - Region, on_delete=models.CASCADE, - default=None, blank=True, null=True, + Region, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class AdminLevelMigration(BaseMigration): old_id = models.IntegerField(unique=True) admin_level = models.ForeignKey( - AdminLevel, on_delete=models.CASCADE, - default=None, blank=True, null=True, + AdminLevel, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class ProjectMigration(BaseMigration): old_id = models.IntegerField(unique=True) project = models.ForeignKey( - Project, on_delete=models.CASCADE, - default=None, blank=True, null=True, + Project, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class LeadMigration(BaseMigration): old_id = models.IntegerField(unique=True) lead = models.ForeignKey( - Lead, on_delete=models.CASCADE, - default=None, blank=True, null=True, + Lead, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class AnalysisFrameworkMigration(BaseMigration): old_id = models.IntegerField(unique=True) analysis_framework = models.ForeignKey( - AnalysisFramework, on_delete=models.CASCADE, - default=None, blank=True, null=True, + AnalysisFramework, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) class UserGroupMigration(BaseMigration): old_id = models.IntegerField(unique=True) user_group = models.ForeignKey( - UserGroup, on_delete=models.CASCADE, - default=None, blank=True, null=True, + UserGroup, + on_delete=models.CASCADE, + default=None, + blank=True, + null=True, ) diff --git a/apps/deep_migration/utils.py b/apps/deep_migration/utils.py index feb1cea144..e3217a677e 100644 --- a/apps/deep_migration/utils.py +++ b/apps/deep_migration/utils.py @@ -1,48 +1,39 @@ -from requests.auth import HTTPBasicAuth -from urllib.parse import urlparse, unquote -from django.core.management.base import BaseCommand - -import requests import os +from urllib.parse import unquote, urlparse +import requests +from django.core.management.base import BaseCommand from gallery.models import File +from requests.auth import HTTPBasicAuth class MigrationCommand(BaseCommand): def add_arguments(self, parser): parser.add_argument( - '--url', - dest='DEEP_1_URL', - ) - parser.add_argument( - '--user', - dest='DEEP_1_USER', - ) - parser.add_argument( - '--password', - dest='DEEP_1_PASSWORD', + "--url", + dest="DEEP_1_URL", ) parser.add_argument( - '--use_s3', - dest='DJANGO_USE_S3', + "--user", + dest="DEEP_1_USER", ) parser.add_argument( - '--query_str', - type=str, - default='' + "--password", + dest="DEEP_1_PASSWORD", ) parser.add_argument( - '--data_file', - type=str, - default=None + "--use_s3", + dest="DJANGO_USE_S3", ) + parser.add_argument("--query_str", type=str, default="") + parser.add_argument("--data_file", type=str, default=None) def handle(self, *args, **kwargs): valid_keys = [ - 'DEEP_1_URL', - 'DEEP_1_USER', - 'DEEP_1_PASSWORD', - 'DJANGO_USE_S3', + "DEEP_1_URL", + "DEEP_1_USER", + "DEEP_1_PASSWORD", + "DJANGO_USE_S3", ] self.kwargs = kwargs for key, value in kwargs.items(): @@ -60,19 +51,19 @@ def run(self): # PASSWORD = os.environ.get('DEEP_1_PASSWORD', 'admin123') -def get_source_url(suffix, version='v2', query=''): - BASE_URL = os.environ.get('DEEP_1_URL', 'http://172.21.0.1:9000') - return '{}/api/{}/{}/?{}'.format(BASE_URL, version, suffix, query) +def get_source_url(suffix, version="v2", query=""): + BASE_URL = os.environ.get("DEEP_1_URL", "http://172.21.0.1:9000") + return "{}/api/{}/{}/?{}".format(BASE_URL, version, suffix, query) def get_migrated_s3_key(s3_url): url_data = urlparse(s3_url) - old_key = unquote(url_data.path)[len('/media/'):] - return 'deep-v1/{}'.format(old_key) + old_key = unquote(url_data.path)[len("/media/") :] + return "deep-v1/{}".format(old_key) def is_using_s3(): - return os.environ.get('DJANGO_USE_S3', 'False').lower() == 'true' + return os.environ.get("DJANGO_USE_S3", "False").lower() == "true" def get_migrated_gallery_file(s3_url, title=None): @@ -83,19 +74,19 @@ def get_migrated_gallery_file(s3_url, title=None): key = get_migrated_s3_key(s3_url) if not title: - title = key.split('/')[-1] + title = key.split("/")[-1] gallery_file, _ = File.objects.get_or_create( file=key, defaults={ - 'title': title, - } + "title": title, + }, ) return gallery_file def request_with_auth(url): - USERNAME = os.environ.get('DEEP_1_USER', 'test@toggle.com') - PASSWORD = os.environ.get('DEEP_1_PASSWORD', 'admin123') + USERNAME = os.environ.get("DEEP_1_USER", "test@toggle.com") + PASSWORD = os.environ.get("DEEP_1_PASSWORD", "admin123") return requests.get(url, auth=HTTPBasicAuth(USERNAME, PASSWORD)).json() diff --git a/apps/deepl_integration/apps.py b/apps/deepl_integration/apps.py index 6154a083d1..867de2881c 100644 --- a/apps/deepl_integration/apps.py +++ b/apps/deepl_integration/apps.py @@ -2,4 +2,4 @@ class DeeplIntegrationConfig(AppConfig): - name = 'deepl_integration' + name = "deepl_integration" diff --git a/apps/deepl_integration/handlers.py b/apps/deepl_integration/handlers.py index 031a5f3d63..5260b59c97 100644 --- a/apps/deepl_integration/handlers.py +++ b/apps/deepl_integration/handlers.py @@ -1,59 +1,53 @@ -import os -import json import copy -import requests +import json import logging -from typing import List, Type +import os from functools import reduce +from typing import List, Type from urllib.parse import urlparse -from django.conf import settings -from django.urls import reverse -from django.utils.encoding import DjangoUnicodeDecodeError -from django.utils import timezone -from django.core.paginator import Paginator -from django.db import transaction, models -from rest_framework import serializers - -from deep.token import DeepTokenGenerator -from deep.deepl import DeeplServiceEndpoint -from utils.common import UidBase64Helper, get_full_media_url -from utils.request import RequestHelper -from deep.exceptions import DeepBaseException - +import requests +from analysis.models import ( + AnalyticalStatementGeoEntry, + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, + TopicModel, + TopicModelCluster, +) from assisted_tagging.models import ( - DraftEntry, AssistedTaggingModel, - AssistedTaggingModelVersion, AssistedTaggingModelPredictionTag, + AssistedTaggingModelVersion, AssistedTaggingPrediction, + DraftEntry, ) +from django.conf import settings +from django.core.paginator import Paginator +from django.db import models, transaction +from django.urls import reverse +from django.utils import timezone +from django.utils.encoding import DjangoUnicodeDecodeError +from entry.models import Entry +from geo.filter_set import GeoAreaGqlFilterSet +from geo.models import GeoArea +from lead.models import Lead, LeadPreview, LeadPreviewImage +from lead.typings import NlpExtractorDocument +from rest_framework import serializers from unified_connector.models import ( ConnectorLead, ConnectorLeadPreviewImage, ConnectorSource, UnifiedConnector, ) -from lead.models import ( - Lead, - LeadPreview, - LeadPreviewImage, -) -from lead.typings import NlpExtractorDocument -from entry.models import Entry -from analysis.models import ( - TopicModel, - TopicModelCluster, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, - AnalyticalStatementGeoEntry, -) -from geo.models import GeoArea -from geo.filter_set import GeoAreaGqlFilterSet -from .models import DeeplTrackBaseModel +from deep.deepl import DeeplServiceEndpoint +from deep.exceptions import DeepBaseException +from deep.token import DeepTokenGenerator +from utils.common import UidBase64Helper, get_full_media_url +from utils.request import RequestHelper +from .models import DeeplTrackBaseModel logger = logging.getLogger(__name__) @@ -74,10 +68,10 @@ def generate_file_url_for_new_deepl_server(file): def custom_error_handler(exception, url=None): if isinstance(exception, requests.exceptions.ConnectionError): - raise serializers.ValidationError(f'ConnectionError on provided file: {url}') + raise serializers.ValidationError(f"ConnectionError on provided file: {url}") if isinstance(exception, json.decoder.JSONDecodeError): - raise serializers.ValidationError(f'Failed to parse provided json file: {url}') - raise serializers.ValidationError(f'Failed to handle the provided file: : {url}') + raise serializers.ValidationError(f"Failed to parse provided json file: {url}") + raise serializers.ValidationError(f"Failed to handle the provided file: : {url}") class DefaultClientIdGenerator(DeepTokenGenerator): @@ -94,8 +88,8 @@ class NlpRequestType: class BaseHandler: REQUEST_HEADERS = { - 'Content-Type': 'application/json', - 'Authorization': f'Token {settings.DEEPL_SERVER_TOKEN}', + "Content-Type": "application/json", + "Authorization": f"Token {settings.DEEPL_SERVER_TOKEN}", } # --- Override @@ -108,31 +102,29 @@ class BaseHandler: class Exception: class InvalidTokenValue(DeepBaseException): - default_message = 'Invalid Token' + default_message = "Invalid Token" class InvalidOrExpiredToken(DeepBaseException): - default_message = 'Invalid/expired token in client_id' + default_message = "Invalid/expired token in client_id" class ObjectNotFound(DeepBaseException): - default_message = 'No draft entry found for provided id' + default_message = "No draft entry found for provided id" @classmethod def get_callback_url(cls, **kwargs): - return ( - settings.DEEPL_SERVICE_CALLBACK_DOMAIN + - reverse( - cls.callback_url_name, kwargs={ - 'version': 'v1', - **kwargs, - }, - ) + return settings.DEEPL_SERVICE_CALLBACK_DOMAIN + reverse( + cls.callback_url_name, + kwargs={ + "version": "v1", + **kwargs, + }, ) @classmethod def get_client_id(cls, instance: models.Model) -> str: uid = UidBase64Helper.encode(instance.pk) token = cls.client_id_generator.make_token(instance) - return f'{uid}-{token}' + return f"{uid}-{token}" @classmethod def get_object_using_client_id(cls, client_id): @@ -141,69 +133,64 @@ def get_object_using_client_id(cls, client_id): - Raise error if invalid/404/expired """ try: - uidb64, token = client_id.split('-', 1) + uidb64, token = client_id.split("-", 1) uid = UidBase64Helper.decode(uidb64) except (ValueError, DjangoUnicodeDecodeError): raise cls.Exception.InvalidTokenValue() if (instance := cls.model.objects.filter(id=uid).first()) is None: - raise cls.Exception.ObjectNotFound(f'No {cls.model.__name__} found for provided id: {uid}') + raise cls.Exception.ObjectNotFound(f"No {cls.model.__name__} found for provided id: {uid}") if not cls.client_id_generator.check_token(instance, token): raise cls.Exception.InvalidOrExpiredToken() return instance @classmethod def send_trigger_request_to_extractor(cls, *_): - raise Exception('Not implemented yet.') + raise Exception("Not implemented yet.") @classmethod def save_data(cls, *_): - raise Exception('Not implemented yet.') + raise Exception("Not implemented yet.") class AssistedTaggingDraftEntryHandler(BaseHandler): model = DraftEntry - callback_url_name = 'assisted_tagging_draft_entry_prediction_callback' + callback_url_name = "assisted_tagging_draft_entry_prediction_callback" @classmethod def send_trigger_request_to_extractor(cls, draft_entry): source_organization = draft_entry.lead.source - author_organizations = [ - author.data.title - for author in draft_entry.lead.authors.all() - ] + author_organizations = [author.data.title for author in draft_entry.lead.authors.all()] payload = { - 'entries': [ + "entries": [ { - 'client_id': cls.get_client_id(draft_entry), - 'entry': draft_entry.excerpt, + "client_id": cls.get_client_id(draft_entry), + "entry": draft_entry.excerpt, } ], - 'lead': draft_entry.lead_id, - 'project': draft_entry.project_id, - 'publishing_organization': source_organization and source_organization.data.title, - 'authoring_organization': author_organizations, - 'callback_url': cls.get_callback_url(), + "lead": draft_entry.lead_id, + "project": draft_entry.project_id, + "publishing_organization": source_organization and source_organization.data.title, + "authoring_organization": author_organizations, + "callback_url": cls.get_callback_url(), } response_content = None try: response = requests.post( - DeeplServiceEndpoint.ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT, - headers=cls.REQUEST_HEADERS, - json=payload + DeeplServiceEndpoint.ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT, headers=cls.REQUEST_HEADERS, json=payload ) response_content = response.content if response.status_code == 202: return True except Exception: - logger.error('Assisted tagging send failed, Exception occurred!!', exc_info=True) + logger.error("Assisted tagging send failed, Exception occurred!!", exc_info=True) draft_entry.prediction_status = DraftEntry.PredictionStatus.SEND_FAILED - draft_entry.save(update_fields=('prediction_status',)) + draft_entry.save(update_fields=("prediction_status",)) logger.error( - 'Assisted tagging send failed!!', + "Assisted tagging send failed!!", extra={ - 'data': { - 'payload': payload, - 'response': response_content, + "data": { + "payload": payload, + "response": response_content, }, }, ) @@ -219,35 +206,37 @@ def get_versions_map(): lambda acc, item: acc | item, [ models.Q( - model__model_id=model_data['id'], - version=model_data['version'], + model__model_id=model_data["id"], + version=model_data["version"], ) for model_data in models_data ], ) - ).select_related('model').all() + ) + .select_related("model") + .all() } existing_model_versions = get_versions_map() new_model_versions = [ - model_data - for model_data in models_data - if (model_data['id'], model_data['version']) not in existing_model_versions + model_data for model_data in models_data if (model_data["id"], model_data["version"]) not in existing_model_versions ] if new_model_versions: - AssistedTaggingModelVersion.objects.bulk_create([ - AssistedTaggingModelVersion( - model=AssistedTaggingModel.objects.get_or_create( - model_id=model_data['id'], - defaults=dict( - name=model_data['id'], - ), - )[0], - version=model_data['version'], - ) - for model_data in models_data - ]) + AssistedTaggingModelVersion.objects.bulk_create( + [ + AssistedTaggingModelVersion( + model=AssistedTaggingModel.objects.get_or_create( + model_id=model_data["id"], + defaults=dict( + name=model_data["id"], + ), + )[0], + version=model_data["version"], + ) + for model_data in models_data + ] + ) existing_model_versions = get_versions_map() return existing_model_versions @@ -256,27 +245,22 @@ def _get_or_create_tags_map(cls, tags): from assisted_tagging.tasks import sync_tags_with_deepl_task def get_tags_map(): - return { - tag_id: _id - for _id, tag_id in AssistedTaggingModelPredictionTag.objects.values_list('id', 'tag_id') - } + return {tag_id: _id for _id, tag_id in AssistedTaggingModelPredictionTag.objects.values_list("id", "tag_id")} current_tags_map = get_tags_map() # Check if new tags needs to be created - new_tags = [ - tag - for tag in tags - if tag not in current_tags_map - ] + new_tags = [tag for tag in tags if tag not in current_tags_map] if new_tags: # Create new tags - AssistedTaggingModelPredictionTag.objects.bulk_create([ - AssistedTaggingModelPredictionTag( - name=new_tag, - tag_id=new_tag, - ) - for new_tag in new_tags - ]) + AssistedTaggingModelPredictionTag.objects.bulk_create( + [ + AssistedTaggingModelPredictionTag( + name=new_tag, + tag_id=new_tag, + ) + for new_tag in new_tags + ] + ) # Refetch current_tags_map = get_tags_map() sync_tags_with_deepl_task.delay() @@ -284,12 +268,12 @@ def get_tags_map(): @classmethod def _process_model_preds(cls, model_version, current_tags_map, draft_entry, model_prediction): - prediction_status = model_prediction['prediction_status'] + prediction_status = model_prediction["prediction_status"] if not prediction_status: # If False no tags are provided return - tags = model_prediction.get('model_tags', {}) # NLP TagId - values = model_prediction.get('values', []) # Raw value + tags = model_prediction.get("model_tags", {}) # NLP TagId + values = model_prediction.get("values", []) # Raw value common_attrs = dict( model_version=model_version, @@ -298,9 +282,9 @@ def _process_model_preds(cls, model_version, current_tags_map, draft_entry, mode new_predictions = [] for category_tag, tags in tags.items(): for tag, prediction_data in tags.items(): - prediction_value = prediction_data.get('prediction') - threshold_value = prediction_data.get('threshold') - is_selected = prediction_data.get('is_selected', False) + prediction_value = prediction_data.get("prediction") + threshold_value = prediction_data.get("threshold") + is_selected = prediction_data.get("is_selected", False) new_predictions.append( AssistedTaggingPrediction( **common_attrs, @@ -328,23 +312,21 @@ def _process_model_preds(cls, model_version, current_tags_map, draft_entry, mode def save_data(cls, draft_entry, data): model_preds = data # Save if new tags are provided - current_tags_map = cls._get_or_create_tags_map([ - tag - for category_tag, tags in model_preds['model_tags'].items() - for tag in [ - category_tag, - *tags.keys(), - ] - ]) - models_version_map = cls._get_or_create_models_version( + current_tags_map = cls._get_or_create_tags_map( [ - model_preds['model_info'] + tag + for category_tag, tags in model_preds["model_tags"].items() + for tag in [ + category_tag, + *tags.keys(), + ] ] ) + models_version_map = cls._get_or_create_models_version([model_preds["model_info"]]) with transaction.atomic(): draft_entry.clear_data() # Clear old data if exists draft_entry.calculated_at = timezone.now() - model_version = models_version_map[(model_preds['model_info']['id'], model_preds['model_info']['version'])] + model_version = models_version_map[(model_preds["model_info"]["id"], model_preds["model_info"]["version"])] cls._process_model_preds(model_version, current_tags_map, draft_entry, model_preds) draft_entry.prediction_status = DraftEntry.PredictionStatus.DONE draft_entry.save_geo_data() @@ -356,7 +338,7 @@ class AutoAssistedTaggingDraftEntryHandler(BaseHandler): # TODO: Fix N+1 issues here. Try to do bulk_update for each models. # Or do this Async model = Lead - callback_url_name = 'auto-assisted_tagging_draft_entry_prediction_callback' + callback_url_name = "auto-assisted_tagging_draft_entry_prediction_callback" @classmethod def auto_trigger_request_to_extractor(cls, lead): @@ -368,31 +350,29 @@ def auto_trigger_request_to_extractor(cls, lead): "text_extraction_id": str(lead_preview.text_extraction_id), } ], - "callback_url": cls.get_callback_url() + "callback_url": cls.get_callback_url(), } response_content = None try: response = requests.post( - url=DeeplServiceEndpoint.ENTRY_EXTRACTION_CLASSIFICATION, - headers=cls.REQUEST_HEADERS, - json=payload + url=DeeplServiceEndpoint.ENTRY_EXTRACTION_CLASSIFICATION, headers=cls.REQUEST_HEADERS, json=payload ) response_content = response.content if response.status_code == 202: lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.PENDING - lead.save(update_fields=('auto_entry_extraction_status',)) + lead.save(update_fields=("auto_entry_extraction_status",)) return True except Exception: - logger.error('Entry Extraction send failed, Exception occurred!!', exc_info=True) + logger.error("Entry Extraction send failed, Exception occurred!!", exc_info=True) lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.FAILED - lead.save(update_fields=('auto_entry_extraction_status',)) + lead.save(update_fields=("auto_entry_extraction_status",)) logger.error( - 'Entry Extraction send failed!!', + "Entry Extraction send failed!!", extra={ - 'data': { - 'payload': payload, - 'response': response_content, + "data": { + "payload": payload, + "response": response_content, }, }, ) @@ -408,35 +388,37 @@ def get_versions_map(): lambda acc, item: acc | item, [ models.Q( - model__model_id=model_data['name'], - version=model_data['version'], + model__model_id=model_data["name"], + version=model_data["version"], ) for model_data in models_data ], ) - ).select_related('model').all() + ) + .select_related("model") + .all() } existing_model_versions = get_versions_map() new_model_versions = [ - model_data - for model_data in models_data - if (model_data['name'], model_data['version']) not in existing_model_versions + model_data for model_data in models_data if (model_data["name"], model_data["version"]) not in existing_model_versions ] if new_model_versions: - AssistedTaggingModelVersion.objects.bulk_create([ - AssistedTaggingModelVersion( - model=AssistedTaggingModel.objects.get_or_create( - model_id=model_data['name'], - defaults=dict( - name=model_data['name'], - ), - )[0], - version=model_data['version'], - ) - for model_data in models_data - ]) + AssistedTaggingModelVersion.objects.bulk_create( + [ + AssistedTaggingModelVersion( + model=AssistedTaggingModel.objects.get_or_create( + model_id=model_data["name"], + defaults=dict( + name=model_data["name"], + ), + )[0], + version=model_data["version"], + ) + for model_data in models_data + ] + ) existing_model_versions = get_versions_map() return existing_model_versions @@ -445,27 +427,22 @@ def _get_or_create_tags_map(cls, tags): from assisted_tagging.tasks import sync_tags_with_deepl_task def get_tags_map(): - return { - tag_id: _id - for _id, tag_id in AssistedTaggingModelPredictionTag.objects.values_list('id', 'tag_id') - } + return {tag_id: _id for _id, tag_id in AssistedTaggingModelPredictionTag.objects.values_list("id", "tag_id")} current_tags_map = get_tags_map() # Check if new tags needs to be created - new_tags = [ - tag - for tag in tags - if tag not in current_tags_map - ] + new_tags = [tag for tag in tags if tag not in current_tags_map] if new_tags: # Create new tags - AssistedTaggingModelPredictionTag.objects.bulk_create([ - AssistedTaggingModelPredictionTag( - name=new_tag, - tag_id=new_tag, - ) - for new_tag in new_tags - ]) + AssistedTaggingModelPredictionTag.objects.bulk_create( + [ + AssistedTaggingModelPredictionTag( + name=new_tag, + tag_id=new_tag, + ) + for new_tag in new_tags + ] + ) # Refetch current_tags_map = get_tags_map() sync_tags_with_deepl_task.delay() @@ -473,12 +450,12 @@ def get_tags_map(): @classmethod def _process_model_preds(cls, model_version, current_tags_map, draft_entry, model_prediction): - prediction_status = model_prediction['prediction_status'] + prediction_status = model_prediction["prediction_status"] if not prediction_status: # If False no tags are provided return - tags = model_prediction.get('classification', {}) # NLP TagId - values = model_prediction.get('values', []) # Raw value + tags = model_prediction.get("classification", {}) # NLP TagId + values = model_prediction.get("values", []) # Raw value common_attrs = dict( model_version=model_version, @@ -487,9 +464,9 @@ def _process_model_preds(cls, model_version, current_tags_map, draft_entry, mode new_predictions = [] for category_tag, tags in tags.items(): for tag, prediction_data in tags.items(): - prediction_value = prediction_data.get('prediction') - threshold_value = prediction_data.get('threshold') - is_selected = prediction_data.get('is_selected', False) + prediction_value = prediction_data.get("prediction") + threshold_value = prediction_data.get("threshold") + is_selected = prediction_data.get("is_selected", False) new_predictions.append( AssistedTaggingPrediction( **common_attrs, @@ -521,49 +498,49 @@ def save_data(cls, lead, data_url): data = RequestHelper(url=data_url, ignore_error=True).json() draft_entry_qs = DraftEntry.objects.filter(lead=lead, type=DraftEntry.Type.AUTO) if draft_entry_qs.exists(): - raise serializers.ValidationError('Draft entries already exit') - for model_preds in data['blocks']: - if not model_preds['relevant']: + raise serializers.ValidationError("Draft entries already exit") + for model_preds in data["blocks"]: + if not model_preds["relevant"]: continue - classification = model_preds['classification'] - current_tags_map = cls._get_or_create_tags_map([ - tag - for category_tag, tags in classification.items() - for tag in [ - category_tag, - *tags.keys(), + classification = model_preds["classification"] + current_tags_map = cls._get_or_create_tags_map( + [ + tag + for category_tag, tags in classification.items() + for tag in [ + category_tag, + *tags.keys(), + ] ] - ]) - models_version_map = cls._get_or_create_models_version([ - data['classification_model_info'] - ]) + ) + models_version_map = cls._get_or_create_models_version([data["classification_model_info"]]) draft = DraftEntry.objects.create( - page=model_preds['page'], - text_order=model_preds['textOrder'], + page=model_preds["page"], + text_order=model_preds["textOrder"], project=lead.project, lead=lead, - excerpt=model_preds['text'], + excerpt=model_preds["text"], prediction_status=DraftEntry.PredictionStatus.DONE, - type=DraftEntry.Type.AUTO + type=DraftEntry.Type.AUTO, ) - if model_preds['geolocations']: + if model_preds["geolocations"]: geo_areas_qs = GeoAreaGqlFilterSet( - data={'titles': [geo['entity'] for geo in model_preds['geolocations']]}, - queryset=GeoArea.get_for_project(lead.project) - ).qs.distinct('title') + data={"titles": [geo["entity"] for geo in model_preds["geolocations"]]}, + queryset=GeoArea.get_for_project(lead.project), + ).qs.distinct("title") draft.related_geoareas.set(geo_areas_qs) model_version = models_version_map[ - (data['classification_model_info']['name'], data['classification_model_info']['version']) + (data["classification_model_info"]["name"], data["classification_model_info"]["version"]) ] cls._process_model_preds(model_version, current_tags_map, draft, model_preds) lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.SUCCESS - lead.save(update_fields=('auto_entry_extraction_status',)) + lead.save(update_fields=("auto_entry_extraction_status",)) return lead class LeadExtractionHandler(BaseHandler): model = Lead - callback_url_name = 'lead_extract_callback' + callback_url_name = "lead_extract_callback" RETRY_COUNTDOWN = 10 * 60 # 10 min @@ -575,30 +552,23 @@ def send_trigger_request_to_extractor( high_priority=False, ): payload = { - 'documents': documents, - 'callback_url': callback_url, - 'request_type': NlpRequestType.USER if high_priority else NlpRequestType.SYSTEM, + "documents": documents, + "callback_url": callback_url, + "request_type": NlpRequestType.USER if high_priority else NlpRequestType.SYSTEM, } response_content = None try: response = requests.post( - DeeplServiceEndpoint.DOCS_EXTRACTOR_ENDPOINT, - headers=cls.REQUEST_HEADERS, - data=json.dumps(payload) + DeeplServiceEndpoint.DOCS_EXTRACTOR_ENDPOINT, headers=cls.REQUEST_HEADERS, data=json.dumps(payload) ) response_content = response.content if response.status_code == 202: return True except Exception: - logger.error('Lead Extraction Failed, Exception occurred!!', exc_info=True) + logger.error("Lead Extraction Failed, Exception occurred!!", exc_info=True) logger.error( - 'Lead Extraction Request Failed!!', - extra={ - 'data': { - 'payload': payload, - 'response': response_content - } - }, + "Lead Extraction Request Failed!!", + extra={"data": {"payload": payload, "response": response_content}}, ) @classmethod @@ -617,8 +587,8 @@ def trigger_lead_extract(cls, lead, task_instance=None): success = cls.send_trigger_request_to_extractor( [ { - 'url': url_to_extract, - 'client_id': cls.get_client_id(lead), + "url": url_to_extract, + "client_id": cls.get_client_id(lead), } ], cls.get_callback_url(), @@ -646,22 +616,19 @@ def save_data( # and create new one LeadPreview.objects.create( lead=lead, - text_extract=RequestHelper(url=text_source_uri, ignore_error=True).get_text(sanitize=True) or '', + text_extract=RequestHelper(url=text_source_uri, ignore_error=True).get_text(sanitize=True) or "", word_count=word_count, page_count=page_count, text_extraction_id=text_extraction_id, ) # Save extracted images as LeadPreviewImage instances # TODO: The logic is same for unified_connector leads as well. Maybe have a single func? - image_base_path = f'{lead.pk}' + image_base_path = f"{lead.pk}" for image_uri in images_uri: lead_image = LeadPreviewImage(lead=lead) image_obj = RequestHelper(url=image_uri, ignore_error=True).get_file() if image_obj: - lead_image.file.save( - os.path.join(image_base_path, os.path.basename(urlparse(image_uri).path)), - image_obj - ) + lead_image.file.save(os.path.join(image_base_path, os.path.basename(urlparse(image_uri).path)), image_obj) lead_image.save() lead.update_extraction_status(Lead.ExtractionStatus.SUCCESS) return lead @@ -697,7 +664,7 @@ def save_lead_data_using_connector_lead( class UnifiedConnectorLeadHandler(BaseHandler): model = ConnectorLead - callback_url_name = 'unified_connector_lead_extract_callback' + callback_url_name = "unified_connector_lead_extract_callback" @staticmethod def save_data( @@ -708,11 +675,11 @@ def save_data( page_count: int, text_extraction_id: str, ): - connector_lead.simplified_text = RequestHelper(url=text_source_uri, ignore_error=True).get_text(sanitize=True) or '' + connector_lead.simplified_text = RequestHelper(url=text_source_uri, ignore_error=True).get_text(sanitize=True) or "" connector_lead.word_count = word_count connector_lead.page_count = page_count connector_lead.text_extraction_id = text_extraction_id - image_base_path = f'{connector_lead.pk}' + image_base_path = f"{connector_lead.pk}" for image_uri in images_uri: lead_image = ConnectorLeadPreviewImage(connector_lead=connector_lead) image_obj = RequestHelper(url=image_uri, ignore_error=True).get_file() @@ -732,7 +699,7 @@ def _process_unified_source(cls, source): source_fetcher = source.source_fetcher() leads, _ = source_fetcher.get_leads(params, source.created_by) - current_source_leads_id = set(source.source_leads.values_list('connector_lead_id', flat=True)) + current_source_leads_id = set(source.source_leads.values_list("connector_lead_id", flat=True)) for connector_lead in leads: connector_lead, _ = ConnectorLead.get_or_create_from_lead(connector_lead) if connector_lead.id not in current_source_leads_id: @@ -744,8 +711,8 @@ def _send_trigger_request_to_extraction(cls, connector_leads: List[ConnectorLead return LeadExtractionHandler.send_trigger_request_to_extractor( [ { - 'url': connector_lead.url, - 'client_id': cls.get_client_id(connector_lead), + "url": connector_lead.url, + "client_id": cls.get_client_id(connector_lead), } for connector_lead in connector_leads ], @@ -757,28 +724,28 @@ def _send_trigger_request_to_extraction(cls, connector_leads: List[ConnectorLead @classmethod def send_retry_trigger_request_to_extractor( - cls, connector_leads_qs: models.QuerySet[ConnectorLead], + cls, + connector_leads_qs: models.QuerySet[ConnectorLead], chunk_size=500, ) -> int: connector_leads = list( # Fetch all now - connector_leads_qs - .filter(extraction_status=ConnectorLead.ExtractionStatus.RETRYING) - .only('id', 'url').distinct()[:chunk_size] + connector_leads_qs.filter(extraction_status=ConnectorLead.ExtractionStatus.RETRYING) + .only("id", "url") + .distinct()[:chunk_size] ) extraction_status = ConnectorLead.ExtractionStatus.RETRYING if cls._send_trigger_request_to_extraction(connector_leads): # True if request is successfully send extraction_status = ConnectorLead.ExtractionStatus.STARTED - ConnectorLead.objects\ - .filter(pk__in=[c.pk for c in connector_leads])\ - .update(extraction_status=extraction_status) + ConnectorLead.objects.filter(pk__in=[c.pk for c in connector_leads]).update(extraction_status=extraction_status) return len(connector_leads) @classmethod def send_trigger_request_to_extractor(cls, connector_leads_qs: models.QuerySet[ConnectorLead]): paginator = Paginator( - connector_leads_qs.filter( - extraction_status=ConnectorLead.ExtractionStatus.PENDING - ).only('id', 'url').order_by('id').distinct(), + connector_leads_qs.filter(extraction_status=ConnectorLead.ExtractionStatus.PENDING) + .only("id", "url") + .order_by("id") + .distinct(), 100, ) processed = 0 @@ -793,26 +760,24 @@ def send_trigger_request_to_extractor(cls, connector_leads_qs: models.QuerySet[C if cls._send_trigger_request_to_extraction(connector_leads): # True if request is successfully send extraction_status = ConnectorLead.ExtractionStatus.STARTED processed += len(connector_leads) - ConnectorLead.objects\ - .filter(pk__in=[c.pk for c in connector_leads])\ - .update(extraction_status=extraction_status) + ConnectorLead.objects.filter(pk__in=[c.pk for c in connector_leads]).update(extraction_status=extraction_status) return processed @classmethod def process_unified_connector_source(cls, source): source.status = ConnectorSource.Status.PROCESSING source.start_date = timezone.now() - source.save(update_fields=('status', 'start_date')) - update_fields = ['status', 'last_fetched_at', 'end_date'] + source.save(update_fields=("status", "start_date")) + update_fields = ["status", "last_fetched_at", "end_date"] try: # Fetch leads cls._process_unified_source(source) source.status = ConnectorSource.Status.SUCCESS source.generate_stats(commit=False) - update_fields.append('stats') + update_fields.append("stats") except Exception: source.status = ConnectorSource.Status.FAILURE - logger.error(f'Failed to process source: {source}', exc_info=True) + logger.error(f"Failed to process source: {source}", exc_info=True) source.last_fetched_at = timezone.now() source.end_date = timezone.now() source.save(update_fields=update_fields) @@ -821,7 +786,7 @@ def process_unified_connector_source(cls, source): def process_unified_connector(cls, unified_connector_id): unified_connector = UnifiedConnector.objects.get(pk=unified_connector_id) if not unified_connector.is_active: - logger.warning(f'Skippping processing for inactive connector (pk:{unified_connector.pk}) {unified_connector}') + logger.warning(f"Skippping processing for inactive connector (pk:{unified_connector.pk}) {unified_connector}") return for source in unified_connector.sources.all(): cls.process_unified_connector_source(source) @@ -837,14 +802,12 @@ class NewNlpServerBaseHandler(BaseHandler): @classmethod def get_callback_url(cls, **kwargs): - return ( - settings.DEEPL_SERVER_CALLBACK_DOMAIN + - reverse( - cls.callback_url_name, kwargs={ - 'version': 'v1', - **kwargs, - }, - ) + return settings.DEEPL_SERVER_CALLBACK_DOMAIN + reverse( + cls.callback_url_name, + kwargs={ + "version": "v1", + **kwargs, + }, ) @classmethod @@ -855,14 +818,12 @@ def get_trigger_payload(cls, _: DeeplTrackBaseModel) -> dict: def send_trigger_request_to_extractor(cls, obj: DeeplTrackBaseModel): # Base payload attributes payload = { - 'mock': settings.DEEPL_SERVER_AS_MOCK, - 'client_id': cls.get_client_id(obj), - 'callback_url': cls.get_callback_url(), + "mock": settings.DEEPL_SERVER_AS_MOCK, + "client_id": cls.get_client_id(obj), + "callback_url": cls.get_callback_url(), } # Additional payload attributes - payload.update( - cls.get_trigger_payload(obj) - ) + payload.update(cls.get_trigger_payload(obj)) try: response = requests.post( @@ -872,42 +833,37 @@ def send_trigger_request_to_extractor(cls, obj: DeeplTrackBaseModel): ) if response.status_code == 202: obj.status = cls.model.Status.STARTED - obj.save(update_fields=('status',)) + obj.save(update_fields=("status",)) return True except Exception: - logger.error(f'{cls.model.__name__} send failed, Exception occurred!!', exc_info=True) - _response = locals().get('response') + logger.error(f"{cls.model.__name__} send failed, Exception occurred!!", exc_info=True) + _response = locals().get("response") error_extra_context = { - 'payload': payload, + "payload": payload, } if _response is not None: - error_extra_context.update({ - 'response': _response.content, - 'response_status_code': _response.status_code, - }) - logger.error( - f'{cls.model.__name__} send failed!!', - extra={ - 'data': { - 'context': error_extra_context + error_extra_context.update( + { + "response": _response.content, + "response_status_code": _response.status_code, } - } - ) + ) + logger.error(f"{cls.model.__name__} send failed!!", extra={"data": {"context": error_extra_context}}) obj.status = cls.model.Status.SEND_FAILED - obj.save(update_fields=('status',)) + obj.save(update_fields=("status",)) class AnalysisTopicModelHandler(NewNlpServerBaseHandler): model = TopicModel endpoint = DeeplServiceEndpoint.ANALYSIS_TOPIC_MODEL - callback_url_name = 'analysis_topic_model_callback' + callback_url_name = "analysis_topic_model_callback" @classmethod def get_trigger_payload(cls, obj: TopicModel): return { - 'entries_url': generate_file_url_for_new_deepl_server(obj.entries_file), - 'cluster_size': settings.ANALYTICAL_ENTRIES_COUNT, - 'max_clusters_num': settings.ANALYTICAL_STATEMENT_COUNT, + "entries_url": generate_file_url_for_new_deepl_server(obj.entries_file), + "cluster_size": settings.ANALYTICAL_ENTRIES_COUNT, + "max_clusters_num": settings.ANALYTICAL_STATEMENT_COUNT, } @staticmethod @@ -915,20 +871,19 @@ def save_data( topic_model: TopicModel, data: dict, ): - data_url = data['presigned_s3_url'] + data_url = data["presigned_s3_url"] entries_data = RequestHelper(url=data_url, custom_error_handler=custom_error_handler).json() if entries_data: # Clear existing TopicModelCluster.objects.filter(topic_model=topic_model).delete() # Create new cluster in bulk - new_clusters = TopicModelCluster.objects.bulk_create([ - TopicModelCluster(topic_model=topic_model, title=_['label']) - for _ in entries_data.values() - ]) + new_clusters = TopicModelCluster.objects.bulk_create( + [TopicModelCluster(topic_model=topic_model, title=_["label"]) for _ in entries_data.values()] + ) # Create new cluster-entry relation in bulk new_cluster_entries = [] for cluster, entries_id in zip(new_clusters, entries_data.values()): - for entry_id in entries_id['entry_id']: + for entry_id in entries_id["entry_id"]: new_cluster_entries.append( TopicModelCluster.entries.through( topicmodelcluster=cluster, @@ -943,12 +898,12 @@ def save_data( class AnalysisAutomaticSummaryHandler(NewNlpServerBaseHandler): model = AutomaticSummary endpoint = DeeplServiceEndpoint.ANALYSIS_AUTOMATIC_SUMMARY - callback_url_name = 'analysis_automatic_summary_callback' + callback_url_name = "analysis_automatic_summary_callback" @classmethod def get_trigger_payload(cls, obj: AutomaticSummary): return { - 'entries_url': generate_file_url_for_new_deepl_server(obj.entries_file), + "entries_url": generate_file_url_for_new_deepl_server(obj.entries_file), } @staticmethod @@ -956,7 +911,7 @@ def save_data( a_summary: AutomaticSummary, data: dict, ): - data_url = data['presigned_s3_url'] + data_url = data["presigned_s3_url"] summary_text = RequestHelper(url=data_url, custom_error_handler=custom_error_handler).get_text() a_summary.status = AutomaticSummary.Status.SUCCESS a_summary.summary = summary_text @@ -966,13 +921,13 @@ def save_data( class AnalyticalStatementNGramHandler(NewNlpServerBaseHandler): model = AnalyticalStatementNGram endpoint = DeeplServiceEndpoint.ANALYSIS_AUTOMATIC_NGRAM - callback_url_name = 'analysis_automatic_ngram_callback' + callback_url_name = "analysis_automatic_ngram_callback" @classmethod def get_trigger_payload(cls, obj: AnalyticalStatementNGram): return { - 'entries_url': generate_file_url_for_new_deepl_server(obj.entries_file), - 'ngrams_config': {}, + "entries_url": generate_file_url_for_new_deepl_server(obj.entries_file), + "ngrams_config": {}, } @staticmethod @@ -980,12 +935,12 @@ def save_data( a_ngram: AnalyticalStatementNGram, data: dict, ): - data_url = data['presigned_s3_url'] + data_url = data["presigned_s3_url"] ngram_data = RequestHelper(url=data_url, custom_error_handler=custom_error_handler).json() if ngram_data: - a_ngram.unigrams = ngram_data.get('unigrams') or {} - a_ngram.bigrams = ngram_data.get('bigrams') or {} - a_ngram.trigrams = ngram_data.get('trigrams') or {} + a_ngram.unigrams = ngram_data.get("unigrams") or {} + a_ngram.bigrams = ngram_data.get("bigrams") or {} + a_ngram.trigrams = ngram_data.get("trigrams") or {} a_ngram.status = AnalyticalStatementNGram.Status.SUCCESS a_ngram.save() @@ -993,12 +948,12 @@ def save_data( class AnalyticalStatementGeoHandler(NewNlpServerBaseHandler): model = AnalyticalStatementGeoTask endpoint = DeeplServiceEndpoint.ANALYSIS_GEO - callback_url_name = 'analysis_geo_callback' + callback_url_name = "analysis_geo_callback" @classmethod def get_trigger_payload(cls, obj: AnalyticalStatementNGram): return { - 'entries_url': generate_file_url_for_new_deepl_server(obj.entries_file), + "entries_url": generate_file_url_for_new_deepl_server(obj.entries_file), } @staticmethod @@ -1006,26 +961,20 @@ def save_data( geo_task: AnalyticalStatementGeoTask, data: dict, ): - data_url = data['presigned_s3_url'] + data_url = data["presigned_s3_url"] geo_data = RequestHelper(url=data_url, custom_error_handler=custom_error_handler).json() if geo_data is not None: geo_entry_objs = [] # Clear out existing - AnalyticalStatementGeoEntry.objects.filter( - task=geo_task - ).delete() + AnalyticalStatementGeoEntry.objects.filter(task=geo_task).delete() existing_entries_id = set( Entry.objects.filter( - project=geo_task.project, - id__in=[ - int(entry_geo_data['entry_id']) - for entry_geo_data in geo_data - ] - ).values_list('id', flat=True) + project=geo_task.project, id__in=[int(entry_geo_data["entry_id"]) for entry_geo_data in geo_data] + ).values_list("id", flat=True) ) for entry_geo_data in geo_data: - entry_id = int(entry_geo_data['entry_id']) - data = entry_geo_data.get('locations') + entry_id = int(entry_geo_data["entry_id"]) + data = entry_geo_data.get("locations") if data and entry_id in existing_entries_id: geo_entry_objs.append( AnalyticalStatementGeoEntry( @@ -1039,4 +988,4 @@ def save_data( geo_task.status = AnalyticalStatementGeoTask.Status.SUCCESS else: geo_task.status = AnalyticalStatementGeoTask.Status.FAILED - geo_task.save(update_fields=('status',)) + geo_task.save(update_fields=("status",)) diff --git a/apps/deepl_integration/models.py b/apps/deepl_integration/models.py index d1fd49a154..1e0a31bcc2 100644 --- a/apps/deepl_integration/models.py +++ b/apps/deepl_integration/models.py @@ -5,12 +5,13 @@ class DeeplTrackBaseModel(models.Model): """ Provide basic fields which are consistent between NLP related models """ + class Status(models.IntegerChoices): - PENDING = 0, 'Pending' - STARTED = 1, 'Started' # INITIATED in deepl side - SUCCESS = 2, 'Success' - FAILED = 3, 'Failed' - SEND_FAILED = 4, 'Send Failed' + PENDING = 0, "Pending" + STARTED = 1, "Started" # INITIATED in deepl side + SUCCESS = 2, "Success" + FAILED = 3, "Failed" + SEND_FAILED = 4, "Send Failed" status = models.PositiveSmallIntegerField(choices=Status.choices, default=Status.PENDING) diff --git a/apps/deepl_integration/serializers.py b/apps/deepl_integration/serializers.py index c4d9f98d15..8eedbc327d 100644 --- a/apps/deepl_integration/serializers.py +++ b/apps/deepl_integration/serializers.py @@ -1,34 +1,29 @@ -from typing import Type import logging -from rest_framework import serializers - -from django.db import transaction, models +from typing import Type +from analysis.models import ( + AnalyticalStatementGeoTask, + AnalyticalStatementNGram, + AutomaticSummary, + TopicModel, +) +from assisted_tagging.models import AssistedTaggingPrediction, DraftEntry +from deduplication.tasks.indexing import index_lead_and_calculate_duplicates from deepl_integration.handlers import ( - BaseHandler, + AnalysisAutomaticSummaryHandler, + AnalysisTopicModelHandler, + AnalyticalStatementGeoHandler, + AnalyticalStatementNGramHandler, AssistedTaggingDraftEntryHandler, + AutoAssistedTaggingDraftEntryHandler, + BaseHandler, LeadExtractionHandler, UnifiedConnectorLeadHandler, - AnalysisTopicModelHandler, - AnalysisAutomaticSummaryHandler, - AnalyticalStatementNGramHandler, - AnalyticalStatementGeoHandler, - AutoAssistedTaggingDraftEntryHandler ) - -from deduplication.tasks.indexing import index_lead_and_calculate_duplicates -from assisted_tagging.models import ( - AssistedTaggingPrediction, - DraftEntry, -) -from unified_connector.models import ConnectorLead +from django.db import models, transaction from lead.models import Lead -from analysis.models import ( - TopicModel, - AutomaticSummary, - AnalyticalStatementNGram, - AnalyticalStatementGeoTask, -) +from rest_framework import serializers +from unified_connector.models import ConnectorLead from .models import DeeplTrackBaseModel @@ -41,14 +36,16 @@ class BaseCallbackSerializer(serializers.Serializer): client_id = serializers.CharField() def validate(self, data): - client_id = data['client_id'] + client_id = data["client_id"] try: - data['object'] = self.nlp_handler.get_object_using_client_id(client_id) + data["object"] = self.nlp_handler.get_object_using_client_id(client_id) except Exception: - logger.error('Failed to parse client id', exc_info=True) - raise serializers.ValidationError({ - 'client_id': 'Failed to parse client id', - }) + logger.error("Failed to parse client id", exc_info=True) + raise serializers.ValidationError( + { + "client_id": "Failed to parse client id", + } + ) return data @@ -56,9 +53,9 @@ class DeeplServerBaseCallbackSerializer(BaseCallbackSerializer): class Status(models.IntegerChoices): # NOTE: Defined by NLP # INITIATED = 1, 'Initiated' # Not needed or used by deep - SUCCESS = 2, 'Success' - FAILED = 3, 'Failed' - INPUT_URL_PROCESS_FAILED = 4, 'Input url process failed' + SUCCESS = 2, "Success" + FAILED = 3, "Failed" + INPUT_URL_PROCESS_FAILED = 4, "Input url process failed" status = serializers.ChoiceField(choices=Status.choices) @@ -68,11 +65,13 @@ class LeadExtractCallbackSerializer(DeeplServerBaseCallbackSerializer): """ Serialize deepl extractor """ + url = serializers.CharField(required=False) # Data fields images_path = serializers.ListField( child=serializers.CharField(allow_blank=True), - required=False, default=[], + required=False, + default=[], ) text_path = serializers.CharField(required=False, allow_null=True) total_words_count = serializers.IntegerField(required=False, default=0, allow_null=True) @@ -83,32 +82,28 @@ class LeadExtractCallbackSerializer(DeeplServerBaseCallbackSerializer): def validate(self, data): data = super().validate(data) # Additional validation - if data['status'] == self.Status.SUCCESS and data.get('text_path') in [None, '']: - raise serializers.ValidationError({ - 'text_path': 'text_path is required when extraction status is success' - }) - if data['status'] == self.Status.SUCCESS: + if data["status"] == self.Status.SUCCESS and data.get("text_path") in [None, ""]: + raise serializers.ValidationError({"text_path": "text_path is required when extraction status is success"}) + if data["status"] == self.Status.SUCCESS: errors = {} - for key in ['text_path', 'total_words_count', 'total_pages', 'text_extraction_id']: + for key in ["text_path", "total_words_count", "total_pages", "text_extraction_id"]: if key not in data or data[key] is None: - errors[key] = ( - f"<{key=} or {data.get('key')=}> is missing. Required when the extraction status is Success" - ) + errors[key] = f"<{key=} or {data.get('key')=}> is missing. Required when the extraction status is Success" if errors: raise serializers.ValidationError(errors) return data def create(self, data): - success = data['status'] == self.Status.SUCCESS - lead = data['object'] # Added from validate + success = data["status"] == self.Status.SUCCESS + lead = data["object"] # Added from validate if success: lead = self.nlp_handler.save_data( lead, - data['text_path'], - data.get('images_path', [])[:10], # TODO: Support for more images, too much image will error. - data.get('total_words_count'), - data.get('total_pages'), - data.get('text_extraction_id'), + data["text_path"], + data.get("images_path", [])[:10], # TODO: Support for more images, too much image will error. + data.get("total_words_count"), + data.get("total_pages"), + data.get("text_extraction_id"), ) # Add to deduplication index transaction.on_commit(lambda: index_lead_and_calculate_duplicates.delay(lead.id)) @@ -122,10 +117,12 @@ class UnifiedConnectorLeadExtractCallbackSerializer(DeeplServerBaseCallbackSeria """ Serialize deepl extractor """ + # Data fields images_path = serializers.ListField( child=serializers.CharField(allow_blank=True), - required=False, default=[], + required=False, + default=[], ) text_path = serializers.CharField(required=False, allow_null=True) total_words_count = serializers.IntegerField(required=False, default=0, allow_null=True) @@ -136,28 +133,26 @@ class UnifiedConnectorLeadExtractCallbackSerializer(DeeplServerBaseCallbackSeria def validate(self, data): data = super().validate(data) - if data['status'] == self.Status.SUCCESS: + if data["status"] == self.Status.SUCCESS: errors = {} - for key in ['text_path', 'total_words_count', 'total_pages', 'text_extraction_id']: + for key in ["text_path", "total_words_count", "total_pages", "text_extraction_id"]: if key not in data or data[key] is None: - errors[key] = ( - f"<{key=} or {data.get('key')=}> is missing. Required when the extraction status is Success" - ) + errors[key] = f"<{key=} or {data.get('key')=}> is missing. Required when the extraction status is Success" if errors: raise serializers.ValidationError(errors) return data def create(self, data): - success = data['status'] == self.Status.SUCCESS - connector_lead = data['object'] # Added from validate + success = data["status"] == self.Status.SUCCESS + connector_lead = data["object"] # Added from validate if success: return self.nlp_handler.save_data( connector_lead, - data['text_path'], - data.get('images_path', [])[:10], # TODO: Support for more images, to much image will error. - data['total_words_count'], - data['total_pages'], - data['text_extraction_id'], + data["text_path"], + data.get("images_path", [])[:10], # TODO: Support for more images, to much image will error. + data["total_words_count"], + data["total_pages"], + data["text_extraction_id"], ) connector_lead.update_extraction_status(ConnectorLead.ExtractionStatus.FAILED) return connector_lead @@ -208,6 +203,7 @@ class ModelPredictionCallbackSerializerTagValue(serializers.Serializer): required=False, ) is_selected = serializers.BooleanField() + values = serializers.ListSerializer( child=serializers.CharField(), required=False, @@ -227,7 +223,7 @@ class AssistedTaggingDraftEntryPredictionCallbackSerializer(BaseCallbackSerializ nlp_handler = AssistedTaggingDraftEntryHandler def create(self, validated_data): - draft_entry = validated_data['object'] + draft_entry = validated_data["object"] if draft_entry.prediction_status == DraftEntry.PredictionStatus.DONE: return draft_entry return self.nlp_handler.save_data( @@ -253,10 +249,10 @@ class AutoAssistedTaggingDraftEntryCallbackSerializer(BaseCallbackSerializer): nlp_handler = AutoAssistedTaggingDraftEntryHandler def create(self, validated_data): - obj = validated_data['object'] + obj = validated_data["object"] return self.nlp_handler.save_data( obj, - validated_data['entry_extraction_classification_path'], + validated_data["entry_extraction_classification_path"], ) @@ -265,12 +261,12 @@ class EntriesCollectionBaseCallbackSerializer(DeeplServerBaseCallbackSerializer) presigned_s3_url = serializers.URLField() def create(self, validated_data): - obj = validated_data['object'] - if validated_data['status'] == self.Status.SUCCESS: + obj = validated_data["object"] + if validated_data["status"] == self.Status.SUCCESS: self.nlp_handler.save_data(obj, validated_data) else: obj.status = self.model.Status.FAILED - obj.save(update_fields=('status',)) + obj.save(update_fields=("status",)) return obj diff --git a/apps/deepl_integration/views.py b/apps/deepl_integration/views.py index 264d0ce6ec..f981fffc1c 100644 --- a/apps/deepl_integration/views.py +++ b/apps/deepl_integration/views.py @@ -1,21 +1,16 @@ from typing import Type -from rest_framework import ( - views, - permissions, - response, - status, - serializers, -) + +from rest_framework import permissions, response, serializers, status, views from .serializers import ( + AnalysisAutomaticSummaryCallbackSerializer, + AnalysisTopicModelCallbackSerializer, + AnalyticalStatementGeoCallbackSerializer, + AnalyticalStatementNGramCallbackSerializer, AssistedTaggingDraftEntryPredictionCallbackSerializer, + AutoAssistedTaggingDraftEntryCallbackSerializer, LeadExtractCallbackSerializer, UnifiedConnectorLeadExtractCallbackSerializer, - AnalysisTopicModelCallbackSerializer, - AnalysisAutomaticSummaryCallbackSerializer, - AnalyticalStatementNGramCallbackSerializer, - AnalyticalStatementGeoCallbackSerializer, - AutoAssistedTaggingDraftEntryCallbackSerializer ) diff --git a/apps/docs/inspectors.py b/apps/docs/inspectors.py index c8785ebaa3..ef33767d54 100644 --- a/apps/docs/inspectors.py +++ b/apps/docs/inspectors.py @@ -1,23 +1,24 @@ from collections import OrderedDict -from django.test.client import RequestFactory -from django.db import models from django.core.exceptions import FieldDoesNotExist +from django.db import models +from django.test.client import RequestFactory from rest_framework import exceptions, serializers from rest_framework.compat import uritemplate +from user.models import User from deep.serializers import RecursiveSerializer -from user.models import User from utils.common import to_camelcase -from .utils import is_list_view, is_custom_action + from . import schema +from .utils import is_custom_action, is_list_view def format_field_name(field_name, required, camelcase): if camelcase: field_name = to_camelcase(field_name) if required: - field_name += '*' + field_name += "*" return field_name @@ -37,13 +38,12 @@ def field_to_schema(field, camelcase=True): elif isinstance(field, serializers.Serializer): return schema.Object( - properties=OrderedDict([ - ( - format_field_name(value.field_name, - value.required, camelcase), - field_to_schema(value) - ) for value in field.fields.values() - ]), + properties=OrderedDict( + [ + (format_field_name(value.field_name, value.required, camelcase), field_to_schema(value)) + for value in field.fields.values() + ] + ), ) elif isinstance(field, serializers.ManyRelatedField): @@ -86,40 +86,35 @@ def field_to_schema(field, camelcase=True): elif isinstance(field, (serializers.FileField, serializers.ImageField)): return schema.File() - if field.style.get('base_template') == 'textarea.html': - return schema.String( - format='textarea' - ) + if field.style.get("base_template") == "textarea.html": + return schema.String(format="textarea") return schema.String() def get_pk_description(model, model_field): if isinstance(model_field, models.AutoField): - value_type = 'unique integer value' + value_type = "unique integer value" elif isinstance(model_field, models.UUIDField): - value_type = 'UUID string' + value_type = "UUID string" else: - value_type = 'unique value' + value_type = "unique value" - return 'A {value_type} identifying this {title}'.format( + return "A {value_type} identifying this {title}".format( value_type=value_type, title=model._meta.verbose_name, ) class Field: - def __init__(self, - title='', - required=False, - schema=None): + def __init__(self, title="", required=False, schema=None): self.title = title self.required = required self.schema = schema def __str__(self): if self.required: - return self.title + '*' + return self.title + "*" return self.title def __repr__(self): @@ -146,7 +141,7 @@ def get_path_fields(self): view = self.view path = self.path - model = getattr(getattr(view, 'queryset', None), 'model', None) + model = getattr(getattr(view, "queryset", None), "model", None) for variable in uritemplate.variables(path): schema_cls = schema.String @@ -167,41 +162,36 @@ def get_path_fields(self): # elif model_field.primary_key: # description = get_pk_description(model, model_field) - if hasattr(view, 'lookup_value_regex') and \ - view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex + if hasattr(view, "lookup_value_regex") and view.lookup_field == variable: + kwargs["pattern"] = view.lookup_value_regex if isinstance(model_field, models.AutoField): schema_cls = schema.Integer # Check other field types ? It's mostly string though - field = Field( - title=variable, - required=True, - schema=schema_cls(**kwargs) - ) + field = Field(title=variable, required=True, schema=schema_cls(**kwargs)) self.path_fields.append(field) def get_serializer_fields(self): view = self.view method = self.method - if hasattr(view, 'action') and is_custom_action(view.action): + if hasattr(view, "action") and is_custom_action(view.action): action = getattr(view, view.action) - if getattr(action, 'delete_view', False): + if getattr(action, "delete_view", False): return - if method in ('DELETE',): + if method in ("DELETE",): return - takes_request = method in ('PUT', 'PATCH', 'POST') + takes_request = method in ("PUT", "PATCH", "POST") - if not hasattr(view, 'get_serializer'): + if not hasattr(view, "get_serializer"): return try: view.request = RequestFactory() - view.request.user = User(username='test') + view.request.user = User(username="test") serializer = view.get_serializer() except exceptions.APIException: serializer = None @@ -226,7 +216,7 @@ def get_serializer_fields(self): if isinstance(field, serializers.HiddenField): continue - required = field.required and method != 'PATCH' + required = field.required and method != "PATCH" out_field = Field( title=to_camelcase(field.field_name), required=required, @@ -246,35 +236,40 @@ def handle_pagination(self): if is_list_view(self.path, self.method, self.view): response_fields = [] - response_fields.append(Field( - title='count', - required=True, - schema=schema.Integer(), - )) - - response_fields.append(Field( - title='next', - required=False, - schema=schema.URL(), - )) - - response_fields.append(Field( - title='previous', - required=False, - schema=schema.URL(), - )) - - response_fields.append(Field( - title='results', - required=True, - schema=schema.Array( - items=schema.Object( - properties=OrderedDict([ - (str(field), field.schema) for field in - self.response_fields - ]) + response_fields.append( + Field( + title="count", + required=True, + schema=schema.Integer(), + ) + ) + + response_fields.append( + Field( + title="next", + required=False, + schema=schema.URL(), + ) + ) + + response_fields.append( + Field( + title="previous", + required=False, + schema=schema.URL(), + ) + ) + + response_fields.append( + Field( + title="results", + required=True, + schema=schema.Array( + items=schema.Object( + properties=OrderedDict([(str(field), field.schema) for field in self.response_fields]) + ), ), - ), - )) + ) + ) self.response_fields = response_fields diff --git a/apps/docs/utils.py b/apps/docs/utils.py index 86b6e76bd3..f989ddfa2c 100644 --- a/apps/docs/utils.py +++ b/apps/docs/utils.py @@ -5,6 +5,7 @@ def mark_as_list(): def decorator(func): func.list_view = True return func + return decorator @@ -12,13 +13,12 @@ def mark_as_delete(): def decorator(func): func.delete_view = True return func + return decorator def is_custom_action(action): - return action not in set([ - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - ]) + return action not in set(["retrieve", "list", "create", "update", "partial_update", "destroy"]) def is_list_view(path, method, view): @@ -26,24 +26,23 @@ def is_list_view(path, method, view): Return True if the given path/method appears to represent a list view. """ - if hasattr(view, 'action'): + if hasattr(view, "action"): # Viewsets have an explicitly defined action, which we can inspect. # If a custom action, check if the detail attribute is set # otherwise check if the action is `list`. if is_custom_action(view.action): action = getattr(view, view.action) - return getattr(action, 'list_view', False) or \ - not action.detail - return view.action == 'list' + return getattr(action, "list_view", False) or not action.detail + return view.action == "list" - if method.lower() != 'get': + if method.lower() != "get": return False if isinstance(view, RetrieveModelMixin): return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: + path_components = path.strip("/").split("/") + if path_components and "{" in path_components[-1]: return False return True diff --git a/apps/entry/admin.py b/apps/entry/admin.py index 13512843aa..0dd5f0132f 100644 --- a/apps/entry/admin.py +++ b/apps/entry/admin.py @@ -1,28 +1,24 @@ import reversion - -from reversion.admin import VersionAdmin from admin_auto_filters.filters import AutocompleteFilterFactory from django.contrib import admin - -from deep.admin import query_buttons - -from entry.models import ( - Entry, +from entry.models import ( # Entry Group Attribute, - FilterData, - ExportData, + Entry, EntryComment, - - # Entry Group - ProjectEntryLabel, + ExportData, + FilterData, LeadEntryGroup, + ProjectEntryLabel, ) +from reversion.admin import VersionAdmin + +from deep.admin import query_buttons class AttributeInline(admin.StackedInline): model = Attribute extra = 0 - raw_id_fields = ('widget',) + raw_id_fields = ("widget",) class EntryCommentInline(admin.TabularInline): @@ -33,57 +29,62 @@ class EntryCommentInline(admin.TabularInline): class FilterDataInline(admin.StackedInline): model = FilterData extra = 0 - raw_id_fields = ('filter',) + raw_id_fields = ("filter",) class ExportDataInline(admin.StackedInline): model = ExportData extra = 0 - raw_id_fields = ('exportable',) + raw_id_fields = ("exportable",) @admin.register(Entry) class EntryAdmin(VersionAdmin): - custom_inlines = [('attribute', AttributeInline), - ('filter', FilterDataInline), - ('exportable', ExportDataInline), - ('Entry Comment', EntryCommentInline), - ] + custom_inlines = [ + ("attribute", AttributeInline), + ("filter", FilterDataInline), + ("exportable", ExportDataInline), + ("Entry Comment", EntryCommentInline), + ] list_display = [ - 'lead', 'project', 'created_by', 'created_at', - query_buttons('View', [inline[0] for inline in custom_inlines]), + "lead", + "project", + "created_by", + "created_at", + query_buttons("View", [inline[0] for inline in custom_inlines]), ] - search_fields = ('lead__title',) - list_filter = ( - AutocompleteFilterFactory('Project', 'project'), - AutocompleteFilterFactory('User', 'created_by'), - 'created_at' - ) + search_fields = ("lead__title",) + list_filter = (AutocompleteFilterFactory("Project", "project"), AutocompleteFilterFactory("User", "created_by"), "created_at") autocomplete_fields = ( - 'lead', 'project', 'created_by', 'modified_by', 'analysis_framework', 'tabular_field', - 'image', 'controlled_changed_by', 'verified_by', + "lead", + "project", + "created_by", + "modified_by", + "analysis_framework", + "tabular_field", + "image", + "controlled_changed_by", + "verified_by", ) - ordering = ('project', 'created_by', 'created_at') + ordering = ("project", "created_by", "created_at") def get_queryset(self, request): - return Entry.objects.select_related('project', 'created_by', 'lead') + return Entry.objects.select_related("project", "created_by", "lead") def get_inline_instances(self, request, obj=None): inlines = [] for name, inline in self.custom_inlines: - if request.GET.get(f'show_{name}', 'False').lower() == 'true': + if request.GET.get(f"show_{name}", "False").lower() == "true": inlines.append(inline(self.model, self.admin_site)) return inlines @admin.register(ProjectEntryLabel) class ProjectEntryLabelAdmin(VersionAdmin): - search_fields = ('title',) - autocomplete_fields = ('created_by', 'modified_by', 'project') - list_filter = ( - AutocompleteFilterFactory('Project', 'project'), - ) - list_display = ('__str__', 'color') + search_fields = ("title",) + autocomplete_fields = ("created_by", "modified_by", "project") + list_filter = (AutocompleteFilterFactory("Project", "project"),) + list_display = ("__str__", "color") reversion.register(LeadEntryGroup) diff --git a/apps/entry/apps.py b/apps/entry/apps.py index 105d99eee9..ecc718afa4 100644 --- a/apps/entry/apps.py +++ b/apps/entry/apps.py @@ -2,7 +2,7 @@ class EntryConfig(AppConfig): - name = 'entry' + name = "entry" def ready(self): import entry.receivers # noqa diff --git a/apps/entry/dataloaders.py b/apps/entry/dataloaders.py index 722ac52212..c1afd4819e 100644 --- a/apps/entry/dataloaders.py +++ b/apps/entry/dataloaders.py @@ -1,39 +1,31 @@ from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property -from django.db import models - -from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin - from analysis_framework.models import Widget +from django.db import models +from django.utils.functional import cached_property +from geo.schema import get_geo_area_queryset_for_project_geo_area_type +from promise import Promise from quality_assurance.models import EntryReviewComment -from geo.schema import get_geo_area_queryset_for_project_geo_area_type +from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from .models import ( - Entry, - Attribute, - EntryGroupLabel, -) +from .models import Attribute, Entry, EntryGroupLabel class EntryLoader(DataLoaderWithContext): def batch_load_fn(self, keys): entry_qs = Entry.objects.filter(id__in=keys) - _map = { - entry.id: entry - for entry in entry_qs - } + _map = {entry.id: entry for entry in entry_qs} return Promise.resolve([_map[key] for key in keys]) class EntryAttributesLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - attributes_qs = Attribute.objects\ - .filter(entry__in=keys)\ - .exclude(widget__widget_id__in=Widget.DEPRECATED_TYPES)\ - .annotate(widget_type=models.F('widget__widget_id')) + attributes_qs = ( + Attribute.objects.filter(entry__in=keys) + .exclude(widget__widget_id__in=Widget.DEPRECATED_TYPES) + .annotate(widget_type=models.F("widget__widget_id")) + ) attributes = defaultdict(list) for attribute in attributes_qs: attributes[attribute.entry_id].append(attribute) @@ -42,34 +34,20 @@ def batch_load_fn(self, keys): class EntryProjectLabelsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - group_labels_qs = EntryGroupLabel.get_stat_for_entry( - EntryGroupLabel.objects.filter(entry__in=keys) - ) + group_labels_qs = EntryGroupLabel.get_stat_for_entry(EntryGroupLabel.objects.filter(entry__in=keys)) group_labels = defaultdict(list) for group_label in group_labels_qs: - group_labels[group_label['entry']].append(group_label) + group_labels[group_label["entry"]].append(group_label) return Promise.resolve([group_labels.get(key) for key in keys]) class AttributeGeoSelectedOptionsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - geo_area_qs = get_geo_area_queryset_for_project_geo_area_type().filter( - id__in={id for ids in keys for id in ids} - ).defer('polygons') - geo_area_map = { - str(geo_area.id): geo_area - for geo_area in geo_area_qs - } - return Promise.resolve( - [ - [ - geo_area_map[str(id)] - for id in ids - if id in geo_area_map - ] - for ids in keys - ] + geo_area_qs = ( + get_geo_area_queryset_for_project_geo_area_type().filter(id__in={id for ids in keys for id in ids}).defer("polygons") ) + geo_area_map = {str(geo_area.id): geo_area for geo_area in geo_area_qs} + return Promise.resolve([[geo_area_map[str(id)] for id in ids if id in geo_area_map] for ids in keys]) class ReviewCommentsLoader(DataLoaderWithContext): @@ -83,23 +61,20 @@ def batch_load_fn(self, keys): class ReviewCommentsCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - count_qs = EntryReviewComment.objects\ - .filter(entry__in=keys)\ - .order_by().values('entry')\ - .annotate(count=models.Count('id'))\ - .values_list('entry', 'count') - counts = { - entry_id: count - for entry_id, count in count_qs - } + count_qs = ( + EntryReviewComment.objects.filter(entry__in=keys) + .order_by() + .values("entry") + .annotate(count=models.Count("id")) + .values_list("entry", "count") + ) + counts = {entry_id: count for entry_id, count in count_qs} return Promise.resolve([counts.get(key, 0) for key in keys]) class EntryVerifiedByLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - verified_by_through_qs = Entry.verified_by.through.objects\ - .filter(entry__in=keys)\ - .prefetch_related('user') + verified_by_through_qs = Entry.verified_by.through.objects.filter(entry__in=keys).prefetch_related("user") _map = defaultdict(list) for item in verified_by_through_qs.all(): _map[item.entry_id].append(item.user) @@ -108,15 +83,14 @@ def batch_load_fn(self, keys): class EntryVerifiedByCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - count_qs = Entry.verified_by.through.objects\ - .filter(entry__in=keys)\ - .order_by().values('entry')\ - .annotate(count=models.Count('id'))\ - .values_list('entry', 'count') - counts = { - entry: count - for entry, count in count_qs - } + count_qs = ( + Entry.verified_by.through.objects.filter(entry__in=keys) + .order_by() + .values("entry") + .annotate(count=models.Count("id")) + .values_list("entry", "count") + ) + counts = {entry: count for entry, count in count_qs} return Promise.resolve([counts.get(key, 0) for key in keys]) diff --git a/apps/entry/enums.py b/apps/entry/enums.py index e98f268a8d..c29f116eb4 100644 --- a/apps/entry/enums.py +++ b/apps/entry/enums.py @@ -5,11 +5,6 @@ from .models import Entry -EntryTagTypeEnum = convert_enum_to_graphene_enum(Entry.TagType, name='EntryTagTypeEnum') +EntryTagTypeEnum = convert_enum_to_graphene_enum(Entry.TagType, name="EntryTagTypeEnum") -enum_map = { - get_enum_name_from_django_field(field): enum - for field, enum in ( - (Entry.entry_type, EntryTagTypeEnum), - ) -} +enum_map = {get_enum_name_from_django_field(field): enum for field, enum in ((Entry.entry_type, EntryTagTypeEnum),)} diff --git a/apps/entry/errors.py b/apps/entry/errors.py index a1d2d57399..71402f028d 100644 --- a/apps/entry/errors.py +++ b/apps/entry/errors.py @@ -6,4 +6,4 @@ class EntryValidationVersionMismatchError(exceptions.ValidationError): status_code = 400 code = error_codes.ENTRY_VALIDATION_VERSION_MISMATCH - default_detail = 'Version Mismatch' + default_detail = "Version Mismatch" diff --git a/apps/entry/factories.py b/apps/entry/factories.py index 6f49b618dc..c0bbe98029 100644 --- a/apps/entry/factories.py +++ b/apps/entry/factories.py @@ -1,14 +1,9 @@ import factory from factory import fuzzy from factory.django import DjangoModelFactory - from gallery.factories import FileFactory -from .models import ( - Entry, - Attribute, - EntryComment, -) +from .models import Attribute, Entry, EntryComment class EntryFactory(DjangoModelFactory): @@ -22,9 +17,9 @@ class Meta: @classmethod def _create(cls, model_class, *args, **kwargs): entry = model_class(*args, **kwargs) - if getattr(entry, 'project', None) is None: # Use lead's project if project is not provided + if getattr(entry, "project", None) is None: # Use lead's project if project is not provided entry.project = entry.lead.project - if getattr(entry, 'analysis_framework', None) is None: # Use lead's project's AF if AF is not provided + if getattr(entry, "analysis_framework", None) is None: # Use lead's project's AF if AF is not provided entry.analysis_framework = entry.lead.project.analysis_framework entry.save() return entry diff --git a/apps/entry/filter_set.py b/apps/entry/filter_set.py index a7f813752f..2843e5f0b6 100644 --- a/apps/entry/filter_set.py +++ b/apps/entry/filter_set.py @@ -1,55 +1,45 @@ import copy -from functools import reduce from datetime import datetime +from functools import reduce -import graphene import django_filters -from django.db import models +import graphene +from analysis_framework.models import Filter, Widget from django.contrib.auth.models import User from django.contrib.postgres.aggregates.general import ArrayAgg +from django.db import models +from entry.widgets.date_widget import parse_date_str +from entry.widgets.time_widget import parse_time_str +from geo.models import GeoArea from graphene_django.filter.filterset import GrapheneFilterSetMixin from graphene_django.filter.utils import get_filtering_args_from_filterset - +from lead.enums import LeadConfidentialityEnum, LeadPriorityEnum, LeadStatusEnum +from lead.models import Lead +from organization.models import OrganizationType +from quality_assurance.models import EntryReviewComment from user_resource.filters import UserResourceGqlFilterSet + +from deep.filter_set import DjangoFilterCSVWidget from utils.common import is_valid_number from utils.graphene.fields import ( - generate_simple_object_type_from_input_type, - generate_object_field_from_input_type, compare_input_output_type_fields, + generate_object_field_from_input_type, + generate_simple_object_type_from_input_type, ) from utils.graphene.filters import ( - IDListFilter, - MultipleInputFilter, DateGteFilter, DateLteFilter, + IDListFilter, + MultipleInputFilter, ) -from deep.filter_set import DjangoFilterCSVWidget -from analysis_framework.models import Filter -from lead.models import Lead -from organization.models import OrganizationType -from analysis_framework.models import Widget -from geo.models import GeoArea -from quality_assurance.models import EntryReviewComment - -from lead.enums import ( - LeadStatusEnum, - LeadPriorityEnum, - LeadConfidentialityEnum, -) -from entry.widgets.date_widget import parse_date_str -from entry.widgets.time_widget import parse_time_str -from .models import ( - Entry, - EntryComment, - ProjectEntryLabel, -) from .enums import EntryTagTypeEnum - +from .models import Entry, EntryComment, ProjectEntryLabel # TODO: Find out whether we need to call timezone.make_aware # from django.utils module to all datetime objects below + # We don't use UserResourceFilterSet since created_at and modified_at # are overridden below class EntryFilterMixin(django_filters.filterset.FilterSet): @@ -57,9 +47,16 @@ class EntryFilterMixin(django_filters.filterset.FilterSet): Entry filter set Basic filtering with lead, excerpt, lead title and dates """ + class CommentStatus(models.TextChoices): - RESOLVED = 'resolved', 'Resolved', - UNRESOLVED = 'unresolved', 'Unresolved', + RESOLVED = ( + "resolved", + "Resolved", + ) + UNRESOLVED = ( + "unresolved", + "Unresolved", + ) lead = django_filters.ModelMultipleChoiceFilter( queryset=Lead.objects.all(), @@ -72,89 +69,88 @@ class CommentStatus(models.TextChoices): ) created_at = django_filters.DateTimeFilter( - field_name='created_at', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + input_formats=["%Y-%m-%d%z"], ) created_at__gt = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='gt', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="gt", + input_formats=["%Y-%m-%d%z"], ) created_at__lt = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='lt', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lt", + input_formats=["%Y-%m-%d%z"], ) created_at__gte = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='gte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="gte", + input_formats=["%Y-%m-%d%z"], ) created_at__lte = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lte", + input_formats=["%Y-%m-%d%z"], ) lead_published_on = django_filters.DateFilter( - field_name='lead__published_on', - + field_name="lead__published_on", ) lead_published_on__gt = django_filters.DateFilter( - field_name='lead__published_on', - lookup_expr='gt', - + field_name="lead__published_on", + lookup_expr="gt", ) lead_published_on__lt = django_filters.DateFilter( - field_name='lead__published_on', - lookup_expr='lt', - + field_name="lead__published_on", + lookup_expr="lt", ) lead_published_on__gte = django_filters.DateFilter( - field_name='lead__published_on', - lookup_expr='gte', - + field_name="lead__published_on", + lookup_expr="gte", ) lead_published_on__lte = django_filters.DateFilter( - field_name='lead__published_on', - lookup_expr='lte', - + field_name="lead__published_on", + lookup_expr="lte", ) lead_assignee = django_filters.ModelMultipleChoiceFilter( - label='Lead Assignees', + label="Lead Assignees", queryset=User.objects.all(), - field_name='lead__assignee', + field_name="lead__assignee", ) comment_status = django_filters.ChoiceFilter( - label='Comment Status', choices=CommentStatus.choices, method='comment_status_filter', + label="Comment Status", + choices=CommentStatus.choices, + method="comment_status_filter", ) comment_assignee = django_filters.ModelMultipleChoiceFilter( - label='Comment Assignees', + label="Comment Assignees", queryset=User.objects.all(), - field_name='entrycomment__assignees', + field_name="entrycomment__assignees", ) comment_created_by = django_filters.ModelMultipleChoiceFilter( - label='Comment Created by', - queryset=User.objects.all(), method='comment_created_by_filter', + label="Comment Created by", + queryset=User.objects.all(), + method="comment_created_by_filter", ) geo_custom_shape = django_filters.CharFilter( - label='GEO Custom Shapes', - method='geo_custom_shape_filter', + label="GEO Custom Shapes", + method="geo_custom_shape_filter", ) # Entry Group Label Filters project_entry_labels = django_filters.ModelMultipleChoiceFilter( - label='Project Entry Labels', - queryset=ProjectEntryLabel.objects.all(), method='project_entry_labels_filter', + label="Project Entry Labels", + queryset=ProjectEntryLabel.objects.all(), + method="project_entry_labels_filter", ) lead_group_label = django_filters.CharFilter( - label='Lead Group Label', - method='lead_group_label_filter', + label="Lead Group Label", + method="lead_group_label_filter", ) authoring_organization_types = django_filters.ModelMultipleChoiceFilter( - method='authoring_organization_types_filter', + method="authoring_organization_types_filter", widget=DjangoFilterCSVWidget, queryset=OrganizationType.objects.all(), ) @@ -163,20 +159,27 @@ class Meta: model = Entry fields = { **{ - x: ['exact'] for x in [ - 'id', 'excerpt', 'lead__title', 'created_at', - 'created_by', 'modified_at', 'modified_by', 'project', - 'controlled', + x: ["exact"] + for x in [ + "id", + "excerpt", + "lead__title", + "created_at", + "created_by", + "modified_at", + "modified_by", + "project", + "controlled", ] }, - 'created_at': ['exact', 'lt', 'gt', 'lte', 'gte'], + "created_at": ["exact", "lt", "gt", "lte", "gte"], # 'lead_published_on': ['exact', 'lt', 'gt', 'lte', 'gte'], } filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -208,9 +211,10 @@ def geo_custom_shape_filter(self, queryset, name, value): lambda acc, item: acc | item, [ models.Q( - attribute__widget__widget_id='geoWidget', - attribute__data__value__contains=[{'type': v}], - ) for v in value.split(',') + attribute__widget__widget_id="geoWidget", + attribute__data__value__contains=[{"type": v}], + ) + for v in value.split(",") ], ) return queryset.filter(query_params) @@ -232,8 +236,7 @@ def authoring_organization_types_filter(self, qs, name, value): if value: qs = qs.annotate( organization_types=models.functions.Coalesce( - 'lead__authors__parent__organization_type', - 'lead__authors__organization_type' + "lead__authors__parent__organization_type", "lead__authors__organization_type" ) ) if isinstance(value[0], OrganizationType): @@ -245,7 +248,7 @@ def authoring_organization_types_filter(self, qs, name, value): def qs(self): qs = super().qs # Note: Since we cannot have `.distinct()` inside a subquery - if self.data.get('from_subquery', False): + if self.data.get("from_subquery", False): return Entry.objects.filter(id__in=qs) return qs.distinct() @@ -255,20 +258,27 @@ class Meta: model = Entry fields = { **{ - x: ['exact'] for x in [ - 'id', 'excerpt', 'lead__title', 'created_at', - 'created_by', 'modified_at', 'modified_by', 'project', - 'controlled', + x: ["exact"] + for x in [ + "id", + "excerpt", + "lead__title", + "created_at", + "created_by", + "modified_at", + "modified_by", + "project", + "controlled", ] }, - 'created_at': ['exact', 'lt', 'gt', 'lte', 'gte'], + "created_at": ["exact", "lt", "gt", "lte", "gte"], # 'lead_published_on': ['exact', 'lt', 'gt', 'lte', 'gte'], } filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } @@ -277,32 +287,32 @@ class Meta: class EntryCommentFilterSet(django_filters.FilterSet): class Meta: model = EntryComment - fields = ('created_by', 'is_resolved', 'resolved_at') + fields = ("created_by", "is_resolved", "resolved_at") def get_filtered_entries_using_af_filter( - entries, filters, queries, - project=None, new_query_structure=False, + entries, + filters, + queries, + project=None, + new_query_structure=False, ): queries = copy.deepcopy(queries) region_max_level = 0 if project: - region_max_level = project.regions\ - .annotate(adminlevel_count=models.Count('adminlevel'))\ - .aggregate(max_level=models.Max('adminlevel_count'))['max_level'] or 0 + region_max_level = ( + project.regions.annotate(adminlevel_count=models.Count("adminlevel")).aggregate( + max_level=models.Max("adminlevel_count") + )["max_level"] + or 0 + ) if isinstance(queries, list): - queries = { - q['filter_key']: q - for q in queries - } - elif 'filterable_data' in queries: + queries = {q["filter_key"]: q for q in queries} + elif "filterable_data" in queries: # XXX: Pass new structure. - queries = { - q['filter_key']: q - for q in queries['filterable_data'] - } + queries = {q["filter_key"]: q for q in queries["filterable_data"]} new_query_structure = True # NOTE: lets not use `.distinct()` in this function as it is used by a subquery in `lead/models.py`. @@ -316,19 +326,19 @@ def get_filtered_entries_using_af_filter( if not new_query_structure: value = queries.get(_filter.key) - value_lte = queries.get(_filter.key + '__lt') - value_gte = queries.get(_filter.key + '__gt') - value_and = queries.get(_filter.key + '__and') + value_lte = queries.get(_filter.key + "__lt") + value_gte = queries.get(_filter.key + "__gt") + value_and = queries.get(_filter.key + "__and") if value_and: value = value_and use_and_operator = True use_exclude = False - value_exclude = queries.get(_filter.key + '_exclude') + value_exclude = queries.get(_filter.key + "_exclude") if value_exclude: value = value_exclude use_and_operator = False use_exclude = True - value_exclude_and = queries.get(_filter.key + '_exclude_and') + value_exclude_and = queries.get(_filter.key + "_exclude_and") if value_exclude_and: value = value_exclude use_exclude = True @@ -339,16 +349,13 @@ def get_filtered_entries_using_af_filter( if not query: continue - value = query.get('value') - value_gte = query.get('value_gte') - value_lte = query.get('value_lte') - value_list = [ - v for v in query.get('value_list') or [] - if v is not None - ] - use_exclude = query.get('use_exclude') - use_and_operator = query.get('use_and_operator') - include_sub_regions = query.get('include_sub_regions') + value = query.get("value") + value_gte = query.get("value_gte") + value_lte = query.get("value_lte") + value_list = [v for v in query.get("value_list") or [] if v is not None] + use_exclude = query.get("use_exclude") + use_and_operator = query.get("use_and_operator") + include_sub_regions = query.get("include_sub_regions") if not any([value, value_gte, value_lte, value_list]): continue @@ -367,9 +374,9 @@ def get_filtered_entries_using_af_filter( Widget.WidgetType.TIME, Widget.WidgetType.TIME_RANGE, ]: - value = value and parse_time_str(value)['time_val'] - value_gte = value_gte and parse_time_str(value_gte)['time_val'] - value_lte = value_lte and parse_time_str(value_lte)['time_val'] + value = value and parse_time_str(value)["time_val"] + value_gte = value_gte and parse_time_str(value_gte)["time_val"] + value_lte = value_lte and parse_time_str(value_lte)["time_val"] if _filter.filter_type == Filter.FilterType.NUMBER: if value: @@ -405,32 +412,38 @@ def get_filtered_entries_using_af_filter( ) if value_lte and value_gte: - q = models.Q( - filterdata__from_number__lte=value_lte, - filterdata__to_number__gte=value_lte, - ) | models.Q( - filterdata__from_number__lte=value_gte, - filterdata__to_number__gte=value_gte, - ) | models.Q( - filterdata__from_number__gte=value_gte, - filterdata__to_number__lte=value_lte, + q = ( + models.Q( + filterdata__from_number__lte=value_lte, + filterdata__to_number__gte=value_lte, + ) + | models.Q( + filterdata__from_number__lte=value_gte, + filterdata__to_number__gte=value_gte, + ) + | models.Q( + filterdata__from_number__gte=value_gte, + filterdata__to_number__lte=value_lte, + ) ) entries = entries.filter(q, filterdata__filter=_filter) elif _filter.filter_type == Filter.FilterType.LIST: if value_list and not isinstance(value_list, list): - value_list = value_list.split(',') + value_list = value_list.split(",") if value_list: # Fetch sub-regions if required if region_max_level and include_sub_regions and _filter.widget_type == Widget.WidgetType.GEO: # XXX: simple values('id') doesn't work. Better way? - value_list = GeoArea.\ - get_sub_childrens(value_list, level=region_max_level)\ - .filter(admin_level__region__project=project)\ - .order_by().values('admin_level__region__project')\ - .annotate(ids=ArrayAgg('id'))\ - .values('ids') + value_list = ( + GeoArea.get_sub_childrens(value_list, level=region_max_level) + .filter(admin_level__region__project=project) + .order_by() + .values("admin_level__region__project") + .annotate(ids=ArrayAgg("id")) + .values("ids") + ) query_filter = models.Q( filterdata__filter=_filter, @@ -450,7 +463,7 @@ def get_filtered_entries_using_af_filter( else: entries = entries.filter(query_filter) - return entries.order_by('-lead__created_by', 'lead', 'created_by') + return entries.order_by("-lead__created_by", "lead", "created_by") def get_filtered_entries(user, queries): @@ -459,28 +472,28 @@ def get_filtered_entries(user, queries): entries = Entry.get_for(user) filters = Filter.get_for(user, with_widget_type=True) - project_id = queries.get('project') + project_id = queries.get("project") if project_id: entries = entries.filter(lead__project__id=project_id) filters = filters.filter(analysis_framework__project__id=project_id) - entries_id = queries.get('entries_id') + entries_id = queries.get("entries_id") if entries_id: entries = entries.filter(id__in=entries_id) - entry_type = queries.get('entry_type') + entry_type = queries.get("entry_type") if entry_type: entries = entries.filter(entry_type__in=entry_type) - lead_status = queries.get('lead_status') + lead_status = queries.get("lead_status") if lead_status: entries = entries.filter(lead__status__in=lead_status) - lead_priority = queries.get('lead_priority') + lead_priority = queries.get("lead_priority") if lead_priority: entries = entries.filter(lead__priority__in=lead_priority) - lead_confidentiality = queries.get('lead_confidentiality') + lead_confidentiality = queries.get("lead_confidentiality") if lead_confidentiality: entries = entries.filter(lead__confidentiality__in=lead_confidentiality) @@ -495,23 +508,23 @@ def get_filtered_entries(user, queries): def parse_date(val): try: - val = val.replace(':', '') - return datetime.strptime(val, '%Y-%m-%d%z') + val = val.replace(":", "") + return datetime.strptime(val, "%Y-%m-%d%z") except Exception: return None QUERY_MAP = { - 'created_at': parse_date, - 'created_at__gt': parse_date, - 'created_at__lt': parse_date, - 'created_at__gte': parse_date, - 'created_at__lte': parse_date, - 'lead__published_on': parse_date, - 'lead__published_on__gt': parse_date, - 'lead__published_on__lt': parse_date, - 'lead__published_on__gte': parse_date, - 'lead__published_on__lte': parse_date, + "created_at": parse_date, + "created_at__gt": parse_date, + "created_at__lt": parse_date, + "created_at__gte": parse_date, + "created_at__lte": parse_date, + "lead__published_on": parse_date, + "lead__published_on__gt": parse_date, + "lead__published_on__lt": parse_date, + "lead__published_on__gte": parse_date, + "lead__published_on__lte": parse_date, } @@ -548,63 +561,67 @@ class EntryFilterDataInputType(graphene.InputObjectType): - use_and_operator (Use AND to filter) - use_exclude (Exclude entry using filter value) """ + filter_key = graphene.ID(required=True) - value = graphene.String(description='Valid for single value widgets') - value_gte = graphene.String(description='Valid for range or single value widgets') - value_lte = graphene.String(description='Valid for range or single value widgets') - value_list = graphene.List(graphene.NonNull(graphene.String), description='Valid for list value widgets') - use_exclude = graphene.Boolean(description='Only for array values') - use_and_operator = graphene.Boolean(description='Used AND instead of OR') - include_sub_regions = graphene.Boolean(description='Only valid for GEO widget values') + value = graphene.String(description="Valid for single value widgets") + value_gte = graphene.String(description="Valid for range or single value widgets") + value_lte = graphene.String(description="Valid for range or single value widgets") + value_list = graphene.List(graphene.NonNull(graphene.String), description="Valid for list value widgets") + use_exclude = graphene.Boolean(description="Only for array values") + use_and_operator = graphene.Boolean(description="Used AND instead of OR") + include_sub_regions = graphene.Boolean(description="Only valid for GEO widget values") # ----------------------------- Graphql Filters --------------------------------------- class EntryGQFilterSet(GrapheneFilterSetMixin, UserResourceGqlFilterSet): # Lead fields - leads = IDListFilter(field_name='lead') - lead_created_by = IDListFilter(field_name='lead__created_by') + leads = IDListFilter(field_name="lead") + lead_created_by = IDListFilter(field_name="lead__created_by") lead_published_on = django_filters.DateFilter() - lead_published_on_gte = DateGteFilter(field_name='lead__published_on') - lead_published_on_lte = DateLteFilter(field_name='lead__published_on') - lead_title = django_filters.CharFilter(lookup_expr='icontains', field_name='lead__title') - lead_assignees = IDListFilter(label='Lead Assignees', field_name='lead__assignee') - lead_statuses = MultipleInputFilter(LeadStatusEnum, field_name='lead__status') - lead_priorities = MultipleInputFilter(LeadPriorityEnum, field_name='lead__priority') - lead_confidentialities = MultipleInputFilter(LeadConfidentialityEnum, field_name='lead__confidentiality') - lead_authoring_organization_types = IDListFilter(method='authoring_organization_types_filter') - lead_author_organizations = IDListFilter(field_name='lead__authors') - lead_source_organizations = IDListFilter(field_name='lead__source') - lead_has_assessment = django_filters.BooleanFilter(method='lead_has_assessment_filter', help_text='Lead has assessment.') - lead_is_assessment = django_filters.BooleanFilter(field_name='lead__is_assessment_lead') - - search = django_filters.CharFilter(method='search_filter') + lead_published_on_gte = DateGteFilter(field_name="lead__published_on") + lead_published_on_lte = DateLteFilter(field_name="lead__published_on") + lead_title = django_filters.CharFilter(lookup_expr="icontains", field_name="lead__title") + lead_assignees = IDListFilter(label="Lead Assignees", field_name="lead__assignee") + lead_statuses = MultipleInputFilter(LeadStatusEnum, field_name="lead__status") + lead_priorities = MultipleInputFilter(LeadPriorityEnum, field_name="lead__priority") + lead_confidentialities = MultipleInputFilter(LeadConfidentialityEnum, field_name="lead__confidentiality") + lead_authoring_organization_types = IDListFilter(method="authoring_organization_types_filter") + lead_author_organizations = IDListFilter(field_name="lead__authors") + lead_source_organizations = IDListFilter(field_name="lead__source") + lead_has_assessment = django_filters.BooleanFilter(method="lead_has_assessment_filter", help_text="Lead has assessment.") + lead_is_assessment = django_filters.BooleanFilter(field_name="lead__is_assessment_lead") + + search = django_filters.CharFilter(method="search_filter") created_by = IDListFilter() modified_by = IDListFilter() - entry_types = MultipleInputFilter(EntryTagTypeEnum, field_name='entry_type') - project_entry_labels = IDListFilter(label='Project Entry Labels', method='project_entry_labels_filter') - entries_id = IDListFilter(field_name='id') - geo_custom_shape = django_filters.CharFilter(label='GEO Custom Shapes', method='geo_custom_shape_filter') + entry_types = MultipleInputFilter(EntryTagTypeEnum, field_name="entry_type") + project_entry_labels = IDListFilter(label="Project Entry Labels", method="project_entry_labels_filter") + entries_id = IDListFilter(field_name="id") + geo_custom_shape = django_filters.CharFilter(label="GEO Custom Shapes", method="geo_custom_shape_filter") # Entry Group Label Filters - lead_group_label = django_filters.CharFilter(label='Lead Group Label', method='lead_group_label_filter') + lead_group_label = django_filters.CharFilter(label="Lead Group Label", method="lead_group_label_filter") # Dynamic filterable data - filterable_data = MultipleInputFilter(EntryFilterDataInputType, method='filterable_data_filter') - has_comment = django_filters.BooleanFilter(method='filter_commented_entries') - is_verified = django_filters.BooleanFilter(method='filter_verified_entries') + filterable_data = MultipleInputFilter(EntryFilterDataInputType, method="filterable_data_filter") + has_comment = django_filters.BooleanFilter(method="filter_commented_entries") + is_verified = django_filters.BooleanFilter(method="filter_verified_entries") class Meta: model = Entry fields = { **{ - x: ['exact'] for x in [ - 'id', 'excerpt', 'controlled', + x: ["exact"] + for x in [ + "id", + "excerpt", + "controlled", ] }, } filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } @@ -614,7 +631,7 @@ def filterable_data_filter(self, queryset, _, value): project = self.request and self.request.active_project if project is None or project.analysis_framework_id is None: # This needs to be defined - raise Exception(f'Both should be defined {project=} {project and project.analysis_framework_id=}') + raise Exception(f"Both should be defined {project=} {project and project.analysis_framework_id=}") filters = Filter.qs_with_widget_type().filter(analysis_framework_id=project.analysis_framework_id).all() return get_filtered_entries_using_af_filter(queryset, filters, value, project=project, new_query_structure=True) return queryset @@ -625,9 +642,10 @@ def geo_custom_shape_filter(self, queryset, name, value): lambda acc, item: acc | item, [ models.Q( - attribute__widget__widget_id='geoWidget', - attribute__data__value__contains=[{'type': v}], - ) for v in value.split(',') + attribute__widget__widget_id="geoWidget", + attribute__data__value__contains=[{"type": v}], + ) + for v in value.split(",") ], ) return queryset.filter(query_params) @@ -654,8 +672,7 @@ def authoring_organization_types_filter(self, qs, name, value): if value: qs = qs.annotate( organization_types=models.functions.Coalesce( - 'lead__authors__parent__organization_type', - 'lead__authors__organization_type' + "lead__authors__parent__organization_type", "lead__authors__organization_type" ) ) if isinstance(value[0], OrganizationType): @@ -692,26 +709,24 @@ def search_filter(self, qs, _, value): def qs(self): qs = super().qs # Note: Since we cannot have `.distinct()` inside a subquery - if self.data.get('from_subquery', False): + if self.data.get("from_subquery", False): return Entry.objects.filter(id__in=qs) return qs.distinct() def get_entry_filter_object_type(input_type): - new_fields_map = generate_object_field_from_input_type(input_type, skip_fields=['filterable_data']) - new_fields_map['filterable_data'] = graphene.List( - graphene.NonNull( - generate_simple_object_type_from_input_type(EntryFilterDataInputType) - ) + new_fields_map = generate_object_field_from_input_type(input_type, skip_fields=["filterable_data"]) + new_fields_map["filterable_data"] = graphene.List( + graphene.NonNull(generate_simple_object_type_from_input_type(EntryFilterDataInputType)) ) - new_type = type('EntriesFilterDataType', (graphene.ObjectType,), new_fields_map) + new_type = type("EntriesFilterDataType", (graphene.ObjectType,), new_fields_map) compare_input_output_type_fields(input_type, new_type) return new_type EntriesFilterDataInputType = type( - 'EntriesFilterDataInputType', + "EntriesFilterDataInputType", (graphene.InputObjectType,), - get_filtering_args_from_filterset(EntryGQFilterSet, 'entry.schema.EntryListType') + get_filtering_args_from_filterset(EntryGQFilterSet, "entry.schema.EntryListType"), ) EntriesFilterDataType = get_entry_filter_object_type(EntriesFilterDataInputType) diff --git a/apps/entry/models.py b/apps/entry/models.py index 1771b2798a..40594b43fd 100644 --- a/apps/entry/models.py +++ b/apps/entry/models.py @@ -1,24 +1,19 @@ +from analysis_framework.models import AnalysisFramework, Exportable, Filter, Widget +from assisted_tagging.models import DraftEntry from django.contrib.contenttypes.fields import GenericRelation from django.contrib.postgres.aggregates.general import ArrayAgg from django.contrib.postgres.fields import ArrayField from django.db import models - -from deep.middleware import get_current_user -from utils.common import parse_number +from gallery.models import File +from lead.models import Lead +from notification.models import Assignment from project.mixins import ProjectEntityMixin from project.permissions import PROJECT_PERMISSIONS -from gallery.models import File from user.models import User from user_resource.models import UserResource -from lead.models import Lead -from notification.models import Assignment -from analysis_framework.models import ( - AnalysisFramework, - Widget, - Filter, - Exportable, -) -from assisted_tagging.models import DraftEntry + +from deep.middleware import get_current_user +from utils.common import parse_number class Entry(UserResource, ProjectEntityMixin): @@ -30,15 +25,22 @@ class Entry(UserResource, ProjectEntityMixin): """ class TagType(models.TextChoices): - EXCERPT = 'excerpt', 'Excerpt', - IMAGE = 'image', 'Image', - DATA_SERIES = 'dataSeries', 'Data Series' # NOTE: data saved as tabular_field id + EXCERPT = ( + "excerpt", + "Excerpt", + ) + IMAGE = ( + "image", + "Image", + ) + DATA_SERIES = "dataSeries", "Data Series" # NOTE: data saved as tabular_field id lead = models.ForeignKey(Lead, on_delete=models.CASCADE) - project = models.ForeignKey('project.Project', on_delete=models.CASCADE) + project = models.ForeignKey("project.Project", on_delete=models.CASCADE) order = models.IntegerField(default=1) analysis_framework = models.ForeignKey( - AnalysisFramework, on_delete=models.CASCADE, + AnalysisFramework, + on_delete=models.CASCADE, ) information_date = models.DateField(default=None, null=True, blank=True) @@ -46,7 +48,7 @@ class TagType(models.TextChoices): excerpt = models.TextField(blank=True) image = models.ForeignKey(File, on_delete=models.SET_NULL, null=True, blank=True) image_raw = models.TextField(blank=True) - tabular_field = models.ForeignKey('tabular.Field', on_delete=models.CASCADE, null=True, blank=True) + tabular_field = models.ForeignKey("tabular.Field", on_delete=models.CASCADE, null=True, blank=True) dropped_excerpt = models.TextField(blank=True) # NOTE: Original Exceprt. Modified version is stored in excerpt excerpt_modified = models.BooleanField(default=False) @@ -54,9 +56,7 @@ class TagType(models.TextChoices): # NOTE: verification is also called controlled in QA controlled = models.BooleanField(default=False, blank=True, null=True) - controlled_changed_by = models.ForeignKey( - User, blank=True, null=True, - related_name='+', on_delete=models.SET_NULL) + controlled_changed_by = models.ForeignKey(User, blank=True, null=True, related_name="+", on_delete=models.SET_NULL) # NOTE: verified_by is related to review comment verified_by = models.ManyToManyField(User, blank=True) draft_entry = models.ForeignKey(DraftEntry, on_delete=models.SET_NULL, null=True, blank=True) @@ -72,25 +72,20 @@ def annotate_comment_count(cls, qs): def _count_subquery(subquery): return models.functions.Coalesce( models.Subquery( - subquery.values('entry').order_by().annotate( - count=models.Count('id', distinct=True) - ).values('count')[:1], - output_field=models.IntegerField() - ), 0, + subquery.values("entry").order_by().annotate(count=models.Count("id", distinct=True)).values("count")[:1], + output_field=models.IntegerField(), + ), + 0, ) current_user = get_current_user() - verified_by_qs = cls.verified_by.through.objects.filter(entry=models.OuterRef('pk')) - entrycomment_qs = EntryComment.objects.filter(entry=models.OuterRef('pk'), parent=None) + verified_by_qs = cls.verified_by.through.objects.filter(entry=models.OuterRef("pk")) + entrycomment_qs = EntryComment.objects.filter(entry=models.OuterRef("pk"), parent=None) return qs.annotate( verified_by_count=_count_subquery(verified_by_qs), is_verified_by_current_user=models.Exists(verified_by_qs.filter(user=current_user)), - resolved_comment_count=_count_subquery( - entrycomment_qs.filter(is_resolved=True) - ), - unresolved_comment_count=_count_subquery( - entrycomment_qs.filter(is_resolved=False) - ), + resolved_comment_count=_count_subquery(entrycomment_qs.filter(is_resolved=True)), + unresolved_comment_count=_count_subquery(entrycomment_qs.filter(is_resolved=False)), ) def __init__(self, *args, **kwargs): @@ -99,7 +94,7 @@ def __init__(self, *args, **kwargs): def __str__(self): if self.entry_type == Entry.TagType.IMAGE: - return 'Image ({})'.format(self.lead.title) + return "Image ({})".format(self.lead.title) else: return '"{}" ({})'.format( self.excerpt[:30], @@ -111,14 +106,14 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) def get_image_url(self): - if hasattr(self, 'image_url'): + if hasattr(self, "image_url"): return self.image_url gallery_file = None if self.image: gallery_file = self.image elif self.image_raw: - fileid = parse_number(self.image_raw.rstrip('/').split('/')[-1]) # remove last slash if present + fileid = parse_number(self.image_raw.rstrip("/").split("/")[-1]) # remove last slash if present if fileid: gallery_file = File.objects.filter(id=fileid).first() self.image_url = gallery_file and gallery_file.get_file_url() @@ -134,26 +129,29 @@ def get_for(cls, user): # NOTE: This is quite complicated because user can have two view roles: # view entry or view_only_unprotected, both of which return different results - qs = cls.objects.filter( - project__projectmembership__member=user, - ).annotate( - # Get permission value for view_only_unprotected permission - view_unprotected=models.F( - 'project__projectmembership__role__entry_permissions' - ).bitand(view_unprotected_perm_value), - # Get permission value for view permission - view_all=models.F( - 'project__projectmembership__role__entry_permissions' - ).bitand(view_perm_value) - ).filter( - # If entry is view only unprotected, filter entries with - # lead confidentiality not confidential - ( - models.Q(view_unprotected=view_unprotected_perm_value) & - ~models.Q(lead__confidentiality=Lead.Confidentiality.CONFIDENTIAL) - ) | - # Or, return nothing if view_all is not present - models.Q(view_all=view_perm_value) + qs = ( + cls.objects.filter( + project__projectmembership__member=user, + ) + .annotate( + # Get permission value for view_only_unprotected permission + view_unprotected=models.F("project__projectmembership__role__entry_permissions").bitand( + view_unprotected_perm_value + ), + # Get permission value for view permission + view_all=models.F("project__projectmembership__role__entry_permissions").bitand(view_perm_value), + ) + .filter( + # If entry is view only unprotected, filter entries with + # lead confidentiality not confidential + ( + models.Q(view_unprotected=view_unprotected_perm_value) + & ~models.Q(lead__confidentiality=Lead.Confidentiality.CONFIDENTIAL) + ) + | + # Or, return nothing if view_all is not present + models.Q(view_all=view_perm_value) + ) ) return qs @@ -164,26 +162,24 @@ def get_exportable_queryset(cls, qs): return qs.annotate( # Get permission value for create_only_unprotected export - create_only_unprotected=models.F( - 'project__projectmembership__role__export_permissions' - ).bitand(export_unprotected_perm_value), + create_only_unprotected=models.F("project__projectmembership__role__export_permissions").bitand( + export_unprotected_perm_value + ), # Get permission value for create permission - create_all=models.F( - 'project__projectmembership__role__export_permissions' - ).bitand(export_perm_value) + create_all=models.F("project__projectmembership__role__export_permissions").bitand(export_perm_value), ).filter( # Priority given to create_only_unprotected export permission i.e. # if create_only_unprotected is true, then fetch non confidential entries ( - models.Q(create_only_unprotected=export_unprotected_perm_value) & - ~models.Q(lead__confidentiality=Lead.Confidentiality.CONFIDENTIAL) - ) | - models.Q(create_all=export_perm_value) + models.Q(create_only_unprotected=export_unprotected_perm_value) + & ~models.Q(lead__confidentiality=Lead.Confidentiality.CONFIDENTIAL) + ) + | models.Q(create_all=export_perm_value) ) class Meta(UserResource.Meta): - verbose_name_plural = 'entries' - ordering = ['order', '-created_at'] + verbose_name_plural = "entries" + ordering = ["order", "-created_at"] class Attribute(models.Model): @@ -193,6 +189,7 @@ class Attribute(models.Model): Note that attributes are set by widgets and has the reference for that widget. """ + entry = models.ForeignKey(Entry, on_delete=models.CASCADE) widget = models.ForeignKey(Widget, on_delete=models.CASCADE) # Widget's version when the attribute was saved (Set by client) @@ -202,10 +199,11 @@ class Attribute(models.Model): def save(self, *args, **kwargs): super().save(*args, **kwargs) from .utils import update_entry_attribute + update_entry_attribute(self) def __str__(self): - return 'Attribute ({}, {})'.format( + return "Attribute ({}, {})".format( self.entry.lead.title, self.widget.title, ) @@ -217,8 +215,7 @@ def get_for(user): it's entry """ return Attribute.objects.filter( - models.Q(entry__lead__project__members=user) | - models.Q(entry__lead__project__user_groups__members=user) + models.Q(entry__lead__project__members=user) | models.Q(entry__lead__project__user_groups__members=user) ).distinct() def can_get(self, user): @@ -232,13 +229,16 @@ class FilterData(models.Model): """ Filter data for an entry to use for filterting """ + entry = models.ForeignKey(Entry, on_delete=models.CASCADE) filter = models.ForeignKey(Filter, on_delete=models.CASCADE) # List of text values values = ArrayField( models.CharField(max_length=100, blank=True), - default=None, blank=True, null=True, + default=None, + blank=True, + null=True, ) # Just number for numeric comparision @@ -256,8 +256,7 @@ def get_for(user): it's entry """ return FilterData.objects.filter( - models.Q(entry__lead__project__members=user) | - models.Q(entry__lead__project__user_groups__members=user) + models.Q(entry__lead__project__members=user) | models.Q(entry__lead__project__user_groups__members=user) ).distinct() def can_get(self, user): @@ -267,7 +266,7 @@ def can_modify(self, user): return self.entry.can_modify(user) def __str__(self): - return 'Filter data ({}, {})'.format( + return "Filter data ({}, {})".format( self.entry.lead.title, self.filter.title, ) @@ -277,6 +276,7 @@ class ExportData(models.Model): """ Export data for an entry """ + entry = models.ForeignKey(Entry, on_delete=models.CASCADE) exportable = models.ForeignKey(Exportable, on_delete=models.CASCADE) data = models.JSONField(default=None, blank=True, null=True) @@ -287,11 +287,11 @@ def get_for(user): Export data can only be accessed by users who have access to it's entry """ - return ExportData.objects.select_related('entry__lead__project')\ - .filter( - models.Q(entry__lead__project__members=user) | - models.Q(entry__lead__project__user_groups__members=user))\ + return ( + ExportData.objects.select_related("entry__lead__project") + .filter(models.Q(entry__lead__project__members=user) | models.Q(entry__lead__project__user_groups__members=user)) .distinct() + ) def can_get(self, user): return self.entry.can_get(user) @@ -300,7 +300,7 @@ def can_modify(self, user): return self.entry.can_modify(user) def __str__(self): - return 'Export data ({}, {})'.format( + return "Export data ({}, {})".format( self.entry.lead.title, self.exportable.widget_key, ) @@ -308,18 +308,20 @@ def __str__(self): class EntryComment(models.Model): entry = models.ForeignKey(Entry, on_delete=models.CASCADE) - created_by = models.ForeignKey(User, related_name='%(class)s_created', on_delete=models.CASCADE) + created_by = models.ForeignKey(User, related_name="%(class)s_created", on_delete=models.CASCADE) assignees = models.ManyToManyField(User, blank=True) is_resolved = models.BooleanField(null=True, blank=True, default=False) resolved_at = models.DateTimeField(null=True, blank=True) parent = models.ForeignKey( - 'EntryComment', - null=True, blank=True, on_delete=models.CASCADE, + "EntryComment", + null=True, + blank=True, + on_delete=models.CASCADE, ) - assignments = GenericRelation(Assignment, related_query_name='entry_comment') + assignments = GenericRelation(Assignment, related_query_name="entry_comment") def __str__(self): - return f'{self.entry}: {self.text} (Resolved: {self.is_resolved})' + return f"{self.entry}: {self.text} (Resolved: {self.is_resolved})" def can_delete(self, user): return self.can_modify(user) @@ -329,11 +331,11 @@ def can_modify(self, user): @staticmethod def get_for(user): - return EntryComment.objects.prefetch_related('entrycommenttext_set')\ - .filter( - models.Q(entry__lead__project__members=user) | - models.Q(entry__lead__project__user_groups__members=user))\ + return ( + EntryComment.objects.prefetch_related("entrycommenttext_set") + .filter(models.Q(entry__lead__project__members=user) | models.Q(entry__lead__project__user_groups__members=user)) .distinct() + ) @property def text(self): @@ -342,10 +344,10 @@ def text(self): return comment_text.text def get_related_users(self, skip_owner_user=True): - users = list(self.entrycomment_set.values_list('created_by', flat=True)) - users.extend(self.assignees.values_list('id', flat=True)) + users = list(self.entrycomment_set.values_list("created_by", flat=True)) + users.extend(self.assignees.values_list("id", flat=True)) if self.parent: - users.extend(self.parent.assignees.values_list('id', flat=True)) + users.extend(self.parent.assignees.values_list("id", flat=True)) users.append(self.parent.created_by_id) queryset = User.objects.filter(pk__in=set(users)) if skip_owner_user: @@ -364,13 +366,14 @@ class ProjectEntryLabel(UserResource): """ Labels defined for entries in Project Scope """ - project = models.ForeignKey('project.Project', on_delete=models.CASCADE) + + project = models.ForeignKey("project.Project", on_delete=models.CASCADE) title = models.CharField(max_length=225) order = models.IntegerField(default=1) color = models.CharField(max_length=255, null=True, blank=True) def __str__(self): - return f'{self.project}: {self.title}' + return f"{self.project}: {self.title}" def can_modify(self, user): return self.project.can_modify(user) @@ -380,12 +383,13 @@ class LeadEntryGroup(UserResource): """ Groups defined for entries in Lead Scope """ + lead = models.ForeignKey(Lead, on_delete=models.CASCADE) title = models.CharField(max_length=225) order = models.IntegerField(default=1) def __str__(self): - return f'{self.lead}: {self.title}' + return f"{self.lead}: {self.title}" def can_modify(self, user): return self.lead.can_modify(user) @@ -395,6 +399,7 @@ class EntryGroupLabel(UserResource): """ Relation between Groups, Labels and Entries """ + label = models.ForeignKey(ProjectEntryLabel, on_delete=models.CASCADE) group = models.ForeignKey(LeadEntryGroup, on_delete=models.CASCADE) entry = models.ForeignKey(Entry, on_delete=models.CASCADE) @@ -402,22 +407,32 @@ class EntryGroupLabel(UserResource): class Meta: # Only single entry allowd in label:group pair - unique_together = ('label', 'group',) + unique_together = ( + "label", + "group", + ) @staticmethod def get_stat_for_entry(qs): - return qs.order_by().values('entry', 'label').annotate( - count=models.Count('id'), - groups=ArrayAgg('group__title'), - ).values( - 'entry', 'count', 'groups', - label_id=models.F('label__id'), - label_color=models.F('label__color'), - label_title=models.F('label__title') + return ( + qs.order_by() + .values("entry", "label") + .annotate( + count=models.Count("id"), + groups=ArrayAgg("group__title"), + ) + .values( + "entry", + "count", + "groups", + label_id=models.F("label__id"), + label_color=models.F("label__color"), + label_title=models.F("label__title"), + ) ) def __str__(self): - return f'[{self.label}]:{self.group} -> {self.entry}' + return f"[{self.label}]:{self.group} -> {self.entry}" def can_modify(self, user): return self.entry.can_modify(user) diff --git a/apps/entry/mutation.py b/apps/entry/mutation.py index 8fa8415a22..ebaa084a01 100644 --- a/apps/entry/mutation.py +++ b/apps/entry/mutation.py @@ -1,27 +1,24 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsBulkGrapheneMutation, PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP from .models import Entry from .schema import EntryType -from .serializers import ( - EntryGqSerializer as EntrySerializer, -) - +from .serializers import EntryGqSerializer as EntrySerializer EntryInputType = generate_input_type_for_serializer( - 'EntryInputType', + "EntryInputType", serializer_class=EntrySerializer, ) -class EntryMutationMixin(): +class EntryMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter( @@ -35,6 +32,7 @@ def filter_queryset(cls, qs, info): class CreateEntry(EntryMutationMixin, PsGrapheneMutation): class Arguments: data = EntryInputType(required=True) + model = Entry serializer_class = EntrySerializer result = graphene.Field(EntryType) @@ -45,6 +43,7 @@ class UpdateEntry(EntryMutationMixin, PsGrapheneMutation): class Arguments: data = EntryInputType(required=True) id = graphene.ID(required=True) + model = Entry serializer_class = EntrySerializer result = graphene.Field(EntryType) @@ -54,6 +53,7 @@ class Arguments: class DeleteEntry(EntryMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = Entry result = graphene.Field(EntryType) permissions = [PP.Permission.DELETE_ENTRY] @@ -76,7 +76,7 @@ class Arguments: permissions = [PP.Permission.CREATE_ENTRY] -class Mutation(): +class Mutation: entry_create = CreateEntry.Field() entry_update = UpdateEntry.Field() entry_delete = DeleteEntry.Field() diff --git a/apps/entry/receivers.py b/apps/entry/receivers.py index 5fd6d69652..2e3f65d792 100644 --- a/apps/entry/receivers.py +++ b/apps/entry/receivers.py @@ -1,7 +1,7 @@ from django.db import models from django.dispatch import receiver - from lead.models import Lead + from .models import Entry @@ -10,4 +10,4 @@ def update_lead_status(sender, instance, created, **kwargs): lead = instance.lead if lead.status == Lead.Status.NOT_TAGGED: lead.status = Lead.Status.IN_PROGRESS - lead.save(update_fields=['status']) + lead.save(update_fields=["status"]) diff --git a/apps/entry/schema.py b/apps/entry/schema.py index f9a369aa9b..402ca43cc9 100644 --- a/apps/entry/schema.py +++ b/apps/entry/schema.py @@ -1,28 +1,23 @@ import graphene - +from analysis_framework.enums import WidgetWidgetTypeEnum +from analysis_framework.models import Widget from django.db.models import QuerySet +from geo.schema import ProjectGeoAreaType from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination - -from utils.common import has_prefetched -from utils.graphene.enums import EnumDescription -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin -from utils.graphene.fields import DjangoPaginatedListObjectField, DjangoListField -from user_resource.schema import UserResourceMixin -from deep.permissions import ProjectPermissions as PP from lead.models import Lead from user.schema import UserType +from user_resource.schema import UserResourceMixin -from analysis_framework.models import Widget -from analysis_framework.enums import WidgetWidgetTypeEnum -from geo.schema import ProjectGeoAreaType +from deep.permissions import ProjectPermissions as PP +from utils.common import has_prefetched +from utils.graphene.enums import EnumDescription +from utils.graphene.fields import DjangoListField, DjangoPaginatedListObjectField +from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType -from .models import ( - Entry, - Attribute, -) from .enums import EntryTagTypeEnum from .filter_set import EntryGQFilterSet +from .models import Attribute, Entry def get_entry_qs(info): @@ -45,6 +40,7 @@ class EntryGroupLabelType(graphene.ObjectType): """ NOTE: Data is generated from entry_project_labels [EntryProjectLabelsLoader] """ + label_id = graphene.ID(required=True) label_title = graphene.String(required=True) label_color = graphene.String() @@ -58,13 +54,15 @@ class Meta: model = Attribute skip_registry = True only_fields = ( - 'id', 'data', 'widget_version', + "id", + "data", + "widget_version", ) widget = graphene.ID(required=True) widget_version = graphene.Int(required=True) widget_type = graphene.Field(WidgetWidgetTypeEnum, required=True) - widget_type_display = EnumDescription(source='get_widget_type', required=True) + widget_type_display = EnumDescription(source="get_widget_type", required=True) # NOTE: This requires region_title and admin_level_title to be annotated geo_selected_options = graphene.List(graphene.NonNull(ProjectGeoAreaType)) @@ -78,25 +76,32 @@ def resolve_widget_type(root, info, **_): @staticmethod def resolve_geo_selected_options(root, info, **_): - if root.widget_type == Widget.WidgetType.GEO and root.data and root.data.get('value'): - return info.context.dl.entry.attribute_geo_selected_options.load( - tuple(root.data['value']) # needs to be hashable - ) + if root.widget_type == Widget.WidgetType.GEO and root.data and root.data.get("value"): + return info.context.dl.entry.attribute_geo_selected_options.load(tuple(root.data["value"])) # needs to be hashable class EntryType(UserResourceMixin, ClientIdMixin, DjangoObjectType): class Meta: model = Entry only_fields = ( - 'id', - 'lead', 'project', 'analysis_framework', 'information_date', 'order', - 'excerpt', 'dropped_excerpt', 'image', 'tabular_field', 'highlight_hidden', - 'controlled', 'controlled_changed_by', - 'client_id', + "id", + "lead", + "project", + "analysis_framework", + "information_date", + "order", + "excerpt", + "dropped_excerpt", + "image", + "tabular_field", + "highlight_hidden", + "controlled", + "controlled_changed_by", + "client_id", ) entry_type = graphene.Field(EntryTagTypeEnum, required=True) - entry_type_display = EnumDescription(source='get_entry_type_display', required=True) + entry_type_display = EnumDescription(source="get_entry_type_display", required=True) attributes = graphene.List(graphene.NonNull(AttributeType)) project_labels = graphene.List(graphene.NonNull(EntryGroupLabelType)) verified_by = DjangoListField(UserType) @@ -126,14 +131,14 @@ def resolve_review_comments_count(root, info, **_): @staticmethod def resolve_verified_by(root, info, **_): # Use cache if available - if has_prefetched(root, 'verified_by'): + if has_prefetched(root, "verified_by"): return root.verified_by.all() return info.context.dl.entry.verified_by.load(root.pk) @staticmethod def resolve_verified_by_count(root, info, **_): # Use cache if available - if has_prefetched(root, 'verified_by'): + if has_prefetched(root, "verified_by"): return len(root.verified_by.all()) return info.context.dl.entry.verified_by_count.load(root.pk) @@ -146,12 +151,7 @@ class Meta: class Query: entry = DjangoObjectField(EntryType) - entries = DjangoPaginatedListObjectField( - EntryListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) - ) + entries = DjangoPaginatedListObjectField(EntryListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize")) @staticmethod def resolve_entries(root, info, **_) -> QuerySet: diff --git a/apps/entry/serializers.py b/apps/entry/serializers.py index 8bef07f866..8a5da98be2 100644 --- a/apps/entry/serializers.py +++ b/apps/entry/serializers.py @@ -1,116 +1,129 @@ import logging +from analysis_framework.serializers import AnalysisFrameworkSerializer from drf_dynamic_fields import DynamicFieldsMixin from drf_writable_nested.mixins import UniqueFieldsMixin +from gallery.models import File +from gallery.serializers import FileSerializer, SimpleFileSerializer +from geo.models import GeoArea, Region +from geo.serializers import SimpleRegionSerializer +from lead.models import Lead, LeadPreviewImage +from lead.serializers import LeadSerializer +from organization.serializers import SimpleOrganizationSerializer +from project.models import Project, ProjectMembership from rest_framework import serializers +from tabular.serializers import FieldProcessedOnlySerializer +from user.serializers import ( + ComprehensiveUserSerializer, + EntryCommentUserSerializer, + SimpleUserSerializer, +) +from user_resource.serializers import ( + DeprecatedUserResourceSerializer, + UserResourceSerializer, +) -from deep.writable_nested_serializers import ListToDictField from deep.serializers import ( IntegerIDField, ProjectPropertySerializerMixin, RemoveNullFieldsMixin, TempClientIdMixin, ) -from organization.serializers import SimpleOrganizationSerializer -from user_resource.serializers import UserResourceSerializer, DeprecatedUserResourceSerializer -from gallery.models import File -from gallery.serializers import FileSerializer, SimpleFileSerializer -from project.models import Project -from lead.serializers import LeadSerializer -from lead.models import Lead, LeadPreviewImage -from analysis_framework.serializers import AnalysisFrameworkSerializer -from geo.models import GeoArea, Region -from geo.serializers import SimpleRegionSerializer -from tabular.serializers import FieldProcessedOnlySerializer -from user.serializers import EntryCommentUserSerializer, ComprehensiveUserSerializer, SimpleUserSerializer -from project.models import ProjectMembership +from deep.writable_nested_serializers import ListToDictField -from .widgets.store import widget_store -from .models import ( +from .models import ( # Entry Grouping Attribute, Entry, EntryComment, EntryCommentText, + EntryGroupLabel, ExportData, FilterData, - # Entry Grouping - ProjectEntryLabel, LeadEntryGroup, - EntryGroupLabel, + ProjectEntryLabel, ) from .utils import base64_to_deep_image +from .widgets.store import widget_store logger = logging.getLogger(__name__) -class AttributeSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class AttributeSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = Attribute - fields = '__all__' + fields = "__all__" # Validations def validate_entry(self, entry): - if not entry.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Entry') + if not entry.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Entry") return entry -class FilterDataSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class FilterDataSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = FilterData - fields = '__all__' + fields = "__all__" # Validations def validate_entry(self, entry): - if not entry.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Entry') + if not entry.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Entry") return entry -class ExportDataSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class ExportDataSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = ExportData - fields = '__all__' + fields = "__all__" # Validations def validate_entry(self, entry): - if not entry.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid Entry') + if not entry.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid Entry") return entry class SimpleAttributeSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Attribute - fields = ('id', 'data', 'widget',) + fields = ( + "id", + "data", + "widget", + ) -class SimpleFilterDataSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleFilterDataSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = FilterData - fields = ('id', 'filter', 'values', 'number',) + fields = ( + "id", + "filter", + "values", + "number", + ) -class SimpleExportDataSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleExportDataSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = ExportData - fields = ('id', 'exportable', 'data',) + fields = ( + "id", + "exportable", + "data", + ) class ProjectEntryLabelSerializer(DynamicFieldsMixin, UserResourceSerializer): class Meta: model = ProjectEntryLabel - fields = '__all__' - read_only_fields = ('project',) + fields = "__all__" + read_only_fields = ("project",) def validate(self, data): - data['project_id'] = int(self.context['view'].kwargs['project_id']) + data["project_id"] = int(self.context["view"].kwargs["project_id"]) return data @@ -122,35 +135,35 @@ class EntryGroupLabelSerializer(UniqueFieldsMixin, UserResourceSerializer): group_id = serializers.PrimaryKeyRelatedField(read_only=True) label_id = serializers.PrimaryKeyRelatedField(queryset=ProjectEntryLabel.objects.all()) entry_id = serializers.PrimaryKeyRelatedField(queryset=Entry.objects.all()) - entry_client_id = serializers.CharField(source='entry.client_id', read_only=True) + entry_client_id = serializers.CharField(source="entry.client_id", read_only=True) def validate(self, data): - data['label'] = data.pop('label_id') - data['entry'] = data.pop('entry_id') + data["label"] = data.pop("label_id") + data["entry"] = data.pop("entry_id") return data class Meta: model = EntryGroupLabel - fields = ('id', 'label_id', 'group_id', 'entry_id', 'entry_client_id') + fields = ("id", "label_id", "group_id", "entry_id", "entry_client_id") class LeadEntryGroupSerializer(UserResourceSerializer): - selections = EntryGroupLabelSerializer(source='entrygrouplabel_set', many=True) + selections = EntryGroupLabelSerializer(source="entrygrouplabel_set", many=True) class Meta: model = LeadEntryGroup - fields = '__all__' - read_only_fields = ('lead',) + fields = "__all__" + read_only_fields = ("lead",) def validate(self, data): - data['lead_id'] = int(self.context['view'].kwargs['lead_id']) + data["lead_id"] = int(self.context["view"].kwargs["lead_id"]) # Custom validation check (It is disabled in EntryGroupLabelSerializer because of nested serializer issue) selections_labels = [] - for selection in self.initial_data['selections']: - selections_label = selection['label_id'] + for selection in self.initial_data["selections"]: + selections_label = selection["label_id"] if selections_label in selections_labels: - raise serializers.ValidationError('Only one entry is allowed for [Group, Label] set') + raise serializers.ValidationError("Only one entry is allowed for [Group, Label] set") selections_labels.append(selections_label) return data @@ -160,82 +173,86 @@ class EntryLeadSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): attachment = FileSerializer(read_only=True) tabular_book = serializers.SerializerMethodField() - assignee_details = SimpleUserSerializer(source='get_assignee', read_only=True) - authors_detail = SimpleOrganizationSerializer(source='authors', many=True, read_only=True) - source_detail = SimpleOrganizationSerializer(source='source', read_only=True) - confidentiality_display = serializers.CharField(source='get_confidentiality_display', read_only=True) - created_by_details = SimpleUserSerializer(source='get_created_by', read_only=True) + assignee_details = SimpleUserSerializer(source="get_assignee", read_only=True) + authors_detail = SimpleOrganizationSerializer(source="authors", many=True, read_only=True) + source_detail = SimpleOrganizationSerializer(source="source", read_only=True) + confidentiality_display = serializers.CharField(source="get_confidentiality_display", read_only=True) + created_by_details = SimpleUserSerializer(source="get_created_by", read_only=True) page_count = serializers.IntegerField( - source='leadpreview.page_count', + source="leadpreview.page_count", read_only=True, ) class Meta: model = Lead fields = ( - 'id', 'title', 'created_at', 'url', 'attachment', 'tabular_book', - 'client_id', 'assignee', 'assignee_details', 'published_on', - 'authors_detail', 'source_detail', 'confidentiality_display', - 'created_by_details', 'page_count', 'confidentiality', + "id", + "title", + "created_at", + "url", + "attachment", + "tabular_book", + "client_id", + "assignee", + "assignee_details", + "published_on", + "authors_detail", + "source_detail", + "confidentiality_display", + "created_by_details", + "page_count", + "confidentiality", ) def get_tabular_book(self, obj): file = obj.attachment - if file and hasattr(file, 'book'): + if file and hasattr(file, "book"): return file.book.id return None -class EntrySerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, DeprecatedUserResourceSerializer): +class EntrySerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, DeprecatedUserResourceSerializer): attributes = ListToDictField( child=SimpleAttributeSerializer(many=True), - key='widget', - source='attribute_set', + key="widget", + source="attribute_set", required=False, ) - project = serializers.PrimaryKeyRelatedField( - required=False, - queryset=Project.objects.all() - ) + project = serializers.PrimaryKeyRelatedField(required=False, queryset=Project.objects.all()) resolved_comment_count = serializers.SerializerMethodField() unresolved_comment_count = serializers.SerializerMethodField() project_labels = serializers.SerializerMethodField() controlled_changed_by_details = SimpleUserSerializer( - source='controlled_changed_by', + source="controlled_changed_by", read_only=True, ) - image_details = SimpleFileSerializer(source='image', read_only=True) - lead_image = serializers.PrimaryKeyRelatedField( - required=False, - write_only=True, - queryset=LeadPreviewImage.objects.all() - ) + image_details = SimpleFileSerializer(source="image", read_only=True) + lead_image = serializers.PrimaryKeyRelatedField(required=False, write_only=True, queryset=LeadPreviewImage.objects.all()) # NOTE: Provided by annotate `annotate_comment_count` verified_by_count = serializers.IntegerField(read_only=True) is_verified_by_current_user = serializers.BooleanField(read_only=True) class Meta: model = Entry - fields = '__all__' + fields = "__all__" def get_project_labels(self, entry): # Should be provided from view - label_count = self.context.get('entry_group_label_count') + label_count = self.context.get("entry_group_label_count") if label_count is not None: return label_count.get(entry.pk) or [] # Fallback return EntryGroupLabel.get_stat_for_entry(entry.entrygrouplabel_set) def get_resolved_comment_count(self, entry): - if hasattr(entry, 'resolved_comment_count'): + if hasattr(entry, "resolved_comment_count"): return entry.resolved_comment_count return entry.entrycomment_set.filter(parent=None, is_resolved=True).count() def get_unresolved_comment_count(self, entry): - if hasattr(entry, 'unresolved_comment_count'): + if hasattr(entry, "unresolved_comment_count"): return entry.unresolved_comment_count return entry.entrycomment_set.filter(parent=None, is_resolved=False).count() @@ -244,30 +261,34 @@ def validate(self, data): - Lead image is copied to deep gallery files - Raw image (base64) are saved as deep gallery files """ - request = self.context['request'] - lead = data.get('lead') or (self.instance and self.instance.lead) - image = data.get('image') - image_raw = data.pop('image_raw', None) - lead_image = data.pop('lead_image', None) + request = self.context["request"] + lead = data.get("lead") or (self.instance and self.instance.lead) + image = data.get("image") + image_raw = data.pop("image_raw", None) + lead_image = data.pop("lead_image", None) # If gallery file is provided make sure user owns the file if image: if ( - (self.instance and self.instance.image) != image and - not image.is_public and - image.created_by != self.context['request'].user + (self.instance and self.instance.image) != image + and not image.is_public + and image.created_by != self.context["request"].user ): - raise serializers.ValidationError({ - 'image': f'You don\'t have permission to attach image: {image}', - }) + raise serializers.ValidationError( + { + "image": f"You don't have permission to attach image: {image}", + } + ) return data # If lead image is provided make sure lead are same elif lead_image: if lead_image.lead != lead: - raise serializers.ValidationError({ - 'lead_image': f'You don\'t have permission to attach lead image: {lead_image}', - }) - data['image'] = lead_image.clone_as_deep_file(request.user) + raise serializers.ValidationError( + { + "lead_image": f"You don't have permission to attach lead image: {lead_image}", + } + ) + data["image"] = lead_image.clone_as_deep_file(request.user) elif image_raw: generated_image = base64_to_deep_image( image_raw, @@ -275,12 +296,12 @@ def validate(self, data): request.user, ) if isinstance(generated_image, File): - data['image'] = generated_image + data["image"] = generated_image return data def create(self, validated_data): - if validated_data.get('project') is None: - validated_data['project'] = validated_data['lead'].project + if validated_data.get("project") is None: + validated_data["project"] = validated_data["lead"].project return super().create(validated_data) @@ -290,7 +311,7 @@ def update(self, instance, validated_data): class EntryProccesedSerializer(EntrySerializer): - tabular_field_data = FieldProcessedOnlySerializer(source='tabular_field') + tabular_field_data = FieldProcessedOnlySerializer(source="tabular_field") class EntryRetriveSerializer(EntrySerializer): @@ -299,34 +320,30 @@ class EntryRetriveSerializer(EntrySerializer): class EntryRetriveProccesedSerializer(EntrySerializer): lead = EntryLeadSerializer() - tabular_field_data = FieldProcessedOnlySerializer(source='tabular_field') + tabular_field_data = FieldProcessedOnlySerializer(source="tabular_field") -class EditEntriesDataSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): - lead = LeadSerializer(source='*', read_only=True) +class EditEntriesDataSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): + lead = LeadSerializer(source="*", read_only=True) entries = serializers.SerializerMethodField() analysis_framework = AnalysisFrameworkSerializer( - source='project.analysis_framework', + source="project.analysis_framework", read_only=True, ) geo_options = serializers.SerializerMethodField() - regions = SimpleRegionSerializer( - source='project.regions', many=True, read_only=True) + regions = SimpleRegionSerializer(source="project.regions", many=True, read_only=True) entry_labels = serializers.SerializerMethodField() entry_groups = serializers.SerializerMethodField() class Meta: model = Lead - fields = ( - 'lead', 'entries', 'analysis_framework', 'geo_options', 'regions', - 'entry_labels', 'entry_groups' - ) + fields = ("lead", "entries", "analysis_framework", "geo_options", "regions", "entry_labels", "entry_groups") def get_entries(self, lead): return EntrySerializer( - Entry.annotate_comment_count(lead.entry_set), many=True, + Entry.annotate_comment_count(lead.entry_set), + many=True, context=self.context, ).data @@ -354,66 +371,69 @@ def get_entry_groups(self, lead): class ComprehensiveAttributeSerializer( - DynamicFieldsMixin, - serializers.ModelSerializer, + DynamicFieldsMixin, + serializers.ModelSerializer, ): - title = serializers.CharField(source='widget.title') - widget_id = serializers.IntegerField(source='widget.pk') - type = serializers.CharField(source='widget.widget_id') + title = serializers.CharField(source="widget.title") + widget_id = serializers.IntegerField(source="widget.pk") + type = serializers.CharField(source="widget.widget_id") value = serializers.SerializerMethodField() class Meta: model = Attribute - fields = ('id', 'title', 'widget_id', 'type', 'value') + fields = ("id", "title", "widget_id", "type", "value") def _get_default_value(self, _, widget, data, widget_data): return { - 'error': 'Unsupported Widget Type, Contact Admin', - 'raw': data, + "error": "Unsupported Widget Type, Contact Admin", + "raw": data, } def _get_initial_wigets_meta(self, instance): - projects_id = self.context['queryset'].order_by('project_id')\ - .values_list('project_id', flat=True).distinct() - regions_id = Region.objects.filter(project__in=projects_id).values_list('pk', flat=True) + projects_id = self.context["queryset"].order_by("project_id").values_list("project_id", flat=True).distinct() + regions_id = Region.objects.filter(project__in=projects_id).values_list("pk", flat=True) geo_areas = {} admin_levels = {} - geo_area_queryset = GeoArea.objects.prefetch_related('admin_level').filter( - admin_level__region__in=regions_id, - ).distinct() + geo_area_queryset = ( + GeoArea.objects.prefetch_related("admin_level") + .filter( + admin_level__region__in=regions_id, + ) + .distinct() + ) for geo_area in geo_area_queryset: geo_areas[geo_area.pk] = { - 'id': geo_area.pk, - 'title': geo_area.title, - 'pcode': geo_area.code, - 'admin_level': geo_area.admin_level_id, - 'parent': geo_area.parent_id, + "id": geo_area.pk, + "title": geo_area.title, + "pcode": geo_area.code, + "admin_level": geo_area.admin_level_id, + "parent": geo_area.parent_id, } if admin_levels.get(geo_area.admin_level_id) is None: admin_level = geo_area.admin_level admin_levels[geo_area.admin_level_id] = { - 'id': admin_level.pk, - 'level': admin_level.level, - 'title': admin_level.title, + "id": admin_level.pk, + "level": admin_level.level, + "title": admin_level.title, } return { - 'geo-widget': { - 'admin_levels': admin_levels, - 'geo_areas': geo_areas, + "geo-widget": { + "admin_levels": admin_levels, + "geo_areas": geo_areas, }, } def _get_value(self, instance): - if not hasattr(self, 'widgets_meta'): + if not hasattr(self, "widgets_meta"): self.widgets_meta = self._get_initial_wigets_meta(instance) widget = instance.widget data = instance.data or {} return getattr( widget_store.get(instance.widget.widget_id, {}), - 'get_comprehensive_data', + "get_comprehensive_data", self._get_default_value, )(self.widgets_meta, widget, data, widget.properties) @@ -421,50 +441,62 @@ def get_value(self, instance): try: return self._get_value(instance) except Exception: - logger.warning('Comprehensive Error!! (Widget:{instance})', exc_info=True) + logger.warning("Comprehensive Error!! (Widget:{instance})", exc_info=True) class ComprehensiveEntriesSerializer( - DynamicFieldsMixin, - serializers.ModelSerializer, + DynamicFieldsMixin, + serializers.ModelSerializer, ): - tabular_field = serializers.HyperlinkedRelatedField(read_only=True, view_name='tabular_field-detail') - attributes = ComprehensiveAttributeSerializer(source='attribute_set', many=True, read_only=True) + tabular_field = serializers.HyperlinkedRelatedField(read_only=True, view_name="tabular_field-detail") + attributes = ComprehensiveAttributeSerializer(source="attribute_set", many=True, read_only=True) created_by = ComprehensiveUserSerializer(read_only=True) modified_by = ComprehensiveUserSerializer(read_only=True) - original_excerpt = serializers.CharField(source='dropped_excerpt', read_only=True) + original_excerpt = serializers.CharField(source="dropped_excerpt", read_only=True) class Meta: model = Entry fields = ( - 'id', 'created_at', 'modified_at', 'entry_type', 'excerpt', 'image_raw', 'tabular_field', - 'attributes', 'created_by', 'modified_by', 'project', 'original_excerpt', + "id", + "created_at", + "modified_at", + "entry_type", + "excerpt", + "image_raw", + "tabular_field", + "attributes", + "created_by", + "modified_by", + "project", + "original_excerpt", ) class EntryCommentTextSerializer(serializers.ModelSerializer): class Meta: model = EntryCommentText - exclude = ('comment',) + exclude = ("comment",) class EntryCommentSerializer(serializers.ModelSerializer): - created_by_detail = EntryCommentUserSerializer(source='created_by', read_only=True) - assignees_detail = EntryCommentUserSerializer(source='assignees', read_only=True, many=True) + created_by_detail = EntryCommentUserSerializer(source="created_by", read_only=True) + assignees_detail = EntryCommentUserSerializer(source="assignees", read_only=True, many=True) text = serializers.CharField() - lead = serializers.IntegerField(source='entry.lead_id', read_only=True) + lead = serializers.IntegerField(source="entry.lead_id", read_only=True) text_history = EntryCommentTextSerializer( - source='entrycommenttext_set', many=True, read_only=True, + source="entrycommenttext_set", + many=True, + read_only=True, ) class Meta: model = EntryComment - fields = '__all__' - read_only_fields = ('entry', 'is_resolved', 'created_by', 'resolved_at') + fields = "__all__" + read_only_fields = ("entry", "is_resolved", "created_by", "resolved_at") def _get_entry(self): - if not hasattr(self, '_entry'): - entry = Entry.objects.get(pk=int(self.context['entry_id'])) + if not hasattr(self, "_entry"): + entry = Entry.objects.get(pk=int(self.context["entry_id"])) self._entry = entry return self._entry @@ -477,46 +509,46 @@ def add_comment_text(self, comment, text): def validate_parent(self, parent_comment): if parent_comment: if parent_comment.entry != self._get_entry(): - raise serializers.ValidationError('Selected parent comment is assigned to different entry') + raise serializers.ValidationError("Selected parent comment is assigned to different entry") return parent_comment def validate(self, data): - assignees = data.get('assignees') - data['entry'] = entry = self._get_entry() + assignees = data.get("assignees") + data["entry"] = entry = self._get_entry() # Check if all assignes are members if assignees: current_members_id = set( ProjectMembership.objects.filter(project=entry.project, member__in=assignees) - .values_list('member', flat=True) + .values_list("member", flat=True) .distinct() ) assigned_users_id = set([a.id for a in assignees]) if current_members_id != assigned_users_id: - raise serializers.ValidationError({'assignees': "Selected assignees don't belong to this project"}) + raise serializers.ValidationError({"assignees": "Selected assignees don't belong to this project"}) - is_patch = self.context['request'].method == 'PATCH' + is_patch = self.context["request"].method == "PATCH" if self.instance and self.instance.is_resolved: - raise serializers.ValidationError('Comment is resolved, no changes allowed') - parent_comment = data.get('parent') + raise serializers.ValidationError("Comment is resolved, no changes allowed") + parent_comment = data.get("parent") if parent_comment: # Reply comment if parent_comment.is_resolved: - raise serializers.ValidationError('Parent comment is resolved, no addition allowed') + raise serializers.ValidationError("Parent comment is resolved, no addition allowed") if parent_comment.parent is not None: - raise serializers.ValidationError('2-level of comment only allowed') - data['entry'] = parent_comment.entry - data['assignees'] = [] + raise serializers.ValidationError("2-level of comment only allowed") + data["entry"] = parent_comment.entry + data["assignees"] = [] else: # Root comment - if not data.get('assignees') and not is_patch: - raise serializers.ValidationError('Root comment should have at least one assignee') - data['created_by'] = self.context['request'].user + if not data.get("assignees") and not is_patch: + raise serializers.ValidationError("Root comment should have at least one assignee") + data["created_by"] = self.context["request"].user return data def comment_save(self, validated_data, instance=None): """ Comment Middleware save logic """ - text = validated_data.pop('text', '').strip() + text = validated_data.pop("text", "").strip() text_change = True if instance is None: # Create instance = super().create(validated_data) @@ -536,19 +568,26 @@ def update(self, instance, validated_data): class SimpleEntrySerializer(serializers.ModelSerializer): - image_details = SimpleFileSerializer(source='image', read_only=True) - tabular_field_data = FieldProcessedOnlySerializer(source='tabular_field') + image_details = SimpleFileSerializer(source="image", read_only=True) + tabular_field_data = FieldProcessedOnlySerializer(source="tabular_field") class Meta: model = Entry - fields = ('id', 'excerpt', 'dropped_excerpt', - 'image', 'image_details', 'entry_type', - 'tabular_field', 'tabular_field_data') + fields = ( + "id", + "excerpt", + "dropped_excerpt", + "image", + "image_details", + "entry_type", + "tabular_field", + "tabular_field_data", + ) class EntryCommentTextSerializer(serializers.ModelSerializer): class Meta: - exclude = ('comment',) + exclude = ("comment",) # --------------------- Graphql Serializers ---------------------------------------- @@ -559,66 +598,72 @@ class AttributeGqSerializer(TempClientIdMixin, serializers.ModelSerializer): class Meta: model = Attribute fields = ( - 'id', 'data', 'widget', 'widget_version', - 'client_id', # From TempClientIdMixin + "id", + "data", + "widget", + "widget_version", + "client_id", # From TempClientIdMixin ) def validate(self, validated_data): if self.instance: # For update, remove widget on save # Don't allow changing widget if instance is provided - validated_data.pop('widget', None) + validated_data.pop("widget", None) else: # For create, make sure widget is in active AF - active_af = self.context['request'].active_project.analysis_framework + active_af = self.context["request"].active_project.analysis_framework if not active_af: - raise serializers.ValidationError({ - 'widget': 'There is not active Framework attached', - }) - if not active_af.widget_set.filter(pk=validated_data['widget'].pk).exists(): - raise serializers.ValidationError({ - 'widget': "Given widget doesn't exists in Active Framework", - }) + raise serializers.ValidationError( + { + "widget": "There is not active Framework attached", + } + ) + if not active_af.widget_set.filter(pk=validated_data["widget"].pk).exists(): + raise serializers.ValidationError( + { + "widget": "Given widget doesn't exists in Active Framework", + } + ) return validated_data class EntryGqSerializer(ProjectPropertySerializerMixin, TempClientIdMixin, UserResourceSerializer): id = IntegerIDField(required=False) - attributes = AttributeGqSerializer(source='attribute_set', required=False, many=True) + attributes = AttributeGqSerializer(source="attribute_set", required=False, many=True) image_raw = serializers.CharField( required=False, write_only=True, help_text=( - 'This is used to add raw base64 images.' - ' This will be changed into gallery image and supplied back in image field.' - ) + "This is used to add raw base64 images." " This will be changed into gallery image and supplied back in image field." + ), ) lead_image = serializers.PrimaryKeyRelatedField( required=False, write_only=True, queryset=LeadPreviewImage.objects.all(), help_text=( - 'This is used to add images from Lead Preview Images.' - ' This will be changed into gallery image and supplied back in image field.' - ) + "This is used to add images from Lead Preview Images." + " This will be changed into gallery image and supplied back in image field." + ), ) class Meta: model = Entry fields = ( - 'id', - 'lead', - 'order', - 'information_date', - 'entry_type', - 'image', - 'image_raw', - 'lead_image', - 'tabular_field', - 'excerpt', - 'dropped_excerpt', - 'highlight_hidden', - 'attributes', - 'draft_entry', - 'client_id', + "id", + "lead", + "order", + "information_date", + "entry_type", + "image", + "image_raw", + "lead_image", + "tabular_field", + "excerpt", + "dropped_excerpt", + "highlight_hidden", + "attributes", + "draft_entry", + "client_id", ) # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -632,7 +677,7 @@ def validate_lead(self, lead): if lead.project_id != self.project.pk: raise serializers.ValidationError("Don't have access to this lead") if self.instance and lead != self.instance.lead: - raise serializers.ValidationError('Changing lead is not allowed') + raise serializers.ValidationError("Changing lead is not allowed") return lead def validate(self, data): @@ -640,27 +685,27 @@ def validate(self, data): - Lead image is copied to deep gallery files - Raw image (base64) are saved as deep gallery files """ - request = self.context['request'] - image = data.get('image') - image_raw = data.pop('image_raw', None) - lead_image = data.pop('lead_image', None) + request = self.context["request"] + image = data.get("image") + image_raw = data.pop("image_raw", None) + lead_image = data.pop("lead_image", None) # ---------------- Lead - lead = data['lead'] + lead = data["lead"] if self.instance and lead != self.instance.lead: - raise serializers.ValidationError({ - 'lead': 'Changing lead is not allowed' - }) + raise serializers.ValidationError({"lead": "Changing lead is not allowed"}) # ----------------- Validate Draft entry if provided - draft_entry = data.get('draft_entry') + draft_entry = data.get("draft_entry") if draft_entry and draft_entry.lead != lead: - raise serializers.ValidationError({ - 'draft_entry': 'Only attach draft entry from current lead.', - }) + raise serializers.ValidationError( + { + "draft_entry": "Only attach draft entry from current lead.", + } + ) # ---------------- Project if not self.instance: # For create only - data['project'] = self.context['request'].active_project + data["project"] = self.context["request"].active_project # -------------- Active AF id from project active_af_id = self.project.analysis_framework_id @@ -671,36 +716,36 @@ def validate(self, data): "Entry's original Framework is different from project's framework. Conflict detected.", ) else: # For update, set entry's AF with active AF - data['analysis_framework_id'] = active_af_id + data["analysis_framework_id"] = active_af_id # ---------------- Set/validate image properly # If gallery file is provided make sure user owns the file if image: - if ( - (self.instance and self.instance.image) != image and - not image.is_public and - image.created_by != request.user - ): - raise serializers.ValidationError({ - 'image': f'You don\'t have permission to attach image: {image}', - }) + if (self.instance and self.instance.image) != image and not image.is_public and image.created_by != request.user: + raise serializers.ValidationError( + { + "image": f"You don't have permission to attach image: {image}", + } + ) # If lead image is provided make sure lead are same elif lead_image: if lead_image.lead != lead: - raise serializers.ValidationError({ - 'lead_image': f'You don\'t have permission to attach lead image: {lead_image}', - }) - data['image'] = lead_image.clone_as_deep_file(request.user) + raise serializers.ValidationError( + { + "lead_image": f"You don't have permission to attach lead image: {lead_image}", + } + ) + data["image"] = lead_image.clone_as_deep_file(request.user) elif image_raw: generated_image = base64_to_deep_image(image_raw, lead, request.user) if isinstance(generated_image, File): - data['image'] = generated_image + data["image"] = generated_image return data def update(self, instance, validated_data): # once altered, unverify the entry if its controlled if instance and instance.controlled: - validated_data['controlled'] = False - validated_data['verification_last_changed_by'] = self.context['request'].user + validated_data["controlled"] = False + validated_data["verification_last_changed_by"] = self.context["request"].user return super().update(instance, validated_data) diff --git a/apps/entry/stats.py b/apps/entry/stats.py index 8c0ab7e085..3b81d8e4a3 100644 --- a/apps/entry/stats.py +++ b/apps/entry/stats.py @@ -1,73 +1,84 @@ import json +from analysis_framework.models import Filter, Widget from django.contrib.gis.db.models import Extent -from django.utils import timezone from django.db.models import Prefetch - +from django.utils import timezone from geo.models import GeoArea -from analysis_framework.models import Widget, Filter + from apps.entry.widgets.geo_widget import get_valid_geo_ids -from .models import Entry, Attribute +from .models import Attribute, Entry SUPPORTED_WIDGETS = [ - 'matrix1dWidget', 'matrix2dWidget', 'scaleWidget', 'multiselectWidget', 'organigramWidget', 'geoWidget', - 'conditionalWidget', + "matrix1dWidget", + "matrix2dWidget", + "scaleWidget", + "multiselectWidget", + "organigramWidget", + "geoWidget", + "conditionalWidget", ] def _get_lead_data(lead): if lead: return { - 'id': lead.id, - 'title': lead.title, - 'source_type': lead.source_type, - 'confidentiality': lead.confidentiality, - 'source_raw': lead.source_raw, - 'source': lead.source and { - 'id': lead.source.id, - 'title': lead.source.data.title, + "id": lead.id, + "title": lead.title, + "source_type": lead.source_type, + "confidentiality": lead.confidentiality, + "source_raw": lead.source_raw, + "source": lead.source + and { + "id": lead.source.id, + "title": lead.source.data.title, }, - 'author_raw': lead.author_raw, - 'authors': [ + "author_raw": lead.author_raw, + "authors": [ { - 'id': author.id, - 'title': author.data.title, + "id": author.id, + "title": author.data.title, # TODO: Legacy: Remove `or` logic after all the author are migrated to authors from author - } for author in lead.authors.all() or ([lead.author] if lead.author else []) + } + for author in lead.authors.all() or ([lead.author] if lead.author else []) ], } return { - 'id': None, - 'title': None, - 'source_type': None, - 'confidentiality': None, - 'source_raw': None, - 'source': None, - 'author_raw': None, - 'authors': [], + "id": None, + "title": None, + "source_type": None, + "confidentiality": None, + "source_raw": None, + "source": None, + "author_raw": None, + "authors": [], } def _get_project_geoareas(project): - qs = GeoArea.objects.filter( - admin_level__region__in=project.regions.values_list('id'), - admin_level__level__in=[0, 1, 2], - ).annotate(extent=Extent('polygons')).values('pk', 'admin_level__level', 'title', 'polygons', 'extent', 'parent') + qs = ( + GeoArea.objects.filter( + admin_level__region__in=project.regions.values_list("id"), + admin_level__level__in=[0, 1, 2], + ) + .annotate(extent=Extent("polygons")) + .values("pk", "admin_level__level", "title", "polygons", "extent", "parent") + ) geo_array = [] for geoarea in qs: - polygons = geoarea['polygons'] + polygons = geoarea["polygons"] centroid = polygons.centroid geo = { - 'id': geoarea['pk'], - 'admin_level': geoarea['admin_level__level'], - 'parent': geoarea['parent'], - 'name': geoarea['title'], - 'centroid': [centroid.x, centroid.y], - 'bounds': [geoarea['extent'][:2], geoarea['extent'][2:]], + "id": geoarea["pk"], + "admin_level": geoarea["admin_level__level"], + "parent": geoarea["parent"], + "name": geoarea["title"], + "centroid": [centroid.x, centroid.y], + "bounds": [geoarea["extent"][:2], geoarea["extent"][2:]], } - geo['polygons'] = json.loads(polygons.geojson) # TODO: + geo["polygons"] = json.loads(polygons.geojson) # TODO: geo_array.append(geo) return geo_array @@ -76,20 +87,20 @@ def _get_widget_info(config, widgets, skip_data=False, default=None): if not config and default is not None: return default - widget = widgets[config['pk']] + widget = widgets[config["pk"]] def _return(properties): return { - '_widget': widget, - 'pk': widget.pk, - 'config': config, - 'properties': properties, + "_widget": widget, + "pk": widget.pk, + "config": config, + "properties": properties, } if skip_data: return _return(None) - if widget.widget_id == 'organigramWidget': + if widget.widget_id == "organigramWidget": w_filter = Filter.objects.filter( widget_key=widget.key, analysis_framework_id=widget.analysis_framework_id, @@ -97,49 +108,48 @@ def _return(properties): return _return(w_filter.properties if w_filter else None) properties = widget.properties - if config.get('is_conditional_widget'): # TODO: Remove this + if config.get("is_conditional_widget"): # TODO: Remove this # TODO: Skipping conditional widget, in new this is not needed return default return _return(properties) def _get_attribute_widget_value(cd_widget_map, w_value, widget_type, widget_pk=None): - if widget_type in ['scaleWidget', 'multiselectWidget', 'organigramWidget']: + if widget_type in ["scaleWidget", "multiselectWidget", "organigramWidget"]: return w_value - elif widget_type == 'geoWidget': + elif widget_type == "geoWidget": # XXX: We don't need this now, as only string are stored here. Remove later. return get_valid_geo_ids(w_value) - elif widget_type == 'conditionalWidget': + elif widget_type == "conditionalWidget": cd_config = cd_widget_map.get(widget_pk) if cd_config is None: return - selected_widget_key = cd_config['widget_key'] - selected_widget_type = cd_config['widget_type'] - selected_widget_pk = cd_config.get('widget_pk') - return _get_attribute_widget_value( - cd_widget_map, - w_value[selected_widget_key]['data']['value'], - selected_widget_type, - selected_widget_pk, - ) if w_value.get(selected_widget_key) else None - elif widget_type in ['matrix1dWidget', 'matrix2dWidget']: - context_keys = [ - f'{widget_pk}-{_value}' - for _value in ( - w_value.keys() if isinstance(w_value, dict) else [] + selected_widget_key = cd_config["widget_key"] + selected_widget_type = cd_config["widget_type"] + selected_widget_pk = cd_config.get("widget_pk") + return ( + _get_attribute_widget_value( + cd_widget_map, + w_value[selected_widget_key]["data"]["value"], + selected_widget_type, + selected_widget_pk, ) - ] + if w_value.get(selected_widget_key) + else None + ) + elif widget_type in ["matrix1dWidget", "matrix2dWidget"]: + context_keys = [f"{widget_pk}-{_value}" for _value in (w_value.keys() if isinstance(w_value, dict) else [])] sectors_keys = [] - if widget_type == 'matrix2dWidget': # Collect sector data from here + if widget_type == "matrix2dWidget": # Collect sector data from here sectors_keys = [ - [f'{widget_pk}-{pillar_key}', subpillar_key, sector_key] + [f"{widget_pk}-{pillar_key}", subpillar_key, sector_key] for pillar_key, pillar in w_value.items() for subpillar_key, subpillar in pillar.items() for sector_key in subpillar.keys() ] return { - 'context_keys': context_keys, - 'sectors_keys': sectors_keys, + "context_keys": context_keys, + "sectors_keys": sectors_keys, } @@ -148,10 +158,10 @@ def _get_attribute_data(collector, attribute, cd_widget_map): widget_pk = attribute.widget.pk data = attribute.data - if widget_type not in SUPPORTED_WIDGETS or not data or not data.get('value'): + if widget_type not in SUPPORTED_WIDGETS or not data or not data.get("value"): return - collector[widget_pk] = _get_attribute_widget_value(cd_widget_map, data['value'], widget_type, widget_pk) + collector[widget_pk] = _get_attribute_widget_value(cd_widget_map, data["value"], widget_type, widget_pk) def get_project_entries_stats(project, skip_geo_data=False): @@ -191,34 +201,27 @@ def get_project_entries_stats(project, skip_geo_data=False): """ af = project.analysis_framework - config = af.properties.get('stats_config') + config = af.properties.get("stats_config") - widgets_pk = [ - info['pk'] - for _info in config.values() - for info in (_info if isinstance(_info, list) else [_info]) - ] + widgets_pk = [info["pk"] for _info in config.values() for info in (_info if isinstance(_info, list) else [_info])] cd_widget_map = { - w_config['pk']: w_config + w_config["pk"]: w_config for _w_config in config.values() for w_config in (_w_config if isinstance(_w_config, list) else [_w_config]) - if w_config.get('is_conditional_widget') + if w_config.get("is_conditional_widget") } - widgets = { - widget.pk: widget - for widget in Widget.objects.filter(pk__in=widgets_pk, analysis_framework=af) - } + widgets = {widget.pk: widget for widget in Widget.objects.filter(pk__in=widgets_pk, analysis_framework=af)} w_reliability_default = w_severity_default = w_multiselect_widget_default = w_organigram_widget_default = { - 'pk': None, - 'properties': { - 'options': [], + "pk": None, + "properties": { + "options": [], }, } - w1ds = [_get_widget_info(_config, widgets) for _config in config['widget_1d'] or []] - w2ds = [_get_widget_info(_config, widgets) for _config in config['widget_2d'] or []] + w1ds = [_get_widget_info(_config, widgets) for _config in config["widget_1d"] or []] + w2ds = [_get_widget_info(_config, widgets) for _config in config["widget_2d"] or []] w_multiselect_widgets = [ _get_widget_info( @@ -226,7 +229,7 @@ def get_project_entries_stats(project, skip_geo_data=False): widgets, default=w_multiselect_widget_default, ) - for _config in config.get('multiselect_widgets') or [] + for _config in config.get("multiselect_widgets") or [] ] w_organigram_widgets = [ @@ -235,161 +238,161 @@ def get_project_entries_stats(project, skip_geo_data=False): widgets, default=w_organigram_widget_default, ) - for _config in config.get('organigram_widgets') or [] + for _config in config.get("organigram_widgets") or [] ] - w_severity = _get_widget_info(config.get('severity_widget'), widgets, default=w_severity_default) - w_reliability = _get_widget_info(config.get('reliability_widget'), widgets, default=w_reliability_default) + w_severity = _get_widget_info(config.get("severity_widget"), widgets, default=w_severity_default) + w_reliability = _get_widget_info(config.get("reliability_widget"), widgets, default=w_reliability_default) - w_geo = _get_widget_info(config['geo_widget'], widgets, skip_data=True) + w_geo = _get_widget_info(config["geo_widget"], widgets, skip_data=True) matrix_widgets = [ - {'id': w['pk'], 'type': w_type, 'title': w['_widget'].title} - for widgets, w_type in [[w1ds, 'widget_1d'], [w2ds, 'widget_2d']] + {"id": w["pk"], "type": w_type, "title": w["_widget"].title} + for widgets, w_type in [[w1ds, "widget_1d"], [w2ds, "widget_2d"]] for w in widgets ] multiselect_widgets = [ { - 'id': w['pk'], - 'title': w['_widget'].title, + "id": w["pk"], + "title": w["_widget"].title, } for w in w_multiselect_widgets ] organigram_widgets = [ { - 'id': w['pk'], - 'title': w['_widget'].title, + "id": w["pk"], + "title": w["_widget"].title, } for w in w_organigram_widgets ] context_array = [ { - 'id': f"{w['pk']}-{dimension[id_key]}", - 'widget_id': w['pk'], - 'name': dimension['label'], - 'color': dimension.get('color'), - } for id_key, w, dimensions in [ - *[('key', w1d, w1d['properties']['rows']) for w1d in w1ds], - *[('key', w2d, w2d['properties']['rows']) for w2d in w2ds], - ] for dimension in dimensions + "id": f"{w['pk']}-{dimension[id_key]}", + "widget_id": w["pk"], + "name": dimension["label"], + "color": dimension.get("color"), + } + for id_key, w, dimensions in [ + *[("key", w1d, w1d["properties"]["rows"]) for w1d in w1ds], + *[("key", w2d, w2d["properties"]["rows"]) for w2d in w2ds], + ] + for dimension in dimensions ] framework_groups_array = [ { - 'id': subdimension['key'], - 'widget_id': w2d['pk'], - 'context_id': f"{w2d['pk']}-{dimension['key']}", - 'name': subdimension['label'], - 'tooltip': subdimension.get('tooltip'), + "id": subdimension["key"], + "widget_id": w2d["pk"], + "context_id": f"{w2d['pk']}-{dimension['key']}", + "name": subdimension["label"], + "tooltip": subdimension.get("tooltip"), } for w2d in w2ds - for dimension in w2d['properties']['rows'] - for subdimension in dimension['subRows'] + for dimension in w2d["properties"]["rows"] + for subdimension in dimension["subRows"] ] sector_array = [ { - 'id': sector['key'], - 'widget_id': w2d['pk'], - 'name': sector['label'], - 'tooltip': sector.get('tooltip'), + "id": sector["key"], + "widget_id": w2d["pk"], + "name": sector["label"], + "tooltip": sector.get("tooltip"), } for w2d in w2ds - for sector in w2d['properties']['columns'] + for sector in w2d["properties"]["columns"] ] organigram_array = [ { - 'id': option['key'], - 'name': option['label'], + "id": option["key"], + "name": option["label"], } for _widget in w_organigram_widgets - for option in _widget['properties']['options'] + for option in _widget["properties"]["options"] ] multiselect_array = [ { - 'id': option['key'], - 'widget_id': _widget['pk'], - 'name': option['label'], + "id": option["key"], + "widget_id": _widget["pk"], + "name": option["label"], } for _widget in w_multiselect_widgets - for option in _widget['properties']['options'] + for option in _widget["properties"]["options"] ] severity_units = [ { - 'id': severity['key'], - 'color': severity.get('color'), - 'name': severity['label'], - } for severity in w_severity['properties']['options'] + "id": severity["key"], + "color": severity.get("color"), + "name": severity["label"], + } + for severity in w_severity["properties"]["options"] ] reliability_units = [ { - 'id': reliability['key'], - 'color': reliability.get('color'), - 'name': reliability['label'], - } for reliability in w_reliability['properties']['options'] + "id": reliability["key"], + "color": reliability.get("color"), + "name": reliability["label"], + } + for reliability in w_reliability["properties"]["options"] ] meta = { - 'data_calculated': timezone.now(), - 'matrix_widgets': matrix_widgets, - 'multiselect_widgets': multiselect_widgets, - 'organigram_widgets': organigram_widgets, - 'context_array': context_array, - 'framework_groups_array': framework_groups_array, - 'sector_array': sector_array, - 'multiselect_array': multiselect_array, - 'organigram_array': organigram_array, - 'severity_units': severity_units, - 'reliability_units': reliability_units, + "data_calculated": timezone.now(), + "matrix_widgets": matrix_widgets, + "multiselect_widgets": multiselect_widgets, + "organigram_widgets": organigram_widgets, + "context_array": context_array, + "framework_groups_array": framework_groups_array, + "sector_array": sector_array, + "multiselect_array": multiselect_array, + "organigram_array": organigram_array, + "severity_units": severity_units, + "reliability_units": reliability_units, } if not skip_geo_data: - meta['geo_array'] = _get_project_geoareas(project) + meta["geo_array"] = _get_project_geoareas(project) data = [] entries = Entry.objects.filter(project=project).prefetch_related( Prefetch( - 'attribute_set', + "attribute_set", queryset=Attribute.objects.filter(widget_id__in=widgets_pk), ), - 'attribute_set__widget', - 'lead', + "attribute_set__widget", + "lead", ) for entry in entries.all(): collector = {} for attribute in entry.attribute_set.all(): _get_attribute_data(collector, attribute, cd_widget_map) - data.append({ - 'pk': entry.pk, - 'created_date': entry.created_at, - 'lead': _get_lead_data(entry.lead), - 'date': entry.lead.published_on, - 'severity': collector.get(w_severity['pk']), - 'reliability': collector.get(w_reliability['pk']), - 'geo': collector.get(w_geo['pk'], []), - 'multiselect': { - _config['pk']: collector.get(_config['pk'], []) - for _config in w_multiselect_widgets - }, - 'organigram': { - _config['pk']: collector.get(_config['pk'], []) - for _config in w_organigram_widgets - }, - 'context_sector': { - w['pk']: { - 'context': collector.get(w['pk'], {}).get('context_keys', []), - 'sector': collector.get(w['pk'], {}).get('sectors_keys', []), - } - for w in [*w1ds, *w2ds] - }, - }) + data.append( + { + "pk": entry.pk, + "created_date": entry.created_at, + "lead": _get_lead_data(entry.lead), + "date": entry.lead.published_on, + "severity": collector.get(w_severity["pk"]), + "reliability": collector.get(w_reliability["pk"]), + "geo": collector.get(w_geo["pk"], []), + "multiselect": {_config["pk"]: collector.get(_config["pk"], []) for _config in w_multiselect_widgets}, + "organigram": {_config["pk"]: collector.get(_config["pk"], []) for _config in w_organigram_widgets}, + "context_sector": { + w["pk"]: { + "context": collector.get(w["pk"], {}).get("context_keys", []), + "sector": collector.get(w["pk"], {}).get("sectors_keys", []), + } + for w in [*w1ds, *w2ds] + }, + } + ) return { - 'meta': meta, - 'data': data, + "meta": meta, + "data": data, } diff --git a/apps/entry/tests/entry_widget_test_data.py b/apps/entry/tests/entry_widget_test_data.py index fd4fa99eb4..957c6bd09d 100644 --- a/apps/entry/tests/entry_widget_test_data.py +++ b/apps/entry/tests/entry_widget_test_data.py @@ -1,423 +1,468 @@ # NOTE: This structure and value are set through https://github.com/the-deep/client WIDGET_PROPERTIES = { - 'selectWidget': { - 'options': [ - {'key': 'option-1', 'label': 'Option 1'}, - {'key': 'option-2', 'label': 'Option 2'}, - {'key': 'option-3', 'label': 'Option 3'} + "selectWidget": { + "options": [ + {"key": "option-1", "label": "Option 1"}, + {"key": "option-2", "label": "Option 2"}, + {"key": "option-3", "label": "Option 3"}, ] }, - 'multiselectWidget': { - 'options': [ - {'key': 'option-1', 'label': 'Option 1'}, - {'key': 'option-2', 'label': 'Option 2'}, - {'key': 'option-3', 'label': 'Option 3'} + "multiselectWidget": { + "options": [ + {"key": "option-1", "label": "Option 1"}, + {"key": "option-2", "label": "Option 2"}, + {"key": "option-3", "label": "Option 3"}, ] }, - 'scaleWidget': { - 'defaultValue': 'scale-1', - 'options': [ - {'key': 'scale-1', 'color': '#470000', 'label': 'Scale 1'}, - {'key': 'scale-2', 'color': '#a40000', 'label': 'Scale 2'}, - {'key': 'scale-3', 'color': '#d40000', 'label': 'Scale 3'} - ] + "scaleWidget": { + "defaultValue": "scale-1", + "options": [ + {"key": "scale-1", "color": "#470000", "label": "Scale 1"}, + {"key": "scale-2", "color": "#a40000", "label": "Scale 2"}, + {"key": "scale-3", "color": "#d40000", "label": "Scale 3"}, + ], }, - 'organigramWidget': { - 'key': 'base', - 'label': 'Base Node', - 'children': [{ - 'key': 'node-1', - 'label': 'Node 1', - 'children': [{ - 'key': 'node-2', - 'label': 'Node 2', - 'children': [ - {'key': 'node-3', 'label': 'Node 3', 'children': []}, - {'key': 'node-4', 'label': 'Node 4', 'children': []}, - {'key': 'node-5', 'label': 'Node 5', 'children': []}, + "organigramWidget": { + "key": "base", + "label": "Base Node", + "children": [ + { + "key": "node-1", + "label": "Node 1", + "children": [ { - 'key': 'node-6', - 'label': 'Node 6', - 'children': [{ - 'key': 'node-7', - 'label': 'Node 7', - 'children': [ - {'key': 'node-8', 'label': 'Node 8', 'children': []} - ] - }] + "key": "node-2", + "label": "Node 2", + "children": [ + {"key": "node-3", "label": "Node 3", "children": []}, + {"key": "node-4", "label": "Node 4", "children": []}, + {"key": "node-5", "label": "Node 5", "children": []}, + { + "key": "node-6", + "label": "Node 6", + "children": [ + { + "key": "node-7", + "label": "Node 7", + "children": [{"key": "node-8", "label": "Node 8", "children": []}], + } + ], + }, + ], } - ] - }] - }] + ], + } + ], }, - - 'matrix1dWidget': { - 'rows': [ + "matrix1dWidget": { + "rows": [ { - 'key': 'pillar-1', - 'cells': [ - {'key': 'subpillar-1', 'label': 'Politics', 'tooltip': ''}, - {'key': 'subpillar-2', 'label': 'Security', 'tooltip': 'Secure is good'}, - {'key': 'subpillar-3', 'label': 'Legal & Policy'}, - {'key': 'subpillar-4', 'label': 'Demography'}, - {'key': 'subpillar-5', 'label': 'Economy'}, - {'key': 'subpillar-5', 'label': 'Socio Cultural'}, - {'key': 'subpillar-7', 'label': 'Environment'}, + "key": "pillar-1", + "cells": [ + {"key": "subpillar-1", "label": "Politics", "tooltip": ""}, + {"key": "subpillar-2", "label": "Security", "tooltip": "Secure is good"}, + {"key": "subpillar-3", "label": "Legal & Policy"}, + {"key": "subpillar-4", "label": "Demography"}, + {"key": "subpillar-5", "label": "Economy"}, + {"key": "subpillar-5", "label": "Socio Cultural"}, + {"key": "subpillar-7", "label": "Environment"}, ], - 'color': '#c26b27', - 'label': 'Context', - 'tooltip': 'Information about the environment in which humanitarian actors operates and the crisis happen', # noqa E501 - }, { - 'key': 'pillar-2', - 'cells': [ - {'key': 'subpillar-8', 'label': 'Affected Groups'}, - {'key': 'subpillar-9', 'label': 'Population Movement'}, - {'key': 'subpillar-10', 'label': 'Push/Pull Factors'}, - {'key': 'subpillar-11', 'label': 'Casualties'}, + "color": "#c26b27", + "label": "Context", + "tooltip": "Information about the environment in which humanitarian actors operates and the crisis happen", # noqa E501 + }, + { + "key": "pillar-2", + "cells": [ + {"key": "subpillar-8", "label": "Affected Groups"}, + {"key": "subpillar-9", "label": "Population Movement"}, + {"key": "subpillar-10", "label": "Push/Pull Factors"}, + {"key": "subpillar-11", "label": "Casualties"}, ], - 'color': '#efaf78', - 'label': 'Humanitarian Profile', - 'tooltip': 'Information related to the population affected, including affected residents and displaced people', # noqa E501 - }, { - 'key': 'pillar-3', - 'cells': [ - {'key': 'subpillar-12', 'label': 'Relief to Beneficiaries'}, - {'key': 'subpillar-13', 'label': 'Beneficiaries to Relief'}, - {'key': 'subpillar-14', 'label': 'Physical Constraints'}, - {'key': 'subpillar-15', 'label': 'Humanitarian Access Gaps'}, + "color": "#efaf78", + "label": "Humanitarian Profile", + "tooltip": "Information related to the population affected, including affected residents and displaced people", # noqa E501 + }, + { + "key": "pillar-3", + "cells": [ + {"key": "subpillar-12", "label": "Relief to Beneficiaries"}, + {"key": "subpillar-13", "label": "Beneficiaries to Relief"}, + {"key": "subpillar-14", "label": "Physical Constraints"}, + {"key": "subpillar-15", "label": "Humanitarian Access Gaps"}, ], - 'color': '#b9b2a5', - 'label': 'Humanitarian Access', - 'tooltip': 'Information related to restrictions and constraints in accessing or being accessed by people in need', # noqa E501 - }, { - 'key': 'pillar-4', - 'cells': [ - {'key': 'subpillar-16', 'label': 'Communication Means & Channels'}, - {'key': 'subpillar-17', 'label': 'Information Challenges'}, - {'key': 'subpillar-18', 'label': 'Information Needs & Gaps'}, + "color": "#b9b2a5", + "label": "Humanitarian Access", + "tooltip": "Information related to restrictions and constraints in accessing or being accessed by people in need", # noqa E501 + }, + { + "key": "pillar-4", + "cells": [ + {"key": "subpillar-16", "label": "Communication Means & Channels"}, + {"key": "subpillar-17", "label": "Information Challenges"}, + {"key": "subpillar-18", "label": "Information Needs & Gaps"}, ], - 'color': '#9bd65b', - 'label': 'Information', - 'tooltip': 'Information about information, including communication means, information challenges and information needs', # noqa E501 - }] + "color": "#9bd65b", + "label": "Information", + "tooltip": "Information about information, including communication means, information challenges and information needs", # noqa E501 + }, + ] }, - - 'matrix2dWidget': { - 'columns': [ - {'key': 'sector-9', 'label': 'Cross', 'tooltip': 'Cross sectoral information', 'subColumns': []}, - {'key': 'sector-0', 'label': 'Food', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-1', 'label': 'Livelihoods', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-2', 'label': 'Health', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-3', 'label': 'Nutrition', 'tooltip': '...', 'subColumns': []}, + "matrix2dWidget": { + "columns": [ + {"key": "sector-9", "label": "Cross", "tooltip": "Cross sectoral information", "subColumns": []}, + {"key": "sector-0", "label": "Food", "tooltip": "...", "subColumns": []}, + {"key": "sector-1", "label": "Livelihoods", "tooltip": "...", "subColumns": []}, + {"key": "sector-2", "label": "Health", "tooltip": "...", "subColumns": []}, + {"key": "sector-3", "label": "Nutrition", "tooltip": "...", "subColumns": []}, { - 'key': 'sector-4', - 'label': 'WASH', - 'tooltip': '...', - 'subColumns': [ - {'key': 'subsector-1', 'label': 'Water'}, - {'key': 'subsector-2', 'label': 'Sanitation'}, - {'key': 'subsector-3', 'label': 'Hygiene'}, - {'key': 'subsector-4', 'label': 'Waste management', 'tooltip': ''}, - {'key': 'subsector-5', 'label': 'Vector control', 'tooltip': ''} - ] + "key": "sector-4", + "label": "WASH", + "tooltip": "...", + "subColumns": [ + {"key": "subsector-1", "label": "Water"}, + {"key": "subsector-2", "label": "Sanitation"}, + {"key": "subsector-3", "label": "Hygiene"}, + {"key": "subsector-4", "label": "Waste management", "tooltip": ""}, + {"key": "subsector-5", "label": "Vector control", "tooltip": ""}, + ], }, - {'key': 'sector-5', 'label': 'Shelter', 'tooltip': '...', 'subColumns': []}, + {"key": "sector-5", "label": "Shelter", "tooltip": "...", "subColumns": []}, { - 'key': 'sector-7', - 'label': 'Education', - 'tooltip': '.....', - 'subColumns': [ - {'key': 'subsector-6', 'label': 'Learning Environment', 'tooltip': ''}, - {'key': 'subsector-7', 'label': 'Teaching and Learning', 'tooltip': ''}, - {'key': 'subsector-8', 'label': 'Teachers and Education Personnel', 'tooltip': ''}, - ] + "key": "sector-7", + "label": "Education", + "tooltip": ".....", + "subColumns": [ + {"key": "subsector-6", "label": "Learning Environment", "tooltip": ""}, + {"key": "subsector-7", "label": "Teaching and Learning", "tooltip": ""}, + {"key": "subsector-8", "label": "Teachers and Education Personnel", "tooltip": ""}, + ], }, - {'key': 'sector-8', 'label': 'Protection', 'tooltip': '', 'subColumns': []}, - {'key': 'sector-10', 'label': 'Agriculture', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-11', 'label': 'Logistics', 'tooltip': '...', 'subColumns': []} + {"key": "sector-8", "label": "Protection", "tooltip": "", "subColumns": []}, + {"key": "sector-10", "label": "Agriculture", "tooltip": "...", "subColumns": []}, + {"key": "sector-11", "label": "Logistics", "tooltip": "...", "subColumns": []}, ], - 'rows': [ + "rows": [ { - 'key': 'dimension-0', - 'color': '#eae285', - 'label': 'Scope & Scale', - 'tooltip': 'Information about the direct and indirect impact of the disaster or crisis', - 'subRows': [ - {'key': 'subdimension-0', 'label': 'Drivers/Aggravating Factors', 'tooltip': '...'}, - {'key': 'subdimension-3', 'label': 'System Disruption', 'tooltip': '...'}, - {'key': 'subdimension-4', 'label': 'Damages & Losses', 'tooltip': '...'}, - {'key': 'subdimension-6', 'label': 'Lessons Learnt', 'tooltip': '...'} - ] + "key": "dimension-0", + "color": "#eae285", + "label": "Scope & Scale", + "tooltip": "Information about the direct and indirect impact of the disaster or crisis", + "subRows": [ + {"key": "subdimension-0", "label": "Drivers/Aggravating Factors", "tooltip": "..."}, + {"key": "subdimension-3", "label": "System Disruption", "tooltip": "..."}, + {"key": "subdimension-4", "label": "Damages & Losses", "tooltip": "..."}, + {"key": "subdimension-6", "label": "Lessons Learnt", "tooltip": "..."}, + ], }, { - 'key': 'dimension-1', - 'color': '#fba855', - 'label': 'Humanitarian Conditions', - 'tooltip': '...', - 'subRows': [ - {'key': 'subdimension-1', 'label': 'Living Standards', 'tooltip': '...'}, - {'key': 'us9kizxxwha7cpgb', 'label': 'Coping Mechanisms', 'tooltip': ''}, - {'key': 'subdimension-7', 'label': 'Physical & mental wellbeing', 'tooltip': '..'}, - {'key': 'subdimension-8', 'label': 'Risks & Vulnerabilities', 'tooltip': '...'}, - {'key': 'ejve4vklgge9ysxm', 'label': 'People with Specific Needs', 'tooltip': ''}, - {'key': 'subdimension-10', 'label': 'Unmet Needs', 'tooltip': '...'}, - {'key': 'subdimension-16', 'label': 'Lessons Learnt', 'tooltip': '...'}, - ] + "key": "dimension-1", + "color": "#fba855", + "label": "Humanitarian Conditions", + "tooltip": "...", + "subRows": [ + {"key": "subdimension-1", "label": "Living Standards", "tooltip": "..."}, + {"key": "us9kizxxwha7cpgb", "label": "Coping Mechanisms", "tooltip": ""}, + {"key": "subdimension-7", "label": "Physical & mental wellbeing", "tooltip": ".."}, + {"key": "subdimension-8", "label": "Risks & Vulnerabilities", "tooltip": "..."}, + {"key": "ejve4vklgge9ysxm", "label": "People with Specific Needs", "tooltip": ""}, + {"key": "subdimension-10", "label": "Unmet Needs", "tooltip": "..."}, + {"key": "subdimension-16", "label": "Lessons Learnt", "tooltip": "..."}, + ], }, { - 'key': 'dimension-2', - 'color': '#92c5f6', - 'label': 'Capacities & Response', - 'tooltip': '...', - 'subRows': [ - {'key': '7iiastsikxackbrt', 'label': 'System Functionality', 'tooltip': '...'}, - {'key': 'subdimension-11', 'label': 'Government', 'tooltip': '...'}, - {'key': 'drk4j92jwvmck7dc', 'label': 'LNGO', 'tooltip': '...'}, - {'key': 'subdimension-12', 'label': 'International', 'tooltip': '...'}, - {'key': 'subdimension-14', 'label': 'Response Gaps', 'tooltip': '...'}, - {'key': 'subdimension-15', 'label': 'Lessons Learnt', 'tooltip': '...'}, - ] - } - ] + "key": "dimension-2", + "color": "#92c5f6", + "label": "Capacities & Response", + "tooltip": "...", + "subRows": [ + {"key": "7iiastsikxackbrt", "label": "System Functionality", "tooltip": "..."}, + {"key": "subdimension-11", "label": "Government", "tooltip": "..."}, + {"key": "drk4j92jwvmck7dc", "label": "LNGO", "tooltip": "..."}, + {"key": "subdimension-12", "label": "International", "tooltip": "..."}, + {"key": "subdimension-14", "label": "Response Gaps", "tooltip": "..."}, + {"key": "subdimension-15", "label": "Lessons Learnt", "tooltip": "..."}, + ], + }, + ], }, - - 'dateWidget': { - 'information_date_selected': False, + "dateWidget": { + "information_date_selected": False, }, - 'numberWidget': { - 'maxValue': 0, - 'minvalue': 12, + "numberWidget": { + "maxValue": 0, + "minvalue": 12, }, - 'dateRangeWidget': {}, - 'timeWidget': {}, - 'timeRangeWidget': {}, - 'textWidget': {}, + "dateRangeWidget": {}, + "timeWidget": {}, + "timeRangeWidget": {}, + "textWidget": {}, } # NOTE: This structure and value are set through https://github.com/the-deep/client # c_response is for comprehensive API widget response ATTRIBUTE_DATA = { - 'selectWidget': [{ - 'data': {'value': 'option-3'}, - 'c_response': 'Option 3', - }, { - 'data': {'value': 'option-5'}, - 'c_response': None, - }], - - 'multiselectWidget': [{ - 'data': {'value': ['option-3', 'option-1']}, - 'c_response': ['Option 3', 'Option 1'], - }, { - 'data': {'value': ['option-5', 'option-1']}, - 'c_response': ['Option 1'], - }], - - 'scaleWidget': [{ - 'data': {'value': 'scale-1'}, - 'c_response': { - 'min': {'key': 'scale-1', 'color': '#470000', 'label': 'Scale 1'}, - 'max': {'key': 'scale-3', 'color': '#d40000', 'label': 'Scale 3'}, - 'label': 'Scale 1', - 'index': 1, + "selectWidget": [ + { + "data": {"value": "option-3"}, + "c_response": "Option 3", }, - }, { - 'data': {'value': 'scale-5'}, - 'c_response': { - 'min': {'key': 'scale-1', 'color': '#470000', 'label': 'Scale 1'}, - 'max': {'key': 'scale-3', 'color': '#d40000', 'label': 'Scale 3'}, - 'label': None, - 'index': None, + { + "data": {"value": "option-5"}, + "c_response": None, }, - }], - - 'dateWidget': [{ - 'data': {'value': '2019-06-25'}, - 'c_response': '25-06-2019', - }, { - 'data': {'value': None}, - 'c_response': None, - }], - - 'dateRangeWidget': [{ - 'data': {'value': {'startDate': '2012-06-25', 'endDate': '2019-06-22'}}, - 'c_response': { - 'from': '25-06-2012', - 'to': '22-06-2019', + ], + "multiselectWidget": [ + { + "data": {"value": ["option-3", "option-1"]}, + "c_response": ["Option 3", "Option 1"], }, - }], - - 'timeWidget': [{ - 'data': {'value': '22:34:00'}, - 'c_response': '22:34', - }, { - 'data': {'value': None}, - 'c_response': None, - }], - - 'numberWidget': [{ - 'data': {'value': '12'}, - 'c_response': '12', - }, { - 'data': {'value': None}, - 'c_response': None, - }], - - 'textWidget': [{ - 'data': {'value': 'This is a sample text'}, - 'c_response': 'This is a sample text', - }, { - 'data': {'value': None}, - 'c_response': '', - }], - - 'matrix1dWidget': [{ - 'data': { - 'value': { - 'pillar-2': {'subpillar-8': True}, - 'pillar-1': {'subpillar-7': False}, - 'pillar-4': {'subpillar-18': True}, + { + "data": {"value": ["option-5", "option-1"]}, + "c_response": ["Option 1"], + }, + ], + "scaleWidget": [ + { + "data": {"value": "scale-1"}, + "c_response": { + "min": {"key": "scale-1", "color": "#470000", "label": "Scale 1"}, + "max": {"key": "scale-3", "color": "#d40000", "label": "Scale 3"}, + "label": "Scale 1", + "index": 1, }, }, - 'c_response': [{ - 'id': 'subpillar-8', - 'value': 'Affected Groups', - 'row': { - 'id': 'pillar-2', - 'title': 'Humanitarian Profile', + { + "data": {"value": "scale-5"}, + "c_response": { + "min": {"key": "scale-1", "color": "#470000", "label": "Scale 1"}, + "max": {"key": "scale-3", "color": "#d40000", "label": "Scale 3"}, + "label": None, + "index": None, }, - }, { - 'id': 'subpillar-18', - 'value': 'Information Needs & Gaps', - 'row': { - 'id': 'pillar-4', - 'title': 'Information', + }, + ], + "dateWidget": [ + { + "data": {"value": "2019-06-25"}, + "c_response": "25-06-2019", + }, + { + "data": {"value": None}, + "c_response": None, + }, + ], + "dateRangeWidget": [ + { + "data": {"value": {"startDate": "2012-06-25", "endDate": "2019-06-22"}}, + "c_response": { + "from": "25-06-2012", + "to": "22-06-2019", }, - }], - }, { - 'data': { - 'value': { - 'pillar-2': {'subpillar-8': True}, - 'pillar-1': {'subpillar-12': False}, - 'pillar-4': {'subpillar-122': True}, + } + ], + "timeWidget": [ + { + "data": {"value": "22:34:00"}, + "c_response": "22:34", + }, + { + "data": {"value": None}, + "c_response": None, + }, + ], + "numberWidget": [ + { + "data": {"value": "12"}, + "c_response": "12", + }, + { + "data": {"value": None}, + "c_response": None, + }, + ], + "textWidget": [ + { + "data": {"value": "This is a sample text"}, + "c_response": "This is a sample text", + }, + { + "data": {"value": None}, + "c_response": "", + }, + ], + "matrix1dWidget": [ + { + "data": { + "value": { + "pillar-2": {"subpillar-8": True}, + "pillar-1": {"subpillar-7": False}, + "pillar-4": {"subpillar-18": True}, + }, }, + "c_response": [ + { + "id": "subpillar-8", + "value": "Affected Groups", + "row": { + "id": "pillar-2", + "title": "Humanitarian Profile", + }, + }, + { + "id": "subpillar-18", + "value": "Information Needs & Gaps", + "row": { + "id": "pillar-4", + "title": "Information", + }, + }, + ], }, - 'c_response': [{ - 'id': 'subpillar-8', - 'value': 'Affected Groups', - 'row': { - 'id': 'pillar-2', - 'title': 'Humanitarian Profile', + { + "data": { + "value": { + "pillar-2": {"subpillar-8": True}, + "pillar-1": {"subpillar-12": False}, + "pillar-4": {"subpillar-122": True}, + }, }, - }], - }], - - 'matrix2dWidget': [{ - 'data': { - 'value': { - 'dimension-0': { - 'subdimension-4': { - 'sector-1': [], - 'sector-4': ['subsector-2', 'subsector-4'], - 'sector-7': ['subsector-8', 'subsector-6'] - } + "c_response": [ + { + "id": "subpillar-8", + "value": "Affected Groups", + "row": { + "id": "pillar-2", + "title": "Humanitarian Profile", + }, } - }, + ], }, - 'c_response': [{ - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-1', 'title': 'Livelihoods'}, - 'subsectors': [] - }, { - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-4', 'title': 'WASH'}, - 'subsectors': [ - {'id': 'subsector-2', 'title': 'Sanitation'}, - {'id': 'subsector-4', 'title': 'Waste management'} - ] - }, { - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-7', 'title': 'Education'}, - 'subsectors': [ - {'id': 'subsector-8', 'title': 'Teachers and Education Personnel'}, - {'id': 'subsector-6', 'title': 'Learning Environment'} - ] - }], - }, { - 'data': { - 'value': { - 'dimension-0': { - 'subdimension-4': { - 'sector-1': [], - 'sector-4': ['subsector-10', 'subsector-4'], - 'sector-7': ['subsector-4', 'subsector-122'] + ], + "matrix2dWidget": [ + { + "data": { + "value": { + "dimension-0": { + "subdimension-4": { + "sector-1": [], + "sector-4": ["subsector-2", "subsector-4"], + "sector-7": ["subsector-8", "subsector-6"], + } } }, - 'dimension-1': { - 'subdimension-9': { - 'sector-1': [], - } + }, + "c_response": [ + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-1", "title": "Livelihoods"}, + "subsectors": [], + }, + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-4", "title": "WASH"}, + "subsectors": [ + {"id": "subsector-2", "title": "Sanitation"}, + {"id": "subsector-4", "title": "Waste management"}, + ], + }, + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-7", "title": "Education"}, + "subsectors": [ + {"id": "subsector-8", "title": "Teachers and Education Personnel"}, + {"id": "subsector-6", "title": "Learning Environment"}, + ], + }, + ], + }, + { + "data": { + "value": { + "dimension-0": { + "subdimension-4": { + "sector-1": [], + "sector-4": ["subsector-10", "subsector-4"], + "sector-7": ["subsector-4", "subsector-122"], + } + }, + "dimension-1": { + "subdimension-9": { + "sector-1": [], + } + }, }, }, + "c_response": [ + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-1", "title": "Livelihoods"}, + "subsectors": [], + }, + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-4", "title": "WASH"}, + "subsectors": [{"id": "subsector-4", "title": "Waste management"}], + }, + { + "dimension": {"id": "dimension-0", "title": "Scope & Scale"}, + "subdimension": {"id": "subdimension-4", "title": "Damages & Losses"}, + "sector": {"id": "sector-7", "title": "Education"}, + "subsectors": [], + }, + ], }, - 'c_response': [{ - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-1', 'title': 'Livelihoods'}, - 'subsectors': [] - }, { - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-4', 'title': 'WASH'}, - 'subsectors': [ - {'id': 'subsector-4', 'title': 'Waste management'} - ] - }, { - 'dimension': {'id': 'dimension-0', 'title': 'Scope & Scale'}, - 'subdimension': {'id': 'subdimension-4', 'title': 'Damages & Losses'}, - 'sector': {'id': 'sector-7', 'title': 'Education'}, - 'subsectors': [] - }], - }], - - 'timeRangeWidget': [{ - 'data': {'value': {'startTime': '18:05:00', 'endTime': '23:05:00'}}, - 'c_response': { - 'from': '18:05', - 'to': '23:05', + ], + "timeRangeWidget": [ + { + "data": {"value": {"startTime": "18:05:00", "endTime": "23:05:00"}}, + "c_response": { + "from": "18:05", + "to": "23:05", + }, + } + ], + "organigramWidget": [ + { + "data": {"value": ["node-1", "node-8"]}, + "c_response": [ + { + "key": "node-1", + "title": "Node 1", + "parents": [{"key": "base", "title": "Base Node"}], + }, + { + "key": "node-8", + "title": "Node 8", + "parents": [ + {"key": "node-7", "title": "Node 7"}, + {"key": "node-6", "title": "Node 6"}, + {"key": "node-2", "title": "Node 2"}, + {"key": "node-1", "title": "Node 1"}, + {"key": "base", "title": "Base Node"}, + ], + }, + ], }, - }], - - 'organigramWidget': [{ - 'data': {'value': ['node-1', 'node-8']}, - 'c_response': [{ - 'key': 'node-1', - 'title': 'Node 1', - 'parents': [{'key': 'base', 'title': 'Base Node'}], - }, { - 'key': 'node-8', - 'title': 'Node 8', - 'parents': [ - {'key': 'node-7', 'title': 'Node 7'}, - {'key': 'node-6', 'title': 'Node 6'}, - {'key': 'node-2', 'title': 'Node 2'}, - {'key': 'node-1', 'title': 'Node 1'}, - {'key': 'base', 'title': 'Base Node'}, - ] - }], - }, { - 'data': {'value': ['node-1', 'node-9', 'base']}, - 'c_response': [{ - 'key': 'base', - 'title': 'Base Node', - 'parents': [], - }, { - 'key': 'node-1', - 'title': 'Node 1', - 'parents': [{'key': 'base', 'title': 'Base Node'}], - }], - }], + { + "data": {"value": ["node-1", "node-9", "base"]}, + "c_response": [ + { + "key": "base", + "title": "Base Node", + "parents": [], + }, + { + "key": "node-1", + "title": "Node 1", + "parents": [{"key": "base", "title": "Base Node"}], + }, + ], + }, + ], } diff --git a/apps/entry/tests/test_apis.py b/apps/entry/tests/test_apis.py index a090be0f69..365bfb7c34 100644 --- a/apps/entry/tests/test_apis.py +++ b/apps/entry/tests/test_apis.py @@ -1,74 +1,59 @@ import autofixture - -from reversion.models import Version - -from deep.tests import TestCase -from project.models import Project -from user.models import User -from lead.models import Lead, LeadPreviewImage -from organization.models import Organization, OrganizationType -from analysis_framework.models import ( - AnalysisFramework, Widget, Filter -) +from analysis_framework.models import AnalysisFramework, Filter, Widget from entry.models import ( - Entry, Attribute, + Entry, + EntryGroupLabel, FilterData, - ProjectEntryLabel, LeadEntryGroup, - EntryGroupLabel, + ProjectEntryLabel, ) - from gallery.models import File -from tabular.models import Sheet, Field +from lead.models import Lead, LeadPreviewImage +from organization.models import Organization, OrganizationType +from project.models import Project +from reversion.models import Version +from tabular.models import Field, Sheet +from user.models import User + +from deep.tests import TestCase class EntryTests(TestCase): def create_entry_with_data_series(self): sheet = autofixture.create_one(Sheet, generate_fk=True) series = [ # create some dummy values - { - 'value': 'male', 'processed_value': 'male', - 'invalid': False, 'empty': False - }, - { - 'value': 'female', 'processed_value': 'female', - 'invalid': False, 'empty': False - }, - { - 'value': 'female', 'processed_value': 'female', - 'invalid': False, 'empty': False - }, + {"value": "male", "processed_value": "male", "invalid": False, "empty": False}, + {"value": "female", "processed_value": "female", "invalid": False, "empty": False}, + {"value": "female", "processed_value": "female", "invalid": False, "empty": False}, ] cache_series = [ - {'value': 'male', 'count': 1}, - {'value': 'female', 'count': 2}, + {"value": "male", "count": 1}, + {"value": "female", "count": 2}, ] health_stats = { - 'invalid': 10, - 'total': 20, - 'empty': 10, + "invalid": 10, + "total": 20, + "empty": 10, } field = autofixture.create_one( Field, field_values={ - 'sheet': sheet, - 'title': 'Abrakadabra', - 'type': Field.STRING, - 'data': series, - 'cache': { - 'status': Field.CACHE_SUCCESS, - 'series': cache_series, - 'health_stats': health_stats, - 'images': [], + "sheet": sheet, + "title": "Abrakadabra", + "type": Field.STRING, + "data": series, + "cache": { + "status": Field.CACHE_SUCCESS, + "series": cache_series, + "health_stats": health_stats, + "images": [], }, - } + }, ) - entry = self.create_entry( - tabular_field=field, entry_type=Entry.TagType.DATA_SERIES - ) + entry = self.create_entry(tabular_field=field, entry_type=Entry.TagType.DATA_SERIES) return entry, field def test_search_filter_polygon(self): @@ -77,36 +62,34 @@ def test_search_filter_polygon(self): Widget, analysis_framework=lead.project.analysis_framework, widget_id=Widget.WidgetType.GEO, - key='geoWidget-101', + key="geoWidget-101", ) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'analysis_framework': geo_widget.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': { + "lead": lead.pk, + "project": lead.project.pk, + "analysis_framework": geo_widget.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": { geo_widget.pk: { - 'data': { - 'value': [1, 2, {'type': 'Point'}] - }, + "data": {"value": [1, 2, {"type": "Point"}]}, }, }, } self.authenticate() self.client.post(url, data) - data['attributes'][geo_widget.pk]['data']['value'] = [{'type': 'Polygon'}] + data["attributes"][geo_widget.pk]["data"]["value"] = [{"type": "Polygon"}] self.client.post(url, data) - data['attributes'][geo_widget.pk]['data']['value'] = [{'type': 'Line'}, {'type': 'Polygon'}] + data["attributes"][geo_widget.pk]["data"]["value"] = [{"type": "Line"}, {"type": "Polygon"}] self.client.post(url, data) - filters = {'geo_custom_shape': 'Point'} + filters = {"geo_custom_shape": "Point"} self.post_filter_test(filters, 1) - filters['geo_custom_shape'] = 'Polygon' + filters["geo_custom_shape"] = "Polygon" self.post_filter_test(filters, 2) - filters['geo_custom_shape'] = 'Point,Line,Polygon' + filters["geo_custom_shape"] = "Point,Line,Polygon" self.post_filter_test(filters, 3) def test_filter_entries_by_type(self): @@ -119,21 +102,21 @@ def test_filter_entries_by_type(self): self.authenticate() self.post_filter_test( # Filter - {'entry_type': [Entry.TagType.EXCERPT, Entry.TagType.IMAGE]}, + {"entry_type": [Entry.TagType.EXCERPT, Entry.TagType.IMAGE]}, # Count - Entry.objects.filter(entry_type__in=[Entry.TagType.EXCERPT, Entry.TagType.IMAGE]).count() + Entry.objects.filter(entry_type__in=[Entry.TagType.EXCERPT, Entry.TagType.IMAGE]).count(), ) self.post_filter_test( # Filter - {'entry_type': [Entry.TagType.EXCERPT]}, + {"entry_type": [Entry.TagType.EXCERPT]}, # Count - Entry.objects.filter(entry_type__in=[Entry.TagType.EXCERPT]).count() + Entry.objects.filter(entry_type__in=[Entry.TagType.EXCERPT]).count(), ) self.post_filter_test( # Filter - {'entry_type': [Entry.TagType.IMAGE, Entry.TagType.DATA_SERIES]}, + {"entry_type": [Entry.TagType.IMAGE, Entry.TagType.DATA_SERIES]}, # Count - Entry.objects.filter(entry_type__in=[Entry.TagType.IMAGE, Entry.TagType.DATA_SERIES]).count() + Entry.objects.filter(entry_type__in=[Entry.TagType.IMAGE, Entry.TagType.DATA_SERIES]).count(), ) def test_search_filter_entry_group_label(self): @@ -145,13 +128,13 @@ def test_search_filter_entry_group_label(self): entry2 = self.create_entry(lead=lead) # Labels - label1 = self.create(ProjectEntryLabel, project=project, title='Label 1', order=1, color='#23f23a') - label2 = self.create(ProjectEntryLabel, project=project, title='Label 2', order=2, color='#23f23a') + label1 = self.create(ProjectEntryLabel, project=project, title="Label 1", order=1, color="#23f23a") + label2 = self.create(ProjectEntryLabel, project=project, title="Label 2", order=2, color="#23f23a") # Groups - group1 = self.create(LeadEntryGroup, lead=lead, title='Group 1', order=1) - group2 = self.create(LeadEntryGroup, lead=lead, title='Group 2', order=2) - group3 = self.create(LeadEntryGroup, lead=lead, title='Group 3', order=3) + group1 = self.create(LeadEntryGroup, lead=lead, title="Group 1", order=1) + group2 = self.create(LeadEntryGroup, lead=lead, title="Group 2", order=2) + group3 = self.create(LeadEntryGroup, lead=lead, title="Group 3", order=3) [ self.create(EntryGroupLabel, group=group, label=label, entry=entry) @@ -162,14 +145,14 @@ def test_search_filter_entry_group_label(self): ] ] - default_filter = {'project': project.id} + default_filter = {"project": project.id} self.authenticate() - self.post_filter_test({**default_filter, 'project_entry_labels': [label1.pk]}, 2) - self.post_filter_test({**default_filter, 'project_entry_labels': [label2.pk]}, 1) + self.post_filter_test({**default_filter, "project_entry_labels": [label1.pk]}, 2) + self.post_filter_test({**default_filter, "project_entry_labels": [label2.pk]}, 1) - self.post_filter_test({**default_filter, 'lead_group_label': group1.title}, 2) - self.post_filter_test({**default_filter, 'lead_group_label': 'Group'}, 2) - self.post_filter_test({**default_filter, 'lead_group_label': group3.title}, 0) + self.post_filter_test({**default_filter, "lead_group_label": group1.title}, 2) + self.post_filter_test({**default_filter, "lead_group_label": "Group"}, 2) + self.post_filter_test({**default_filter, "lead_group_label": group3.title}, 0) def test_create_entry(self): entry_count = Entry.objects.count() @@ -179,18 +162,18 @@ def test_create_entry(self): Widget, analysis_framework=lead.project.analysis_framework, widget_id=Widget.WidgetType.TEXT, - key='text-102', + key="text-102", ) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'analysis_framework': widget.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': { + "lead": lead.pk, + "project": lead.project.pk, + "analysis_framework": widget.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": { widget.pk: { - 'data': {'a': 'b'}, + "data": {"a": "b"}, }, }, } @@ -201,26 +184,23 @@ def test_create_entry(self): r_data = response.json() self.assertEqual(Entry.objects.count(), entry_count + 1) - self.assertEqual(r_data['versionId'], 1) - self.assertEqual(r_data['excerpt'], data['excerpt']) + self.assertEqual(r_data["versionId"], 1) + self.assertEqual(r_data["excerpt"], data["excerpt"]) - attributes = r_data['attributes'] + attributes = r_data["attributes"] self.assertEqual(len(attributes.values()), 1) - attribute = Attribute.objects.get( - id=attributes[str(widget.pk)]['id'] - ) + attribute = Attribute.objects.get(id=attributes[str(widget.pk)]["id"]) self.assertEqual(attribute.widget.pk, widget.pk) - self.assertEqual(attribute.data['a'], 'b') + self.assertEqual(attribute.data["a"], "b") # Check if project matches - entry = Entry.objects.get(id=r_data['id']) + entry = Entry.objects.get(id=r_data["id"]) self.assertEqual(entry.project, entry.lead.project) def test_create_entry_no_project(self): - """Even without project parameter, entry should be created(using project from lead) - """ + """Even without project parameter, entry should be created(using project from lead)""" entry_count = Entry.objects.count() lead = self.create_lead() @@ -228,17 +208,17 @@ def test_create_entry_no_project(self): Widget, analysis_framework=lead.project.analysis_framework, widget_id=Widget.WidgetType.TEXT, - key='text-103', + key="text-103", ) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'analysis_framework': widget.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': { + "lead": lead.pk, + "analysis_framework": widget.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": { widget.pk: { - 'data': {'a': 'b'}, + "data": {"a": "b"}, }, }, } @@ -249,21 +229,19 @@ def test_create_entry_no_project(self): r_data = response.json() self.assertEqual(Entry.objects.count(), entry_count + 1) - self.assertEqual(r_data['versionId'], 1) - self.assertEqual(r_data['excerpt'], data['excerpt']) + self.assertEqual(r_data["versionId"], 1) + self.assertEqual(r_data["excerpt"], data["excerpt"]) - attributes = r_data['attributes'] + attributes = r_data["attributes"] self.assertEqual(len(attributes.values()), 1) - attribute = Attribute.objects.get( - id=attributes[str(widget.pk)]['id'] - ) + attribute = Attribute.objects.get(id=attributes[str(widget.pk)]["id"]) self.assertEqual(attribute.widget.pk, widget.pk) - self.assertEqual(attribute.data['a'], 'b') + self.assertEqual(attribute.data["a"], "b") # Check if project matches - entry = Entry.objects.get(id=r_data['id']) + entry = Entry.objects.get(id=r_data["id"]) self.assertEqual(entry.project, entry.lead.project) def test_create_entry_no_perm(self): @@ -274,21 +252,21 @@ def test_create_entry_no_perm(self): Widget, analysis_framework=lead.project.analysis_framework, widget_id=Widget.WidgetType.TEXT, - key='text-104', + key="text-104", ) user = self.create(User) lead.project.add_member(user, self.view_only_role) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'analysis_framework': widget.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': { + "lead": lead.pk, + "project": lead.project.pk, + "analysis_framework": widget.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": { widget.pk: { - 'data': {'a': 'b'}, + "data": {"a": "b"}, }, }, } @@ -302,7 +280,7 @@ def test_create_entry_no_perm(self): def test_delete_entry(self): entry = self.create_entry() - url = '/api/v1/entries/{}/'.format(entry.id) + url = "/api/v1/entries/{}/".format(entry.id) self.authenticate() @@ -314,7 +292,7 @@ def test_delete_entry_no_perm(self): user = self.create(User) entry.project.add_member(user, self.view_only_role) - url = '/api/v1/entries/{}/'.format(entry.id) + url = "/api/v1/entries/{}/".format(entry.id) self.authenticate(user) @@ -325,14 +303,14 @@ def test_duplicate_entry(self): entry_count = Entry.objects.count() lead = self.create_lead() - client_id = 'randomId123' - url = '/api/v1/entries/' + client_id = "randomId123" + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'excerpt': 'Test excerpt', - 'analysis_framework': lead.project.analysis_framework.id, - 'client_id': client_id, + "lead": lead.pk, + "project": lead.project.pk, + "excerpt": "Test excerpt", + "analysis_framework": lead.project.analysis_framework.id, + "client_id": client_id, } self.authenticate() @@ -341,7 +319,7 @@ def test_duplicate_entry(self): r_data = response.json() self.assertEqual(Entry.objects.count(), entry_count + 1) - self.assertEqual(r_data['clientId'], client_id) + self.assertEqual(r_data["clientId"], client_id) response = self.client.post(url, data) self.assert_500(response) @@ -352,29 +330,29 @@ def test_patch_attributes(self): Widget, analysis_framework=entry.lead.project.analysis_framework, widget_id=Widget.WidgetType.TEXT, - key='text-105', + key="text-105", ) widget2 = self.create( Widget, analysis_framework=entry.lead.project.analysis_framework, widget_id=Widget.WidgetType.TEXT, - key='text-106', + key="text-106", ) self.create( Attribute, - data={'a': 'b'}, + data={"a": "b"}, widget=widget1, ) - url = '/api/v1/entries/{}/'.format(entry.id) + url = "/api/v1/entries/{}/".format(entry.id) data = { - 'attributes': { + "attributes": { widget1.pk: { - 'data': {'c': 'd'}, + "data": {"c": "d"}, }, widget2.pk: { - 'data': {'e': 'f'}, - } + "data": {"e": "f"}, + }, }, } @@ -383,21 +361,21 @@ def test_patch_attributes(self): self.assert_200(response) r_data = response.json() - attributes = r_data['attributes'] + attributes = r_data["attributes"] self.assertEqual(len(attributes.values()), 2) attribute1 = attributes[str(widget1.pk)] - self.assertEqual(attribute1['data']['c'], 'd') + self.assertEqual(attribute1["data"]["c"], "d") attribute2 = attributes[str(widget2.pk)] - self.assertEqual(attribute2['data']['e'], 'f') + self.assertEqual(attribute2["data"]["e"], "f") def test_entry_options(self): - url = '/api/v1/entry-options/' + url = "/api/v1/entry-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertIn('created_by', response.data) + self.assertIn("created_by", response.data) def test_entry_options_in_project(self): user1 = self.create_user() @@ -415,106 +393,102 @@ def test_entry_options_in_project(self): self.create(Entry, lead=lead3, project=project1, created_by=user2) # filter by project2 - url = f'/api/v1/entry-options/?project={project2.id}' + url = f"/api/v1/entry-options/?project={project2.id}" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) # gives all the member of the project - self.assertEqual( - set([item['key'] for item in response.data['created_by']]), - set([user1.id, user2.id]) - ) + self.assertEqual(set([item["key"] for item in response.data["created_by"]]), set([user1.id, user2.id])) # filter by project1 - url = f'/api/v1/entry-options/?project={project1.id}' + url = f"/api/v1/entry-options/?project={project1.id}" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) # gives all the member of the project - self.assertEqual(user1.id, response.data['created_by'][0]['key']) - self.assertEqual(len(response.data['created_by']), 1) + self.assertEqual(user1.id, response.data["created_by"][0]["key"]) + self.assertEqual(len(response.data["created_by"]), 1) def filter_test(self, params, count=1): - url = '/api/v1/entries/?{}'.format(params) + url = "/api/v1/entries/?{}".format(params) self.authenticate() response = self.client.get(url) self.assert_200(response) r_data = response.json() - self.assertEqual(len(r_data['results']), count) + self.assertEqual(len(r_data["results"]), count) def post_filter_test(self, filters, count=1, skip_auth=False): - return super().post_filter_test('/api/v1/entries/filter/', filters, count=count, skip_auth=skip_auth) + return super().post_filter_test("/api/v1/entries/filter/", filters, count=count, skip_auth=skip_auth) def both_filter_test(self, filters, count=1): self.filter_test(filters, count) - k, v = filters.split('=') + k, v = filters.split("=") filters = {k: v} self.post_filter_test(filters, count) def test_filters(self): entry = self.create_entry() - self.filter_test('controlled=False', 1) - self.filter_test('controlled=True', 0) + self.filter_test("controlled=False", 1) + self.filter_test("controlled=True", 0) filter = self.create( Filter, analysis_framework=entry.analysis_framework, - widget_key='test_filter', - key='test_filter', - title='Test Filter', + widget_key="test_filter", + key="test_filter", + title="Test Filter", filter_type=Filter.FilterType.NUMBER, ) self.create(FilterData, entry=entry, filter=filter, number=500) - self.both_filter_test('test_filter=500') - self.both_filter_test('test_filter__lt=600') - self.both_filter_test('test_filter__gt=400') - self.both_filter_test('test_filter__lt=400', 0) + self.both_filter_test("test_filter=500") + self.both_filter_test("test_filter__lt=600") + self.both_filter_test("test_filter__gt=400") + self.both_filter_test("test_filter__lt=400", 0) filter = self.create( Filter, analysis_framework=entry.analysis_framework, - widget_key='test_list_filter', - key='test_list_filter', - title='Test List Filter', + widget_key="test_list_filter", + key="test_list_filter", + title="Test List Filter", filter_type=Filter.FilterType.LIST, ) - self.create(FilterData, entry=entry, filter=filter, - values=['abc', 'def', 'ghi']) + self.create(FilterData, entry=entry, filter=filter, values=["abc", "def", "ghi"]) - self.both_filter_test('test_list_filter=abc') - self.both_filter_test('test_list_filter=ghi,def', 1) - self.both_filter_test('test_list_filter=uml,hij', 0) + self.both_filter_test("test_list_filter=abc") + self.both_filter_test("test_list_filter=ghi,def", 1) + self.both_filter_test("test_list_filter=uml,hij", 0) - entry.excerpt = 'hello' + entry.excerpt = "hello" entry.save() - self.post_filter_test({'search': 'el'}, 1) - self.post_filter_test({'search': 'pollo'}, 0) + self.post_filter_test({"search": "el"}, 1) + self.post_filter_test({"search": "pollo"}, 0) def test_lead_published_on_filter(self): - lead1 = self.create_lead(published_on='2020-09-25') - lead2 = self.create_lead(published_on='2020-09-26') - lead3 = self.create_lead(published_on='2020-09-27') + lead1 = self.create_lead(published_on="2020-09-25") + lead2 = self.create_lead(published_on="2020-09-26") + lead3 = self.create_lead(published_on="2020-09-27") self.create_entry(lead=lead1) self.create_entry(lead=lead2) self.create_entry(lead=lead3) filters = { - 'lead_published_on__gte': '2020-09-25', - 'lead_published_on__lte': '2020-09-26', + "lead_published_on__gte": "2020-09-25", + "lead_published_on__lte": "2020-09-26", } self.authenticate() self.post_filter_test(filters, 2) # simulate filter behaviour of today from the frontend filters = { - 'lead_published_on__gte': '2020-09-25', - 'lead_published_on__lt': '2020-09-26', + "lead_published_on__gte": "2020-09-25", + "lead_published_on__lt": "2020-09-26", } self.post_filter_test(filters, 1) @@ -538,79 +512,76 @@ def test_lead_assignee_filter(self): # test assignee created by self user filters = { - 'lead_assignee': [self.user.pk], + "lead_assignee": [self.user.pk], } self.authenticate() self.post_filter_test(filters, 4) # test assignee created by another user filters = { - 'lead_assignee': [another_user.pk], + "lead_assignee": [another_user.pk], } self.post_filter_test(filters, 2) # test assignee created by both users filters = { - 'lead_assignee': [self.user.pk, another_user.pk], + "lead_assignee": [self.user.pk, another_user.pk], } self.post_filter_test(filters, 6) def test_search_filter(self): entry, field = self.create_entry_with_data_series() filters = { - 'search': 'kadabra', + "search": "kadabra", } self.authenticate() self.post_filter_test(filters) # Should have single result filters = { - 'comment_status': 'resolved', - 'comment_assignee': self.user.pk, - 'comment_created_by': self.user.pk, + "comment_status": "resolved", + "comment_assignee": self.user.pk, + "comment_created_by": self.user.pk, } self.post_filter_test(filters, 0) # Should have no result - filters['comment_status'] = 'unresolved' + filters["comment_status"] = "unresolved" self.post_filter_test(filters, 0) # Should have no result def test_project_label_api(self): project = self.create_project(is_private=True) - label1 = self.create(ProjectEntryLabel, project=project, title='Label 1', color='color', order=1) - label2 = self.create(ProjectEntryLabel, project=project, title='Label 2', color='color', order=2) - label3 = self.create(ProjectEntryLabel, project=project, title='Label 3', color='color', order=3) + label1 = self.create(ProjectEntryLabel, project=project, title="Label 1", color="color", order=1) + label2 = self.create(ProjectEntryLabel, project=project, title="Label 2", color="color", order=2) + label3 = self.create(ProjectEntryLabel, project=project, title="Label 3", color="color", order=3) # Non member user self.authenticate(self.create_user()) - url = f'/api/v1/projects/{project.pk}/entry-labels/' + url = f"/api/v1/projects/{project.pk}/entry-labels/" response = self.client.get(url) self.assert_403(response) # List API self.authenticate() - url = f'/api/v1/projects/{project.pk}/entry-labels/' + url = f"/api/v1/projects/{project.pk}/entry-labels/" response = self.client.get(url) - assert len(response.json()['results']) == 3 + assert len(response.json()["results"]) == 3 # Bulk update API - url = f'/api/v1/projects/{project.pk}/entry-labels/bulk-update-order/' + url = f"/api/v1/projects/{project.pk}/entry-labels/bulk-update-order/" order_data = [ - {'id': label1.pk, 'order': 3}, - {'id': label2.pk, 'order': 2}, - {'id': label3.pk, 'order': 1}, + {"id": label1.pk, "order": 3}, + {"id": label2.pk, "order": 2}, + {"id": label3.pk, "order": 1}, ] response = self.client.post(url, order_data) - self.assertEqual( - {d['id']: d['order'] for d in order_data}, - {d['id']: d['order'] for d in response.json()} - ) + self.assertEqual({d["id"]: d["order"] for d in order_data}, {d["id"]: d["order"] for d in response.json()}) def test_control_entry(self): entry = self.create_entry() user = self.create(User) entry.project.add_member(user, self.view_only_role) - control_url = '/api/v1/entries/{}/control/'.format(entry.id) - uncontrol_url = '/api/v1/entries/{}/uncontrol/'.format(entry.id) + control_url = "/api/v1/entries/{}/control/".format(entry.id) + uncontrol_url = "/api/v1/entries/{}/uncontrol/".format(entry.id) self.authenticate(user) @@ -623,23 +594,23 @@ def test_control_entry(self): self.authenticate() current_version = Version.objects.get_for_object(entry).count() - response = self.client.post(control_url, {'version_id': current_version}, format='json') + response = self.client.post(control_url, {"version_id": current_version}, format="json") self.assert_200(response) entry.refresh_from_db() self.assertTrue(entry.controlled) current_version = Version.objects.get_for_object(entry).count() - response = self.client.post(uncontrol_url, {'version_id': current_version}, format='json') + response = self.client.post(uncontrol_url, {"version_id": current_version}, format="json") self.assert_200(response) response_data = response.json() - assert response_data['id'] == entry.pk - assert response_data['versionId'] != current_version - assert response_data['versionId'] == current_version + 1 + assert response_data["id"] == entry.pk + assert response_data["versionId"] != current_version + assert response_data["versionId"] == current_version + 1 entry.refresh_from_db() self.assertFalse(entry.controlled) # With old current_version - response = self.client.post(uncontrol_url, {'version_id': current_version}, format='json') + response = self.client.post(uncontrol_url, {"version_id": current_version}, format="json") self.assert_400(response) def test_authoring_organization_filter(self): @@ -665,29 +636,25 @@ def test_authoring_organization_filter(self): self.create_entry(lead=lead3) # Test for GET - url = '/api/v1/entries/?authoring_organization_types={}' + url = "/api/v1/entries/?authoring_organization_types={}" self.authenticate() response = self.client.get(url.format(organization_type1.id)) self.assert_200(response) - assert len(response.data['results']) == 2, "There should be 2 entry" + assert len(response.data["results"]) == 2, "There should be 2 entry" # get multiple leads - organization_type_query = ','.join([ - str(id) for id in [organization_type1.id, organization_type3.id] - ]) + organization_type_query = ",".join([str(id) for id in [organization_type1.id, organization_type3.id]]) response = self.client.get(url.format(organization_type_query)) - assert len(response.data['results']) == 3, "There should be 3 entry" + assert len(response.data["results"]) == 3, "There should be 3 entry" # filter single post filters = { - 'authoring_organization_types': [organization_type1.id], + "authoring_organization_types": [organization_type1.id], } self.post_filter_test(filters, 2) - filters = { - 'authoring_organization_types': [organization_type1.id, organization_type3.id] - } + filters = {"authoring_organization_types": [organization_type1.id, organization_type3.id]} self.post_filter_test(filters, 3) def test_entry_image_validation(self): @@ -695,72 +662,75 @@ def test_entry_image_validation(self): user = self.create_user() lead.project.add_member(user, role=self.normal_role) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'lead': lead.pk, - 'project': lead.project.pk, - 'analysis_framework': lead.project.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': {}, + "lead": lead.pk, + "project": lead.project.pk, + "analysis_framework": lead.project.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": {}, } self.authenticate() image = self.create_gallery_file() # Using raw image - data['image_raw'] = '' # noqa: E501 + data["image_raw"] = ( + "" # noqa: E501 + ) response = self.client.post(url, data) self.assert_201(response) - assert 'image' in response.data - assert 'image_details' in response.data - data.pop('image_raw') + assert "image" in response.data + assert "image_details" in response.data + data.pop("image_raw") # Try to update entry with another user. we don't want 400 here as we are not updating image self.authenticate(user) - response = self.client.patch(f"{url}{response.data['id']}/", {'attributes': {}}) + response = self.client.patch(f"{url}{response.data['id']}/", {"attributes": {}}) self.assert_200(response) - assert 'image' in response.data - assert 'image_details' in response.data + assert "image" in response.data + assert "image_details" in response.data self.authenticate() # Using lead image (same lead) - data['lead_image'] = self.create(LeadPreviewImage, lead=lead, file=image.file).pk + data["lead_image"] = self.create(LeadPreviewImage, lead=lead, file=image.file).pk response = self.client.post(url, data) self.assert_201(response) - assert 'image' in response.data - assert 'image_details' in response.data - data.pop('lead_image') + assert "image" in response.data + assert "image_details" in response.data + data.pop("lead_image") # Using lead image (different lead) - data['lead_image'] = self.create(LeadPreviewImage, lead=self.create_lead(), file=image.file).pk + data["lead_image"] = self.create(LeadPreviewImage, lead=self.create_lead(), file=image.file).pk response = self.client.post(url, data) self.assert_400(response) - data.pop('lead_image') + data.pop("lead_image") # Using gallery file (owned) - data['image'] = image.pk + data["image"] = image.pk response = self.client.post(url, data) self.assert_201(response) - assert 'image' in response.data - assert 'image_details' in response.data - data.pop('image') + assert "image" in response.data + assert "image_details" in response.data + data.pop("image") # Using gallery file (not owned) image.created_by = self.root_user image.is_public = False image.save() - data['image'] = image.pk + data["image"] = image.pk response = self.client.post(url, data) self.assert_400(response) - data.pop('image') + data.pop("image") # Using gallery file (not owned but public) image.is_public = True image.save() - data['image'] = image.pk + data["image"] = image.pk response = self.client.post(url, data) self.assert_201(response) - data.pop('image') + data.pop("image") + # TODO: test export data and filter data apis def test_entry_id_filter(self): @@ -773,53 +743,43 @@ def test_entry_id_filter(self): self.authenticate(user) # only the entry of project that user is member - self.post_filter_test({'entries_id': [entry1.pk, entry3.pk]}, 1, skip_auth=False) + self.post_filter_test({"entries_id": [entry1.pk, entry3.pk]}, 1, skip_auth=False) # try filtering out the entries that the user is not member of # Only the entry of project that user is member - self.post_filter_test({'entries_id': [entry1.pk, entry2.pk, entry3.pk]}, 2, skip_auth=False) + self.post_filter_test({"entries_id": [entry1.pk, entry2.pk, entry3.pk]}, 2, skip_auth=False) # try authenticating with default user created with project self.authenticate() # There should be 3 the entry - self.post_filter_test({'entries_id': [entry1.pk, entry2.pk, entry3.pk]}, 3, skip_auth=False) + self.post_filter_test({"entries_id": [entry1.pk, entry2.pk, entry3.pk]}, 3, skip_auth=False) class EntryTest(TestCase): def setUp(self): super().setUp() - self.file = File.objects.create(title='test') + self.file = File.objects.create(title="test") def create_project(self): analysis_framework = self.create(AnalysisFramework) - return self.create( - Project, analysis_framework=analysis_framework, - role=self.admin_role - ) + return self.create(Project, analysis_framework=analysis_framework, role=self.admin_role) def create_lead(self, **fields): project = self.create_project() return self.create(Lead, project=project, **fields) def create_entry(self, **fields): - lead = fields.pop('lead', self.create_lead()) - return self.create( - Entry, lead=lead, project=lead.project, - analysis_framework=lead.project.analysis_framework, - **fields - ) + lead = fields.pop("lead", self.create_lead()) + return self.create(Entry, lead=lead, project=lead.project, analysis_framework=lead.project.analysis_framework, **fields) def test_entry_no_image(self): - entry = self.create_entry(image=None, image_raw='') + entry = self.create_entry(image=None, image_raw="") assert entry.get_image_url() is None def test_entry_image(self): - entry_image_url = '/some/path' + entry_image_url = "/some/path" file = File.objects.get(id=self.file.id) - entry_with_raw_image = self.create_entry( - image=None, - image_raw='{}/{}'.format(entry_image_url, file.id) - ) + entry_with_raw_image = self.create_entry(image=None, image_raw="{}/{}".format(entry_image_url, file.id)) entry_with_image = self.create_entry(image=File.objects.get(id=file.id)) assert entry_with_raw_image.get_image_url() == file.get_file_url() assert entry_with_image.get_image_url() == file.get_file_url() @@ -845,33 +805,45 @@ def test_list_entries_summary(self): self.create_entry(lead=lead2) self.create_entry(lead=lead2) - url = '/api/v1/entries/filter/' + url = "/api/v1/entries/filter/" self.authenticate() - response = self.client.post(url, dict(calculate_summary='1')) + response = self.client.post(url, dict(calculate_summary="1")) self.assert_200(response) r_data = response.json() - self.assertIn('summary', r_data) - summ = r_data['summary'] - self.assertEqual(summ['totalControlledEntries'], Entry.objects.filter(controlled=True).count()) - self.assertEqual(summ['totalUncontrolledEntries'], Entry.objects.filter(controlled=False).count()) - self.assertEqual(summ['totalLeads'], len([lead1, lead2])) - self.assertEqual(summ['totalSources'], len({org1, org3})) - - self.assertTrue({'org': {'id': org_type1.id, 'shortName': org_type1.short_name, 'title': org_type1.title}, 'count': 2} in summ['orgTypeCount']) # noqa: E501 - self.assertTrue({'org': {'id': org_type2.id, 'shortName': org_type2.short_name, 'title': org_type2.title}, 'count': 1} in summ['orgTypeCount']) # noqa: E501 - - url = '/api/v1/entries/?calculate_summary=1' + self.assertIn("summary", r_data) + summ = r_data["summary"] + self.assertEqual(summ["totalControlledEntries"], Entry.objects.filter(controlled=True).count()) + self.assertEqual(summ["totalUncontrolledEntries"], Entry.objects.filter(controlled=False).count()) + self.assertEqual(summ["totalLeads"], len([lead1, lead2])) + self.assertEqual(summ["totalSources"], len({org1, org3})) + + self.assertTrue( + {"org": {"id": org_type1.id, "shortName": org_type1.short_name, "title": org_type1.title}, "count": 2} + in summ["orgTypeCount"] + ) # noqa: E501 + self.assertTrue( + {"org": {"id": org_type2.id, "shortName": org_type2.short_name, "title": org_type2.title}, "count": 1} + in summ["orgTypeCount"] + ) # noqa: E501 + + url = "/api/v1/entries/?calculate_summary=1" self.authenticate() response = self.client.get(url) self.assert_200(response) r_data = response.json() - self.assertIn('summary', r_data) - summ = r_data['summary'] - self.assertEqual(summ['totalControlledEntries'], Entry.objects.filter(controlled=True).count()) - self.assertEqual(summ['totalUncontrolledEntries'], Entry.objects.filter(controlled=False).count()) - self.assertEqual(summ['totalLeads'], len([lead1, lead2])) - self.assertEqual(summ['totalSources'], len({org1, org3})) - self.assertTrue({'org': {'id': org_type1.id, 'shortName': org_type1.short_name, 'title': org_type1.title}, 'count': 2} in summ['orgTypeCount']) # noqa: E501 - self.assertTrue({'org': {'id': org_type2.id, 'shortName': org_type2.short_name, 'title': org_type2.title}, 'count': 1} in summ['orgTypeCount']) # noqa: E501 + self.assertIn("summary", r_data) + summ = r_data["summary"] + self.assertEqual(summ["totalControlledEntries"], Entry.objects.filter(controlled=True).count()) + self.assertEqual(summ["totalUncontrolledEntries"], Entry.objects.filter(controlled=False).count()) + self.assertEqual(summ["totalLeads"], len([lead1, lead2])) + self.assertEqual(summ["totalSources"], len({org1, org3})) + self.assertTrue( + {"org": {"id": org_type1.id, "shortName": org_type1.short_name, "title": org_type1.title}, "count": 2} + in summ["orgTypeCount"] + ) # noqa: E501 + self.assertTrue( + {"org": {"id": org_type2.id, "shortName": org_type2.short_name, "title": org_type2.title}, "count": 1} + in summ["orgTypeCount"] + ) # noqa: E501 diff --git a/apps/entry/tests/test_comprehensive_apis.py b/apps/entry/tests/test_comprehensive_apis.py index 0e6c30a7d7..2da0ff8c35 100644 --- a/apps/entry/tests/test_comprehensive_apis.py +++ b/apps/entry/tests/test_comprehensive_apis.py @@ -1,17 +1,11 @@ +from analysis_framework.models import Widget +from entry.models import Attribute +from entry.widgets.store import widget_store from parameterized import parameterized from deep.tests import TestCase -from analysis_framework.models import ( - Widget, -) -from entry.models import ( - Attribute, -) -from entry.widgets.store import widget_store - -from .entry_widget_test_data import WIDGET_PROPERTIES, ATTRIBUTE_DATA - +from .entry_widget_test_data import ATTRIBUTE_DATA, WIDGET_PROPERTIES SKIP_WIDGETS = [ Widget.WidgetType.GEO, @@ -39,8 +33,8 @@ def create_widget(self, widget_id, widget_properties): analysis_framework=project.analysis_framework, properties=widget_properties, widget_id=widget_id, - key=f'{widget_id}-{self._counter}', - title=f'{widget_id}-{self._counter} (Title)', + key=f"{widget_id}-{self._counter}", + title=f"{widget_id}-{self._counter} (Title)", ) self._counter += 1 return widget @@ -64,43 +58,52 @@ def assertAttributeValue(self, widgets_meta, widget_id, widget_properties, attr_ expected_c_response = expected_c_response or {} widget, attribute = self.create_attribute(widget_id, widget_properties, attr_data) data = attribute.data or {} - c_resposne = self.get_data_selector(widget_id)( - widgets_meta, widget, data, widget.properties, - ) or {} + c_resposne = ( + self.get_data_selector(widget_id)( + widgets_meta, + widget, + data, + widget.properties, + ) + or {} + ) if widget_id in (Widget.WidgetType.SCALE,): # new key 'scale' is appended - self.assertTrue( - expected_c_response.items() <= c_resposne.items(), - (expected_c_response.items(), c_resposne.items()) - ) + self.assertTrue(expected_c_response.items() <= c_resposne.items(), (expected_c_response.items(), c_resposne.items())) else: self.assertEqual(expected_c_response, c_resposne) def _test_widget(self, widget_id): widget_properties = WIDGET_PROPERTIES[widget_id] - if not hasattr(self, 'widgets_meta'): + if not hasattr(self, "widgets_meta"): self.widgets_meta = {} for attribute_data in ATTRIBUTE_DATA[widget_id]: - attr_data = attribute_data['data'] - expected_c_response = attribute_data['c_response'] + attr_data = attribute_data["data"] + expected_c_response = attribute_data["c_response"] self.assertAttributeValue( - self.widgets_meta, widget_id, widget_properties, - attr_data, expected_c_response, + self.widgets_meta, + widget_id, + widget_properties, + attr_data, + expected_c_response, ) def test_comprehensive_api(self): self.authenticate() project = self.create_project() - url = f'/api/v1/projects/{project.pk}/comprehensive-entries/' + url = f"/api/v1/projects/{project.pk}/comprehensive-entries/" response = self.client.get(url) self.assert_200(response) - @parameterized.expand([ - [widget_id] for widget_id, widget_meta in widget_store.items() - if hasattr(widget_meta, 'get_comprehensive_data') and widget_id not in SKIP_WIDGETS - ]) + @parameterized.expand( + [ + [widget_id] + for widget_id, widget_meta in widget_store.items() + if hasattr(widget_meta, "get_comprehensive_data") and widget_id not in SKIP_WIDGETS + ] + ) def test_comprehensive_(self, widget_id): self.maxDiff = None self._test_widget(widget_id) diff --git a/apps/entry/tests/test_entry_comment.py b/apps/entry/tests/test_entry_comment.py index 3505903ee6..81eb63a7da 100644 --- a/apps/entry/tests/test_entry_comment.py +++ b/apps/entry/tests/test_entry_comment.py @@ -1,27 +1,31 @@ -from deep.tests import TestCase from entry.models import EntryComment from notification.models import Notification +from deep.tests import TestCase + class EntryCommentTests(TestCase): def setUp(self): super().setUp() self.entry = self.create_entry() - self.comment = self.create(EntryComment, **{ - 'entry': self.entry, - 'assignees': [self.user], - 'text': 'This is a comment text', - 'parent': None, - }) + self.comment = self.create( + EntryComment, + **{ + "entry": self.entry, + "assignees": [self.user], + "text": "This is a comment text", + "parent": None, + }, + ) assert self.comment.is_resolved is False self.entry.project.add_member(self.root_user) def test_create_comment(self): - url = f'/api/v1/entries/{self.entry.pk}/entry-comments/' + url = f"/api/v1/entries/{self.entry.pk}/entry-comments/" data = { - 'assignees': [self.user.pk], - 'text': 'This is first comment', - 'parent': None, + "assignees": [self.user.pk], + "text": "This is first comment", + "parent": None, } self.authenticate() @@ -29,75 +33,78 @@ def test_create_comment(self): self.assert_201(response) # Throw error if assignee is not provided for root comment - data.pop('assignees') + data.pop("assignees") response = self.client.post(url, data) self.assert_400(response) - data['assignees'] = [self.user.pk] + data["assignees"] = [self.user.pk] # Throw error if text is not provided - data['text'] = None + data["text"] = None response = self.client.post(url, data) self.assert_400(response) def test_create_comment_reply(self): entry_2 = self.create_entry() - url = f'/api/v1/entries/{entry_2.pk}/entry-comments/' + url = f"/api/v1/entries/{entry_2.pk}/entry-comments/" data = { - 'assignees': [self.user.pk], - 'text': 'This is first comment', - 'parent': self.comment.pk, + "assignees": [self.user.pk], + "text": "This is first comment", + "parent": self.comment.pk, } self.authenticate() response = self.client.post(url, data) self.assert_400(response) - assert 'parent' in response.data['errors'] - - comment_2 = self.create(EntryComment, **{ - 'entry': entry_2, - 'assignees': [self.user], - 'text': 'This is a comment text', - 'parent': None, - }) - data['parent'] = comment_2.pk + assert "parent" in response.data["errors"] + + comment_2 = self.create( + EntryComment, + **{ + "entry": entry_2, + "assignees": [self.user], + "text": "This is a comment text", + "parent": None, + }, + ) + data["parent"] = comment_2.pk response = self.client.post(url, data) self.assert_201(response) - assert response.data['entry'] == entry_2.pk, 'Should be same to parent entry' - assert response.data['assignees'] == [], 'There should be no assignee in reply comment' + assert response.data["entry"] == entry_2.pk, "Should be same to parent entry" + assert response.data["assignees"] == [], "There should be no assignee in reply comment" def test_comment_text_history(self): - url = f'/api/v1/entries/{self.entry.pk}/entry-comments/' + url = f"/api/v1/entries/{self.entry.pk}/entry-comments/" data = { - 'assignees': [self.user.pk], - 'text': 'This is first comment', - 'parent': None, + "assignees": [self.user.pk], + "text": "This is first comment", + "parent": None, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - comment_id = response.json()['id'] + comment_id = response.json()["id"] # Patch new text - new_text = 'this is second comment' - response = self.client.patch(f'{url}{comment_id}/', {'text': new_text}) + new_text = "this is second comment" + response = self.client.patch(f"{url}{comment_id}/", {"text": new_text}) r_data = response.json() - assert r_data['text'] == new_text - assert len(r_data['textHistory']) == 2 + assert r_data["text"] == new_text + assert len(r_data["textHistory"]) == 2 # Patch same text again - response = self.client.patch(f'{url}{comment_id}/', {'text': new_text}) + response = self.client.patch(f"{url}{comment_id}/", {"text": new_text}) r_data = response.json() - assert r_data['text'] == new_text - assert len(r_data['textHistory']) == 2 + assert r_data["text"] == new_text + assert len(r_data["textHistory"]) == 2 def test_comment_resolve(self): - url = f'/api/v1/entries/{self.entry.pk}/entry-comments/' + url = f"/api/v1/entries/{self.entry.pk}/entry-comments/" data = { - 'assignees': [self.user.pk], - 'text': 'This is first comment', - 'parent': None, + "assignees": [self.user.pk], + "text": "This is first comment", + "parent": None, } self.authenticate(self.root_user) @@ -105,89 +112,86 @@ def test_comment_resolve(self): response = self.client.post(url, data) r_data = response.json() self.assert_201(response) - assert r_data['isResolved'] is False - parent_comment_id_1 = r_data['id'] + assert r_data["isResolved"] is False + parent_comment_id_1 = r_data["id"] self.authenticate(self.user) # Add comment response = self.client.post(url, data) r_data = response.json() self.assert_201(response) - assert r_data['isResolved'] is False - parent_comment_id = r_data['id'] + assert r_data["isResolved"] is False + parent_comment_id = r_data["id"] # Add reply comment - data['parent'] = parent_comment_id + data["parent"] = parent_comment_id response = self.client.post(url, data) r_data = response.json() self.assert_201(response) - assert r_data['isResolved'] is False - comment_id = r_data['id'] + assert r_data["isResolved"] is False + comment_id = r_data["id"] # Throw error if resolved request is send for reply - response = self.client.post(f'{url}{comment_id}/resolve/') + response = self.client.post(f"{url}{comment_id}/resolve/") self.assert_400(response) # Send resolve request to comment - response = self.client.post(f'{url}{parent_comment_id}/resolve/') + response = self.client.post(f"{url}{parent_comment_id}/resolve/") r_data = response.json() - assert r_data['isResolved'] is True + assert r_data["isResolved"] is True # Throw error if reply is added for resolved comment - data['parent'] = parent_comment_id + data["parent"] = parent_comment_id response = self.client.post(url, data) self.assert_400(response) # Throw error if request send to resolved other user's comment - response = self.client.post(f'{url}{parent_comment_id_1}/resolve/') + response = self.client.post(f"{url}{parent_comment_id_1}/resolve/") r_data = response.json() self.assert_403(response) def test_comment_delete(self): - url = f'/api/v1/entries/{self.entry.pk}/entry-comments/' + url = f"/api/v1/entries/{self.entry.pk}/entry-comments/" user1 = self.user user2 = self.root_user data = { - 'assignees': [self.user.pk], - 'text': 'This is first comment', - 'parent': None, + "assignees": [self.user.pk], + "text": "This is first comment", + "parent": None, } self.authenticate(user1) # Add comment by user1 response = self.client.post(url, data) self.assert_201(response) - comment1_id = response.json()['id'] + comment1_id = response.json()["id"] self.authenticate(user2) # Add comment by user2 response = self.client.post(url, data) self.assert_201(response) - comment2_id = response.json()['id'] + comment2_id = response.json()["id"] self.authenticate(user1) - response = self.client.delete(f'{url}{comment1_id}/') + response = self.client.delete(f"{url}{comment1_id}/") self.assert_204(response) - response = self.client.delete(f'{url}{comment2_id}/') + response = self.client.delete(f"{url}{comment2_id}/") self.assert_403(response) def test_comment_notification(self): """ Used to send Notification using DEEP and Email """ + def _get_comment_users_pk(pk): - return set( - EntryComment.objects.get(pk=pk).get_related_users().values_list('pk', flat=True) - ) + return set(EntryComment.objects.get(pk=pk).get_related_users().values_list("pk", flat=True)) def _clear_notifications(): return Notification.objects.all().delete() def _get_notifications_receivers(): - return set( - Notification.objects.values_list('receiver', flat=True) - ), set( - Notification.objects.values_list('notification_type', flat=True).distinct() + return set(Notification.objects.values_list("receiver", flat=True)), set( + Notification.objects.values_list("notification_type", flat=True).distinct() ) reviewer = self.create_user() @@ -196,11 +200,11 @@ def _get_notifications_receivers(): for user in [reviewer, tagger1, tagger2]: self.entry.project.add_member(user, role=self.normal_role) - url = f'/api/v1/entries/{self.entry.pk}/entry-comments/' + url = f"/api/v1/entries/{self.entry.pk}/entry-comments/" data = { - 'assignees': [tagger1.pk], - 'text': 'This is first comment', - 'parent': None, + "assignees": [tagger1.pk], + "text": "This is first comment", + "parent": None, } # Create a commit @@ -208,22 +212,22 @@ def _get_notifications_receivers(): self.authenticate(reviewer) # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment1_id = self.client.post(url, data).json()['id'] + comment1_id = self.client.post(url, data).json()["id"] assert _get_comment_users_pk(comment1_id) == set([tagger1.pk]) assert _get_notifications_receivers() == ( set([tagger1.pk]), set([Notification.Type.ENTRY_COMMENT_ADD]), ) - data['parent'] = comment1_id - data['assignees'] = [] + data["parent"] = comment1_id + data["assignees"] = [] # Create a reply 1 _clear_notifications() self.authenticate(tagger1) - data['text'] = 'this is first reply' + data["text"] = "this is first reply" with self.captureOnCommitCallbacks(execute=True): - reply1_id = self.client.post(url, data).json()['id'] + reply1_id = self.client.post(url, data).json()["id"] assert _get_comment_users_pk(reply1_id) == set([reviewer.pk]) assert _get_notifications_receivers() == ( set([reviewer.pk]), @@ -233,9 +237,9 @@ def _get_notifications_receivers(): # Create a reply 2 _clear_notifications() self.authenticate(reviewer) - data['text'] = 'this is second reply' + data["text"] = "this is second reply" with self.captureOnCommitCallbacks(execute=True): - reply2_id = self.client.post(url, data).json()['id'] + reply2_id = self.client.post(url, data).json()["id"] assert _get_comment_users_pk(reply2_id) == set([tagger1.pk]) assert _get_notifications_receivers() == ( set([tagger1.pk]), # Targeted users for notification @@ -245,9 +249,9 @@ def _get_notifications_receivers(): # Create a reply 3 _clear_notifications() self.authenticate(tagger2) - data['text'] = 'this is third reply' + data["text"] = "this is third reply" with self.captureOnCommitCallbacks(execute=True): - reply3_id = self.client.post(url, data).json()['id'] + reply3_id = self.client.post(url, data).json()["id"] assert _get_comment_users_pk(reply3_id) == set([reviewer.pk, tagger1.pk]) assert _get_notifications_receivers() == ( set([reviewer.pk, tagger1.pk]), @@ -257,9 +261,12 @@ def _get_notifications_receivers(): # Update reply 3 _clear_notifications() with self.captureOnCommitCallbacks(execute=True): - self.client.patch(f'{url}{reply3_id}/', { - 'text': 'updating the third reply text', - }) + self.client.patch( + f"{url}{reply3_id}/", + { + "text": "updating the third reply text", + }, + ) assert _get_comment_users_pk(reply3_id) == set([reviewer.pk, tagger1.pk]) assert _get_notifications_receivers() == ( set([reviewer.pk, tagger1.pk]), @@ -270,9 +277,12 @@ def _get_notifications_receivers(): _clear_notifications() self.authenticate(reviewer) with self.captureOnCommitCallbacks(execute=True): - self.client.patch(f'{url}{comment1_id}/', { - 'assignees': [tagger2.pk], - }) + self.client.patch( + f"{url}{comment1_id}/", + { + "assignees": [tagger2.pk], + }, + ) assert _get_notifications_receivers() == ( set([tagger1.pk, tagger2.pk]), set([Notification.Type.ENTRY_COMMENT_ASSIGNEE_CHANGE]), @@ -282,9 +292,12 @@ def _get_notifications_receivers(): _clear_notifications() self.authenticate(reviewer) with self.captureOnCommitCallbacks(execute=True): - self.client.patch(f'{url}{comment1_id}/', { - 'text': 'updating the comment text', - }) + self.client.patch( + f"{url}{comment1_id}/", + { + "text": "updating the comment text", + }, + ) assert _get_notifications_receivers() == ( set([tagger1.pk, tagger2.pk]), set([Notification.Type.ENTRY_COMMENT_MODIFY]), @@ -294,7 +307,7 @@ def _get_notifications_receivers(): _clear_notifications() self.authenticate(reviewer) with self.captureOnCommitCallbacks(execute=True): - self.client.post(f'{url}{comment1_id}/resolve/') + self.client.post(f"{url}{comment1_id}/resolve/") assert _get_notifications_receivers() == ( set([tagger1.pk, tagger2.pk]), set([Notification.Type.ENTRY_COMMENT_RESOLVED]), @@ -312,20 +325,20 @@ def test_entry_comment_put(self): # Non member user data = { - 'text': 'Test comment', - 'assignees': [user1.pk, user2.pk], + "text": "Test comment", + "assignees": [user1.pk, user2.pk], } self.authenticate(user) - url = f'/api/v1/entries/{entry.pk}/entry-comments/' + url = f"/api/v1/entries/{entry.pk}/entry-comments/" response = self.client.post(url, data) self.assert_201(response) - comment_id = response.json()['id'] + comment_id = response.json()["id"] - url = f'/api/v1/entries/{entry.pk}/entry-comments/{comment_id}/' - data['text'] = 'updated test comment' - data['assignees'] = [user1.pk] + url = f"/api/v1/entries/{entry.pk}/entry-comments/{comment_id}/" + data["text"] = "updated test comment" + data["assignees"] = [user1.pk] response = self.client.put(url, data) self.assert_200(response) @@ -344,10 +357,10 @@ def test_entry_comment_permissions(self): entry = self.create_entry(lead=lead, project=project) entry.project.add_member(user2) - url = f'/api/v1/entries/{entry.id}/entry-comments/' + url = f"/api/v1/entries/{entry.id}/entry-comments/" data = { - 'text': 'test_entry_comment', - 'assignees': [user2.id], + "text": "test_entry_comment", + "assignees": [user2.id], } self.authenticate(user1) @@ -362,12 +375,12 @@ def test_entry_comment_permissions(self): self.assert_201(response) # Check if member can create entry comment with non-member assignee - data['assignees'] = [user3.id] + data["assignees"] = [user3.id] response = self.client.post(url, data) self.assert_400(response) - assert 'assignees' in response.data['errors'] + assert "assignees" in response.data["errors"] - data['assignees'] = [user2.id] + data["assignees"] = [user2.id] # Comment owner should be able to update comment response = self.client.put(f"{url}{resp_data['id']}/", data) self.assert_200(response) diff --git a/apps/entry/tests/test_migrations.py b/apps/entry/tests/test_migrations.py index 9f3f2aa7bd..5a2b390e22 100644 --- a/apps/entry/tests/test_migrations.py +++ b/apps/entry/tests/test_migrations.py @@ -1,16 +1,15 @@ import importlib +from analysis_framework.factories import AnalysisFrameworkFactory from django.db import models -from deep.tests import TestCase - +from entry.factories import EntryFactory from entry.models import Entry -from quality_assurance.models import EntryReviewComment - +from lead.factories import LeadFactory from project.factories import ProjectFactory -from analysis_framework.factories import AnalysisFrameworkFactory +from quality_assurance.models import EntryReviewComment from user.factories import UserFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory + +from deep.tests import TestCase class TestCustomMigrationsLogic(TestCase): @@ -21,7 +20,7 @@ class TestCustomMigrationsLogic(TestCase): """ def test_test_entry_review_verify_control_migrations(self): - migration_file = importlib.import_module('entry.migrations.0031_entry-migrate-verify-to-review-comment') + migration_file = importlib.import_module("entry.migrations.0031_entry-migrate-verify-to-review-comment") # 3 normal users + Additional non-active user user1, user2, user3, _ = UserFactory.create_batch(4) @@ -44,18 +43,18 @@ def test_test_entry_review_verify_control_migrations(self): # 2 verified review comment are created and 1 unverified review comment is created assert EntryReviewComment.objects.count() == 3 # Related review comment are created by user last action on entry. - assert set(EntryReviewComment.objects.values_list('created_by_id', flat=True)) == set([user1.pk, user2.pk, user3.pk]) + assert set(EntryReviewComment.objects.values_list("created_by_id", flat=True)) == set([user1.pk, user2.pk, user3.pk]) assert EntryReviewComment.objects.filter(comment_type=EntryReviewComment.CommentType.VERIFY).count() == 2 assert EntryReviewComment.objects.filter(comment_type=EntryReviewComment.CommentType.UNVERIFY).count() == 1 assert set( EntryReviewComment.objects.filter( comment_type=EntryReviewComment.CommentType.VERIFY, - ).values_list('created_by_id', flat=True) + ).values_list("created_by_id", flat=True) ) == set([user1.pk, user2.pk]) assert set( EntryReviewComment.objects.filter( comment_type=EntryReviewComment.CommentType.UNVERIFY, - ).values_list('created_by_id', flat=True) + ).values_list("created_by_id", flat=True) ) == set([user3.pk]) # All controlled, controlled_changed_by should be reset. assert Entry.objects.filter(controlled=True).count() == 0 @@ -63,19 +62,19 @@ def test_test_entry_review_verify_control_migrations(self): def test_entry_dropped_excerpt_migrations(self): def _get_excerpt_snapshot(): - return list(Entry.objects.order_by('id').values_list('excerpt', 'dropped_excerpt', 'excerpt_modified')) + return list(Entry.objects.order_by("id").values_list("excerpt", "dropped_excerpt", "excerpt_modified")) - migration_file = importlib.import_module('entry.migrations.0036_entry_excerpt_modified') + migration_file = importlib.import_module("entry.migrations.0036_entry_excerpt_modified") af = AnalysisFrameworkFactory.create() project = ProjectFactory.create(analysis_framework=af) lead = LeadFactory.create(project=project) # Create entry before data migrate - EntryFactory.create(lead=lead, excerpt='', dropped_excerpt='') - EntryFactory.create(lead=lead, excerpt='sample-1', dropped_excerpt='') - EntryFactory.create(lead=lead, excerpt='sample-2', dropped_excerpt='sample-2-updated') - EntryFactory.create(lead=lead, excerpt='sample-3', dropped_excerpt='sample-3') + EntryFactory.create(lead=lead, excerpt="", dropped_excerpt="") + EntryFactory.create(lead=lead, excerpt="sample-1", dropped_excerpt="") + EntryFactory.create(lead=lead, excerpt="sample-2", dropped_excerpt="sample-2-updated") + EntryFactory.create(lead=lead, excerpt="sample-3", dropped_excerpt="sample-3") old_excerpt_snaphost = _get_excerpt_snapshot() # Apply the migration logic @@ -83,15 +82,15 @@ def _get_excerpt_snapshot(): new_excerpt_snaphost = _get_excerpt_snapshot() assert Entry.objects.count() == 4 - assert Entry.objects.filter(dropped_excerpt='').count() == 1 + assert Entry.objects.filter(dropped_excerpt="").count() == 1 assert Entry.objects.filter(excerpt_modified=True).count() == 1 - assert Entry.objects.filter(dropped_excerpt=models.F('excerpt')).count() == 3 - assert Entry.objects.exclude(dropped_excerpt=models.F('excerpt')).count() == 1 + assert Entry.objects.filter(dropped_excerpt=models.F("excerpt")).count() == 3 + assert Entry.objects.exclude(dropped_excerpt=models.F("excerpt")).count() == 1 assert new_excerpt_snaphost != old_excerpt_snaphost assert new_excerpt_snaphost == [ - ('', '', False), - ('sample-1', 'sample-1', False), - ('sample-2', 'sample-2-updated', True), - ('sample-3', 'sample-3', False), + ("", "", False), + ("sample-1", "sample-1", False), + ("sample-2", "sample-2-updated", True), + ("sample-3", "sample-3", False), ] diff --git a/apps/entry/tests/test_mutations.py b/apps/entry/tests/test_mutations.py index 56d703672d..3763ceeeca 100644 --- a/apps/entry/tests/test_mutations.py +++ b/apps/entry/tests/test_mutations.py @@ -1,15 +1,13 @@ +from analysis_framework.factories import AnalysisFrameworkFactory, WidgetFactory from django.utils import timezone - -from utils.graphene.tests import GraphQLSnapShotTestCase - +from entry.factories import EntryAttributeFactory, EntryFactory from entry.models import Entry - -from user.factories import UserFactory -from entry.factories import EntryFactory, EntryAttributeFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory -from analysis_framework.factories import AnalysisFrameworkFactory, WidgetFactory from gallery.factories import FileFactory +from lead.factories import LeadFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase class TestEntryMutation(GraphQLSnapShotTestCase): @@ -17,9 +15,10 @@ class TestEntryMutation(GraphQLSnapShotTestCase): TODO: - Make sure only 1 attribute is allowed for one widget """ + factories_used = [FileFactory] - CREATE_ENTRY_QUERY = ''' + CREATE_ENTRY_QUERY = """ mutation MyMutation ($projectId: ID!, $input: EntryInputType!) { project(id: $projectId) { entryCreate(data: $input) { @@ -49,9 +48,9 @@ class TestEntryMutation(GraphQLSnapShotTestCase): } } } - ''' + """ - UPDATE_ENTRY_QUERY = ''' + UPDATE_ENTRY_QUERY = """ mutation MyMutation ($projectId: ID!, $entryId: ID!, $input: EntryInputType!) { project(id: $projectId) { entryUpdate(id: $entryId data: $input) { @@ -81,9 +80,9 @@ class TestEntryMutation(GraphQLSnapShotTestCase): } } } - ''' + """ - DELETE_ENTRY_QUERY = ''' + DELETE_ENTRY_QUERY = """ mutation MyMutation ($projectId: ID!, $entryId: ID!) { project(id: $projectId) { entryDelete(id: $entryId) { @@ -113,9 +112,9 @@ class TestEntryMutation(GraphQLSnapShotTestCase): } } } - ''' + """ - BULK_ENTRY_QUERY = ''' + BULK_ENTRY_QUERY = """ mutation MyMutation ($projectId: ID!, $deleteIds: [ID!], $items: [BulkEntryInputType!]) { project(id: $projectId) { entryBulk(deleteIds: $deleteIds items: $items) { @@ -165,7 +164,7 @@ class TestEntryMutation(GraphQLSnapShotTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -192,9 +191,9 @@ def test_entry_create(self): """ minput = dict( attributes=[ - dict(widget=self.widget1.pk, data=self.dummy_data, clientId='client-id-attribute-1', widgetVersion=1), - dict(widget=self.widget2.pk, data=self.dummy_data, clientId='client-id-attribute-2', widgetVersion=1), - dict(widget=self.widget3.pk, data=self.dummy_data, clientId='client-id-attribute-3', widgetVersion=1), + dict(widget=self.widget1.pk, data=self.dummy_data, clientId="client-id-attribute-1", widgetVersion=1), + dict(widget=self.widget2.pk, data=self.dummy_data, clientId="client-id-attribute-2", widgetVersion=1), + dict(widget=self.widget3.pk, data=self.dummy_data, clientId="client-id-attribute-3", widgetVersion=1), ], order=1, lead=self.lead.pk, @@ -202,19 +201,14 @@ def test_entry_create(self): image=self.other_file.pk, # leadImage='', highlightHidden=False, - excerpt='This is a text', + excerpt="This is a text", entryType=self.genum(Entry.TagType.EXCERPT), - droppedExcerpt='This is a dropped text', - clientId='entry-101', + droppedExcerpt="This is a dropped text", + clientId="entry-101", ) def _query_check(**kwargs): - return self.query_check( - self.CREATE_ENTRY_QUERY, - minput=minput, - variables={'projectId': self.project.id}, - **kwargs - ) + return self.query_check(self.CREATE_ENTRY_QUERY, minput=minput, variables={"projectId": self.project.id}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -231,12 +225,12 @@ def _query_check(**kwargs): # Invalid input self.force_login(self.member_user) response = _query_check(okay=False) - self.assertMatchSnapshot(response, 'error') + self.assertMatchSnapshot(response, "error") # Valid input - minput['image'] = self.our_file.pk + minput["image"] = self.our_file.pk response = _query_check() - self.assertMatchSnapshot(response, 'success') + self.assertMatchSnapshot(response, "success") def test_entry_update(self): """ @@ -246,9 +240,9 @@ def test_entry_update(self): minput = dict( attributes=[ - dict(widget=self.widget1.pk, data=self.dummy_data, clientId='client-id-attribute-1', widgetVersion=1), - dict(widget=self.widget2.pk, data=self.dummy_data, clientId='client-id-attribute-2', widgetVersion=1), - dict(widget=self.widget1.pk, data=self.dummy_data, clientId='client-id-attribute-3', widgetVersion=1), + dict(widget=self.widget1.pk, data=self.dummy_data, clientId="client-id-attribute-1", widgetVersion=1), + dict(widget=self.widget2.pk, data=self.dummy_data, clientId="client-id-attribute-2", widgetVersion=1), + dict(widget=self.widget1.pk, data=self.dummy_data, clientId="client-id-attribute-3", widgetVersion=1), ], order=1, lead=self.lead.pk, @@ -256,18 +250,15 @@ def test_entry_update(self): image=self.other_file.pk, # leadImage='', highlightHidden=False, - excerpt='This is a text', + excerpt="This is a text", entryType=self.genum(Entry.TagType.EXCERPT), - droppedExcerpt='This is a dropped text', - clientId='entry-101', + droppedExcerpt="This is a dropped text", + clientId="entry-101", ) def _query_check(**kwargs): return self.query_check( - self.UPDATE_ENTRY_QUERY, - minput=minput, - variables={'projectId': self.project.id, 'entryId': entry.id}, - **kwargs + self.UPDATE_ENTRY_QUERY, minput=minput, variables={"projectId": self.project.id, "entryId": entry.id}, **kwargs ) # -- Without login @@ -285,12 +276,12 @@ def _query_check(**kwargs): # Invalid input self.force_login(self.member_user) response = _query_check(okay=False) - self.assertMatchSnapshot(response, 'error') + self.assertMatchSnapshot(response, "error") # Valid input - minput['image'] = self.our_file.pk + minput["image"] = self.our_file.pk response = _query_check() - self.assertMatchSnapshot(response, 'success') + self.assertMatchSnapshot(response, "success") def test_entry_delete(self): """ @@ -300,9 +291,7 @@ def test_entry_delete(self): def _query_check(**kwargs): return self.query_check( - self.DELETE_ENTRY_QUERY, - variables={'projectId': self.project.id, 'entryId': entry.id}, - **kwargs + self.DELETE_ENTRY_QUERY, variables={"projectId": self.project.id, "entryId": entry.id}, **kwargs ) # -- Without login @@ -319,16 +308,15 @@ def _query_check(**kwargs): # --- member user # Invalid input self.force_login(self.member_user) - content = _query_check(okay=False)['data']['project']['entryDelete']['result'] - self.assertIdEqual(content['id'], entry.id) + content = _query_check(okay=False)["data"]["project"]["entryDelete"]["result"] + self.assertIdEqual(content["id"], entry.id) def test_entry_bulk(self): """ This test makes sure only valid users can bulk create/update/delete entry """ entry1, entry2 = EntryFactory.create_batch( - 2, - project=self.project, lead=self.lead, analysis_framework=self.project.analysis_framework + 2, project=self.project, lead=self.lead, analysis_framework=self.project.analysis_framework ) entry2_att1 = EntryAttributeFactory.create(entry=entry2, widget=self.widget1, data=self.dummy_data) @@ -341,14 +329,14 @@ def test_entry_bulk(self): dict( widget=self.widget1.pk, data=self.dummy_data, - clientId='client-id-old-new-attribute-1', - widgetVersion=1 + clientId="client-id-old-new-attribute-1", + widgetVersion=1, ), dict( id=entry2_att1.pk, widget=self.widget1.pk, data=self.dummy_data, - clientId='client-id-old-attribute-1', + clientId="client-id-old-attribute-1", widgetVersion=1, ), ], @@ -358,17 +346,17 @@ def test_entry_bulk(self): image=self.other_file.pk, # leadImage='', highlightHidden=False, - excerpt='This is a text (UPDATED)', + excerpt="This is a text (UPDATED)", entryType=self.genum(Entry.TagType.EXCERPT), - droppedExcerpt='This is a dropped text (UPDATED)', - clientId='entry-old-101 (UPDATED)', + droppedExcerpt="This is a dropped text (UPDATED)", + clientId="entry-old-101 (UPDATED)", ), dict( attributes=[ dict( widget=self.widget1.pk, data=self.dummy_data, - clientId='client-id-new-attribute-1', + clientId="client-id-new-attribute-1", widgetVersion=1, ), ], @@ -378,20 +366,16 @@ def test_entry_bulk(self): image=self.other_file.pk, # leadImage='', highlightHidden=False, - excerpt='This is a text (NEW)', + excerpt="This is a text (NEW)", entryType=self.genum(Entry.TagType.EXCERPT), - droppedExcerpt='This is a dropped text (NEW)', - clientId='entry-new-102', - ) + droppedExcerpt="This is a dropped text (NEW)", + clientId="entry-new-102", + ), ], ) def _query_check(**kwargs): - return self.query_check( - self.BULK_ENTRY_QUERY, - variables={'projectId': self.project.id, **minput}, - **kwargs - ) + return self.query_check(self.BULK_ENTRY_QUERY, variables={"projectId": self.project.id, **minput}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -408,12 +392,12 @@ def _query_check(**kwargs): self.force_login(self.member_user) # Invalid input response = _query_check(okay=False) - self.assertMatchSnapshot(response, 'error') + self.assertMatchSnapshot(response, "error") # Valid input - minput['items'][0]['image'] = self.our_file.pk - minput['items'][1]['image'] = self.our_file.pk + minput["items"][0]["image"] = self.our_file.pk + minput["items"][1]["image"] = self.our_file.pk response = _query_check() - self.assertMatchSnapshot(response, 'success') + self.assertMatchSnapshot(response, "success") # TODO: Add test for other entry attributes id as well diff --git a/apps/entry/tests/test_permissions.py b/apps/entry/tests/test_permissions.py index 210cf6160f..53cd0bf3be 100644 --- a/apps/entry/tests/test_permissions.py +++ b/apps/entry/tests/test_permissions.py @@ -1,29 +1,29 @@ -from deep.tests import TestCase -from entry.models import Lead -from entry.models import Entry +from analysis_framework.models import AnalysisFramework +from entry.models import Entry, Lead from project.models import Project, ProjectRole from project.permissions import PROJECT_PERMISSIONS, get_project_permissions_value -from analysis_framework.models import AnalysisFramework + +from deep.tests import TestCase class TestEntryPermissions(TestCase): def setUp(self): super().setUp() self.no_entry_creation_role = ProjectRole.objects.create( - title='No Lead Creation Role', + title="No Lead Creation Role", entry_permissions=0, - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), - assessment_permissions=get_project_permissions_value('assessment', '__all__'), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), ) self.entry_creation_role = ProjectRole.objects.create( - title='Lead Creation Role', - entry_permissions=get_project_permissions_value('entry', ['create']), - lead_permissions=get_project_permissions_value('lead', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), - assessment_permissions=get_project_permissions_value('assessment', '__all__'), + title="Lead Creation Role", + entry_permissions=get_project_permissions_value("entry", ["create"]), + lead_permissions=get_project_permissions_value("lead", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), ) def test_cannot_view_confidential_entry_without_permissions(self): @@ -39,24 +39,24 @@ def test_cannot_view_confidential_entry_without_permissions(self): entry1 = self.create(Entry, lead=lead1, project=project) entry_confidential = self.create(Entry, lead=lead_confidential, project=project) - url = '/api/v1/entries/' + url = "/api/v1/entries/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - entries_ids = set([x['id'] for x in resp.data['results']]) + entries_ids = set([x["id"] for x in resp.data["results"]]) assert entries_ids == {entry1.id} # Check particular non-confidential entry, should return 200 - url = f'/api/v1/entries/{entry1.id}/' + url = f"/api/v1/entries/{entry1.id}/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) # Check particular confidential entry, should return 404 - url = f'/api/v1/entries/{entry_confidential.id}/' + url = f"/api/v1/entries/{entry_confidential.id}/" self.authenticate() resp = self.client.get(url) @@ -65,19 +65,15 @@ def test_cannot_view_confidential_entry_without_permissions(self): def test_create_entry_no_permission(self): # Create project where self.user has no entry creation permission af = self.create(AnalysisFramework) - project = self.create( - Project, - analysis_framework=af, - role=self.no_entry_creation_role - ) + project = self.create(Project, analysis_framework=af, role=self.no_entry_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'project': project.pk, - 'lead': lead.pk, - 'analysis_framework': project.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': {}, + "project": project.pk, + "lead": lead.pk, + "analysis_framework": project.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": {}, } self.authenticate() @@ -87,19 +83,15 @@ def test_create_entry_no_permission(self): def test_create_entry_with_permission(self): # Create project where self.user has no entry creation permission af = self.create(AnalysisFramework) - project = self.create( - Project, - analysis_framework=af, - role=self.entry_creation_role - ) + project = self.create(Project, analysis_framework=af, role=self.entry_creation_role) lead = self.create(Lead, project=project) - url = '/api/v1/entries/' + url = "/api/v1/entries/" data = { - 'project': project.pk, - 'lead': lead.pk, - 'analysis_framework': project.analysis_framework.pk, - 'excerpt': 'This is test excerpt', - 'attributes': {}, + "project": project.pk, + "lead": lead.pk, + "analysis_framework": project.analysis_framework.pk, + "excerpt": "This is test excerpt", + "attributes": {}, } self.authenticate() @@ -108,10 +100,7 @@ def test_create_entry_with_permission(self): def create_project(self): analysis_framework = self.create(AnalysisFramework) - return self.create( - Project, analysis_framework=analysis_framework, - role=self.admin_role - ) + return self.create(Project, analysis_framework=analysis_framework, role=self.admin_role) def create_lead(self, project=None): project = project or self.create_project() diff --git a/apps/entry/tests/test_schemas.py b/apps/entry/tests/test_schemas.py index b4a49b2265..79d716f437 100644 --- a/apps/entry/tests/test_schemas.py +++ b/apps/entry/tests/test_schemas.py @@ -1,22 +1,19 @@ -from lead.models import Lead - -from entry.models import Entry -from quality_assurance.models import EntryReviewComment +from analysis_framework.factories import AnalysisFrameworkFactory, WidgetFactory from analysis_framework.models import Widget - -from utils.graphene.tests import GraphQLTestCase - -from user.factories import UserFactory -from geo.factories import RegionFactory, AdminLevelFactory, GeoAreaFactory -from project.factories import ProjectFactory +from assessment_registry.factories import AssessmentRegistryFactory +from entry.factories import EntryAttributeFactory, EntryFactory +from entry.models import Entry +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory from lead.factories import LeadFactory -from entry.factories import EntryFactory, EntryAttributeFactory -from analysis_framework.factories import AnalysisFrameworkFactory, WidgetFactory +from lead.models import Lead +from lead.tests.test_schemas import TestLeadQuerySchema from organization.factories import OrganizationFactory, OrganizationTypeFactory -from assessment_registry.factories import AssessmentRegistryFactory +from project.factories import ProjectFactory from quality_assurance.factories import EntryReviewCommentFactory +from quality_assurance.models import EntryReviewComment +from user.factories import UserFactory -from lead.tests.test_schemas import TestLeadQuerySchema +from utils.graphene.tests import GraphQLTestCase class TestEntryQuery(GraphQLTestCase): @@ -29,7 +26,7 @@ def setUp(self): def test_lead_entries_query(self): # Includes permissions checks - query = ''' + query = """ query MyQuery ($projectId: ID! $leadId: ID!) { project(id: $projectId) { lead(id: $leadId) { @@ -65,26 +62,26 @@ def test_lead_entries_query(self): } } } - ''' + """ lead = LeadFactory.create(project=self.project) entry = EntryFactory.create(project=self.project, analysis_framework=self.af, lead=lead) def _query_check(**kwargs): - return self.query_check(query, variables={'projectId': self.project.pk, 'leadId': lead.id}, **kwargs) + return self.query_check(query, variables={"projectId": self.project.pk, "leadId": lead.id}, **kwargs) # Without login _query_check(assert_for_error=True) # With login self.force_login(self.user) content = _query_check() - results = content['data']['project']['lead']['entries'] - self.assertEqual(len(content['data']['project']['lead']['entries']), 1, content) - self.assertIdEqual(results[0]['id'], entry.pk, results) + results = content["data"]["project"]["lead"]["entries"] + self.assertEqual(len(content["data"]["project"]["lead"]["entries"]), 1, content) + self.assertIdEqual(results[0]["id"], entry.pk, results) def test_entries_query(self): # Includes permissions checks - query = ''' + query = """ query MyQuery ($projectId: ID!) { project(id: $projectId) { entries (ordering: "-id") { @@ -121,7 +118,7 @@ def test_entries_query(self): } } } - ''' + """ user = UserFactory.create() lead = LeadFactory.create(project=self.project) @@ -130,7 +127,7 @@ def test_entries_query(self): conf_entry = EntryFactory.create(project=self.project, analysis_framework=self.af, lead=conf_lead) def _query_check(**kwargs): - return self.query_check(query, variables={'projectId': self.project.pk}, **kwargs) + return self.query_check(query, variables={"projectId": self.project.pk}, **kwargs) # Without login _query_check(assert_for_error=True) @@ -138,27 +135,27 @@ def _query_check(**kwargs): self.force_login(user) # -- Without membership content = _query_check() - results = content['data']['project']['entries']['results'] - self.assertEqual(content['data']['project']['entries']['totalCount'], 0, content) + results = content["data"]["project"]["entries"]["results"] + self.assertEqual(content["data"]["project"]["entries"]["totalCount"], 0, content) self.assertEqual(len(results), 0, results) # -- Without membership (confidential only) current_membership = self.project.add_member(user, role=self.project_role_reader_non_confidential) content = _query_check() - results = content['data']['project']['entries']['results'] - self.assertEqual(content['data']['project']['entries']['totalCount'], 1, content) - self.assertIdEqual(results[0]['id'], entry.pk, results) + results = content["data"]["project"]["entries"]["results"] + self.assertEqual(content["data"]["project"]["entries"]["totalCount"], 1, content) + self.assertIdEqual(results[0]["id"], entry.pk, results) # -- With membership (non-confidential only) current_membership.delete() self.project.add_member(user, role=self.project_role_reader) content = _query_check() - results = content['data']['project']['entries']['results'] - self.assertEqual(content['data']['project']['entries']['totalCount'], 2, content) - self.assertIdEqual(results[0]['id'], conf_entry.pk, results) - self.assertIdEqual(results[1]['id'], entry.pk, results) + results = content["data"]["project"]["entries"]["results"] + self.assertEqual(content["data"]["project"]["entries"]["totalCount"], 2, content) + self.assertIdEqual(results[0]["id"], conf_entry.pk, results) + self.assertIdEqual(results[1]["id"], entry.pk, results) def test_entry_query(self): # Includes permissions checks - query = ''' + query = """ query MyQuery ($projectId: ID! $entryId: ID!) { project(id: $projectId) { entry (id: $entryId) { @@ -192,7 +189,7 @@ def test_entry_query(self): } } } - ''' + """ user = UserFactory.create() lead = LeadFactory.create(project=self.project) @@ -201,7 +198,7 @@ def test_entry_query(self): conf_entry = EntryFactory.create(project=self.project, analysis_framework=self.af, lead=conf_lead) def _query_check(entry, **kwargs): - return self.query_check(query, variables={'projectId': self.project.pk, 'entryId': entry.id}, **kwargs) + return self.query_check(query, variables={"projectId": self.project.pk, "entryId": entry.id}, **kwargs) # Without login _query_check(entry, assert_for_error=True) @@ -210,25 +207,25 @@ def _query_check(entry, **kwargs): self.force_login(user) # -- Without membership content = _query_check(entry) # Normal entry - self.assertEqual(content['data']['project']['entry'], None, content) + self.assertEqual(content["data"]["project"]["entry"], None, content) content = _query_check(conf_entry) # Confidential entry - self.assertEqual(content['data']['project']['entry'], None, content) + self.assertEqual(content["data"]["project"]["entry"], None, content) # -- Without membership (confidential only) current_membership = self.project.add_member(user, role=self.project_role_reader_non_confidential) content = _query_check(entry) # Normal entry - self.assertNotEqual(content['data']['project']['entry'], None, content) + self.assertNotEqual(content["data"]["project"]["entry"], None, content) content = _query_check(conf_entry) # Confidential entry - self.assertEqual(content['data']['project']['entry'], None, content) + self.assertEqual(content["data"]["project"]["entry"], None, content) # -- With membership (non-confidential only) current_membership.delete() self.project.add_member(user, role=self.project_role_reader) content = _query_check(entry) # Normal entry - self.assertNotEqual(content['data']['project']['entry'], None, content) + self.assertNotEqual(content["data"]["project"]["entry"], None, content) content = _query_check(conf_entry) # Confidential entry - self.assertNotEqual(content['data']['project']['entry'], None, content) + self.assertNotEqual(content["data"]["project"]["entry"], None, content) def test_entry_query_filter(self): - query = ''' + query = """ query MyQuery ( $projectId: ID! $leadAuthoringOrganizationTypes: [ID!] @@ -300,7 +297,7 @@ def test_entry_query_filter(self): } } } - ''' + """ af = AnalysisFrameworkFactory.create() project = ProjectFactory.create(analysis_framework=af) @@ -317,7 +314,7 @@ def test_entry_query_filter(self): project.add_member(member2, role=self.project_role_reader) lead1 = LeadFactory.create( project=project, - title='Test 1', + title="Test 1", source_type=Lead.SourceType.TEXT, confidentiality=Lead.Confidentiality.CONFIDENTIAL, authors=[org1, org2], @@ -329,7 +326,7 @@ def test_entry_query_filter(self): lead2 = LeadFactory.create( project=project, source_type=Lead.SourceType.TEXT, - title='Test 2', + title="Test 2", assignee=[member2], authors=[org2, org3], priority=Lead.Priority.HIGH, @@ -338,15 +335,15 @@ def test_entry_query_filter(self): lead3 = LeadFactory.create( project=project, source_type=Lead.SourceType.WEBSITE, - url='https://wwwexample.com/sample-1', - title='Sample 1', + url="https://wwwexample.com/sample-1", + title="Sample 1", confidentiality=Lead.Confidentiality.CONFIDENTIAL, authors=[org1, org3], priority=Lead.Priority.LOW, ) lead4 = LeadFactory.create( project=project, - title='Sample 2', + title="Sample 2", authors=[org1], priority=Lead.Priority.MEDIUM, ) @@ -355,18 +352,22 @@ def test_entry_query_filter(self): other_lead = LeadFactory.create(project=other_project) outside_entry = EntryFactory.create(project=other_project, analysis_framework=af, lead=other_lead) entry1_1 = EntryFactory.create( - project=project, analysis_framework=af, lead=lead1, entry_type=Entry.TagType.EXCERPT, controlled=False) + project=project, analysis_framework=af, lead=lead1, entry_type=Entry.TagType.EXCERPT, controlled=False + ) entry1_1.verified_by.add(user) entry2_1 = EntryFactory.create( - project=project, analysis_framework=af, lead=lead2, entry_type=Entry.TagType.IMAGE, controlled=True) + project=project, analysis_framework=af, lead=lead2, entry_type=Entry.TagType.IMAGE, controlled=True + ) entry2_1.verified_by.add(user) entry3_1 = EntryFactory.create( - project=project, analysis_framework=af, lead=lead3, entry_type=Entry.TagType.EXCERPT, controlled=False) + project=project, analysis_framework=af, lead=lead3, entry_type=Entry.TagType.EXCERPT, controlled=False + ) entry4_1 = EntryFactory.create( - project=project, analysis_framework=af, lead=lead4, entry_type=Entry.TagType.EXCERPT, controlled=False) + project=project, analysis_framework=af, lead=lead4, entry_type=Entry.TagType.EXCERPT, controlled=False + ) # For assessment filters AssessmentRegistryFactory.create(project=project, lead=lead1) @@ -377,7 +378,7 @@ def test_entry_query_filter(self): EntryReviewCommentFactory(entry=entry3_1, created_by=member1, comment_type=EntryReviewComment.CommentType.UNCONTROL) # Change lead1 status to TAGGED lead1.status = Lead.Status.TAGGED - lead1.save(update_fields=['status']) + lead1.save(update_fields=["status"]) # Some leads/entries in other projects other_leads = LeadFactory.create_batch(3, project=ProjectFactory.create(analysis_framework=af)) @@ -388,47 +389,44 @@ def test_entry_query_filter(self): # TODO: Add direct test for filter_set as well (is used within export) for filter_data, expected_entries in [ - ({'controlled': True}, [entry2_1]), - ({'controlled': False}, [entry1_1, entry3_1, entry4_1]), - ({'entriesId': [entry1_1.id, entry2_1.id, outside_entry.id]}, [entry1_1, entry2_1]), - ({'entryTypes': [self.genum(Entry.TagType.EXCERPT)]}, [entry1_1, entry3_1, entry4_1]), + ({"controlled": True}, [entry2_1]), + ({"controlled": False}, [entry1_1, entry3_1, entry4_1]), + ({"entriesId": [entry1_1.id, entry2_1.id, outside_entry.id]}, [entry1_1, entry2_1]), + ({"entryTypes": [self.genum(Entry.TagType.EXCERPT)]}, [entry1_1, entry3_1, entry4_1]), ( - {'entryTypes': [self.genum(Entry.TagType.EXCERPT), self.genum(Entry.TagType.IMAGE)]}, + {"entryTypes": [self.genum(Entry.TagType.EXCERPT), self.genum(Entry.TagType.IMAGE)]}, [entry1_1, entry2_1, entry3_1, entry4_1], ), # TODO: ({'projectEntryLabels': []}, []), # TODO: ({'geoCustomShape': []}, []), # Lead filters - ({'leadAuthoringOrganizationTypes': [org_type2.pk]}, [entry1_1, entry2_1, entry3_1]), - ({'leadAuthoringOrganizationTypes': [org_type1.pk, org_type2.pk]}, [entry1_1, entry2_1, entry3_1, entry4_1]), - ({'leads': [lead1.pk, lead2.pk]}, [entry1_1, entry2_1]), - ({'leadTitle': 'test'}, [entry1_1, entry2_1]), - ({'leadAssignees': [member2.pk]}, [entry2_1]), - ({'leadAssignees': [member1.pk, member2.pk]}, [entry1_1, entry2_1]), - ({'leadConfidentialities': self.genum(Lead.Confidentiality.CONFIDENTIAL)}, [entry1_1, entry3_1]), - ({'leadPriorities': [self.genum(Lead.Priority.HIGH)]}, [entry1_1, entry2_1]), + ({"leadAuthoringOrganizationTypes": [org_type2.pk]}, [entry1_1, entry2_1, entry3_1]), + ({"leadAuthoringOrganizationTypes": [org_type1.pk, org_type2.pk]}, [entry1_1, entry2_1, entry3_1, entry4_1]), + ({"leads": [lead1.pk, lead2.pk]}, [entry1_1, entry2_1]), + ({"leadTitle": "test"}, [entry1_1, entry2_1]), + ({"leadAssignees": [member2.pk]}, [entry2_1]), + ({"leadAssignees": [member1.pk, member2.pk]}, [entry1_1, entry2_1]), + ({"leadConfidentialities": self.genum(Lead.Confidentiality.CONFIDENTIAL)}, [entry1_1, entry3_1]), + ({"leadPriorities": [self.genum(Lead.Priority.HIGH)]}, [entry1_1, entry2_1]), + ({"leadPriorities": [self.genum(Lead.Priority.LOW), self.genum(Lead.Priority.HIGH)]}, [entry1_1, entry2_1, entry3_1]), + ({"leadStatuses": [self.genum(Lead.Status.NOT_TAGGED)]}, []), + ({"leadStatuses": [self.genum(Lead.Status.IN_PROGRESS)]}, [entry2_1, entry3_1, entry4_1]), + ({"leadStatuses": [self.genum(Lead.Status.TAGGED)]}, [entry1_1]), ( - {'leadPriorities': [self.genum(Lead.Priority.LOW), self.genum(Lead.Priority.HIGH)]}, - [entry1_1, entry2_1, entry3_1] - ), - ({'leadStatuses': [self.genum(Lead.Status.NOT_TAGGED)]}, []), - ({'leadStatuses': [self.genum(Lead.Status.IN_PROGRESS)]}, [entry2_1, entry3_1, entry4_1]), - ({'leadStatuses': [self.genum(Lead.Status.TAGGED)]}, [entry1_1]), - ( - {'leadStatuses': [self.genum(Lead.Status.IN_PROGRESS), self.genum(Lead.Status.TAGGED)]}, - [entry1_1, entry2_1, entry3_1, entry4_1] + {"leadStatuses": [self.genum(Lead.Status.IN_PROGRESS), self.genum(Lead.Status.TAGGED)]}, + [entry1_1, entry2_1, entry3_1, entry4_1], ), - ({'leadIsAssessment': True}, [entry1_1, entry2_1]), - ({'leadIsAssessment': False}, [entry3_1, entry4_1]), - ({'leadHasAssessment': True}, [entry1_1]), - ({'leadHasAssessment': False}, [entry2_1, entry3_1, entry4_1]), - ({'hasComment': True}, [entry1_1, entry3_1]), - ({'hasComment': False}, [entry2_1, entry4_1]), - ({'isVerified': True}, [entry1_1, entry2_1]), - ({'isVerified': False}, [entry3_1, entry4_1]), - ({'search': str(entry1_1.id)}, [entry1_1]), - ({'search': str('1.11')}, []), - ({'search': lead2.title}, [entry2_1]), + ({"leadIsAssessment": True}, [entry1_1, entry2_1]), + ({"leadIsAssessment": False}, [entry3_1, entry4_1]), + ({"leadHasAssessment": True}, [entry1_1]), + ({"leadHasAssessment": False}, [entry2_1, entry3_1, entry4_1]), + ({"hasComment": True}, [entry1_1, entry3_1]), + ({"hasComment": False}, [entry2_1, entry4_1]), + ({"isVerified": True}, [entry1_1, entry2_1]), + ({"isVerified": False}, [entry3_1, entry4_1]), + ({"search": str(entry1_1.id)}, [entry1_1]), + ({"search": str("1.11")}, []), + ({"search": lead2.title}, [entry2_1]), # TODO: Common filters # ({'excerpt': []}, []), # ({'modifiedAt': []}, []), @@ -443,31 +441,31 @@ def test_entry_query_filter(self): # ({'leadPublishedOnLte': []}, []), ]: # Entry filter test - content = self.query_check(query, variables={'projectId': project.id, **filter_data}) + content = self.query_check(query, variables={"projectId": project.id, **filter_data}) self.assertListIds( - content['data']['project']['entries']['results'], expected_entries, - {'response': content, 'filter': filter_data} + content["data"]["project"]["entries"]["results"], expected_entries, {"response": content, "filter": filter_data} ) # Lead filter test expected_leads = set([entry.lead for entry in expected_entries]) content = self.query_check( TestLeadQuerySchema.lead_filter_query, variables={ - 'projectId': project.id, - 'hasEntries': True, - 'entriesFilterData': filter_data, - } + "projectId": project.id, + "hasEntries": True, + "entriesFilterData": filter_data, + }, ) self.assertListIds( - content['data']['project']['leads']['results'], expected_leads, - {'response': content, 'filter': filter_data}, + content["data"]["project"]["leads"]["results"], + expected_leads, + {"response": content, "filter": filter_data}, ) class TestEntryFilterDataQuery(GraphQLTestCase): def setUp(self): super().setUp() - self.entries_query = ''' + self.entries_query = """ query MyQuery ($projectId: ID! $filterableData: [EntryFilterDataInputType!]) { project(id: $projectId) { entries (filterableData: $filterableData) { @@ -477,7 +475,7 @@ def setUp(self): } } } - ''' + """ # AnalysisFramework setup self.af = AnalysisFrameworkFactory.create() @@ -498,28 +496,16 @@ def setUp(self): # For LIST Filter self.widget_multiselect = WidgetFactory.create( analysis_framework=self.af, - key='multiselect-widget-101', - title='Multiselect Widget', + key="multiselect-widget-101", + title="Multiselect Widget", widget_id=Widget.WidgetType.MULTISELECT, properties={ - 'data': { - 'options': [ - { - 'key': 'key-101', - 'label': 'Key label 101' - }, - { - 'key': 'key-102', - 'label': 'Key label 102' - }, - { - 'key': 'key-103', - 'label': 'Key label 103' - }, - { - 'key': 'key-104', - 'label': 'Key label 104' - }, + "data": { + "options": [ + {"key": "key-101", "label": "Key label 101"}, + {"key": "key-102", "label": "Key label 102"}, + {"key": "key-103", "label": "Key label 103"}, + {"key": "key-104", "label": "Key label 104"}, ] }, }, @@ -527,28 +513,28 @@ def setUp(self): # For Number Filter self.widget_number = WidgetFactory.create( analysis_framework=self.af, - key='number-widget-101', - title='Number Widget', + key="number-widget-101", + title="Number Widget", widget_id=Widget.WidgetType.NUMBER, ) # For INTERSECTS Filter self.widget_date_range = WidgetFactory.create( analysis_framework=self.af, - key='date-range-widget-101', - title='DateRange Widget', + key="date-range-widget-101", + title="DateRange Widget", widget_id=Widget.WidgetType.DATE_RANGE, ) # For TEXT Filter self.widget_text = WidgetFactory.create( analysis_framework=self.af, - key='text-widget-101', - title='Text Widget', + key="text-widget-101", + title="Text Widget", widget_id=Widget.WidgetType.TEXT, ) self.widget_geo = WidgetFactory.create( analysis_framework=self.af, - key='geo-widget-101', - title='GEO Widget', + key="geo-widget-101", + title="GEO Widget", widget_id=Widget.WidgetType.GEO, ) @@ -569,56 +555,50 @@ def test(self): entry3_2 = EntryFactory.create(**self.entry_create_kwargs, lead=self.lead3) # Create attributes for multiselect (LIST Filter) - EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_multiselect, data={'value': ['key-101', 'key-102']}) - EntryAttributeFactory.create(entry=entry2_1, widget=self.widget_multiselect, data={'value': ['key-102', 'key-103']}) + EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_multiselect, data={"value": ["key-101", "key-102"]}) + EntryAttributeFactory.create(entry=entry2_1, widget=self.widget_multiselect, data={"value": ["key-102", "key-103"]}) # Create attributes for time (NUMBER Filter) - EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_number, data={'value': 10001}) - EntryAttributeFactory.create(entry=entry3_1, widget=self.widget_number, data={'value': 10002}) + EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_number, data={"value": 10001}) + EntryAttributeFactory.create(entry=entry3_1, widget=self.widget_number, data={"value": 10002}) # Create attributes for date range (INTERSECTS Filter) EntryAttributeFactory.create( entry=entry2_1, widget=self.widget_date_range, - data={'value': {'startDate': '2020-01-10', 'endDate': '2020-01-20'}}, + data={"value": {"startDate": "2020-01-10", "endDate": "2020-01-20"}}, ) EntryAttributeFactory.create( entry=entry3_1, widget=self.widget_date_range, - data={'value': {'startDate': '2020-01-10', 'endDate': '2020-02-20'}}, + data={"value": {"startDate": "2020-01-10", "endDate": "2020-02-20"}}, ) EntryAttributeFactory.create( entry=entry3_2, widget=self.widget_date_range, - data={'value': {'startDate': '2020-01-15', 'endDate': '2020-01-25'}}, + data={"value": {"startDate": "2020-01-15", "endDate": "2020-01-25"}}, ) # Create attributes for text (TEXT Filter) - EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_text, data={'value': 'This is a test 1'}) - EntryAttributeFactory.create(entry=entry3_1, widget=self.widget_text, data={'value': 'This is a test 2'}) + EntryAttributeFactory.create(entry=entry1_1, widget=self.widget_text, data={"value": "This is a test 1"}) + EntryAttributeFactory.create(entry=entry3_1, widget=self.widget_text, data={"value": "This is a test 2"}) # Create attributes for GEO (LIST Filter) EntryAttributeFactory.create( - entry=entry1_1, widget=self.widget_geo, - data={'value': [self.geo_area_3_2.pk]} # Leaf tagged + entry=entry1_1, widget=self.widget_geo, data={"value": [self.geo_area_3_2.pk]} # Leaf tagged ) + EntryAttributeFactory.create(entry=entry2_1, widget=self.widget_geo, data={"value": [self.geo_area_1.pk]}) # Root tagged EntryAttributeFactory.create( - entry=entry2_1, widget=self.widget_geo, - data={'value': [self.geo_area_1.pk]} # Root tagged + entry=entry3_1, widget=self.widget_geo, data={"value": [self.geo_area_2_1.pk]} # Middle child tagged ) EntryAttributeFactory.create( - entry=entry3_1, widget=self.widget_geo, - data={'value': [self.geo_area_2_1.pk]} # Middle child tagged - ) - EntryAttributeFactory.create( - entry=entry3_2, widget=self.widget_geo, - data={'value': [self.geo_area_1.pk, self.geo_area_3_2.pk]} # Middle child tagged + leaf node + entry=entry3_2, + widget=self.widget_geo, + data={"value": [self.geo_area_1.pk, self.geo_area_3_2.pk]}, # Middle child tagged + leaf node ) # Some entries with different AF - other_entry = EntryFactory.create(lead=self.lead1, analysis_framework=AnalysisFrameworkFactory.create(title='Other')) + other_entry = EntryFactory.create(lead=self.lead1, analysis_framework=AnalysisFrameworkFactory.create(title="Other")) + EntryAttributeFactory.create(entry=other_entry, widget=self.widget_multiselect, data={"value": ["key-101", "key-102"]}) + EntryAttributeFactory.create(entry=other_entry, widget=self.widget_number, data={"value": 10002}) EntryAttributeFactory.create( - entry=other_entry, widget=self.widget_multiselect, data={'value': ['key-101', 'key-102']}) - EntryAttributeFactory.create(entry=other_entry, widget=self.widget_number, data={'value': 10002}) - EntryAttributeFactory.create( - entry=other_entry, widget=self.widget_geo, - data={'value': [self.geo_area_3_2.pk]} # Leaf tagged + entry=other_entry, widget=self.widget_geo, data={"value": [self.geo_area_3_2.pk]} # Leaf tagged ) # Some leads/entries in other projects @@ -631,80 +611,78 @@ def test(self): for filter_name, filter_data, expected_entries in [ # NUMBER Filter Cases ( - 'number-filter-1', + "number-filter-1", [ { - 'filterKey': self.widget_number.key, - 'value': '10001', - 'valueGte': '10002', # This is ignored when value is provided - 'valueLte': '10005', # This is ignored when value is provided + "filterKey": self.widget_number.key, + "value": "10001", + "valueGte": "10002", # This is ignored when value is provided + "valueLte": "10005", # This is ignored when value is provided }, ], [entry1_1], ), ( - 'number-filter-2', + "number-filter-2", [ { - 'filterKey': self.widget_number.key, - 'valueGte': '10001', - 'valueLte': '10005', + "filterKey": self.widget_number.key, + "valueGte": "10001", + "valueLte": "10005", }, ], [entry1_1, entry3_1], ), ( - 'number-filter-3', + "number-filter-3", [ { - 'filterKey': self.widget_number.key, - 'valueLte': '10001', + "filterKey": self.widget_number.key, + "valueLte": "10001", }, ], [entry1_1], ), ( - 'number-filter-4', + "number-filter-4", [ { - 'filterKey': self.widget_number.key, - 'valueGte': '10002', + "filterKey": self.widget_number.key, + "valueGte": "10002", }, ], [entry3_1], ), - # TEXT Filter Cases ( - 'text-filter-1', + "text-filter-1", [ { - 'filterKey': self.widget_text.key, - 'value': 'This is a test', - 'valueGte': '10002', # This is ignored + "filterKey": self.widget_text.key, + "value": "This is a test", + "valueGte": "10002", # This is ignored }, ], [entry1_1, entry3_1], ), ( - 'text-filter-2', + "text-filter-2", [ { - 'filterKey': self.widget_text.key, - 'value': 'This is a test 1', - 'valueLte': '10002', # This is ignored + "filterKey": self.widget_text.key, + "value": "This is a test 1", + "valueLte": "10002", # This is ignored }, ], [entry1_1], ), - # INTERSECTS TODO: May need more test cases ( - 'intersect-filter-1', + "intersect-filter-1", [ { - 'filterKey': self.widget_date_range.key, - 'value': '2020-01-10', + "filterKey": self.widget_date_range.key, + "value": "2020-01-10", # 'valueLte': '2020-01-01', # TODO: # 'valueGte': '2020-01-30', # TODO: }, @@ -712,163 +690,163 @@ def test(self): [entry2_1, entry3_1], ), ( - 'intersect-filter-2', + "intersect-filter-2", [ { - 'filterKey': self.widget_date_range.key, - 'valueGte': '2020-01-01', - 'valueLte': '2020-01-30', + "filterKey": self.widget_date_range.key, + "valueGte": "2020-01-01", + "valueLte": "2020-01-30", }, ], [entry2_1, entry3_1, entry3_2], ), ( - 'intersect-filter-3', + "intersect-filter-3", [ { - 'filterKey': self.widget_date_range.key, - 'valueGte': '2020-01-30', # Only one is ignored + "filterKey": self.widget_date_range.key, + "valueGte": "2020-01-30", # Only one is ignored }, ], [entry1_1, entry2_1, entry3_1, entry3_2], ), - # LIST Filter ( - 'list-filter-1', + "list-filter-1", [ { - 'filterKey': self.widget_multiselect.key, - 'value': '13', # This is ignored + "filterKey": self.widget_multiselect.key, + "value": "13", # This is ignored }, ], [entry1_1, entry2_1, entry3_1, entry3_2], ), ( - 'list-filter-2', + "list-filter-2", [ { - 'filterKey': self.widget_multiselect.key, - 'valueList': ['key-101', 'key-102'], + "filterKey": self.widget_multiselect.key, + "valueList": ["key-101", "key-102"], }, ], [entry1_1, entry2_1], ), ( - 'list-filter-3', + "list-filter-3", [ { - 'filterKey': self.widget_multiselect.key, - 'valueList': ['key-101', 'key-102'], - 'useAndOperator': True, + "filterKey": self.widget_multiselect.key, + "valueList": ["key-101", "key-102"], + "useAndOperator": True, }, ], [entry1_1], ), ( - 'list-filter-4', + "list-filter-4", [ { - 'filterKey': self.widget_multiselect.key, - 'valueList': ['key-101', 'key-102'], - 'useAndOperator': True, - 'useExclude': True, + "filterKey": self.widget_multiselect.key, + "valueList": ["key-101", "key-102"], + "useAndOperator": True, + "useExclude": True, }, ], [entry2_1, entry3_1, entry3_2], ), ( - 'list-filter-5', + "list-filter-5", [ { - 'filterKey': self.widget_multiselect.key, - 'valueList': ['key-101', 'key-102'], - 'useExclude': True, + "filterKey": self.widget_multiselect.key, + "valueList": ["key-101", "key-102"], + "useExclude": True, }, ], [entry3_1, entry3_2], ), - # GEO (LIST) Filter ( - 'geo-filter-1', + "geo-filter-1", [ { - 'filterKey': self.widget_geo.key, - 'valueList': [self.geo_area_1.pk], + "filterKey": self.widget_geo.key, + "valueList": [self.geo_area_1.pk], }, ], [entry2_1, entry3_2], ), ( - 'geo-filter-2', + "geo-filter-2", [ { - 'filterKey': self.widget_geo.key, - 'valueList': [self.geo_area_1.pk], - 'includeSubRegions': True, + "filterKey": self.widget_geo.key, + "valueList": [self.geo_area_1.pk], + "includeSubRegions": True, }, ], [entry1_1, entry2_1, entry3_1, entry3_2], ), ( - 'geo-filter-3', + "geo-filter-3", [ { - 'filterKey': self.widget_geo.key, - 'valueList': [self.geo_area_1.pk], - 'includeSubRegions': True, - 'useExclude': True, + "filterKey": self.widget_geo.key, + "valueList": [self.geo_area_1.pk], + "includeSubRegions": True, + "useExclude": True, }, ], [], ), ( - 'geo-filter-4', + "geo-filter-4", [ { - 'filterKey': self.widget_geo.key, - 'valueList': [self.geo_area_2_2.pk], - 'includeSubRegions': True, + "filterKey": self.widget_geo.key, + "valueList": [self.geo_area_2_2.pk], + "includeSubRegions": True, }, ], [entry1_1, entry3_2], ), ( - 'geo-filter-5', + "geo-filter-5", [ { - 'filterKey': self.widget_geo.key, - 'valueList': [self.geo_area_2_2.pk], - 'includeSubRegions': True, - 'useExclude': True, + "filterKey": self.widget_geo.key, + "valueList": [self.geo_area_2_2.pk], + "includeSubRegions": True, + "useExclude": True, }, ], [entry2_1, entry3_1], - ) + ), ]: # Lead filter test content = self.query_check( self.entries_query, - variables={'projectId': self.project.id, 'filterableData': filter_data}, + variables={"projectId": self.project.id, "filterableData": filter_data}, ) self.assertListIds( - content['data']['project']['entries']['results'], expected_entries, - {'response': content, 'filter': filter_data, 'filter_name': filter_name} + content["data"]["project"]["entries"]["results"], + expected_entries, + {"response": content, "filter": filter_data, "filter_name": filter_name}, ) # Lead filter test expected_leads = set([entry.lead for entry in expected_entries]) content = self.query_check( TestLeadQuerySchema.lead_filter_query, variables={ - 'projectId': self.project.id, - 'hasEntries': True, - 'entriesFilterData': { - 'filterableData': filter_data, - } - } + "projectId": self.project.id, + "hasEntries": True, + "entriesFilterData": { + "filterableData": filter_data, + }, + }, ) self.assertListIds( - content['data']['project']['leads']['results'], expected_leads, - {'response': content, 'filter': filter_data, 'filter_name': filter_name} + content["data"]["project"]["leads"]["results"], + expected_leads, + {"response": content, "filter": filter_data, "filter_name": filter_name}, ) diff --git a/apps/entry/utils.py b/apps/entry/utils.py index c88dd1ffa5..684165b376 100644 --- a/apps/entry/utils.py +++ b/apps/entry/utils.py @@ -1,9 +1,10 @@ from entry.models import Attribute from gallery.models import File + from utils.image import decode_base64_if_possible -from .widgets.utils import set_filter_data, set_export_data from .widgets.store import widget_store +from .widgets.utils import set_export_data, set_filter_data def update_entry_attribute(attribute): @@ -21,8 +22,8 @@ def update_entry_attribute(attribute): widget.properties or {}, ) - filter_data_list = update_info.get('filter_data') - export_data = update_info.get('export_data') + filter_data_list = update_info.get("filter_data") + export_data = update_info.get("export_data") if filter_data_list: for filter_data in filter_data_list: @@ -55,9 +56,9 @@ def base64_to_deep_image(image, lead, user): if isinstance(decoded_file, str): return image - mime_type = '' + mime_type = "" if header: - mime_type = header[len('data:'):] + mime_type = header[len("data:") :] file = File.objects.create( title=decoded_file.name, diff --git a/apps/entry/views.py b/apps/entry/views.py index 52d8355dab..12777445d5 100644 --- a/apps/entry/views.py +++ b/apps/entry/views.py @@ -1,37 +1,43 @@ from collections import defaultdict +import django_filters +from analysis_framework.models import Filter, Widget from django.contrib.auth.models import User from django.db import models from django.utils import timezone -from rest_framework.decorators import action +from lead.models import Lead +from organization.models import OrganizationType +from project.models import Project from rest_framework import ( filters, generics, + mixins, permissions, response, + serializers, views, viewsets, - serializers, - mixins, ) +from rest_framework.decorators import action from reversion.models import Version -import django_filters +from tabular.models import Field as TabularField -from deep.permissions import ModifyPermission, IsProjectMember, CreateEntryPermission -from project.models import Project -from lead.models import Lead -from analysis_framework.models import Widget -from organization.models import OrganizationType -from analysis_framework.models import Filter +from deep.permissions import CreateEntryPermission, IsProjectMember, ModifyPermission from .errors import EntryValidationVersionMismatchError -from .widgets import matrix1d_widget, matrix2d_widget -from .models import ( - Entry, Attribute, FilterData, ExportData, EntryComment, - # Entry Grouping - ProjectEntryLabel, LeadEntryGroup, EntryGroupLabel, +from .filter_set import EntryCommentFilterSet, EntryFilterSet, get_filtered_entries +from .models import ( # Entry Grouping + Attribute, + Entry, + EntryComment, + EntryGroupLabel, + ExportData, + FilterData, + LeadEntryGroup, + ProjectEntryLabel, ) -from .serializers import ( +from .pagination import ComprehensiveEntriesSetPagination +from .serializers import ( # Entry Grouping AttributeSerializer, ComprehensiveEntriesSerializer, EditEntriesDataSerializer, @@ -42,25 +48,18 @@ EntrySerializer, ExportDataSerializer, FilterDataSerializer, - # Entry Grouping - ProjectEntryLabelDetailSerializer, LeadEntryGroupSerializer, + ProjectEntryLabelDetailSerializer, ) -from .pagination import ComprehensiveEntriesSetPagination -from .filter_set import ( - EntryFilterSet, - EntryCommentFilterSet, - get_filtered_entries, -) -from tabular.models import Field as TabularField +from .widgets import matrix1d_widget, matrix2d_widget class EntrySummaryPaginationMixin(object): def get_entries_filters(self): - if hasattr(self, '_entry_filters'): + if hasattr(self, "_entry_filters"): return self._entry_filters - filters = self.request.data.get('filters', []) + filters = self.request.data.get("filters", []) filters = {f[0]: f[1] for f in filters} self._entry_filters = filters return self._entry_filters @@ -68,15 +67,14 @@ def get_entries_filters(self): def get_counts_by_matrix_2d(self, qs): # Project should be provided filters = self.get_entries_filters() - project = filters.get('project') + project = filters.get("project") if project is None: return {} # Pull necessary widgets widgets = Widget.objects.filter( - analysis_framework__project=project, - widget_id__in=[matrix1d_widget.WIDGET_ID, matrix2d_widget.WIDGET_ID] - ).values_list('key', 'widget_id', 'properties') + analysis_framework__project=project, widget_id__in=[matrix1d_widget.WIDGET_ID, matrix2d_widget.WIDGET_ID] + ).values_list("key", "widget_id", "properties") # Pull necessary filters filters = { @@ -86,14 +84,8 @@ def get_counts_by_matrix_2d(self, qs): key__in=[ _key for key, widget_id, _ in widgets - for _key in ( - [ - f'{key}-dimensions', - f'{key}-sectors' - ] if widget_id == matrix2d_widget.WIDGET_ID else - [key] - ) - ] + for _key in ([f"{key}-dimensions", f"{key}-sectors"] if widget_id == matrix2d_widget.WIDGET_ID else [key]) + ], ) } @@ -101,92 +93,75 @@ def get_counts_by_matrix_2d(self, qs): agg_data = qs.aggregate( **{ f"{key}__{ele['id' if widget_id == matrix2d_widget.WIDGET_ID else 'key']}__{control_status}": models.Count( - 'id', + "id", filter=models.Q( - controlled=control_status == 'controlled', - filterdata__filter=( - filters[f'{key}-{data_type}' if widget_id == matrix2d_widget.WIDGET_ID else key] - ), - filterdata__values__contains=[ - ele['id' if widget_id == matrix2d_widget.WIDGET_ID else 'key'] - ], + controlled=control_status == "controlled", + filterdata__filter=(filters[f"{key}-{data_type}" if widget_id == matrix2d_widget.WIDGET_ID else key]), + filterdata__values__contains=[ele["id" if widget_id == matrix2d_widget.WIDGET_ID else "key"]], ), distinct=True, ) for key, widget_id, properties in widgets - for data_type in ( - ['sectors', 'dimensions'] if widget_id == matrix2d_widget.WIDGET_ID else - ['rows'] - ) - for _ele in properties['data'][data_type] - for ele in [ - _ele, - *( - _ele.get(f'sub{data_type}' if widget_id == matrix2d_widget.WIDGET_ID else 'cells') or [] - ) - ] - for control_status in ['controlled', 'uncontrolled'] + for data_type in (["sectors", "dimensions"] if widget_id == matrix2d_widget.WIDGET_ID else ["rows"]) + for _ele in properties["data"][data_type] + for ele in [_ele, *(_ele.get(f"sub{data_type}" if widget_id == matrix2d_widget.WIDGET_ID else "cells") or [])] + for control_status in ["controlled", "uncontrolled"] } ) # Re-structure data (also snake-case to camel case conversion will change the key) response = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) for key, count in agg_data.items(): - widget_key, label_key, controlled_status = key.split('__') + widget_key, label_key, controlled_status = key.split("__") response[widget_key][label_key][controlled_status] = count return [ { - 'widget_key': widget_key, - 'label_key': label_key, - 'controlled_count': count['controlled'], - 'uncontrolled_count': count['uncontrolled'], + "widget_key": widget_key, + "label_key": label_key, + "controlled_count": count["controlled"], + "uncontrolled_count": count["uncontrolled"], } for widget_key, widget_data in response.items() for label_key, count in widget_data.items() ] def get_paginated_response(self, data): - calculate_summary = self.request.data.get( - 'calculate_summary', - self.request.GET.get('calculate_summary', '0') - ) == '1' - calculate_count_per_toc_item = self.request.data.get( - 'calculate_count_per_toc_item', - self.request.GET.get('calculate_count_per_toc_item', '0') - ) == '1' + calculate_summary = self.request.data.get("calculate_summary", self.request.GET.get("calculate_summary", "0")) == "1" + calculate_count_per_toc_item = ( + self.request.data.get("calculate_count_per_toc_item", self.request.GET.get("calculate_count_per_toc_item", "0")) + == "1" + ) if not calculate_summary: return super().get_paginated_response(data) qs = self.filter_queryset(self.get_queryset()) - q = qs.annotate( - org=models.functions.Coalesce('lead__authors__parent', 'lead__authors') - ).values('org').annotate( - org_type=models.functions.Coalesce( - 'lead__authors__parent__organization_type', - 'lead__authors__organization_type' + q = ( + qs.annotate(org=models.functions.Coalesce("lead__authors__parent", "lead__authors")) + .values("org") + .annotate( + org_type=models.functions.Coalesce("lead__authors__parent__organization_type", "lead__authors__organization_type") ) - ).values('org_type').order_by('org_type').annotate( - count=models.Count('org', distinct=True) - ).values('org_type', 'count') + .values("org_type") + .order_by("org_type") + .annotate(count=models.Count("org", distinct=True)) + .values("org_type", "count") + ) - q = { - each['org_type']: each['count'] - for each in q if each['org_type'] - } + q = {each["org_type"]: each["count"] for each in q if each["org_type"]} org_types = OrganizationType.objects.filter(id__in=q.keys()) org_type_count = [ { - 'org': { - 'id': org_type.id, - 'title': org_type.title, - 'shortName': org_type.short_name, + "org": { + "id": org_type.id, + "title": org_type.title, + "shortName": org_type.short_name, }, - 'count': q[org_type.id] + "count": q[org_type.id], } for org_type in org_types ] - total_sources = qs.values('lead__source_id').annotate(count=models.Count('lead__source_id')).count() - total_leads = qs.values('lead_id').annotate(count=models.Count('lead_id')).count() + total_sources = qs.values("lead__source_id").annotate(count=models.Count("lead__source_id")).count() + total_leads = qs.values("lead_id").annotate(count=models.Count("lead_id")).count() summary_data = dict( total_controlled_entries=qs.filter(controlled=True).count(), total_uncontrolled_entries=qs.filter(controlled=False).count(), @@ -195,39 +170,35 @@ def get_paginated_response(self, data): total_sources=total_sources, ) if calculate_count_per_toc_item: - summary_data['count_per_toc_item'] = self.get_counts_by_matrix_2d(qs) + summary_data["count_per_toc_item"] = self.get_counts_by_matrix_2d(qs) - return response.Response({ - **super().get_paginated_response(data).data, - 'summary': summary_data - }) + return response.Response({**super().get_paginated_response(data).data, "summary": summary_data}) class EntryViewSet(EntrySummaryPaginationMixin, viewsets.ModelViewSet): """ Entry view set """ + serializer_class = EntrySerializer - permission_classes = [permissions.IsAuthenticated, CreateEntryPermission, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, CreateEntryPermission, ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter) + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter) filterset_class = EntryFilterSet - search_fields = ('lead__title', 'excerpt') + search_fields = ("lead__title", "excerpt") def get_queryset(self): return get_filtered_entries(self.request.user, self.request.GET) def get_serializer_class(self): - if self.action == 'list': + if self.action == "list": return EntryRetriveSerializer return super().get_serializer_class() @action( detail=False, - url_path='processed', + url_path="processed", serializer_class=EntryProccesedSerializer, ) def get_proccessed_entries(self, request, version=None): @@ -238,23 +209,18 @@ def get_proccessed_entries(self, request, version=None): def _validate_entry_version(self, entry, requested_version): if requested_version is None: - raise EntryValidationVersionMismatchError('Version is required') + raise EntryValidationVersionMismatchError("Version is required") current_entry_version = Version.objects.get_for_object(entry).count() if requested_version != current_entry_version: raise EntryValidationVersionMismatchError( - f'Version mismatch. Current version in server: {current_entry_version}.' - f' Requested version: {requested_version}' + f"Version mismatch. Current version in server: {current_entry_version}." + f" Requested version: {requested_version}" ) - @action( - detail=True, - permission_classes=[permissions.IsAuthenticated, ModifyPermission], - url_path='control', - methods=['post'] - ) + @action(detail=True, permission_classes=[permissions.IsAuthenticated, ModifyPermission], url_path="control", methods=["post"]) def control_entry(self, request, **kwargs): entry = self.get_object() - self._validate_entry_version(entry, request.data.get('version_id')) + self._validate_entry_version(entry, request.data.get("version_id")) entry.control(request.user) return response.Response( self.get_serializer_class()( @@ -264,14 +230,11 @@ def control_entry(self, request, **kwargs): ) @action( - detail=True, - permission_classes=[permissions.IsAuthenticated, ModifyPermission], - url_path='uncontrol', - methods=['post'] + detail=True, permission_classes=[permissions.IsAuthenticated, ModifyPermission], url_path="uncontrol", methods=["post"] ) def uncontrol_entry(self, request, **kwargs): entry = self.get_object() - self._validate_entry_version(entry, request.data.get('version_id')) + self._validate_entry_version(entry, request.data.get("version_id")) entry.control(request.user, controlled=False) return response.Response( self.get_serializer_class()( @@ -285,54 +248,51 @@ class EntryFilterView(EntrySummaryPaginationMixin, generics.GenericAPIView): """ Entry view for getting entries based filters in POST body """ + serializer_class = EntryRetriveProccesedSerializer permission_classes = [permissions.IsAuthenticated] def get_queryset(self): filters = self.get_entries_filters() - queryset = get_filtered_entries(self.request.user, filters).select_related( - 'lead', 'lead__attachment', - 'controlled_changed_by', - ).prefetch_related('lead__assignee') + queryset = ( + get_filtered_entries(self.request.user, filters) + .select_related( + "lead", + "lead__attachment", + "controlled_changed_by", + ) + .prefetch_related("lead__assignee") + ) queryset = Entry.annotate_comment_count(queryset) - project = filters.get('project') - search = filters.get('search') + project = filters.get("project") + search = filters.get("search") if search: # For searching tabular columns field_filters = {} if project: - field_filters['sheet__book__project'] = project + field_filters["sheet__book__project"] = project - fields = TabularField.objects.filter( - title__icontains=search, - **field_filters - ) + fields = TabularField.objects.filter(title__icontains=search, **field_filters) queryset = queryset.filter( - models.Q(lead__title__icontains=search) | - models.Q(excerpt__icontains=search) | - ( - models.Q( - tabular_field__in=models.Subquery( - fields.values_list('pk', flat=True)) - ) - ) + models.Q(lead__title__icontains=search) + | models.Q(excerpt__icontains=search) + | (models.Q(tabular_field__in=models.Subquery(fields.values_list("pk", flat=True)))) ) - return ( - queryset - .select_related( - 'image', 'lead', - 'created_by__profile', 'modified_by__profile', - ).prefetch_related( - 'attribute_set', - 'lead__authors', - 'lead__assignee', - 'lead__assignee__profile', - 'lead__leadpreview', - ) + return queryset.select_related( + "image", + "lead", + "created_by__profile", + "modified_by__profile", + ).prefetch_related( + "attribute_set", + "lead__authors", + "lead__assignee", + "lead__assignee__profile", + "lead__leadpreview", ) def post(self, request, **kwargs): @@ -343,7 +303,7 @@ def post(self, request, **kwargs): entry_group_label_count = {} entry_group_label_qs = EntryGroupLabel.get_stat_for_entry(EntryGroupLabel.objects.filter(entry__in=page)) for count_data in entry_group_label_qs: - entry_id = count_data.pop('entry') + entry_id = count_data.pop("entry") if entry_id not in entry_group_label_count: entry_group_label_count[entry_id] = [count_data] else: @@ -352,8 +312,8 @@ def post(self, request, **kwargs): # Custom Context serializer_class = self.get_serializer_class() context = self.get_serializer_context() - context['entry_group_label_count'] = entry_group_label_count - context['post_is_used_for_filter'] = True + context["entry_group_label_count"] = entry_group_label_count + context["post_is_used_for_filter"] = True if page is not None: serializer = serializer_class(page, many=True, context=context) @@ -367,21 +327,22 @@ class EditEntriesDataViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet) """ Page API for Edit Entries """ + serializer_class = EditEntriesDataSerializer permission_classes = [permissions.IsAuthenticated] def get_queryset(self): # TODO: Optimize this queryset return Lead.get_for(self.request.user).select_related( - 'project', - 'project__analysis_framework', 'project__analysis_framework__organization', + "project", + "project__analysis_framework", + "project__analysis_framework__organization", ) class AttributeViewSet(viewsets.ModelViewSet): serializer_class = AttributeSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return Attribute.get_for(self.request.user) @@ -389,8 +350,7 @@ def get_queryset(self): class FilterDataViewSet(viewsets.ModelViewSet): serializer_class = FilterDataSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return FilterData.get_for(self.request.user) @@ -398,8 +358,7 @@ def get_queryset(self): class ExportDataViewSet(viewsets.ModelViewSet): serializer_class = ExportDataSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return ExportData.get_for(self.request.user) @@ -409,45 +368,50 @@ class EntryOptionsView(views.APIView): """ Options for various attributes related to entry """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, version=None): - project_query = request.GET.get('project') - fields_query = request.GET.get('fields') + project_query = request.GET.get("project") + fields_query = request.GET.get("fields") projects = Project.get_for_member(request.user) if project_query: - projects = projects.filter(id__in=project_query.split(',')) + projects = projects.filter(id__in=project_query.split(",")) fields = None if fields_query: - fields = fields_query.split(',') + fields = fields_query.split(",") options = { - 'lead_status': [ + "lead_status": [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Status.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Status.choices ], - 'lead_priority': [ + "lead_priority": [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Priority.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Priority.choices ], - 'lead_confidentiality': [ + "lead_confidentiality": [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Confidentiality.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Confidentiality.choices ], - 'organization_types': [ + "organization_types": [ { - 'key': organization_type.id, - 'value': organization_type.title, - } for organization_type in OrganizationType.objects.all() - ] + "key": organization_type.id, + "value": organization_type.title, + } + for organization_type in OrganizationType.objects.all() + ], } def _filter_by_projects(qs, projects): @@ -455,18 +419,18 @@ def _filter_by_projects(qs, projects): qs = qs.filter(project=p) return qs - if fields is None or 'created_by' in fields: + if fields is None or "created_by" in fields: created_by = _filter_by_projects(User.objects, projects) - options['created_by'] = [ + options["created_by"] = [ { - 'key': user.id, - 'value': user.profile.get_display_name(), - } for user in created_by.distinct() + "key": user.id, + "value": user.profile.get_display_name(), + } + for user in created_by.distinct() ] - if fields is None or 'project_entry_labels' in fields: - options['project_entry_label'] = ProjectEntryLabel.objects.filter( - project__in=projects).values('id', 'title', 'color') + if fields is None or "project_entry_labels" in fields: + options["project_entry_label"] = ProjectEntryLabel.objects.filter(project__in=projects).values("id", "title", "color") return response.Response(options) @@ -476,6 +440,7 @@ class ComprehensiveEntriesViewSet(viewsets.ReadOnlyModelViewSet): Comprehensive API for Entries TODO: Should we create this view also?? """ + serializer_class = ComprehensiveEntriesSerializer permission_classes = [permissions.IsAuthenticated] pagination_class = ComprehensiveEntriesSetPagination @@ -486,21 +451,23 @@ def get_queryset(self): ignore_widget_type = [Widget.WidgetType.EXCERPT.value] prefetch_related_fields = [ models.Prefetch( - 'attribute_set', + "attribute_set", queryset=Attribute.objects.exclude(widget__widget_id__in=ignore_widget_type), ), models.Prefetch( - 'attribute_set__widget', + "attribute_set__widget", queryset=Widget.objects.exclude(widget_id__in=ignore_widget_type), ), - 'created_by', 'created_by__profile', - 'modified_by', 'modified_by__profile', + "created_by", + "created_by__profile", + "modified_by", + "modified_by__profile", ] return Entry.get_for(self.request.user).prefetch_related(*prefetch_related_fields) def get_serializer_context(self): context = super().get_serializer_context() - context['queryset'] = self.get_queryset() + context["queryset"] = self.get_queryset() return context @@ -511,27 +478,27 @@ class EntryCommentViewSet(viewsets.ModelViewSet): filterset_class = EntryCommentFilterSet def get_queryset(self): - return EntryComment.get_for(self.request.user).filter(entry=self.kwargs['entry_id']) + return EntryComment.get_for(self.request.user).filter(entry=self.kwargs["entry_id"]) def get_serializer_context(self): return { **super().get_serializer_context(), - 'entry_id': self.kwargs.get('entry_id'), + "entry_id": self.kwargs.get("entry_id"), } @action( detail=True, - url_path='resolve', - methods=['post'], + url_path="resolve", + methods=["post"], ) def resolve_comment(self, request, pk, entry_id=None, version=None): comment = self.get_object() if comment.is_resolved: - raise serializers.ValidationError('Already Resolved') + raise serializers.ValidationError("Already Resolved") if comment.parent: - raise serializers.ValidationError('only root comment can be resolved') + raise serializers.ValidationError("only root comment can be resolved") if comment.created_by != request.user: - raise serializers.ValidationError('only comment owner can resolve') + raise serializers.ValidationError("only comment owner can resolve") comment.is_resolved = True comment.resolved_at = timezone.now() comment.save() @@ -542,18 +509,20 @@ class ProjectEntryLabelViewSet(viewsets.ModelViewSet): # TODO: Restrict non-admin for update/create serializer_class = ProjectEntryLabelDetailSerializer permission_classes = [ - permissions.IsAuthenticated, IsProjectMember, ModifyPermission, + permissions.IsAuthenticated, + IsProjectMember, + ModifyPermission, ] def get_queryset(self): - return ProjectEntryLabel.objects.filter(project=self.kwargs['project_id']).annotate( - entry_count=models.Count('entrygrouplabel__entry', distinct=True), + return ProjectEntryLabel.objects.filter(project=self.kwargs["project_id"]).annotate( + entry_count=models.Count("entrygrouplabel__entry", distinct=True), ) @action( detail=False, - url_path='bulk-update-order', - methods=['post'], + url_path="bulk-update-order", + methods=["post"], ) def bulk_update_order(self, request, *args, **kwargs): # TODO: Restrict non-admin @@ -564,13 +533,11 @@ def bulk_update_order(self, request, *args, **kwargs): [{"id": 1, "order": 2}, {"id": 2, "order": 1}] ``` """ - labels_order = { - label['id']: label['order'] for label in request.data if label.get('id') - } + labels_order = {label["id"]: label["order"] for label in request.data if label.get("id")} labels = [] for label in self.get_queryset().filter(id__in=labels_order.keys()).all(): label.order = labels_order[label.pk] - label.save(update_fields=['order']) + label.save(update_fields=["order"]) labels.append(label) return response.Response(self.get_serializer(labels, many=True).data) @@ -578,8 +545,10 @@ def bulk_update_order(self, request, *args, **kwargs): class LeadEntryGroupViewSet(viewsets.ModelViewSet): serializer_class = LeadEntryGroupSerializer permission_classes = [ - permissions.IsAuthenticated, IsProjectMember, ModifyPermission, + permissions.IsAuthenticated, + IsProjectMember, + ModifyPermission, ] def get_queryset(self): - return LeadEntryGroup.objects.filter(lead=self.kwargs['lead_id']) + return LeadEntryGroup.objects.filter(lead=self.kwargs["lead_id"]) diff --git a/apps/entry/widgets/conditional_widget.py b/apps/entry/widgets/conditional_widget.py index 7673c61f57..52c967e785 100644 --- a/apps/entry/widgets/conditional_widget.py +++ b/apps/entry/widgets/conditional_widget.py @@ -1,4 +1,6 @@ -from analysis_framework.widgets.conditional_widget import WIDGET_ID # type: ignore # noqa:F401 +from analysis_framework.widgets.conditional_widget import ( # type: ignore # noqa:F401 + WIDGET_ID, +) class Dummy: @@ -8,13 +10,10 @@ class Dummy: def update_attribute(widget, data, widget_data): from entry.widgets.store import widget_store - value = data.get('value', {}) - selected_widget_key = value.get('selected_widget_key') + value = data.get("value", {}) + selected_widget_key = value.get("selected_widget_key") - widgets = [ - w.get('widget') - for w in (widget_data.get('widgets') or []) - ] + widgets = [w.get("widget") for w in (widget_data.get("widgets") or [])] filter_data = [] excel_data = [] @@ -22,64 +21,69 @@ def update_attribute(widget, data, widget_data): report_keys = [] common_data = {} for w in widgets: - widget_module = widget_store.get(w.get('widget_id')) - common_data[getattr(widget_module, 'WIDGET_ID')] = getattr(widget_module, 'DATA_VERSION', None) + widget_module = widget_store.get(w.get("widget_id")) + common_data[getattr(widget_module, "WIDGET_ID")] = getattr(widget_module, "DATA_VERSION", None) if not widget_module: continue - w_key = w.get('key') + w_key = w.get("key") if w_key == selected_widget_key: - w_data = value.get(w_key, {}).get('data', {}) + w_data = value.get(w_key, {}).get("data", {}) else: w_data = {} - w_widget_data = w.get('properties', {}).get('data', {}) + w_widget_data = w.get("properties", {}).get("data", {}) w_obj = Dummy() w_obj.key = w_key - w_obj.title = w['title'] - w_obj.widget_id = w['widget_id'] + w_obj.title = w["title"] + w_obj.widget_id = w["widget_id"] update_info = widget_module.update_attribute( w_obj, w_data, w_widget_data, ) - w_filter_data = update_info.get('filter_data') or [] - w_export_data = update_info.get('export_data') - - filter_data = filter_data + [{ - **wfd, - 'key': '{}-{}'.format( - widget.key, - wfd.get('key', w_key), - ) - } for wfd in w_filter_data] + w_filter_data = update_info.get("filter_data") or [] + w_export_data = update_info.get("export_data") + + filter_data = filter_data + [ + { + **wfd, + "key": "{}-{}".format( + widget.key, + wfd.get("key", w_key), + ), + } + for wfd in w_filter_data + ] if w_export_data: - excel_data.append({ - **w_export_data.get('data', {}).get('common', {}), - **w_export_data.get('data', {}).get('excel', {}), - }) + excel_data.append( + { + **w_export_data.get("data", {}).get("common", {}), + **w_export_data.get("data", {}).get("excel", {}), + } + ) report_datum = { - **w_export_data.get('data', {}).get('common', {}), - **w_export_data.get('data', {}).get('report', {}), + **w_export_data.get("data", {}).get("common", {}), + **w_export_data.get("data", {}).get("report", {}), } - report_keys += report_datum.get('keys') or [] + report_keys += report_datum.get("keys") or [] report_data.append(report_datum) else: excel_data.append(None) return { - 'filter_data': filter_data, - 'export_data': { - 'data': { - 'excel': excel_data, - 'report': { - 'other': report_data, - 'keys': report_keys, + "filter_data": filter_data, + "export_data": { + "data": { + "excel": excel_data, + "report": { + "other": report_data, + "keys": report_keys, }, - 'common': common_data + "common": common_data, # TODO: 'condition': }, }, @@ -89,42 +93,41 @@ def update_attribute(widget, data, widget_data): def get_comprehensive_data(widgets_meta, widget, data, widget_data): from entry.widgets.store import widget_store - value = data.get('value', {}) - selected_widget_key = value.get('selected_widget_key') + value = data.get("value", {}) + selected_widget_key = value.get("selected_widget_key") selected_widgets = [ - w.get('widget') - for w in (widget_data.get('widgets') or []) if w.get('widget', {}).get('key') == selected_widget_key + w.get("widget") for w in (widget_data.get("widgets") or []) if w.get("widget", {}).get("key") == selected_widget_key ] selected_widget = selected_widgets[0] if selected_widgets else None if selected_widget is None: return None - widget_module = widget_store.get(selected_widget.get('widget_id')) + widget_module = widget_store.get(selected_widget.get("widget_id")) if widget_module is None: return None - w_key = selected_widget.get('key') + w_key = selected_widget.get("key") if w_key == selected_widget_key: - w_data = value.get(w_key, {}).get('data', {}) + w_data = value.get(w_key, {}).get("data", {}) else: w_data = {} - w_widget_data = selected_widget.get('properties', {}).get('data', {}) + w_widget_data = selected_widget.get("properties", {}).get("data", {}) w_obj = Dummy() w_obj.pk = f"${w_key}-{selected_widget.get('widget_id')}" w_obj.key = w_key return { - 'id': selected_widget.get('key'), - 'type': selected_widget.get('widget_id'), - 'title': selected_widget.get('title'), - 'value': widget_module.get_comprehensive_data( + "id": selected_widget.get("key"), + "type": selected_widget.get("widget_id"), + "title": selected_widget.get("title"), + "value": widget_module.get_comprehensive_data( widgets_meta, w_obj, w_data, w_widget_data, - ) + ), } diff --git a/apps/entry/widgets/date_range_widget.py b/apps/entry/widgets/date_range_widget.py index 96b60059f4..54ecc62010 100644 --- a/apps/entry/widgets/date_range_widget.py +++ b/apps/entry/widgets/date_range_widget.py @@ -7,9 +7,9 @@ def _get_date(widget, data, widget_properties): - value = data.get('value', {}) - from_value = value.get('startDate') # TODO: use from - to_value = value.get('endDate') # TODO: use to + value = data.get("value", {}) + from_value = value.get("startDate") # TODO: use from + to_value = value.get("endDate") # TODO: use to from_date, from_number = parse_date_str(from_value) to_date, to_number = parse_date_str(to_value) @@ -29,23 +29,22 @@ def update_attribute(widget, data, widget_properties): return { # NOTE: Please update the data version when you update the data format - 'filter_data': [{ - 'from_number': from_number, - 'to_number': to_number, - }], - - 'export_data': { - 'data': { - 'common': { - 'values': [from_date, to_date], - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [ + { + "from_number": from_number, + "to_number": to_number, + } + ], + "export_data": { + "data": { + "common": { + "values": [from_date, to_date], + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - }, - 'report': { - } + "excel": {}, + "report": {}, }, }, } @@ -54,6 +53,6 @@ def update_attribute(widget, data, widget_properties): def get_comprehensive_data(_, *args): (from_date, to_date) = _get_date(*args)[1] return { - 'from': from_date, - 'to': to_date, + "from": from_date, + "to": to_date, } diff --git a/apps/entry/widgets/date_widget.py b/apps/entry/widgets/date_widget.py index 2f7791965f..ef0ea3d98c 100644 --- a/apps/entry/widgets/date_widget.py +++ b/apps/entry/widgets/date_widget.py @@ -1,7 +1,7 @@ +from analysis_framework.widgets.date_widget import WIDGET_ID from dateutil.parser import parse as date_parse from utils.common import ONE_DAY, deep_date_format -from analysis_framework.widgets.date_widget import WIDGET_ID # NOTE: Please update the data version when you update the data format # this is tallied against the version stored in the export json data @@ -16,7 +16,7 @@ def parse_date_str(value): def _get_date(widget, data, widget_properties): - value = data.get('value') + value = data.get("value") return parse_date_str(value) @@ -25,22 +25,21 @@ def update_attribute(widget, data, widget_properties): return { # NOTE: Please update the data version when you update the data format - 'filter_data': [{ - 'number': number, - }], - - 'export_data': { - 'data': { - 'common': { - 'value': date, - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - }, - 'excel': { + "filter_data": [ + { + "number": number, + } + ], + "export_data": { + "data": { + "common": { + "value": date, + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'report': { - } + "excel": {}, + "report": {}, } }, } diff --git a/apps/entry/widgets/geo_widget.py b/apps/entry/widgets/geo_widget.py index be36082178..42f9c9d7c1 100644 --- a/apps/entry/widgets/geo_widget.py +++ b/apps/entry/widgets/geo_widget.py @@ -1,14 +1,14 @@ -from utils.common import is_valid_number from analysis_framework.widgets.geo_widget import WIDGET_ID +from utils.common import is_valid_number DATA_VERSION = 1 def _get_geoareas_from_polygon(geo_value): try: - properties = geo_value['geo_json']['properties'] - return properties['geoareas'], properties.get('title'), geo_value['region'] + properties = geo_value["geo_json"]["properties"] + return properties["geoareas"], properties.get("title"), geo_value["region"] except (AttributeError, KeyError): return [], None, None @@ -28,14 +28,14 @@ def get_valid_geo_ids(raw_values, extract_polygon_title=False): else: # This will be a polygon pgeo_areas, ptitle, pregion_id = _get_geoareas_from_polygon(raw_value) - geo_areas.extend([ - int(id) for id in pgeo_areas if is_valid_number(id) - ]) + geo_areas.extend([int(id) for id in pgeo_areas if is_valid_number(id)]) if extract_polygon_title and ptitle and pregion_id: - polygons.append({ - 'region_id': pregion_id, - 'title': ptitle, - }) + polygons.append( + { + "region_id": pregion_id, + "title": ptitle, + } + ) geo_areas = list(set(geo_areas)) if extract_polygon_title: @@ -48,31 +48,28 @@ def update_attribute(widget, data, widget_properties): data: { value: [], polygons: [], points: [] } """ - all_values = [ - *(data.get('value') or []), - *(data.get('polygons') or []), - *(data.get('points') or []) - ] + all_values = [*(data.get("value") or []), *(data.get("polygons") or []), *(data.get("points") or [])] values, polygons = get_valid_geo_ids( all_values, extract_polygon_title=True, ) return { - 'filter_data': [{ - 'values': values, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - 'values': values, # GEOAREA IDs + "filter_data": [ + { + "values": values, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, + "values": values, # GEOAREA IDs }, - 'excel': { - 'polygons': polygons, # Polygons + "excel": { + "polygons": polygons, # Polygons }, } }, @@ -80,30 +77,30 @@ def update_attribute(widget, data, widget_properties): def _get_geo_area_parents(geo_areas, admin_levels, geo_area): - parent = geo_areas.get(geo_area['parent']) + parent = geo_areas.get(geo_area["parent"]) if parent is None: return [] - parents = [{ - 'id': parent['id'], - 'title': parent['title'], - 'pcode': parent['pcode'], - 'admin_level': admin_levels.get(parent['admin_level']), - }] - p_parent = geo_areas.get(parent.get('parent')) + parents = [ + { + "id": parent["id"], + "title": parent["title"], + "pcode": parent["pcode"], + "admin_level": admin_levels.get(parent["admin_level"]), + } + ] + p_parent = geo_areas.get(parent.get("parent")) if p_parent: - parents.extend( - _get_geo_area_parents(geo_areas, admin_levels, p_parent) - ) + parents.extend(_get_geo_area_parents(geo_areas, admin_levels, p_parent)) return parents def get_comprehensive_data(widgets_meta, widget, data, widget_properties): - geo_areas = widgets_meta['geo-widget']['geo_areas'] - admin_levels = widgets_meta['geo-widget']['admin_levels'] + geo_areas = widgets_meta["geo-widget"]["geo_areas"] + admin_levels = widgets_meta["geo-widget"]["admin_levels"] # Ignore invalid ids - geo_areas_id = get_valid_geo_ids((data or {}).get('value') or []) + geo_areas_id = get_valid_geo_ids((data or {}).get("value") or []) values = [] @@ -111,13 +108,15 @@ def get_comprehensive_data(widgets_meta, widget, data, widget_properties): geo_area = geo_areas.get(int(geo_area_id)) if geo_area is None: continue - admin_level = admin_levels.get(geo_area.get('admin_level')) - values.append({ - 'id': geo_area['id'], - 'title': geo_area['title'], - 'pcode': geo_area['pcode'], - 'admin_level': admin_level, - 'parent': _get_geo_area_parents(geo_areas, admin_levels, geo_area), - }) + admin_level = admin_levels.get(geo_area.get("admin_level")) + values.append( + { + "id": geo_area["id"], + "title": geo_area["title"], + "pcode": geo_area["pcode"], + "admin_level": admin_level, + "parent": _get_geo_area_parents(geo_areas, admin_levels, geo_area), + } + ) return values or geo_areas_id diff --git a/apps/entry/widgets/matrix1d_widget.py b/apps/entry/widgets/matrix1d_widget.py index ec2a608524..2d53b7ade4 100644 --- a/apps/entry/widgets/matrix1d_widget.py +++ b/apps/entry/widgets/matrix1d_widget.py @@ -16,16 +16,13 @@ def update_attribute(widget, data, widget_properties): excel_values = [] report_values = [] - data_value = (data or {}).get('value', {}) - rows = widget_properties.get('rows', []) + data_value = (data or {}).get("value", {}) + rows = widget_properties.get("rows", []) for row_key, row in data_value.items(): row_exists = False - row_data = next(( - r for r in rows - if r.get('key') == row_key - ), {}) - cells = row_data.get('cells', []) + row_data = next((r for r in rows if r.get("key") == row_key), {}) + cells = row_data.get("cells", []) if not row: continue @@ -33,44 +30,43 @@ def update_attribute(widget, data, widget_properties): for cell_key, cell in row.items(): if cell: row_exists = True - cell_data = next(( - c for c in cells - if c.get('key') == cell_key - ), {}) + cell_data = next((c for c in cells if c.get("key") == cell_key), {}) filter_values.append(cell_key) - excel_values.append([ - row_data.get('label'), - cell_data.get('label'), - ]) - report_values.append('{}-{}'.format(row_key, cell_key)) + excel_values.append( + [ + row_data.get("label"), + cell_data.get("label"), + ] + ) + report_values.append("{}-{}".format(row_key, cell_key)) if row_exists: filter_values.append(row_key) return { - 'filter_data': [{ - 'values': filter_values, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [ + { + "values": filter_values, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - 'type': 'lists', - 'values': excel_values, + "excel": { + "type": "lists", + "values": excel_values, }, - 'report': { - 'keys': report_values, + "report": { + "keys": report_values, }, } }, - } @@ -78,32 +74,34 @@ def _get_headers(widgets_meta, widget, widget_properties): if widgets_meta.get(widget.pk) is not None: widget_meta = widgets_meta[widget.pk] return ( - widget_meta['pillar_header_map'], - widget_meta['subpillar_header_map'], + widget_meta["pillar_header_map"], + widget_meta["subpillar_header_map"], ) pillar_header_map = {} subpillar_header_map = {} - for pillar in widget_properties.get('rows', []): + for pillar in widget_properties.get("rows", []): subpillar_keys = [] - pillar_header_map[pillar['key']] = pillar - for subpillar in pillar['cells']: - subpillar_header_map[subpillar['key']] = subpillar - subpillar_keys.append(subpillar['key']) - pillar_header_map[pillar['key']]['subpillar_keys'] = subpillar_keys + pillar_header_map[pillar["key"]] = pillar + for subpillar in pillar["cells"]: + subpillar_header_map[subpillar["key"]] = subpillar + subpillar_keys.append(subpillar["key"]) + pillar_header_map[pillar["key"]]["subpillar_keys"] = subpillar_keys widgets_meta[widget.pk] = { - 'pillar_header_map': pillar_header_map, - 'subpillar_header_map': subpillar_header_map, + "pillar_header_map": pillar_header_map, + "subpillar_header_map": subpillar_header_map, } return pillar_header_map, subpillar_header_map def get_comprehensive_data(widgets_meta, widget, data, widget_properties): - data_value = (data or {}).get('value') or {} + data_value = (data or {}).get("value") or {} pillar_header_map, subpillar_header_map = _get_headers( - widgets_meta, widget, widget_properties, + widgets_meta, + widget, + widget_properties, ) values = [] @@ -113,15 +111,17 @@ def get_comprehensive_data(widgets_meta, widget, data, widget_properties): pillar_header = pillar_header_map.get(pillar_key) subpillar_header = subpillar_header_map.get(subpillar_key) if ( - not subpillar_selected or - pillar_header is None or - subpillar_header is None or - subpillar_key not in pillar_header.get('subpillar_keys', []) + not subpillar_selected + or pillar_header is None + or subpillar_header is None + or subpillar_key not in pillar_header.get("subpillar_keys", []) ): continue - values.append({ - 'id': subpillar_header['key'], - 'value': subpillar_header['label'], - 'row': {'id': pillar_header['key'], 'title': pillar_header['label']}, - }) + values.append( + { + "id": subpillar_header["key"], + "value": subpillar_header["label"], + "row": {"id": pillar_header["key"], "title": pillar_header["label"]}, + } + ) return values diff --git a/apps/entry/widgets/matrix2d_widget.py b/apps/entry/widgets/matrix2d_widget.py index 67dd4001e5..893ea6bf9d 100644 --- a/apps/entry/widgets/matrix2d_widget.py +++ b/apps/entry/widgets/matrix2d_widget.py @@ -4,9 +4,9 @@ def update_attribute(widget, data, widget_properties): - data = (data or {}).get('value', {}) - rows = widget_properties.get('rows', []) - columns = widget_properties.get('columns', []) + data = (data or {}).get("value", {}) + rows = widget_properties.get("rows", []) + columns = widget_properties.get("columns", []) filter1_values = [] filter2_values = [] @@ -17,11 +17,8 @@ def update_attribute(widget, data, widget_properties): for key, row in data.items(): dim_exists = False - row_data = next(( - d for d in rows - if d.get('key') == key - ), {}) - sub_rows = row_data.get('subRows', []) + row_data = next((d for d in rows if d.get("key") == key), {}) + sub_rows = row_data.get("subRows", []) if row is None: continue @@ -29,10 +26,7 @@ def update_attribute(widget, data, widget_properties): for sub_key, sub_row in row.items(): subdim_exists = False - sub_row_data = next(( - s for s in sub_rows - if s.get('key') == sub_key - ), {}) + sub_row_data = next((s for s in sub_rows if s.get("key") == sub_key), {}) if row is None: continue @@ -49,70 +43,57 @@ def update_attribute(widget, data, widget_properties): filter2_values.append(column_key) filter2_values.extend(sub_columns) - column_data = next(( - s for s in columns - if s.get('key') == column_key - ), {}) + column_data = next((s for s in columns if s.get("key") == column_key), {}) def get_ss_title(ss): - return next(( - ssd.get('label') for ssd - in column_data.get('subColumns', []) - if ssd.get('key') == ss - ), '') - - excel_values.append([ - row.get('label'), - sub_row_data.get('label'), - column_data.get('label'), - [get_ss_title(ss) for ss in sub_columns], - ]) + return next((ssd.get("label") for ssd in column_data.get("subColumns", []) if ssd.get("key") == ss), "") - # Without sub_columns {column}-{row}-{sub-row} - report_values.append( - '{}-{}-{}'.format(column_key, key, sub_key) - ) - # With sub_columns {column}-{sub-column}-{row}-{sub-row} - report_values.extend( + excel_values.append( [ - '{}-{}-{}-{}'.format(column_key, ss, key, sub_key) - for ss in sub_columns + row.get("label"), + sub_row_data.get("label"), + column_data.get("label"), + [get_ss_title(ss) for ss in sub_columns], ] ) + # Without sub_columns {column}-{row}-{sub-row} + report_values.append("{}-{}-{}".format(column_key, key, sub_key)) + # With sub_columns {column}-{sub-column}-{row}-{sub-row} + report_values.extend(["{}-{}-{}-{}".format(column_key, ss, key, sub_key) for ss in sub_columns]) + if subdim_exists: filter1_values.append(sub_key) if dim_exists: filter1_values.append(key) return { - 'filter_data': [ + "filter_data": [ { - 'key': '{}-rows'.format(widget.key), - 'values': filter1_values, + "key": "{}-rows".format(widget.key), + "values": filter1_values, }, { - 'key': '{}-columns'.format(widget.key), - 'values': filter2_values, + "key": "{}-columns".format(widget.key), + "values": filter2_values, }, ], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - 'type': 'lists', - 'values': excel_values, + "excel": { + "type": "lists", + "values": excel_values, }, - 'report': { - 'keys': report_values, + "report": { + "keys": report_values, }, }, - } + }, } @@ -120,38 +101,38 @@ def _get_headers(widgets_meta, widget, widget_properties): if widgets_meta.get(widget.pk) is not None: widget_meta = widgets_meta[widget.pk] return ( - widget_meta['dimension_header_map'], - widget_meta['subdimension_header_map'], - widget_meta['sector_header_map'], - widget_meta['subsector_header_map'], + widget_meta["dimension_header_map"], + widget_meta["subdimension_header_map"], + widget_meta["sector_header_map"], + widget_meta["subsector_header_map"], ) dimension_header_map = {} subdimension_header_map = {} - for dimension in widget_properties.get('rows', []): + for dimension in widget_properties.get("rows", []): subdimension_keys = [] - dimension_header_map[dimension['key']] = dimension - for subdimension in dimension['subRows']: - subdimension_header_map[subdimension['key']] = subdimension - subdimension_keys.append(subdimension['key']) - dimension_header_map[dimension['key']]['subdimension_keys'] = subdimension_keys + dimension_header_map[dimension["key"]] = dimension + for subdimension in dimension["subRows"]: + subdimension_header_map[subdimension["key"]] = subdimension + subdimension_keys.append(subdimension["key"]) + dimension_header_map[dimension["key"]]["subdimension_keys"] = subdimension_keys sector_header_map = {} subsector_header_map = {} - for sector in widget_properties.get('columns', []): + for sector in widget_properties.get("columns", []): subsector_keys = [] - sector_header_map[sector['key']] = sector - for subsector in sector['subColumns']: - subsector_header_map[subsector['key']] = subsector - subsector_keys.append(subsector['key']) - sector_header_map[sector['key']]['subsector_keys'] = subsector_keys + sector_header_map[sector["key"]] = sector + for subsector in sector["subColumns"]: + subsector_header_map[subsector["key"]] = subsector + subsector_keys.append(subsector["key"]) + sector_header_map[sector["key"]]["subsector_keys"] = subsector_keys widgets_meta[widget.pk] = { - 'dimension_header_map': dimension_header_map, - 'subdimension_header_map': subdimension_header_map, - 'sector_header_map': sector_header_map, - 'subsector_header_map': subsector_header_map, + "dimension_header_map": dimension_header_map, + "subdimension_header_map": subdimension_header_map, + "sector_header_map": sector_header_map, + "subsector_header_map": subsector_header_map, } return ( dimension_header_map, @@ -165,20 +146,20 @@ def _get_subsectors(subsector_header_map, sector_header, subsectors): subsectors_header = [] for subsector_key in subsectors: subsector_header = subsector_header_map.get(subsector_key) - if subsector_header and subsector_key in sector_header['subsector_keys']: - subsectors_header.append( - {'id': subsector_header['key'], 'title': subsector_header['label']} - ) + if subsector_header and subsector_key in sector_header["subsector_keys"]: + subsectors_header.append({"id": subsector_header["key"], "title": subsector_header["label"]}) return subsectors_header def get_comprehensive_data(widgets_meta, widget, data, widget_properties): - data = (data or {}).get('value') or {} + data = (data or {}).get("value") or {} values = [] ( - dimension_header_map, subdimension_header_map, - sector_header_map, subsector_header_map, + dimension_header_map, + subdimension_header_map, + sector_header_map, + subsector_header_map, ) = _get_headers(widgets_meta, widget, widget_properties) for dimension_key, dimension_value in data.items(): @@ -188,18 +169,22 @@ def get_comprehensive_data(widgets_meta, widget, data, widget_properties): subdimension_header = subdimension_header_map.get(subdimension_key) sector_header = sector_header_map.get(sector_key) if ( - dimension_header is None or - subdimension_header is None or - sector_header is None or - subdimension_key not in dimension_header['subdimension_keys'] + dimension_header is None + or subdimension_header is None + or sector_header is None + or subdimension_key not in dimension_header["subdimension_keys"] ): continue - values.append({ - 'dimension': {'id': dimension_header['key'], 'title': dimension_header['label']}, - 'subdimension': {'id': subdimension_header['key'], 'title': subdimension_header['label']}, - 'sector': {'id': sector_header['key'], 'title': sector_header['label']}, - 'subsectors': _get_subsectors( - subsector_header_map, sector_header, selected_subsectors, - ), - }) + values.append( + { + "dimension": {"id": dimension_header["key"], "title": dimension_header["label"]}, + "subdimension": {"id": subdimension_header["key"], "title": subdimension_header["label"]}, + "sector": {"id": sector_header["key"], "title": sector_header["label"]}, + "subsectors": _get_subsectors( + subsector_header_map, + sector_header, + selected_subsectors, + ), + } + ) return values diff --git a/apps/entry/widgets/multiselect_widget.py b/apps/entry/widgets/multiselect_widget.py index 3ac2dde4aa..0d9d5ae92e 100644 --- a/apps/entry/widgets/multiselect_widget.py +++ b/apps/entry/widgets/multiselect_widget.py @@ -1,22 +1,18 @@ from analysis_framework.widgets.multiselect_widget import WIDGET_ID - # NOTE: Please update the data version when you update the data format DATA_VERSION = 1 def _get_label_list(widget, data, widget_data): - values = data.get('value', []) - options = widget_data.get('options', []) + values = data.get("value", []) + options = widget_data.get("options", []) label_list = [] for item in values: - option = next(( - o for o in options - if o.get('key') == item - ), None) + option = next((o for o in options if o.get("key") == item), None) if option: - label_list.append(option.get('label') or 'Unknown') + label_list.append(option.get("label") or "Unknown") return label_list, values @@ -25,21 +21,21 @@ def update_attribute(widget, data, widget_data): label_list, values = _get_label_list(widget, data, widget_data) return { - 'filter_data': [{ - 'values': values, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - 'type': 'list', - 'value': label_list, - }, - 'excel': { + "filter_data": [ + { + "values": values, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, + "type": "list", + "value": label_list, }, + "excel": {}, } }, } diff --git a/apps/entry/widgets/number_matrix_widget.py b/apps/entry/widgets/number_matrix_widget.py index 5190aaa7cb..9605a49b8a 100644 --- a/apps/entry/widgets/number_matrix_widget.py +++ b/apps/entry/widgets/number_matrix_widget.py @@ -6,63 +6,57 @@ def update_attribute(widget, _data, widget_data): - data = (_data or {}).get('value') or {} - row_headers = widget_data.get('row_headers', []) - column_headers = widget_data.get('column_headers', []) + data = (_data or {}).get("value") or {} + row_headers = widget_data.get("row_headers", []) + column_headers = widget_data.get("column_headers", []) excel_values = [] for row_header in row_headers: row_values = [] for column_header in column_headers: - value = (data.get(row_header.get('key')) or {}).get( - column_header.get('key'), + value = (data.get(row_header.get("key")) or {}).get( + column_header.get("key"), ) if value is None: - excel_values.append('') + excel_values.append("") else: row_values.append(value) excel_values.append(str(value)) is_same = len(row_values) == 0 or len(set(row_values)) == 1 - excel_values.append('True' if is_same else 'False') + excel_values.append("True" if is_same else "False") return { - 'filter_data': [], - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - 'values': excel_values, + "excel": { + "values": excel_values, }, } - } + }, } def get_comprehensive_data(widgets_meta, widget, _data, widget_data): - data = (_data or {}).get('value') or {} + data = (_data or {}).get("value") or {} if widgets_meta.get(widget.pk) is None: - row_headers_map = { - row['key']: row - for row in widget_data.get('row_headers', []) - } - column_headers_map = { - col['key']: col - for col in widget_data.get('column_headers', []) - } + row_headers_map = {row["key"]: row for row in widget_data.get("row_headers", [])} + column_headers_map = {col["key"]: col for col in widget_data.get("column_headers", [])} widgets_meta[widget.pk] = { - 'row_headers_map': row_headers_map, - 'column_headers_map': column_headers_map, + "row_headers_map": row_headers_map, + "column_headers_map": column_headers_map, } else: widget_meta = widgets_meta[widget.pk] - row_headers_map = widget_meta['row_headers_map'] - column_headers_map = widget_meta['column_headers_map'] + row_headers_map = widget_meta["row_headers_map"] + column_headers_map = widget_meta["column_headers_map"] values = [] for row_key, row_value in data.items(): @@ -70,9 +64,11 @@ def get_comprehensive_data(widgets_meta, widget, _data, widget_data): row_header = row_headers_map.get(row_key) col_header = column_headers_map.get(col_key) if row_header and col_header: - values.append({ - 'value': value, - 'row': {'id': row_key, 'title': row_header['title']}, - 'column': {'id': col_key, 'title': col_header['title']}, - }) + values.append( + { + "value": value, + "row": {"id": row_key, "title": row_header["title"]}, + "column": {"id": col_key, "title": col_header["title"]}, + } + ) return values diff --git a/apps/entry/widgets/number_widget.py b/apps/entry/widgets/number_widget.py index 62e90f0a14..c8a94a4136 100644 --- a/apps/entry/widgets/number_widget.py +++ b/apps/entry/widgets/number_widget.py @@ -1,11 +1,10 @@ from analysis_framework.widgets.number_widget import WIDGET_ID - DATA_VERSION = 1 def _get_number(widget, data, widget_data): - value = data.get('value') + value = data.get("value") return value and str(value), value @@ -14,19 +13,20 @@ def update_attribute(*args): widget = args[0] return { - 'filter_data': [{ - 'number': value, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [ + { + "number": value, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - 'value': str_value, + "excel": { + "value": str_value, }, }, }, diff --git a/apps/entry/widgets/organigram_widget.py b/apps/entry/widgets/organigram_widget.py index e14b98e116..b3bf8f43ef 100644 --- a/apps/entry/widgets/organigram_widget.py +++ b/apps/entry/widgets/organigram_widget.py @@ -5,91 +5,100 @@ def _get_parent_nodes(node_mapping, node_key): node = node_mapping[node_key] - parent_node_key = node.get('parent_node') + parent_node_key = node.get("parent_node") parent_node = node_mapping.get(parent_node_key) - parent_nodes = [{ - 'key': parent_node_key, - 'title': node.get('parent_title'), - }] if parent_node_key else [] + parent_nodes = ( + [ + { + "key": parent_node_key, + "title": node.get("parent_title"), + } + ] + if parent_node_key + else [] + ) if parent_node: - parent_nodes.extend( - _get_parent_nodes(node_mapping, parent_node_key) - ) + parent_nodes.extend(_get_parent_nodes(node_mapping, parent_node_key)) return parent_nodes def _get_selected_nodes_with_parent(node, selected_ids, node_mapping=None): node_mapping = node_mapping or {} - organs = node.get('children', []) + organs = node.get("children", []) - if 'key' not in node: + if "key" not in node: return [] - if node['key'] not in node_mapping: - node_mapping[node['key']] = { - 'key': node['key'], - 'title': node.get('label'), + if node["key"] not in node_mapping: + node_mapping[node["key"]] = { + "key": node["key"], + "title": node.get("label"), } selected = [] - if node['key'] in selected_ids: - selected.append({ - 'key': node['key'], - 'title': node['label'], - 'parents': _get_parent_nodes(node_mapping, node['key']), - }) + if node["key"] in selected_ids: + selected.append( + { + "key": node["key"], + "title": node["label"], + "parents": _get_parent_nodes(node_mapping, node["key"]), + } + ) for organ in organs: - if 'key' not in organ: + if "key" not in organ: continue - node_mapping[organ['key']] = { - 'key': organ['key'], - 'title': organ['label'], - 'parent_node': node['key'], - 'parent_title': node['label'], + node_mapping[organ["key"]] = { + "key": organ["key"], + "title": organ["label"], + "parent_node": node["key"], + "parent_title": node["label"], } selected.extend( _get_selected_nodes_with_parent( - organ, selected_ids, node_mapping=node_mapping, + organ, + selected_ids, + node_mapping=node_mapping, ) ) return selected def update_attribute(widget, data, widget_properties): - values = data.get('value', []) - base_node = widget_properties.get('options', {}) + values = data.get("value", []) + base_node = widget_properties.get("options", {}) selected_nodes_with_parents = [ [ *[ # Don't show base/root as parent nodes - parent_node['title'] if base_node.get('key') != parent_node['key'] else '' - for parent_node in node['parents'] + parent_node["title"] if base_node.get("key") != parent_node["key"] else "" + for parent_node in node["parents"] ][::-1], - node['title'], + node["title"], ] for node in _get_selected_nodes_with_parent(base_node, set(values)) ] return { - 'filter_data': [{ - 'values': values, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - 'values': selected_nodes_with_parents, + "filter_data": [ + { + "values": values, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, + "values": selected_nodes_with_parents, }, - 'excel': { - 'type': 'lists', + "excel": { + "type": "lists", }, }, }, @@ -97,5 +106,5 @@ def update_attribute(widget, data, widget_properties): def get_comprehensive_data(_, widget, data, widget_properties): - values = data.get('value', []) + values = data.get("value", []) return _get_selected_nodes_with_parent(widget_properties, set(values)) diff --git a/apps/entry/widgets/scale_widget.py b/apps/entry/widgets/scale_widget.py index ef16cc32f0..fc17b8796e 100644 --- a/apps/entry/widgets/scale_widget.py +++ b/apps/entry/widgets/scale_widget.py @@ -1,29 +1,24 @@ from analysis_framework.widgets.scale_widget import WIDGET_ID - # NOTE: Please update the data version when you update the data format DATA_VERSION = 1 def _get_scale(widget, data, widget_properties): - selected_scale = data.get('value') + selected_scale = data.get("value") selected_scales = [selected_scale] if selected_scale is not None else [] - options = widget_properties.get('options', []) - scale = next(( - s for s in options - if s['key'] == selected_scale - ), None) + options = widget_properties.get("options", []) + scale = next((s for s in options if s["key"] == selected_scale), None) scale = scale or {} return { # Note: Please change the DATA_VERSION when you change the data format - # widget_id will be used to alter rendering in report - 'widget_id': getattr(widget, 'widget_id', ''), + "widget_id": getattr(widget, "widget_id", ""), # widget related attributes - 'title': getattr(widget, 'title', ''), - 'label': scale.get('label'), - 'color': scale.get('color'), + "title": getattr(widget, "title", ""), + "label": scale.get("label"), + "color": scale.get("color"), }, selected_scales @@ -32,25 +27,26 @@ def update_attribute(widget, data, widget_properties): return { # Note: Please change the DATA_VERSION when you change the data format - 'filter_data': [{ - 'values': selected_scales, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [ + { + "values": selected_scales, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, + }, + "excel": { + "value": scale["label"], }, - 'excel': { - 'value': scale['label'], + "report": { + "title": scale["title"], + "label": scale["label"], + "color": scale["color"], }, - 'report': { - 'title': scale['title'], - 'label': scale['label'], - 'color': scale['color'], - } }, }, } @@ -58,25 +54,24 @@ def update_attribute(widget, data, widget_properties): def get_comprehensive_data(widgets_meta, widget, data, widget_properties): scale, selected_scales = _get_scale(widget, data, widget_properties) - options = widget_properties.get('options', []) + options = widget_properties.get("options", []) if widgets_meta.get(widget.pk) is None: # To avoid calculating meta at each attribute widgets_meta[widget.pk] = {} min_option, max_option = {}, {} if options: min_option, max_option = {**options[0]}, {**options[len(options) - 1]} - min_option['key'], max_option['key'] = min_option.pop('key'), max_option.pop('key') + min_option["key"], max_option["key"] = min_option.pop("key"), max_option.pop("key") widgets_meta[widget.pk] = { - 'min': min_option, - 'max': max_option, + "min": min_option, + "max": max_option, } return { **widgets_meta[widget.pk], - 'scale': scale, - 'label': scale['label'], - 'index': ([ - (i + 1) for i, v in enumerate(options) - if v['key'] == selected_scales[0] - ] or [None])[0] if selected_scales else None, + "scale": scale, + "label": scale["label"], + "index": ( + ([(i + 1) for i, v in enumerate(options) if v["key"] == selected_scales[0]] or [None])[0] if selected_scales else None + ), } diff --git a/apps/entry/widgets/select_widget.py b/apps/entry/widgets/select_widget.py index c86e86b841..8813df2dc8 100644 --- a/apps/entry/widgets/select_widget.py +++ b/apps/entry/widgets/select_widget.py @@ -1,23 +1,19 @@ from analysis_framework.widgets.select_widget import WIDGET_ID - # NOTE: Please update the data version when you update the data format DATA_VERSION = 1 def _get_label_list(widget, data, widget_properties): - values = data.get('value') + values = data.get("value") values = [values] if values is not None else [] - options = widget_properties.get('options', []) + options = widget_properties.get("options", []) label_list = [] for item in values: - option = next(( - o for o in options - if o.get('key') == item - ), None) + option = next((o for o in options if o.get("key") == item), None) if option: - label_list.append(option.get('label') or 'Unknown') + label_list.append(option.get("label") or "Unknown") return label_list, values @@ -26,21 +22,21 @@ def update_attribute(widget, data, widget_properties): label_list, values = _get_label_list(widget, data, widget_properties) return { - 'filter_data': [{ - 'values': values, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - 'type': 'list', - 'value': label_list, - }, - 'excel': { + "filter_data": [ + { + "values": values, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, + "type": "list", + "value": label_list, }, + "excel": {}, }, }, } diff --git a/apps/entry/widgets/store.py b/apps/entry/widgets/store.py index ceda527609..b9dec2301a 100644 --- a/apps/entry/widgets/store.py +++ b/apps/entry/widgets/store.py @@ -1,22 +1,21 @@ from . import ( - date_widget, + conditional_widget, date_range_widget, - time_widget, - time_range_widget, - number_widget, - scale_widget, - select_widget, - multiselect_widget, + date_widget, geo_widget, - organigram_widget, matrix1d_widget, matrix2d_widget, + multiselect_widget, number_matrix_widget, - conditional_widget, + number_widget, + organigram_widget, + scale_widget, + select_widget, text_widget, + time_range_widget, + time_widget, ) - widget_store = { widget.WIDGET_ID: widget for widget in ( diff --git a/apps/entry/widgets/text_widget.py b/apps/entry/widgets/text_widget.py index e513fed55b..59e6356523 100644 --- a/apps/entry/widgets/text_widget.py +++ b/apps/entry/widgets/text_widget.py @@ -1,13 +1,10 @@ from analysis_framework.widgets.text_widget import WIDGET_ID - DATA_VERSION = 1 def _get_text(widget, data, widget_properties): - return str( - data.get('value') or '' - ) + return str(data.get("value") or "") def update_attribute(*args, **kwargs): @@ -15,19 +12,20 @@ def update_attribute(*args, **kwargs): widget = args[0] return { - 'filter_data': [{ - 'text': text, - }], - - 'export_data': { - 'data': { - 'common': { - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, + "filter_data": [ + { + "text": text, + } + ], + "export_data": { + "data": { + "common": { + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, - 'excel': { - 'value': text, + "excel": { + "value": text, }, } }, diff --git a/apps/entry/widgets/time_range_widget.py b/apps/entry/widgets/time_range_widget.py index 406c1a105b..401408bf1f 100644 --- a/apps/entry/widgets/time_range_widget.py +++ b/apps/entry/widgets/time_range_widget.py @@ -1,25 +1,25 @@ from analysis_framework.widgets.time_range_widget import WIDGET_ID -from .time_widget import parse_time_str +from .time_widget import parse_time_str # NOTE: Please update the data version when you update the data format DATA_VERSION = 1 def _get_time(widget, data, widget_properties): - value = data.get('value') or {} - from_value = value.get('startTime') # TODO: use from - to_value = value.get('endTime') # TODO: use to + value = data.get("value") or {} + from_value = value.get("startTime") # TODO: use from + to_value = value.get("endTime") # TODO: use to from_time = from_value and parse_time_str(from_value) to_time = to_value and parse_time_str(to_value) # NOTE: Please update the data version when you update the data format return ( - from_time and from_time['time_val'], - to_time and to_time['time_val'], + from_time and from_time["time_val"], + to_time and to_time["time_val"], ), ( - from_time and from_time['time_str'], - to_time and to_time['time_str'], + from_time and from_time["time_str"], + to_time and to_time["time_str"], ) @@ -34,23 +34,22 @@ def update_attribute(widget, data, widget_properties): return { # NOTE: Please update the data version when you update the data format - 'filter_data': [{ - 'from_number': from_number, - 'to_number': to_number, - }], - - 'export_data': { - 'data': { - 'common': { - 'values': [from_time, to_time], - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - }, - 'excel': { - }, - 'report': { + "filter_data": [ + { + "from_number": from_number, + "to_number": to_number, + } + ], + "export_data": { + "data": { + "common": { + "values": [from_time, to_time], + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, + "excel": {}, + "report": {}, }, }, } @@ -59,6 +58,6 @@ def update_attribute(widget, data, widget_properties): def get_comprehensive_data(_, *args): (from_time, to_time) = _get_time(*args)[1] return { - 'from': from_time, - 'to': to_time, + "from": from_time, + "to": to_time, } diff --git a/apps/entry/widgets/time_widget.py b/apps/entry/widgets/time_widget.py index 669bca7f8c..fe11599089 100644 --- a/apps/entry/widgets/time_widget.py +++ b/apps/entry/widgets/time_widget.py @@ -1,6 +1,5 @@ from analysis_framework.widgets.time_widget import WIDGET_ID - # NOTE: Please update the data version when you update the data format DATA_VERSION = 1 @@ -8,20 +7,20 @@ # NOTE: Please update the data version when you update the data format # This is also used in time_range_widget def parse_time_str(time_string): - splits = time_string.split(':') + splits = time_string.split(":") h = int(splits[0]) m = int(splits[1]) return { - 'time_str': '{:02d}:{:02d}'.format(h, m), - 'time_val': h * 60 + m, + "time_str": "{:02d}:{:02d}".format(h, m), + "time_val": h * 60 + m, } def _get_time(widget, data, widget_properties): - value = data.get('value') + value = data.get("value") time = value and parse_time_str(value) # NOTE: Please update the data version when you update the data format - return time and time['time_val'], value and time['time_str'] + return time and time["time_val"], value and time["time_str"] def update_attribute(widget, data, widget_properties): @@ -29,22 +28,17 @@ def update_attribute(widget, data, widget_properties): return { # NOTE: Please update the data version when you update the data format - 'filter_data': [{ - 'number': time_val - }], - - 'export_data': { - 'data': { - 'common': { - 'value': time_str, - 'widget_id': WIDGET_ID, - 'widget_key': widget.key, - 'version': DATA_VERSION, - }, - 'excel': { - }, - 'report': { + "filter_data": [{"number": time_val}], + "export_data": { + "data": { + "common": { + "value": time_str, + "widget_id": WIDGET_ID, + "widget_key": widget.key, + "version": DATA_VERSION, }, + "excel": {}, + "report": {}, }, }, } diff --git a/apps/entry/widgets/utils.py b/apps/entry/widgets/utils.py index 872bf1bc57..cbf4552f0e 100644 --- a/apps/entry/widgets/utils.py +++ b/apps/entry/widgets/utils.py @@ -1,13 +1,16 @@ -from analysis_framework.models import Filter, Exportable -from entry.models import FilterData, ExportData +from analysis_framework.models import Exportable, Filter +from entry.models import ExportData, FilterData def set_filter_data( - entry, widget, key=None, - number=None, - from_number=None, to_number=None, - values=None, - text=None, + entry, + widget, + key=None, + number=None, + from_number=None, + to_number=None, + values=None, + text=None, ): key = key or widget.key filter = Filter.objects.filter( @@ -19,11 +22,11 @@ def set_filter_data( entry=entry, filter=filter, defaults={ - 'number': number, - 'values': values, - 'from_number': from_number, - 'to_number': to_number, - 'text': text, + "number": number, + "values": values, + "from_number": from_number, + "to_number": to_number, + "text": text, }, ) return f @@ -38,7 +41,7 @@ def set_export_data(entry, widget, data): entry=entry, exportable=exportable, defaults={ - 'data': data, + "data": data, }, ) return e diff --git a/apps/export/admin.py b/apps/export/admin.py index 26d6829f5b..e4722dfe06 100644 --- a/apps/export/admin.py +++ b/apps/export/admin.py @@ -1,44 +1,40 @@ -from django.contrib import admin -from django.utils.translation import gettext_lazy as _ -from django.utils.safestring import mark_safe -from django.db import models -from django.contrib import messages from admin_auto_filters.filters import AutocompleteFilterFactory +from django.contrib import admin, messages +from django.db import models +from django.utils.safestring import mark_safe +from django.utils.translation import gettext_lazy as _ from deep.admin import ModelAdmin, document_preview from .models import Export, GenericExport from .tasks import export_task - TRIGGER_LIMIT = 5 def trigger_retry(modeladmin, request, queryset): - for export_id in queryset.values_list('id', flat=True).distinct()[:TRIGGER_LIMIT]: + for export_id in queryset.values_list("id", flat=True).distinct()[:TRIGGER_LIMIT]: export_task.delay(export_id, force=True) messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully force triggerd retry for exports:

' + - '
'.join( - '& {} : {}'.format(*value) - for value in queryset.values_list('id', 'title').distinct()[:TRIGGER_LIMIT] - ) - ) + "Successfully force triggerd retry for exports:

" + + "
".join("& {} : {}".format(*value) for value in queryset.values_list("id", "title").distinct()[:TRIGGER_LIMIT]) + ), ) -trigger_retry.short_description = 'Force trigger export process for selected export, limit: {}'.format(TRIGGER_LIMIT) +trigger_retry.short_description = "Force trigger export process for selected export, limit: {}".format(TRIGGER_LIMIT) class HaveExecutionTimeFilter(admin.SimpleListFilter): class Parameter(models.TextChoices): - TRUE = 'true', _('Yes') - FALSE = 'false', _('False') + TRUE = "true", _("Yes") + FALSE = "false", _("False") - title = _('Have Execution Time') - parameter_name = 'have_execution_time' + title = _("Have Execution Time") + parameter_name = "have_execution_time" def lookups(self, *_): return self.Parameter.choices @@ -55,30 +51,51 @@ def queryset(self, _, queryset): @admin.register(Export) class ExportAdmin(ModelAdmin): list_display = ( - 'title', 'file', 'type', 'exported_by', 'exported_at', 'execution_time', 'project', 'export_type', - 'format', 'is_preview', 'status', + "title", + "file", + "type", + "exported_by", + "exported_at", + "execution_time", + "project", + "export_type", + "format", + "is_preview", + "status", ) - search_fields = ('title',) - readonly_fields = (document_preview('file'),) + search_fields = ("title",) + readonly_fields = (document_preview("file"),) list_filter = ( - 'type', 'export_type', 'format', 'status', 'is_preview', 'is_deleted', 'is_archived', - ('ended_at', admin.EmptyFieldListFilter), + "type", + "export_type", + "format", + "status", + "is_preview", + "is_deleted", + "is_archived", + ("ended_at", admin.EmptyFieldListFilter), HaveExecutionTimeFilter, - AutocompleteFilterFactory('Project', 'project'), - AutocompleteFilterFactory('Analysis Framework', 'project__analysis_framework'), - AutocompleteFilterFactory('Exported By', 'exported_by'), + AutocompleteFilterFactory("Project", "project"), + AutocompleteFilterFactory("Analysis Framework", "project__analysis_framework"), + AutocompleteFilterFactory("Exported By", "exported_by"), ) actions = [trigger_retry] - autocomplete_fields = ('project', 'exported_by',) + autocomplete_fields = ( + "project", + "exported_by", + ) def get_queryset(self, request): - return super().get_queryset(request)\ - .annotate(execution_time=models.F('ended_at') - models.F('started_at'))\ - .select_related('exported_by', 'project') + return ( + super() + .get_queryset(request) + .annotate(execution_time=models.F("ended_at") - models.F("started_at")) + .select_related("exported_by", "project") + ) @admin.display( - ordering='execution_time', - description='Execution Time', + ordering="execution_time", + description="Execution Time", ) def execution_time(self, obj): return obj.execution_time @@ -87,26 +104,37 @@ def execution_time(self, obj): @admin.register(GenericExport) class GenericExportAdmin(ModelAdmin): list_display = ( - 'title', 'file', 'type', 'exported_by', 'exported_at', 'execution_time', - 'format', 'status', + "title", + "file", + "type", + "exported_by", + "exported_at", + "execution_time", + "format", + "status", ) - search_fields = ('title',) - readonly_fields = (document_preview('file'),) + search_fields = ("title",) + readonly_fields = (document_preview("file"),) list_filter = ( - 'type', 'format', 'status', - ('ended_at', admin.EmptyFieldListFilter), + "type", + "format", + "status", + ("ended_at", admin.EmptyFieldListFilter), HaveExecutionTimeFilter, - AutocompleteFilterFactory('Exported By', 'exported_by'), + AutocompleteFilterFactory("Exported By", "exported_by"), ) def get_queryset(self, request): - return super().get_queryset(request)\ - .annotate(execution_time=models.F('ended_at') - models.F('started_at'))\ - .select_related('exported_by') + return ( + super() + .get_queryset(request) + .annotate(execution_time=models.F("ended_at") - models.F("started_at")) + .select_related("exported_by") + ) @admin.display( - ordering='execution_time', - description='Execution Time', + ordering="execution_time", + description="Execution Time", ) def execution_time(self, obj): return obj.execution_time diff --git a/apps/export/analyses/excel_exporter.py b/apps/export/analyses/excel_exporter.py index 90eb6723b0..dc0594827f 100644 --- a/apps/export/analyses/excel_exporter.py +++ b/apps/export/analyses/excel_exporter.py @@ -1,58 +1,61 @@ -from django.core.files.base import ContentFile +import logging +from django.core.files.base import ContentFile from export.formats.xlsx import WorkBook -import logging - from deep.permalinks import Permalink -logger = logging.getLogger('__name__') +logger = logging.getLogger("__name__") class ExcelExporter: def __init__(self, analytical_statement_entries): self.wb = WorkBook() self.split = None - self.analysis_sheet = self.wb.get_active_sheet().set_title('Analysis') + self.analysis_sheet = self.wb.get_active_sheet().set_title("Analysis") self.titles = [ - 'Analysis Pillar ID', - 'Analysis Pillar', - 'Assignee', - 'Statement ID', - 'Statement', - 'Entry ID', - 'Entry', - 'Entry Link', - 'Source Link' + "Analysis Pillar ID", + "Analysis Pillar", + "Assignee", + "Statement ID", + "Statement", + "Entry ID", + "Entry", + "Entry Link", + "Source Link", ] def add_analytical_statement_entries(self, analytical_statement_entries): self.analysis_sheet.append([self.titles]) # FIXME: Use values only instead of fetching everything. qs = analytical_statement_entries.select_related( - 'entry', - 'entry__lead', - 'analytical_statement', - 'analytical_statement__analysis_pillar', - 'analytical_statement__analysis_pillar__assignee', + "entry", + "entry__lead", + "analytical_statement", + "analytical_statement__analysis_pillar", + "analytical_statement__analysis_pillar__assignee", ) for analytical_statement_entry in qs.iterator(): entry = analytical_statement_entry.entry lead = entry.lead analytical_statement = analytical_statement_entry.analytical_statement analysis_pillar = analytical_statement.analysis_pillar - self.analysis_sheet.append([[ - analysis_pillar.id, - analysis_pillar.title, - analysis_pillar.assignee.get_display_name(), - analytical_statement.pk, - analytical_statement.statement, - entry.id, - entry.excerpt, - Permalink.entry(entry.project_id, lead.id, entry.id), - lead.url or Permalink.lead_share_view(lead.uuid), - ]]) + self.analysis_sheet.append( + [ + [ + analysis_pillar.id, + analysis_pillar.title, + analysis_pillar.assignee.get_display_name(), + analytical_statement.pk, + analytical_statement.statement, + entry.id, + entry.excerpt, + Permalink.entry(entry.project_id, lead.id, entry.id), + lead.url or Permalink.lead_share_view(lead.uuid), + ] + ] + ) return self def export(self): diff --git a/apps/export/apps.py b/apps/export/apps.py index 15c9d19716..01b6c57eb6 100644 --- a/apps/export/apps.py +++ b/apps/export/apps.py @@ -2,4 +2,4 @@ class ExportConfig(AppConfig): - name = 'export' + name = "export" diff --git a/apps/export/assessments/excel_exporter.py b/apps/export/assessments/excel_exporter.py index 377c448371..719c2ccaad 100644 --- a/apps/export/assessments/excel_exporter.py +++ b/apps/export/assessments/excel_exporter.py @@ -1,45 +1,44 @@ +import logging from collections import OrderedDict from django.core.files.base import ContentFile - -from export.formats.xlsx import WorkBook, RowsBuilder +from export.formats.xlsx import RowsBuilder, WorkBook from openpyxl.styles import Alignment, Font -from utils.common import deep_date_format, underscore_to_title -import logging +from utils.common import deep_date_format, underscore_to_title -logger = logging.getLogger('django') +logger = logging.getLogger("django") class ExcelExporter: """ NOTE: Legacy exporter (Not used) """ + def __init__(self, decoupled=True): self.wb = WorkBook() # Create two worksheets if decoupled: - self.split = self.wb.get_active_sheet()\ - .set_title('Split Assessments') - self.group = self.wb.create_sheet('Grouped Assessments') + self.split = self.wb.get_active_sheet().set_title("Split Assessments") + self.group = self.wb.create_sheet("Grouped Assessments") else: self.split = None - self.group = self.wb.get_active_sheet().set_title('Assessments') + self.group = self.wb.get_active_sheet().set_title("Assessments") self.decoupled = decoupled # Cells to be merged self.merge_cells = {} # Initial titles self.lead_titles = [ - 'Date of Source Publication', - 'Imported By', - 'Source Title', - 'Publisher', + "Date of Source Publication", + "Imported By", + "Source Title", + "Publisher", ] self.titles = [*self.lead_titles] self.col_types = { - 0: 'date', + 0: "date", } self._titles_dict = {k: True for k in self.titles} @@ -68,17 +67,15 @@ def to_flattened_key_vals(self, dictdata, parents=[]): # check if list elements are dict or not for i in v: if isinstance(i, dict): - flat.update( - self.to_flattened_key_vals(i, [k, *parents]) - ) + flat.update(self.to_flattened_key_vals(i, [k, *parents])) else: - vals = flat.get(k, {}).get('value', []) + vals = flat.get(k, {}).get("value", []) vals.append(i) # FIXME: assigning parents is repeated every step - flat[k] = {'value': vals, 'parents': parents} + flat[k] = {"value": vals, "parents": parents} else: # Just add key value - flat[k] = {'value': v, 'parents': parents} + flat[k] = {"value": v, "parents": parents} return flat def add_assessments(self, assessments): @@ -94,35 +91,34 @@ def add_assessment(self, assessment): # update the titles for k, v in flat.items(): - parent = v['parents'][-1] + parent = v["parents"][-1] header_titles = self._headers_titles.get(parent, []) if k not in header_titles: header_titles.append(k) self._headers_titles[parent] = header_titles - ''' + """ if not self._titles_dict.get(k): self.titles.append(k) self._titles_dict[k] = True - ''' + """ return self def get_titles(self): - return [ - *self.lead_titles, - *[y for k, v in self._headers_titles.items() for y in v] - ] + return [*self.lead_titles, *[y for k, v in self._headers_titles.items() for y in v]] def assessments_to_rows(self): for index, assessment in enumerate(self._assessments): rows = RowsBuilder(self.split, self.group, split=False) lead = assessment.lead - rows.add_value_list([ - deep_date_format(lead.created_at), - lead.created_by.username, - lead.title, - (lead.source and lead.source.title) or lead.source_raw, - ]) + rows.add_value_list( + [ + deep_date_format(lead.created_at), + lead.created_by.username, + lead.title, + (lead.source and lead.source.title) or lead.source_raw, + ] + ) headers_dict = {} flat = self._flats[index] for i, t in enumerate(self.get_titles()): @@ -135,19 +131,19 @@ def assessments_to_rows(self): self._title_headers.append("") continue - v = flat[t]['value'] - val = ', '.join([str(x) for x in v]) if isinstance(v, list) else str(v) + v = flat[t]["value"] + val = ", ".join([str(x) for x in v]) if isinstance(v, list) else str(v) rows.add_value(val) - header = flat[t]['parents'][-1] + header = flat[t]["parents"][-1] if not self._headers_dict.get(header): self._title_headers.append(header.upper()) self._headers_dict[header] = True else: - self.merge_cells[header]['end'] += 1 + self.merge_cells[header]["end"] += 1 if not headers_dict.get(header): - self.merge_cells[header] = {'start': i, 'end': i} + self.merge_cells[header] = {"start": i, "end": i} headers_dict[header] = True else: self._title_headers.append("") @@ -177,12 +173,9 @@ def export(self): if self.merge_cells: sheet = self.wb.wb.active for k, v in self.merge_cells.items(): - sheet.merge_cells( - start_row=1, start_column=v['start'] + 1, - end_row=1, end_column=v['end'] + 1 - ) - cell = sheet.cell(row=1, column=v['start'] + 1) - cell.alignment = Alignment(horizontal='center') + sheet.merge_cells(start_row=1, start_column=v["start"] + 1, end_row=1, end_column=v["end"] + 1) + cell = sheet.cell(row=1, column=v["start"] + 1) + cell.alignment = Alignment(horizontal="center") self.group.set_col_types(self.col_types) if self.split: @@ -249,18 +242,13 @@ def add_headers(self): for header, info in headerinfo.items(): wb_sheet = self.wb_sheets[sheet].ws if info: - wb_sheet.merge_cells( - start_row=1, - start_column=counter, - end_row=1, - end_column=counter + len(info) - 1 - ) + wb_sheet.merge_cells(start_row=1, start_column=counter, end_row=1, end_column=counter + len(info) - 1) counter += len(info) else: counter += 1 # Styling cell = wb_sheet.cell(row=1, column=counter) - cell.alignment = Alignment(horizontal='center') + cell.alignment = Alignment(horizontal="center") cell.font = Font(bold=True) # Style sub headers for i, header in enumerate(sub_header_row): @@ -288,7 +276,7 @@ def export(self): # Remove default sheet only if other sheets present if self.wb_sheets: - self.wb.wb.remove(self.wb.wb.get_sheet_by_name('Sheet')) + self.wb.wb.remove(self.wb.wb.get_sheet_by_name("Sheet")) buffer = self.wb.save() return ContentFile(buffer) diff --git a/apps/export/entries/excel_exporter.py b/apps/export/entries/excel_exporter.py index 69d38a0d0c..ab0be423e3 100644 --- a/apps/export/entries/excel_exporter.py +++ b/apps/export/entries/excel_exporter.py @@ -1,19 +1,16 @@ import logging + +from analysis_framework.models import Widget from django.core.files.base import ContentFile from django.db import models +from entry.models import Entry, ExportData, LeadEntryGroup, ProjectEntryLabel +from export.formats.xlsx import RowsBuilder, WorkBook +from export.models import Export +from lead.models import Lead from deep.permalinks import Permalink -from utils.common import ( - excel_column_name, - get_valid_xml_string as xstr, - deep_date_parse, -) -from export.formats.xlsx import WorkBook, RowsBuilder - -from analysis_framework.models import Widget -from entry.models import Entry, ExportData, ProjectEntryLabel, LeadEntryGroup -from lead.models import Lead -from export.models import Export +from utils.common import deep_date_parse, excel_column_name +from utils.common import get_valid_xml_string as xstr logger = logging.getLogger(__name__) @@ -26,14 +23,11 @@ def get_hyperlink(url, text): class ExcelExporter: class ColumnsData: TITLES = { - **{ - key: label - for key, label in Export.StaticColumn.choices - }, + **{key: label for key, label in Export.StaticColumn.choices}, # Override labels here. - Export.StaticColumn.ENTRY_EXCERPT: lambda self: [ - 'Modified Excerpt', 'Original Excerpt' - ] if self.modified_excerpt_exists else ['Excerpt'], + Export.StaticColumn.ENTRY_EXCERPT: lambda self: ( + ["Modified Excerpt", "Original Excerpt"] if self.modified_excerpt_exists else ["Excerpt"] + ), } def __init__( @@ -58,68 +52,61 @@ def __init__( # Create worksheets(Main, Grouped, Entry Groups, Bibliography) if decoupled: - self.split = self.wb.get_active_sheet()\ - .set_title('Split Entries') - self.group = self.wb.create_sheet('Grouped Entries') + self.split = self.wb.get_active_sheet().set_title("Split Entries") + self.group = self.wb.create_sheet("Grouped Entries") else: self.split = None - self.group = self.wb.get_active_sheet().set_title('Entries') + self.group = self.wb.get_active_sheet().set_title("Entries") - self.entry_groups_sheet = self.wb.create_sheet('Entry Groups') + self.entry_groups_sheet = self.wb.create_sheet("Entry Groups") self.decoupled = decoupled self.columns = columns - self.bibliography_sheet = self.wb.create_sheet('Bibliography') + self.bibliography_sheet = self.wb.create_sheet("Bibliography") self.modified_excerpt_exists = entries.filter(excerpt_modified=True).exists() project_entry_labels = ProjectEntryLabel.objects.filter( project=self.project, - ).order_by('order') + ).order_by("order") - self.label_id_title_map = { - _id: title for _id, title in project_entry_labels.values_list('id', 'title') - } + self.label_id_title_map = {_id: title for _id, title in project_entry_labels.values_list("id", "title")} - lead_groups = LeadEntryGroup.objects.filter(lead__project=self.project).order_by('order') + lead_groups = LeadEntryGroup.objects.filter(lead__project=self.project).order_by("order") self.group_id_title_map = {x.id: x.title for x in lead_groups} # Create matrix of labels and groups self.group_label_matrix = { - (group.lead_id, group.id): { - _id: None for _id in self.label_id_title_map.keys() - } - for group in lead_groups + (group.lead_id, group.id): {_id: None for _id in self.label_id_title_map.keys()} for group in lead_groups } self.lead_id_titles_map = { _id: title for _id, title in Lead.objects.filter( - project=self.project, - id__in=[_id for _id, _ in self.group_label_matrix.keys()] - ).values_list('id', 'title') + project=self.project, id__in=[_id for _id, _ in self.group_label_matrix.keys()] + ).values_list("id", "title") } self.entry_group_titles = [ - 'Lead', - 'Group', + "Lead", + "Group", *self.label_id_title_map.values(), ] self.entry_groups_sheet.append([self.entry_group_titles]) self.col_types = { - 0: 'date', - 2: 'date', + 0: "date", + 2: "date", } # Keep track of sheet data present - ''' + """ tabular_sheets = { 'leadtitle-sheettitle': { 'field1_title': col_num_in_sheet, 'field2_title': col_num_in_sheet, } } - ''' + """ self.tabular_sheets = {} # Keep track of tabular fields @@ -130,43 +117,45 @@ def __init__( self._sheets = {} def log_error(self, message, **kwargs): - logger.error(f'[EXPORT:{self.export_object.id}] {message}', **kwargs) + logger.error(f"[EXPORT:{self.export_object.id}] {message}", **kwargs) def load_exportable_titles(self, data, regions): - export_type = data.get('type') - col_type = data.get('col_type') + export_type = data.get("type") + col_type = data.get("col_type") exportable_titles = [] - if export_type == 'geo' and regions: + if export_type == "geo" and regions: self.region_data = {} for region in regions: admin_levels = region.adminlevel_set.all() admin_level_data = [] - exportable_titles.append(f'{region.title} Polygons') + exportable_titles.append(f"{region.title} Polygons") for admin_level in admin_levels: exportable_titles.append(admin_level.title) - exportable_titles.append('{} (code)'.format(admin_level.title)) + exportable_titles.append("{} (code)".format(admin_level.title)) # Collect geo area names for each admin level - admin_level_data.append({ - 'id': admin_level.id, - 'geo_area_titles': admin_level.get_geo_area_titles(), - }) + admin_level_data.append( + { + "id": admin_level.id, + "geo_area_titles": admin_level.get_geo_area_titles(), + } + ) self.region_data[region.id] = admin_level_data - elif export_type == 'multiple': + elif export_type == "multiple": index = len(exportable_titles) - exportable_titles.extend(data.get('titles')) + exportable_titles.extend(data.get("titles")) if col_type: for i in range(index, len(exportable_titles)): self.col_types[i] = col_type[i - index] - elif data.get('title'): + elif data.get("title"): index = len(exportable_titles) - exportable_titles.append(data.get('title')) + exportable_titles.append(data.get("title")) if col_type: self.col_types[index] = col_type return exportable_titles @@ -182,15 +171,15 @@ def load_exportables(self, exportables, regions=None): if self.columns is not None: _exportables = [] for column in self.columns: - if not column['is_widget']: - _exportables.append(column['static_column']) + if not column["is_widget"]: + _exportables.append(column["static_column"]) continue - widget_key = column['widget_key'] + widget_key = column["widget_key"] exportable = widget_exportables.get(widget_key) if exportable: _exportables.append(exportable) else: - self.log_error(f'Non-existing widget key is passed <{widget_key}>') + self.log_error(f"Non-existing widget key is passed <{widget_key}>") else: _exportables = [ *self.ColumnsData.TITLES.keys(), @@ -214,10 +203,8 @@ def load_exportables(self, exportables, regions=None): else: # For each exportable, create titles according to type # and data - data = exportable.data.get('excel') - column_titles.extend( - self.load_exportable_titles(data, regions) - ) + data = exportable.data.get("excel") + column_titles.extend(self.load_exportable_titles(data, regions)) if self.decoupled and self.split: self.split.append([column_titles]) @@ -244,9 +231,9 @@ def add_entries_from_excel_data_for_static_column( elif exportable == Export.StaticColumn.ENTRY_CREATED_AT: return self.date_renderer(entry.created_at) elif exportable == Export.StaticColumn.ENTRY_CONTROL_STATUS: - return 'Controlled' if entry.controlled else 'Uncontrolled' + return "Controlled" if entry.controlled else "Uncontrolled" elif exportable == Export.StaticColumn.LEAD_ID: - return f'{lead.id}' + return f"{lead.id}" elif exportable == Export.StaticColumn.LEAD_TITLE: return lead.title elif exportable == Export.StaticColumn.LEAD_URL: @@ -264,9 +251,9 @@ def add_entries_from_excel_data_for_static_column( elif exportable == Export.StaticColumn.LEAD_ASSIGNEE: return assignee and assignee.profile.get_display_name() elif exportable == Export.StaticColumn.ENTRY_ID: - return f'{entry.id}' + return f"{entry.id}" elif exportable == Export.StaticColumn.LEAD_ENTRY_ID: - return f'{lead.id}-{entry.id}' + return f"{lead.id}-{entry.id}" elif exportable == Export.StaticColumn.ENTRY_EXCERPT: entry_excerpt = self.get_entry_data(entry) if self.modified_excerpt_exists: @@ -274,10 +261,10 @@ def add_entries_from_excel_data_for_static_column( return entry_excerpt def add_entries_from_excel_data(self, rows, data, export_data): - export_type = data.get('type') + export_type = data.get("type") - if export_type == 'nested': - children = data.get('children') + if export_type == "nested": + children = data.get("children") for i, child in enumerate(children): if export_data is None or i >= len(export_data): _export_data = None @@ -289,11 +276,11 @@ def add_entries_from_excel_data(self, rows, data, export_data): _export_data, ) - elif export_type == 'multiple': - col_span = len(data.get('titles')) + elif export_type == "multiple": + col_span = len(data.get("titles")) if export_data: - if export_data.get('type') == 'lists': - export_data_values = export_data.get('values') + if export_data.get("type") == "lists": + export_data_values = export_data.get("values") rows_of_value_lists = [] for export_data_value in export_data_values: # Handle for Matrix2D subsectors @@ -305,13 +292,13 @@ def add_entries_from_excel_data(self, rows, data, export_data): for subsector in export_data_value[3]: rows_of_value_lists.append(export_data_value[:3] + [subsector]) else: - rows_of_value_lists.append(export_data_value[:3] + ['']) - elif len(export_data_value) != len(data.get('titles')): - titles_len = len(data.get('titles')) + rows_of_value_lists.append(export_data_value[:3] + [""]) + elif len(export_data_value) != len(data.get("titles")): + titles_len = len(data.get("titles")) values_len = len(export_data_value) if titles_len > values_len: # Add additional empty cells - rows_of_value_lists.append(export_data_value + [''] * (titles_len - values_len)) + rows_of_value_lists.append(export_data_value + [""] * (titles_len - values_len)) else: # Remove extra cells rows_of_value_lists.append(export_data_value[:titles_len]) @@ -319,34 +306,33 @@ def add_entries_from_excel_data(self, rows, data, export_data): rows_of_value_lists.append(export_data_value) rows.add_rows_of_value_lists( # Filter if all values are None - [ - x for x in rows_of_value_lists - if x is not None and not all(y is None for y in x) - ], + [x for x in rows_of_value_lists if x is not None and not all(y is None for y in x)], col_span, ) else: - export_data_values = export_data.get('values') - if export_data.get('widget_key') == Widget.WidgetType.DATE_RANGE.value: + export_data_values = export_data.get("values") + if export_data.get("widget_key") == Widget.WidgetType.DATE_RANGE.value: if len(export_data_values) == 2 and any(export_data_values): - rows.add_value_list([ - self.date_renderer(deep_date_parse(export_data_values[0], raise_exception=False)), - self.date_renderer(deep_date_parse(export_data_values[1], raise_exception=False)), - ]) + rows.add_value_list( + [ + self.date_renderer(deep_date_parse(export_data_values[0], raise_exception=False)), + self.date_renderer(deep_date_parse(export_data_values[1], raise_exception=False)), + ] + ) else: rows.add_value_list(export_data_values) else: - rows.add_value_list([''] * col_span) + rows.add_value_list([""] * col_span) - elif export_type == 'geo' and self.regions: + elif export_type == "geo" and self.regions: geo_id_values = [] region_geo_polygons = {} if export_data: - geo_id_values = [str(v) for v in export_data.get('values') or []] - for geo_polygon in export_data.get('polygons') or []: - region_id = geo_polygon['region_id'] + geo_id_values = [str(v) for v in export_data.get("values") or []] + for geo_polygon in export_data.get("polygons") or []: + region_id = geo_polygon["region_id"] region_geo_polygons[region_id] = region_geo_polygons.get(region_id) or [] - region_geo_polygons[region_id].append(geo_polygon['title']) + region_geo_polygons[region_id].append(geo_polygon["title"]) for region in self.regions: admin_levels = self.region_data[region.id] @@ -357,7 +343,7 @@ def add_entries_from_excel_data(self, rows, data, export_data): rows.add_rows_of_values(geo_polygons) for rev_level, admin_level in enumerate(admin_levels[::-1]): - geo_area_titles = admin_level['geo_area_titles'] + geo_area_titles = admin_level["geo_area_titles"] level = max_levels - rev_level for geo_id in geo_id_values: if geo_id not in geo_area_titles: @@ -366,61 +352,59 @@ def add_entries_from_excel_data(self, rows, data, export_data): rows_value.append(self.geoarea_data_cache[geo_id]) continue - row_values = ['' for i in range(0, max_levels - level)] * 2 + row_values = ["" for i in range(0, max_levels - level)] * 2 - title = geo_area_titles[geo_id].get('title', '') - code = geo_area_titles[geo_id].get('code', '') - parent_id = geo_area_titles[geo_id].get('parent_id') + title = geo_area_titles[geo_id].get("title", "") + code = geo_area_titles[geo_id].get("code", "") + parent_id = geo_area_titles[geo_id].get("parent_id") row_values.extend([code, title]) for _level in range(0, level - 1)[::-1]: if parent_id: - _geo_area_titles = admin_levels[_level]['geo_area_titles'] + _geo_area_titles = admin_levels[_level]["geo_area_titles"] _geo_area = _geo_area_titles.get(parent_id) or {} - _title = _geo_area.get('title', '') - _code = _geo_area.get('code', '') - parent_id = _geo_area.get('parent_id') + _title = _geo_area.get("title", "") + _code = _geo_area.get("code", "") + parent_id = _geo_area.get("parent_id") row_values.extend([_code, _title]) else: - row_values.extend(['', '']) + row_values.extend(["", ""]) rows_value.append(row_values[::-1]) self.geoarea_data_cache[geo_id] = row_values[::-1] if len(rows_value) > 0: rows.add_rows_of_value_lists(rows_value) else: - rows.add_rows_of_value_lists([['' for i in range(0, max_levels)] * 2]) + rows.add_rows_of_value_lists([["" for i in range(0, max_levels)] * 2]) else: if export_data: - if export_data.get('type') == 'list': + if export_data.get("type") == "list": row_values = [ # This is in hope of filtering out non-existent data from excel row - x for x in export_data.get('value', []) + x + for x in export_data.get("value", []) if x is not None ] - rows.add_rows_of_values(row_values if row_values else ['']) + rows.add_rows_of_values(row_values if row_values else [""]) else: - rows.add_value(export_data.get('value')) + rows.add_value(export_data.get("value")) else: - rows.add_value('') + rows.add_value("") def get_data_series(self, entry): lead = entry.lead field = entry.tabular_field if field is None: - return '' + return "" self.tabular_fields[field.id] = field # Get Sheet title which is Lead title - Sheet title # Worksheet title is limited to 31 as excel's tab length is capped to 31 - worksheet_title = '{}-{}'.format(lead.title, field.sheet.title) + worksheet_title = "{}-{}".format(lead.title, field.sheet.title) if not self._sheets.get(worksheet_title) and len(worksheet_title) > 31: - self._sheets[worksheet_title] = '{}-{}'.format( - worksheet_title[:28], - len(self.wb.wb.worksheets) - ) + self._sheets[worksheet_title] = "{}-{}".format(worksheet_title[:28], len(self.wb.wb.worksheets)) elif not self._sheets.get(worksheet_title): self._sheets[worksheet_title] = worksheet_title worksheet_title = self._sheets[worksheet_title] @@ -448,18 +432,15 @@ def get_data_series(self, entry): self.tabular_sheets[worksheet_title] = worksheet_data # Insert field title to sheet in first row - tabular_sheet['{}1'.format(sheet_col_name)].value =\ - field.title + tabular_sheet["{}1".format(sheet_col_name)].value = field.title # Add field values to corresponding column for i, x in enumerate(field.actual_data): - tabular_sheet[ - '{}{}'.format(sheet_col_name, 2 + i) - ].value = x.get('processed_value') or x['value'] + tabular_sheet["{}{}".format(sheet_col_name, 2 + i)].value = x.get("processed_value") or x["value"] else: sheet_col_name = excel_column_name(col_number) - link = f'#\'{worksheet_title}\'!{sheet_col_name}1' + link = f"#'{worksheet_title}'!{sheet_col_name}1" return get_hyperlink(link, field.title) def get_entry_data(self, entry): @@ -474,15 +455,15 @@ def get_entry_data(self, entry): return self.get_data_series(entry) except Exception: self.log_error( - 'Data Series EXCEL Export Failed for entry', + "Data Series EXCEL Export Failed for entry", exc_info=1, - extra={'data': {'entry_id': entry.pk}}, + extra={"data": {"entry_id": entry.pk}}, ) - return '' + return "" def add_entries(self, entries): - iterable_entries = entries[:Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries + iterable_entries = entries[: Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries for i, entry in enumerate(iterable_entries): # Export each entry # Start building rows and export data for each exportable @@ -491,8 +472,8 @@ def add_entries(self, entries): # Add it to appropriate row/column in self.group_label_matrix for group_label in entry.entrygrouplabel_set.all(): key = (group_label.group.lead_id, group_label.group_id) - entries_sheet_name = 'Grouped Entries' if self.decoupled else 'Entries' - link = f'#\'{entries_sheet_name}\'!A{i+2}' + entries_sheet_name = "Grouped Entries" if self.decoupled else "Entries" + link = f"#'{entries_sheet_name}'!A{i+2}" self.group_label_matrix[key][group_label.label_id] = get_hyperlink(link, entry.excerpt[:50]) lead = entry.lead @@ -517,20 +498,17 @@ def add_entries(self, entries): # exportable. # And write some value based on type and data # or empty strings if no data. - data = exportable.data.get('excel') + data = exportable.data.get("excel") export_data = ExportData.objects.filter( exportable=exportable, entry=entry, data__excel__isnull=False, ).first() - if export_data and type(export_data.data.get('excel', {})) == list: - export_data = export_data.data.get('excel', []) + if export_data and type(export_data.data.get("excel", {})) == list: + export_data = export_data.data.get("excel", []) else: - export_data = export_data and { - **export_data.data.get('common', {}), - **export_data.data.get('excel', {}) - } + export_data = export_data and {**export_data.data.get("common", {}), **export_data.data.get("excel", {})} self.add_entries_from_excel_data(rows, data, export_data) rows.apply() @@ -546,33 +524,38 @@ def add_entries(self, entries): return self def add_bibliography_sheet(self, leads_qs): - self.bibliography_sheet.append([['Author', 'Source', 'Published Date', 'Title', 'Entries Count']]) + self.bibliography_sheet.append([["Author", "Source", "Published Date", "Title", "Entries Count"]]) qs = leads_qs # This is annotated from LeadGQFilterSet.filter_queryset if not use total entries count - if 'filtered_entry_count' not in qs.query.annotations: + if "filtered_entry_count" not in qs.query.annotations: qs = qs.annotate( filtered_entry_count=models.functions.Coalesce( models.Subquery( Entry.objects.filter( project=self.project, analysis_framework=self.project.analysis_framework_id, - lead=models.OuterRef('pk'), - ).order_by().values('lead') - .annotate(count=models.Count('id')) - .values('count'), + lead=models.OuterRef("pk"), + ) + .order_by() + .values("lead") + .annotate(count=models.Count("id")) + .values("count"), output_field=models.IntegerField(), - ), 0, + ), + 0, ) ) for lead in qs: self.bibliography_sheet.append( - [[ - lead.get_authors_display(), - lead.get_source_display(), - self.date_renderer(lead.published_on), - get_hyperlink(lead.url, lead.title) if lead.url else lead.title, - lead.filtered_entry_count, - ]] + [ + [ + lead.get_authors_display(), + lead.get_source_display(), + self.date_renderer(lead.published_on), + get_hyperlink(lead.url, lead.title) if lead.url else lead.title, + lead.filtered_entry_count, + ] + ] ) def export(self, leads_qs): diff --git a/apps/export/entries/json_exporter.py b/apps/export/entries/json_exporter.py index 34053d74d6..c07a359a1c 100644 --- a/apps/export/entries/json_exporter.py +++ b/apps/export/entries/json_exporter.py @@ -1,7 +1,8 @@ -from utils.files import generate_json_file_for_upload from analysis_framework.models import Widget from export.models import Export +from utils.files import generate_json_file_for_upload + class JsonExporter: def __init__(self, is_preview=False): @@ -12,7 +13,7 @@ def load_exportables(self, exportables): self.exportables = exportables self.widget_ids = [] - self.data['widgets'] = [] + self.data["widgets"] = [] for exportable in self.exportables: widget = Widget.objects.get( analysis_framework=exportable.analysis_framework, @@ -21,44 +22,44 @@ def load_exportables(self, exportables): self.widget_ids.append(widget.id) data = {} - data['id'] = widget.key - data['widget_type'] = widget.widget_id - data['title'] = widget.title - data['properties'] = widget.properties - self.data['widgets'].append(data) + data["id"] = widget.key + data["widget_type"] = widget.widget_id + data["title"] = widget.title + data["properties"] = widget.properties + self.data["widgets"].append(data) return self def add_entries(self, entries): - self.data['entries'] = [] + self.data["entries"] = [] - iterable_entries = entries[:Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries + iterable_entries = entries[: Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries for entry in iterable_entries: lead = entry.lead data = {} - data['id'] = entry.id - data['lead_id'] = lead.id - data['lead'] = lead.title - data['source'] = lead.get_source_display() - data['priority'] = lead.get_priority_display() - data['author'] = lead.get_authors_display() - data['date'] = lead.published_on - data['excerpt'] = entry.excerpt - data['image'] = entry.get_image_url() - data['attributes'] = [] - data['data_series'] = {} + data["id"] = entry.id + data["lead_id"] = lead.id + data["lead"] = lead.title + data["source"] = lead.get_source_display() + data["priority"] = lead.get_priority_display() + data["author"] = lead.get_authors_display() + data["date"] = lead.published_on + data["excerpt"] = entry.excerpt + data["image"] = entry.get_image_url() + data["attributes"] = [] + data["data_series"] = {} for attribute in entry.attribute_set.all(): attribute_data = {} - attribute_data['widget_id'] = attribute.widget.key - attribute_data['data'] = attribute.data - data['attributes'].append(attribute_data) + attribute_data["widget_id"] = attribute.widget.key + attribute_data["data"] = attribute.data + data["attributes"].append(attribute_data) if entry.tabular_field: - data['data_series'] = { - 'options': entry.tabular_field.options, - 'data': entry.tabular_field.actual_data, + data["data_series"] = { + "options": entry.tabular_field.options, + "data": entry.tabular_field.actual_data, } - self.data['entries'].append(data) + self.data["entries"].append(data) return self def export(self): diff --git a/apps/export/entries/report_exporter.py b/apps/export/entries/report_exporter.py index 745919a6bd..edfc3a9015 100644 --- a/apps/export/entries/report_exporter.py +++ b/apps/export/entries/report_exporter.py @@ -1,56 +1,46 @@ +import logging import os import tempfile -import logging from datetime import datetime from subprocess import call +from analysis_framework.models import Widget +from ary.export.affected_groups_info import ( + get_affected_groups_info as ary_get_affected_groups_info, +) +from ary.export.data_collection_techniques_info import ( + get_data_collection_techniques_info as ary_get_data_collection_techniques_info, +) from django.conf import settings from django.core.files.base import ContentFile, File -from django.db.models import ( - Case, - When, - Q, -) +from django.db.models import Case, Q, When from docx.shared import Inches -from deep.permalinks import Permalink -from utils.common import deep_date_parse, deep_date_format - -from export.formats.docx import Document - -from analysis_framework.models import Widget -from entry.models import ( - Entry, - ExportData, - Attribute, - # EntryGroupLabel, -) +from entry.models import Attribute, Entry, ExportData # EntryGroupLabel, from entry.widgets import ( - scale_widget, - time_widget, - date_widget, - time_range_widget, date_range_widget, + date_widget, geo_widget, - select_widget, multiselect_widget, organigram_widget, + scale_widget, + select_widget, + time_range_widget, + time_widget, ) from entry.widgets.store import widget_store - -from ary.export.affected_groups_info import get_affected_groups_info as ary_get_affected_groups_info -from ary.export.data_collection_techniques_info import ( - get_data_collection_techniques_info as ary_get_data_collection_techniques_info -) - +from export.formats.docx import Document +from export.models import Export from lead.models import Lead from tabular.viz import renderer as viz_renderer -from export.models import Export + +from deep.permalinks import Permalink +from utils.common import deep_date_format, deep_date_parse logger = logging.getLogger(__name__) -SEPARATOR = ', ' -INTERNAL_SEPARATOR = '; ' -ASSESSMENT_ICON_IMAGE_PATH = os.path.join(settings.BASE_DIR, 'apps/static/image/drop-icon.png') +SEPARATOR = ", " +INTERNAL_SEPARATOR = "; " +ASSESSMENT_ICON_IMAGE_PATH = os.path.join(settings.BASE_DIR, "apps/static/image/drop-icon.png") class ExportDataVersionMismatch(Exception): @@ -66,7 +56,7 @@ def _add_common(para, text, bold): @staticmethod def _add_scale_widget_data(para, label, color, bold): """ - Output: + Output: """ para.add_oval_shape(color) para.add_run(label, bold) @@ -80,8 +70,8 @@ def _get_scale_widget_data(cls, data, bold, **kwargs): - color as described here: apps.entry.widgets.scale_widget._get_scale """ - label = data.get('label') - color = data.get('color') + label = data.get("label") + color = data.get("color") if label and color: return cls._add_scale_widget_data, label, color, bold @@ -92,22 +82,16 @@ def _get_date_range_widget_data(cls, data, bold, **kwargs): - tuple (from, to) as described here: apps.entry.widgets.date_range_widget._get_date """ - date_renderer = kwargs['date_renderer'] - values = data.get('values', []) + date_renderer = kwargs["date_renderer"] + values = data.get("values", []) if len(values) == 2 and any(values): - label = '{} - {}'.format( + label = "{} - {}".format( date_renderer( - deep_date_parse( - values[0], - raise_exception=False - ), + deep_date_parse(values[0], raise_exception=False), fallback="N/A", ), date_renderer( - deep_date_parse( - values[1], - raise_exception=False - ), + deep_date_parse(values[1], raise_exception=False), fallback="N/A", ), ) @@ -120,9 +104,9 @@ def _get_time_range_widget_data(cls, data, bold, **kwargs): - tuple (from, to) as described here: apps.entry.widgets.time_range_widget._get_time """ - values = data.get('values', []) + values = data.get("values", []) if len(values) == 2 and any(values): - text = '{} - {}'.format( + text = "{} - {}".format( values[0] or "~~:~~", values[1] or "~~:~~", ) @@ -135,25 +119,25 @@ def _get_date_widget_data(cls, data, bold, **kwargs): - string (=date) as described here: apps.entry.widgets.date_widget """ - value = data.get('value') + value = data.get("value") if not value: return - date_renderer = kwargs['date_renderer'] + date_renderer = kwargs["date_renderer"] _value = date_renderer(deep_date_parse(value, raise_exception=False)) if _value: return cls._add_common, _value, bold @classmethod def _get_time_widget_data(cls, data, bold, **kwargs): - value = data.get('value') + value = data.get("value") if value: return cls._add_common, value, bold @classmethod def _get_select_widget_data(cls, data, bold, **kwargs): - type_ = data.get('type') - value = [str(v) for v in data.get('value') or []] - if type_ == 'list' and value: + type_ = data.get("type") + value = [str(v) for v in data.get("value") or []] + if type_ == "list" and value: return cls._add_common, INTERNAL_SEPARATOR.join(value), bold @classmethod @@ -162,20 +146,17 @@ def _get_multi_select_widget_data(cls, data, bold, **kwargs): @classmethod def _get_organigram_widget_data(cls, data, bold, **kwargs): - text = INTERNAL_SEPARATOR.join( - '/'.join(value_with_parent) - for value_with_parent in data.get('values') - ) + text = INTERNAL_SEPARATOR.join("/".join(value_with_parent) for value_with_parent in data.get("values")) return cls._add_common, text, bold @classmethod def _get_geo_widget_data(cls, data, bold, **kwargs): # XXX: Cache this value. # Right now everything needs to be loaded so doing this at entry save can take lot of memory - geo_id_values = [str(v) for v in data.get('values') or []] + geo_id_values = [str(v) for v in data.get("values") or []] if len(geo_id_values) == 0: return - geo_values = kwargs['_get_geo_admin_level_1_data'](geo_id_values) + geo_values = kwargs["_get_geo_admin_level_1_data"](geo_id_values) if geo_values: return cls._add_common, geo_values, bold @@ -195,8 +176,8 @@ def get_widget_information_into_report( """ if not isinstance(report, dict): return - if 'widget_id' in report: - widget_id = report.get('widget_id') + if "widget_id" in report: + widget_id = report.get("widget_id") mapper = { scale_widget.WIDGET_ID: cls._get_scale_widget_data, date_range_widget.WIDGET_ID: cls._get_date_range_widget_data, @@ -209,9 +190,9 @@ def get_widget_information_into_report( multiselect_widget.WIDGET_ID: cls._get_multi_select_widget_data, } if widget_id in mapper.keys(): - if report.get('version') != widget_store[widget_id].DATA_VERSION: + if report.get("version") != widget_store[widget_id].DATA_VERSION: raise ExportDataVersionMismatch( - f'{widget_id} widget data is not upto date. Export data being exported: {report}' + f"{widget_id} widget data is not upto date. Export data being exported: {report}" ) return mapper[widget_id](report, bold, **kwargs) @@ -233,17 +214,11 @@ def __init__( self.show_entry_widget_data = show_entry_widget_data # self.entry_group_labels = {} # TODO: Remove entry group labels? - self.doc = Document( - os.path.join(settings.APPS_DIR, 'static/doc_export/template.docx') - ) + self.doc = Document(os.path.join(settings.APPS_DIR, "static/doc_export/template.docx")) self.lead_ids = [] # ordered list of widget ids - self.exporting_widgets_ids = [ - int(_id) for _id in exporting_widgets or [] - ] - self.exporting_widgets_keys = list( - Widget.objects.filter(id__in=self.exporting_widgets_ids).values_list('key', flat=True) - ) + self.exporting_widgets_ids = [int(_id) for _id in exporting_widgets or []] + self.exporting_widgets_keys = list(Widget.objects.filter(id__in=self.exporting_widgets_ids).values_list("key", flat=True)) self.region_data = {} # XXX: Limit memory usage? (Or use redis?) self.geoarea_data_cache = {} @@ -263,10 +238,7 @@ def load_exportables(self, exportables, regions): self.exportables = exportables - geo_data_required = Widget.objects.filter( - id__in=self.exporting_widgets_ids, - widget_id=geo_widget.WIDGET_ID - ).exists() + geo_data_required = Widget.objects.filter(id__in=self.exporting_widgets_ids, widget_id=geo_widget.WIDGET_ID).exists() # Load geo data if required if geo_data_required: self.region_data = {} @@ -274,9 +246,9 @@ def load_exportables(self, exportables, regions): # Collect geo area names for each admin level self.region_data[region.id] = [ { - 'id': admin_level.id, - 'level': admin_level.level, - 'geo_area_titles': admin_level.get_geo_area_titles(), + "id": admin_level.id, + "level": admin_level.level, + "geo_area_titles": admin_level.get_geo_area_titles(), } for admin_level in region.adminlevel_set.all() ] @@ -310,10 +282,7 @@ def load_text_from_text_widgets(self, entries, text_widget_ids): """ # User defined widget order (Generate order map) widget_ids = [int(id) for id in text_widget_ids] - widget_map = { - int(id): index - for index, id in enumerate(text_widget_ids) - } + widget_map = {int(id): index for index, id in enumerate(text_widget_ids)} attribute_qs = Attribute.objects.filter( entry__in=entries, @@ -330,13 +299,13 @@ def load_text_from_text_widgets(self, entries, text_widget_ids): widget_id, widget_type, ) in attribute_qs.values_list( - 'entry_id', - 'data__value', # Text - 'widget__title', - 'widget__id', - 'widget__widget_id', # Widget Type + "entry_id", + "data__value", # Text + "widget__title", + "widget__id", + "widget__widget_id", # Widget Type ): - if widget_type == 'conditionalWidget': + if widget_type == "conditionalWidget": continue widget_order = widget_map[widget_id] @@ -358,35 +327,35 @@ def _generate_legend_page(self, project): self.legend_paragraph.add_next_paragraph(para) # todo in a table - scale_widgets = project.analysis_framework.widget_set.filter(widget_id='scaleWidget') + scale_widgets = project.analysis_framework.widget_set.filter(widget_id="scaleWidget") for widget in scale_widgets[::-1]: - if not hasattr(widget, 'title'): + if not hasattr(widget, "title"): continue title_para = self.doc.add_paragraph() title_para.ref.paragraph_format.right_indent = Inches(0.25) - title_para.add_run(f'{widget.title}') - for legend in widget.properties.get('options', [])[::-1]: + title_para.add_run(f"{widget.title}") + for legend in widget.properties.get("options", [])[::-1]: para = self.doc.add_paragraph() para.ref.paragraph_format.right_indent = Inches(0.25) - para.add_oval_shape(legend.get('color')) + para.add_oval_shape(legend.get("color")) para.add_run(f' {legend.get("label", "-Missing-")}') self.legend_paragraph.add_next_paragraph(para) self.legend_paragraph.add_next_paragraph(title_para) - cond_widgets = project.analysis_framework.widget_set.filter(widget_id='conditionalWidget') + cond_widgets = project.analysis_framework.widget_set.filter(widget_id="conditionalWidget") for c_widget in cond_widgets[::-1]: for widget in filter( - lambda x: x.get('widget', {}).get('widget_id') == 'scaleWidget', - c_widget.properties.get('data', {}).get('widgets', []) + lambda x: x.get("widget", {}).get("widget_id") == "scaleWidget", + c_widget.properties.get("data", {}).get("widgets", []), ): - if not widget.get('widget', {}).get('title'): + if not widget.get("widget", {}).get("title"): continue title_para = self.doc.add_paragraph() title_para.ref.paragraph_format.right_indent = Inches(0.25) title_para.add_run(f'{widget.get("widget", {}).get("title")}') - for legend in widget.get('widget', {}).get('properties', {}).get('options', [])[::-1]: + for legend in widget.get("widget", {}).get("properties", {}).get("options", [])[::-1]: para = self.doc.add_paragraph() para.ref.paragraph_format.right_indent = Inches(0.25) - para.add_oval_shape(legend.get('color')) + para.add_oval_shape(legend.get("color")) para.add_run(f' {legend.get("label", "-Missing-")}') self.legend_paragraph.add_next_paragraph(para) self.legend_paragraph.add_next_paragraph(title_para) @@ -398,7 +367,7 @@ def _get_geo_admin_level_1_data(self, geo_id_values): render_values = [] for region_id, admin_levels in self.region_data.items(): for admin_level in admin_levels: - geo_area_titles = admin_level['geo_area_titles'] + geo_area_titles = admin_level["geo_area_titles"] for geo_id in geo_id_values: if geo_id not in geo_area_titles: continue @@ -408,20 +377,20 @@ def _get_geo_admin_level_1_data(self, geo_id_values): continue self.geoarea_data_cache[geo_id] = None - title = geo_area_titles[geo_id].get('title') - parent_id = geo_area_titles[geo_id].get('parent_id') - if admin_level['level'] == 1: + title = geo_area_titles[geo_id].get("title") + parent_id = geo_area_titles[geo_id].get("parent_id") + if admin_level["level"] == 1: title and render_values.append(title) self.geoarea_data_cache[geo_id] = title continue # Try to look through parent - for _level in range(0, admin_level['level'])[::-1]: + for _level in range(0, admin_level["level"])[::-1]: if parent_id: - _geo_area_titles = admin_levels[_level]['geo_area_titles'] + _geo_area_titles = admin_levels[_level]["geo_area_titles"] _geo_area = _geo_area_titles.get(parent_id) or {} - _title = _geo_area.get('title') - parent_id = _geo_area.get('parent_id') + _title = _geo_area.get("title") + parent_id = _geo_area.get("parent_id") if _level == 1: _title and render_values.append(_title) self.geoarea_data_cache[geo_id] = title @@ -433,8 +402,8 @@ def _get_geo_admin_level_1_data(self, geo_id_values): def _add_assessment_info_for_entry(self, assessment, para, bold=True): def _add_assessment_icon(): # NOTE: Add icon here - run = para.add_run('', bold=bold) - with open(ASSESSMENT_ICON_IMAGE_PATH, 'rb') as fp: + run = para.add_run("", bold=bold) + with open(ASSESSMENT_ICON_IMAGE_PATH, "rb") as fp: run.add_inline_image(fp, width=Inches(0.15), height=Inches(0.15)) cache = self.assessment_data_cache.get(assessment.pk) @@ -443,28 +412,32 @@ def _add_assessment_icon(): if cache is None: cache = {} # Collect Assessment GEO Data - cache['locations'] = self._get_geo_admin_level_1_data( - assessment.locations.values_list('id', flat=True), + cache["locations"] = self._get_geo_admin_level_1_data( + assessment.locations.values_list("id", flat=True), + ) + cache["affected_groups_info"] = INTERNAL_SEPARATOR.join( + [ + "/".join([str(s) for s in info.values() if s]) + for info in ary_get_affected_groups_info(assessment)["affected_groups_info"] + ] + ) + cache["data_collection_techniques_info"] = INTERNAL_SEPARATOR.join( + [ + f"{info['Sampling Size']} {info['Data Collection Technique']}" + for info in ary_get_data_collection_techniques_info(assessment)["data_collection_technique"] + if info.get("Sampling Size") + ] ) - cache['affected_groups_info'] = INTERNAL_SEPARATOR.join([ - '/'.join([str(s) for s in info.values() if s]) - for info in ary_get_affected_groups_info(assessment)['affected_groups_info'] - ]) - cache['data_collection_techniques_info'] = INTERNAL_SEPARATOR.join([ - f"{info['Sampling Size']} {info['Data Collection Technique']}" - for info in ary_get_data_collection_techniques_info(assessment)['data_collection_technique'] - if info.get('Sampling Size') - ]) dc_start_date = deep_date_format(assessment.data_collection_start_date) dc_end_date = deep_date_format(assessment.data_collection_end_date) if dc_start_date or dc_end_date: - cache['data_collection_date'] = f'Data collection: {dc_start_date} - {dc_end_date}' + cache["data_collection_date"] = f"Data collection: {dc_start_date} - {dc_end_date}" self.assessment_data_cache[assessment.pk] = cache - locations = cache['locations'] - affected_groups_info = cache['affected_groups_info'] - data_collection_techniques_info = cache['data_collection_techniques_info'] - data_collection_date = cache['data_collection_date'] + locations = cache["locations"] + affected_groups_info = cache["affected_groups_info"] + data_collection_techniques_info = cache["data_collection_techniques_info"] + data_collection_date = cache["data_collection_date"] to_process_fuctions = [ func @@ -474,29 +447,30 @@ def _add_assessment_icon(): (affected_groups_info, lambda: para.add_run(affected_groups_info, bold=bold)), (data_collection_techniques_info, lambda: para.add_run(data_collection_techniques_info, bold=bold)), (data_collection_date, lambda: para.add_run(data_collection_date, bold=bold)), - ] if condition + ] + if condition ] - para.add_run(' [', bold=True) + para.add_run(" [", bold=True) # Finally add all assessment data to the docx total_process_functions = len(to_process_fuctions) - 1 for index, add_data in enumerate(to_process_fuctions): add_data() if index < total_process_functions: para.add_run(SEPARATOR, bold=bold) - para.add_run('] ', bold=True) + para.add_run("] ", bold=True) def _generate_for_entry_widget_data(self, entry, para): if entry.id not in self.entry_widget_data_cache: raw_export_data = [] for each in entry.exportdata_set.all(): export_datum = { - **(each.data.get('common') or {}), - **(each.data.get('report') or {}), + **(each.data.get("common") or {}), + **(each.data.get("report") or {}), } - if export_datum.get('widget_key') and export_datum['widget_key'] in self.exporting_widgets_keys: + if export_datum.get("widget_key") and export_datum["widget_key"] in self.exporting_widgets_keys: raw_export_data.append(export_datum) - raw_export_data.sort(key=lambda x: self.exporting_widgets_keys.index(x['widget_key'])) + raw_export_data.sort(key=lambda x: self.exporting_widgets_keys.index(x["widget_key"])) export_data = [] if raw_export_data: @@ -511,20 +485,19 @@ def _generate_for_entry_widget_data(self, entry, para): export_data.append(resp) except ExportDataVersionMismatch: logger.error( - f'ExportDataVersionMismatch: For entry {entry.id}, project {entry.project.id}', - exc_info=True + f"ExportDataVersionMismatch: For entry {entry.id}, project {entry.project.id}", exc_info=True ) self.entry_widget_data_cache[entry.id] = export_data export_data = self.entry_widget_data_cache[entry.id] if export_data: - para.add_run(' [', bold=True) + para.add_run(" [", bold=True) export_data_len = len(export_data) - 1 for index, [func, *args] in enumerate(export_data): func(para, *args) # Add to para if index < export_data_len: para.add_run(SEPARATOR, bold=True) - para.add_run('] ', bold=True) + para.add_run("] ", bold=True) def _generate_for_entry(self, entry): """ @@ -535,14 +508,14 @@ def _generate_for_entry(self, entry): # entry-lead id if self.show_lead_entry_id: - para.add_run('[', bold=True) + para.add_run("[", bold=True) # Add lead-entry id url = Permalink.entry(entry.project_id, entry.lead_id, entry.id) para.add_hyperlink(url, f"{entry.lead_id}-{entry.id}") - para.add_run(']', bold=True) + para.add_run("]", bold=True) # Assessment Data - if self.show_assessment_data and getattr(entry.lead, 'assessment', None): + if self.show_assessment_data and getattr(entry.lead, "assessment", None): self._add_assessment_info_for_entry(entry.lead.assessmentregistry, para, bold=True) # Entry widget Data @@ -554,13 +527,10 @@ def _generate_for_entry(self, entry): # where source is hyperlinked to appropriate url # Excerpt can also be image - excerpt = ( - entry.excerpt if entry.entry_type == Entry.TagType.EXCERPT - else '' - ) + excerpt = entry.excerpt if entry.entry_type == Entry.TagType.EXCERPT else "" if self.citation_style == Export.CitationStyle.STYLE_1: - para.add_run(excerpt.rstrip('.')) + para.add_run(excerpt.rstrip(".")) else: # Default para.add_run(excerpt) @@ -572,7 +542,7 @@ def _generate_for_entry(self, entry): for order in sorted(entry_texts.keys()): title, text = entry_texts[order] para = self.doc.add_paragraph().justify() - para.add_run(f'{title}: ', bold=True) + para.add_run(f"{title}: ", bold=True) para.add_run(text) para = self.doc.add_paragraph().justify() @@ -585,12 +555,12 @@ def _generate_for_entry(self, entry): # para.add_run().add_image(entry.image_raw) elif entry.entry_type == Entry.TagType.DATA_SERIES and entry.tabular_field: image = viz_renderer.get_entry_image(entry) - h_stats = (entry.tabular_field.cache or {}).get('health_stats', {}) + h_stats = (entry.tabular_field.cache or {}).get("health_stats", {}) - image_text = ' Total values: {}'.format(h_stats.get('total', 'N/A')) - for key in ['invalid', 'null']: + image_text = " Total values: {}".format(h_stats.get("total", "N/A")) + for key in ["invalid", "null"]: if h_stats.get(key): - image_text += f', {key.title()} values: {h_stats.get(key)}' if h_stats.get(key) else '' + image_text += f", {key.title()} values: {h_stats.get(key)}" if h_stats.get(key) else "" if image: self.doc.add_image(image) @@ -603,15 +573,15 @@ def _generate_for_entry(self, entry): # --- Reference Start if widget_texts_exists or image: - para.add_run('(') # Starting from new line + para.add_run("(") # Starting from new line else: - para.add_run(' (') # Starting from same line + para.add_run(" (") # Starting from same line if self.citation_style == Export.CitationStyle.STYLE_1: - source = '' + source = "" author = lead.get_authors_display(short_name=True) else: # Default - source = lead.get_source_display() or 'Reference' + source = lead.get_source_display() or "Reference" author = lead.get_authors_display() url = lead.url or Permalink.lead_share_view(lead.uuid) @@ -619,18 +589,18 @@ def _generate_for_entry(self, entry): if self.citation_style == Export.CitationStyle.STYLE_1: # Add author is available - if (author and author.lower() != (source or '').lower()): + if author and author.lower() != (source or "").lower(): if url: - para.add_hyperlink(url, f'{author} ') + para.add_hyperlink(url, f"{author} ") else: - para.add_run(f'{author} ') + para.add_run(f"{author} ") else: # Default # Add author is available - if (author and author.lower() != (source or '').lower()): + if author and author.lower() != (source or "").lower(): if url: - para.add_hyperlink(url, f'{author}, ') + para.add_hyperlink(url, f"{author}, ") else: - para.add_run(f'{author}, ') + para.add_run(f"{author}, ") # Add source (with url if available) if url: para.add_hyperlink(url, source) @@ -639,9 +609,9 @@ def _generate_for_entry(self, entry): # Add (confidential/restricted) to source without , if lead.confidentiality == Lead.Confidentiality.CONFIDENTIAL: - para.add_run(' (confidential)') + para.add_run(" (confidential)") elif lead.confidentiality == Lead.Confidentiality.RESTRICTED: - para.add_run(' (restricted)') + para.add_run(" (restricted)") if self.citation_style == Export.CitationStyle.STYLE_1: pass @@ -657,36 +627,36 @@ def _generate_for_entry(self, entry): else: # Default para.add_run(f", {self.date_renderer(date)}") - para.add_run(')') + para.add_run(")") # --- Reference End if self.citation_style == Export.CitationStyle.STYLE_1: - para.add_run('.') + para.add_run(".") self.doc.add_paragraph() def _load_into_levels( - self, - entry, - keys, - levels, - entries_map, - valid_levels, + self, + entry, + keys, + levels, + entries_map, + valid_levels, ): """ Map an entry into corresponding levels """ parent_level_valid = False for level in levels: - level_id = level.get('id') - valid_level = (level_id in keys) + level_id = level.get("id") + valid_level = level_id in keys if valid_level: if not entries_map.get(level_id): entries_map[level_id] = [] entries_map[level_id].append(entry) - sublevels = level.get('sublevels') + sublevels = level.get("sublevels") if sublevels: valid_level = valid_level or self._load_into_levels( entry, @@ -703,44 +673,41 @@ def _load_into_levels( return parent_level_valid def _generate_for_levels( - self, - levels, - level_entries_map, - valid_levels, - structures=None, - heading_level=2, + self, + levels, + level_entries_map, + valid_levels, + structures=None, + heading_level=2, ): """ Generate paragraphs for all entries in this level and recursively do it for further sublevels """ if structures is not None: - level_map = dict((level.get('id'), level) for level in levels) - levels = [level_map[s['id']] for s in structures] + level_map = dict((level.get("id"), level) for level in levels) + levels = [level_map[s["id"]] for s in structures] for level in levels: - if level.get('id') not in valid_levels: + if level.get("id") not in valid_levels: continue - title = level.get('title') - entries = level_entries_map.get(level.get('id')) - sublevels = level.get('sublevels') + title = level.get("title") + entries = level_entries_map.get(level.get("id")) + sublevels = level.get("sublevels") if entries or sublevels: self.doc.add_heading(title, heading_level) self.doc.add_paragraph() if entries: - iterable_entries = entries[:Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries + iterable_entries = entries[: Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries [self._generate_for_entry(entry) for entry in iterable_entries] if sublevels: substructures = None if structures: - substructures = next(( - s.get('levels') for s in structures - if s['id'] == level.get('id') - ), None) + substructures = next((s.get("levels") for s in structures if s["id"] == level.get("id")), None) self._generate_for_levels( sublevels, @@ -751,10 +718,7 @@ def _generate_for_levels( ) def _generate_for_uncategorized(self, entries, categorized_entry_processed): - entries = entries.exclude( - Q(exportdata__data__report__keys__isnull=False) | - Q(exportdata__data__report__keys__len__gt=0) - ) + entries = entries.exclude(Q(exportdata__data__report__keys__isnull=False) | Q(exportdata__data__report__keys__len__gt=0)) if entries.count() == 0: return @@ -762,10 +726,10 @@ def _generate_for_uncategorized(self, entries, categorized_entry_processed): if self.is_preview and categorized_entry_processed >= Export.PREVIEW_ENTRY_SIZE: return - self.doc.add_heading('Uncategorized', 2) + self.doc.add_heading("Uncategorized", 2) self.doc.add_paragraph() - iterable_entries = entries[:Export.PREVIEW_ENTRY_SIZE - categorized_entry_processed] if self.is_preview else entries + iterable_entries = entries[: Export.PREVIEW_ENTRY_SIZE - categorized_entry_processed] if self.is_preview else entries for entry in iterable_entries: self._generate_for_entry(entry) @@ -774,7 +738,7 @@ def pre_build_document(self, project): Structure the document """ self.doc.add_heading( - 'DEEP Export — {} — {}'.format( + "DEEP Export — {} — {}".format( self.date_renderer(datetime.today()), project.title, ), @@ -782,7 +746,7 @@ def pre_build_document(self, project): ) self.doc.add_paragraph() - self.legend_heading = self.doc.add_heading('Legends', 2) + self.legend_heading = self.doc.add_heading("Legends", 2) self.legend_paragraph = self.doc.add_paragraph() def add_entries(self, entries): @@ -792,34 +756,31 @@ def add_entries(self, entries): if entries: self.pre_build_document(entries[0].project) exportables = self.exportables - af_levels_map = dict((str(level.get('id')), level.get('levels')) for level in self.levels) + af_levels_map = dict((str(level.get("id")), level.get("levels")) for level in self.levels) uncategorized = False categorized_entry_processed = 0 # NOTE: Used for preview limit only if self.structure: - ids = [s['id'] for s in self.structure] - uncategorized = 'uncategorized' in ids - ids = [id for id in ids if id != 'uncategorized'] - - order = Case(*[ - When(pk=pk, then=pos) - for pos, pk - in enumerate(ids) - ]) + ids = [s["id"] for s in self.structure] + uncategorized = "uncategorized" in ids + ids = [id for id in ids if id != "uncategorized"] + + order = Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(ids)]) exportables = exportables.filter(pk__in=ids).order_by(order) for exportable in exportables: levels = ( # Custom levels provided by client - af_levels_map.get(str(exportable.pk)) or + af_levels_map.get(str(exportable.pk)) + or # Predefined levels available in server - exportable.data.get('report').get('levels') + exportable.data.get("report").get("levels") ) level_entries_map = {} valid_levels = [] - iterable_entries = entries[:Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries + iterable_entries = entries[: Export.PREVIEW_ENTRY_SIZE] if self.is_preview else entries for entry in iterable_entries: # TODO # Set entry.report_data to all exportdata for all exportable @@ -833,17 +794,18 @@ def add_entries(self, entries): if export_data: self._load_into_levels( - entry, export_data.data.get('report').get('keys'), - levels, level_entries_map, valid_levels, + entry, + export_data.data.get("report").get("keys"), + levels, + level_entries_map, + valid_levels, ) categorized_entry_processed += 1 - structures = self.structure and next(( - s.get('levels') for s in self.structure - if str(s['id']) == str(exportable.id) - ), None) - self._generate_for_levels(levels, level_entries_map, - valid_levels, structures) + structures = self.structure and next( + (s.get("levels") for s in self.structure if str(s["id"]) == str(exportable.id)), None + ) + self._generate_for_levels(levels, level_entries_map, valid_levels, structures) if uncategorized: self._generate_for_uncategorized(entries, categorized_entry_processed) @@ -863,7 +825,7 @@ def export(self, pdf=False): self.doc.add_paragraph().add_horizontal_line() self.doc.add_paragraph() - self.doc.add_heading('Bibliography', 1) + self.doc.add_heading("Bibliography", 1) self.doc.add_paragraph() for lead in leads: @@ -872,12 +834,12 @@ def export(self, pdf=False): para = self.doc.add_paragraph() author = lead.get_authors_display() - source = lead.get_source_display() or 'Missing source' + source = lead.get_source_display() or "Missing source" if author: - para.add_run(f'{author}.') - para.add_run(f' {source}.') - para.add_run(f' {lead.title}.') + para.add_run(f"{author}.") + para.add_run(f" {source}.") + para.add_run(f" {lead.title}.") if lead.published_on: para.add_run(f" {self.date_renderer(lead.published_on)}. ") @@ -886,12 +848,12 @@ def export(self, pdf=False): if url: para.add_hyperlink(url, url) else: - para.add_run('Missing url.') + para.add_run("Missing url.") if lead.confidentiality == Lead.Confidentiality.CONFIDENTIAL: - para.add_run(' (confidential)') + para.add_run(" (confidential)") elif lead.confidentiality == Lead.Confidentiality.RESTRICTED: - para.add_run(' (restricted)') + para.add_run(" (restricted)") self.doc.add_paragraph() # self.doc.add_page_break() @@ -900,11 +862,11 @@ def export(self, pdf=False): temp_doc = tempfile.NamedTemporaryFile(dir=settings.TEMP_DIR) self.doc.save_to_file(temp_doc) - filename = temp_doc.name.split('/')[-1] - temp_pdf = os.path.join(settings.TEMP_DIR, '{}.pdf'.format(filename)) + filename = temp_doc.name.split("/")[-1] + temp_pdf = os.path.join(settings.TEMP_DIR, "{}.pdf".format(filename)) - call(['libreoffice', '--headless', '--convert-to', 'pdf', temp_doc.name, '--outdir', settings.TEMP_DIR]) - file = File(open(temp_pdf, 'rb')) + call(["libreoffice", "--headless", "--convert-to", "pdf", temp_doc.name, "--outdir", settings.TEMP_DIR]) + file = File(open(temp_pdf, "rb")) # Cleanup os.unlink(temp_pdf) diff --git a/apps/export/enums.py b/apps/export/enums.py index fd17b40bca..c42a047edd 100644 --- a/apps/export/enums.py +++ b/apps/export/enums.py @@ -5,23 +5,23 @@ from .models import Export, GenericExport -ExportFormatEnum = convert_enum_to_graphene_enum(Export.Format, name='ExportFormatEnum') -ExportStatusEnum = convert_enum_to_graphene_enum(Export.Status, name='ExportStatusEnum') -ExportDataTypeEnum = convert_enum_to_graphene_enum(Export.DataType, name='ExportDataTypeEnum') -ExportExportTypeEnum = convert_enum_to_graphene_enum(Export.ExportType, name='ExportExportTypeEnum') +ExportFormatEnum = convert_enum_to_graphene_enum(Export.Format, name="ExportFormatEnum") +ExportStatusEnum = convert_enum_to_graphene_enum(Export.Status, name="ExportStatusEnum") +ExportDataTypeEnum = convert_enum_to_graphene_enum(Export.DataType, name="ExportDataTypeEnum") +ExportExportTypeEnum = convert_enum_to_graphene_enum(Export.ExportType, name="ExportExportTypeEnum") ExportExcelSelectedStaticColumnEnum = convert_enum_to_graphene_enum( Export.StaticColumn, - name='ExportExcelSelectedStaticColumnEnum', + name="ExportExcelSelectedStaticColumnEnum", ) -ExportDateFormatEnum = convert_enum_to_graphene_enum(Export.DateFormat, name='ExportDateFormatEnum') +ExportDateFormatEnum = convert_enum_to_graphene_enum(Export.DateFormat, name="ExportDateFormatEnum") ExportReportCitationStyleEnum = convert_enum_to_graphene_enum( Export.CitationStyle, - name='ExportReportCitationStyleEnum', + name="ExportReportCitationStyleEnum", ) -GenericExportFormatEnum = convert_enum_to_graphene_enum(GenericExport.Format, name='GenericExportFormatEnum') -GenericExportStatusEnum = convert_enum_to_graphene_enum(GenericExport.Status, name='GenericExportStatusEnum') -GenericExportDataTypeEnum = convert_enum_to_graphene_enum(GenericExport.DataType, name='GenericExportDataTypeEnum') +GenericExportFormatEnum = convert_enum_to_graphene_enum(GenericExport.Format, name="GenericExportFormatEnum") +GenericExportStatusEnum = convert_enum_to_graphene_enum(GenericExport.Status, name="GenericExportStatusEnum") +GenericExportDataTypeEnum = convert_enum_to_graphene_enum(GenericExport.DataType, name="GenericExportDataTypeEnum") enum_map = { # Need to pass model with abstract base class @@ -39,20 +39,22 @@ ) } -enum_map.update({ - get_enum_name_from_django_field( - None, - field_name='static_column', - serializer_name='ExportExcelSelectedColumnSerializer', - ): ExportExcelSelectedStaticColumnEnum, - get_enum_name_from_django_field( - None, - field_name='date_format', - serializer_name='ExportExtraOptionsSerializer', - ): ExportDateFormatEnum, - get_enum_name_from_django_field( - None, - field_name='report_citation_style', - serializer_name='ExportExtraOptionsSerializer', - ): ExportReportCitationStyleEnum, -}) +enum_map.update( + { + get_enum_name_from_django_field( + None, + field_name="static_column", + serializer_name="ExportExcelSelectedColumnSerializer", + ): ExportExcelSelectedStaticColumnEnum, + get_enum_name_from_django_field( + None, + field_name="date_format", + serializer_name="ExportExtraOptionsSerializer", + ): ExportDateFormatEnum, + get_enum_name_from_django_field( + None, + field_name="report_citation_style", + serializer_name="ExportExtraOptionsSerializer", + ): ExportReportCitationStyleEnum, + } +) diff --git a/apps/export/exporters.py b/apps/export/exporters.py index f85f327b96..4251bd9d89 100644 --- a/apps/export/exporters.py +++ b/apps/export/exporters.py @@ -1,14 +1,9 @@ from utils.files import generate_json_file_for_upload - -DOCX_MIME_TYPE = \ - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' -PDF_MIME_TYPE = \ - 'application/pdf' -EXCEL_MIME_TYPE = \ - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' -JSON_MIME_TYPE = \ - 'application/json' +DOCX_MIME_TYPE = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" +PDF_MIME_TYPE = "application/pdf" +EXCEL_MIME_TYPE = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" +JSON_MIME_TYPE = "application/json" class Exporter: diff --git a/apps/export/factories.py b/apps/export/factories.py index ca182b15c9..ea18829d05 100644 --- a/apps/export/factories.py +++ b/apps/export/factories.py @@ -5,7 +5,7 @@ class ExportFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Export-{n}') + title = factory.Sequence(lambda n: f"Export-{n}") type = factory.fuzzy.FuzzyChoice(Export.DataType) class Meta: diff --git a/apps/export/filter_set.py b/apps/export/filter_set.py index aa5931b8c4..103b17755d 100644 --- a/apps/export/filter_set.py +++ b/apps/export/filter_set.py @@ -1,19 +1,15 @@ import django_filters - from project.models import Project + from utils.graphene.filters import ( - MultipleInputFilter, DateTimeFilter, DateTimeGteFilter, DateTimeLteFilter, + MultipleInputFilter, ) +from .enums import ExportDataTypeEnum, ExportFormatEnum, ExportStatusEnum from .models import Export -from .enums import ( - ExportDataTypeEnum, - ExportFormatEnum, - ExportStatusEnum, -) class ExportFilterSet(django_filters.rest_framework.FilterSet): @@ -21,41 +17,28 @@ class ExportFilterSet(django_filters.rest_framework.FilterSet): Export filter set Also make most fields filerable by multiple values """ + project = django_filters.ModelChoiceFilter( queryset=Project.objects.all(), - field_name='project', + field_name="project", ) ordering = django_filters.CharFilter( - method='ordering_filter', + method="ordering_filter", ) - status = django_filters.MultipleChoiceFilter( - choices=Export.Status.choices, - widget=django_filters.widgets.CSVWidget - ) + status = django_filters.MultipleChoiceFilter(choices=Export.Status.choices, widget=django_filters.widgets.CSVWidget) - type = django_filters.MultipleChoiceFilter( - choices=Export.DataType.choices, - widget=django_filters.widgets.CSVWidget - ) - exported_at__lt = django_filters.DateFilter( - field_name='exported_at', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'] - ) - exported_at__gte = django_filters.DateFilter( - field_name='exported_at', - lookup_expr='gte', - input_formats=['%Y-%m-%d%z'] - ) + type = django_filters.MultipleChoiceFilter(choices=Export.DataType.choices, widget=django_filters.widgets.CSVWidget) + exported_at__lt = django_filters.DateFilter(field_name="exported_at", lookup_expr="lte", input_formats=["%Y-%m-%d%z"]) + exported_at__gte = django_filters.DateFilter(field_name="exported_at", lookup_expr="gte", input_formats=["%Y-%m-%d%z"]) class Meta: model = Export - fields = ['is_archived'] + fields = ["is_archived"] def ordering_filter(self, qs, name, value): - orderings = [x.strip() for x in value.split(',') if x.strip()] + orderings = [x.strip() for x in value.split(",") if x.strip()] for ordering in orderings: qs = qs.order_by(ordering) @@ -67,10 +50,10 @@ class ExportGQLFilterSet(django_filters.rest_framework.FilterSet): format = MultipleInputFilter(ExportFormatEnum) status = MultipleInputFilter(ExportStatusEnum) - search = django_filters.CharFilter(field_name='title', lookup_expr='icontains') + search = django_filters.CharFilter(field_name="title", lookup_expr="icontains") exported_at = DateTimeFilter() - exported_at_gte = DateTimeGteFilter(field_name='exported_at') - exported_at_lte = DateTimeLteFilter(field_name='exported_at') + exported_at_gte = DateTimeGteFilter(field_name="exported_at") + exported_at_lte = DateTimeLteFilter(field_name="exported_at") class Meta: model = Export diff --git a/apps/export/formats/docx.py b/apps/export/formats/docx.py index ac9203568f..de6ddf7fee 100644 --- a/apps/export/formats/docx.py +++ b/apps/export/formats/docx.py @@ -1,23 +1,20 @@ -import docx -import requests +import base64 import io +import logging import re import tempfile -import base64 -import logging from uuid import uuid4 - +import docx +import requests from docx.enum.dml import MSO_THEME_COLOR_INDEX from docx.oxml import OxmlElement, oxml_parser from docx.oxml.ns import qn from docx.shared import Pt, RGBColor - from PIL import Image from utils.common import get_valid_xml_string - logger = logging.getLogger(__name__) @@ -61,6 +58,7 @@ class Run: """ Single run inside a paragraph """ + def __init__(self, ref): self.ref = ref @@ -72,19 +70,19 @@ def add_image(self, image): try: if image and len(image) > 0: fimage = tempfile.NamedTemporaryFile() - if re.search(r'http[s]?://', image): + if re.search(r"http[s]?://", image): image = requests.get(image, stream=True, timeout=2) _write_file(image, fimage) else: - image = base64.b64decode(image.split(',')[1]) + image = base64.b64decode(image.split(",")[1]) fimage.write(image) self.ref.add_picture(fimage) except Exception: - self.add_text('Invalid Image') + self.add_text("Invalid Image") def add_font_color(self, hex_color_string=None): - hex_color_string = hex_color_string or '#000000' - if '#' in hex_color_string: + hex_color_string = hex_color_string or "#000000" + if "#" in hex_color_string: hex_color_string = hex_color_string[1:] color = RGBColor.from_string(hex_color_string) self.ref.font.color.rgb = color @@ -94,20 +92,20 @@ def add_shading(self, hex_color_string=None): XML representation """ - hex_color_string = hex_color_string or '#888888' - if '#' in hex_color_string: + hex_color_string = hex_color_string or "#888888" + if "#" in hex_color_string: hex_color_string = hex_color_string[1:] rPr = self.ref._element.get_or_add_rPr() - ele = OxmlElement('w:shd') - ele.set(qn('w:fill'), hex_color_string) + ele = OxmlElement("w:shd") + ele.set(qn("w:fill"), hex_color_string) rPr.append(ele) def add_inline_image(self, image, width, height): inline = self.ref.part.new_pic_inline(image, width, height) # Remove left/right spacing - inline.set('distL', '0') - inline.set('distR', '0') + inline.set("distL", "0") + inline.set("distR", "0") return self.ref._r.add_drawing(inline) def add_oval_shape(self, fill_hex_color=None): @@ -115,30 +113,22 @@ def add_oval_shape(self, fill_hex_color=None): https://python-docx.readthedocs.io/en/latest/user/shapes.html https://docs.microsoft.com/en-us/windows/win32/vml/web-workshop---specs---standards----how-to-use-vml-on-web-pages """ - fill_hex_color = fill_hex_color or '#ffffff' + fill_hex_color = fill_hex_color or "#ffffff" color = fill_hex_color - if '#' != color[0]: - color = '#' + color + if "#" != color[0]: + color = "#" + color - pict = OxmlElement('w:pict') - nsmap = dict( - v='urn:schemas-microsoft-com:vml' - ) + pict = OxmlElement("w:pict") + nsmap = dict(v="urn:schemas-microsoft-com:vml") oval_attrs = dict( id=str(uuid4()), - style='width:12pt;height:12pt;z-index:-251658240;mso-position-vertical:top;mso-position-horizontal:left', + style="width:12pt;height:12pt;z-index:-251658240;mso-position-vertical:top;mso-position-horizontal:left", fillcolor=color, ) - oval = oxml_parser.makeelement('{%s}%s' % (nsmap['v'], 'oval'), - attrib=oval_attrs, nsmap=nsmap) + oval = oxml_parser.makeelement("{%s}%s" % (nsmap["v"], "oval"), attrib=oval_attrs, nsmap=nsmap) - border_attrs = dict( - color='gray', - joinstyle='round', - endcap='flat' - ) - stroke = oxml_parser.makeelement('{%s}%s' % (nsmap['v'], 'stroke'), - attrib=border_attrs, nsmap=nsmap) + border_attrs = dict(color="gray", joinstyle="round", endcap="flat") + stroke = oxml_parser.makeelement("{%s}%s" % (nsmap["v"], "stroke"), attrib=border_attrs, nsmap=nsmap) oval.append(stroke) pict.append(oval) self.ref._element.append(pict) @@ -149,6 +139,7 @@ class Paragraph: One paragraph: supports normal text runs, hyperlinks, horizontal lines. """ + def __init__(self, ref): self.ref = ref @@ -165,11 +156,11 @@ def add_hyperlink(self, url, text): is_external=True, ) - hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink') - hyperlink.set(docx.oxml.shared.qn('r:id'), r_id) + hyperlink = docx.oxml.shared.OxmlElement("w:hyperlink") + hyperlink.set(docx.oxml.shared.qn("r:id"), r_id) - new_run = docx.oxml.shared.OxmlElement('w:r') - r_pr = docx.oxml.shared.OxmlElement('w:rPr') + new_run = docx.oxml.shared.OxmlElement("w:r") + r_pr = docx.oxml.shared.OxmlElement("w:rPr") new_run.append(r_pr) new_run.text = get_valid_xml_string(text) @@ -186,31 +177,53 @@ def add_hyperlink(self, url, text): def add_horizontal_line(self): p = self.ref._p p_pr = p.get_or_add_pPr() - p_bdr = OxmlElement('w:pBdr') - - _insert_element_before(p_pr, p_bdr, successors=( - 'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', - 'w:wordWrap', 'w:overflowPunct', 'w:topLinePunct', - 'w:autoSpaceDE', 'w:autoSpaceDN', 'w:bidi', 'w:adjustRightInd', - 'w:snapToGrid', 'w:spacing', 'w:ind', 'w:contextualSpacing', - 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc', - 'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap', - 'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr', - 'w:pPrChange' - )) - - bottom = OxmlElement('w:bottom') - bottom.set(qn('w:val'), 'single') - bottom.set(qn('w:sz'), '6') - bottom.set(qn('w:space'), '1') - bottom.set(qn('w:color'), 'auto') + p_bdr = OxmlElement("w:pBdr") + + _insert_element_before( + p_pr, + p_bdr, + successors=( + "w:shd", + "w:tabs", + "w:suppressAutoHyphens", + "w:kinsoku", + "w:wordWrap", + "w:overflowPunct", + "w:topLinePunct", + "w:autoSpaceDE", + "w:autoSpaceDN", + "w:bidi", + "w:adjustRightInd", + "w:snapToGrid", + "w:spacing", + "w:ind", + "w:contextualSpacing", + "w:mirrorIndents", + "w:suppressOverlap", + "w:jc", + "w:textDirection", + "w:textAlignment", + "w:textboxTightWrap", + "w:outlineLvl", + "w:divId", + "w:cnfStyle", + "w:rPr", + "w:sectPr", + "w:pPrChange", + ), + ) + + bottom = OxmlElement("w:bottom") + bottom.set(qn("w:val"), "single") + bottom.set(qn("w:sz"), "6") + bottom.set(qn("w:space"), "1") + bottom.set(qn("w:color"), "auto") p_bdr.append(bottom) return self def justify(self): - self.ref.paragraph_format.alignment = \ - docx.enum.text.WD_ALIGN_PARAGRAPH.JUSTIFY + self.ref.paragraph_format.alignment = docx.enum.text.WD_ALIGN_PARAGRAPH.JUSTIFY return self def delete(self): @@ -223,7 +236,7 @@ def add_shaded_text(self, text, color): run.add_shading(color) def add_oval_shape(self, color): - run = self.add_run(' ') + run = self.add_run(" ") run.add_oval_shape(color) def add_next_paragraph(self, other): @@ -235,6 +248,7 @@ class Document: """ A docx document representation """ + def __init__(self, template=None): self.doc = docx.Document(template) @@ -245,27 +259,20 @@ def add_image(self, image): try: sec = self.doc.sections[-1] try: - cols = int( - sec._sectPr.xpath('./w:cols')[0].get(qn('w:num')) - ) - width = ( - (sec.page_width / cols) - - (sec.right_margin + sec.left_margin) - ) + cols = int(sec._sectPr.xpath("./w:cols")[0].get(qn("w:num"))) + width = (sec.page_width / cols) - (sec.right_margin + sec.left_margin) except Exception: - width = ( - sec.page_width - (sec.right_margin + sec.left_margin) - ) + width = sec.page_width - (sec.right_margin + sec.left_margin) - if hasattr(image, 'read'): + if hasattr(image, "read"): fimage = image elif image and len(image): fimage = tempfile.NamedTemporaryFile() - if re.search(r'http[s]?://', image): + if re.search(r"http[s]?://", image): image = requests.get(image, stream=True, timeout=2) _write_file(image, fimage) else: - image = base64.b64decode(image.split(',')[1]) + image = base64.b64decode(image.split(",")[1]) fimage.write(image) image_width, _ = Image.open(fimage).size @@ -278,9 +285,9 @@ def add_image(self, image): self.doc.paragraphs[-1].alignment = docx.enum.text.WD_ALIGN_PARAGRAPH.CENTER return self except Exception: - self.doc.add_paragraph('Invalid Image') + self.doc.add_paragraph("Invalid Image") logger.error( - 'export.formats.docx Add Image Error!!', + "export.formats.docx Add Image Error!!", exc_info=True, ) return self diff --git a/apps/export/formats/xlsx.py b/apps/export/formats/xlsx.py index f6afe716d7..ff91421284 100644 --- a/apps/export/formats/xlsx.py +++ b/apps/export/formats/xlsx.py @@ -1,16 +1,10 @@ from collections import OrderedDict - from openpyxl import Workbook from openpyxl.utils import get_column_letter from openpyxl.writer.excel import save_virtual_workbook -from utils.common import ( - get_valid_xml_string, - parse_date, - parse_time, - parse_number, -) +from utils.common import get_valid_xml_string, parse_date, parse_number, parse_time def xstr(value): @@ -23,6 +17,7 @@ class WorkBook: """ An xlsx workbook """ + def __init__(self): self.wb = Workbook() @@ -37,15 +32,15 @@ def save(self): COL_TYPES = { - 'date': 'dd-mm-yyyy', - 'time': 'HH:MM', - 'number': '', + "date": "dd-mm-yyyy", + "time": "HH:MM", + "number": "", } TYPE_CONVERTERS = { - 'date': lambda x: parse_date(x) or x, - 'time': lambda x: parse_time(x) or x, - 'number': lambda x: parse_number(x) or x, + "date": lambda x: parse_date(x) or x, + "time": lambda x: parse_time(x) or x, + "number": lambda x: parse_number(x) or x, } @@ -53,6 +48,7 @@ class WorkSheet: """ A worksheet inside a workbook """ + def __init__(self, ws): self.ws = ws @@ -78,9 +74,9 @@ def _set_cell_type(self, cell, col_type): def set_col_types(self, col_types): for col_index, col_type in col_types.items(): for cell_t in self.ws.iter_rows( - min_row=2, - min_col=col_index + 1, - max_col=col_index + 1, + min_row=2, + min_col=col_index + 1, + max_col=col_index + 1, ): if len(cell_t) < 1: continue @@ -92,6 +88,7 @@ class RowsBuilder: """ Rows builder to build rows that permute with new rows """ + def __init__(self, split_sheet=None, group_sheet=None, split=True): self.rows = [[]] self.group_rows = [] @@ -121,7 +118,7 @@ def add_rows_of_values(self, rows): num = len(values) if num == 0: - self.add_value('') + self.add_value("") return self if num == 1: @@ -141,7 +138,7 @@ def add_rows_of_values(self, rows): for j in range(0, len(oldrows)): self.rows[i * len(oldrows) + j].append(values[i]) - self.group_rows.append(', '.join(values)) + self.group_rows.append(", ".join(values)) return self @@ -156,7 +153,7 @@ def add_rows_of_value_lists(self, rows, col_span=1): num = len(values) if num == 0: - self.add_value_list([''] * col_span) + self.add_value_list([""] * col_span) return self if num == 1: @@ -181,10 +178,12 @@ def add_rows_of_value_lists(self, rows, col_span=1): # Convert each zipped to list and convert overall to list as well for column in list(map(list, zip(*values))): # Make sure each column only contains unique values - self.group_rows.append(', '.join( - # sorted(list(dict.fromkeys(column))) - list(OrderedDict.fromkeys(filter(lambda x: x not in [None, ''], column))) - )) + self.group_rows.append( + ", ".join( + # sorted(list(dict.fromkeys(column))) + list(OrderedDict.fromkeys(filter(lambda x: x not in [None, ""], column))) + ) + ) return self diff --git a/apps/export/mime_types.py b/apps/export/mime_types.py index d955e22129..705c6fba52 100644 --- a/apps/export/mime_types.py +++ b/apps/export/mime_types.py @@ -1,9 +1,5 @@ -DOCX_MIME_TYPE = \ - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' -PDF_MIME_TYPE = \ - 'application/pdf' -EXCEL_MIME_TYPE = \ - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' -CSV_MIME_TYPE = 'text/csv' -JSON_MIME_TYPE = \ - 'application/json' +DOCX_MIME_TYPE = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" +PDF_MIME_TYPE = "application/pdf" +EXCEL_MIME_TYPE = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" +CSV_MIME_TYPE = "text/csv" +JSON_MIME_TYPE = "application/json" diff --git a/apps/export/models.py b/apps/export/models.py index 25bf7680dc..06da8642dc 100644 --- a/apps/export/models.py +++ b/apps/export/models.py @@ -1,15 +1,12 @@ -import typing import datetime +import typing -from django.db import models -from django.core.cache import cache +from analysis.models import Analysis from django.contrib.auth.models import User +from django.core.cache import cache +from django.db import models from django.utils import timezone from django.utils.crypto import get_random_string - -from deep.caches import CacheKey -from deep.celery import app as celery_app -from project.models import Project from export.mime_types import ( CSV_MIME_TYPE, DOCX_MIME_TYPE, @@ -17,31 +14,34 @@ JSON_MIME_TYPE, PDF_MIME_TYPE, ) -from analysis.models import Analysis +from project.models import Project + +from deep.caches import CacheKey +from deep.celery import app as celery_app def export_upload_to(instance, filename: str) -> str: random_string = get_random_string(length=10) - prefix = 'export' + prefix = "export" if isinstance(instance, GenericExport): - prefix = 'global-export' - return f'{prefix}/{random_string}/{filename}' + prefix = "global-export" + return f"{prefix}/{random_string}/{filename}" class ExportBaseModel(models.Model): class Status(models.TextChoices): - PENDING = 'pending', 'Pending' - STARTED = 'started', 'Started' - SUCCESS = 'success', 'Success' - FAILURE = 'failure', 'Failure' - CANCELED = 'canceled', 'Canceled' + PENDING = "pending", "Pending" + STARTED = "started", "Started" + SUCCESS = "success", "Success" + FAILURE = "failure", "Failure" + CANCELED = "canceled", "Canceled" class Format(models.TextChoices): - CSV = 'csv', 'csv' - XLSX = 'xlsx', 'xlsx' - DOCX = 'docx', 'docx' - PDF = 'pdf', 'pdf' - JSON = 'json', 'json' + CSV = "csv", "csv" + XLSX = "xlsx", "xlsx" + DOCX = "docx", "docx" + PDF = "pdf", "pdf" + JSON = "json", "json" # Mime types MIME_TYPE_MAP = { @@ -51,12 +51,12 @@ class Format(models.TextChoices): Format.PDF: PDF_MIME_TYPE, Format.JSON: JSON_MIME_TYPE, } - DEFAULT_MIME_TYPE = 'application/octet-stream' + DEFAULT_MIME_TYPE = "application/octet-stream" # Used to validate which combination is supported and provide default title DEFAULT_TITLE_LABEL = {} - CELERY_TASK_CACHE_KEY = 'N/A' + CELERY_TASK_CACHE_KEY = "N/A" title = models.CharField(max_length=255) @@ -79,10 +79,7 @@ def __str__(self): @classmethod def get_for(cls, user): - return cls.objects.filter( - exported_by=user, - is_deleted=False - ).distinct() + return cls.objects.filter(exported_by=user, is_deleted=False).distinct() def set_task_id(self, async_id): # Defined timeout is arbitrary now. @@ -105,7 +102,7 @@ def cancel(self, commit=True): celery_app.control.revoke(self.get_task_id(clear=True), terminate=True) self.status = self.Status.CANCELED if commit: - self.save(update_fields=('status',)) + self.save(update_fields=("status",)) class Export(ExportBaseModel): @@ -115,46 +112,47 @@ class Export(ExportBaseModel): Represents an exported file along with few other attributes Scoped by a project """ + Format = ExportBaseModel.Format class DataType(models.TextChoices): - ENTRIES = 'entries', 'Entries' - ASSESSMENTS = 'assessments', 'Assessments' - PLANNED_ASSESSMENTS = 'planned_assessments', 'Planned Assessments' - ANALYSES = 'analyses', 'Analysis' + ENTRIES = "entries", "Entries" + ASSESSMENTS = "assessments", "Assessments" + PLANNED_ASSESSMENTS = "planned_assessments", "Planned Assessments" + ANALYSES = "analyses", "Analysis" class ExportType(models.TextChoices): - EXCEL = 'excel', 'Excel' - REPORT = 'report', 'Report' - JSON = 'json', 'Json' + EXCEL = "excel", "Excel" + REPORT = "report", "Report" + JSON = "json", "Json" # Used by extra options class StaticColumn(models.TextChoices): - LEAD_PUBLISHED_ON = 'lead_published_on', 'Date of Source Publication' - ENTRY_CREATED_BY = 'entry_created_by', 'Imported By' - ENTRY_CREATED_AT = 'entry_created_at', 'Date Imported' - ENTRY_CONTROL_STATUS = 'entry_control_status', 'Verification Status' - LEAD_ID = 'lead_id', 'Source Id' - LEAD_TITLE = 'lead_title', 'Source Title' - LEAD_URL = 'lead_url', 'Source URL' - LEAD_PAGE_COUNT = 'lead_page_count', 'Page Count' - LEAD_ORGANIZATION_TYPE_AUTHOR = 'lead_organization_type_author', 'Authoring Organizations Type' - LEAD_ORGANIZATION_AUTHOR = 'lead_organization_author', 'Author' - LEAD_ORGANIZATION_SOURCE = 'lead_organization_source', 'Publisher' - LEAD_PRIORITY = 'lead_priority', 'Source Priority' - LEAD_ASSIGNEE = 'lead_assignee', 'Assignee' - ENTRY_ID = 'entry_id', 'Entry Id' - LEAD_ENTRY_ID = 'lead_entry_id', 'Source-Entry Id' - ENTRY_EXCERPT = 'entry_excerpt', 'Modified Excerpt, Original Excerpt' + LEAD_PUBLISHED_ON = "lead_published_on", "Date of Source Publication" + ENTRY_CREATED_BY = "entry_created_by", "Imported By" + ENTRY_CREATED_AT = "entry_created_at", "Date Imported" + ENTRY_CONTROL_STATUS = "entry_control_status", "Verification Status" + LEAD_ID = "lead_id", "Source Id" + LEAD_TITLE = "lead_title", "Source Title" + LEAD_URL = "lead_url", "Source URL" + LEAD_PAGE_COUNT = "lead_page_count", "Page Count" + LEAD_ORGANIZATION_TYPE_AUTHOR = "lead_organization_type_author", "Authoring Organizations Type" + LEAD_ORGANIZATION_AUTHOR = "lead_organization_author", "Author" + LEAD_ORGANIZATION_SOURCE = "lead_organization_source", "Publisher" + LEAD_PRIORITY = "lead_priority", "Source Priority" + LEAD_ASSIGNEE = "lead_assignee", "Assignee" + ENTRY_ID = "entry_id", "Entry Id" + LEAD_ENTRY_ID = "lead_entry_id", "Source-Entry Id" + ENTRY_EXCERPT = "entry_excerpt", "Modified Excerpt, Original Excerpt" # Used by extra options for Report class CitationStyle(models.IntegerChoices): - DEFAULT = 1, 'Default' - STYLE_1 = 2, 'Sample 1' # TODO: Update naming + DEFAULT = 1, "Default" + STYLE_1 = 2, "Sample 1" # TODO: Update naming __description__ = { - DEFAULT: 'Entry excerpt. (Author[link], Publisher, Published Date)', - STYLE_1: 'Entry excerpt (Author[link] Published Date).', + DEFAULT: "Entry excerpt. (Author[link], Publisher, Published Date)", + STYLE_1: "Entry excerpt (Author[link] Published Date).", } # Used by extra options @@ -163,25 +161,25 @@ class CitationStyle(models.IntegerChoices): # https://github.com/toggle-corp/fujs/blob/3b1b64199dad249c81d57fc4d26ed800bdccca13/src/date.ts#L77 # TODO: Add a unit test to make sure all label are valid class DateFormat(models.TextChoices): - DEFAULT = '%d-%m-%Y', 'dd-MM-yyyy' - FORMAT_1 = '%d/%m/%Y', 'dd/MM/yyyy' + DEFAULT = "%d-%m-%Y", "dd-MM-yyyy" + FORMAT_1 = "%d/%m/%Y", "dd/MM/yyyy" __description__ = { - DEFAULT: '23-11-2021', - FORMAT_1: '23/11/2021', + DEFAULT: "23-11-2021", + FORMAT_1: "23/11/2021", } # NOTE: Also used to validate which combination is supported DEFAULT_TITLE_LABEL = { - (DataType.ENTRIES, ExportType.EXCEL, Format.XLSX): 'Entries Excel Export', - (DataType.ENTRIES, ExportType.REPORT, Format.DOCX): 'Entries General Export', - (DataType.ENTRIES, ExportType.REPORT, Format.PDF): 'Entries General Export', - (DataType.ENTRIES, ExportType.JSON, Format.JSON): 'Entries JSON Export', - (DataType.ASSESSMENTS, ExportType.EXCEL, Format.XLSX): 'Assessments Excel Export', - (DataType.ASSESSMENTS, ExportType.JSON, Format.JSON): 'Assessments JSON Export', - (DataType.PLANNED_ASSESSMENTS, ExportType.EXCEL, Format.XLSX): 'Planned Assessments Excel Export', - (DataType.PLANNED_ASSESSMENTS, ExportType.JSON, Format.JSON): 'Planned Assessments JSON Export', - (DataType.ANALYSES, ExportType.EXCEL, Format.XLSX): 'Analysis Excel Export', + (DataType.ENTRIES, ExportType.EXCEL, Format.XLSX): "Entries Excel Export", + (DataType.ENTRIES, ExportType.REPORT, Format.DOCX): "Entries General Export", + (DataType.ENTRIES, ExportType.REPORT, Format.PDF): "Entries General Export", + (DataType.ENTRIES, ExportType.JSON, Format.JSON): "Entries JSON Export", + (DataType.ASSESSMENTS, ExportType.EXCEL, Format.XLSX): "Assessments Excel Export", + (DataType.ASSESSMENTS, ExportType.JSON, Format.JSON): "Assessments JSON Export", + (DataType.PLANNED_ASSESSMENTS, ExportType.EXCEL, Format.XLSX): "Planned Assessments Excel Export", + (DataType.PLANNED_ASSESSMENTS, ExportType.JSON, Format.JSON): "Planned Assessments JSON Export", + (DataType.ANALYSES, ExportType.EXCEL, Format.XLSX): "Analysis Excel Export", } CELERY_TASK_CACHE_KEY = CacheKey.EXPORT_TASK_CACHE_KEY_FORMAT @@ -206,7 +204,9 @@ class DateFormat(models.TextChoices): # used for analysis export analysis = models.ForeignKey( - Analysis, null=True, blank=True, + Analysis, + null=True, + blank=True, verbose_name="analysis", on_delete=models.SET_NULL, ) @@ -214,16 +214,13 @@ class DateFormat(models.TextChoices): @classmethod def generate_title(cls, data_type, export_type, export_format): file_label = cls.DEFAULT_TITLE_LABEL[(data_type, export_type, export_format)] - time_str = timezone.now().strftime('%Y%m%d') - return f'{time_str} DEEP {file_label}' + time_str = timezone.now().strftime("%Y%m%d") + return f"{time_str} DEEP {file_label}" @classmethod def get_date_renderer(cls, date_format: DateFormat) -> typing.Callable: - def custom_format(d, fallback: typing.Optional[str] = ''): - if d and ( - isinstance(d, datetime.datetime) or - isinstance(d, datetime.date) - ): + def custom_format(d, fallback: typing.Optional[str] = ""): + if d and (isinstance(d, datetime.datetime) or isinstance(d, datetime.date)): return d.strftime(date_format) if date_format else fallback return fallback @@ -238,15 +235,16 @@ class GenericExport(ExportBaseModel): """ Async export tasks not scoped by a project """ + Format = ExportBaseModel.Format class DataType(models.TextChoices): - PROJECTS_STATS = 'projects_stats', 'Projects Stats' + PROJECTS_STATS = "projects_stats", "Projects Stats" CELERY_TASK_CACHE_KEY = CacheKey.GENERIC_EXPORT_TASK_CACHE_KEY_FORMAT DEFAULT_TITLE_LABEL = { - (DataType.PROJECTS_STATS, Format.CSV): 'Projects Stats', + (DataType.PROJECTS_STATS, Format.CSV): "Projects Stats", } type = models.CharField(max_length=99, choices=DataType.choices) @@ -257,8 +255,8 @@ class DataType(models.TextChoices): @classmethod def generate_title(cls, data_type, export_format): file_label = cls.DEFAULT_TITLE_LABEL[(data_type, export_format)] - time_str = timezone.now().strftime('%Y%m%d') - return f'{time_str} Generic DEEP {file_label}' + time_str = timezone.now().strftime("%Y%m%d") + return f"{time_str} Generic DEEP {file_label}" def save(self, *args, **kwargs): self.title = self.title or self.generate_title(self.type, self.format) diff --git a/apps/export/mutation.py b/apps/export/mutation.py index df30c6ded1..31604c5aaf 100644 --- a/apps/export/mutation.py +++ b/apps/export/mutation.py @@ -1,11 +1,11 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, GrapheneMutation, PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP from .models import Export from .schema import ( @@ -20,33 +20,32 @@ UserGenericExportCreateGqlSerializer, ) - ExportCreateInputType = generate_input_type_for_serializer( - 'ExportCreateInputType', + "ExportCreateInputType", serializer_class=UserExportCreateGqlSerializer, ) ExportUpdateInputType = generate_input_type_for_serializer( - 'ExportUpdateInputType', + "ExportUpdateInputType", serializer_class=UserExportUpdateGqlSerializer, partial=True, ) GenericExportCreateInputType = generate_input_type_for_serializer( - 'GenericExportCreateInputType', + "GenericExportCreateInputType", serializer_class=UserGenericExportCreateGqlSerializer, ) -class UserExportMutationMixin(): +class UserExportMutationMixin: @classmethod def filter_queryset(cls, _, info): return get_export_qs(info) -class UserGenericExportMutationMixin(): +class UserGenericExportMutationMixin: @classmethod def filter_queryset(cls, _, info): return get_generic_export_qs(info) @@ -55,6 +54,7 @@ def filter_queryset(cls, _, info): class CreateUserExport(PsGrapheneMutation): class Arguments: data = ExportCreateInputType(required=True) + model = Export serializer_class = UserExportCreateGqlSerializer result = graphene.Field(UserExportType) @@ -65,6 +65,7 @@ class UpdateUserExport(UserExportMutationMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) data = ExportUpdateInputType(required=True) + model = Export serializer_class = UserExportUpdateGqlSerializer result = graphene.Field(UserExportType) @@ -74,6 +75,7 @@ class Arguments: class CancelUserExport(UserExportMutationMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) + model = Export result = graphene.Field(UserExportType) permissions = [PP.Permission.CREATE_EXPORT] @@ -90,6 +92,7 @@ def perform_mutate(cls, root, info, **kwargs): class DeleteUserExport(UserExportMutationMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) + model = Export result = graphene.Field(UserExportType) permissions = [PP.Permission.CREATE_EXPORT] @@ -101,7 +104,12 @@ def perform_mutate(cls, root, info, **kwargs): return cls(result=export, errors=errors, ok=True) export.cancel(commit=False) export.is_deleted = True # Soft delete - export.save(update_fields=('status', 'is_deleted',)) + export.save( + update_fields=( + "status", + "is_deleted", + ) + ) return cls(result=export, errors=None, ok=True) @@ -123,6 +131,7 @@ def check_permissions(cls, *args, **_): class CancelUserGenericExport(UserGenericExportMutationMixin, GrapheneMutation): class Arguments: id = graphene.ID(required=True) + model = Export result = graphene.Field(UserGenericExportType) @@ -139,13 +148,13 @@ def perform_mutate(cls, root, info, **kwargs): return cls(result=export, errors=None, ok=True) -class ProjectMutation(): +class ProjectMutation: export_create = CreateUserExport.Field() export_update = UpdateUserExport.Field() export_cancel = CancelUserExport.Field() export_delete = DeleteUserExport.Field() -class Mutation(): +class Mutation: generic_export_create = CreateUserGenericExport.Field() generic_export_cancel = CancelUserGenericExport.Field() diff --git a/apps/export/schema.py b/apps/export/schema.py index cc66fc8edb..fb18bc28a0 100644 --- a/apps/export/schema.py +++ b/apps/export/schema.py @@ -1,32 +1,30 @@ import graphene - from django.db.models import QuerySet -from graphene_django import DjangoObjectType from graphene.types.generic import GenericScalar +from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from lead.filter_set import LeadsFilterDataType +from lead.schema import LeadFilterDataType, get_lead_filter_data from deep.serializers import URLCachedFileField -from utils.graphene.types import CustomDjangoListObjectType, FileFieldType -from utils.graphene.fields import DjangoPaginatedListObjectField, generate_type_for_serializer - -from lead.schema import ( - LeadFilterDataType, - get_lead_filter_data, +from utils.graphene.fields import ( + DjangoPaginatedListObjectField, + generate_type_for_serializer, ) +from utils.graphene.types import CustomDjangoListObjectType, FileFieldType -from lead.filter_set import LeadsFilterDataType -from .serializers import ExportExtraOptionsSerializer -from .models import Export, GenericExport -from .filter_set import ExportGQLFilterSet from .enums import ( ExportDataTypeEnum, + ExportExportTypeEnum, ExportFormatEnum, ExportStatusEnum, - ExportExportTypeEnum, + GenericExportDataTypeEnum, GenericExportFormatEnum, GenericExportStatusEnum, - GenericExportDataTypeEnum, ) +from .filter_set import ExportGQLFilterSet +from .models import Export, GenericExport +from .serializers import ExportExtraOptionsSerializer def get_export_qs(info): @@ -42,7 +40,7 @@ def get_generic_export_qs(info): ExportExtraOptionsType = generate_type_for_serializer( - 'ExportExtraOptionsType', + "ExportExtraOptionsType", serializer_class=ExportExtraOptionsSerializer, ) @@ -51,13 +49,21 @@ class UserExportType(DjangoObjectType): class Meta: model = Export only_fields = ( - 'id', 'project', 'is_preview', 'title', - 'mime_type', 'extra_options', 'exported_by', - 'exported_at', 'started_at', 'ended_at', 'is_archived', - 'analysis', + "id", + "project", + "is_preview", + "title", + "mime_type", + "extra_options", + "exported_by", + "exported_at", + "started_at", + "ended_at", + "is_archived", + "analysis", ) - project = graphene.ID(source='project_id') + project = graphene.ID(source="project_id") format = graphene.Field(graphene.NonNull(ExportFormatEnum)) type = graphene.Field(graphene.NonNull(ExportDataTypeEnum)) status = graphene.Field(graphene.NonNull(ExportStatusEnum)) @@ -82,10 +88,7 @@ def resolve_file_download_url(root, info, **kwargs): if root.file and root.file.name: return info.context.request.build_absolute_uri( URLCachedFileField.generate_url( - root.file.name, - parameters={ - 'ResponseContentDisposition': f'filename = "{root.title}.{root.format}"' - } + root.file.name, parameters={"ResponseContentDisposition": f'filename = "{root.title}.{root.format}"'} ) ) @@ -94,13 +97,13 @@ class UserGenericExportType(DjangoObjectType): class Meta: model = GenericExport only_fields = ( - 'id', - 'title', - 'mime_type', - 'exported_by', - 'exported_at', - 'started_at', - 'ended_at', + "id", + "title", + "mime_type", + "exported_by", + "exported_at", + "started_at", + "ended_at", ) format = graphene.Field(graphene.NonNull(GenericExportFormatEnum)) @@ -120,10 +123,7 @@ def resolve_file_download_url(root, info, **kwargs): if root.file and root.file.name: return info.context.request.build_absolute_uri( URLCachedFileField.generate_url( - root.file.name, - parameters={ - 'ResponseContentDisposition': f'filename = "{root.title}.{root.format}"' - } + root.file.name, parameters={"ResponseContentDisposition": f'filename = "{root.title}.{root.format}"'} ) ) @@ -137,10 +137,7 @@ class Meta: class ProjectQuery: export = DjangoObjectField(UserExportType) exports = DjangoPaginatedListObjectField( - UserExportListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + UserExportListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod @@ -148,5 +145,5 @@ def resolve_exports(root, info, **kwargs) -> QuerySet: return get_export_qs(info).filter(is_preview=False) -class Query(): +class Query: generic_export = DjangoObjectField(UserGenericExportType) diff --git a/apps/export/serializers.py b/apps/export/serializers.py index 5cadb05c83..d9dc1265cf 100644 --- a/apps/export/serializers.py +++ b/apps/export/serializers.py @@ -1,23 +1,23 @@ +from analysis_framework.models import Exportable, Widget +from deep_explore.filter_set import ExploreProjectFilterSet +from deep_explore.schema import ExploreDeepFilterInputType from django.db import transaction - from drf_dynamic_fields import DynamicFieldsMixin +from entry.filter_set import EntriesFilterDataInputType, EntryGQFilterSet +from lead.filter_set import LeadGQFilterSet, LeadsFilterDataInputType +from project.filter_set import ProjectGqlFilterSet, ProjectsFilterDataInputType from rest_framework import serializers -from utils.graphene.fields import generate_serializer_field_class from deep.serializers import ( - RemoveNullFieldsMixin, + GraphqlSupportDrfSerializerJSONField, ProjectPropertySerializerMixin, + RemoveNullFieldsMixin, StringIDField, - GraphqlSupportDrfSerializerJSONField, ) -from lead.filter_set import LeadGQFilterSet, LeadsFilterDataInputType -from entry.filter_set import EntryGQFilterSet, EntriesFilterDataInputType -from project.filter_set import ProjectGqlFilterSet, ProjectsFilterDataInputType -from deep_explore.schema import ExploreDeepFilterInputType -from deep_explore.filter_set import ExploreProjectFilterSet -from analysis_framework.models import Widget, Exportable -from .tasks import export_task, generic_export_task +from utils.graphene.fields import generate_serializer_field_class + from .models import Export, GenericExport +from .tasks import export_task, generic_export_task class ExportSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): @@ -27,40 +27,42 @@ class ExportSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.Mo class Meta: model = Export - exclude = ('filters',) + exclude = ("filters",) # ------------------- Graphql Serializers ---------------------------------------- # ---- [Start] ExportReportLevel Serialisers class ExportReportLevelWidgetFourthLevelSerializer(serializers.Serializer): """ - Additional sub-level (sub-column) For matrix2d + Additional sub-level (sub-column) For matrix2d """ - id = StringIDField(help_text='Matrix2D: {column-key}-{sub-column}-{row-key}-{sub-row-key}') - title = serializers.CharField(help_text='Matrix2D: {sub-column-label}') + + id = StringIDField(help_text="Matrix2D: {column-key}-{sub-column}-{row-key}-{sub-row-key}") + title = serializers.CharField(help_text="Matrix2D: {sub-column-label}") class ExportReportLevelWidgetSubSubLevelSerializer(serializers.Serializer): # Additional sub-level For matrix2d - id = StringIDField(help_text='Matrix2D: {column-key}-{row-key}-{sub-row-key}') - title = serializers.CharField(help_text='Matrix2D: {sub-row-label}') + id = StringIDField(help_text="Matrix2D: {column-key}-{row-key}-{sub-row-key}") + title = serializers.CharField(help_text="Matrix2D: {sub-row-label}") sublevels = ExportReportLevelWidgetFourthLevelSerializer( many=True, required=False, - help_text='For 2D matrix (sub-column)', + help_text="For 2D matrix (sub-column)", ) class ExportReportLevelWidgetSubLevelSerializer(serializers.Serializer): - id = StringIDField(help_text='Matrix1D: {row-key}-{cell-key}, Matrix2D: {column-key}-{row-key}') - title = serializers.CharField(help_text='Matrix1D: {cell-label}, Matrix2D: {row-label}') - sublevels = ExportReportLevelWidgetSubSubLevelSerializer(many=True, required=False, help_text='For 2D matrix') + id = StringIDField(help_text="Matrix1D: {row-key}-{cell-key}, Matrix2D: {column-key}-{row-key}") + title = serializers.CharField(help_text="Matrix1D: {cell-label}, Matrix2D: {row-label}") + sublevels = ExportReportLevelWidgetSubSubLevelSerializer(many=True, required=False, help_text="For 2D matrix") class ExportReportLevelWidgetLevelSerializer(serializers.Serializer): - id = StringIDField(help_text='Matrix1D: {row-key}, Matrix2D: {column-key}') - title = serializers.CharField(help_text='Matrix1D: {row-label}, Matrix2D: {column-label}') + id = StringIDField(help_text="Matrix1D: {row-key}, Matrix2D: {column-key}") + title = serializers.CharField(help_text="Matrix1D: {row-label}, Matrix2D: {column-label}") sublevels = ExportReportLevelWidgetSubLevelSerializer( - many=True, required=False, help_text='Not required for uncategorized data') + many=True, required=False, help_text="Not required for uncategorized data" + ) class ExportReportLevelWidgetSerializer(serializers.Serializer): @@ -115,40 +117,43 @@ class ExportReportLevelWidgetSerializer(serializers.Serializer): ], } """ - id = StringIDField(help_text='Widget ID') - levels = ExportReportLevelWidgetLevelSerializer(many=True, required=False, help_text='Widget levels') + id = StringIDField(help_text="Widget ID") + levels = ExportReportLevelWidgetLevelSerializer(many=True, required=False, help_text="Widget levels") # ---- [End] ExportReportLevel Serialisers # ---- [Start] ExportReportStructure Serialisers class ExportReportStructureWidgetFourthLevelSerializer(serializers.Serializer): """ - Additional sub-level (sub-column) For matrix2d + Additional sub-level (sub-column) For matrix2d """ - id = StringIDField(help_text='Matrix2D: {column-key}-{sub-column}-{row-key}-{sub-row-key}') + + id = StringIDField(help_text="Matrix2D: {column-key}-{sub-column}-{row-key}-{sub-row-key}") class ExportReportStructureWidgetThirdLevelSerializer(serializers.Serializer): """ # Additional sub-level (sub-row) For matrix2d """ - id = StringIDField(help_text='Matrix2D: {column-key}-{row-key}-{sub-row-key}') + + id = StringIDField(help_text="Matrix2D: {column-key}-{row-key}-{sub-row-key}") levels = ExportReportStructureWidgetFourthLevelSerializer( many=True, required=False, - help_text='For 2D matrix (sub-column)', + help_text="For 2D matrix (sub-column)", ) class ExportReportStructureWidgetSecondLevelSerializer(serializers.Serializer): - id = StringIDField(help_text='Matrix1D: {row-key}-{cell-key}, Matrix2D: {column-key}-{row-key}') - levels = ExportReportStructureWidgetThirdLevelSerializer(many=True, required=False, help_text='For 2D matrix') + id = StringIDField(help_text="Matrix1D: {row-key}-{cell-key}, Matrix2D: {column-key}-{row-key}") + levels = ExportReportStructureWidgetThirdLevelSerializer(many=True, required=False, help_text="For 2D matrix") class ExportReportStructureWidgetFirstLevelSerializer(serializers.Serializer): - id = StringIDField(help_text='Matrix1D: {row-key}, Matrix2D: {column-key}') + id = StringIDField(help_text="Matrix1D: {row-key}, Matrix2D: {column-key}") levels = ExportReportStructureWidgetSecondLevelSerializer( - many=True, required=False, help_text='Not required for uncategorized data') + many=True, required=False, help_text="Not required for uncategorized data" + ) class ExportReportStructureWidgetSerializer(serializers.Serializer): @@ -194,8 +199,10 @@ class ExportReportStructureWidgetSerializer(serializers.Serializer): ], } """ - id = StringIDField(help_text='Widget ID') - levels = ExportReportStructureWidgetFirstLevelSerializer(many=True, required=False, help_text='Widget levels') + + id = StringIDField(help_text="Widget ID") + levels = ExportReportStructureWidgetFirstLevelSerializer(many=True, required=False, help_text="Widget levels") + # ---- [End] ExportReportStructure Serialisers @@ -206,11 +213,11 @@ class ExportExcelSelectedColumnSerializer(serializers.Serializer): static_column = serializers.ChoiceField(choices=Export.StaticColumn.choices, required=False) def validate(self, data): - if data['is_widget']: - if data.get('widget_key') is None: - raise serializers.ValidationError('widget_key key is required when is widget is True') - elif data.get('static_column') is None: - raise serializers.ValidationError('static_column is required when is widget is False') + if data["is_widget"]: + if data.get("widget_key") is None: + raise serializers.ValidationError("widget_key key is required when is widget is True") + elif data.get("static_column") is None: + raise serializers.ValidationError("static_column is required when is widget is False") return data @@ -219,8 +226,7 @@ class ExportExtraOptionsSerializer(ProjectPropertySerializerMixin, serializers.S date_format = serializers.ChoiceField(choices=Export.DateFormat.choices, required=False) # Excel - excel_decoupled = serializers.BooleanField( - help_text="Don't group entries tags. Slower export generation.", required=False) + excel_decoupled = serializers.BooleanField(help_text="Don't group entries tags. Slower export generation.", required=False) excel_columns = ExportExcelSelectedColumnSerializer( required=False, many=True, @@ -235,9 +241,11 @@ class ExportExtraOptionsSerializer(ProjectPropertySerializerMixin, serializers.S report_text_widget_ids = serializers.ListField(child=StringIDField(), allow_empty=True, required=False) report_exporting_widgets = serializers.ListField(child=StringIDField(), allow_empty=True, required=False) report_levels = ExportReportLevelWidgetSerializer( - required=False, many=True, help_text=ExportReportLevelWidgetSerializer.__doc__) + required=False, many=True, help_text=ExportReportLevelWidgetSerializer.__doc__ + ) report_structure = ExportReportStructureWidgetSerializer( - required=False, many=True, help_text=ExportReportStructureWidgetSerializer.__doc__) + required=False, many=True, help_text=ExportReportStructureWidgetSerializer.__doc__ + ) report_citation_style = serializers.ChoiceField(choices=Export.CitationStyle.choices, required=False) @@ -247,14 +255,14 @@ class UserExportBaseGqlMixin(ProjectPropertySerializerMixin, serializers.ModelSe class Meta: model = Export fields = ( - 'title', - 'type', # Data type (entries, assessments, ..) - 'format', # xlsx, docx, pdf, ... - 'export_type', # excel, report, json, ... - 'is_preview', - 'filters', - 'extra_options', - 'analysis', + "title", + "type", # Data type (entries, assessments, ..) + "format", # xlsx, docx, pdf, ... + "export_type", # excel, report, json, ... + "is_preview", + "filters", + "extra_options", + "analysis", ) filters = generate_serializer_field_class(LeadsFilterDataInputType, GraphqlSupportDrfSerializerJSONField)() @@ -272,38 +280,30 @@ def validate_title(self, title): existing_exports = Export.objects.filter( title=title, project=self.project, - exported_by=self.context['request'].user, + exported_by=self.context["request"].user, ) if self.instance: existing_exports = existing_exports.exclude(id=self.instance.id) if existing_exports.exists(): - raise serializers.ValidationError(f'Title {title} already exists.') + raise serializers.ValidationError(f"Title {title} already exists.") return title def validate_filters(self, filters): - filter_set = LeadGQFilterSet(data=filters, request=self.context['request']) + filter_set = LeadGQFilterSet(data=filters, request=self.context["request"]) if not filter_set.is_valid(): raise serializers.ValidationError(filter_set.errors) return filters def validate_report_text_widget_ids(self, widget_ids): if widget_ids: - text_widgets_id = self.widget_qs.filter(widget_id=Widget.WidgetType.TEXT).values_list('id', flat=True) - return [ - widget_id - for widget_id in widget_ids - if widget_id in text_widgets_id - ] + text_widgets_id = self.widget_qs.filter(widget_id=Widget.WidgetType.TEXT).values_list("id", flat=True) + return [widget_id for widget_id in widget_ids if widget_id in text_widgets_id] return [] def validate_report_exporting_widgets(self, widget_ids): if widget_ids: - widgets_id = self.widget_qs.values_list('id', flat=True) - return [ - widget_id - for widget_id in widget_ids - if widget_id in widgets_id - ] + widgets_id = self.widget_qs.values_list("id", flat=True) + return [widget_id for widget_id in widget_ids if widget_id in widgets_id] return [] # TODO: def validate_report_levels(self, widget_ids): @@ -321,36 +321,32 @@ class UserExportCreateGqlSerializer(UserExportBaseGqlMixin, serializers.ModelSer def validate(self, data): # NOTE: We only need to check with create logic (as update only have title for now) # Validate type, export_type and format - data_type = data['type'] - export_type = data['export_type'] - _format = data['format'] + data_type = data["type"] + export_type = data["export_type"] + _format = data["format"] if (data_type, export_type, _format) not in Export.DEFAULT_TITLE_LABEL: - raise serializers.ValidationError(f'Unsupported Export request: {(data_type, export_type, _format)}') + raise serializers.ValidationError(f"Unsupported Export request: {(data_type, export_type, _format)}") return data def create(self, data): - data['title'] = data.get('title') or Export.generate_title(data['type'], data['export_type'], data['format']) - data['exported_by'] = self.context['request'].user - data['project'] = self.project + data["title"] = data.get("title") or Export.generate_title(data["type"], data["export_type"], data["format"]) + data["exported_by"] = self.context["request"].user + data["project"] = self.project export = super().create(data) - transaction.on_commit( - lambda: export.set_task_id(export_task.delay(export.id).id) - ) + transaction.on_commit(lambda: export.set_task_id(export_task.delay(export.id).id)) return export def update(self, _): - raise serializers.ValidationError('Not allowed using this serializer.') + raise serializers.ValidationError("Not allowed using this serializer.") class UserExportUpdateGqlSerializer(UserExportBaseGqlMixin, serializers.ModelSerializer): class Meta: model = Export - fields = ( - 'title', - ) + fields = ("title",) def create(self, _): - raise serializers.ValidationError('Not allowed using this serializer.') + raise serializers.ValidationError("Not allowed using this serializer.") class UserGenericExportFiltersGqlSerializer(serializers.Serializer): @@ -372,10 +368,10 @@ class UserGenericExportCreateGqlSerializer(serializers.ModelSerializer): class Meta: model = GenericExport fields = ( - 'title', - 'type', # Data type - 'format', # csv, xlsx, docx, pdf, ... - 'filters', + "title", + "type", # Data type + "format", # csv, xlsx, docx, pdf, ... + "filters", ) filters = UserGenericExportFiltersGqlSerializer() @@ -385,7 +381,7 @@ def _validate_filterset(filter_data, filter_key, filter_set): filter_data = filter_data.get(filter_key) if not filter_data: return - filter_set = filter_set(data=filter_data, request=self.context['request']) + filter_set = filter_set(data=filter_data, request=self.context["request"]) if not filter_set.is_valid(): return filter_set.errors @@ -393,24 +389,22 @@ def _validate_filterset(filter_data, filter_key, filter_set): # Validate each filter data for filter_key, FilterSet in [ - ('project', ProjectGqlFilterSet), - ('lead', LeadGQFilterSet), - ('entry', EntryGQFilterSet), - ('deep_explore', None), + ("project", ProjectGqlFilterSet), + ("lead", LeadGQFilterSet), + ("entry", EntryGQFilterSet), + ("deep_explore", None), ]: - if filter_key == 'deep_explore': + if filter_key == "deep_explore": filter_data = filters.get(filter_key) or {} if data_type == GenericExport.DataType.PROJECTS_STATS and not filter_data: - errors[filter_key] = [f'This is required for {data_type}'] + errors[filter_key] = [f"This is required for {data_type}"] continue if filterset_errors := _validate_filterset( filter_data, - 'project', + "project", ExploreProjectFilterSet, ): - errors[filter_key] = { - 'project': filterset_errors - } + errors[filter_key] = {"project": filterset_errors} continue # Generic if filterset_errors := _validate_filterset( @@ -424,23 +418,21 @@ def _validate_filterset(filter_data, filter_key, filter_set): def validate(self, data): # Validate type, export_type and format - data_type = data['type'] - _format = data['format'] - filters = data['filters'] + data_type = data["type"] + _format = data["format"] + filters = data["filters"] if (data_type, _format) not in GenericExport.DEFAULT_TITLE_LABEL: - raise serializers.ValidationError(f'Unsupported Export request: {(data_type, _format)}') + raise serializers.ValidationError(f"Unsupported Export request: {(data_type, _format)}") if errors := self._validate_filters(data_type, filters): - raise serializers.ValidationError({'filters': errors}) + raise serializers.ValidationError({"filters": errors}) return data def create(self, data): - data['title'] = data.get('title') or GenericExport.generate_title(data['type'], data['format']) - data['exported_by'] = self.context['request'].user + data["title"] = data.get("title") or GenericExport.generate_title(data["type"], data["format"]) + data["exported_by"] = self.context["request"].user export = super().create(data) - transaction.on_commit( - lambda: export.set_task_id(generic_export_task.delay(export.id).id) - ) + transaction.on_commit(lambda: export.set_task_id(generic_export_task.delay(export.id).id)) return export def update(self, _): - raise serializers.ValidationError('Not allowed using this serializer.') + raise serializers.ValidationError("Not allowed using this serializer.") diff --git a/apps/export/tasks/__init__.py b/apps/export/tasks/__init__.py index f1afd4c865..90f80447e0 100644 --- a/apps/export/tasks/__init__.py +++ b/apps/export/tasks/__init__.py @@ -1,13 +1,14 @@ import logging -from django.utils import timezone from celery import shared_task +from django.utils import timezone +from export.models import Export, GenericExport from deep.celery import CeleryQueue -from export.models import Export, GenericExport -from .tasks_entries import export_entries -from .tasks_assessment import export_assessments + from .tasks_analyses import export_analyses +from .tasks_assessment import export_assessments +from .tasks_entries import export_entries from .tasks_projects import export_projects_stats logger = logging.getLogger(__name__) @@ -24,27 +25,32 @@ def get_export_filename(export): - filename = f'{export.title}.{export.format}' - if getattr(export, 'is_preview', False): - filename = f'(Preview) {filename}' + filename = f"{export.title}.{export.format}" + if getattr(export, "is_preview", False): + filename = f"(Preview) {filename}" return filename @shared_task(queue=CeleryQueue.EXPORT_HEAVY) def export_task(export_id, force=False): - data_type = 'UNKNOWN' + data_type = "UNKNOWN" try: export = Export.objects.get(pk=export_id) data_type = export.type # Skip if export is already started if not force and export.status != Export.Status.PENDING: - logger.warning(f'Export status is {export.get_status_display()}') - return 'SKIPPED' + logger.warning(f"Export status is {export.get_status_display()}") + return "SKIPPED" # Update status to STARTED export.status = Export.Status.STARTED export.started_at = timezone.now() - export.save(update_fields=('status', 'started_at',)) + export.save( + update_fields=( + "status", + "started_at", + ) + ) file = EXPORTER_TYPE[export.type](export) @@ -63,13 +69,18 @@ def export_task(export_id, force=False): if export: export.status = Export.Status.FAILURE export.ended_at = timezone.now() - export.save(update_fields=('status', 'ended_at',)) + export.save( + update_fields=( + "status", + "ended_at", + ) + ) logger.error( - f'Export Failed {data_type}!!', + f"Export Failed {data_type}!!", exc_info=True, extra={ - 'data': { - 'export_id': export_id, + "data": { + "export_id": export_id, }, }, ) @@ -81,19 +92,24 @@ def export_task(export_id, force=False): # NOTE: limit are in seconds @shared_task(queue=CeleryQueue.DEFAULT, time_limit=220, soft_time_limit=120) def generic_export_task(export_id, force=False): - data_type = 'UNKNOWN' + data_type = "UNKNOWN" try: export = GenericExport.objects.get(pk=export_id) data_type = export.type # Skip if export is already started if not force and export.status != GenericExport.Status.PENDING: - logger.warning(f'Generic Export status is {export.get_status_display()}') - return 'SKIPPED' + logger.warning(f"Generic Export status is {export.get_status_display()}") + return "SKIPPED" # Update status to STARTED export.status = GenericExport.Status.STARTED export.started_at = timezone.now() - export.save(update_fields=('status', 'started_at',)) + export.save( + update_fields=( + "status", + "started_at", + ) + ) file = GENERIC_EXPORTER_TYPE[export.type](export) @@ -112,13 +128,18 @@ def generic_export_task(export_id, force=False): if export: export.status = GenericExport.Status.FAILURE export.ended_at = timezone.now() - export.save(update_fields=('status', 'ended_at',)) + export.save( + update_fields=( + "status", + "ended_at", + ) + ) logger.error( - f'Generic Export Failed {data_type}!!', + f"Generic Export Failed {data_type}!!", exc_info=True, extra={ - 'data': { - 'export_id': export_id, + "data": { + "export_id": export_id, }, }, ) diff --git a/apps/export/tasks/tasks_analyses.py b/apps/export/tasks/tasks_analyses.py index 7f6da88eac..af2093a350 100644 --- a/apps/export/tasks/tasks_analyses.py +++ b/apps/export/tasks/tasks_analyses.py @@ -1,7 +1,6 @@ from analysis.models import AnalyticalStatementEntry -from export.models import Export - from export.analyses.excel_exporter import ExcelExporter +from export.models import Export def export_analyses(export): @@ -12,12 +11,10 @@ def export_analyses(export): analytical_statement__analysis_pillar__analysis=analysis ) if export_type == Export.ExportType.EXCEL: - export_data = ExcelExporter(analytical_statement_entries)\ - .add_analytical_statement_entries(analytical_statement_entries)\ - .export() - else: - raise Exception( - f'(Analysis Export) Unkown Export Type Provided: {export_type} for Export: {export.id}' + export_data = ( + ExcelExporter(analytical_statement_entries).add_analytical_statement_entries(analytical_statement_entries).export() ) + else: + raise Exception(f"(Analysis Export) Unkown Export Type Provided: {export_type} for Export: {export.id}") return export_data diff --git a/apps/export/tasks/tasks_assessment.py b/apps/export/tasks/tasks_assessment.py index 135f25f516..72ed505f24 100644 --- a/apps/export/tasks/tasks_assessment.py +++ b/apps/export/tasks/tasks_assessment.py @@ -1,16 +1,15 @@ import copy -from deep.permissions import ProjectPermissions as PP -from deep.filter_set import get_dummy_request -from lead.models import Lead -from lead.filter_set import LeadGQFilterSet -from ary.export import ( - get_export_data_for_assessments, -) -from export.models import Export -from export.exporters import JsonExporter -from export.assessments import NewExcelExporter +from ary.export import get_export_data_for_assessments from assessment_registry.models import AssessmentRegistry +from export.assessments import NewExcelExporter +from export.exporters import JsonExporter +from export.models import Export +from lead.filter_set import LeadGQFilterSet +from lead.models import Lead + +from deep.filter_set import get_dummy_request +from deep.permissions import ProjectPermissions as PP def _export_assessments(export, AssessmentModel, excel_sheet_data_generator): @@ -19,7 +18,7 @@ def _export_assessments(export, AssessmentModel, excel_sheet_data_generator): export_type = export.export_type is_preview = export.is_preview - arys = AssessmentModel.objects.filter(project=project).select_related('project').distinct() + arys = AssessmentModel.objects.filter(project=project).select_related("project").distinct() if AssessmentModel == AssessmentRegistry: # Filter is only available for Assessments (not PlannedAssessment) user_project_permissions = PP.get_permissions(project, user) filters = copy.deepcopy(export.filters) # Avoid mutating database values @@ -30,23 +29,17 @@ def _export_assessments(export, AssessmentModel, excel_sheet_data_generator): dummy_request = get_dummy_request(active_project=project) leads_qs = LeadGQFilterSet(data=filters, queryset=leads_qs, request=dummy_request).qs arys = arys.filter(lead__in=leads_qs) - iterable_arys = arys[:Export.PREVIEW_ASSESSMENT_SIZE] if is_preview else arys + iterable_arys = arys[: Export.PREVIEW_ASSESSMENT_SIZE] if is_preview else arys if export_type == Export.ExportType.JSON: exporter = JsonExporter() - exporter.data = { - ary.project.title: ary.to_exportable_json() - for ary in iterable_arys - } + exporter.data = {ary.project.title: ary.to_exportable_json() for ary in iterable_arys} export_data = exporter.export() elif export_type == Export.ExportType.EXCEL: sheets_data = excel_sheet_data_generator(iterable_arys) - export_data = NewExcelExporter(sheets_data)\ - .export() + export_data = NewExcelExporter(sheets_data).export() else: - raise Exception( - f'(Assessments Export) Unkown Export Type Provided: {export_type} for Export: {export.id}' - ) + raise Exception(f"(Assessments Export) Unkown Export Type Provided: {export_type} for Export: {export.id}") return export_data diff --git a/apps/export/tasks/tasks_entries.py b/apps/export/tasks/tasks_entries.py index 60888c44ed..52bd376cc2 100644 --- a/apps/export/tasks/tasks_entries.py +++ b/apps/export/tasks/tasks_entries.py @@ -1,19 +1,19 @@ import copy -from django.db import models - -from deep.permissions import ProjectPermissions as PP -from deep.filter_set import get_dummy_request from analysis_framework.models import Exportable +from django.db import models +from entry.filter_set import EntryGQFilterSet from entry.models import Entry -from export.models import Export from export.entries.excel_exporter import ExcelExporter -from export.entries.report_exporter import ReportExporter from export.entries.json_exporter import JsonExporter +from export.entries.report_exporter import ReportExporter +from export.models import Export from geo.models import Region -from lead.models import Lead from lead.filter_set import LeadGQFilterSet -from entry.filter_set import EntryGQFilterSet +from lead.models import Lead + +from deep.filter_set import get_dummy_request +from deep.permissions import ProjectPermissions as PP def export_entries(export): @@ -35,35 +35,35 @@ def export_entries(export): # Lead and Entry FilterSet needs request to work with active_project dummy_request = get_dummy_request(active_project=project) leads_qs = LeadGQFilterSet(data=filters, queryset=leads_qs, request=dummy_request).qs.prefetch_related( - 'authors', - 'authors__organization_type', + "authors", + "authors__organization_type", # Also organization parents - 'authors__parent', - 'authors__parent__organization_type', + "authors__parent", + "authors__parent__organization_type", ) entries_qs = EntryGQFilterSet( - data=filters.get('entries_filter_data'), + data=filters.get("entries_filter_data"), request=dummy_request, queryset=Entry.objects.filter( project=export.project, analysis_framework=export.project.analysis_framework_id, lead__in=leads_qs, - ) + ), ).qs # Prefetches entries_qs = entries_qs.prefetch_related( - 'entrygrouplabel_set', + "entrygrouplabel_set", models.Prefetch( - 'lead', + "lead", queryset=Lead.objects.annotate( - page_count=models.F('leadpreview__page_count'), + page_count=models.F("leadpreview__page_count"), ).prefetch_related( - 'authors', - 'authors__organization_type', + "authors", + "authors__organization_type", # Also organization parents - 'authors__parent', - 'authors__parent__organization_type', + "authors__parent", + "authors__parent__organization_type", ), ), ) @@ -73,38 +73,40 @@ def export_entries(export): ).distinct() regions = Region.objects.filter(project=project).distinct() - date_format = extra_options.get('date_format') + date_format = extra_options.get("date_format") if export_type == Export.ExportType.EXCEL: - decoupled = extra_options.get('excel_decoupled', False) - columns = extra_options.get('excel_columns') - export_data = ExcelExporter( - export, - entries_qs, - project, - date_format, - columns=columns, - decoupled=decoupled, - is_preview=is_preview, - )\ - .load_exportables(exportables, regions)\ - .add_entries(entries_qs)\ + decoupled = extra_options.get("excel_decoupled", False) + columns = extra_options.get("excel_columns") + export_data = ( + ExcelExporter( + export, + entries_qs, + project, + date_format, + columns=columns, + decoupled=decoupled, + is_preview=is_preview, + ) + .load_exportables(exportables, regions) + .add_entries(entries_qs) .export(leads_qs) + ) elif export_type == Export.ExportType.REPORT: # which widget data needs to be exported along with - exporting_widgets = extra_options.get('report_exporting_widgets', []) + exporting_widgets = extra_options.get("report_exporting_widgets", []) report_show_attributes = dict( - show_lead_entry_id=extra_options.get('report_show_lead_entry_id', True), - show_assessment_data=extra_options.get('report_show_assessment_data', True), - show_entry_widget_data=extra_options.get('report_show_entry_widget_data', True), + show_lead_entry_id=extra_options.get("report_show_lead_entry_id", True), + show_assessment_data=extra_options.get("report_show_assessment_data", True), + show_entry_widget_data=extra_options.get("report_show_entry_widget_data", True), ) - citation_style = extra_options.get('report_citation_style') - report_structure = extra_options.get('report_structure') - report_levels = extra_options.get('report_levels') - text_widget_ids = extra_options.get('report_text_widget_ids') or [] - show_groups = extra_options.get('report_show_groups') + citation_style = extra_options.get("report_citation_style") + report_structure = extra_options.get("report_structure") + report_levels = extra_options.get("report_levels") + text_widget_ids = extra_options.get("report_text_widget_ids") or [] + show_groups = extra_options.get("report_show_groups") export_data = ( ReportExporter( date_format, @@ -112,7 +114,8 @@ def export_entries(export): exporting_widgets=exporting_widgets, is_preview=is_preview, **report_show_attributes, - ).load_exportables(exportables, regions) + ) + .load_exportables(exportables, regions) .load_levels(report_levels) .load_structure(report_structure) .load_group_lables(entries_qs, show_groups) @@ -122,14 +125,9 @@ def export_entries(export): ) elif export_type == Export.ExportType.JSON: - export_data = JsonExporter(is_preview=is_preview)\ - .load_exportables(exportables)\ - .add_entries(entries_qs)\ - .export() + export_data = JsonExporter(is_preview=is_preview).load_exportables(exportables).add_entries(entries_qs).export() else: - raise Exception( - '(Entries Export) Unkown Export Type Provided: {export_type} for Export: {export.id}' - ) + raise Exception("(Entries Export) Unkown Export Type Provided: {export_type} for Export: {export.id}") return export_data diff --git a/apps/export/tasks/tasks_projects.py b/apps/export/tasks/tasks_projects.py index b501040578..ec96249987 100644 --- a/apps/export/tasks/tasks_projects.py +++ b/apps/export/tasks/tasks_projects.py @@ -1,26 +1,23 @@ import csv from io import StringIO +from deep_explore.filter_set import ExploreProjectFilterSet +from deep_explore.schema import get_global_filters, project_queryset from django.db import models +from organization.models import Organization +from project.models import ProjectMembership, ProjectOrganization, ProjectRole -from utils.files import generate_file_for_upload from deep.filter_set import get_dummy_request -from project.models import ProjectOrganization, ProjectRole, ProjectMembership -from organization.models import Organization -from deep_explore.filter_set import ExploreProjectFilterSet -from deep_explore.schema import get_global_filters, project_queryset +from utils.files import generate_file_for_upload def get_organizations_display(project, organization_type=None): organization_ids_qs = ProjectOrganization.objects.filter(project=project) if organization_type: organization_ids_qs = organization_ids_qs.filter(organization_type=organization_type) - return ','.join([ - org.data.title - for org in Organization.objects.filter( - id__in=organization_ids_qs.values_list('organization', flat=True) - ) - ]) + return ",".join( + [org.data.title for org in Organization.objects.filter(id__in=organization_ids_qs.values_list("organization", flat=True))] + ) def generate_projects_stats(filters, user): @@ -33,48 +30,50 @@ def generate_projects_stats(filters, user): ) if not filters: - raise Exception('This should be defined.') + raise Exception("This should be defined.") - project_filters = filters.get('project') or {} + project_filters = filters.get("project") or {} file = StringIO() headers = [ - 'ID', - 'Title', - 'Created Date', - 'Owners', - 'Start Date', - 'End Date', - 'Last Entry (Date)', - 'Organisation (Project owner)', - 'Project Stakeholders', - 'Geo Areas', - 'Analysis Framework', - 'Description', - 'Status', - 'Test project (Y/N)', - 'Members Count', - 'Sources Count', - 'Entries Count', - '# of Exports', + "ID", + "Title", + "Created Date", + "Owners", + "Start Date", + "End Date", + "Last Entry (Date)", + "Organisation (Project owner)", + "Project Stakeholders", + "Geo Areas", + "Analysis Framework", + "Description", + "Status", + "Test project (Y/N)", + "Members Count", + "Sources Count", + "Entries Count", + "# of Exports", ] projects_qs = project_queryset().annotate( - analysis_framework_title=models.F('analysis_framework__title'), + analysis_framework_title=models.F("analysis_framework__title"), ) projects_qs = ExploreProjectFilterSet(project_filters, queryset=projects_qs, **filterset_attrs).qs projects_qs = projects_qs.filter(**get_global_filters(filters)) - writer = csv.DictWriter(file, fieldnames=headers, extrasaction='ignore') + writer = csv.DictWriter(file, fieldnames=headers, extrasaction="ignore") writer.writeheader() - for project in projects_qs.order_by('-id'): - last_entry = project.entry_set.order_by('-id').first() - owners = ','.join([ - f'{member.member.get_display_name()}' - for member in ProjectMembership.objects.filter(project=project, role=PROJECT_OWNER_ROLE) - ]) + for project in projects_qs.order_by("-id"): + last_entry = project.entry_set.order_by("-id").first() + owners = ",".join( + [ + f"{member.member.get_display_name()}" + for member in ProjectMembership.objects.filter(project=project, role=PROJECT_OWNER_ROLE) + ] + ) regions_qs = project.regions members_qs = project.members @@ -82,34 +81,36 @@ def generate_projects_stats(filters, user): leads_qs = project.lead_set.filter(**get_global_filters(filters)) entries_qs = project.entry_set.filter(**get_global_filters(filters)) - writer.writerow({ - 'ID': project.id, - 'Title': project.title, - 'Created Date': project.created_at, - 'Owners': owners, - 'Start Date': project.start_date, - 'End Date': project.end_date, - 'Last Entry (Date)': last_entry and last_entry.created_at, - 'Organisation (Project owner)': get_organizations_display( - project, - ProjectOrganization.Type.LEAD_ORGANIZATION, - ), - 'Project Stakeholders': get_organizations_display(project), - 'Geo Areas': ','.join(regions_qs.values_list('title', flat=True).distinct()), - 'Analysis Framework': project.analysis_framework_title, - 'Description': project.description, - 'Status': project.status, - 'Test project (Y/N)': 'Y' if project.is_test else 'N', - 'Members Count': members_qs.count(), - 'Sources Count': leads_qs.count(), - 'Entries Count': entries_qs.count(), - '# of Exports': exports_qs.count(), - }) + writer.writerow( + { + "ID": project.id, + "Title": project.title, + "Created Date": project.created_at, + "Owners": owners, + "Start Date": project.start_date, + "End Date": project.end_date, + "Last Entry (Date)": last_entry and last_entry.created_at, + "Organisation (Project owner)": get_organizations_display( + project, + ProjectOrganization.Type.LEAD_ORGANIZATION, + ), + "Project Stakeholders": get_organizations_display(project), + "Geo Areas": ",".join(regions_qs.values_list("title", flat=True).distinct()), + "Analysis Framework": project.analysis_framework_title, + "Description": project.description, + "Status": project.status, + "Test project (Y/N)": "Y" if project.is_test else "N", + "Members Count": members_qs.count(), + "Sources Count": leads_qs.count(), + "Entries Count": entries_qs.count(), + "# of Exports": exports_qs.count(), + } + ) return generate_file_for_upload(file) def export_projects_stats(export): return generate_projects_stats( - (export.filters or {}).get('deep_explore') or {}, + (export.filters or {}).get("deep_explore") or {}, export.exported_by, ) diff --git a/apps/export/tests/test_apis.py b/apps/export/tests/test_apis.py index 604ebeff1b..00f7b1aafa 100644 --- a/apps/export/tests/test_apis.py +++ b/apps/export/tests/test_apis.py @@ -1,23 +1,22 @@ from dateutil.relativedelta import relativedelta - from django.utils import timezone +from export.models import Export +from project.models import Project from deep.tests import TestCase -from project.models import Project -from export.models import Export class ExportTests(TestCase): def test_get_export(self): export = self.create(Export, exported_by=self.user) - url = '/api/v1/exports/{}/'.format(export.id) + url = "/api/v1/exports/{}/".format(export.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['title'], export.title) - self.assertEqual(response.data['exported_by'], self.user.id) + self.assertEqual(response.data["title"], export.title) + self.assertEqual(response.data["exported_by"], self.user.id) def test_trigger_api_without_export_permission(self): # Create project and modify role to have no export permission @@ -28,10 +27,10 @@ def test_trigger_api_without_export_permission(self): role.export_permissions = 0 role.save() - url = '/api/v1/export-trigger/' + url = "/api/v1/export-trigger/" data = { - 'filters': [ - ['project', project.pk], + "filters": [ + ["project", project.pk], ], } @@ -42,7 +41,7 @@ def test_trigger_api_without_export_permission(self): assert Export.objects.count() == 0 def test_trigger_api_with_export_permission(self): - url = '/api/v1/export-trigger/' + url = "/api/v1/export-trigger/" # Create project and modify role to have no export permission project = self.create(Project) @@ -53,10 +52,10 @@ def test_trigger_api_with_export_permission(self): role.save() self.authenticate(self.user) - response = self.client.post(url, data={'filters': [['project', project.id]]}) + response = self.client.post(url, data={"filters": [["project", project.id]]}) self.assert_200(response) - export = Export.objects.get(id=response.data['export_triggered']) + export = Export.objects.get(id=response.data["export_triggered"]) self.assertEqual(export.exported_by, self.user) def test_delete_export(self): @@ -69,20 +68,20 @@ def test_delete_export(self): before_delete = Export.objects.count() # test user can delete his export - url = '/api/v1/exports/{}/'.format(export1.id) + url = "/api/v1/exports/{}/".format(export1.id) self.authenticate(user1) response = self.client.delete(url) self.assert_204(response) # delete from api # test user canot delete other export - url = '/api/v1/exports/{}/'.format(export3.id) + url = "/api/v1/exports/{}/".format(export3.id) self.authenticate(user2) response = self.client.delete(url) self.assert_404(response) - url = '/api/v1/exports/{}/'.format(export2.id) + url = "/api/v1/exports/{}/".format(export2.id) self.authenticate(user2) response = self.client.delete(url) self.assert_204(response) @@ -91,26 +90,22 @@ def test_delete_export(self): self.assertEqual(before_delete, after_delete) # test get the data from api - url = '/api/v1/exports/' + url = "/api/v1/exports/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) # should have one export + self.assertEqual(response.data["count"], 1) # should have one export # test update by deleted export - url = '/api/v1/exports/{}/'.format(export1.id) - data = { - 'title': 'Title test' - } + url = "/api/v1/exports/{}/".format(export1.id) + data = {"title": "Title test"} self.authenticate(user1) response = self.client.patch(url, data) self.assert_404(response) # test update by another user - url = '/api/v1/exports/{}/'.format(export1.id) - data = { - 'title': 'Title test' - } + url = "/api/v1/exports/{}/".format(export1.id) + data = {"title": "Title test"} self.authenticate(user2) response = self.client.patch(url, data) self.assert_404(response) @@ -122,12 +117,12 @@ def test_export_filter_by_status(self): self.create(Export, exported_by=self.user, status=Export.Status.FAILURE) self.authenticate() - response = self.client.get(f'/api/v1/exports/?status={Export.Status.PENDING.value}') - assert response.json()['count'] == 1 - response = self.client.get('/api/v1/exports/') - assert response.json()['count'] == 4 - response = self.client.get(f'/api/v1/exports/?status={Export.Status.PENDING.value},{Export.Status.FAILURE.value}') - assert response.json()['count'] == 2 + response = self.client.get(f"/api/v1/exports/?status={Export.Status.PENDING.value}") + assert response.json()["count"] == 1 + response = self.client.get("/api/v1/exports/") + assert response.json()["count"] == 4 + response = self.client.get(f"/api/v1/exports/?status={Export.Status.PENDING.value},{Export.Status.FAILURE.value}") + assert response.json()["count"] == 2 def test_export_filter_by_type(self): types = [ @@ -140,10 +135,10 @@ def test_export_filter_by_type(self): self.create(Export, exported_by=self.user, type=type) self.authenticate() - response = self.client.get(f'/api/v1/exports/?type={Export.DataType.ASSESSMENTS}') - assert response.json()['count'] == 2 - response = self.client.get(f'/api/v1/exports/?type={Export.DataType.ASSESSMENTS},{Export.DataType.ENTRIES}') - assert response.json()['count'] == 3 + response = self.client.get(f"/api/v1/exports/?type={Export.DataType.ASSESSMENTS}") + assert response.json()["count"] == 2 + response = self.client.get(f"/api/v1/exports/?type={Export.DataType.ASSESSMENTS},{Export.DataType.ENTRIES}") + assert response.json()["count"] == 3 def test_export_filter_by_exported_at(self): now = timezone.now() @@ -152,12 +147,12 @@ def test_export_filter_by_exported_at(self): self.update_obj(self.create(Export, exported_by=self.user), exported_at=now + relativedelta(days=day)) self.update_obj(self.create(Export, exported_by=self.user), exported_at=now) - params = {'exported_at__gte': now.strftime('%Y-%m-%d%z')} - url = '/api/v1/exports/' + params = {"exported_at__gte": now.strftime("%Y-%m-%d%z")} + url = "/api/v1/exports/" self.authenticate() respose = self.client.get(url, params) self.assert_200(respose) - self.assertEqual(len(respose.data['results']), 4) + self.assertEqual(len(respose.data["results"]), 4) def test_export_filter_by_archived(self): self.create(Export, exported_by=self.user, is_archived=False) @@ -166,22 +161,22 @@ def test_export_filter_by_archived(self): self.create(Export, exported_by=self.user, is_archived=False) self.authenticate() - response = self.client.get(f'/api/v1/exports/?is_archived={True}') - assert response.json()['count'] == 1 + response = self.client.get(f"/api/v1/exports/?is_archived={True}") + assert response.json()["count"] == 1 def test_export_cancel(self): for initial_status, final_status in [ - (Export.Status.PENDING, Export.Status.CANCELED), - (Export.Status.STARTED, Export.Status.CANCELED), - (Export.Status.SUCCESS, Export.Status.SUCCESS), - (Export.Status.FAILURE, Export.Status.FAILURE), - (Export.Status.CANCELED, Export.Status.CANCELED), + (Export.Status.PENDING, Export.Status.CANCELED), + (Export.Status.STARTED, Export.Status.CANCELED), + (Export.Status.SUCCESS, Export.Status.SUCCESS), + (Export.Status.FAILURE, Export.Status.FAILURE), + (Export.Status.CANCELED, Export.Status.CANCELED), ]: export = self.create(Export, status=initial_status, exported_by=self.user, is_archived=False) - url = '/api/v1/exports/{}/cancel/'.format(export.id) + url = "/api/v1/exports/{}/cancel/".format(export.id) # without export.set_task_id('this-is-random-id'), it will not throw error self.authenticate() response = self.client.post(url) self.assert_200(response) - self.assertEqual(response.data['status'], final_status) + self.assertEqual(response.data["status"], final_status) diff --git a/apps/export/tests/test_mutations.py b/apps/export/tests/test_mutations.py index 29da172544..9e7fe68fc0 100644 --- a/apps/export/tests/test_mutations.py +++ b/apps/export/tests/test_mutations.py @@ -1,25 +1,23 @@ import datetime - from unittest.mock import patch -from utils.graphene.tests import GraphQLTestCase, GraphQLSnapShotTestCase - -from user.factories import UserFactory -from project.factories import ProjectFactory -from export.factories import ExportFactory from analysis.factories import AnalysisFactory from analysis_framework.factories import AnalysisFrameworkFactory -from lead.factories import LeadFactory from entry.factories import EntryFactory - -from lead.models import Lead +from export.factories import ExportFactory from export.models import Export, GenericExport, export_upload_to -from export.tasks import get_export_filename from export.serializers import UserExportCreateGqlSerializer +from export.tasks import get_export_filename +from lead.factories import LeadFactory +from lead.models import Lead +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase, GraphQLTestCase class TestExportMutationSchema(GraphQLTestCase): - CREATE_EXPORT_QUERY = ''' + CREATE_EXPORT_QUERY = """ mutation MyMutation ($projectId: ID!, $input: ExportCreateInputType!) { project(id: $projectId) { exportCreate(data: $input) { @@ -128,9 +126,9 @@ class TestExportMutationSchema(GraphQLTestCase): } } } - ''' + """ - UPDATE_EXPORT_QUERY = ''' + UPDATE_EXPORT_QUERY = """ mutation MyMutation ($projectId: ID!, $exportId: ID!, $input: ExportUpdateInputType!) { project(id: $projectId) { exportUpdate(id: $exportId, data: $input) { @@ -239,9 +237,9 @@ class TestExportMutationSchema(GraphQLTestCase): } } } - ''' + """ - CANCEL_EXPORT_QUERY = ''' + CANCEL_EXPORT_QUERY = """ mutation MyMutation ($projectId: ID!, $exportId: ID!) { project(id: $projectId) { exportCancel(id: $exportId) { @@ -346,9 +344,9 @@ class TestExportMutationSchema(GraphQLTestCase): } } } - ''' + """ - DELETE_EXPORT_QUERY = ''' + DELETE_EXPORT_QUERY = """ mutation MyMutation ($projectId: ID!, $exportId: ID!) { project(id: $projectId) { exportDelete(id: $exportId) { @@ -453,7 +451,7 @@ class TestExportMutationSchema(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -477,60 +475,55 @@ def test_export_create(self): """ This test makes sure only valid users can create export """ + def _query_check(minput, **kwargs): - return self.query_check( - self.CREATE_EXPORT_QUERY, - minput=minput, - variables={'projectId': self.project.id}, - **kwargs - ) + return self.query_check(self.CREATE_EXPORT_QUERY, minput=minput, variables={"projectId": self.project.id}, **kwargs) minput = dict( format=self.genum(Export.Format.PDF), type=self.genum(Export.DataType.ENTRIES), - title='Export 101', + title="Export 101", exportType=self.genum(Export.ExportType.EXCEL), isPreview=False, - filters={ - 'ids': [], - 'search': None, - 'statuses': [ + "ids": [], + "search": None, + "statuses": [ self.genum(Lead.Status.NOT_TAGGED), self.genum(Lead.Status.IN_PROGRESS), self.genum(Lead.Status.TAGGED), ], - 'assignees': None, - 'priorities': None, - 'createdAtGte': '2021-11-01T00:00:00Z', - 'createdAtLte': '2021-01-01T00:00:00Z', - 'confidentiality': None, - 'publishedOnGte': None, - 'publishedOnLte': None, - 'excludeProvidedLeadsId': True, - 'authoringOrganizationTypes': None, - 'hasEntries': True, - 'entriesFilterData': { - 'controlled': None, - 'createdBy': None, - 'entryTypes': None, - 'filterableData': [ + "assignees": None, + "priorities": None, + "createdAtGte": "2021-11-01T00:00:00Z", + "createdAtLte": "2021-01-01T00:00:00Z", + "confidentiality": None, + "publishedOnGte": None, + "publishedOnLte": None, + "excludeProvidedLeadsId": True, + "authoringOrganizationTypes": None, + "hasEntries": True, + "entriesFilterData": { + "controlled": None, + "createdBy": None, + "entryTypes": None, + "filterableData": [ { - 'filterKey': 'random-element-1', - 'value': None, - 'valueGte': None, - 'valueLte': None, - 'valueList': [ - 'random-value-1', - 'random-value-2', - 'random-value-3', - 'random-value-4', + "filterKey": "random-element-1", + "value": None, + "valueGte": None, + "valueLte": None, + "valueList": [ + "random-value-1", + "random-value-2", + "random-value-3", + "random-value-4", ], - 'useExclude': None, - 'useAndOperator': None, - 'includeSubRegions': None, + "useExclude": None, + "useAndOperator": None, + "includeSubRegions": None, } - ] + ], }, }, ) @@ -544,54 +537,49 @@ def _query_check(minput, **kwargs): # --- member user self.force_login(self.member_user) # ----- (Simple validation) - response = _query_check(minput, okay=False)['data'] - self.assertEqual(response['project']['exportCreate']['result'], None, response) + response = _query_check(minput, okay=False)["data"] + self.assertEqual(response["project"]["exportCreate"]["result"], None, response) # ----- - minput['format'] = self.genum(Export.Format.XLSX) - response = _query_check(minput)['data'] - response_export = response['project']['exportCreate']['result'] + minput["format"] = self.genum(Export.Format.XLSX) + response = _query_check(minput)["data"] + response_export = response["project"]["exportCreate"]["result"] self.assertNotEqual(response_export, None, response) - export = Export.objects.get(pk=response_export['id']) + export = Export.objects.get(pk=response_export["id"]) excepted_filters = { - 'ids': [], - 'search': None, - 'statuses': [ - 'pending', - 'processed', - 'validated', + "ids": [], + "search": None, + "statuses": [ + "pending", + "processed", + "validated", ], - 'assignees': None, - 'priorities': None, - 'created_at_gte': '2021-11-01T00:00:00Z', - 'created_at_lte': '2021-01-01T00:00:00Z', - 'confidentiality': None, - 'published_on_gte': None, - 'published_on_lte': None, - 'exclude_provided_leads_id': True, - 'authoring_organization_types': None, - 'has_entries': True, - 'entries_filter_data': { - 'controlled': None, - 'created_by': None, - 'entry_types': None, - 'filterable_data': [ + "assignees": None, + "priorities": None, + "created_at_gte": "2021-11-01T00:00:00Z", + "created_at_lte": "2021-01-01T00:00:00Z", + "confidentiality": None, + "published_on_gte": None, + "published_on_lte": None, + "exclude_provided_leads_id": True, + "authoring_organization_types": None, + "has_entries": True, + "entries_filter_data": { + "controlled": None, + "created_by": None, + "entry_types": None, + "filterable_data": [ { - 'value': None, - 'value_gte': None, - 'value_lte': None, - 'filter_key': 'random-element-1', - 'value_list': [ - 'random-value-1', - 'random-value-2', - 'random-value-3', - 'random-value-4' - ], - 'use_exclude': None, - 'use_and_operator': None, - 'include_sub_regions': None + "value": None, + "value_gte": None, + "value_lte": None, + "filter_key": "random-element-1", + "value_list": ["random-value-1", "random-value-2", "random-value-3", "random-value-4"], + "use_exclude": None, + "use_and_operator": None, + "include_sub_regions": None, } - ] + ], }, } # Make sure the filters are stored in db properly @@ -603,7 +591,7 @@ def test_export_update(self): """ export = ExportFactory.create(exported_by=self.member_user, **self.common_export_attrs) export_2 = ExportFactory.create( - title='Export 2', + title="Export 2", exported_by=self.member_user, **self.common_export_attrs, ) @@ -612,12 +600,12 @@ def _query_check(minput, **kwargs): return self.query_check( self.UPDATE_EXPORT_QUERY, minput=minput, - mnested=['project'], + mnested=["project"], variables={ - 'projectId': self.project.id, - 'exportId': export.id, + "projectId": self.project.id, + "exportId": export.id, }, - **kwargs + **kwargs, ) # Snapshot @@ -634,37 +622,28 @@ def _query_check(minput, **kwargs): # --- member user self.force_login(self.member_user) # ----- - minput['title'] = 'Export 1 (Updated)' - response = _query_check(minput, okay=True)['data'] - response_export = response['project']['exportUpdate']['result'] + minput["title"] = "Export 1 (Updated)" + response = _query_check(minput, okay=True)["data"] + response_export = response["project"]["exportUpdate"]["result"] self.assertNotEqual(response_export, None, response) export.refresh_from_db() updated_export_data = UserExportCreateGqlSerializer(export).data # Make sure the filters are stored in db properly self.assertNotEqual(updated_export_data, export_data, response) - export_data['title'] = minput['title'] + export_data["title"] = minput["title"] self.assertEqual(updated_export_data, export_data, response) def test_analysis_export(self): # create analysis - analysis1 = AnalysisFactory.create( - project=self.project, - end_date=datetime.datetime.now(), - team_lead=self.member_user - ) + analysis1 = AnalysisFactory.create(project=self.project, end_date=datetime.datetime.now(), team_lead=self.member_user) def _query_check(minput, **kwargs): - return self.query_check( - self.CREATE_EXPORT_QUERY, - minput=minput, - variables={'projectId': self.project.id}, - **kwargs - ) + return self.query_check(self.CREATE_EXPORT_QUERY, minput=minput, variables={"projectId": self.project.id}, **kwargs) minput = dict( format=self.genum(Export.Format.XLSX), type=self.genum(Export.DataType.ANALYSES), - title='Analysis Export 100', + title="Analysis Export 100", exportType=self.genum(Export.ExportType.EXCEL), analysis=analysis1.id, filters={}, @@ -679,9 +658,9 @@ def _query_check(minput, **kwargs): # --- member user self.force_login(self.member_user) - response = _query_check(minput, okay=False)['data'] - self.assertNotEqual(response['project']['exportCreate']['result'], None, response) - self.assertEqual(response['project']['exportCreate']['result']['analysis']['title'], analysis1.title) + response = _query_check(minput, okay=False)["data"] + self.assertNotEqual(response["project"]["exportCreate"]["result"], None, response) + self.assertEqual(response["project"]["exportCreate"]["result"]["analysis"]["title"], analysis1.title) # TODO: Add test case for file check @@ -689,17 +668,18 @@ def test_export_cancel(self): """ This test makes sure only valid users can cancel export """ + def _query_check(export, **kwargs): return self.query_check( - self.CANCEL_EXPORT_QUERY, - variables={'projectId': self.project.id, 'exportId': export.id}, - **kwargs + self.CANCEL_EXPORT_QUERY, variables={"projectId": self.project.id, "exportId": export.id}, **kwargs ) export_pending = ExportFactory.create( - exported_by=self.member_user, status=Export.Status.PENDING, **self.common_export_attrs) + exported_by=self.member_user, status=Export.Status.PENDING, **self.common_export_attrs + ) export_failed = ExportFactory.create( - exported_by=self.member_user, status=Export.Status.FAILURE, **self.common_export_attrs) + exported_by=self.member_user, status=Export.Status.FAILURE, **self.common_export_attrs + ) export2 = ExportFactory.create(exported_by=self.another_member_user, **self.common_export_attrs) # -- Without login @@ -719,21 +699,20 @@ def _query_check(export, **kwargs): # --- member user (with ownership) self.force_login(self.member_user) - content = _query_check(export_failed)['data']['project']['exportCancel']['result'] - self.assertEqual(content['status'], self.genum(Export.Status.FAILURE), content) + content = _query_check(export_failed)["data"]["project"]["exportCancel"]["result"] + self.assertEqual(content["status"], self.genum(Export.Status.FAILURE), content) - content = _query_check(export_pending)['data']['project']['exportCancel']['result'] - self.assertEqual(content['status'], self.genum(Export.Status.CANCELED), content) + content = _query_check(export_pending)["data"]["project"]["exportCancel"]["result"] + self.assertEqual(content["status"], self.genum(Export.Status.CANCELED), content) def test_export_delete(self): """ This test makes sure only valid users can delete export """ + def _query_check(export, **kwargs): return self.query_check( - self.DELETE_EXPORT_QUERY, - variables={'projectId': self.project.id, 'exportId': export.id}, - **kwargs + self.DELETE_EXPORT_QUERY, variables={"projectId": self.project.id, "exportId": export.id}, **kwargs ) export1 = ExportFactory.create(exported_by=self.member_user, **self.common_export_attrs) @@ -756,15 +735,15 @@ def _query_check(export, **kwargs): # --- member user (with ownership) self.force_login(self.member_user) - content = _query_check(export1)['data']['project']['exportDelete']['result'] - self.assertEqual(content['id'], str(export1.id), content) + content = _query_check(export1)["data"]["project"]["exportDelete"]["result"] + self.assertEqual(content["id"], str(export1.id), content) class TestGenericExportMutationSchema(GraphQLSnapShotTestCase): factories_used = [AnalysisFrameworkFactory, ProjectFactory, LeadFactory, UserFactory] ENABLE_NOW_PATCHER = True - CREATE_GENERIC_EXPORT_QUERY = ''' + CREATE_GENERIC_EXPORT_QUERY = """ mutation MyMutation ($input: GenericExportCreateInputType!) { genericExportCreate(data: $input) { ok @@ -789,7 +768,7 @@ class TestGenericExportMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -815,11 +794,7 @@ def setUp(self): def test_project_stats(self): def _query_check(minput, **kwargs): - return self.query_check( - self.CREATE_GENERIC_EXPORT_QUERY, - minput=minput, - **kwargs - ) + return self.query_check(self.CREATE_GENERIC_EXPORT_QUERY, minput=minput, **kwargs) minput = dict( format=self.genum(GenericExport.Format.CSV), @@ -838,64 +813,57 @@ def _query_check(minput, **kwargs): self.force_login(self.user) with self.captureOnCommitCallbacks(execute=True): - response = _query_check(minput, okay=True)['data'] - self.assertNotEqual(response['genericExportCreate']['result'], None, response) - generic_export = GenericExport.objects.get(pk=response['genericExportCreate']['result']['id']) + response = _query_check(minput, okay=True)["data"] + self.assertNotEqual(response["genericExportCreate"]["result"], None, response) + generic_export = GenericExport.objects.get(pk=response["genericExportCreate"]["result"]["id"]) self.assertEqual(generic_export.status, GenericExport.Status.SUCCESS, response) self.assertNotEqual(generic_export.file.name, None, response) - self.assertMatchSnapshot(generic_export.file.read().decode('utf-8'), 'generic-export-csv') + self.assertMatchSnapshot(generic_export.file.read().decode("utf-8"), "generic-export-csv") class GeneraltestCase(GraphQLTestCase): def test_export_path_generation(self): - MOCK_TIME_STR = '20211205' - MOCK_RANDOM_STRING = 'random-string' + MOCK_TIME_STR = "20211205" + MOCK_RANDOM_STRING = "random-string" user = UserFactory.create() project = ProjectFactory.create() common_args = { - 'type': Export.DataType.ENTRIES, - 'exported_by': user, - 'project': project, + "type": Export.DataType.ENTRIES, + "exported_by": user, + "project": project, } - with \ - patch('export.models.get_random_string') as get_random_string_mock, \ - patch('export.models.timezone') as timezone_mock: + with patch("export.models.get_random_string") as get_random_string_mock, patch("export.models.timezone") as timezone_mock: get_random_string_mock.return_value = MOCK_RANDOM_STRING timezone_mock.now.return_value.strftime.return_value = MOCK_TIME_STR for export, expected_title, expected_filename, _type in [ ( - ExportFactory( - title='', - format=Export.Format.DOCX, - export_type=Export.ExportType.REPORT, - **common_args - ), - f'{MOCK_TIME_STR} DEEP Entries General Export', - f'{MOCK_TIME_STR} DEEP Entries General Export.docx', - 'without-title', + ExportFactory(title="", format=Export.Format.DOCX, export_type=Export.ExportType.REPORT, **common_args), + f"{MOCK_TIME_STR} DEEP Entries General Export", + f"{MOCK_TIME_STR} DEEP Entries General Export.docx", + "without-title", ), ( ExportFactory( - title='test 123', + title="test 123", format=Export.Format.PDF, export_type=Export.ExportType.REPORT, **common_args, ), - 'test 123', - 'test 123.pdf', - 'with-title-01', + "test 123", + "test 123.pdf", + "with-title-01", ), ( ExportFactory( - title='test 321', + title="test 321", format=Export.Format.JSON, export_type=Export.ExportType.JSON, is_preview=True, **common_args, ), - 'test 321', - '(Preview) test 321.json', - 'with-title-02', + "test 321", + "(Preview) test 321.json", + "with-title-02", ), ]: export.save() @@ -903,4 +871,4 @@ def test_export_path_generation(self): # export.title = export.title or generated_title # This is automatically done on export save (mocking here) generated_filename = export_upload_to(export, get_export_filename(export)) self.assertEqual(export.title, expected_title, _type) - self.assertEqual(generated_filename, f'export/{MOCK_RANDOM_STRING}/{expected_filename}', _type) + self.assertEqual(generated_filename, f"export/{MOCK_RANDOM_STRING}/{expected_filename}", _type) diff --git a/apps/export/tests/test_schemas.py b/apps/export/tests/test_schemas.py index 9645bb96c1..f6e31b5a8f 100644 --- a/apps/export/tests/test_schemas.py +++ b/apps/export/tests/test_schemas.py @@ -1,10 +1,9 @@ -from utils.graphene.tests import GraphQLTestCase - -from user.factories import UserFactory +from export.factories import ExportFactory +from export.models import Export from project.factories import ProjectFactory +from user.factories import UserFactory -from export.models import Export -from export.factories import ExportFactory +from utils.graphene.tests import GraphQLTestCase class TestExportQuerySchema(GraphQLTestCase): @@ -12,7 +11,7 @@ def test_export_query(self): """ Test export for project """ - query = ''' + query = """ query MyQuery ($projectId: ID! $exportId: ID!) { project(id: $projectId) { export (id: $exportId) { @@ -21,7 +20,7 @@ def test_export_query(self): } } } - ''' + """ project = ProjectFactory.create() project2 = ProjectFactory.create() @@ -33,7 +32,7 @@ def test_export_query(self): other_export = ExportFactory.create(project=project2, exported_by=user2) def _query_check(export, **kwargs): - return self.query_check(query, variables={'projectId': project.id, 'exportId': export.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id, "exportId": export.id}, **kwargs) # -- Without login _query_check(export, assert_for_error=True) @@ -41,15 +40,15 @@ def _query_check(export, **kwargs): # --- With login self.force_login(user) content = _query_check(export) - self.assertNotEqual(content['data']['project']['export'], None, content) - self.assertEqual(content['data']['project']['export']['id'], str(export.id)) + self.assertNotEqual(content["data"]["project"]["export"], None, content) + self.assertEqual(content["data"]["project"]["export"]["id"], str(export.id)) self.force_login(user) content = _query_check(other_export) - self.assertEqual(content['data']['project']['export'], None, content) + self.assertEqual(content["data"]["project"]["export"], None, content) def test_exports_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { exports { @@ -63,7 +62,7 @@ def test_exports_query(self): } } } - ''' + """ project = ProjectFactory.create() project2 = ProjectFactory.create() user = UserFactory.create() @@ -74,7 +73,7 @@ def test_exports_query(self): ExportFactory.create_batch(8, project=project2, exported_by=user2) def _query_check(**kwargs): - return self.query_check(query, variables={'id': project.id}, **kwargs) + return self.query_check(query, variables={"id": project.id}, **kwargs) # --- Without login _query_check(assert_for_error=True) @@ -82,17 +81,17 @@ def _query_check(**kwargs): # --- With login self.force_login(user) content = _query_check() - self.assertEqual(content['data']['project']['exports']['totalCount'], 6, content) - self.assertEqual(len(content['data']['project']['exports']['results']), 6, content) + self.assertEqual(content["data"]["project"]["exports"]["totalCount"], 6, content) + self.assertEqual(len(content["data"]["project"]["exports"]["results"]), 6, content) # --- With login by user whose has not exported the export self.force_login(user2) content = _query_check() - self.assertEqual(content['data']['project']['exports']['totalCount'], 0, content) - self.assertEqual(len(content['data']['project']['exports']['results']), 0, content) + self.assertEqual(content["data"]["project"]["exports"]["totalCount"], 0, content) + self.assertEqual(len(content["data"]["project"]["exports"]["results"]), 0, content) def test_exports_type_filter(self): - query = ''' + query = """ query MyQuery ($id: ID!, $type: [ExportDataTypeEnum!]) { project(id: $id) { exports(type: $type){ @@ -106,7 +105,7 @@ def test_exports_type_filter(self): } } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() project.add_member(user, role=self.project_role_reader_non_confidential) @@ -114,13 +113,7 @@ def test_exports_type_filter(self): ExportFactory.create_batch(2, project=project, exported_by=user, type=Export.DataType.ASSESSMENTS) def _query_check(**kwargs): - return self.query_check( - query, - variables={ - 'id': project.id, - 'type': [self.genum(Export.DataType.ENTRIES)] - }, - **kwargs) + return self.query_check(query, variables={"id": project.id, "type": [self.genum(Export.DataType.ENTRIES)]}, **kwargs) # --- Without login _query_check(assert_for_error=True) @@ -128,11 +121,11 @@ def _query_check(**kwargs): # --- With login self.force_login(user) content = _query_check() - self.assertEqual(content['data']['project']['exports']['totalCount'], 6, content) - self.assertEqual(len(content['data']['project']['exports']['results']), 6, content) + self.assertEqual(content["data"]["project"]["exports"]["totalCount"], 6, content) + self.assertEqual(len(content["data"]["project"]["exports"]["results"]), 6, content) def test_exports_status_filter(self): - query = ''' + query = """ query MyQuery ($id: ID!, $status: [ExportStatusEnum!]) { project(id: $id) { exports(status: $status){ @@ -146,7 +139,7 @@ def test_exports_status_filter(self): } } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() project.add_member(user, role=self.project_role_reader_non_confidential) @@ -155,13 +148,7 @@ def test_exports_status_filter(self): ExportFactory.create_batch(3, project=project, exported_by=user, status=Export.Status.SUCCESS) def _query_check(**kwargs): - return self.query_check( - query, - variables={ - 'id': project.id, - 'status': [self.genum(Export.Status.PENDING)] - }, - **kwargs) + return self.query_check(query, variables={"id": project.id, "status": [self.genum(Export.Status.PENDING)]}, **kwargs) # --- Without login _query_check(assert_for_error=True) @@ -169,19 +156,17 @@ def _query_check(**kwargs): # --- With login self.force_login(user) content = _query_check() - self.assertEqual(content['data']['project']['exports']['totalCount'], 4, content) - self.assertEqual(len(content['data']['project']['exports']['results']), 4, content) + self.assertEqual(content["data"]["project"]["exports"]["totalCount"], 4, content) + self.assertEqual(len(content["data"]["project"]["exports"]["results"]), 4, content) def _query_check(**kwargs): return self.query_check( query, - variables={ - 'id': project.id, - 'status': [self.genum(Export.Status.PENDING), self.genum(Export.Status.STARTED)] - }, - **kwargs) + variables={"id": project.id, "status": [self.genum(Export.Status.PENDING), self.genum(Export.Status.STARTED)]}, + **kwargs, + ) self.force_login(user) content = _query_check() - self.assertEqual(content['data']['project']['exports']['totalCount'], 6, content) - self.assertEqual(len(content['data']['project']['exports']['results']), 6, content) + self.assertEqual(content["data"]["project"]["exports"]["totalCount"], 6, content) + self.assertEqual(len(content["data"]["project"]["exports"]["results"]), 6, content) diff --git a/apps/export/tests/test_xlsx.py b/apps/export/tests/test_xlsx.py index bae6342102..5afaa69b17 100644 --- a/apps/export/tests/test_xlsx.py +++ b/apps/export/tests/test_xlsx.py @@ -4,31 +4,31 @@ class RowsBuilderTest(TestCase): def test_rows(self): - builder = RowsBuilder()\ - .add_value('Hello')\ - .add_value_list(['My', 'Name'])\ - .add_rows_of_values(['Is', 'Not', 'Jon'])\ - .add_rows_of_values(['1', '2'])\ - .add_rows_of_value_lists([['3', '4'], ['5', '6']]) + builder = ( + RowsBuilder() + .add_value("Hello") + .add_value_list(["My", "Name"]) + .add_rows_of_values(["Is", "Not", "Jon"]) + .add_rows_of_values(["1", "2"]) + .add_rows_of_value_lists([["3", "4"], ["5", "6"]]) + ) result = [ - ['Hello', 'My', 'Name', 'Is', '1', '3', '4'], - ['Hello', 'My', 'Name', 'Not', '1', '3', '4'], - ['Hello', 'My', 'Name', 'Jon', '1', '3', '4'], - ['Hello', 'My', 'Name', 'Is', '2', '3', '4'], - ['Hello', 'My', 'Name', 'Not', '2', '3', '4'], - ['Hello', 'My', 'Name', 'Jon', '2', '3', '4'], - ['Hello', 'My', 'Name', 'Is', '1', '5', '6'], - ['Hello', 'My', 'Name', 'Not', '1', '5', '6'], - ['Hello', 'My', 'Name', 'Jon', '1', '5', '6'], - ['Hello', 'My', 'Name', 'Is', '2', '5', '6'], - ['Hello', 'My', 'Name', 'Not', '2', '5', '6'], - ['Hello', 'My', 'Name', 'Jon', '2', '5', '6'], + ["Hello", "My", "Name", "Is", "1", "3", "4"], + ["Hello", "My", "Name", "Not", "1", "3", "4"], + ["Hello", "My", "Name", "Jon", "1", "3", "4"], + ["Hello", "My", "Name", "Is", "2", "3", "4"], + ["Hello", "My", "Name", "Not", "2", "3", "4"], + ["Hello", "My", "Name", "Jon", "2", "3", "4"], + ["Hello", "My", "Name", "Is", "1", "5", "6"], + ["Hello", "My", "Name", "Not", "1", "5", "6"], + ["Hello", "My", "Name", "Jon", "1", "5", "6"], + ["Hello", "My", "Name", "Is", "2", "5", "6"], + ["Hello", "My", "Name", "Not", "2", "5", "6"], + ["Hello", "My", "Name", "Jon", "2", "5", "6"], ] - group_result = [ - 'Hello', 'My', 'Name', 'Is, Not, Jon', '1, 2', '3, 5', '4, 6' - ] + group_result = ["Hello", "My", "Name", "Is, Not, Jon", "1, 2", "3, 5", "4, 6"] self.assertEqual(result, builder.rows) self.assertEqual(group_result, builder.group_rows) diff --git a/apps/export/views.py b/apps/export/views.py index 5dd37589fb..f118b3c135 100644 --- a/apps/export/views.py +++ b/apps/export/views.py @@ -1,23 +1,14 @@ from django.db import transaction -from rest_framework.decorators import action -from rest_framework import ( - permissions, - response, - views, - viewsets, - status, -) - -from deep.celery import app as celery_app -from export.serializers import ExportSerializer +from export.filter_set import ExportFilterSet from export.models import Export +from export.serializers import ExportSerializer +from export.tasks import export_task from project.models import Project from project.permissions import PROJECT_PERMISSIONS -from export.filter_set import ( - ExportFilterSet, -) +from rest_framework import permissions, response, status, views, viewsets +from rest_framework.decorators import action -from export.tasks import export_task +from deep.celery import app as celery_app class MetaExtractionView(views.APIView): @@ -34,14 +25,14 @@ class ExportViewSet(viewsets.ModelViewSet): def get_queryset(self): qs = Export.get_for(self.request.user) - if self.action == 'list': + if self.action == "list": return qs.filter(is_preview=False) return qs @action( detail=True, - url_path='cancel', - methods=('post',), + url_path="cancel", + methods=("post",), ) def cancel(self, request, pk=None, version=None): export = self.get_object() @@ -62,31 +53,28 @@ class ExportTriggerView(views.APIView): permission_classes = [permissions.IsAuthenticated] def post(self, request, version=None): - filters = request.data.get('filters', []) + filters = request.data.get("filters", []) filters = {f[0]: f[1] for f in filters} - project_id = filters.get('project') - export_type = filters.get('export_type', 'excel') - export_item = filters.get('export_item', 'entry') + project_id = filters.get("project") + export_type = filters.get("export_type", "excel") + export_item = filters.get("export_item", "entry") - is_preview = filters.get('is_preview', False) + is_preview = filters.get("is_preview", False) if project_id: project = Project.objects.get(id=project_id) else: project = None - if export_item == 'entry': + if export_item == "entry": type = Export.DataType.ENTRIES - elif export_item == 'assessment': + elif export_item == "assessment": type = Export.DataType.ASSESSMENTS - elif export_item == 'planned_assessment': + elif export_item == "planned_assessment": type = Export.DataType.PLANNED_ASSESSMENTS else: - return response.Response( - {'export_item': 'Invalid export item name'}, - status=status.HTTP_400_BAD_REQUEST - ) + return response.Response({"export_item": "Invalid export item name"}, status=status.HTTP_400_BAD_REQUEST) if project: # Check permission @@ -102,7 +90,7 @@ def post(self, request, version=None): return response.Response({}, status=status.HTTP_403_FORBIDDEN) export = Export.objects.create( - title='Generating Export.....', + title="Generating Export.....", exported_by=request.user, project=project, type=type, @@ -111,10 +99,10 @@ def post(self, request, version=None): filters=filters, ) - transaction.on_commit( - lambda: export.set_task_id(export_task.delay(export.id).id) - ) + transaction.on_commit(lambda: export.set_task_id(export_task.delay(export.id).id)) - return response.Response({ - 'export_triggered': export.id, - }) + return response.Response( + { + "export_triggered": export.id, + } + ) diff --git a/apps/gallery/admin.py b/apps/gallery/admin.py index e95ff4ae25..02c1e285e8 100644 --- a/apps/gallery/admin.py +++ b/apps/gallery/admin.py @@ -3,14 +3,26 @@ from deep.admin import document_preview -from .models import File from .filters import IsTabularListFilter +from .models import File @admin.register(File) class FileAdmin(VersionAdmin): - list_display = ('title', 'file', 'mime_type',) - readonly_fields = (document_preview('file'),) - search_fields = ('title', 'file', 'mime_type', ) + list_display = ( + "title", + "file", + "mime_type", + ) + readonly_fields = (document_preview("file"),) + search_fields = ( + "title", + "file", + "mime_type", + ) list_filter = (IsTabularListFilter,) - autocomplete_fields = ('created_by', 'modified_by', 'projects',) + autocomplete_fields = ( + "created_by", + "modified_by", + "projects", + ) diff --git a/apps/gallery/apps.py b/apps/gallery/apps.py index 07f5d13453..fac387513e 100644 --- a/apps/gallery/apps.py +++ b/apps/gallery/apps.py @@ -2,4 +2,4 @@ class GalleryConfig(AppConfig): - name = 'gallery' + name = "gallery" diff --git a/apps/gallery/dataloaders.py b/apps/gallery/dataloaders.py index 69c9700796..8f24b2aa81 100644 --- a/apps/gallery/dataloaders.py +++ b/apps/gallery/dataloaders.py @@ -1,18 +1,14 @@ -from promise import Promise from django.utils.functional import cached_property +from gallery.models import File +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from gallery.models import File - class GalleryFileLoader(DataLoaderWithContext): def batch_load_fn(self, keys): qs = File.objects.filter(pk__in=keys) - _map = { - item.pk: item - for item in qs - } + _map = {item.pk: item for item in qs} return Promise.resolve([_map.get(key) for key in keys]) diff --git a/apps/gallery/factories.py b/apps/gallery/factories.py index 99f7ce56b3..3367836e5b 100644 --- a/apps/gallery/factories.py +++ b/apps/gallery/factories.py @@ -1,8 +1,8 @@ import uuid import factory -from factory.django import DjangoModelFactory from django.core.files.base import ContentFile +from factory.django import DjangoModelFactory from .models import File @@ -12,19 +12,17 @@ class Meta: model = File uuid = factory.LazyAttribute(lambda x: str(uuid.uuid4())) - title = factory.Sequence(lambda n: f'file-{n}') + title = factory.Sequence(lambda n: f"file-{n}") file = factory.LazyAttribute( - lambda _: ContentFile( - factory.django.ImageField()._make_data( - {'width': 1024, 'height': 768} - ), 'example.jpg' - ) + lambda _: ContentFile(factory.django.ImageField()._make_data({"width": 1024, "height": 768}), "example.jpg") ) # is_public = factory.Iterator([True, False]) - mime_type = factory.Faker('mime_type') - metadata = factory.Dict({ - 'md5_hash': factory.Sequence(lambda n: f'random-hash-{n}'), - }) + mime_type = factory.Faker("mime_type") + metadata = factory.Dict( + { + "md5_hash": factory.Sequence(lambda n: f"random-hash-{n}"), + } + ) @factory.post_generation def addresses(self, create, extracted, **kwargs): diff --git a/apps/gallery/filters.py b/apps/gallery/filters.py index 6ef9ca00f3..bd62df1952 100644 --- a/apps/gallery/filters.py +++ b/apps/gallery/filters.py @@ -6,21 +6,21 @@ class IsTabularListFilter(admin.SimpleListFilter): # Human-readable title which will be displayed in the # right admin sidebar just above the filter options. - title = _('Is Tabular Image') + title = _("Is Tabular Image") # Parameter for the filter that will be used in the URL query. - parameter_name = 'is_tabular_image' + parameter_name = "is_tabular_image" def lookups(self, request, model_admin): return ( - ('yes', _('Yes')), - ('no', _('No')), + ("yes", _("Yes")), + ("no", _("No")), ) def queryset(self, request, queryset): - if self.value() == 'yes': + if self.value() == "yes": return queryset.filter(metadata__tabular=True) - elif self.value() == 'no': + elif self.value() == "no": return queryset.filter(~Q(metadata__tabular=True)) else: return queryset.all() diff --git a/apps/gallery/management/commands/calculate_and_store_file_size.py b/apps/gallery/management/commands/calculate_and_store_file_size.py index a986a18b33..aa424c2638 100644 --- a/apps/gallery/management/commands/calculate_and_store_file_size.py +++ b/apps/gallery/management/commands/calculate_and_store_file_size.py @@ -1,9 +1,8 @@ import botocore -from django.db.models.functions import Cast from django.contrib.postgres.fields.jsonb import KeyTextTransform from django.core.management.base import BaseCommand -from django.db.models import Q, IntegerField - +from django.db.models import IntegerField, Q +from django.db.models.functions import Cast from gallery.models import File @@ -11,20 +10,20 @@ class Command(BaseCommand): def handle(self, *args, **options): qs = File.objects.annotate( file_size=Cast( - KeyTextTransform('file_size', 'metadata'), + KeyTextTransform("file_size", "metadata"), IntegerField(), ) - ).filter(~Q(file=''), file_size__isnull=True) + ).filter(~Q(file=""), file_size__isnull=True) to_process_count = qs.count() index = 1 for file in qs.iterator(): file.metadata = file.metadata or {} try: - file.metadata['file_size'] = file.file.size - print(f'Processed {index}/{to_process_count}', end='\r', flush=True) - file.save(update_fields=['metadata']) + file.metadata["file_size"] = file.file.size + print(f"Processed {index}/{to_process_count}", end="\r", flush=True) + file.save(update_fields=["metadata"]) index += 1 except botocore.exceptions.ClientError: pass - print(f'\nProcessed: {index}/{to_process_count} files successfully') + print(f"\nProcessed: {index}/{to_process_count} files successfully") diff --git a/apps/gallery/models.py b/apps/gallery/models.py index a30f98a61c..3e45aa8863 100644 --- a/apps/gallery/models.py +++ b/apps/gallery/models.py @@ -1,7 +1,8 @@ import uuid as python_uuid + +from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.db import models -from django.conf import settings from django.urls import reverse from user_resource.models import UserResource @@ -10,13 +11,12 @@ class File(UserResource): uuid = models.UUIDField(default=python_uuid.uuid4, editable=False, unique=True) title = models.CharField(max_length=255) - file = models.FileField(upload_to='gallery/', max_length=255, - null=True, blank=True, default=None) + file = models.FileField(upload_to="gallery/", max_length=255, null=True, blank=True, default=None) mime_type = models.CharField(max_length=130, blank=True, null=True) metadata = models.JSONField(default=None, blank=True, null=True) is_public = models.BooleanField(default=False) - projects = models.ManyToManyField('project.Project', blank=True) + projects = models.ManyToManyField("project.Project", blank=True) def __str__(self): return self.title @@ -35,13 +35,13 @@ def can_get(self, user): # return self in File.get_for(user) def get_file_url(self): - return '{protocol}://{domain}{url}'.format( + return "{protocol}://{domain}{url}".format( protocol=settings.HTTP_PROTOCOL, domain=settings.DJANGO_API_HOST, url=reverse( - 'gallery_private_url', - kwargs={'uuid': self.uuid, 'filename': self.title}, - ) + "gallery_private_url", + kwargs={"uuid": self.uuid, "filename": self.title}, + ), ) def can_modify(self, user): @@ -55,4 +55,4 @@ class FilePreview(models.Model): extracted = models.BooleanField(default=False) def __str__(self): - return 'Text extracted for {}'.format(self.file) + return "Text extracted for {}".format(self.file) diff --git a/apps/gallery/mutations.py b/apps/gallery/mutations.py index 385202f7cb..e50afc7e37 100644 --- a/apps/gallery/mutations.py +++ b/apps/gallery/mutations.py @@ -1,20 +1,15 @@ import graphene from utils.graphene.mutation import ( - generate_input_type_for_serializer, PsGrapheneMutation, + generate_input_type_for_serializer, ) from .models import File from .schema import GalleryFileType -from .serializers import ( - FileSerializer -) +from .serializers import FileSerializer -FileUploadInputType = generate_input_type_for_serializer( - 'FileUploadInputType', - serializer_class=FileSerializer -) +FileUploadInputType = generate_input_type_for_serializer("FileUploadInputType", serializer_class=FileSerializer) class UploadFile(PsGrapheneMutation): @@ -27,5 +22,5 @@ class Arguments: permissions = [] -class Mutation(): +class Mutation: file_upload = UploadFile.Field() diff --git a/apps/gallery/schema.py b/apps/gallery/schema.py index e6f1d7e239..228dd37afa 100644 --- a/apps/gallery/schema.py +++ b/apps/gallery/schema.py @@ -2,6 +2,7 @@ from graphene_django import DjangoObjectType from utils.graphene.types import FileFieldType + from .models import File @@ -9,11 +10,12 @@ class GalleryFileType(DjangoObjectType): class Meta: model = File only_fields = ( - 'id', - 'title', - 'mime_type', - 'metadata', + "id", + "title", + "mime_type", + "metadata", ) + file = graphene.Field(FileFieldType) @staticmethod @@ -26,7 +28,11 @@ class PublicGalleryFileType(DjangoObjectType): class Meta: model = File skip_registry = True - only_fields = ('title', 'mime_type',) + only_fields = ( + "title", + "mime_type", + ) + file = graphene.Field(FileFieldType) @staticmethod diff --git a/apps/gallery/serializers.py b/apps/gallery/serializers.py index a0943cd18f..222f0ae4e6 100644 --- a/apps/gallery/serializers.py +++ b/apps/gallery/serializers.py @@ -1,77 +1,81 @@ -from drf_dynamic_fields import DynamicFieldsMixin -from rest_framework import serializers +import logging +import os from django.core.files.uploadedfile import InMemoryUploadedFile +from drf_dynamic_fields import DynamicFieldsMixin +from rest_framework import serializers +from user_resource.serializers import UserResourceSerializer -from deep.serializers import RemoveNullFieldsMixin, URLCachedFileField import deep.documents_types as deep_doc_types -from user_resource.serializers import UserResourceSerializer -from utils.external_storages.google_drive import download as g_download +from deep.serializers import RemoveNullFieldsMixin, URLCachedFileField +from utils.common import calculate_md5 from utils.external_storages.dropbox import download as d_download +from utils.external_storages.google_drive import download as g_download from utils.extractor.formats.docx import get_pages_in_docx from utils.extractor.formats.pdf import get_pages_in_pdf -from utils.common import calculate_md5 -from .models import File, FilePreview -import os -import logging +from .models import File, FilePreview logger = logging.getLogger(__name__) -FILE_READONLY_FIELDS = ('metadata', 'mime_type',) +FILE_READONLY_FIELDS = ( + "metadata", + "mime_type", +) -class SimpleFileSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleFileSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): title = serializers.CharField(required=False, read_only=True) file = URLCachedFileField(required=False, read_only=True) mime_type = serializers.CharField(required=False, read_only=True) class Meta: model = File - fields = ('id', 'title', 'file', 'mime_type') + fields = ("id", "title", "file", "mime_type") read_only_fields = FILE_READONLY_FIELDS -class FileSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): +class FileSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): file = URLCachedFileField(required=True, read_only=False) class Meta: model = File - exclude = ('created_by',) + exclude = ("created_by",) read_only_fields = FILE_READONLY_FIELDS # Validations def validate_file(self, file): extension = os.path.splitext(file.name)[1][1:] - if file.content_type not in deep_doc_types.DEEP_SUPPORTED_MIME_TYPES\ - and extension not in deep_doc_types.DEEP_SUPPORTED_EXTENSIONS: - raise serializers.ValidationError( - 'Unsupported file type {}'.format(file.content_type)) + if ( + file.content_type not in deep_doc_types.DEEP_SUPPORTED_MIME_TYPES + and extension not in deep_doc_types.DEEP_SUPPORTED_EXTENSIONS + ): + raise serializers.ValidationError("Unsupported file type {}".format(file.content_type)) return file def _get_metadata(self, file): - metadata = {'md5_hash': calculate_md5(file.file)} + metadata = {"md5_hash": calculate_md5(file.file)} mime_type = file.content_type if mime_type in deep_doc_types.PDF_MIME_TYPES: - metadata.update({ - 'pages': get_pages_in_pdf(file.file), - }) + metadata.update( + { + "pages": get_pages_in_pdf(file.file), + } + ) elif mime_type in deep_doc_types.DOCX_MIME_TYPES: - metadata.update({ - 'pages': get_pages_in_docx(file.file), - }) + metadata.update( + { + "pages": get_pages_in_docx(file.file), + } + ) return metadata def create(self, validated_data): - validated_data['mime_type'] = validated_data.get('file').content_type + validated_data["mime_type"] = validated_data.get("file").content_type try: - validated_data['metadata'] = self._get_metadata( - validated_data.get('file') - ) + validated_data["metadata"] = self._get_metadata(validated_data.get("file")) except Exception: - logger.error('File create Failed!!', exc_info=True) + logger.error("File create Failed!!", exc_info=True) return super().create(validated_data) @@ -83,13 +87,13 @@ class GoogleDriveFileSerializer(FileSerializer): class Meta: model = File - fields = ('__all__') + fields = "__all__" def create(self, validated_data): - title = validated_data.get('title') - access_token = validated_data.pop('access_token') - file_id = validated_data.pop('file_id') - mime_type = validated_data.get('mime_type', '') + title = validated_data.get("title") + access_token = validated_data.pop("access_token") + file_id = validated_data.pop("file_id") + mime_type = validated_data.get("mime_type", "") file = g_download( file_id, @@ -99,9 +103,7 @@ def create(self, validated_data): ) # TODO: is this good? - validated_data['file'] = InMemoryUploadedFile( - file, None, title, mime_type, None, None - ) + validated_data["file"] = InMemoryUploadedFile(file, None, title, mime_type, None, None) return super().create(validated_data) @@ -112,11 +114,11 @@ class DropboxFileSerializer(FileSerializer): class Meta: model = File - fields = ('__all__') + fields = "__all__" def create(self, validated_data): - title = validated_data.get('title') - file_url = validated_data.pop('file_url') + title = validated_data.get("title") + file_url = validated_data.pop("file_url") file, mime_type = d_download( file_url, @@ -124,17 +126,14 @@ def create(self, validated_data): ) # TODO: is this good? - validated_data['file'] = InMemoryUploadedFile( - file, None, title, mime_type, None, None - ) + validated_data["file"] = InMemoryUploadedFile(file, None, title, mime_type, None, None) - validated_data['mime_type'] = mime_type + validated_data["mime_type"] = mime_type return super().create(validated_data) -class FilePreviewSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class FilePreviewSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = FilePreview - fields = ('id', 'text', 'ngrams', 'extracted') + fields = ("id", "text", "ngrams", "extracted") diff --git a/apps/gallery/tasks.py b/apps/gallery/tasks.py index ab4b9c8c3c..ee73aef8b2 100644 --- a/apps/gallery/tasks.py +++ b/apps/gallery/tasks.py @@ -1,14 +1,12 @@ -import reversion import logging -from redis_store import redis + +import reversion from celery import shared_task +from gallery.models import File, FilePreview +from redis_store import redis -from utils.extractor.file_document import FileDocument from utils.common import sanitize_text -from gallery.models import ( - File, - FilePreview, -) +from utils.extractor.file_document import FileDocument logger = logging.getLogger(__name__) @@ -20,7 +18,7 @@ def _extract_from_file_core(file_preview_id): files = File.objects.filter(id__in=file_preview.file_ids) with reversion.create_revision(): - all_text = '' + all_text = "" for i, file in enumerate(files): try: @@ -32,10 +30,10 @@ def _extract_from_file_core(file_preview_id): text = sanitize_text(text) if i != 0: - all_text += '\n\n' + all_text += "\n\n" all_text += text except Exception: - logger.error('gallery._extract_from_file_core', exc_info=True) + logger.error("gallery._extract_from_file_core", exc_info=True) continue if all_text: file_preview.text = all_text @@ -47,7 +45,7 @@ def _extract_from_file_core(file_preview_id): @shared_task def extract_from_file(file_preview_id): - key = 'file_extraction_{}'.format(file_preview_id) + key = "file_extraction_{}".format(file_preview_id) lock = redis.get_lock(key, 60 * 60 * 24) # Lock lifetime 24 hours have_lock = lock.acquire(blocking=False) if not have_lock: @@ -56,7 +54,7 @@ def extract_from_file(file_preview_id): try: return_value = _extract_from_file_core(file_preview_id) except Exception: - logger.error('gallery.extract_from_file', exc_info=True) + logger.error("gallery.extract_from_file", exc_info=True) return_value = False lock.release() diff --git a/apps/gallery/tests/test_apis.py b/apps/gallery/tests/test_apis.py index 167849e166..38211ff362 100644 --- a/apps/gallery/tests/test_apis.py +++ b/apps/gallery/tests/test_apis.py @@ -1,16 +1,16 @@ import os import tempfile -from django.urls import reverse from django.conf import settings -from django.utils.http import urlsafe_base64_encode +from django.urls import reverse from django.utils.encoding import force_bytes - -from deep.tests import TestCase +from django.utils.http import urlsafe_base64_encode +from entry.models import Entry from gallery.models import File, FilePreview from lead.models import Lead from project.models import Project -from entry.models import Entry + +from deep.tests import TestCase class GalleryTests(TestCase): @@ -18,11 +18,11 @@ def setUp(self): super().setUp() tmp_file = tempfile.NamedTemporaryFile(delete=False) - tmp_file.write(b'Hello world') + tmp_file.write(b"Hello world") tmp_file.close() - path = os.path.join(settings.TEST_DIR, 'documents') - self.supported_file = os.path.join(path, 'doc.docx') + path = os.path.join(settings.TEST_DIR, "documents") + self.supported_file = os.path.join(path, "doc.docx") self.unsupported_file = tmp_file.name @@ -32,20 +32,20 @@ def tearDown(self): def test_upload_supported_file(self): file_count = File.objects.count() - url = '/api/v1/files/' + url = "/api/v1/files/" data = { - 'title': 'Test file', - 'file': open(self.supported_file, 'rb'), - 'isPublic': True, + "title": "Test file", + "file": open(self.supported_file, "rb"), + "isPublic": True, } self.authenticate() - response = self.client.post(url, data, format='multipart') + response = self.client.post(url, data, format="multipart") self.assert_201(response) self.assertEqual(File.objects.count(), file_count + 1) - self.assertEqual(response.data['title'], data['title']) + self.assertEqual(response.data["title"], data["title"]) # Let's delete the file from the filesystem to keep # things clean @@ -58,12 +58,12 @@ def test_upload_supported_file(self): def test_upload_unsupported_file(self): file_count = File.objects.count() - url = '/api/v1/files/' + url = "/api/v1/files/" data = { - 'title': 'Test file', - 'file': open(self.unsupported_file, 'rb'), - 'isPublic': True, + "title": "Test file", + "file": open(self.unsupported_file, "rb"), + "isPublic": True, } self.authenticate() @@ -73,70 +73,69 @@ def test_upload_unsupported_file(self): self.assertEqual(File.objects.count(), file_count) def test_trigger_api(self): - url = '/api/v1/file-extraction-trigger/' + url = "/api/v1/file-extraction-trigger/" data = { - 'file_ids': [1], + "file_ids": [1], } self.authenticate() response = self.client.post(url, data) self.assert_200(response) - self.assertTrue(FilePreview.objects.filter( - id=response.data['extraction_triggered'] - ).exists()) + self.assertTrue(FilePreview.objects.filter(id=response.data["extraction_triggered"]).exists()) def test_duplicate_trigger_api(self): preview = self.create(FilePreview, file_ids=[1, 2]) - url = '/api/v1/file-extraction-trigger/' + url = "/api/v1/file-extraction-trigger/" data = { - 'file_ids': [2, 1], + "file_ids": [2, 1], } self.authenticate() response = self.client.post(url, data) self.assert_200(response) - self.assertEqual(response.data['extraction_triggered'], preview.id) + self.assertEqual(response.data["extraction_triggered"], preview.id) def test_preview_api(self): preview = self.create(FilePreview, file_ids=[]) - url = '/api/v1/file-previews/{}/'.format(preview.id) + url = "/api/v1/file-previews/{}/".format(preview.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['text'], preview.text) + self.assertEqual(response.data["text"], preview.text) def test_meta_api_no_file(self): - url = 'api/v1/meta-extraction/1000/' + url = "api/v1/meta-extraction/1000/" self.authenticate() response = self.client.get(url) self.assert_404(response) def test_get_file_private_no_random_string(self): - url = '/private-file/1/' + url = "/private-file/1/" self.authenticate() response = self.client.get(url) self.assert_404(response) def test_public_to_private_file_url(self): - urlf_public = '/public-file/{}/{}/{}' + urlf_public = "/public-file/{}/{}/{}" file_id = self.save_file_with_api() file = File.objects.get(id=file_id) url = urlf_public.format( urlsafe_base64_encode(force_bytes(file_id)), - 'random-strings-xxyyzz', + "random-strings-xxyyzz", file.title, ) - redirect_url = 'http://testserver' + reverse( - 'gallery_private_url', + redirect_url = "http://testserver" + reverse( + "gallery_private_url", kwargs={ - 'uuid': file.uuid, 'filename': file.title, + "uuid": file.uuid, + "filename": file.title, }, ) response = self.client.get(url) @@ -144,10 +143,10 @@ def test_public_to_private_file_url(self): assert response.url == redirect_url, f"Should return {redirect_url}" def test_private_file_url(self): - urlf = '/private-file/{}/{}' + urlf = "/private-file/{}/{}" - file_id = self.save_file_with_api({'isPublic': False}) - entry_file_id = self.save_file_with_api({'isPublic': False}) + file_id = self.save_file_with_api({"isPublic": False}) + entry_file_id = self.save_file_with_api({"isPublic": False}) file = File.objects.get(id=file_id) entry_file = File.objects.get(id=entry_file_id) @@ -190,18 +189,18 @@ def test_private_file_url(self): assert response.status_code == 302, "Should return 302 redirect" def save_file_with_api(self, kwargs={}): - url = '/api/v1/files/' + url = "/api/v1/files/" data = { - 'title': 'Test file', - 'file': open(self.supported_file, 'rb'), - 'isPublic': True, + "title": "Test file", + "file": open(self.supported_file, "rb"), + "isPublic": True, **kwargs, } self.authenticate() - response = self.client.post(url, data, format='multipart') + response = self.client.post(url, data, format="multipart") self.assert_201(response) - return response.data['id'] + return response.data["id"] # NOTE: Test for files diff --git a/apps/gallery/tests/test_mutations.py b/apps/gallery/tests/test_mutations.py index 1c8ba84764..9f633b0eb7 100644 --- a/apps/gallery/tests/test_mutations.py +++ b/apps/gallery/tests/test_mutations.py @@ -1,15 +1,15 @@ -from utils.graphene.tests import GraphQLTestCase import json -from graphene_file_upload.django.testing import GraphQLFileUploadTestCase from django.core.files.temp import NamedTemporaryFile - -from user.factories import UserFactory from gallery.models import File +from graphene_file_upload.django.testing import GraphQLFileUploadTestCase +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestUploadFileMutation(GraphQLFileUploadTestCase, GraphQLTestCase): - UPLOAD_FILE = ''' + UPLOAD_FILE = """ mutation MyMutation ($data: FileUploadInputType!) { fileUpload(data: $data) { ok @@ -23,41 +23,34 @@ class TestUploadFileMutation(GraphQLFileUploadTestCase, GraphQLTestCase): } } } -''' +""" def setUp(self): super().setUp() - self.variables = { - "data": {"title": 'test', "file": None} - } + self.variables = {"data": {"title": "test", "file": None}} self.user = UserFactory.create() self.force_login(self.user) def test_upload_file(self): - file_text = b'preview image text' - with NamedTemporaryFile(suffix='.jpeg') as t_file: + file_text = b"preview image text" + with NamedTemporaryFile(suffix=".jpeg") as t_file: t_file.write(file_text) t_file.seek(0) response = self._client.post( - '/graphql', + "/graphql", data={ - 'operations': json.dumps({ - 'query': self.UPLOAD_FILE, - 'variables': self.variables - }), - 't_file': t_file, - 'map': json.dumps({ - 't_file': ['variables.data.file'] - }) - } + "operations": json.dumps({"query": self.UPLOAD_FILE, "variables": self.variables}), + "t_file": t_file, + "map": json.dumps({"t_file": ["variables.data.file"]}), + }, ) content = response.json() self.assertResponseNoErrors(response) - created_by_user = File.objects.get(id=content['data']['fileUpload']['result']['id']).created_by - self.assertTrue(content['data']['fileUpload']['ok'], content) - self.assertTrue(content['data']['fileUpload']['result']['file']["name"]) - file_name = content['data']['fileUpload']['result']['file']["name"] - file_url = content['data']['fileUpload']['result']['file']["url"] - self.assertTrue(file_name.endswith('.jpeg')) + created_by_user = File.objects.get(id=content["data"]["fileUpload"]["result"]["id"]).created_by + self.assertTrue(content["data"]["fileUpload"]["ok"], content) + self.assertTrue(content["data"]["fileUpload"]["result"]["file"]["name"]) + file_name = content["data"]["fileUpload"]["result"]["file"]["name"] + file_url = content["data"]["fileUpload"]["result"]["file"]["url"] + self.assertTrue(file_name.endswith(".jpeg")) self.assertEqual(created_by_user, self.user) self.assertTrue(file_url.endswith(file_name)) diff --git a/apps/gallery/tests/test_tasks.py b/apps/gallery/tests/test_tasks.py index e8414971f2..1ff38adaa2 100644 --- a/apps/gallery/tests/test_tasks.py +++ b/apps/gallery/tests/test_tasks.py @@ -1,34 +1,31 @@ -from django.core.files.uploadedfile import SimpleUploadedFile +import logging +from os.path import join + from django.conf import settings +from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase +from gallery.models import File, FilePreview from gallery.tasks import extract_from_file -from os.path import join -import logging -from utils.common import ( - get_or_write_file, - makedirs, -) +from utils.common import get_or_write_file, makedirs from utils.extractor.tests.test_file_document import DOCX_FILE -from gallery.models import File, FilePreview - logger = logging.getLogger(__name__) class ExtractFromFileTaskTest(TestCase): def setUp(self): # This is similar to test_file_document - self.path = join(settings.TEST_DIR, 'documents_attachment') - self.documents = join(settings.TEST_DIR, 'documents') + self.path = join(settings.TEST_DIR, "documents_attachment") + self.documents = join(settings.TEST_DIR, "documents") makedirs(self.path) # Create the sample file self.file = File.objects.create( - title='test', + title="test", file=SimpleUploadedFile( name=DOCX_FILE, - content=open(join(self.documents, DOCX_FILE), 'rb').read(), + content=open(join(self.documents, DOCX_FILE), "rb").read(), ), ) @@ -39,7 +36,7 @@ def setUp(self): def test_extraction(self): # TODO: - print('SKIPING THIS AS WE ARE NOT USING DEEPL RIGHT NOW') + print("SKIPING THIS AS WE ARE NOT USING DEEPL RIGHT NOW") return # Check if extraction works succesfully result = extract_from_file(self.file_preview.id) @@ -49,17 +46,15 @@ def test_extraction(self): self.file_preview = FilePreview.objects.get(id=self.file_preview.id) if not self.file_preview.extracted: border_len = 50 - logger.warning('*' * border_len) - logger.warning('---- File extraction is not working ----') - logger.warning('Probably an issue with DEEPL integration') - logger.warning('*' * border_len) + logger.warning("*" * border_len) + logger.warning("---- File extraction is not working ----") + logger.warning("Probably an issue with DEEPL integration") + logger.warning("*" * border_len) # This is similar to test_file_document path = join(self.path, DOCX_FILE) - extracted = get_or_write_file( - path + '.txt', self.file_preview.text - ) + extracted = get_or_write_file(path + ".txt", self.file_preview.text) self.assertEqual( - ' '.join(self.file_preview.text.split()), - ' '.join(extracted.read().split()), + " ".join(self.file_preview.text.split()), + " ".join(extracted.read().split()), ) diff --git a/apps/gallery/views.py b/apps/gallery/views.py index 79d8b21439..b3e89ae2f9 100644 --- a/apps/gallery/views.py +++ b/apps/gallery/views.py @@ -1,51 +1,48 @@ import logging -from django.urls import reverse -from django.views.generic import View + +import django_filters from django.conf import settings from django.db import models, transaction +from django.shortcuts import get_object_or_404, redirect +from django.urls import reverse from django.utils.encoding import force_text from django.utils.http import urlsafe_base64_decode -from django.shortcuts import redirect, get_object_or_404 - +from django.views.generic import View +from entry.models import Entry +from lead.models import Lead +from project.models import Project from rest_framework import ( - views, - viewsets, - permissions, - response, + decorators, + exceptions, filters, mixins, - exceptions, - decorators, + permissions, + response, status, + views, + viewsets, ) -import django_filters +from user_resource.filters import UserResourceFilterSet -from deep.permissions import ModifyPermission from deep.permalinks import Permalink -from project.models import Project -from lead.models import Lead -from entry.models import Entry -from user_resource.filters import UserResourceFilterSet +from deep.permissions import ModifyPermission +from utils.extractor.formats import ods, xlsx -from utils.extractor.formats import ( - xlsx, - ods, -) +from .models import File, FilePreview from .serializers import ( + DropboxFileSerializer, + FilePreviewSerializer, FileSerializer, GoogleDriveFileSerializer, - DropboxFileSerializer, - FilePreviewSerializer ) from .tasks import extract_from_file -from .models import File, FilePreview logger = logging.getLogger(__name__) META_EXTRACTION_FUNCTIONS = { # The functions take file as argument - 'xlsx': xlsx.extract_meta, - 'ods': ods.extract_meta, + "xlsx": xlsx.extract_meta, + "ods": ods.extract_meta, } @@ -56,24 +53,25 @@ def DEFAULT_EXTRACTION_FUNCTION(file): # TODO: Remove this after all entry images are migrated class FileView(View): def get(self, request, file_id): - return response.Response({ - 'error': 'This API is deprecated', - }, status=status.HTTP_403_FORBIDDEN) + return response.Response( + { + "error": "This API is deprecated", + }, + status=status.HTTP_403_FORBIDDEN, + ) class PrivateFileView(views.APIView): def get(self, request, uuid=None, filename=None): - queryset = File.objects.prefetch_related('lead_set') + queryset = File.objects.prefetch_related("lead_set") file = get_object_or_404(queryset, uuid=uuid) if file.lead_set.count() == 1: # Redirect to new url - return redirect( - Permalink.lead_share_view(file.lead_set.first().uuid) - ) + return redirect(Permalink.lead_share_view(file.lead_set.first().uuid)) # Redirect to old url return redirect( reverse( - 'deprecated_gallery_private_url', + "deprecated_gallery_private_url", kwargs=dict( uuid=uuid, filename=filename, @@ -86,45 +84,55 @@ class DeprecatedPrivateFileView(views.APIView): permission_classes = [permissions.IsAuthenticated] def get(self, request, uuid=None, filename=None): - queryset = File.objects.prefetch_related('lead_set') + queryset = File.objects.prefetch_related("lead_set") file = get_object_or_404(queryset, uuid=uuid) user = request.user - leads_pk = file.lead_set.values_list('pk', flat=True) + leads_pk = file.lead_set.values_list("pk", flat=True) if ( - file.is_public or - Lead.get_for(user).filter(pk__in=leads_pk).exists() or - Entry.get_for(user).filter(image=file).exists() or - Entry.get_for(user).filter( - image_raw=request.build_absolute_uri( - reverse('file', kwargs={'file_id': file.pk}), - ) - ).exists() - # TODO: Add Profile + file.is_public + or Lead.get_for(user).filter(pk__in=leads_pk).exists() + or Entry.get_for(user).filter(image=file).exists() + or Entry.get_for(user) + .filter( + image_raw=request.build_absolute_uri( + reverse("file", kwargs={"file_id": file.pk}), + ) + ) + .exists() + # TODO: Add Profile ): if file.file: return redirect(request.build_absolute_uri(file.file.url)) - return response.Response({ - 'error': 'File doesn\'t exists', - }, status=status.HTTP_404_NOT_FOUND) - return response.Response({ - 'error': 'Access Forbidden, Contact Admin', - }, status=status.HTTP_403_FORBIDDEN) + return response.Response( + { + "error": "File doesn't exists", + }, + status=status.HTTP_404_NOT_FOUND, + ) + return response.Response( + { + "error": "Access Forbidden, Contact Admin", + }, + status=status.HTTP_403_FORBIDDEN, + ) class PublicFileView(View): """ NOTE: Public File API is deprecated. """ + def get(self, request, fidb64=None, token=None, filename=None): file_id = force_text(urlsafe_base64_decode(fidb64)) file = get_object_or_404(File, id=file_id) return redirect( request.build_absolute_uri( reverse( - 'gallery_private_url', + "gallery_private_url", kwargs={ - 'uuid': file.uuid, 'filename': filename, + "uuid": file.uuid, + "filename": filename, }, ) ) @@ -135,10 +143,7 @@ def filter_files_by_projects(qs, name, value): if len(value) == 0: return qs - return qs.filter( - models.Q(projects__in=value) | - models.Q(lead__project__in=value) - ) + return qs.filter(models.Q(projects__in=value) | models.Q(lead__project__in=value)) class FileFilterSet(UserResourceFilterSet): @@ -152,7 +157,7 @@ class FileFilterSet(UserResourceFilterSet): """ projects = django_filters.ModelMultipleChoiceFilter( - field_name='projects', + field_name="projects", queryset=Project.objects.all(), widget=django_filters.widgets.CSVWidget, method=filter_files_by_projects, @@ -160,13 +165,13 @@ class FileFilterSet(UserResourceFilterSet): class Meta: model = File - fields = ['id', 'title', 'mime_type'] + fields = ["id", "title", "mime_type"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -174,37 +179,33 @@ class Meta: class FileViewSet(viewsets.ModelViewSet): serializer_class = FileSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = FileFilterSet - search_fields = ('title', 'file') + search_fields = ("title", "file") def get_queryset(self): return File.get_for(self.request.user) @decorators.action( detail=True, - url_path='preview', + url_path="preview", ) def get_preview(self, request, pk=None, version=None): obj = self.get_object() - url = self.get_serializer(obj).data.get('file') + url = self.get_serializer(obj).data.get("file") response = redirect(request.build_absolute_uri(url)) return response class GoogleDriveFileViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): serializer_class = GoogleDriveFileSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] class DropboxFileViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): serializer_class = DropboxFileSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] class FilePreviewViewSet(viewsets.ReadOnlyModelViewSet): @@ -219,7 +220,7 @@ class FileExtractionTriggerView(views.APIView): permission_classes = [permissions.IsAuthenticated] def post(self, request, version=None): - file_ids = request.data.get('file_ids') + file_ids = request.data.get("file_ids") # Check if preview with same file ids already exists existing = FilePreview.objects.filter( @@ -229,9 +230,11 @@ def post(self, request, version=None): # If so, just return that if existing: - return response.Response({ - 'extraction_triggered': existing.id, - }) + return response.Response( + { + "extraction_triggered": existing.id, + } + ) file_preview = FilePreview.objects.create( file_ids=file_ids, @@ -239,36 +242,35 @@ def post(self, request, version=None): ) if not settings.TESTING: - transaction.on_commit( - lambda: extract_from_file.delay(file_preview.id) - ) + transaction.on_commit(lambda: extract_from_file.delay(file_preview.id)) - return response.Response({ - 'extraction_triggered': file_preview.id, - }) + return response.Response( + { + "extraction_triggered": file_preview.id, + } + ) class MetaExtractionView(views.APIView): permission_classes = [permissions.IsAuthenticated] def get(self, request, file_id=None, version=None): - file_type = request.query_params.get('file_type') + file_type = request.query_params.get("file_type") if file_type is None: - raise exceptions.ValidationError({ - 'file_type': 'file_type should be present' - }) + raise exceptions.ValidationError({"file_type": "file_type should be present"}) file = File.objects.filter(id=file_id).first() if file is None: raise exceptions.NotFound() extraction_function = META_EXTRACTION_FUNCTIONS.get( - file_type, DEFAULT_EXTRACTION_FUNCTION, + file_type, + DEFAULT_EXTRACTION_FUNCTION, ) try: return response.Response(extraction_function(file.file)) except Exception: - logger.warning('Exception while extracting file {}'.format(file.id)) - raise exceptions.ValidationError('Can\'t get metadata. Check if the file has correct format.') + logger.warning("Exception while extracting file {}".format(file.id)) + raise exceptions.ValidationError("Can't get metadata. Check if the file has correct format.") diff --git a/apps/geo/admin.py b/apps/geo/admin.py index 3e1031e539..d1fd5049f5 100644 --- a/apps/geo/admin.py +++ b/apps/geo/admin.py @@ -1,66 +1,62 @@ -from django.contrib import admin +from django.contrib import admin, messages from django.utils.safestring import mark_safe -from django.contrib import messages from deep.admin import VersionAdmin, linkify -from .models import Region, AdminLevel, GeoArea -from .tasks import cal_region_cache, cal_admin_level_cache +from .models import AdminLevel, GeoArea, Region +from .tasks import cal_admin_level_cache, cal_region_cache def trigger_region_cache_reset(_, request, queryset): - cal_region_cache.delay( - list(queryset.values_list('id', flat=True).distinct()) - ) + cal_region_cache.delay(list(queryset.values_list("id", flat=True).distinct())) messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully triggered regions:

' + - '
'.join( - '* {0} : ({1}) {2}'.format(*value) - for value in queryset.values_list('id', 'code', 'title').distinct() - ) - ) + "Successfully triggered regions:

" + + "
".join("* {0} : ({1}) {2}".format(*value) for value in queryset.values_list("id", "code", "title").distinct()) + ), ) -trigger_region_cache_reset.short_description = 'Trigger cache reset for selected Regions' +trigger_region_cache_reset.short_description = "Trigger cache reset for selected Regions" def trigger_admin_level_cache_reset(_, request, queryset): - cal_admin_level_cache.delay( - list(queryset.values_list('id', flat=True).distinct()) - ) + cal_admin_level_cache.delay(list(queryset.values_list("id", flat=True).distinct())) messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully triggered Admin Levels:

' + - '
'.join( - '* {0} : (level={1}) {2}'.format(*value) - for value in queryset.values_list('id', 'level', 'title').distinct() + "Successfully triggered Admin Levels:

" + + "
".join( + "* {0} : (level={1}) {2}".format(*value) for value in queryset.values_list("id", "level", "title").distinct() ) - ) + ), ) -trigger_admin_level_cache_reset.short_description = 'Trigger cache reset for selected AdminLevels' +trigger_admin_level_cache_reset.short_description = "Trigger cache reset for selected AdminLevels" class AdminLevelInline(admin.StackedInline): model = AdminLevel - autocomplete_fields = ('parent', 'geo_shape_file',) - exclude = ('geo_area_titles',) + autocomplete_fields = ( + "parent", + "geo_shape_file", + ) + exclude = ("geo_area_titles",) max_num = 0 @admin.register(Region) class RegionAdmin(VersionAdmin): - list_display = ('title', 'project_count') - search_fields = ('title',) + list_display = ("title", "project_count") + search_fields = ("title",) inlines = [AdminLevelInline] - exclude = ('geo_options',) + exclude = ("geo_options",) actions = [trigger_region_cache_reset] - autocomplete_fields = ('created_by', 'modified_by') + autocomplete_fields = ("created_by", "modified_by") list_per_page = 10 def project_count(self, instance): @@ -69,17 +65,31 @@ def project_count(self, instance): @admin.register(AdminLevel) class AdminLevelAdmin(VersionAdmin): - search_fields = ('title', 'region__title',) - list_display = ('title', linkify('region'),) - autocomplete_fields = ('region',) + AdminLevelInline.autocomplete_fields - exclude = ('geo_area_titles',) + search_fields = ( + "title", + "region__title", + ) + list_display = ( + "title", + linkify("region"), + ) + autocomplete_fields = ("region",) + AdminLevelInline.autocomplete_fields + exclude = ("geo_area_titles",) actions = [trigger_admin_level_cache_reset] list_per_page = 10 @admin.register(GeoArea) class GeoAreaAdmin(VersionAdmin): - search_fields = ('title',) - list_display = ('title', linkify('admin_level'), linkify('parent'), 'code',) - autocomplete_fields = ('parent', 'admin_level',) + search_fields = ("title",) + list_display = ( + "title", + linkify("admin_level"), + linkify("parent"), + "code", + ) + autocomplete_fields = ( + "parent", + "admin_level", + ) list_per_page = 10 diff --git a/apps/geo/apps.py b/apps/geo/apps.py index 7799a7487b..e40f307c3d 100644 --- a/apps/geo/apps.py +++ b/apps/geo/apps.py @@ -2,7 +2,7 @@ class GeoConfig(AppConfig): - name = 'geo' + name = "geo" def ready(self): import utils.db.functions # noqa diff --git a/apps/geo/dataloaders.py b/apps/geo/dataloaders.py index d9c73e200d..f50b8a1c2d 100644 --- a/apps/geo/dataloaders.py +++ b/apps/geo/dataloaders.py @@ -1,20 +1,20 @@ from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property -from django.db.models import Prefetch +from assessment_registry.models import AssessmentRegistry from assisted_tagging.models import DraftEntry +from django.db.models import Prefetch +from django.utils.functional import cached_property +from geo.schema import get_geo_area_queryset_for_project_geo_area_type +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from geo.schema import get_geo_area_queryset_for_project_geo_area_type from .models import AdminLevel -from assessment_registry.models import AssessmentRegistry class AdminLevelLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - adminlevel_qs = AdminLevel.objects.filter(region__in=keys).defer('geo_area_titles') + adminlevel_qs = AdminLevel.objects.filter(region__in=keys).defer("geo_area_titles") _map = defaultdict(list) for adminlevel in adminlevel_qs: _map[adminlevel.region_id].append(adminlevel) @@ -23,10 +23,9 @@ def batch_load_fn(self, keys): class AssessmentRegistryGeoAreaLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - ary_geo_area_qs = AssessmentRegistry.locations.through.objects\ - .filter(assessmentregistry__in=keys).prefetch_related( - Prefetch('geoarea', queryset=get_geo_area_queryset_for_project_geo_area_type()) - ) + ary_geo_area_qs = AssessmentRegistry.locations.through.objects.filter(assessmentregistry__in=keys).prefetch_related( + Prefetch("geoarea", queryset=get_geo_area_queryset_for_project_geo_area_type()) + ) _map = defaultdict(list) for ary_geo_area in ary_geo_area_qs.all(): _map[ary_geo_area.assessmentregistry_id].append(ary_geo_area.geoarea) @@ -35,10 +34,11 @@ def batch_load_fn(self, keys): class DraftEntryGeoAreaLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - draft_entry_geo_area_qs = DraftEntry.objects\ - .filter(id__in=keys).prefetch_related( - Prefetch('related_geoareas', queryset=get_geo_area_queryset_for_project_geo_area_type()) - ).only('pk') + draft_entry_geo_area_qs = ( + DraftEntry.objects.filter(id__in=keys) + .prefetch_related(Prefetch("related_geoareas", queryset=get_geo_area_queryset_for_project_geo_area_type())) + .only("pk") + ) _map = defaultdict(list) for draft_entry_geo_area in draft_entry_geo_area_qs.all(): _map[draft_entry_geo_area.pk].extend(draft_entry_geo_area.related_geoareas.all()) diff --git a/apps/geo/enums.py b/apps/geo/enums.py index 78ca1d46ca..84c66e42e3 100644 --- a/apps/geo/enums.py +++ b/apps/geo/enums.py @@ -3,8 +3,8 @@ class GeoAreaOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_ADMIN_LEVEL = 'admin_level__level' + ASC_ID = "id" + ASC_ADMIN_LEVEL = "admin_level__level" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_ADMIN_LEVEL = f'-{ASC_ADMIN_LEVEL}' + DESC_ID = f"-{ASC_ID}" + DESC_ADMIN_LEVEL = f"-{ASC_ADMIN_LEVEL}" diff --git a/apps/geo/factories.py b/apps/geo/factories.py index 7f9a552606..ed6f8aa87f 100644 --- a/apps/geo/factories.py +++ b/apps/geo/factories.py @@ -1,32 +1,27 @@ import factory from factory import fuzzy from factory.django import DjangoModelFactory - -from geo.models import ( - Region, - AdminLevel, - GeoArea, -) +from geo.models import AdminLevel, GeoArea, Region class RegionFactory(DjangoModelFactory): code = fuzzy.FuzzyText(length=3) - title = factory.Sequence(lambda n: f'Region-{n}') + title = factory.Sequence(lambda n: f"Region-{n}") class Meta: model = Region class AdminLevelFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Admin-Level-{n}') + title = factory.Sequence(lambda n: f"Admin-Level-{n}") class Meta: model = AdminLevel class GeoAreaFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'GeoArea-{n}') - code = factory.Sequence(lambda n: f'code-{n}') + title = factory.Sequence(lambda n: f"GeoArea-{n}") + code = factory.Sequence(lambda n: f"code-{n}") class Meta: model = GeoArea diff --git a/apps/geo/filter_set.py b/apps/geo/filter_set.py index ad11807e36..ae18782414 100644 --- a/apps/geo/filter_set.py +++ b/apps/geo/filter_set.py @@ -2,30 +2,18 @@ import django_filters from django.db import models - -from deep.filter_set import OrderEnumMixin -from utils.graphene.filters import ( - IDListFilter, - StringListFilter, - MultipleInputFilter, -) - from project.models import Project from user_resource.filters import UserResourceFilterSet -from .models import ( - AdminLevel, - GeoArea, - Region, -) +from deep.filter_set import OrderEnumMixin +from utils.graphene.filters import IDListFilter, MultipleInputFilter, StringListFilter + from .enums import GeoAreaOrderingEnum +from .models import AdminLevel, GeoArea, Region class GeoAreaFilterSet(django_filters.rest_framework.FilterSet): - label = django_filters.CharFilter( - label='Geo Area Label', - method='geo_area_label' - ) + label = django_filters.CharFilter(label="Geo Area Label", method="geo_area_label") class Meta: model = GeoArea @@ -43,22 +31,22 @@ class RegionFilterSet(UserResourceFilterSet): Filter by code, title and public fields """ + # NOTE: This filter the regions not in the supplied project exclude_project = django_filters.ModelMultipleChoiceFilter( - method='exclude_project_region_filter', + method="exclude_project_region_filter", widget=django_filters.widgets.CSVWidget, queryset=Project.objects.all(), ) class Meta: model = Region - fields = ['id', 'code', 'title', 'public', 'project', - 'created_at', 'created_by', 'modified_at', 'modified_by'] + fields = ["id", "code", "title", "public", "project", "created_at", "created_by", "modified_at", "modified_by"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -75,34 +63,30 @@ class AdminLevelFilterSet(django_filters.rest_framework.FilterSet): Filter by title, region and parent """ + class Meta: model = AdminLevel - fields = ['id', 'title', 'region', 'parent'] + fields = ["id", "title", "region", "parent"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } + # ------------------------------ Graphql filters ----------------------------------- class GeoAreaGqlFilterSet(OrderEnumMixin, django_filters.rest_framework.FilterSet): - ids = IDListFilter(field_name='id') - region_ids = IDListFilter(field_name='admin_level__region') - admin_level_ids = IDListFilter(field_name='admin_level') - search = django_filters.CharFilter( - label='Geo Area Label search', - method='geo_area_label' - ) - titles = StringListFilter( - label='Geo Area Label search (Multiple titles)', - method='filter_titles' - ) - ordering = MultipleInputFilter(GeoAreaOrderingEnum, method='ordering_filter') + ids = IDListFilter(field_name="id") + region_ids = IDListFilter(field_name="admin_level__region") + admin_level_ids = IDListFilter(field_name="admin_level") + search = django_filters.CharFilter(label="Geo Area Label search", method="geo_area_label") + titles = StringListFilter(label="Geo Area Label search (Multiple titles)", method="filter_titles") + ordering = MultipleInputFilter(GeoAreaOrderingEnum, method="ordering_filter") class Meta: model = GeoArea @@ -117,24 +101,13 @@ def filter_titles(self, queryset, _, values): if values: # Let's only use 10 max. _values = set(values[:10]) - return queryset.filter( - reduce( - lambda acc, item: acc | item, - [ - models.Q(title__iexact=value) - for value in _values - ] - ) - ) + return queryset.filter(reduce(lambda acc, item: acc | item, [models.Q(title__iexact=value) for value in _values])) return queryset class RegionGqlFilterSet(RegionFilterSet): - search = django_filters.CharFilter( - label='Region label search', - method='region_search' - ) - exclude_project = IDListFilter(method='exclude_project_region_filter') + search = django_filters.CharFilter(label="Region label search", method="region_search") + exclude_project = IDListFilter(method="exclude_project_region_filter") def region_search(self, queryset, _, value): if value: diff --git a/apps/geo/management/commands/retrigger_geo_cache.py b/apps/geo/management/commands/retrigger_geo_cache.py index c8b3793efa..c5a3a8ee34 100644 --- a/apps/geo/management/commands/retrigger_geo_cache.py +++ b/apps/geo/management/commands/retrigger_geo_cache.py @@ -1,25 +1,25 @@ import logging + from django.core.management.base import BaseCommand from django.db import transaction - -from geo.models import Region, AdminLevel +from geo.models import AdminLevel, Region logger = logging.getLogger(__name__) class Command(BaseCommand): - help = 'Re-trigger cached data for all geo entities. For specific objects admin panel' + help = "Re-trigger cached data for all geo entities. For specific objects admin panel" def add_arguments(self, parser): parser.add_argument( - '--region', - action='store_true', - help='Calculate all regions cache', + "--region", + action="store_true", + help="Calculate all regions cache", ) parser.add_argument( - '--admin-level', - action='store_true', - help='Calculate all regions admin level cache', + "--admin-level", + action="store_true", + help="Calculate all regions admin level cache", ) def calculate(self, Model): @@ -33,12 +33,12 @@ def calculate(self, Model): item.calc_cache() success_ids.append(item.pk) except Exception: - logger.error(f'{Model.__name__} Cache Calculation Failed!!', exc_info=True) - self.stdout.write(self.style.SUCCESS(f'{success_ids=}')) + logger.error(f"{Model.__name__} Cache Calculation Failed!!", exc_info=True) + self.stdout.write(self.style.SUCCESS(f"{success_ids=}")) def handle(self, *_, **options): - calculate_regions = options['region'] - calculate_admin_levels = options['admin_level'] + calculate_regions = options["region"] + calculate_admin_levels = options["admin_level"] if calculate_regions: self.calculate(Region) diff --git a/apps/geo/models.py b/apps/geo/models.py index f16efbdb91..e105762318 100644 --- a/apps/geo/models.py +++ b/apps/geo/models.py @@ -1,16 +1,16 @@ import json - from typing import List, Union + from django.contrib.gis.db import models -from django.core.serializers import serialize -from django.db import transaction, connection -from django.contrib.gis.gdal import Envelope from django.contrib.gis.db.models.aggregates import Union as PgUnion from django.contrib.gis.db.models.functions import Centroid +from django.contrib.gis.gdal import Envelope +from django.core.serializers import serialize +from django.db import connection, transaction +from gallery.models import File +from user_resource.models import UserResource from utils.files import generate_json_file_for_upload -from user_resource.models import UserResource -from gallery.models import File class Region(UserResource): @@ -21,6 +21,7 @@ class Region(UserResource): Region can be global in which case it will be available directly to public. Project specific regions won't be available publicly. """ + code = models.CharField(max_length=10) title = models.CharField(max_length=255) public = models.BooleanField(default=True) @@ -45,30 +46,32 @@ def __str__(self): return f"[{'Public' if self.public else 'Private'}] {self.title}" class Meta: - ordering = ['title', 'code'] + ordering = ["title", "code"] def calc_cache(self, save=True): self.geo_options = [ { - 'label': '{} / {}'.format(geo_area.admin_level.title, geo_area.title), - 'title': geo_area.title, - 'key': str(geo_area.id), - 'admin_level': geo_area.admin_level.level, - 'admin_level_title': geo_area.admin_level.title, - 'region': self.id, - 'region_title': self.title, - 'parent': geo_area.parent.id if geo_area.parent else None, - } for geo_area in GeoArea.objects.prefetch_related( - 'admin_level', - ).filter( - admin_level__region=self - ).order_by('admin_level__level').distinct() + "label": "{} / {}".format(geo_area.admin_level.title, geo_area.title), + "title": geo_area.title, + "key": str(geo_area.id), + "admin_level": geo_area.admin_level.level, + "admin_level_title": geo_area.admin_level.title, + "region": self.id, + "region_title": self.title, + "parent": geo_area.parent.id if geo_area.parent else None, + } + for geo_area in GeoArea.objects.prefetch_related( + "admin_level", + ) + .filter(admin_level__region=self) + .order_by("admin_level__level") + .distinct() ] # Calculate region centroid - self.centroid = GeoArea.objects\ - .filter(admin_level__region=self)\ - .aggregate(centroid=Centroid(PgUnion(Centroid('polygons'))))['centroid'] + self.centroid = GeoArea.objects.filter(admin_level__region=self).aggregate( + centroid=Centroid(PgUnion(Centroid("polygons"))) + )["centroid"] self.cache_index += 1 # Increment after every calc_cache. This is used by project to generate overall cache. if save: self.save() @@ -76,13 +79,13 @@ def calc_cache(self, save=True): def get_verbose_title(self): if self.public: return self.title - return '{} (Private)'.format(self.title) + return "{} (Private)".format(self.title) def clone_to_private(self, user): region = Region( code=self.code, # Strip off extra chars from title to add ' (cloned) - title='{} (cloned)'.format(self.title[:230]), + title="{} (cloned)".format(self.title[:230]), public=False, regional_groups=self.regional_groups, key_figures=self.key_figures, @@ -107,9 +110,7 @@ def clone_to_private(self, user): @staticmethod def get_for(user): return Region.objects.filter( - models.Q(public=True) | - models.Q(created_by=user) | - models.Q(project__members=user) + models.Q(public=True) | models.Q(created_by=user) | models.Q(project__members=user) ).distinct() def can_get(self, user): @@ -117,19 +118,26 @@ def can_get(self, user): def can_modify(self, user): from project.models import ProjectMembership, ProjectRole + return ( # Either created by user - not self.is_published and ( - (self.created_by == user) or + not self.is_published + and ( + (self.created_by == user) + or # Or is public and user is superuser - (self.public and user.is_superuser) or + (self.public and user.is_superuser) + or # Or is private and user is admin of one of the projects # with this region - (not self.public and ProjectMembership.objects.filter( - project__regions=self, - member=user, - role__in=ProjectRole.get_admin_roles(), - ).exists()) + ( + not self.public + and ProjectMembership.objects.filter( + project__regions=self, + member=user, + role__in=ProjectRole.get_admin_roles(), + ).exists() + ) ) ) @@ -154,10 +162,9 @@ class AdminLevel(models.Model): * parent_name_prop - Property defining name of parent of the geo area * parent_code_prop - Property defining code of parent of the geo area """ + region = models.ForeignKey(Region, on_delete=models.CASCADE) - parent = models.ForeignKey('AdminLevel', - on_delete=models.SET_NULL, - null=True, blank=True, default=None) + parent = models.ForeignKey("AdminLevel", on_delete=models.SET_NULL, null=True, blank=True, default=None) title = models.CharField(max_length=255) level = models.IntegerField(null=True, blank=True, default=None) name_prop = models.CharField(max_length=255, blank=True) @@ -165,18 +172,25 @@ class AdminLevel(models.Model): parent_name_prop = models.CharField(max_length=255, blank=True) parent_code_prop = models.CharField(max_length=255, blank=True) - geo_shape_file = models.ForeignKey(File, on_delete=models.SET_NULL, - null=True, blank=True, default=None) + geo_shape_file = models.ForeignKey(File, on_delete=models.SET_NULL, null=True, blank=True, default=None) tolerance = models.FloatField(default=0.0001) stale_geo_areas = models.BooleanField(default=True) # cache data geojson_file = models.FileField( - upload_to='geojson/', max_length=255, null=True, blank=True, default=None, + upload_to="geojson/", + max_length=255, + null=True, + blank=True, + default=None, ) bounds_file = models.FileField( - upload_to='geo-bounds/', max_length=255, null=True, blank=True, default=None, + upload_to="geo-bounds/", + max_length=255, + null=True, + blank=True, + default=None, ) geo_area_titles = models.JSONField(default=None, blank=True, null=True) @@ -184,7 +198,7 @@ def __str__(self): return self.title class Meta: - ordering = ['level'] + ordering = ["level"] def get_geo_area_titles(self): if not self.geo_area_titles: @@ -194,7 +208,7 @@ def get_geo_area_titles(self): def calc_cache(self, save=True): # Update geo parent_titles data with transaction.atomic(): - GEO_PARENT_DATA_CALC_SQL = f''' + GEO_PARENT_DATA_CALC_SQL = f""" WITH geo_parents_data as ( SELECT id, @@ -237,24 +251,26 @@ def calc_cache(self, save=True): FROM geo_parents_data GP WHERE G.id = GP.id - ''' + """ with connection.cursor() as cursor: - cursor.execute(GEO_PARENT_DATA_CALC_SQL, {'admin_level_id': self.pk}) - - geojson = json.loads(serialize( - 'geojson', - self.geoarea_set.all(), - geometry_field='polygons', - fields=('pk', 'title', 'code', 'cached_data'), - )) + cursor.execute(GEO_PARENT_DATA_CALC_SQL, {"admin_level_id": self.pk}) + + geojson = json.loads( + serialize( + "geojson", + self.geoarea_set.all(), + geometry_field="polygons", + fields=("pk", "title", "code", "cached_data"), + ) + ) # Titles titles = {} for geo_area in self.geoarea_set.all(): titles[str(geo_area.id)] = { - 'title': geo_area.title, - 'parent_id': str(geo_area.parent.pk) if geo_area.parent else None, - 'code': geo_area.code, + "title": geo_area.title, + "parent_id": str(geo_area.parent.pk) if geo_area.parent else None, + "code": geo_area.code, } self.geo_area_titles = titles @@ -267,21 +283,21 @@ def calc_cache(self, save=True): for area in areas[1:]: envelope.expand_to_include(*area.polygons.extent) bounds = { - 'minX': envelope.min_x, - 'minY': envelope.min_y, - 'maxX': envelope.max_x, - 'maxY': envelope.max_y, + "minX": envelope.min_x, + "minY": envelope.min_y, + "maxX": envelope.max_x, + "maxY": envelope.max_y, } except ValueError: pass self.geojson_file.save( - f'admin-level-{self.pk}.json', + f"admin-level-{self.pk}.json", generate_json_file_for_upload(geojson), ) self.bounds_file.save( - f'admin-level-{self.pk}.json', - generate_json_file_for_upload({'bounds': bounds}), + f"admin-level-{self.pk}.json", + generate_json_file_for_upload({"bounds": bounds}), ) if save: self.save() @@ -318,9 +334,7 @@ def clone_to(self, region, parent=None): @staticmethod def get_for(user): return AdminLevel.objects.filter( - models.Q(region__public=True) | - models.Q(region__created_by=user) | - models.Q(region__project__members=user) + models.Q(region__public=True) | models.Q(region__created_by=user) | models.Q(region__project__members=user) ).distinct() def can_get(self, user): @@ -334,11 +348,14 @@ class GeoArea(models.Model): """ An actual geo area in a given admin level """ + admin_level = models.ForeignKey(AdminLevel, on_delete=models.CASCADE) parent = models.ForeignKey( - 'GeoArea', + "GeoArea", on_delete=models.SET_NULL, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, ) title = models.CharField(max_length=255) code = models.CharField(max_length=255, blank=True) @@ -357,14 +374,9 @@ def __str__(self): @classmethod def sync_centroid(cls): cls.objects.filter( - ( - models.Q(centroid__isempty=True) | - models.Q(centroid__isnull=True) - ), + (models.Q(centroid__isempty=True) | models.Q(centroid__isnull=True)), polygons__isempty=False, - ).update( - centroid=Centroid('polygons') - ) + ).update(centroid=Centroid("polygons")) @classmethod def get_for_project(cls, project, is_published=True): @@ -379,7 +391,7 @@ def clone_to(self, admin_level, parent=None): admin_level=admin_level, parent=parent, # Strip off extra chars from title to add ' (cloned) - title='{} (cloned)'.format(self.title[:230]), + title="{} (cloned)".format(self.title[:230]), code=self.code, data=self.data, polygons=self.polygons, @@ -404,9 +416,9 @@ def get_sub_childrens(cls, value: List[Union[str, int]], level=1): @staticmethod def get_for(user): return AdminLevel.objects.filter( - models.Q(admin_level__region__public=True) | - models.Q(admin_level__region__created_by=user) | - models.Q(admin_level__region__project__members=user) + models.Q(admin_level__region__public=True) + | models.Q(admin_level__region__created_by=user) + | models.Q(admin_level__region__project__members=user) ).distinct() def can_get(self, user): @@ -416,4 +428,4 @@ def can_modify(self, user): return self.admin_level.can_modify(user) def get_label(self): - return '{} / {}'.format(self.admin_level.title, self.title) + return "{} / {}".format(self.admin_level.title, self.title) diff --git a/apps/geo/schema.py b/apps/geo/schema.py index dba9eb526a..eaf99e081e 100644 --- a/apps/geo/schema.py +++ b/apps/geo/schema.py @@ -1,33 +1,32 @@ import graphene +from django.db import models +from geo.filter_set import GeoAreaGqlFilterSet, RegionGqlFilterSet +from geo.models import AdminLevel, GeoArea, Region from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination -from django.db import models -from utils.graphene.types import CustomDjangoListObjectType, FileFieldType from utils.graphene.fields import DjangoPaginatedListObjectField from utils.graphene.pagination import NoOrderingPageGraphqlPagination - -from geo.models import Region, GeoArea, AdminLevel -from geo.filter_set import RegionGqlFilterSet, GeoAreaGqlFilterSet +from utils.graphene.types import CustomDjangoListObjectType, FileFieldType def get_users_region_qs(info): - return Region.get_for(info.context.user).defer('geo_options') + return Region.get_for(info.context.user).defer("geo_options") def get_users_adminlevel_qs(info): # NOTE: We don't need geo_area_titles - return AdminLevel.get_for(info.context.user).defer('geo_area_titles') + return AdminLevel.get_for(info.context.user).defer("geo_area_titles") -def get_geo_area_queryset_for_project_geo_area_type(queryset=None, defer_fields=('polygons', 'centroid', 'cached_data')): +def get_geo_area_queryset_for_project_geo_area_type(queryset=None, defer_fields=("polygons", "centroid", "cached_data")): _queryset = queryset if _queryset is None: _queryset = GeoArea.objects _queryset = _queryset.annotate( - region_title=models.F('admin_level__region__title'), - admin_level_title=models.F('admin_level__title'), - admin_level_level=models.F('admin_level__level'), + region_title=models.F("admin_level__region__title"), + admin_level_title=models.F("admin_level__title"), + admin_level_level=models.F("admin_level__level"), ) if defer_fields: _queryset = _queryset.defer(*defer_fields) @@ -38,12 +37,19 @@ class AdminLevelType(DjangoObjectType): class Meta: model = AdminLevel only_fields = ( - 'id', - 'title', 'level', 'tolerance', 'stale_geo_areas', 'geo_shape_file', - 'name_prop', 'code_prop', 'parent_name_prop', 'parent_code_prop', + "id", + "title", + "level", + "tolerance", + "stale_geo_areas", + "geo_shape_file", + "name_prop", + "code_prop", + "parent_name_prop", + "parent_code_prop", ) - parent = graphene.ID(source='parent_id') + parent = graphene.ID(source="parent_id") geojson_file = graphene.Field(FileFieldType) bounds_file = graphene.Field(FileFieldType) @@ -56,9 +62,15 @@ class RegionType(DjangoObjectType): class Meta: model = Region only_fields = ( - 'id', 'title', 'public', 'regional_groups', - 'key_figures', 'population_data', 'media_sources', - 'centroid', 'is_published', + "id", + "title", + "public", + "regional_groups", + "key_figures", + "population_data", + "media_sources", + "centroid", + "is_published", ) @staticmethod @@ -71,9 +83,15 @@ class Meta: model = Region skip_registry = True only_fields = ( - 'id', 'title', 'public', 'regional_groups', - 'key_figures', 'population_data', 'media_sources', - 'centroid', 'is_published', + "id", + "title", + "public", + "regional_groups", + "key_figures", + "population_data", + "media_sources", + "centroid", + "is_published", ) admin_levels = graphene.List(graphene.NonNull(AdminLevelType)) @@ -91,12 +109,7 @@ class Meta: class Query: region = DjangoObjectField(RegionDetailType) - regions = DjangoPaginatedListObjectField( - RegionListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) - ) + regions = DjangoPaginatedListObjectField(RegionListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize")) @staticmethod def resolve_regions(root, info, **kwargs): @@ -108,7 +121,10 @@ def resolve_regions(root, info, **kwargs): class ProjectGeoAreaType(DjangoObjectType): class Meta: model = GeoArea - only_fields = ('id', 'title',) + only_fields = ( + "id", + "title", + ) skip_registry = True region_title = graphene.String(required=True) @@ -117,7 +133,7 @@ class Meta: parent_titles = graphene.List(graphene.NonNull(graphene.String), required=True) def resolve_parent_titles(root, info, **kwargs): - return (root.cached_data or {}).get('parent_titles') or [] + return (root.cached_data or {}).get("parent_titles") or [] class ProjectGeoAreaListType(CustomDjangoListObjectType): @@ -129,14 +145,9 @@ class Meta: class ProjectScopeQuery: geo_areas = DjangoPaginatedListObjectField( - ProjectGeoAreaListType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + ProjectGeoAreaListType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod def resolve_geo_areas(queryset, info, **kwargs): - return get_geo_area_queryset_for_project_geo_area_type( - queryset=GeoArea.get_for_project(info.context.active_project) - ) + return get_geo_area_queryset_for_project_geo_area_type(queryset=GeoArea.get_for_project(info.context.active_project)) diff --git a/apps/geo/serializers.py b/apps/geo/serializers.py index a7356576ff..8a0456cd21 100644 --- a/apps/geo/serializers.py +++ b/apps/geo/serializers.py @@ -1,45 +1,46 @@ from django.conf import settings from django.db import transaction from drf_dynamic_fields import DynamicFieldsMixin - -from deep.serializers import RemoveNullFieldsMixin, URLCachedFileField -from rest_framework import serializers -from user_resource.serializers import UserResourceSerializer -from geo.models import ( - Region, - AdminLevel, - GeoArea -) +from gallery.serializers import SimpleFileSerializer +from geo.models import AdminLevel, GeoArea, Region from geo.tasks import load_geo_areas from project.models import Project -from gallery.serializers import SimpleFileSerializer +from rest_framework import serializers +from user_resource.serializers import UserResourceSerializer + +from deep.serializers import RemoveNullFieldsMixin, URLCachedFileField -class SimpleRegionSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleRegionSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): title = serializers.CharField(read_only=True) class Meta: model = Region - fields = ('id', 'title') + fields = ("id", "title") -class SimpleAdminLevelSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleAdminLevelSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = AdminLevel - fields = ('id', 'title', 'level', 'name_prop', 'code_prop', - 'parent_name_prop', 'parent_code_prop',) + fields = ( + "id", + "title", + "level", + "name_prop", + "code_prop", + "parent_name_prop", + "parent_code_prop", + ) -class RegionSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): +class RegionSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): """ Region Model Serializer """ + admin_levels = SimpleAdminLevelSerializer( many=True, - source='adminlevel_set', + source="adminlevel_set", read_only=True, ) @@ -51,27 +52,25 @@ class RegionSerializer(RemoveNullFieldsMixin, class Meta: model = Region - exclude = ('geo_options',) + exclude = ("geo_options",) def validate_project(self, project): try: project = Project.objects.get(id=project) except Project.DoesNotExist: - raise serializers.ValidationError( - 'Project matching query does not exist' - ) + raise serializers.ValidationError("Project matching query does not exist") - if not project.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if not project.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project.id def validate(self, data): if self.instance and self.instance.is_published: - raise serializers.ValidationError('Published region can\'t be changed. Please contact Admin') + raise serializers.ValidationError("Published region can't be changed. Please contact Admin") return data def create(self, validated_data): - project = validated_data.pop('project', None) + project = validated_data.pop("project", None) region = super().create(validated_data) if project: @@ -81,24 +80,24 @@ def create(self, validated_data): return region -class AdminLevelSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class AdminLevelSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Admin Level Model Serializer """ - geo_shape_file_details = SimpleFileSerializer(source='geo_shape_file', read_only=True) + + geo_shape_file_details = SimpleFileSerializer(source="geo_shape_file", read_only=True) geojson_file = URLCachedFileField(required=False, read_only=True) bounds_file = URLCachedFileField(required=False, read_only=True) class Meta: model = AdminLevel - exclude = ('geo_area_titles',) + exclude = ("geo_area_titles",) # Validations def validate_region(self, region): - if not region.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid region') + if not region.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid region") return region def create(self, validated_data): @@ -107,7 +106,7 @@ def create(self, validated_data): admin_level.save() region = admin_level.region - region.modified_by = self.context['request'].user + region.modified_by = self.context["request"].user region.save() if not settings.TESTING: @@ -124,7 +123,7 @@ def update(self, instance, validated_data): admin_level.save() region = admin_level.region - region.modified_by = self.context['request'].user + region.modified_by = self.context["request"].user region.save() if not settings.TESTING: @@ -143,8 +142,4 @@ class GeoAreaSerializer(serializers.ModelSerializer): class Meta: model = GeoArea - fields = ( - 'key', 'label', 'region', 'title', - 'region_title', 'admin_level_level', 'admin_level_title', - 'parent' - ) + fields = ("key", "label", "region", "title", "region_title", "admin_level_level", "admin_level_title", "parent") diff --git a/apps/geo/tasks.py b/apps/geo/tasks.py index 90dab4fc2c..8a7af3adeb 100644 --- a/apps/geo/tasks.py +++ b/apps/geo/tasks.py @@ -1,19 +1,17 @@ +import logging +import os +import tempfile +import zipfile + +import reversion from celery import shared_task from django.conf import settings from django.contrib.gis.gdal import DataSource from django.contrib.gis.geos import GEOSGeometry from django.db.models import Q -from geo.models import Region, AdminLevel, GeoArea - +from geo.models import AdminLevel, GeoArea, Region from redis_store import redis -import os -import reversion -import tempfile -import zipfile - -import logging - logger = logging.getLogger(__name__) @@ -35,7 +33,7 @@ def _save_geo_area(admin_level, parent, feature): if admin_level.code_prop: code = feature.get(admin_level.code_prop) - name = name or '' + name = name or "" geo_area = GeoArea.objects.filter( Q(code=None, title=name) | Q(code=code), @@ -46,7 +44,7 @@ def _save_geo_area(admin_level, parent, feature): geo_area = GeoArea() geo_area.title = name - geo_area.code = code if code else '' + geo_area.code = code if code else "" geo_area.admin_level = admin_level geom = feature.geom @@ -61,31 +59,18 @@ def _save_geo_area(admin_level, parent, feature): # raise Exception('Invalid geometry type for geoarea') geo_area.polygons = geom - feature_names = [ - f.decode('utf-8') if isinstance(f, bytes) else f - for f in feature.fields - ] + feature_names = [f.decode("utf-8") if isinstance(f, bytes) else f for f in feature.fields] if parent: - if admin_level.parent_name_prop and \ - admin_level.parent_name_prop in feature_names: - candidates = GeoArea.objects.filter( - admin_level=parent, - title=feature.get(admin_level.parent_name_prop) - ) + if admin_level.parent_name_prop and admin_level.parent_name_prop in feature_names: + candidates = GeoArea.objects.filter(admin_level=parent, title=feature.get(admin_level.parent_name_prop)) if admin_level.parent_code_prop: - candidates = candidates.filter( - code=feature.get(admin_level.parent_code_prop) - ) + candidates = candidates.filter(code=feature.get(admin_level.parent_code_prop)) geo_area.parent = candidates.first() - elif admin_level.parent_code_prop and \ - admin_level.parent_code_prop in feature_names: - geo_area.parent = GeoArea.objects.filter( - admin_level=parent, - code=feature.get(admin_level.parent_code_prop) - ).first() + elif admin_level.parent_code_prop and admin_level.parent_code_prop in feature_names: + geo_area.parent = GeoArea.objects.filter(admin_level=parent, code=feature.get(admin_level.parent_code_prop)).first() geo_area.save() return geo_area @@ -101,8 +86,7 @@ def _generate_geo_areas(admin_level, parent): # disk. # Then load data from that file filename, extension = os.path.splitext(geo_shape_file.file.name) - f = tempfile.NamedTemporaryFile(suffix=extension, - dir=settings.TEMP_DIR) + f = tempfile.NamedTemporaryFile(suffix=extension, dir=settings.TEMP_DIR) f.write(geo_shape_file.file.read()) # Flush the file before reading it with GDAL @@ -110,14 +94,11 @@ def _generate_geo_areas(admin_level, parent): # the write is complete and will raise an exception. f.flush() - if extension == '.zip': - with tempfile.TemporaryDirectory( - dir=settings.TEMP_DIR - ) as tmpdirname: - zipfile.ZipFile(f.name, 'r').extractall(tmpdirname) + if extension == ".zip": + with tempfile.TemporaryDirectory(dir=settings.TEMP_DIR) as tmpdirname: + zipfile.ZipFile(f.name, "r").extractall(tmpdirname) files = os.listdir(tmpdirname) - shape_file = next((f for f in files if f.endswith('.shp')), - None) + shape_file = next((f for f in files if f.endswith(".shp")), None) data_source = DataSource(os.path.join(tmpdirname, shape_file)) else: data_source = DataSource(f.name) @@ -132,15 +113,14 @@ def _generate_geo_areas(admin_level, parent): for feature in layer: # Each feature is a geo area geo_area = _save_geo_area( - admin_level, parent, + admin_level, + parent, feature, ) added_areas.append(geo_area.id) # Delete all previous geo areas that have not been added - GeoArea.objects.filter( - admin_level=admin_level - ).exclude(id__in=added_areas).delete() + GeoArea.objects.filter(admin_level=admin_level).exclude(id__in=added_areas).delete() admin_level.stale_geo_areas = False admin_level.geojson_file = None @@ -188,9 +168,7 @@ def _load_geo_areas(region_id): if AdminLevel.objects.filter(region=region).count() == 0: return True - parent_admin_levels = AdminLevel.objects.filter( - region=region, parent=None - ) + parent_admin_levels = AdminLevel.objects.filter(region=region, parent=None) completed_levels = [] _extract_from_admin_levels( parent_admin_levels, @@ -205,7 +183,7 @@ def _load_geo_areas(region_id): @shared_task def load_geo_areas(region_id): - key = 'load_geo_areas_{}'.format(region_id) + key = "load_geo_areas_{}".format(region_id) lock = redis.get_lock(key, 60 * 30) # Lock lifetime 30 minutes have_lock = lock.acquire(blocking=False) if not have_lock: @@ -214,7 +192,7 @@ def load_geo_areas(region_id): try: return_value = _load_geo_areas(region_id) except Exception: - logger.error('Load Geo Areas', exc_info=True) + logger.error("Load Geo Areas", exc_info=True) return_value = False lock.release() @@ -232,7 +210,7 @@ def cal_region_cache(regions_id): region.calc_cache() success_regions.append(region.pk) except Exception: - logger.error('Region Cache Calculation Failed!!', exc_info=True) + logger.error("Region Cache Calculation Failed!!", exc_info=True) return success_regions @@ -247,5 +225,5 @@ def cal_admin_level_cache(admin_levels_id): admin_level.calc_cache() success_admin_levels.append(admin_level.pk) except Exception: - logger.error('Admin Level Cache Calculation Failed!!', exc_info=True) + logger.error("Admin Level Cache Calculation Failed!!", exc_info=True) return success_admin_levels diff --git a/apps/geo/tests/test_apis.py b/apps/geo/tests/test_apis.py index 4a61ae38ca..65366775fe 100644 --- a/apps/geo/tests/test_apis.py +++ b/apps/geo/tests/test_apis.py @@ -1,22 +1,23 @@ import json -from deep.tests import TestCase -from geo.models import Region, AdminLevel, GeoArea +from geo.models import AdminLevel, GeoArea, Region from project.models import Project +from deep.tests import TestCase + class RegionTests(TestCase): def test_create_region(self): region_count = Region.objects.count() project = self.create(Project, role=self.admin_role) - url = '/api/v1/regions/' + url = "/api/v1/regions/" data = { - 'code': 'NLP', - 'title': 'Nepal', - 'data': {'testfield': 'testfile'}, - 'public': True, - 'project': project.id, + "code": "NLP", + "title": "Nepal", + "data": {"testfield": "testfile"}, + "public": True, + "project": project.id, } self.authenticate() @@ -24,9 +25,8 @@ def test_create_region(self): self.assert_201(response) self.assertEqual(Region.objects.count(), region_count + 1) - self.assertEqual(response.data['code'], data['code']) - self.assertIn(Region.objects.get(id=response.data['id']), - project.regions.all()) + self.assertEqual(response.data["code"], data["code"]) + self.assertIn(Region.objects.get(id=response.data["id"]), project.regions.all()) def test_region_published_status(self): """ @@ -36,10 +36,8 @@ def test_region_published_status(self): region = self.create(Region, is_published=True) project.regions.add(region) - data = { - 'is_published': False - } - url = f'/api/v1/regions/{region.id}/' + data = {"is_published": False} + url = f"/api/v1/regions/{region.id}/" self.authenticate() response = self.client.patch(url, data) self.assert_403(response) @@ -51,7 +49,7 @@ def test_publish_region(self): region = self.create(Region, created_by=user) project.regions.add(region) - url = f'/api/v1/regions/{region.id}/publish/' + url = f"/api/v1/regions/{region.id}/publish/" data = {} # authenticated with user that has not created region @@ -62,27 +60,27 @@ def test_publish_region(self): self.authenticate(user) response = self.client.post(url, data) self.assert_200(response) - self.assertEqual(response.data['is_published'], True) + self.assertEqual(response.data["is_published"], True) def test_clone_region(self): project = self.create(Project, role=self.admin_role) region = self.create(Region) project.regions.add(region) - url = '/api/v1/clone-region/{}/'.format(region.id) + url = "/api/v1/clone-region/{}/".format(region.id) data = { - 'project': project.id, + "project": project.id, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertNotEqual(response.data['id'], region.id) - self.assertFalse(response.data['public']) + self.assertNotEqual(response.data["id"], region.id) + self.assertFalse(response.data["public"]) self.assertFalse(region in project.regions.all()) - new_region = Region.objects.get(id=response.data['id']) + new_region = Region.objects.get(id=response.data["id"]) self.assertTrue(new_region in project.regions.all()) def test_region_filter_not_in_project(self): @@ -96,27 +94,24 @@ def test_region_filter_not_in_project(self): project_2.regions.add(region_3) # filter regions in project - url = f'/api/v1/regions/?project={project_1.id}' + url = f"/api/v1/regions/?project={project_1.id}" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) - self.assertEqual( - set(rg['id'] for rg in response.data['results']), - set([region_1.id, region_2.id]) - ) + self.assertEqual(len(response.data["results"]), 2) + self.assertEqual(set(rg["id"] for rg in response.data["results"]), set([region_1.id, region_2.id])) # filter the region that are not in project - url = f'/api/v1/regions/?exclude_project={project_1.id}' + url = f"/api/v1/regions/?exclude_project={project_1.id}" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 1) - self.assertEqual(response.data['results'][0]['id'], region_3.id) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], region_3.id) def test_trigger_api(self): region = self.create(Region) - url = '/api/v1/geo-areas-load-trigger/{}/'.format(region.id) + url = "/api/v1/geo-areas-load-trigger/{}/".format(region.id) self.authenticate() response = self.client.get(url) @@ -128,14 +123,14 @@ def test_create_admin_level(self): admin_level_count = AdminLevel.objects.count() region = self.create(Region) - url = '/api/v1/admin-levels/' + url = "/api/v1/admin-levels/" data = { - 'region': region.pk, - 'title': 'test', - 'name_prop': 'test', - 'pcode_prop': 'test', - 'parent_name_prop': 'test', - 'parent_pcode_prop': 'test', + "region": region.pk, + "title": "test", + "name_prop": "test", + "pcode_prop": "test", + "parent_name_prop": "test", + "parent_pcode_prop": "test", } self.authenticate() @@ -143,72 +138,69 @@ def test_create_admin_level(self): self.assert_201(response) self.assertEqual(AdminLevel.objects.count(), admin_level_count + 1) - self.assertEqual(response.data['title'], data['title']) + self.assertEqual(response.data["title"], data["title"]) class GeoOptionsApi(TestCase): def test_geo_options(self): - region1 = self.create(Region, title='Region 1') - region2 = self.create(Region, title='Region 2') - region3 = self.create(Region, title='Region 3') + region1 = self.create(Region, title="Region 1") + region2 = self.create(Region, title="Region 2") + region3 = self.create(Region, title="Region 3") project = self.create_project() project.regions.add(region1, region2) - admin_level1_1 = self.create(AdminLevel, title='AdminLevel1', region=region1, level=0) - admin_level1_2 = self.create(AdminLevel, title='AdminLevel2', region=region2, level=1) - admin_level2_1 = self.create(AdminLevel, title='AdminLevel1', region=region1, level=0) - self.create(AdminLevel, title='AdminLevel1', region=region2, level=0) - self.create(AdminLevel, title='AdminLevel2', region=region2, level=1) - self.create(AdminLevel, title='AdminLevel1', region=region3, level=0) - self.create(AdminLevel, title='AdminLevel2', region=region3, level=1) - geo_area1_1 = self.create(GeoArea, title='GeoArea1', admin_level=admin_level1_1) - geo_area1_2 = self.create(GeoArea, title='GeoArea2', admin_level=admin_level1_2, parent=geo_area1_1) - self.create(GeoArea, title='GeoArea2', admin_level=admin_level2_1) - - url = f'/api/v1/geo-options/?project={project.pk}' + admin_level1_1 = self.create(AdminLevel, title="AdminLevel1", region=region1, level=0) + admin_level1_2 = self.create(AdminLevel, title="AdminLevel2", region=region2, level=1) + admin_level2_1 = self.create(AdminLevel, title="AdminLevel1", region=region1, level=0) + self.create(AdminLevel, title="AdminLevel1", region=region2, level=0) + self.create(AdminLevel, title="AdminLevel2", region=region2, level=1) + self.create(AdminLevel, title="AdminLevel1", region=region3, level=0) + self.create(AdminLevel, title="AdminLevel2", region=region3, level=1) + geo_area1_1 = self.create(GeoArea, title="GeoArea1", admin_level=admin_level1_1) + geo_area1_2 = self.create(GeoArea, title="GeoArea2", admin_level=admin_level1_2, parent=geo_area1_1) + self.create(GeoArea, title="GeoArea2", admin_level=admin_level2_1) + + url = f"/api/v1/geo-options/?project={project.pk}" self.authenticate() response = self.client.get(url, follow=True) self.assert_200(response) - cached_file_url = response.data['geo_options_cached_file'] + cached_file_url = response.data["geo_options_cached_file"] - data = json.loads(b''.join(list(self.client.get(cached_file_url).streaming_content))) - self.assertEqual( - data[str(region1.id)][1].get('label'), - '{} / {}'.format(admin_level1_1.title, geo_area1_2.title) - ) + data = json.loads(b"".join(list(self.client.get(cached_file_url).streaming_content))) + self.assertEqual(data[str(region1.id)][1].get("label"), "{} / {}".format(admin_level1_1.title, geo_area1_2.title)) # check if parent is present in geo options for _, options in data.items(): for option in options: - assert 'parent' in option + assert "parent" in option # URL should be same for future request response = self.client.get(url, follow=True) self.assert_200(response) - assert cached_file_url == response.data['geo_options_cached_file'] + assert cached_file_url == response.data["geo_options_cached_file"] # URL should be changed if region data is changed region1.refresh_from_db() region1.cache_index += 1 - region1.save(update_fields=('cache_index',)) + region1.save(update_fields=("cache_index",)) response = self.client.get(url, follow=True) self.assert_200(response) - assert cached_file_url != response.data['geo_options_cached_file'] - cached_file_url = response.data['geo_options_cached_file'] + assert cached_file_url != response.data["geo_options_cached_file"] + cached_file_url = response.data["geo_options_cached_file"] # URL should be same again for future request response = self.client.get(url, follow=True) self.assert_200(response) - assert cached_file_url == response.data['geo_options_cached_file'] + assert cached_file_url == response.data["geo_options_cached_file"] # URL shouldn't be changed if non assigned region data is changed region3.refresh_from_db() region3.cache_index += 1 - region3.save(update_fields=('cache_index',)) + region3.save(update_fields=("cache_index",)) response = self.client.get(url, follow=True) self.assert_200(response) - assert cached_file_url == response.data['geo_options_cached_file'] + assert cached_file_url == response.data["geo_options_cached_file"] class TestGeoAreaApi(TestCase): @@ -225,23 +217,23 @@ def test_geo_area(self): project2.add_member(user2) project2.regions.add(region2) - admin_level1 = self.create(AdminLevel, region=region, title='test') + admin_level1 = self.create(AdminLevel, region=region, title="test") admin_level2 = self.create(AdminLevel, region=region) admin_level3 = self.create(AdminLevel, region=region1) admin_level4 = self.create(AdminLevel, region=region2) - geo_area1 = self.create(GeoArea, admin_level=admin_level1, title='me') + geo_area1 = self.create(GeoArea, admin_level=admin_level1, title="me") self.create(GeoArea, admin_level=admin_level2, parent=geo_area1) self.create(GeoArea, admin_level=admin_level4) self.create(GeoArea, admin_level=admin_level3) - url = f'/api/v1/projects/{project.id}/geo-area/' + url = f"/api/v1/projects/{project.id}/geo-area/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 2) # geo area with region `published=True` + self.assertEqual(response.data["count"], 2) # geo area with region `published=True` # test for the label - self.assertEqual(response.data['results'][0]['label'], '{}/{}'.format(admin_level1.title, geo_area1.title)) + self.assertEqual(response.data["results"][0]["label"], "{}/{}".format(admin_level1.title, geo_area1.title)) # test for the not project member self.authenticate(user2) @@ -249,22 +241,22 @@ def test_geo_area(self): self.assert_403(response) # test for the pagination - url = f'/api/v1/projects/{project.id}/geo-area/?limit=1' + url = f"/api/v1/projects/{project.id}/geo-area/?limit=1" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 1) + self.assertEqual(len(response.data["results"]), 1) # test for the search field - url = f'/api/v1/projects/{project.id}/geo-area/?label=test' + url = f"/api/v1/projects/{project.id}/geo-area/?label=test" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # Passing the label that is not either region or geoarea title - url = f'/api/v1/projects/{project.id}/geo-area/?label=acd' + url = f"/api/v1/projects/{project.id}/geo-area/?label=acd" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 0) + self.assertEqual(response.data["count"], 0) diff --git a/apps/geo/tests/test_schemas.py b/apps/geo/tests/test_schemas.py index 7c0f3bf822..4e29849ee4 100644 --- a/apps/geo/tests/test_schemas.py +++ b/apps/geo/tests/test_schemas.py @@ -1,14 +1,13 @@ -from utils.graphene.tests import GraphQLTestCase - +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory from geo.models import AdminLevel - from project.factories import ProjectFactory from user.factories import UserFactory -from geo.factories import RegionFactory, AdminLevelFactory, GeoAreaFactory + +from utils.graphene.tests import GraphQLTestCase class TestGeoSchema(GraphQLTestCase): - GEO_QUERY = ''' + GEO_QUERY = """ query GeoQuery( $projectId: ID!, $adminLevelIds: [ID!], @@ -36,7 +35,7 @@ class TestGeoSchema(GraphQLTestCase): } } } - ''' + """ def test_geo_filters(self): user = UserFactory.create() @@ -55,11 +54,7 @@ def test_geo_filters(self): region_4_admin_level_0 = AdminLevelFactory.create(region=region4) # Geo areas region_1_ad_1_geo_areas = GeoAreaFactory.create_batch(3, admin_level=region_1_admin_level_1) - region_1_ad_2_geo_areas = GeoAreaFactory.create_batch( - 5, - admin_level=region_1_admin_level_2, - title='XYZ Geô' - ) + region_1_ad_2_geo_areas = GeoAreaFactory.create_batch(5, admin_level=region_1_admin_level_2, title="XYZ Geô") region_2_ad_1_geo_areas = GeoAreaFactory.create_batch(4, admin_level=region_2_admin_level_1) GeoAreaFactory.create_batch(2, admin_level=region_3_admin_level_0) GeoAreaFactory.create_batch(4, admin_level=region_4_admin_level_0) @@ -69,7 +64,7 @@ def _query_check(filters, **kwargs): self.GEO_QUERY, variables={ **filters, - 'projectId': str(project.id), + "projectId": str(project.id), }, **kwargs, ) @@ -82,29 +77,21 @@ def _query_check(filters, **kwargs): # With filters for name, filters, geo_areas in ( + ("no-filter", dict(), [*region_1_ad_1_geo_areas, *region_1_ad_2_geo_areas, *region_2_ad_1_geo_areas]), + ("invalid-region-id", dict(regionIds=[str(region3.pk)]), []), + ("valid-region-id", dict(regionIds=[str(region1.pk)]), [*region_1_ad_1_geo_areas, *region_1_ad_2_geo_areas]), + ("invalid-admin-level-id", dict(adminLevelIds=[str(region_3_admin_level_0.pk)]), []), ( - 'no-filter', - dict(), - [*region_1_ad_1_geo_areas, *region_1_ad_2_geo_areas, *region_2_ad_1_geo_areas] - ), - ('invalid-region-id', dict(regionIds=[str(region3.pk)]), []), - ( - 'valid-region-id', - dict(regionIds=[str(region1.pk)]), - [*region_1_ad_1_geo_areas, *region_1_ad_2_geo_areas] - ), - ('invalid-admin-level-id', dict(adminLevelIds=[str(region_3_admin_level_0.pk)]), []), - ( - 'valid-admin-level-id', + "valid-admin-level-id", dict(adminLevelIds=[str(region_1_admin_level_1.pk)]), region_1_ad_1_geo_areas, ), - ('search', dict(search='XYZ Geo'), region_1_ad_2_geo_areas), + ("search", dict(search="XYZ Geo"), region_1_ad_2_geo_areas), ): - content = _query_check(filters)['data']['project']['geoAreas'] - self.assertEqual(content['totalCount'], len(geo_areas), (name, content)) - self.assertEqual(len(content['results']), len(geo_areas), (name, content)) - self.assertListIds(content['results'], geo_areas, (name, content)) + content = _query_check(filters)["data"]["project"]["geoAreas"] + self.assertEqual(content["totalCount"], len(geo_areas), (name, content)) + self.assertEqual(len(content["results"]), len(geo_areas), (name, content)) + self.assertListIds(content["results"], geo_areas, (name, content)) def test_geo_query(self): user = UserFactory.create() @@ -130,33 +117,32 @@ def test_geo_query(self): region_3_ad_1_geo_area_01 = GeoAreaFactory.create(admin_level=region_3_admin_level_1) # -- Sub nodes region_1_ad_2_geo_area_01 = GeoAreaFactory.create( - title='child (Region 1, AdminLevel 2) Geo Area 01', + title="child (Region 1, AdminLevel 2) Geo Area 01", admin_level=region_1_admin_level_2, parent=region_1_ad_1_geo_area_01, ) region_1_ad_2_geo_area_02 = GeoAreaFactory.create( - title='child (Region 1, AdminLevel 2) Geo Area 02', + title="child (Region 1, AdminLevel 2) Geo Area 02", admin_level=region_1_admin_level_2, parent=region_1_ad_1_geo_area_01, ) region_1_ad_2_geo_area_03 = GeoAreaFactory.create( - title='child (Region 1, AdminLevel 2) Geo Area 03', - admin_level=region_1_admin_level_2 + title="child (Region 1, AdminLevel 2) Geo Area 03", admin_level=region_1_admin_level_2 ) region_2_ad_2_geo_area_01 = GeoAreaFactory.create(admin_level=region_2_admin_level_2) GeoAreaFactory.create( - title='child (Region 3, AdminLevel 2) Geo Area 01', + title="child (Region 3, AdminLevel 2) Geo Area 01", admin_level=region_3_admin_level_1, parent=region_3_ad_1_geo_area_01, ) # -- Sub Sub nodes region_1_ad_3_geo_area_01 = GeoAreaFactory.create( - title='child (Region 1, AdminLevel 3) Geo Area 01', + title="child (Region 1, AdminLevel 3) Geo Area 01", admin_level=region_1_admin_level_3, parent=region_1_ad_2_geo_area_01, ) region_2_ad_2_geo_area_01 = GeoAreaFactory.create( - title='child (Region 1, AdminLevel 3) Geo Area 01', + title="child (Region 1, AdminLevel 3) Geo Area 01", admin_level=region_2_admin_level_3, ) @@ -164,8 +150,8 @@ def _query_check(**kwargs): return self.query_check( self.GEO_QUERY, variables={ - 'projectId': str(project.id), - 'search': 'child', + "projectId": str(project.id), + "search": "child", }, **kwargs, ) @@ -176,24 +162,24 @@ def _query_check(**kwargs): for admin_level in AdminLevel.objects.all(): admin_level.calc_cache() - content = _query_check()['data']['project']['geoAreas'] - self.assertEqual(content['results'], [ - { - 'id': str(geo_area.id), - 'title': geo_area.title, - 'adminLevelLevel': geo_area.admin_level.level, - 'adminLevelTitle': geo_area.admin_level.title, - 'regionTitle': geo_area.admin_level.region.title, - 'parentTitles': [ - parent.title - for parent in parents - ], - } - for geo_area, parents in [ - (region_1_ad_2_geo_area_01, [region_1_ad_1_geo_area_01]), - (region_1_ad_2_geo_area_02, [region_1_ad_1_geo_area_01]), - (region_1_ad_2_geo_area_03, []), - (region_1_ad_3_geo_area_01, [region_1_ad_1_geo_area_01, region_1_ad_2_geo_area_01]), - (region_2_ad_2_geo_area_01, []) - ] - ]) + content = _query_check()["data"]["project"]["geoAreas"] + self.assertEqual( + content["results"], + [ + { + "id": str(geo_area.id), + "title": geo_area.title, + "adminLevelLevel": geo_area.admin_level.level, + "adminLevelTitle": geo_area.admin_level.title, + "regionTitle": geo_area.admin_level.region.title, + "parentTitles": [parent.title for parent in parents], + } + for geo_area, parents in [ + (region_1_ad_2_geo_area_01, [region_1_ad_1_geo_area_01]), + (region_1_ad_2_geo_area_02, [region_1_ad_1_geo_area_01]), + (region_1_ad_2_geo_area_03, []), + (region_1_ad_3_geo_area_01, [region_1_ad_1_geo_area_01, region_1_ad_2_geo_area_01]), + (region_2_ad_2_geo_area_01, []), + ] + ], + ) diff --git a/apps/geo/tests/test_tasks.py b/apps/geo/tests/test_tasks.py index 2301fa9995..dae775924b 100644 --- a/apps/geo/tests/test_tasks.py +++ b/apps/geo/tests/test_tasks.py @@ -1,24 +1,24 @@ -import re -import os import json +import os +import re import tempfile from django.conf import settings from django.core.files.uploadedfile import SimpleUploadedFile from django.test.utils import override_settings +from gallery.models import File +from geo.models import AdminLevel, GeoArea, Region +from geo.tasks import load_geo_areas from deep.tests import TestCase -from geo.tasks import load_geo_areas -from geo.models import Region, AdminLevel, GeoArea -from gallery.models import File def read_json_from_url(url): file_path = os.path.join( settings.MEDIA_ROOT, - re.search('http://testserver/media/(?P.*)$', url).group('path'), + re.search("http://testserver/media/(?P.*)$", url).group("path"), ) - with open(file_path, 'r') as fp: + with open(file_path, "r") as fp: return json.load(fp) @@ -28,46 +28,44 @@ def setUp(self): super().setUp() # Create a dummy region - region = Region(code='NPL', title='Nepal') + region = Region(code="NPL", title="Nepal") region.save() # Load a shape file from a test shape file and create admin level 0 - admin_level0 = AdminLevel(region=region, parent=None, - title='Zone', - name_prop='ZONE_NAME', - code_prop='HRPCode') + admin_level0 = AdminLevel(region=region, parent=None, title="Zone", name_prop="ZONE_NAME", code_prop="HRPCode") shape_data = open( - os.path.join(settings.TEST_DIR, - 'nepal-geo-json/admin_level2.geo.json'), - 'rb', + os.path.join(settings.TEST_DIR, "nepal-geo-json/admin_level2.geo.json"), + "rb", ).read() admin_level0.geo_shape_file = File.objects.create( - title='al2', + title="al2", file=SimpleUploadedFile( - name='al2.geo.json', + name="al2.geo.json", content=shape_data, - ) + ), ) admin_level0.save() # Load admin level 1 similarly - admin_level1 = AdminLevel(region=region, parent=None, - title='District', - name_prop='DISTRICT', - code_prop='HRPCode', - parent_name_prop='ZONE', - parent_code_prop='HRParent') + admin_level1 = AdminLevel( + region=region, + parent=None, + title="District", + name_prop="DISTRICT", + code_prop="HRPCode", + parent_name_prop="ZONE", + parent_code_prop="HRParent", + ) shape_data = open( - os.path.join(settings.TEST_DIR, - 'nepal-geo-json/admin_level3.geo.json'), - 'rb', + os.path.join(settings.TEST_DIR, "nepal-geo-json/admin_level3.geo.json"), + "rb", ).read() admin_level1.geo_shape_file = File.objects.create( - title='al3', + title="al3", file=SimpleUploadedFile( - name='al3.geo.json', + name="al3.geo.json", content=shape_data, - ) + ), ) admin_level1.parent = admin_level0 @@ -95,21 +93,21 @@ def test_load_areas(self): # Test if a geo area in admin level 0 is correctly set bagmati = GeoArea.objects.filter( - title='Bagmati', + title="Bagmati", admin_level=self.admin_level0, parent=None, - code='NP-C-BAG', + code="NP-C-BAG", ).first() self.assertIsNotNone(bagmati) # Test if a geo area in admin level 1 is correctly set sindhupalchowk = GeoArea.objects.filter( - title='Sindhupalchok', + title="Sindhupalchok", admin_level=self.admin_level1, - parent__title='Bagmati', - parent__code='NP-C-BAG', - code='NP-C-BAG-23', + parent__title="Bagmati", + parent__code="NP-C-BAG", + code="NP-C-BAG-23", ).first() self.assertIsNotNone(sindhupalchowk) @@ -119,7 +117,7 @@ def test_geojson_api(self): self.assertTrue(result) # Test if geojson api works - url = '/api/v1/admin-levels/{}/geojson/'.format(self.admin_level0.pk) + url = "/api/v1/admin-levels/{}/geojson/".format(self.admin_level0.pk) self.authenticate() response = self.client.get(url) @@ -127,17 +125,15 @@ def test_geojson_api(self): # NOTE: response is FileReponse r_data = read_json_from_url(response.url) - self.assertEqual(r_data['type'], 'FeatureCollection') - self.assertIsNotNone(r_data['features']) - self.assertTrue(len(r_data['features']) > 0) + self.assertEqual(r_data["type"], "FeatureCollection") + self.assertIsNotNone(r_data["features"]) + self.assertTrue(len(r_data["features"]) > 0) # Test if geobounds also works - url = '/api/v1/admin-levels/{}/geojson/bounds/'.format( - self.admin_level0.pk - ) + url = "/api/v1/admin-levels/{}/geojson/bounds/".format(self.admin_level0.pk) response = self.client.get(url) self.assert_302(response) r_data = read_json_from_url(response.url) - self.assertIsNotNone(r_data['bounds']) + self.assertIsNotNone(r_data["bounds"]) diff --git a/apps/geo/views.py b/apps/geo/views.py index f4a6578a93..95242e1e26 100644 --- a/apps/geo/views.py +++ b/apps/geo/views.py @@ -1,9 +1,11 @@ -from django.shortcuts import redirect, get_object_or_404 -from django.contrib.gis.geos import GEOSGeometry -from django.contrib.gis.gdal.error import GDALException +import django_filters from django.conf import settings +from django.contrib.gis.gdal.error import GDALException +from django.contrib.gis.geos import GEOSGeometry from django.db import models - +from django.shortcuts import get_object_or_404, redirect +from project.models import Project +from project.tasks import generate_project_geo_region_cache from rest_framework import ( exceptions, filters, @@ -14,44 +16,29 @@ viewsets, ) from rest_framework.decorators import action -import django_filters -from deep.permissions import ( - ModifyPermission, - IsProjectMember -) -from project.models import Project -from project.tasks import generate_project_geo_region_cache +from deep.permissions import IsProjectMember, ModifyPermission -from .models import Region, AdminLevel, GeoArea -from .serializers import ( - AdminLevelSerializer, - RegionSerializer, - GeoAreaSerializer -) -from .filter_set import ( - GeoAreaFilterSet, - AdminLevelFilterSet, - RegionFilterSet -) +from .filter_set import AdminLevelFilterSet, GeoAreaFilterSet, RegionFilterSet +from .models import AdminLevel, GeoArea, Region +from .serializers import AdminLevelSerializer, GeoAreaSerializer, RegionSerializer from .tasks import load_geo_areas class RegionViewSet(viewsets.ModelViewSet): serializer_class = RegionSerializer permission_classes = [permissions.IsAuthenticated, ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = RegionFilterSet - search_fields = ('title', 'code') + search_fields = ("title", "code") def get_queryset(self): - return Region.get_for(self.request.user).defer('geo_options') + return Region.get_for(self.request.user).defer("geo_options") @action( detail=True, - url_path='intersects', - methods=('post',), + url_path="intersects", + methods=("post",), # TODO: Better permissions permission_classes=[permissions.IsAuthenticated], ) @@ -59,42 +46,42 @@ def get_intersects(self, request, pk=None, version=None): region = self.get_object() try: geoms = [] - features = request.data['features'] + features = request.data["features"] for feature in features: - geoms.append([feature.get('id'), GEOSGeometry(str(feature['geometry']))]) + geoms.append([feature.get("id"), GEOSGeometry(str(feature["geometry"]))]) except (GDALException, KeyError) as e: - raise exceptions.ValidationError( - f"Geometry parsed failed, Error: {getattr(e, 'message', repr(e))}" - ) - return response.Response([ - { - 'id': id, - 'region_id': region.pk, - 'geoareas': ( - # https://docs.djangoproject.com/en/2.1/ref/contrib/gis/geoquerysets/ - GeoArea.objects.filter( - admin_level__region=region, - polygons__intersects=geom, - ).values_list('id', flat=True) - ), - } - for id, geom in geoms - ]) + raise exceptions.ValidationError(f"Geometry parsed failed, Error: {getattr(e, 'message', repr(e))}") + return response.Response( + [ + { + "id": id, + "region_id": region.pk, + "geoareas": ( + # https://docs.djangoproject.com/en/2.1/ref/contrib/gis/geoquerysets/ + GeoArea.objects.filter( + admin_level__region=region, + polygons__intersects=geom, + ).values_list("id", flat=True) + ), + } + for id, geom in geoms + ] + ) @action( detail=True, - url_path='publish', - methods=['post'], + url_path="publish", + methods=["post"], serializer_class=RegionSerializer, - permission_classes=[permissions.IsAuthenticated] + permission_classes=[permissions.IsAuthenticated], ) def get_published(self, request, pk=None, version=None): region = self.get_object() if not region.can_publish(self.request.user): - raise exceptions.ValidationError('Can be published by user who created it') + raise exceptions.ValidationError("Can be published by user who created it") region.is_published = True - region.save(update_fields=['is_published']) - serializer = RegionSerializer(region, partial=True, context={'request': request}) + region.save(update_fields=["is_published"]) + serializer = RegionSerializer(region, partial=True, context={"request": request}) return response.Response(serializer.data) @@ -110,44 +97,43 @@ def post(self, request, region_id, version=None): raise exceptions.PermissionDenied() new_region = region.clone_to_private(request.user) - serializer = RegionSerializer(new_region, context={'request': request}) + serializer = RegionSerializer(new_region, context={"request": request}) - project = request.data.get('project') + project = request.data.get("project") if project: project = Project.objects.get(id=project) if not project.can_modify(request.user): - raise exceptions.ValidationError({ - 'project': 'Invalid project', - }) + raise exceptions.ValidationError( + { + "project": "Invalid project", + } + ) project.regions.remove(region) project.regions.add(new_region) - return response.Response(serializer.data, - status=status.HTTP_201_CREATED) + return response.Response(serializer.data, status=status.HTTP_201_CREATED) class AdminLevelViewSet(viewsets.ModelViewSet): """ Admin Level API Point """ + serializer_class = AdminLevelSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = AdminLevelFilterSet - search_fields = ('title') + search_fields = "title" def get_queryset(self): - return AdminLevel.get_for(self.request.user).select_related('geo_shape_file').defer( - *AdminLevelSerializer.Meta.exclude - ) + return AdminLevel.get_for(self.request.user).select_related("geo_shape_file").defer(*AdminLevelSerializer.Meta.exclude) class GeoAreasLoadTriggerView(views.APIView): """ A trigger for loading geo areas from admin level """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, region_id, version=None): @@ -160,15 +146,18 @@ def get(self, request, region_id, version=None): if not settings.TESTING: load_geo_areas.delay(region_id) - return response.Response({ - 'load_triggered': region_id, - }) + return response.Response( + { + "load_triggered": region_id, + } + ) class GeoJsonView(views.APIView): """ A view that returns geojson for given admin level """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, admin_level_id, version=None): @@ -191,6 +180,7 @@ class GeoBoundsView(views.APIView): """ A view that returns geo bounds for given admin level """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, admin_level_id, version=None): @@ -211,19 +201,17 @@ class GeoOptionsView(views.APIView): permission_classes = [permissions.IsAuthenticated] def get(self, request, version=None): - project = get_object_or_404(Project, pk=request.GET.get('project')) + project = get_object_or_404(Project, pk=request.GET.get("project")) if not project.is_member(request.user): raise exceptions.PermissionDenied() if ( - project.geo_cache_file.name is None or - project.geo_cache_hash is None or - project.geo_cache_hash != str(hash(tuple(project.regions.order_by('id').values_list('cache_index', flat=True)))) + project.geo_cache_file.name is None + or project.geo_cache_hash is None + or project.geo_cache_hash != str(hash(tuple(project.regions.order_by("id").values_list("cache_index", flat=True)))) ): generate_project_geo_region_cache(project) - return response.Response({ - 'geo_options_cached_file': request.build_absolute_uri(project.geo_cache_file.url) - }) + return response.Response({"geo_options_cached_file": request.build_absolute_uri(project.geo_cache_file.url)}) class GeoAreaView(viewsets.ReadOnlyModelViewSet): @@ -232,19 +220,17 @@ class GeoAreaView(viewsets.ReadOnlyModelViewSet): filterset_class = GeoAreaFilterSet def get_queryset(self): - return GeoArea.objects.filter( - admin_level__region__project=self.kwargs['project_id'], - admin_level__region__is_published=True - ).annotate( - label=models.functions.Concat( - models.F('admin_level__title'), - models.Value('/'), - models.F('title'), - output_field=models.fields.CharField() - ), - region=models.F('admin_level__region_id'), - region_title=models.F('admin_level__region__title'), - admin_level_level=models.F('admin_level__level'), - admin_level_title=models.F('admin_level__title'), - key=models.F('id') - ).distinct() + return ( + GeoArea.objects.filter(admin_level__region__project=self.kwargs["project_id"], admin_level__region__is_published=True) + .annotate( + label=models.functions.Concat( + models.F("admin_level__title"), models.Value("/"), models.F("title"), output_field=models.fields.CharField() + ), + region=models.F("admin_level__region_id"), + region_title=models.F("admin_level__region__title"), + admin_level_level=models.F("admin_level__level"), + admin_level_title=models.F("admin_level__title"), + key=models.F("id"), + ) + .distinct() + ) diff --git a/apps/jwt_auth/apps.py b/apps/jwt_auth/apps.py index 904b0a9cd6..a91286e03d 100644 --- a/apps/jwt_auth/apps.py +++ b/apps/jwt_auth/apps.py @@ -2,4 +2,4 @@ class JwtAuthConfig(AppConfig): - name = 'jwt_auth' + name = "jwt_auth" diff --git a/apps/jwt_auth/authentication.py b/apps/jwt_auth/authentication.py index 51cac410f7..2e7df239b7 100644 --- a/apps/jwt_auth/authentication.py +++ b/apps/jwt_auth/authentication.py @@ -4,25 +4,25 @@ from .token import AccessToken, TokenError - # The auth header type 'Bearer' encoded to bytes -AUTH_HEADER_TYPE_BYTES = 'Bearer'.encode(HTTP_HEADER_ENCODING) +AUTH_HEADER_TYPE_BYTES = "Bearer".encode(HTTP_HEADER_ENCODING) # Paths for which no verification of access token in performed # such as expiry verifications # TODO: Use more generalized way to check safe path such as regex -SAFE_PATHS = ['/api/v1/token/refresh/'] +SAFE_PATHS = ["/api/v1/token/refresh/"] class JwtAuthentication(authentication.BaseAuthentication): """ JwtAuthentication for django rest framework """ + def authenticate_header(self, request): """ Value of www-authenticate header in 401 error """ - return 'Bearer realm=api' + return "Bearer realm=api" def authenticate(self, request): """ @@ -31,7 +31,7 @@ def authenticate(self, request): """ # Get header - header = request.META.get('HTTP_AUTHORIZATION') + header = request.META.get("HTTP_AUTHORIZATION") if header is None: return None @@ -63,17 +63,14 @@ def get_access_token(self, header, request): # Improper Bearer header if len(parts) != 2: - raise AuthenticationFailed( - 'Authorization header must be of format: Bearer ' - ) + raise AuthenticationFailed("Authorization header must be of format: Bearer ") token = parts[1] # We got the token string, decode and return the # access token object try: - access_token = AccessToken(token, - verify=request.path not in SAFE_PATHS) + access_token = AccessToken(token, verify=request.path not in SAFE_PATHS) return access_token except TokenError as e: raise AuthenticationFailed(e.message) diff --git a/apps/jwt_auth/captcha.py b/apps/jwt_auth/captcha.py index 5b3ace4ffd..88b8d8c1b8 100644 --- a/apps/jwt_auth/captcha.py +++ b/apps/jwt_auth/captcha.py @@ -1,9 +1,9 @@ -from django.conf import settings import requests +from django.conf import settings from .errors import InvalidCaptchaError -HCAPTCHA_VERIFY_URL = 'https://hcaptcha.com/siteverify' +HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" def _validate_hcaptcha(captcha): @@ -11,13 +11,13 @@ def _validate_hcaptcha(captcha): return False data = { - 'secret': settings.HCAPTCHA_SECRET, - 'response': captcha, + "secret": settings.HCAPTCHA_SECRET, + "response": captcha, } response = requests.post(url=HCAPTCHA_VERIFY_URL, data=data) response_json = response.json() - return response_json['success'] + return response_json["success"] def validate_hcaptcha(captcha, raise_on_error=True): diff --git a/apps/jwt_auth/errors.py b/apps/jwt_auth/errors.py index 5aa7959c21..a5c53b4b4c 100644 --- a/apps/jwt_auth/errors.py +++ b/apps/jwt_auth/errors.py @@ -1,56 +1,60 @@ from django.conf import settings + from deep import error_codes class UserNotFoundError(Exception): status_code = 401 code = error_codes.USER_NOT_FOUND - message = 'User not found' + message = "User not found" class UserInactiveError(Exception): status_code = 401 code = error_codes.USER_INACTIVE - message = 'User account is deactivated' + message = "User account is deactivated" def __init__(self, message): - if (message): + if message: self.message = message class UnknownTokenError(Exception): status_code = 400 code = error_codes.TOKEN_INVALID - message = 'Token contains no valid user identification' + message = "Token contains no valid user identification" class NotAuthenticatedError(Exception): status_code = 401 - code = error_codes.NOT_AUTHENTICATED, - message = 'You are not authenticated' + code = (error_codes.NOT_AUTHENTICATED,) + message = "You are not authenticated" class InvalidCaptchaError(Exception): status_code = 401 code = error_codes.INVALID_CAPTCHA - default_detail = 'Invalid captcha! Please, Try Again' + default_detail = "Invalid captcha! Please, Try Again" class AuthenticationFailedError(Exception): status_code = 400 code = error_codes.AUTHENTICATION_FAILED - message = 'No active account found with the given credentials' + message = "No active account found with the given credentials" def __init__(self, login_attempts=None): if login_attempts: remaining = settings.MAX_LOGIN_ATTEMPTS - login_attempts - self.message +=\ - '. You have {} login attempts remaining'.format( - remaining if remaining >= 0 else 0, - ) + self.message += ". You have {} login attempts remaining".format( + remaining if remaining >= 0 else 0, + ) WARN_EXCEPTIONS = [ - UserNotFoundError, UserInactiveError, UnknownTokenError, - NotAuthenticatedError, InvalidCaptchaError, AuthenticationFailedError, + UserNotFoundError, + UserInactiveError, + UnknownTokenError, + NotAuthenticatedError, + InvalidCaptchaError, + AuthenticationFailedError, ] diff --git a/apps/jwt_auth/serializers.py b/apps/jwt_auth/serializers.py index 75d084e5db..422c2ccf19 100644 --- a/apps/jwt_auth/serializers.py +++ b/apps/jwt_auth/serializers.py @@ -4,16 +4,14 @@ from django.contrib.auth import authenticate, models from rest_framework import serializers from rest_framework.exceptions import AuthenticationFailed +from user.utils import send_account_activation +from user.validators import CustomMaximumLengthValidator from utils.hid import hid -from user.utils import send_account_activation -from .token import AccessToken, RefreshToken, TokenError + from .captcha import validate_hcaptcha -from .errors import ( - AuthenticationFailedError, - UserInactiveError, -) -from user.validators import CustomMaximumLengthValidator +from .errors import AuthenticationFailedError, UserInactiveError +from .token import AccessToken, RefreshToken, TokenError logger = logging.getLogger(__name__) @@ -35,8 +33,7 @@ def deactivate_account(self, user): # user.is_active = False # user.save() # send_account_activation(user) - raise UserInactiveError( - message='Account is deactivated, check your email') + raise UserInactiveError(message="Account is deactivated, check your email") def check_login_attempts(self, user, captcha): login_attempts = user.profile.login_attempts @@ -47,16 +44,12 @@ def check_login_attempts(self, user, captcha): def validate(self, data): # NOTE: authenticate only works for active users - user = authenticate( - username=data['username'], - password=data['password'] - ) - captcha = data.get('hcaptcha_response') + user = authenticate(username=data["username"], password=data["password"]) + captcha = data.get("hcaptcha_response") # user not active or user credentials don't match if not user or not user.is_active: - user = models.User.objects.filter(username=data['username'])\ - .first() + user = models.User.objects.filter(username=data["username"]).first() if user: user.profile.login_attempts += 1 user.save() @@ -74,8 +67,8 @@ def validate(self, data): refresh_token = RefreshToken.for_access_token(access_token) return { - 'access': access_token.encode(), - 'refresh': refresh_token.encode(), + "access": access_token.encode(), + "refresh": refresh_token.encode(), } @@ -83,30 +76,26 @@ class TokenRefreshSerializer(serializers.Serializer): refresh = serializers.CharField() def validate(self, data): - user = self.context['request'].user + user = self.context["request"].user try: - refresh_token = RefreshToken(data['refresh']) - user_id = refresh_token['userId'] + refresh_token = RefreshToken(data["refresh"]) + user_id = refresh_token["userId"] except KeyError: - raise serializers.ValidationError( - 'Token contains no valid user identification' - ) + raise serializers.ValidationError("Token contains no valid user identification") except TokenError as e: raise serializers.ValidationError(e.message) if user.id != user_id: - raise serializers.ValidationError( - 'Invalid refresh token' - ) + raise serializers.ValidationError("Invalid refresh token") if not user.is_active: - raise AuthenticationFailed('User not active') + raise AuthenticationFailed("User not active") access_token = AccessToken.for_user(user) return { - 'access': access_token.encode(), + "access": access_token.encode(), } @@ -117,20 +106,20 @@ class HIDTokenObtainPairSerializer(serializers.Serializer): state = serializers.IntegerField(required=False) def validate(self, data): - humanitarian_id = hid.HumanitarianId(data['access_token']) + humanitarian_id = hid.HumanitarianId(data["access_token"]) try: user = humanitarian_id.get_user() except hid.HIDBaseException as e: raise serializers.ValidationError(e.message) except Exception: - logger.error('HID error', exc_info=True) - raise serializers.ValidationError('Unexpected Error') + logger.error("HID error", exc_info=True) + raise serializers.ValidationError("Unexpected Error") access_token = AccessToken.for_user(user) refresh_token = RefreshToken.for_access_token(access_token) return { - 'access': access_token.encode(), - 'refresh': refresh_token.encode(), + "access": access_token.encode(), + "refresh": refresh_token.encode(), } diff --git a/apps/jwt_auth/tests/test_apis.py b/apps/jwt_auth/tests/test_apis.py index fcdd681a89..77152b4eab 100644 --- a/apps/jwt_auth/tests/test_apis.py +++ b/apps/jwt_auth/tests/test_apis.py @@ -1,50 +1,42 @@ -from deep.tests import TestCase from user.models import User +from deep.tests import TestCase + class JwtApiTests(TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.user_password = 'joHnDave!@#123' + self.user_password = "joHnDave!@#123" def test_login_with_password_greater_than_128_characters(self): - data = { - 'username': "Hari@gmail.com", - "password": 'abcd' * 130 - } - url = '/api/v1/token/' + data = {"username": "Hari@gmail.com", "password": "abcd" * 130} + url = "/api/v1/token/" response = self.client.post(url, data) self.assert_400(response) - assert 'password' in response.data['errors'] + assert "password" in response.data["errors"] def test_valid_login(self): - user = User.objects.create_user(username='test@deep.com', password=self.user_password) + user = User.objects.create_user(username="test@deep.com", password=self.user_password) user.is_active = True user.save() # try to login - data = { - 'username': user.username, - 'password': self.user_password - } - url = '/api/v1/token/' + data = {"username": user.username, "password": self.user_password} + url = "/api/v1/token/" # NOTE: Just to make sure empty doesn't throw error - self.client.credentials(HTTP_AUTHORIZATION='') + self.client.credentials(HTTP_AUTHORIZATION="") response = self.client.post(url, data=data) self.assert_200(response) - self.assertIn('access', response.data) + self.assertIn("access", response.data) def test_invalid_login_with_password_length_greater_than_128_character(self): - user = User.objects.create_user(username='test@deep.com', password=self.user_password * 129) + user = User.objects.create_user(username="test@deep.com", password=self.user_password * 129) user.is_active = True user.save() # try to login - data = { - 'username': user.username, - 'password': self.user_password * 129 - } - url = '/api/v1/token/' + data = {"username": user.username, "password": self.user_password * 129} + url = "/api/v1/token/" response = self.client.post(url, data=data) self.assert_400(response) - assert 'password' in response.data['errors'] + assert "password" in response.data["errors"] diff --git a/apps/jwt_auth/token.py b/apps/jwt_auth/token.py index d903817c2b..6470b3881a 100644 --- a/apps/jwt_auth/token.py +++ b/apps/jwt_auth/token.py @@ -1,19 +1,17 @@ -from user.models import User -from django.conf import settings import datetime + import jwt +from django.conf import settings +from user.models import User -from .errors import ( - UnknownTokenError, - UserNotFoundError, - UserInactiveError, -) +from .errors import UnknownTokenError, UserInactiveError, UserNotFoundError class TokenError(Exception): """ Token encode/decode error """ + code = 0x70531 # Trying and failing to hex-speak TOKEN def __init__(self, message): @@ -31,6 +29,7 @@ class Token: """ Wrapper for jwt token """ + def __init__(self, token=None, verify=True): """ Initialize with given jwt string to decode or create a new one @@ -45,25 +44,24 @@ def __init__(self, token=None, verify=True): self.payload = jwt.decode( self.token, SECRET, - algorithms=['HS256'], + algorithms=["HS256"], verify=verify, ) except (jwt.ExpiredSignatureError, jwt.InvalidSignatureError): - raise TokenError('Token is invalid or expired') + raise TokenError("Token is invalid or expired") else: # Not token was given, so create a new one # Also set proper lifetime starting now if self.lifetime: - self.payload['exp'] = \ - datetime.datetime.utcnow() + self.lifetime + self.payload["exp"] = datetime.datetime.utcnow() + self.lifetime # Finally set the proper token type - self.payload['tokenType'] = self.token_type + self.payload["tokenType"] = self.token_type # Leave rest of the payload to be set by inherited classes def encode(self): - return jwt.encode(self.payload, SECRET, algorithm='HS256') + return jwt.encode(self.payload, SECRET, algorithm="HS256") def __repr__(self): return repr(self.payload) @@ -88,7 +86,8 @@ class AccessToken(Token): """ Access token """ - token_type = 'access' + + token_type = "access" lifetime = ACCESS_TOKEN_LIFETIME @staticmethod @@ -98,14 +97,14 @@ def for_user(user): """ token = AccessToken() - token['userId'] = user.id + token["userId"] = user.id return token def get_user(self): """ Get user from the access token """ - user_id = self.payload.get('userId') + user_id = self.payload.get("userId") if not user_id: raise UnknownTokenError() @@ -124,7 +123,8 @@ class RefreshToken(Token): """ Refresh token """ - token_type = 'refresh' + + token_type = "refresh" lifetime = None @staticmethod @@ -135,6 +135,6 @@ def for_access_token(access_token): token = RefreshToken() # For now just set same user id - token['userId'] = access_token['userId'] + token["userId"] = access_token["userId"] return token diff --git a/apps/jwt_auth/views.py b/apps/jwt_auth/views.py index 24e76889f6..aa53ba9b1f 100644 --- a/apps/jwt_auth/views.py +++ b/apps/jwt_auth/views.py @@ -1,4 +1,4 @@ -from rest_framework import generics, status, permissions +from rest_framework import generics, permissions, status from rest_framework.response import Response from . import serializers @@ -11,10 +11,7 @@ def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - return Response( - serializer.validated_data, - status=status.HTTP_200_OK - ) + return Response(serializer.validated_data, status=status.HTTP_200_OK) class TokenObtainPairView(TokenViewBase): diff --git a/apps/lang/admin.py b/apps/lang/admin.py index f4bbf8405a..b790bce7a7 100644 --- a/apps/lang/admin.py +++ b/apps/lang/admin.py @@ -1,29 +1,32 @@ from django.contrib import admin from deep.admin import ReadOnlyMixin -from .models import ( - String, - Link, - LinkCollection, -) + +from .models import Link, LinkCollection, String @admin.register(String) class StringAdmin(admin.ModelAdmin): - search_fields = ('language', 'value',) - list_filter = ('language',) + search_fields = ( + "language", + "value", + ) + list_filter = ("language",) @admin.register(LinkCollection) class LinkCollectionAdmin(ReadOnlyMixin, admin.ModelAdmin): - search_fields = ('key',) + search_fields = ("key",) @admin.register(Link) class LinkAdmin(admin.ModelAdmin): - search_fields = ('key',) - autocomplete_fields = ('link_collection', 'string',) - list_display = ('key', 'string', 'language', 'link_collection') + search_fields = ("key",) + autocomplete_fields = ( + "link_collection", + "string", + ) + list_display = ("key", "string", "language", "link_collection") def get_form(self, request, obj=None, **kwargs): form = super().get_form( diff --git a/apps/lang/apps.py b/apps/lang/apps.py index f21fffa475..66a97cf8fa 100644 --- a/apps/lang/apps.py +++ b/apps/lang/apps.py @@ -2,4 +2,4 @@ class LangConfig(AppConfig): - name = 'lang' + name = "lang" diff --git a/apps/lang/management/commands/import_lang.py b/apps/lang/management/commands/import_lang.py index f8a4040d72..765d1cf75d 100644 --- a/apps/lang/management/commands/import_lang.py +++ b/apps/lang/management/commands/import_lang.py @@ -1,43 +1,36 @@ from csv import DictReader + from django.core.management.base import BaseCommand from django.db import transaction - -from lang.models import String, Link, LinkCollection +from lang.models import Link, LinkCollection, String class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument('--code', dest='lang_code') - parser.add_argument('filename') + parser.add_argument("--code", dest="lang_code") + parser.add_argument("filename") def handle(self, *args, **kwargs): - filename = kwargs['filename'] - lang_code = kwargs['lang_code'] + filename = kwargs["filename"] + lang_code = kwargs["lang_code"] self.import_language(filename, lang_code) @transaction.atomic def import_language(self, filename, lang_code): reader = DictReader(open(filename)) for i, row in enumerate(reader): - print(f'Loading row #{i}') - string_value = row['sp_text_new'] + print(f"Loading row #{i}") + string_value = row["sp_text_new"] string, _ = String.objects.get_or_create( language=lang_code, value=string_value, ) - links = row['links'].split(', ') + links = row["links"].split(", ") for link_id in links: if len(link_id.strip()) == 0: continue - collection_key, link_key = link_id.split(': ') - collection, _ = LinkCollection.objects.get_or_create( - key=collection_key - ) + collection_key, link_key = link_id.split(": ") + collection, _ = LinkCollection.objects.get_or_create(key=collection_key) - link, _ = Link.objects.get_or_create( - link_collection=collection, - key=link_key, - string=string, - language=lang_code - ) + link, _ = Link.objects.get_or_create(link_collection=collection, key=link_key, string=string, language=lang_code) diff --git a/apps/lang/models.py b/apps/lang/models.py index c8699ba2c3..bb090fd902 100644 --- a/apps/lang/models.py +++ b/apps/lang/models.py @@ -1,5 +1,5 @@ -from django.db import models from django.conf import settings +from django.db import models class String(models.Model): @@ -11,7 +11,7 @@ class String(models.Model): value = models.TextField() def __str__(self): - return '{} ({})'.format(self.value, self.language) + return "{} ({})".format(self.value, self.language) class LinkCollection(models.Model): @@ -28,15 +28,18 @@ class Link(models.Model): default=settings.LANGUAGE_CODE, ) link_collection = models.ForeignKey( - LinkCollection, related_name='links', on_delete=models.CASCADE, + LinkCollection, + related_name="links", + on_delete=models.CASCADE, ) key = models.CharField(max_length=255) string = models.ForeignKey( String, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, on_delete=models.SET_NULL, ) def __str__(self): - return '{} : {} ({})'.format(self.key, self.string.value, - self.language) + return "{} : {} ({})".format(self.key, self.string.value, self.language) diff --git a/apps/lang/serializers.py b/apps/lang/serializers.py index a6814357e2..e570960461 100644 --- a/apps/lang/serializers.py +++ b/apps/lang/serializers.py @@ -1,31 +1,29 @@ from rest_framework import serializers from deep.serializers import RemoveNullFieldsMixin -from .models import String, Link, LinkCollection +from .models import Link, LinkCollection, String -class LanguageSerializer(RemoveNullFieldsMixin, - serializers.Serializer): + +class LanguageSerializer(RemoveNullFieldsMixin, serializers.Serializer): code = serializers.CharField() title = serializers.CharField() -class StringSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class StringSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): action = serializers.CharField(write_only=True) class Meta: model = String - fields = ('id', 'value', 'action') + fields = ("id", "value", "action") -class LinkSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class LinkSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): action = serializers.CharField(write_only=True) class Meta: model = Link - fields = ('key', 'string', 'action') + fields = ("key", "string", "action") # Override DictField with partial value set to solve a DRF Bug @@ -34,24 +32,23 @@ class DictField(serializers.DictField): # Expects a object containing 'code', title', `strings` and `links` -class StringsSerializer(RemoveNullFieldsMixin, - serializers.Serializer): +class StringsSerializer(RemoveNullFieldsMixin, serializers.Serializer): code = serializers.CharField(read_only=True) title = serializers.CharField(read_only=True) strings = StringSerializer(many=True) links = DictField(child=LinkSerializer(many=True)) def save(self): - code = self.initial_data['code'] - strings = self.initial_data.get('strings') or [] - link_collections = self.initial_data.get('links') or {} + code = self.initial_data["code"] + strings = self.initial_data.get("strings") or [] + link_collections = self.initial_data.get("links") or {} string_map = {} for string_data in strings: - action = string_data['action'] - id = string_data['id'] + action = string_data["action"] + id = string_data["id"] - if action == 'add': + if action == "add": string = String() else: string = String.objects.filter(id=id).first() @@ -59,25 +56,23 @@ def save(self): if not string: continue - if action == 'delete': + if action == "delete": string.delete() continue string.language = code - string.value = string_data['value'] + string.value = string_data["value"] string.save() string_map[id] = string for collection_key, links in link_collections.items(): - collection, _ = LinkCollection.objects.get_or_create( - key=collection_key - ) + collection, _ = LinkCollection.objects.get_or_create(key=collection_key) for link_data in links: - action = link_data['action'] - key = link_data['key'] + action = link_data["action"] + key = link_data["key"] - if action == 'add': + if action == "add": link = Link() else: link = Link.objects.get( @@ -85,7 +80,7 @@ def save(self): link_collection=collection, ) - if action == 'delete': + if action == "delete": link.delete() continue @@ -93,9 +88,8 @@ def save(self): link.link_collection = collection link.key = key - str_id = link_data['string'] - link.string = string_map.get(str_id) or \ - String.objects.get(id=str_id) + str_id = link_data["string"] + link.string = string_map.get(str_id) or String.objects.get(id=str_id) link.save() LinkCollection.objects.filter(links__isnull=True).delete() diff --git a/apps/lang/tests/test_apis.py b/apps/lang/tests/test_apis.py index 58a7b8971e..45bd2581b3 100644 --- a/apps/lang/tests/test_apis.py +++ b/apps/lang/tests/test_apis.py @@ -1,26 +1,27 @@ from django.conf import settings +from lang.models import Link, LinkCollection, String + from deep.tests import TestCase -from lang.models import String, LinkCollection, Link class LangApiTests(TestCase): def test_update_language(self): lang_code = settings.LANGUAGE_CODE - url = '/api/v1/languages/{}/'.format(lang_code) + url = "/api/v1/languages/{}/".format(lang_code) # First test creating data = { - 'strings': [ - {'id': 's_1', 'value': 's1', 'action': 'add'}, - {'id': 's_2', 'value': 's2', 'action': 'add'}, + "strings": [ + {"id": "s_1", "value": "s1", "action": "add"}, + {"id": "s_2", "value": "s2", "action": "add"}, ], - 'links': { - 'group1': [ - {'key': 'l_1', 'string': 's_1', 'action': 'add'}, - {'key': 'l_2', 'string': 's_2', 'action': 'add'}, + "links": { + "group1": [ + {"key": "l_1", "string": "s_1", "action": "add"}, + {"key": "l_2", "string": "s_2", "action": "add"}, ], - 'group2': [ - {'key': 'l_1', 'string': 's_2', 'action': 'add'}, + "group2": [ + {"key": "l_1", "string": "s_2", "action": "add"}, ], }, } @@ -29,42 +30,39 @@ def test_update_language(self): response = self.client.put(url, data) self.assert_200(response) - s1 = String.objects.filter(value='s1').first() - s2 = String.objects.filter(value='s2').first() + s1 = String.objects.filter(value="s1").first() + s2 = String.objects.filter(value="s2").first() self.assertIsNotNone(s1) self.assertIsNotNone(s2) - group1 = LinkCollection.objects.filter(key='group1').first() - group2 = LinkCollection.objects.filter(key='group2').first() + group1 = LinkCollection.objects.filter(key="group1").first() + group2 = LinkCollection.objects.filter(key="group2").first() self.assertIsNotNone(group1) self.assertIsNotNone(group2) - l1 = Link.objects.filter(link_collection=group1, - key='l_1', string=s1).first() - l2 = Link.objects.filter(link_collection=group1, - key='l_2', string=s2).first() - l3 = Link.objects.filter(link_collection=group2, - key='l_1', string=s2).first() + l1 = Link.objects.filter(link_collection=group1, key="l_1", string=s1).first() + l2 = Link.objects.filter(link_collection=group1, key="l_2", string=s2).first() + l3 = Link.objects.filter(link_collection=group2, key="l_1", string=s2).first() self.assertIsNotNone(l1) self.assertIsNotNone(l2) self.assertIsNotNone(l3) # Then test updating, deleting and creating data = { - 'strings': [ - {'id': s1.id, 'value': 's1 new', 'action': 'edit'}, - {'id': s2.id, 'action': 'delete'}, - {'id': 's_3', 'value': 's3', 'action': 'add'}, + "strings": [ + {"id": s1.id, "value": "s1 new", "action": "edit"}, + {"id": s2.id, "action": "delete"}, + {"id": "s_3", "value": "s3", "action": "add"}, ], - 'links': { - 'group1': [ - {'key': 'l_1', 'action': 'delete'}, - {'key': 'l_2', 'string': 's_3', 'action': 'edit'}, + "links": { + "group1": [ + {"key": "l_1", "action": "delete"}, + {"key": "l_2", "string": "s_3", "action": "edit"}, ], - 'group2': [ - {'key': 'l_1', 'string': s1.id, 'action': 'edit'}, + "group2": [ + {"key": "l_1", "string": s1.id, "action": "edit"}, ], - 'group3': [], + "group3": [], }, } @@ -74,14 +72,14 @@ def test_update_language(self): s1 = String.objects.filter(id=s1.id).first() s2 = String.objects.filter(id=s2.id).first() - s3 = String.objects.filter(value='s3').first() - self.assertEqual(s1.value, 's1 new') + s3 = String.objects.filter(value="s3").first() + self.assertEqual(s1.value, "s1 new") self.assertIsNone(s2) self.assertIsNotNone(s3) - group1 = LinkCollection.objects.filter(key='group1').first() - group2 = LinkCollection.objects.filter(key='group2').first() - group3 = LinkCollection.objects.filter(key='group3').first() + group1 = LinkCollection.objects.filter(key="group1").first() + group2 = LinkCollection.objects.filter(key="group2").first() + group3 = LinkCollection.objects.filter(key="group3").first() self.assertIsNotNone(group1) self.assertIsNotNone(group2) self.assertIsNone(group3) diff --git a/apps/lang/views.py b/apps/lang/views.py index 8ddcfd03b1..0c39b29f31 100644 --- a/apps/lang/views.py +++ b/apps/lang/views.py @@ -1,17 +1,14 @@ from django.conf import settings -from rest_framework import ( - viewsets, - response, - permissions, -) +from rest_framework import permissions, response, viewsets + from deep.permissions import IsSuperAdmin + +from .models import LinkCollection, String from .serializers import LanguageSerializer, StringsSerializer -from .models import String, LinkCollection class LanguageViewSet(viewsets.ViewSet): - permission_classes = [permissions.IsAuthenticated, - IsSuperAdmin] + permission_classes = [permissions.IsAuthenticated, IsSuperAdmin] def retrieve(self, request, pk=None, version=None): code = pk @@ -22,13 +19,10 @@ def get_links(collection): return collection.links.filter(language=code) obj = { - 'code': code, - 'title': language[1], - 'strings': String.objects.filter(language=code), - 'links': { - link_collection.key: get_links(link_collection) - for link_collection in LinkCollection.objects.all() - }, + "code": code, + "title": language[1], + "strings": String.objects.filter(language=code), + "links": {link_collection.key: get_links(link_collection) for link_collection in LinkCollection.objects.all()}, } return response.Response(StringsSerializer(obj).data) @@ -36,8 +30,8 @@ def get_links(collection): def list(self, request, version=None): languages = [ { - 'code': _lang[0], - 'title': _lang[1], + "code": _lang[0], + "title": _lang[1], } for _lang in settings.LANGUAGES ] @@ -47,15 +41,19 @@ def list(self, request, version=None): ) results = serializer.data - return response.Response({ - 'count': len(results), - 'results': results, - }) + return response.Response( + { + "count": len(results), + "results": results, + } + ) def update(self, request, pk=None, version=None): - serializer = StringsSerializer(data={ - 'code': pk, - **request.data, - }) + serializer = StringsSerializer( + data={ + "code": pk, + **request.data, + } + ) serializer.save() return self.retrieve(request, pk=pk) diff --git a/apps/lead/__init__.py b/apps/lead/__init__.py index a2873b4ed0..2b7845766a 100644 --- a/apps/lead/__init__.py +++ b/apps/lead/__init__.py @@ -1 +1 @@ -default_app_config = 'lead.apps.LeadConfig' +default_app_config = "lead.apps.LeadConfig" diff --git a/apps/lead/admin.py b/apps/lead/admin.py index 0052268ff3..60018bf696 100644 --- a/apps/lead/admin.py +++ b/apps/lead/admin.py @@ -1,15 +1,10 @@ -from django.contrib import admin +from admin_auto_filters.filters import AutocompleteFilterFactory +from django.contrib import admin, messages from django.utils.safestring import mark_safe -from django.contrib import messages from reversion.admin import VersionAdmin -from admin_auto_filters.filters import AutocompleteFilterFactory +from .models import EMMEntity, Lead, LeadGroup, LeadPreview, LeadPreviewImage from .tasks import extract_from_lead -from .models import ( - Lead, LeadGroup, - LeadPreview, LeadPreviewImage, - EMMEntity, -) class LeadPreviewInline(admin.StackedInline): @@ -22,61 +17,66 @@ class LeadPreviewImageInline(admin.TabularInline): def trigger_lead_extract(modeladmin, request, queryset): - extract_from_lead.delay( - list(queryset.values_list('id', flat=True).distinct()[:10]) - ) + extract_from_lead.delay(list(queryset.values_list("id", flat=True).distinct()[:10])) messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully triggered leads:

' + - '
'.join( - '* {0} : ({1}) {2}'.format(*value) - for value in queryset.values_list('id', 'project_id', 'title').distinct() + "Successfully triggered leads:

" + + "
".join( + "* {0} : ({1}) {2}".format(*value) for value in queryset.values_list("id", "project_id", "title").distinct() ) - ) + ), ) -trigger_lead_extract.short_description = 'Trigger lead extraction' +trigger_lead_extract.short_description = "Trigger lead extraction" @admin.register(Lead) class LeadAdmin(VersionAdmin): inlines = [LeadPreviewInline, LeadPreviewImageInline] - search_fields = ['title'] + search_fields = ["title"] list_filter = ( - AutocompleteFilterFactory('Project', 'project'), - AutocompleteFilterFactory('Created By', 'created_by'), - 'created_at', + AutocompleteFilterFactory("Project", "project"), + AutocompleteFilterFactory("Created By", "created_by"), + "created_at", ) list_display = [ - 'title', 'project', 'created_by', 'created_at', + "title", + "project", + "created_by", + "created_at", ] - ordering = ('project', 'created_by', 'created_at') + ordering = ("project", "created_by", "created_at") autocomplete_fields = ( - 'project', - 'created_by', - 'modified_by', - 'attachment', - 'assignee', - 'source', - 'authors', - 'author', - 'emm_entities', - 'lead_group', - 'connector_lead', - 'duplicate_leads', + "project", + "created_by", + "modified_by", + "attachment", + "assignee", + "source", + "authors", + "author", + "emm_entities", + "lead_group", + "connector_lead", + "duplicate_leads", ) - readonly_fields = ('uuid',) + readonly_fields = ("uuid",) actions = [trigger_lead_extract] @admin.register(LeadGroup) class LeadGroupAdmin(VersionAdmin): - search_fields = ('title',) - autocomplete_fields = ('project', 'created_by', 'modified_by',) + search_fields = ("title",) + autocomplete_fields = ( + "project", + "created_by", + "modified_by", + ) @admin.register(EMMEntity) class EMMEntityAdmin(admin.ModelAdmin): - search_fields = ('name',) + search_fields = ("name",) diff --git a/apps/lead/apps.py b/apps/lead/apps.py index 6a80ecc02d..6e4a1c8149 100644 --- a/apps/lead/apps.py +++ b/apps/lead/apps.py @@ -2,4 +2,4 @@ class LeadConfig(AppConfig): - name = 'lead' + name = "lead" diff --git a/apps/lead/dataloaders.py b/apps/lead/dataloaders.py index 91a6e43c8f..975597304d 100644 --- a/apps/lead/dataloaders.py +++ b/apps/lead/dataloaders.py @@ -1,81 +1,63 @@ -from promise import Promise from collections import defaultdict -from django.utils.functional import cached_property +from assessment_registry.models import AssessmentRegistry +from assisted_tagging.models import DraftEntry from django.db import models - -from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin - +from django.utils.functional import cached_property from entry.models import Entry +from organization.dataloaders import OrganizationLoader from organization.models import Organization +from promise import Promise -from organization.dataloaders import OrganizationLoader +from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from .models import Lead, LeadPreview, LeadGroup -from assisted_tagging.models import DraftEntry -from assessment_registry.models import AssessmentRegistry +from .models import Lead, LeadGroup, LeadPreview class LeadPreviewLoader(DataLoaderWithContext): def batch_load_fn(self, keys): lead_preview_qs = LeadPreview.objects.filter(lead__in=keys) - _map = { - lead_preview.lead_id: lead_preview - for lead_preview in lead_preview_qs - } + _map = {lead_preview.lead_id: lead_preview for lead_preview in lead_preview_qs} return Promise.resolve([_map.get(key) for key in keys]) class EntriesCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): active_af = self.context.active_project.analysis_framework - stat_qs = Entry.objects\ - .filter(lead__in=keys)\ - .order_by('lead').values('lead')\ + stat_qs = ( + Entry.objects.filter(lead__in=keys) + .order_by("lead") + .values("lead") .annotate( total=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(analysis_framework=active_af) - ), + models.Count("id", filter=models.Q(analysis_framework=active_af)), 0, ), controlled=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(controlled=True, analysis_framework=active_af) - ), + models.Count("id", filter=models.Q(controlled=True, analysis_framework=active_af)), 0, ), - ).values('lead_id', 'total', 'controlled') - _map = { - stat.pop('lead_id'): stat - for stat in stat_qs - } + ) + .values("lead_id", "total", "controlled") + ) + _map = {stat.pop("lead_id"): stat for stat in stat_qs} _dummy = { - 'total': 0, - 'controlled': 0, + "total": 0, + "controlled": 0, } return Promise.resolve([_map.get(key, _dummy) for key in keys]) class LeadGroupLeadCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - lead_group_qs = LeadGroup.objects.filter(id__in=keys).annotate( - lead_counts=models.Count('lead', distinct=True) - ) - _map = { - id: count - for id, count in lead_group_qs.values_list('id', 'lead_counts') - } + lead_group_qs = LeadGroup.objects.filter(id__in=keys).annotate(lead_counts=models.Count("lead", distinct=True)) + _map = {id: count for id, count in lead_group_qs.values_list("id", "lead_counts")} return Promise.resolve([_map.get(key, 0) for key in keys]) class LeadAuthorsLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - lead_author_qs = Lead.objects\ - .filter(id__in=keys, authors__isnull=False)\ - .values_list('id', 'authors__id') + lead_author_qs = Lead.objects.filter(id__in=keys, authors__isnull=False).values_list("id", "authors__id") lead_author_map = defaultdict(list) organizations_id = set() for lead_id, author_id in lead_author_qs: @@ -83,52 +65,36 @@ def batch_load_fn(self, keys): organizations_id.add(author_id) organization_qs = Organization.objects.filter(id__in=organizations_id) - _map = { - org.id: org for org in organization_qs - } - return Promise.resolve([ - [ - _map.get(author) - for author in lead_author_map.get(key, []) - ] - for key in keys - ]) + _map = {org.id: org for org in organization_qs} + return Promise.resolve([[_map.get(author) for author in lead_author_map.get(key, [])] for key in keys]) class LeadAssessmentIdLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - assessments_qs = AssessmentRegistry.objects.filter(lead__in=keys).values_list('id', 'lead') - _map = { - lead_id: _id for _id, lead_id in assessments_qs - } + assessments_qs = AssessmentRegistry.objects.filter(lead__in=keys).values_list("id", "lead") + _map = {lead_id: _id for _id, lead_id in assessments_qs} return Promise.resolve([_map.get(key) for key in keys]) class LeadDraftEntryCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - stat_qs = DraftEntry.objects\ - .filter(lead__in=keys)\ - .order_by('lead').values('lead')\ + stat_qs = ( + DraftEntry.objects.filter(lead__in=keys) + .order_by("lead") + .values("lead") .annotate( discarded_draft_entry=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(is_discarded=True) - ), + models.Count("id", filter=models.Q(is_discarded=True)), 0, ), undiscarded_draft_entry=models.functions.Coalesce( - models.Count( - 'id', - filter=models.Q(is_discarded=False) - ), + models.Count("id", filter=models.Q(is_discarded=False)), 0, ), - ).values('lead_id', 'undiscarded_draft_entry', 'discarded_draft_entry') - _map = { - stat.pop('lead_id'): stat - for stat in stat_qs - } + ) + .values("lead_id", "undiscarded_draft_entry", "discarded_draft_entry") + ) + _map = {stat.pop("lead_id"): stat for stat in stat_qs} return Promise.resolve([_map.get(key, _map) for key in keys]) diff --git a/apps/lead/enums.py b/apps/lead/enums.py index a767fe58de..42fde442e0 100644 --- a/apps/lead/enums.py +++ b/apps/lead/enums.py @@ -7,14 +7,12 @@ from .models import Lead -LeadConfidentialityEnum = convert_enum_to_graphene_enum(Lead.Confidentiality, name='LeadConfidentialityEnum') -LeadStatusEnum = convert_enum_to_graphene_enum(Lead.Status, name='LeadStatusEnum') -LeadPriorityEnum = convert_enum_to_graphene_enum(Lead.Priority, name='LeadPriorityEnum') -LeadSourceTypeEnum = convert_enum_to_graphene_enum(Lead.SourceType, name='LeadSourceTypeEnum') -LeadExtractionStatusEnum = convert_enum_to_graphene_enum(Lead.ExtractionStatus, name='LeadExtractionStatusEnum') -LeadAutoEntryExtractionTypeEnum = convert_enum_to_graphene_enum( - Lead.AutoExtractionStatus, name='LeadAutoEntryExtractionTypeEnum' -) +LeadConfidentialityEnum = convert_enum_to_graphene_enum(Lead.Confidentiality, name="LeadConfidentialityEnum") +LeadStatusEnum = convert_enum_to_graphene_enum(Lead.Status, name="LeadStatusEnum") +LeadPriorityEnum = convert_enum_to_graphene_enum(Lead.Priority, name="LeadPriorityEnum") +LeadSourceTypeEnum = convert_enum_to_graphene_enum(Lead.SourceType, name="LeadSourceTypeEnum") +LeadExtractionStatusEnum = convert_enum_to_graphene_enum(Lead.ExtractionStatus, name="LeadExtractionStatusEnum") +LeadAutoEntryExtractionTypeEnum = convert_enum_to_graphene_enum(Lead.AutoExtractionStatus, name="LeadAutoEntryExtractionTypeEnum") enum_map = { get_enum_name_from_django_field(field): enum @@ -32,26 +30,26 @@ # TODO: Define this dynamically through a list? class LeadOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_CREATED_AT = 'created_at' - ASC_TITLE = 'title' - ASC_SOURCE = 'source__title' - ASC_PUBLISHED_ON = 'published_on' - ASC_CREATED_BY = 'created_by' - ASC_ASSIGNEE = 'assignee__first_name' - ASC_PRIORITY = 'priority' + ASC_ID = "id" + ASC_CREATED_AT = "created_at" + ASC_TITLE = "title" + ASC_SOURCE = "source__title" + ASC_PUBLISHED_ON = "published_on" + ASC_CREATED_BY = "created_by" + ASC_ASSIGNEE = "assignee__first_name" + ASC_PRIORITY = "priority" # # Custom Filters - ASC_PAGE_COUNT = 'page_count' - ASC_ENTRIES_COUNT = 'entries_count' + ASC_PAGE_COUNT = "page_count" + ASC_ENTRIES_COUNT = "entries_count" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_TITLE = f'-{ASC_TITLE}' - DESC_SOURCE = f'-{ASC_SOURCE}' - DESC_PUBLISHED_ON = f'-{ASC_PUBLISHED_ON}' - DESC_CREATED_BY = f'-{ASC_CREATED_BY}' - DESC_ASSIGNEE = f'-{ASC_ASSIGNEE}' - DESC_PRIORITY = f'-{ASC_PRIORITY}' + DESC_ID = f"-{ASC_ID}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_TITLE = f"-{ASC_TITLE}" + DESC_SOURCE = f"-{ASC_SOURCE}" + DESC_PUBLISHED_ON = f"-{ASC_PUBLISHED_ON}" + DESC_CREATED_BY = f"-{ASC_CREATED_BY}" + DESC_ASSIGNEE = f"-{ASC_ASSIGNEE}" + DESC_PRIORITY = f"-{ASC_PRIORITY}" # # Custom Filters - DESC_PAGE_COUNT = f'-{ASC_PAGE_COUNT}' - DESC_ENTRIES_COUNT = f'-{ASC_ENTRIES_COUNT}' + DESC_PAGE_COUNT = f"-{ASC_PAGE_COUNT}" + DESC_ENTRIES_COUNT = f"-{ASC_ENTRIES_COUNT}" diff --git a/apps/lead/factories.py b/apps/lead/factories.py index 167974f145..ac58f82aca 100644 --- a/apps/lead/factories.py +++ b/apps/lead/factories.py @@ -1,26 +1,26 @@ -import factory import datetime + +import factory from factory import fuzzy from factory.django import DjangoModelFactory - -from project.factories import ProjectFactory from gallery.factories import FileFactory +from project.factories import ProjectFactory + from .models import ( - Lead, EMMEntity, - LeadGroup, + Lead, LeadEMMTrigger, + LeadGroup, LeadPreview, LeadPreviewImage, UserSavedLeadFilter, ) - DEFAULT_START_DATE = datetime.date(year=2017, month=1, day=1) class LeadFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Lead-{n}') + title = factory.Sequence(lambda n: f"Lead-{n}") text = fuzzy.FuzzyText(length=100) project = factory.SubFactory(ProjectFactory) attachment = factory.SubFactory(FileFactory) @@ -56,29 +56,29 @@ def emm_entities(self, create, extracted, **kwargs): class EmmEntityFactory(DjangoModelFactory): - name = factory.Sequence(lambda n: f'emm-name-{n}') + name = factory.Sequence(lambda n: f"emm-name-{n}") class Meta: model = EMMEntity class LeadGroupFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'LeadGroup-{n}') + title = factory.Sequence(lambda n: f"LeadGroup-{n}") class Meta: model = LeadGroup class LeadEMMTriggerFactory(DjangoModelFactory): - emm_keyword = factory.Sequence(lambda n: f'emm_keyword-{n}') - emm_risk_factor = factory.Sequence(lambda n: f'emm_risk_factor-{n}') + emm_keyword = factory.Sequence(lambda n: f"emm_keyword-{n}") + emm_risk_factor = factory.Sequence(lambda n: f"emm_risk_factor-{n}") class Meta: model = LeadEMMTrigger class LeadPreviewFactory(DjangoModelFactory): - text_extract = factory.Faker('text', max_nb_chars=4000) + text_extract = factory.Faker("text", max_nb_chars=4000) class Meta: model = LeadPreview diff --git a/apps/lead/filter_set.py b/apps/lead/filter_set.py index 28eeac509b..7963c514ea 100644 --- a/apps/lead/filter_set.py +++ b/apps/lead/filter_set.py @@ -1,36 +1,38 @@ -import graphene import django_filters +import graphene from django.db import models from django.db.models.functions import Coalesce +from entry.filter_set import ( + EntriesFilterDataInputType, + EntriesFilterDataType, + EntryGQFilterSet, +) +from entry.models import Entry +from organization.models import OrganizationType +from project.models import Project +from user.models import User +from user_resource.filters import UserResourceFilterSet, UserResourceGqlFilterSet from deep.filter_set import DjangoFilterCSVWidget, generate_type_for_filter_set -from user_resource.filters import UserResourceFilterSet from utils.graphene.filters import ( - NumberInFilter, - MultipleInputFilter, - SimpleInputFilter, - IDListFilter, - IDFilter, DateGteFilter, DateLteFilter, + IDFilter, + IDListFilter, + MultipleInputFilter, + NumberInFilter, + SimpleInputFilter, ) -from project.models import Project -from organization.models import OrganizationType -from user.models import User -from entry.models import Entry -from entry.filter_set import EntryGQFilterSet, EntriesFilterDataInputType, EntriesFilterDataType -from user_resource.filters import UserResourceGqlFilterSet - -from .models import Lead, LeadGroup, LeadDuplicates from .enums import ( LeadConfidentialityEnum, - LeadStatusEnum, + LeadExtractionStatusEnum, + LeadOrderingEnum, LeadPriorityEnum, LeadSourceTypeEnum, - LeadOrderingEnum, - LeadExtractionStatusEnum, + LeadStatusEnum, ) +from .models import Lead, LeadDuplicates, LeadGroup class LeadFilterSet(django_filters.FilterSet): @@ -44,33 +46,38 @@ class LeadFilterSet(django_filters.FilterSet): """ class Exists(models.TextChoices): - ENTRIES_EXISTS = 'entries_exists', 'Entry Exists' - ASSESSMENT_EXISTS = 'assessment_exists', 'Assessment Exists' - ENTRIES_DO_NOT_EXIST = 'entries_do_not_exist', 'Entries do not exist' - ASSESSMENT_DOES_NOT_EXIST = 'assessment_does_not_exist', 'Assessment does not exist' + ENTRIES_EXISTS = "entries_exists", "Entry Exists" + ASSESSMENT_EXISTS = "assessment_exists", "Assessment Exists" + ENTRIES_DO_NOT_EXIST = "entries_do_not_exist", "Entries do not exist" + ASSESSMENT_DOES_NOT_EXIST = "assessment_does_not_exist", "Assessment does not exist" class CustomFilter(models.TextChoices): - EXCLUDE_EMPTY_FILTERED_ENTRIES = 'exclude_empty_filtered_entries', 'exclude empty filtered entries' + EXCLUDE_EMPTY_FILTERED_ENTRIES = "exclude_empty_filtered_entries", "exclude empty filtered entries" EXCLUDE_EMPTY_CONTROLLED_FILTERED_ENTRIES = ( - 'exclude_empty_controlled_filtered_entries', 'exclude empty controlled filtered entries' + "exclude_empty_controlled_filtered_entries", + "exclude empty controlled filtered entries", ) - search = django_filters.CharFilter(method='search_filter') + search = django_filters.CharFilter(method="search_filter") published_on__lt = django_filters.DateFilter( - field_name='published_on', lookup_expr='lt', + field_name="published_on", + lookup_expr="lt", ) published_on__gt = django_filters.DateFilter( - field_name='published_on', lookup_expr='gt', + field_name="published_on", + lookup_expr="gt", ) published_on__lte = django_filters.DateFilter( - field_name='published_on', lookup_expr='lte', + field_name="published_on", + lookup_expr="lte", ) published_on__gte = django_filters.DateFilter( - field_name='published_on', lookup_expr='gte', + field_name="published_on", + lookup_expr="gte", ) project = django_filters.CharFilter( - method='project_filter', + method="project_filter", ) confidentiality = django_filters.MultipleChoiceFilter( choices=Lead.Confidentiality.choices, @@ -89,78 +96,78 @@ class CustomFilter(models.TextChoices): widget=django_filters.widgets.CSVWidget, ) classified_doc_id = NumberInFilter( - field_name='leadpreview__classified_doc_id', - lookup_expr='in', + field_name="leadpreview__classified_doc_id", + lookup_expr="in", widget=django_filters.widgets.CSVWidget, ) created_at = django_filters.DateTimeFilter( - field_name='created_at', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + input_formats=["%Y-%m-%d%z"], ) created_at__lt = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='lt', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lt", + input_formats=["%Y-%m-%d%z"], ) created_at__gte = django_filters.DateTimeFilter( - field_name='created_at', lookup_expr='gte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="gte", + input_formats=["%Y-%m-%d%z"], ) created_at__lte = django_filters.DateTimeFilter( - field_name='created_at', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'], + field_name="created_at", + lookup_expr="lte", + input_formats=["%Y-%m-%d%z"], ) exists = django_filters.ChoiceFilter( - label='Exists Choice', - choices=Exists.choices, method='exists_filter', + label="Exists Choice", + choices=Exists.choices, + method="exists_filter", ) emm_entities = django_filters.CharFilter( - method='emm_entities_filter', + method="emm_entities_filter", ) emm_keywords = django_filters.CharFilter( - method='emm_keywords_filter', + method="emm_keywords_filter", ) emm_risk_factors = django_filters.CharFilter( - method='emm_risk_factors_filter', + method="emm_risk_factors_filter", ) ordering = django_filters.CharFilter( - method='ordering_filter', + method="ordering_filter", ) authoring_organization_types = django_filters.ModelMultipleChoiceFilter( - method='authoring_organization_types_filter', + method="authoring_organization_types_filter", widget=DjangoFilterCSVWidget, queryset=OrganizationType.objects.all(), ) # used in export custom_filters = django_filters.ChoiceFilter( - label='Filtered Exists Choice', - choices=CustomFilter.choices, method='filtered_exists_filter', + label="Filtered Exists Choice", + choices=CustomFilter.choices, + method="filtered_exists_filter", ) class Meta: model = Lead fields = { - **{ - x: ['exact'] - for x in ['id', 'text', 'url'] - }, - 'emm_entities': ['exact'], + **{x: ["exact"] for x in ["id", "text", "url"]}, + "emm_entities": ["exact"], # 'emm_keywords': ['exact'], # 'emm_risk_factors': ['exact'], - 'created_at': ['exact', 'lt', 'gt', 'lte', 'gte'], - 'published_on': ['exact', 'lt', 'gt', 'lte', 'gte'], + "created_at": ["exact", "lt", "gt", "lte", "gte"], + "published_on": ["exact", "lt", "gt", "lte", "gte"], } filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -173,7 +180,7 @@ def get_processed_filter_data(raw_filter_data): filter_data = {} for key, value in raw_filter_data.items(): if isinstance(value, list): - filter_data[key] = ','.join([str(x) for x in value]) + filter_data[key] = ",".join([str(x) for x in value]) else: filter_data[key] = value return filter_data @@ -184,17 +191,20 @@ def search_filter(self, qs, name, value): return qs return qs.filter( # By title - models.Q(title__icontains=value) | + models.Q(title__icontains=value) + | # By source - models.Q(source_raw__icontains=value) | - models.Q(source__title__icontains=value) | - models.Q(source__parent__title__icontains=value) | + models.Q(source_raw__icontains=value) + | models.Q(source__title__icontains=value) + | models.Q(source__parent__title__icontains=value) + | # By author - models.Q(author__title__icontains=value) | - models.Q(author__parent__title__icontains=value) | - models.Q(author_raw__icontains=value) | - models.Q(authors__title__icontains=value) | - models.Q(authors__parent__title__icontains=value) | + models.Q(author__title__icontains=value) + | models.Q(author__parent__title__icontains=value) + | models.Q(author_raw__icontains=value) + | models.Q(authors__title__icontains=value) + | models.Q(authors__parent__title__icontains=value) + | # By URL models.Q(url__icontains=value) ).distinct() @@ -202,7 +212,7 @@ def search_filter(self, qs, name, value): def project_filter(self, qs, name, value): # NOTE: @bewakes used this because normal project filter # was giving problem with post filter - project_ids = value.split(',') + project_ids = value.split(",") return qs.filter(project_id__in=project_ids) def exists_filter(self, qs, name, value): @@ -217,40 +227,35 @@ def exists_filter(self, qs, name, value): return qs def emm_entities_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_entities__in=splitted) def emm_keywords_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_triggers__emm_keyword__in=splitted) def emm_risk_factors_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_triggers__emm_risk_factor__in=splitted) def ordering_filter(self, qs, name, value): # NOTE: @bewakes used this because normal ordering filter # was giving problem with post filter # Just clean the order_by fields - orderings = [x.strip() for x in value.split(',') if x.strip()] + orderings = [x.strip() for x in value.split(",") if x.strip()] for ordering in orderings: - if ordering == '-page_count': - qs = qs.order_by(models.F('leadpreview__page_count').desc(nulls_last=True)) - elif ordering == 'page_count': - qs = qs.order_by(models.F('leadpreview__page_count').asc(nulls_first=True)) + if ordering == "-page_count": + qs = qs.order_by(models.F("leadpreview__page_count").desc(nulls_last=True)) + elif ordering == "page_count": + qs = qs.order_by(models.F("leadpreview__page_count").asc(nulls_first=True)) else: qs = qs.order_by(ordering) return qs def authoring_organization_types_filter(self, qs, name, value): if value: - qs = qs.annotate( - organization_types=Coalesce( - 'authors__parent__organization_type', - 'authors__organization_type' - ) - ) + qs = qs.annotate(organization_types=Coalesce("authors__parent__organization_type", "authors__organization_type")) if isinstance(value[0], OrganizationType): return qs.filter(organization_types__in=[ot.id for ot in value]).distinct() return qs.filter(organization_types__in=value).distinct() @@ -277,13 +282,13 @@ class LeadGroupFilterSet(UserResourceFilterSet): class Meta: model = LeadGroup - fields = ['id', 'title'] + fields = ["id", "title"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda f: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda f: { + "lookup_expr": "icontains", }, }, } @@ -291,56 +296,54 @@ class Meta: # ------------------------------ Graphql filters ----------------------------------- class LeadGQFilterSet(UserResourceGqlFilterSet): - ids = IDListFilter(method='filter_leads_id', help_text='Empty ids are ignored.') + ids = IDListFilter(method="filter_leads_id", help_text="Empty ids are ignored.") exclude_provided_leads_id = django_filters.BooleanFilter( - method='filter_exclude_provided_leads_id', help_text='Only used when ids are provided.') + method="filter_exclude_provided_leads_id", help_text="Only used when ids are provided." + ) created_by = IDListFilter() modified_by = IDListFilter() - source_types = MultipleInputFilter(LeadSourceTypeEnum, field_name='source_type') - priorities = MultipleInputFilter(LeadPriorityEnum, field_name='priority') + source_types = MultipleInputFilter(LeadSourceTypeEnum, field_name="source_type") + priorities = MultipleInputFilter(LeadPriorityEnum, field_name="priority") confidentiality = SimpleInputFilter(LeadConfidentialityEnum) - statuses = MultipleInputFilter(LeadStatusEnum, field_name='status') - extraction_status = SimpleInputFilter(LeadExtractionStatusEnum, field_name='extraction_status') - assignees = IDListFilter(field_name='assignee') - authoring_organization_types = IDListFilter(method='authoring_organization_types_filter') - author_organizations = IDListFilter(method='authoring_organizations_filter') - source_organizations = IDListFilter(method='source_organizations_filter') + statuses = MultipleInputFilter(LeadStatusEnum, field_name="status") + extraction_status = SimpleInputFilter(LeadExtractionStatusEnum, field_name="extraction_status") + assignees = IDListFilter(field_name="assignee") + authoring_organization_types = IDListFilter(method="authoring_organization_types_filter") + author_organizations = IDListFilter(method="authoring_organizations_filter") + source_organizations = IDListFilter(method="source_organizations_filter") # Filter-only enum filter - has_entries = django_filters.BooleanFilter(method='filter_has_entries', help_text='Lead has entries.') - has_assessment = django_filters.BooleanFilter(method='filter_has_assessment', help_text='Lead has assessment.') - is_assessment = django_filters.BooleanFilter(field_name='is_assessment_lead') - entries_filter_data = SimpleInputFilter(EntriesFilterDataInputType, method='filtered_entries_filter_data') + has_entries = django_filters.BooleanFilter(method="filter_has_entries", help_text="Lead has entries.") + has_assessment = django_filters.BooleanFilter(method="filter_has_assessment", help_text="Lead has assessment.") + is_assessment = django_filters.BooleanFilter(field_name="is_assessment_lead") + entries_filter_data = SimpleInputFilter(EntriesFilterDataInputType, method="filtered_entries_filter_data") - search = django_filters.CharFilter(method='search_filter') + search = django_filters.CharFilter(method="search_filter") published_on = django_filters.DateFilter() - published_on_gte = DateGteFilter(field_name='published_on') - published_on_lte = DateLteFilter(field_name='published_on') + published_on_gte = DateGteFilter(field_name="published_on") + published_on_lte = DateLteFilter(field_name="published_on") - emm_entities = django_filters.CharFilter(method='emm_entities_filter') - emm_keywords = django_filters.CharFilter(method='emm_keywords_filter') - emm_risk_factors = django_filters.CharFilter(method='emm_risk_factors_filter') + emm_entities = django_filters.CharFilter(method="emm_entities_filter") + emm_keywords = django_filters.CharFilter(method="emm_keywords_filter") + emm_risk_factors = django_filters.CharFilter(method="emm_risk_factors_filter") # duplicates - has_duplicates = django_filters.BooleanFilter(method='has_duplicates_filter', help_text='Has duplicate leads') - duplicates_of = IDFilter(method='duplicates_of_filter') + has_duplicates = django_filters.BooleanFilter(method="has_duplicates_filter", help_text="Has duplicate leads") + duplicates_of = IDFilter(method="duplicates_of_filter") - ordering = MultipleInputFilter(LeadOrderingEnum, method='ordering_filter') + ordering = MultipleInputFilter(LeadOrderingEnum, method="ordering_filter") class Meta: model = Lead fields = { - **{ - x: ['exact'] - for x in ['text', 'url'] - }, + **{x: ["exact"] for x in ["text", "url"]}, } filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } @@ -352,9 +355,9 @@ def __init__(self, *args, **kwargs): @property def active_project(self) -> Project: if self.request is None: - raise Exception(f'{self.request=} should be defined') + raise Exception(f"{self.request=} should be defined") if self.request.active_project is None: - raise Exception(f'{self.request.active_project=} should be defined') + raise Exception(f"{self.request.active_project=} should be defined") return self.request.active_project # Filters methods @@ -364,23 +367,26 @@ def search_filter(self, qs, name, value): return qs return qs.filter( # By title - models.Q(title__icontains=value) | + models.Q(title__icontains=value) + | # By source - models.Q(source_raw__icontains=value) | - models.Q(source__title__icontains=value) | - models.Q(source__parent__title__icontains=value) | + models.Q(source_raw__icontains=value) + | models.Q(source__title__icontains=value) + | models.Q(source__parent__title__icontains=value) + | # By author - models.Q(author__title__icontains=value) | - models.Q(author__parent__title__icontains=value) | - models.Q(author_raw__icontains=value) | - models.Q(authors__title__icontains=value) | - models.Q(authors__parent__title__icontains=value) | + models.Q(author__title__icontains=value) + | models.Q(author__parent__title__icontains=value) + | models.Q(author_raw__icontains=value) + | models.Q(authors__title__icontains=value) + | models.Q(authors__parent__title__icontains=value) + | # By URL models.Q(url__icontains=value) ).distinct() def ordering_filter(self, qs, name, value): - active_entry_count_field = self.custom_context.get('active_entry_count_field') + active_entry_count_field = self.custom_context.get("active_entry_count_field") for ordering in value: # Custom for entries count (use filter or normal entry count) if active_entry_count_field and ordering in [ @@ -390,37 +396,32 @@ def ordering_filter(self, qs, name, value): if ordering == LeadOrderingEnum.ASC_ENTRIES_COUNT: qs = qs.order_by(active_entry_count_field) else: - qs = qs.order_by(f'-{active_entry_count_field}') + qs = qs.order_by(f"-{active_entry_count_field}") # Custom for page count with nulls_last elif ordering == LeadOrderingEnum.DESC_PAGE_COUNT: - qs = qs.order_by(models.F('leadpreview__page_count').desc(nulls_last=True)) + qs = qs.order_by(models.F("leadpreview__page_count").desc(nulls_last=True)) elif ordering == LeadOrderingEnum.ASC_PAGE_COUNT: - qs = qs.order_by(models.F('leadpreview__page_count').asc(nulls_first=True)) + qs = qs.order_by(models.F("leadpreview__page_count").asc(nulls_first=True)) # For remaining else: qs = qs.order_by(ordering) return qs def emm_entities_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_entities__in=splitted) def emm_keywords_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_triggers__emm_keyword__in=splitted) def emm_risk_factors_filter(self, qs, name, value): - splitted = [x for x in value.split(',') if x] + splitted = [x for x in value.split(",") if x] return qs.filter(emm_triggers__emm_risk_factor__in=splitted) def authoring_organization_types_filter(self, qs, name, value): if value: - qs = qs.annotate( - organization_types=Coalesce( - 'authors__parent__organization_type', - 'authors__organization_type' - ) - ) + qs = qs.annotate(organization_types=Coalesce("authors__parent__organization_type", "authors__organization_type")) if isinstance(value[0], OrganizationType): return qs.filter(organization_types__in=[ot.id for ot in value]).distinct() return qs.filter(organization_types__in=value).distinct() @@ -428,13 +429,13 @@ def authoring_organization_types_filter(self, qs, name, value): def authoring_organizations_filter(self, qs, _, value): if value: - qs = qs.annotate(authoring_organizations=Coalesce('authors__parent_id', 'authors__id')) + qs = qs.annotate(authoring_organizations=Coalesce("authors__parent_id", "authors__id")) return qs.filter(authoring_organizations__in=value).distinct() return qs def source_organizations_filter(self, qs, _, value): if value: - qs = qs.annotate(source_organizations=Coalesce('source__parent_id', 'source__id')) + qs = qs.annotate(source_organizations=Coalesce("source__parent_id", "source__id")) return qs.filter(source_organizations__in=value).distinct() return qs @@ -445,7 +446,7 @@ def filter_exclude_provided_leads_id(self, qs, *_): def filter_leads_id(self, qs, _, value): if value is None: return qs - if self.data.get('exclude_provided_leads_id'): + if self.data.get("exclude_provided_leads_id"): return qs.exclude(id__in=value) return qs.filter(id__in=value) @@ -468,30 +469,32 @@ def filtered_entries_filter_data(self, qs, _, value): def filter_queryset(self, qs): def _entry_subquery(entry_qs: models.QuerySet): - subquery_qs = entry_qs.\ - filter( + subquery_qs = ( + entry_qs.filter( project=self.active_project, analysis_framework=self.active_project.analysis_framework_id, - lead=models.OuterRef('pk'), - )\ - .values('lead').order_by()\ - .annotate(count=models.Count('id'))\ - .values('count') + lead=models.OuterRef("pk"), + ) + .values("lead") + .order_by() + .annotate(count=models.Count("id")) + .values("count") + ) return Coalesce( - models.Subquery( - subquery_qs[:1], - output_field=models.IntegerField() - ), 0, + models.Subquery(subquery_qs[:1], output_field=models.IntegerField()), + 0, ) # Pre-annotate required fields for entries count (w/wo filters) - entries_filter_data = self.data.get('entries_filter_data') - has_entries = self.data.get('has_entries') + entries_filter_data = self.data.get("entries_filter_data") + has_entries = self.data.get("has_entries") has_entries_count_ordering = any( - ordering in [ + ordering + in [ LeadOrderingEnum.ASC_ENTRIES_COUNT, LeadOrderingEnum.DESC_ENTRIES_COUNT, - ] for ordering in self.data.get('ordering') or [] + ] + for ordering in self.data.get("ordering") or [] ) # With filter @@ -501,30 +504,25 @@ def _entry_subquery(entry_qs: models.QuerySet): EntryGQFilterSet( data={ **entries_filter_data, - 'from_subquery': True, + "from_subquery": True, }, request=self.request, ).qs ) ) - self.custom_context['active_entry_count_field'] = 'filtered_entry_count' + self.custom_context["active_entry_count_field"] = "filtered_entry_count" # Without filter - if has_entries is not None or ( - entries_filter_data is None and has_entries_count_ordering - ): - self.custom_context['active_entry_count_field'] = self.custom_context.\ - get('active_entry_count_field', 'entry_count') - qs = qs.annotate( - entry_count=_entry_subquery(Entry.objects.all()) - ) + if has_entries is not None or (entries_filter_data is None and has_entries_count_ordering): + self.custom_context["active_entry_count_field"] = self.custom_context.get("active_entry_count_field", "entry_count") + qs = qs.annotate(entry_count=_entry_subquery(Entry.objects.all())) # Call super function return super().filter_queryset(qs) def duplicates_of_filter(self, qs, _, lead_id: int): if lead_id is None: return qs - dup_qs1 = LeadDuplicates.objects.filter(source_lead=lead_id).values_list('target_lead', flat=True) - dup_qs2 = LeadDuplicates.objects.filter(target_lead=lead_id).values_list('source_lead', flat=True) + dup_qs1 = LeadDuplicates.objects.filter(source_lead=lead_id).values_list("target_lead", flat=True) + dup_qs2 = LeadDuplicates.objects.filter(target_lead=lead_id).values_list("source_lead", flat=True) return qs.filter(pk__in=dup_qs1.union(dup_qs2)) def has_duplicates_filter(self, qs, _, val: bool): @@ -540,7 +538,7 @@ def qs(self): class LeadGroupGQFilterSet(UserResourceGqlFilterSet): - search = django_filters.CharFilter(method='filter_title') + search = django_filters.CharFilter(method="filter_title") class Meta: model = LeadGroup @@ -554,10 +552,10 @@ def filter_title(self, qs, name, value): LeadsFilterDataType, LeadsFilterDataInputType = generate_type_for_filter_set( LeadGQFilterSet, - 'lead.schema.LeadListType', - 'LeadsFilterDataType', - 'LeadsFilterDataInputType', + "lead.schema.LeadListType", + "LeadsFilterDataType", + "LeadsFilterDataInputType", custom_new_fields_map={ - 'entries_filter_data': graphene.Field(EntriesFilterDataType), + "entries_filter_data": graphene.Field(EntriesFilterDataType), }, ) diff --git a/apps/lead/models.py b/apps/lead/models.py index eb04b1fe5a..1b6e22dd26 100644 --- a/apps/lead/models.py +++ b/apps/lead/models.py @@ -1,16 +1,16 @@ import uuid as python_uuid -from django.contrib.contenttypes.fields import GenericRelation + from django.conf import settings from django.contrib.auth.models import User +from django.contrib.contenttypes.fields import GenericRelation from django.db import models, transaction - -from project.models import Project -from project.permissions import PROJECT_PERMISSIONS -from project.mixins import ProjectEntityMixin +from gallery.models import File from notification.models import Assignment from organization.models import Organization +from project.mixins import ProjectEntityMixin +from project.models import Project +from project.permissions import PROJECT_PERMISSIONS from user_resource.models import UserResource -from gallery.models import File class LeadGroup(UserResource): @@ -26,12 +26,11 @@ def get_for_project(project): @staticmethod def get_for(user): - return LeadGroup.objects.filter( - models.Q(project__members=user) | - models.Q(project__user_groups__members=user) - ).annotate( - no_of_leads=models.Count('lead', distinct=True) - ).distinct() + return ( + LeadGroup.objects.filter(models.Q(project__members=user) | models.Q(project__user_groups__members=user)) + .annotate(no_of_leads=models.Count("lead", distinct=True)) + .distinct() + ) def can_get(self, user): return self.project.is_member(user) @@ -55,51 +54,53 @@ class Lead(UserResource, ProjectEntityMixin): """ class Confidentiality(models.TextChoices): - UNPROTECTED = 'unprotected', 'Public' - RESTRICTED = 'restricted', 'Restricted' - CONFIDENTIAL = 'confidential', 'Confidential' + UNPROTECTED = "unprotected", "Public" + RESTRICTED = "restricted", "Restricted" + CONFIDENTIAL = "confidential", "Confidential" class Status(models.TextChoices): # TODO: Update enum value - NOT_TAGGED = 'pending', 'Not Tagged' - IN_PROGRESS = 'processed', 'In progress' - TAGGED = 'validated', 'Tagged' + NOT_TAGGED = "pending", "Not Tagged" + IN_PROGRESS = "processed", "In progress" + TAGGED = "validated", "Tagged" class Priority(models.IntegerChoices): - LOW = 100, 'Low' - MEDIUM = 200, 'Medium' - HIGH = 300, 'High' + LOW = 100, "Low" + MEDIUM = 200, "Medium" + HIGH = 300, "High" class SourceType(models.TextChoices): - TEXT = 'text', 'Text' - DISK = 'disk', 'Disk' - WEBSITE = 'website', 'Website' - DROPBOX = 'dropbox', 'Dropbox' - GOOGLE_DRIVE = 'google-drive', 'Google Drive' + TEXT = "text", "Text" + DISK = "disk", "Disk" + WEBSITE = "website", "Website" + DROPBOX = "dropbox", "Dropbox" + GOOGLE_DRIVE = "google-drive", "Google Drive" - RSS = 'rss', 'RSS Feed' - EMM = 'emm', 'EMM' - WEB_API = 'api', 'Web API' - UNKNOWN = 'unknown', 'Unknown' + RSS = "rss", "RSS Feed" + EMM = "emm", "EMM" + WEB_API = "api", "Web API" + UNKNOWN = "unknown", "Unknown" class ExtractionStatus(models.IntegerChoices): - PENDING = 0, 'Pending' - STARTED = 1, 'Started' - RETRYING = 4, 'Retrying' - SUCCESS = 2, 'Success' - FAILED = 3, 'Failed' + PENDING = 0, "Pending" + STARTED = 1, "Started" + RETRYING = 4, "Retrying" + SUCCESS = 2, "Success" + FAILED = 3, "Failed" class AutoExtractionStatus(models.IntegerChoices): NONE = 0, "None" STARTED = 1, "Started" - PENDING = 2, 'Pending' - SUCCESS = 3, 'Success' - FAILED = 4, 'Failed' + PENDING = 2, "Pending" + SUCCESS = 3, "Success" + FAILED = 4, "Failed" lead_group = models.ForeignKey( LeadGroup, on_delete=models.SET_NULL, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, ) uuid = models.UUIDField(default=python_uuid.uuid4, editable=False, unique=True) @@ -109,12 +110,21 @@ class AutoExtractionStatus(models.IntegerChoices): authors = models.ManyToManyField(Organization, blank=True) # TODO: Remove (Legacy), Make sure to copy author to authors if authors is empty author = models.ForeignKey( - Organization, verbose_name='author (legacy)', related_name='leads_by_author', - on_delete=models.SET_NULL, null=True, blank=True, default=None, + Organization, + verbose_name="author (legacy)", + related_name="leads_by_author", + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, ) source = models.ForeignKey( - Organization, related_name='leads_by_source', - on_delete=models.SET_NULL, null=True, blank=True, default=None, + Organization, + related_name="leads_by_source", + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, ) # Legacy Data (Remove after migrating all data) @@ -131,35 +141,38 @@ class AutoExtractionStatus(models.IntegerChoices): text = models.TextField(blank=True) url = models.TextField(blank=True) - extraction_status = models.SmallIntegerField( - choices=ExtractionStatus.choices, default=ExtractionStatus.PENDING) + extraction_status = models.SmallIntegerField(choices=ExtractionStatus.choices, default=ExtractionStatus.PENDING) attachment = models.ForeignKey( - File, on_delete=models.SET_NULL, default=None, null=True, blank=True, + File, + on_delete=models.SET_NULL, + default=None, + null=True, + blank=True, ) - emm_entities = models.ManyToManyField('EMMEntity', blank=True) - assignments = GenericRelation(Assignment, related_query_name='lead') + emm_entities = models.ManyToManyField("EMMEntity", blank=True) + assignments = GenericRelation(Assignment, related_query_name="lead") is_assessment_lead = models.BooleanField(default=False) # Connector # On delete, make sure to update UnifiedConnectorLead aleady_added to false. connector_lead = models.ForeignKey( - 'unified_connector.ConnectorLead', - on_delete=models.SET_NULL, related_name='+', blank=True, null=True + "unified_connector.ConnectorLead", on_delete=models.SET_NULL, related_name="+", blank=True, null=True ) duplicate_leads = models.ManyToManyField( "Lead", blank=True, - through='LeadDuplicates', + through="LeadDuplicates", ) is_indexed = models.BooleanField(default=False) duplicate_leads_count = models.PositiveIntegerField(default=0) indexed_at = models.DateTimeField(null=True, blank=True) auto_entry_extraction_status = models.SmallIntegerField( - choices=AutoExtractionStatus.choices, default=AutoExtractionStatus.NONE) + choices=AutoExtractionStatus.choices, default=AutoExtractionStatus.NONE + ) def __str__(self): - return '{}'.format(self.title) + return "{}".format(self.title) # Lead preview is invalid while saving url/text/attachment # Retrigger extraction at such cases @@ -173,35 +186,32 @@ def __init__(self, *args, **kwargs): def get_dict(self): return { - 'text': self.text, - 'url': self.url, - 'attachment_id': self.attachment_id, + "text": self.text, + "url": self.url, + "attachment_id": self.attachment_id, } def update_extraction_status(self, new_status, commit=True): self.extraction_status = new_status if commit: - self.save(update_fields=('extraction_status',)) + self.save(update_fields=("extraction_status",)) def save(self, *args, **kwargs): super().save(*args, **kwargs) - update_fields = kwargs.get('update_fields') - initial_fields = ['text', 'attachment', 'attachment_id', 'url'] - - if ( - not settings.TESTING and ( - self.id is None or - update_fields is None or - any(x in update_fields for x in initial_fields) - ) - ): + update_fields = kwargs.get("update_fields") + initial_fields = ["text", "attachment", "attachment_id", "url"] + + if not settings.TESTING and (self.id is None or update_fields is None or any(x in update_fields for x in initial_fields)): from lead.tasks import extract_from_lead d1 = self.__initial d2 = self.get_dict() - if not d1 or d1.get('text') != d2.get('text') or \ - d1.get('url') != d2.get('url') or \ - d1.get('attachment_id') != d2.get('attachment_id'): + if ( + not d1 + or d1.get("text") != d2.get("text") + or d1.get("url") != d2.get("url") + or d1.get("attachment_id") != d2.get("attachment_id") + ): transaction.on_commit(lambda: extract_from_lead.delay(self.id)) @classmethod @@ -221,56 +231,72 @@ def get_for(cls, user, filters=None): # NOTE: This is quite complicated because user can have two view roles: # view all or view only unprotected, both of which return different results - qs = cls.objects.filter( - # First filter if user is member - project__projectmembership__member=user, - ).annotate( - # Get permission value for view_only_unprotected permission - view_unprotected=models.F( - 'project__projectmembership__role__lead_permissions' - ).bitand(view_unprotected_perm_value), - # Get permission value for view permission - view_all=models.F( - 'project__projectmembership__role__lead_permissions' - ).bitand(view_perm_value) - ).filter( - # If view only unprotected, filter leads with confidentiality not confidential - ( - models.Q(view_unprotected=view_unprotected_perm_value) & - ~models.Q(confidentiality=Lead.Confidentiality.CONFIDENTIAL) - ) | - # Or, return nothing if view_all is not present - models.Q(view_all=view_perm_value) + qs = ( + cls.objects.filter( + # First filter if user is member + project__projectmembership__member=user, + ) + .annotate( + # Get permission value for view_only_unprotected permission + view_unprotected=models.F("project__projectmembership__role__lead_permissions").bitand( + view_unprotected_perm_value + ), + # Get permission value for view permission + view_all=models.F("project__projectmembership__role__lead_permissions").bitand(view_perm_value), + ) + .filter( + # If view only unprotected, filter leads with confidentiality not confidential + ( + models.Q(view_unprotected=view_unprotected_perm_value) + & ~models.Q(confidentiality=Lead.Confidentiality.CONFIDENTIAL) + ) + | + # Or, return nothing if view_all is not present + models.Q(view_all=view_perm_value) + ) ) # filter entries - entries_filter_data = filters.get('entries_filter_data', {}) + entries_filter_data = filters.get("entries_filter_data", {}) original_filter = {**entries_filter_data} - original_filter.pop('project', None) - entries_filter_data['from_subquery'] = True + original_filter.pop("project", None) + entries_filter_data["from_subquery"] = True return qs.annotate( - entries_count=models.Count('entry', distinct=True), - assessment_id=models.F('assessment'), - controlled_entries_count=models.Count('entry', filter=models.Q(entry__controlled=True)), - filtered_entries_count=models.functions.Coalesce( - models.Subquery( - get_filtered_entries(user, entries_filter_data).filter( - lead=models.OuterRef('pk') - ).values('lead').order_by().annotate( - count=models.Count('id') - ).values('count')[:1], output_field=models.IntegerField() - ), 0 - ) if original_filter else models.F('entries_count'), - controlled_filtered_entries_count=models.functions.Coalesce( - models.Subquery( - get_filtered_entries(user, entries_filter_data).filter( - lead=models.OuterRef('pk'), - controlled=True - ).values('lead').order_by().annotate( - count=models.Count('id') - ).values('count')[:1], output_field=models.IntegerField() - ), 0 - ) if original_filter else models.F('controlled_entries_count'), + entries_count=models.Count("entry", distinct=True), + assessment_id=models.F("assessment"), + controlled_entries_count=models.Count("entry", filter=models.Q(entry__controlled=True)), + filtered_entries_count=( + models.functions.Coalesce( + models.Subquery( + get_filtered_entries(user, entries_filter_data) + .filter(lead=models.OuterRef("pk")) + .values("lead") + .order_by() + .annotate(count=models.Count("id")) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ) + if original_filter + else models.F("entries_count") + ), + controlled_filtered_entries_count=( + models.functions.Coalesce( + models.Subquery( + get_filtered_entries(user, entries_filter_data) + .filter(lead=models.OuterRef("pk"), controlled=True) + .values("lead") + .order_by() + .annotate(count=models.Count("id")) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ) + if original_filter + else models.F("controlled_entries_count") + ), ) def get_assignee(self): @@ -289,9 +315,7 @@ def get_source_display(self, short_name=False): def get_authors_display(self, short_name=False): authors = self.authors.all() if authors: - return ','.join([ - (author.data.short_name if short_name else author.data.title) for author in authors - ]) + return ",".join([(author.data.short_name if short_name else author.data.title) for author in authors]) elif self.author: # TODO: Remove (Legacy) return self.author and (self.author.data.short_name if short_name else self.author.data.title) @@ -300,9 +324,7 @@ def get_authors_display(self, short_name=False): def get_authoring_organizations_type_display(self): authors = self.authors.all() if authors: - return ','.join(set([ - author.get_organization_type_display() for author in authors if author.data.organization_type - ])) + return ",".join(set([author.get_organization_type_display() for author in authors if author.data.organization_type])) elif self.author: return self.author.data.organization_type and self.author.data.organization_type.title return @@ -312,54 +334,62 @@ def get_associated_entities(cls, project_id, lead_ids): """ Used for pre-check before deletion """ - from entry.models import Entry from ary.models import Assessment + from entry.models import Entry + return { - 'entries': Entry.objects.filter(lead__in=lead_ids, lead__project_id=project_id).count(), - 'assessments': Assessment.objects.filter(lead__project_id=project_id, lead__in=lead_ids).count(), + "entries": Entry.objects.filter(lead__in=lead_ids, lead__project_id=project_id).count(), + "assessments": Assessment.objects.filter(lead__project_id=project_id, lead__in=lead_ids).count(), } @classmethod def get_emm_summary(cls, lead_qs): # Aggregate emm data - emm_entities = EMMEntity.objects\ - .filter(lead__in=lead_qs).values('name')\ - .annotate(total_count=models.Count('name'))\ - .order_by('-total_count').values('name', 'total_count') - emm_triggers = LeadEMMTrigger.objects\ - .filter(lead__in=lead_qs).values('emm_keyword', 'emm_risk_factor')\ - .annotate(total_count=models.Sum('count'))\ - .order_by('-total_count').values('emm_keyword', 'emm_risk_factor', 'total_count') + emm_entities = ( + EMMEntity.objects.filter(lead__in=lead_qs) + .values("name") + .annotate(total_count=models.Count("name")) + .order_by("-total_count") + .values("name", "total_count") + ) + emm_triggers = ( + LeadEMMTrigger.objects.filter(lead__in=lead_qs) + .values("emm_keyword", "emm_risk_factor") + .annotate(total_count=models.Sum("count")) + .order_by("-total_count") + .values("emm_keyword", "emm_risk_factor", "total_count") + ) return { - 'emm_entities': emm_entities, - 'emm_triggers': emm_triggers, + "emm_entities": emm_entities, + "emm_triggers": emm_triggers, } class LeadPreview(models.Model): class ClassificationStatus(models.TextChoices): - NONE = 'none', 'None' # For leads which are not texts - INITIATED = 'initiated', 'Initiated' - COMPLETED = 'completed', 'Completed' - FAILED = 'failed', 'Failed' # Somehow Failed due to connection error - ERRORED = 'errored', 'Errored' # If errored, no point in retrying + NONE = "none", "None" # For leads which are not texts + INITIATED = "initiated", "Initiated" + COMPLETED = "completed", "Completed" + FAILED = "failed", "Failed" # Somehow Failed due to connection error + ERRORED = "errored", "Errored" # If errored, no point in retrying lead = models.OneToOneField(Lead, on_delete=models.CASCADE) text_extract = models.TextField(blank=True) thumbnail = models.ImageField( - upload_to='lead-thumbnail/', - default=None, null=True, blank=True, - height_field='thumbnail_height', - width_field='thumbnail_width' + upload_to="lead-thumbnail/", + default=None, + null=True, + blank=True, + height_field="thumbnail_height", + width_field="thumbnail_width", ) thumbnail_height = models.IntegerField(default=None, null=True, blank=True) thumbnail_width = models.IntegerField(default=None, null=True, blank=True) word_count = models.IntegerField(default=None, null=True, blank=True) page_count = models.IntegerField(default=None, null=True, blank=True) - classified_doc_id = models.IntegerField(default=None, - null=True, blank=True) + classified_doc_id = models.IntegerField(default=None, null=True, blank=True) classification_status = models.CharField( max_length=20, choices=ClassificationStatus.choices, @@ -369,20 +399,23 @@ class ClassificationStatus(models.TextChoices): text_extraction_id = models.UUIDField(blank=True, null=True) # Saved when TextExtraction is completed def __str__(self): - return 'Text extracted for {}'.format(self.lead) + return "Text extracted for {}".format(self.lead) class LeadPreviewImage(models.Model): """ NOTE: File can be only used by gallery (when attached to a entry) """ + lead = models.ForeignKey( - Lead, related_name='images', on_delete=models.CASCADE, + Lead, + related_name="images", + on_delete=models.CASCADE, ) - file = models.FileField(upload_to='lead-preview/') + file = models.FileField(upload_to="lead-preview/") def __str__(self): - return 'Image extracted for {}'.format(self.lead) + return "Image extracted for {}".format(self.lead) def clone_as_deep_file(self, user): """ @@ -399,20 +432,20 @@ def clone_as_deep_file(self, user): class LeadEMMTrigger(models.Model): - lead = models.ForeignKey(Lead, related_name='emm_triggers', on_delete=models.CASCADE) + lead = models.ForeignKey(Lead, related_name="emm_triggers", on_delete=models.CASCADE) emm_keyword = models.CharField(max_length=100) emm_risk_factor = models.CharField(max_length=100) count = models.PositiveIntegerField(default=0) class Meta: - ordering = ('-count',) + ordering = ("-count",) class EMMEntity(models.Model): name = models.CharField(max_length=150, unique=True) class Meta: - ordering = ('name',) + ordering = ("name",) def __str__(self): return self.name diff --git a/apps/lead/mutation.py b/apps/lead/mutation.py index 122a7b0c3d..0c72072590 100644 --- a/apps/lead/mutation.py +++ b/apps/lead/mutation.py @@ -1,50 +1,44 @@ import graphene +from deep.permissions import ProjectPermissions as PP +from utils.graphene.error_types import CustomErrorType, mutation_is_not_valid from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsBulkGrapheneMutation, PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from utils.graphene.error_types import ( - mutation_is_not_valid, - CustomErrorType -) -from deep.permissions import ProjectPermissions as PP from .models import Lead, LeadGroup, UserSavedLeadFilter -from .schema import LeadType, LeadGroupType, UserSavedLeadFilterType -from .serializers import ( - LeadGqSerializer as LeadSerializer, - LeadCopyGqSerializer, - UserSavedLeadFilterSerializer, -) - +from .schema import LeadGroupType, LeadType, UserSavedLeadFilterType +from .serializers import LeadCopyGqSerializer +from .serializers import LeadGqSerializer as LeadSerializer +from .serializers import UserSavedLeadFilterSerializer LeadInputType = generate_input_type_for_serializer( - 'LeadInputType', + "LeadInputType", serializer_class=LeadSerializer, ) LeadCopyInputType = generate_input_type_for_serializer( - 'LeadCopyInputType', + "LeadCopyInputType", serializer_class=LeadCopyGqSerializer, ) UserSavedLeadFilterInputType = generate_input_type_for_serializer( - 'UserSavedLeadFilterInputType', + "UserSavedLeadFilterInputType", serializer_class=UserSavedLeadFilterSerializer, ) -class LeadMutationMixin(): +class LeadMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(project=info.context.active_project) -class LeadGroupMutationMixin(): +class LeadGroupMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(project=info.context.active_project) @@ -53,6 +47,7 @@ def filter_queryset(cls, qs, info): class CreateLead(LeadMutationMixin, PsGrapheneMutation): class Arguments: data = LeadInputType(required=True) + model = Lead serializer_class = LeadSerializer result = graphene.Field(LeadType) @@ -63,6 +58,7 @@ class UpdateLead(LeadMutationMixin, PsGrapheneMutation): class Arguments: data = LeadInputType(required=True) id = graphene.ID(required=True) + model = Lead serializer_class = LeadSerializer result = graphene.Field(LeadType) @@ -72,6 +68,7 @@ class Arguments: class DeleteLead(LeadMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = Lead result = graphene.Field(LeadType) permissions = [PP.Permission.DELETE_LEAD] @@ -80,6 +77,7 @@ class Arguments: class DeleteLeadGroup(LeadGroupMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = LeadGroup result = graphene.Field(LeadGroupType) permissions = [PP.Permission.DELETE_LEAD] @@ -111,7 +109,7 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = LeadCopyGqSerializer(data=data, context={'request': info.context.request}) + serializer = LeadCopyGqSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return LeadCopy(errors=errors, ok=False) new_leads = serializer.save() @@ -121,6 +119,7 @@ def mutate(root, info, data): class SaveUserSavedLeadFilter(PsGrapheneMutation): class Arguments: data = UserSavedLeadFilterInputType(required=True) + model = Lead serializer_class = UserSavedLeadFilterSerializer result = graphene.Field(UserSavedLeadFilterType) @@ -132,14 +131,14 @@ def mutate(root, info, data): user=info.context.user, project=info.context.active_project, ) - serializer = UserSavedLeadFilterSerializer(instance=instance, data=data, context={'request': info.context.request}) + serializer = UserSavedLeadFilterSerializer(instance=instance, data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return SaveUserSavedLeadFilter(errors=errors, ok=False) updated_instance = serializer.save() return SaveUserSavedLeadFilter(result=updated_instance, errors=None, ok=True) -class Mutation(): +class Mutation: lead_create = CreateLead.Field() lead_update = UpdateLead.Field() lead_delete = DeleteLead.Field() diff --git a/apps/lead/public_schema.py b/apps/lead/public_schema.py index ffa5a4f608..5ea072be35 100644 --- a/apps/lead/public_schema.py +++ b/apps/lead/public_schema.py @@ -1,10 +1,10 @@ import graphene from django.db import models +from gallery.schema import PublicGalleryFileType +from project.public_schema import PublicProjectWithMembershipData from deep.permissions import ProjectPermissions as PP from utils.graphene.enums import EnumDescription -from gallery.schema import PublicGalleryFileType -from project.public_schema import PublicProjectWithMembershipData from .models import Lead from .schema import LeadSourceTypeEnum @@ -15,12 +15,12 @@ def get_public_lead_qs(): models.Q( project__has_publicly_viewable_unprotected_leads=True, confidentiality=Lead.Confidentiality.UNPROTECTED, - ) | - models.Q( + ) + | models.Q( project__has_publicly_viewable_restricted_leads=True, confidentiality=Lead.Confidentiality.RESTRICTED, - ) | - models.Q( + ) + | models.Q( project__has_publicly_viewable_confidential_leads=True, confidentiality=Lead.Confidentiality.CONFIDENTIAL, ) @@ -36,7 +36,7 @@ class PublicLeadDetailType(graphene.ObjectType): published_on = graphene.Date() source_type = graphene.Field(LeadSourceTypeEnum, required=True) - source_type_display = EnumDescription(source='get_source_type_display', required=True) + source_type_display = EnumDescription(source="get_source_type_display", required=True) text = graphene.String() url = graphene.String() attachment = graphene.Field(PublicGalleryFileType) @@ -81,18 +81,21 @@ def _return(lead, project, has_access): if lead: lead.has_project_access = has_access return { - 'project': _project, - 'lead': lead, + "project": _project, + "lead": lead, } def _get_lead_from_qs(qs): - return qs\ - .select_related( - 'project', - 'created_by', - 'source', - 'source__parent', - ).filter(uuid=kwargs['uuid']).first() + return ( + qs.select_related( + "project", + "created_by", + "source", + "source__parent", + ) + .filter(uuid=kwargs["uuid"]) + .first() + ) user = info.context.user public_lead = _get_lead_from_qs(get_public_lead_qs()) @@ -116,8 +119,8 @@ def _get_lead_from_qs(qs): if PP.Permission.VIEW_ALL_LEAD in user_permissions: return _return(lead, lead.project, True) if ( - PP.Permission.VIEW_ONLY_UNPROTECTED_LEAD in user_permissions and - lead.confidentiality != Lead.Confidentiality.CONFIDENTIAL + PP.Permission.VIEW_ONLY_UNPROTECTED_LEAD in user_permissions + and lead.confidentiality != Lead.Confidentiality.CONFIDENTIAL ): return _return(lead, lead.project, True) return _return(None, lead.project, True) diff --git a/apps/lead/receivers.py b/apps/lead/receivers.py index 5a82124530..2f6ca50d5a 100644 --- a/apps/lead/receivers.py +++ b/apps/lead/receivers.py @@ -1,10 +1,9 @@ # Reusable actions +from deduplication.tasks.indexing import remove_lead_from_index from django.db import models, transaction from django.dispatch import receiver - from lead.models import Lead, LeadDuplicates from unified_connector.models import ConnectorSourceLead -from deduplication.tasks.indexing import remove_lead_from_index @receiver(models.signals.post_delete, sender=Lead) @@ -22,9 +21,7 @@ def update_indices(sender, instance, **kwargs): def update_index_and_duplicates(lead: Lead): remove_lead_from_index.delay(lead.id) # Now get all other leads which are duplicates of the lead and update their count - dup_qs1 = LeadDuplicates.objects.filter(source_lead=lead.id).values_list('target_lead', flat=True) - dup_qs2 = LeadDuplicates.objects.filter(target_lead=lead.id).values_list('source_lead', flat=True) + dup_qs1 = LeadDuplicates.objects.filter(source_lead=lead.id).values_list("target_lead", flat=True) + dup_qs2 = LeadDuplicates.objects.filter(target_lead=lead.id).values_list("source_lead", flat=True) dup_leads = Lead.objects.filter(pk__in=dup_qs1.union(dup_qs2)) - dup_leads.update( - duplicate_leads_count=models.F('duplicate_leads_count') - 1 - ) + dup_leads.update(duplicate_leads_count=models.F("duplicate_leads_count") - 1) diff --git a/apps/lead/schema.py b/apps/lead/schema.py index cf352b7031..6f43326ab4 100644 --- a/apps/lead/schema.py +++ b/apps/lead/schema.py @@ -1,50 +1,46 @@ -import graphene from functools import reduce from typing import Union + +import graphene +from analysis_framework.models import Filter as AfFilter +from analysis_framework.models import Widget from django.db import models from django.db.models import QuerySet -from graphene_django import DjangoObjectType, DjangoListField -from graphene_django_extras import DjangoObjectField, PageGraphqlPagination - -from utils.graphene.pagination import NoOrderingPageGraphqlPagination -from utils.graphene.enums import EnumDescription -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin -from utils.graphene.fields import DjangoPaginatedListObjectField - -from user.models import User -from organization.models import Organization, OrganizationType as OrganizationTypeModel from geo.models import GeoArea -from analysis_framework.models import Filter as AfFilter, Widget - -from user_resource.schema import UserResourceMixin -from deep.permissions import ProjectPermissions as PP -from deep.permalinks import Permalink -from organization.schema import OrganizationType, OrganizationTypeType -from user.schema import UserType from geo.schema import ProjectGeoAreaType - +from graphene_django import DjangoListField, DjangoObjectType +from graphene_django_extras import DjangoObjectField, PageGraphqlPagination from lead.filter_set import LeadsFilterDataType +from organization.models import Organization +from organization.models import OrganizationType as OrganizationTypeModel +from organization.schema import OrganizationType, OrganizationTypeType +from user.models import User +from user.schema import UserType +from user_resource.schema import UserResourceMixin +from deep.permalinks import Permalink +from deep.permissions import ProjectPermissions as PP +from utils.graphene.enums import EnumDescription +from utils.graphene.fields import DjangoPaginatedListObjectField +from utils.graphene.pagination import NoOrderingPageGraphqlPagination +from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType -from .models import ( - Lead, - LeadGroup, - LeadPreview, - LeadEMMTrigger, - EMMEntity, - UserSavedLeadFilter, -) from .enums import ( + LeadAutoEntryExtractionTypeEnum, LeadConfidentialityEnum, - LeadStatusEnum, + LeadExtractionStatusEnum, LeadPriorityEnum, LeadSourceTypeEnum, - LeadExtractionStatusEnum, - LeadAutoEntryExtractionTypeEnum, + LeadStatusEnum, ) -from .filter_set import ( - LeadGQFilterSet, - LeadGroupGQFilterSet, +from .filter_set import LeadGQFilterSet, LeadGroupGQFilterSet +from .models import ( + EMMEntity, + Lead, + LeadEMMTrigger, + LeadGroup, + LeadPreview, + UserSavedLeadFilter, ) @@ -86,53 +82,48 @@ def get_lead_emm_entities_qs(info): # Generates database level objects used in filters. def get_lead_filter_data(filters, context): def _filter_by_id(entity_list, entity_id_list): - return [ - entity - for entity in entity_list - if entity.id in entity_id_list - ] + return [entity for entity in entity_list if entity.id in entity_id_list] def _id_to_int(ids): - return [ - int(_id) for _id in ids - ] + return [int(_id) for _id in ids] if filters is None or not isinstance(filters, dict): return {} - entry_filter_data = filters.get('entries_filter_data') or {} + entry_filter_data = filters.get("entries_filter_data") or {} geo_widget_filter_keys = AfFilter.objects.filter( analysis_framework=context.active_project.analysis_framework_id, widget_key__in=Widget.objects.filter( analysis_framework=context.active_project.analysis_framework_id, widget_id=Widget.WidgetType.GEO, - ).values_list('key', flat=True) - ).values_list('key', flat=True) + ).values_list("key", flat=True), + ).values_list("key", flat=True) # Lead Filter Data - created_by_ids = _id_to_int(filters.get('created_by') or []) - modified_by_ids = _id_to_int(filters.get('modified_by') or []) - assignee_ids = _id_to_int(filters.get('assignees') or []) - author_organization_type_ids = _id_to_int(filters.get('authoring_organization_types') or []) - author_organization_ids = _id_to_int(filters.get('author_organizations') or []) - source_organization_ids = _id_to_int(filters.get('source_organizations') or []) + created_by_ids = _id_to_int(filters.get("created_by") or []) + modified_by_ids = _id_to_int(filters.get("modified_by") or []) + assignee_ids = _id_to_int(filters.get("assignees") or []) + author_organization_type_ids = _id_to_int(filters.get("authoring_organization_types") or []) + author_organization_ids = _id_to_int(filters.get("author_organizations") or []) + source_organization_ids = _id_to_int(filters.get("source_organizations") or []) # Entry Filter Data - ef_lead_assignee_ids = _id_to_int(entry_filter_data.get('lead_assignees') or []) - ef_lead_authoring_organizationtype_ids = _id_to_int(entry_filter_data.get('lead_authoring_organization_types') or []) - ef_lead_author_organization_ids = _id_to_int(entry_filter_data.get('lead_author_organizations') or []) - ef_lead_source_organization_ids = _id_to_int(entry_filter_data.get('lead_source_organizations') or []) - ef_lead_created_by_ids = _id_to_int(entry_filter_data.get('lead_created_by') or []) - ef_created_by_ids = _id_to_int(entry_filter_data.get('created_by') or []) - ef_modified_by_ids = _id_to_int(entry_filter_data.get('modified_by') or []) + ef_lead_assignee_ids = _id_to_int(entry_filter_data.get("lead_assignees") or []) + ef_lead_authoring_organizationtype_ids = _id_to_int(entry_filter_data.get("lead_authoring_organization_types") or []) + ef_lead_author_organization_ids = _id_to_int(entry_filter_data.get("lead_author_organizations") or []) + ef_lead_source_organization_ids = _id_to_int(entry_filter_data.get("lead_source_organizations") or []) + ef_lead_created_by_ids = _id_to_int(entry_filter_data.get("lead_created_by") or []) + ef_created_by_ids = _id_to_int(entry_filter_data.get("created_by") or []) + ef_modified_by_ids = _id_to_int(entry_filter_data.get("modified_by") or []) ef_geo_area_ids = set( _id_to_int( reduce( lambda a, b: a + b, [ - filterable_data['value_list'] or [] - for filterable_data in entry_filter_data.get('filterable_data') or [] - if filterable_data.get('filter_key') in geo_widget_filter_keys and filterable_data.get('value_list') - ], [] + filterable_data["value_list"] or [] + for filterable_data in entry_filter_data.get("filterable_data") or [] + if filterable_data.get("filter_key") in geo_widget_filter_keys and filterable_data.get("value_list") + ], + [], ) ) ) @@ -141,43 +132,46 @@ def _id_to_int(ids): users = list( User.objects.filter( projectmembership__project=context.active_project, - id__in=set([ - *created_by_ids, - *modified_by_ids, - *assignee_ids, - *ef_created_by_ids, - *ef_lead_assignee_ids, - *ef_lead_created_by_ids, - *ef_modified_by_ids, - ]) - ).order_by('id') + id__in=set( + [ + *created_by_ids, + *modified_by_ids, + *assignee_ids, + *ef_created_by_ids, + *ef_lead_assignee_ids, + *ef_lead_created_by_ids, + *ef_modified_by_ids, + ] + ), + ).order_by("id") ) organizations = list( Organization.objects.filter( - id__in=set([ - *author_organization_ids, - *source_organization_ids, - *ef_lead_author_organization_ids, - *ef_lead_source_organization_ids, - ]) - ).order_by('id') + id__in=set( + [ + *author_organization_ids, + *source_organization_ids, + *ef_lead_author_organization_ids, + *ef_lead_source_organization_ids, + ] + ) + ).order_by("id") ) organization_types = list( OrganizationTypeModel.objects.filter( - id__in=set([ - *author_organization_type_ids, - *ef_lead_authoring_organizationtype_ids, - ]) - ).order_by('id') + id__in=set( + [ + *author_organization_type_ids, + *ef_lead_authoring_organizationtype_ids, + ] + ) + ).order_by("id") ) geoareas = list( - GeoArea.objects.filter( - admin_level__region__project=context.active_project, - id__in=ef_geo_area_ids - ).order_by('id') + GeoArea.objects.filter(admin_level__region__project=context.active_project, id__in=ef_geo_area_ids).order_by("id") ) return dict( @@ -204,13 +198,13 @@ class LeadPreviewType(DjangoObjectType): class Meta: model = LeadPreview only_fields = ( - 'text_extract', - 'thumbnail', - 'thumbnail_height', - 'thumbnail_width', - 'word_count', - 'page_count', - 'text_extraction_id' + "text_extract", + "thumbnail", + "thumbnail_height", + "thumbnail_width", + "word_count", + "page_count", + "text_extraction_id", # 'classified_doc_id', # 'classification_status', ) @@ -219,7 +213,7 @@ class Meta: class LeadEmmTriggerType(DjangoObjectType): class Meta: model = LeadEMMTrigger - only_fields = ('id', 'emm_keyword', 'emm_risk_factor', 'count') + only_fields = ("id", "emm_keyword", "emm_risk_factor", "count") @staticmethod def get_custom_queryset(queryset, info, **kwargs): @@ -235,7 +229,7 @@ class Meta: class EmmEntityType(DjangoObjectType): class Meta: model = EMMEntity - only_fields = ('id', 'name') + only_fields = ("id", "name") @staticmethod def get_custom_queryset(queryset, info, **kwargs): @@ -273,7 +267,8 @@ class LeadFilterDataType(graphene.ObjectType): entry_filter_lead_assignee_options = graphene.List(graphene.NonNull(UserType), required=True) entry_filter_lead_author_organization_options = graphene.List(graphene.NonNull(OrganizationType), required=True) entry_filter_lead_authoring_organization_type_options = graphene.List( - graphene.NonNull(OrganizationTypeType), required=True, + graphene.NonNull(OrganizationTypeType), + required=True, ) entry_filter_lead_created_by_options = graphene.List(graphene.NonNull(UserType), required=True) entry_filter_lead_source_organization_options = graphene.List(graphene.NonNull(OrganizationType), required=True) @@ -288,10 +283,10 @@ class UserSavedLeadFilterType(DjangoObjectType): class Meta: model = UserSavedLeadFilter only_fields = ( - 'id', - 'title', - 'created_at', - 'modified_at', + "id", + "title", + "created_at", + "modified_at", ) @staticmethod @@ -303,10 +298,11 @@ class LeadGroupType(UserResourceMixin, DjangoObjectType): class Meta: model = LeadGroup only_fields = ( - 'id', - 'title', - 'project', + "id", + "title", + "project", ) + lead_counts = graphene.Int(required=True) @staticmethod @@ -329,21 +325,28 @@ class LeadType(UserResourceMixin, ClientIdMixin, DjangoObjectType): class Meta: model = Lead only_fields = ( - 'id', 'title', 'is_assessment_lead', 'lead_group', 'assignee', 'published_on', - 'text', 'url', 'attachment', - 'client_id', + "id", + "title", + "is_assessment_lead", + "lead_group", + "assignee", + "published_on", + "text", + "url", + "attachment", + "client_id", ) - project = graphene.ID(source='project_id', required=True) + project = graphene.ID(source="project_id", required=True) # Enums source_type = graphene.Field(LeadSourceTypeEnum, required=True) - source_type_display = EnumDescription(source='get_source_type_display', required=True) + source_type_display = EnumDescription(source="get_source_type_display", required=True) priority = graphene.Field(LeadPriorityEnum, required=True) - priority_display = EnumDescription(source='get_priority_display', required=True) + priority_display = EnumDescription(source="get_priority_display", required=True) confidentiality = graphene.Field(LeadConfidentialityEnum, required=True) - confidentiality_display = EnumDescription(source='get_confidentiality_display', required=True) + confidentiality_display = EnumDescription(source="get_confidentiality_display", required=True) status = graphene.Field(LeadStatusEnum, required=True) - status_display = EnumDescription(source='get_status_display', required=True) + status_display = EnumDescription(source="get_status_display", required=True) extraction_status = graphene.Field(LeadExtractionStatusEnum) lead_preview = graphene.Field(LeadPreviewType) @@ -355,13 +358,12 @@ class Meta: emm_entities = DjangoListField(EmmEntityType) emm_triggers = DjangoListField(LeadEmmTriggerType) assessment_id = graphene.ID() - connector_lead = graphene.ID(source='connector_lead_id', required=False) + connector_lead = graphene.ID(source="connector_lead_id", required=False) # Entries count entries_count = graphene.Field(EntriesCountType) filtered_entries_count = graphene.Int( description=( - 'Count used to order or filter-out leads' - '. Can be =null or =entries_count->total or !=entries_count->total.' + "Count used to order or filter-out leads" ". Can be =null or =entries_count->total or !=entries_count->total." ) ) # Duplicate leads @@ -402,7 +404,7 @@ def resolve_entries_count(root, info, **kwargs): @staticmethod def resolve_filtered_entries_count(root, info, **kwargs): # filtered_entry_count is from LeadFilterSet - return getattr(root, 'filtered_entry_count', None) + return getattr(root, "filtered_entry_count", None) @staticmethod def resolve_share_view_url(root: Lead, info, **kwargs): @@ -423,12 +425,19 @@ class Meta: model = Lead skip_registry = True only_fields = ( - 'id', 'title', 'is_assessment_lead', 'lead_group', 'assignee', 'published_on', - 'text', 'url', 'attachment', - 'client_id', + "id", + "title", + "is_assessment_lead", + "lead_group", + "assignee", + "published_on", + "text", + "url", + "attachment", + "client_id", ) - entries = graphene.List(graphene.NonNull('entry.schema.EntryType')) + entries = graphene.List(graphene.NonNull("entry.schema.EntryType")) draft_entry_stat = graphene.Field(DraftEntryCountByLead) @staticmethod @@ -458,27 +467,18 @@ class Query: leads = DjangoPaginatedListObjectField( LeadListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', - ) + page_size_query_param="pageSize", + ), ) lead_group = DjangoObjectField(LeadGroupType) lead_groups = DjangoPaginatedListObjectField( - LeadGroupListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + LeadGroupListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) emm_entities = DjangoPaginatedListObjectField( - EmmEntityListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + EmmEntityListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) lead_emm_triggers = DjangoPaginatedListObjectField( - LeadEmmTriggerListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + LeadEmmTriggerListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) # TODO: Add Pagination emm_keywords = graphene.List(graphene.NonNull(EmmKeyWordType)) @@ -504,25 +504,29 @@ def resolve_lead_emm_triggers(root, info, **kwargs) -> QuerySet: @staticmethod def resolve_emm_keywords(root, info, **kwargs): - return LeadEMMTrigger.objects.filter( - lead__project=info.context.active_project - ).values('emm_keyword').annotate( - total_count=models.Sum('count'), - key=models.F('emm_keyword'), - label=models.F('emm_keyword') - ).order_by('emm_keyword') + return ( + LeadEMMTrigger.objects.filter(lead__project=info.context.active_project) + .values("emm_keyword") + .annotate(total_count=models.Sum("count"), key=models.F("emm_keyword"), label=models.F("emm_keyword")) + .order_by("emm_keyword") + ) @staticmethod def resolve_emm_risk_factors(root, info, **kwargs): - return LeadEMMTrigger.objects.filter( - ~models.Q(emm_risk_factor=''), - ~models.Q(emm_risk_factor=None), - lead__project=info.context.active_project, - ).values('emm_risk_factor').annotate( - total_count=models.Sum('count'), - key=models.F('emm_risk_factor'), - label=models.F('emm_risk_factor'), - ).order_by('emm_risk_factor') + return ( + LeadEMMTrigger.objects.filter( + ~models.Q(emm_risk_factor=""), + ~models.Q(emm_risk_factor=None), + lead__project=info.context.active_project, + ) + .values("emm_risk_factor") + .annotate( + total_count=models.Sum("count"), + key=models.F("emm_risk_factor"), + label=models.F("emm_risk_factor"), + ) + .order_by("emm_risk_factor") + ) @staticmethod def resolve_user_saved_lead_filter(root, info, **kwargs): diff --git a/apps/lead/serializers.py b/apps/lead/serializers.py index f8879ae09e..0a7e257fcc 100644 --- a/apps/lead/serializers.py +++ b/apps/lead/serializers.py @@ -1,37 +1,37 @@ import copy import uuid as python_uuid -from django.shortcuts import get_object_or_404 +from deduplication.tasks.indexing import index_lead_and_calculate_duplicates +from deepl_integration.handlers import LeadExtractionHandler from django.db import transaction +from django.shortcuts import get_object_or_404 +from django.utils import timezone from drf_dynamic_fields import DynamicFieldsMixin +from gallery.serializers import File, SimpleFileSerializer +from lead.filter_set import LeadsFilterDataInputType +from organization.serializers import SimpleOrganizationSerializer +from project.models import ProjectMembership +from project.serializers import SimpleProjectSerializer from rest_framework import serializers -from django.utils import timezone +from unified_connector.models import ConnectorSourceLead +from user.models import User +from user.serializers import SimpleUserSerializer +from user_resource.serializers import UserResourceSerializer -from utils.graphene.fields import generate_serializer_field_class from deep.permissions import ProjectPermissions as PP from deep.serializers import ( + GraphqlSupportDrfSerializerJSONField, + IdListField, + IntegerIDField, + ProjectPropertySerializerMixin, RemoveNullFieldsMixin, + StringListField, TempClientIdMixin, - IntegerIDField, URLCachedFileField, - IdListField, - StringListField, WriteOnlyOnCreateSerializerMixin, - ProjectPropertySerializerMixin, - GraphqlSupportDrfSerializerJSONField, ) -from organization.serializers import SimpleOrganizationSerializer -from user.serializers import SimpleUserSerializer -from user_resource.serializers import UserResourceSerializer -from project.serializers import SimpleProjectSerializer -from gallery.serializers import SimpleFileSerializer, File -from user.models import User -from project.models import ProjectMembership -from unified_connector.models import ConnectorSourceLead -from lead.filter_set import LeadsFilterDataInputType +from utils.graphene.fields import generate_serializer_field_class -from deepl_integration.handlers import LeadExtractionHandler -from deduplication.tasks.indexing import index_lead_and_calculate_duplicates from .models import ( EMMEntity, Lead, @@ -45,14 +45,25 @@ def check_if_url_exists(url, user=None, project=None, exception_id=None, return_lead=False): existing_lead = None if not project and user: - existing_lead = url and Lead.get_for(user).filter( - url__icontains=url, - ).exclude(id=exception_id).first() + existing_lead = ( + url + and Lead.get_for(user) + .filter( + url__icontains=url, + ) + .exclude(id=exception_id) + .first() + ) elif project: - existing_lead = url and Lead.objects.filter( - url__icontains=url, - project=project, - ).exclude(id=exception_id).first() + existing_lead = ( + url + and Lead.objects.filter( + url__icontains=url, + project=project, + ) + .exclude(id=exception_id) + .first() + ) if existing_lead: if return_lead: return existing_lead @@ -67,23 +78,32 @@ def raise_or_return_existing_lead(project, lead, source_type, url, text, attachm if source_type == Lead.SourceType.WEBSITE: existing_lead = check_if_url_exists(url, None, project, lead and lead.pk, return_lead=return_lead) - error_message = f'A source with this URL has already been added to Project: {project}' + error_message = f"A source with this URL has already been added to Project: {project}" elif ( - attachment and attachment.metadata and - source_type in [Lead.SourceType.DISK, Lead.SourceType.DROPBOX, Lead.SourceType.GOOGLE_DRIVE] + attachment + and attachment.metadata + and source_type in [Lead.SourceType.DISK, Lead.SourceType.DROPBOX, Lead.SourceType.GOOGLE_DRIVE] ): # For attachment types, check if file already used (using file hash) - existing_lead = Lead.objects.filter( - project=project, - attachment__metadata__md5_hash=attachment.metadata.get('md5_hash'), - ).exclude(pk=lead and lead.pk).first() - error_message = f'A source with this file has already been added to Project: {project}' + existing_lead = ( + Lead.objects.filter( + project=project, + attachment__metadata__md5_hash=attachment.metadata.get("md5_hash"), + ) + .exclude(pk=lead and lead.pk) + .first() + ) + error_message = f"A source with this file has already been added to Project: {project}" elif source_type == Lead.SourceType.TEXT: - existing_lead = Lead.objects.filter( - project=project, - text=text, - ).exclude(pk=lead and lead.pk).first() - error_message = f'A source with this text has already been added to Project: {project}' + existing_lead = ( + Lead.objects.filter( + project=project, + text=text, + ) + .exclude(pk=lead and lead.pk) + .first() + ) + error_message = f"A source with this text has already been added to Project: {project}" if existing_lead: if return_lead: @@ -110,20 +130,27 @@ class EMMEntitySerializer(serializers.Serializer, RemoveNullFieldsMixin, Dynamic name = serializers.CharField() class Meta: - fields = '__all__' + fields = "__all__" -class SimpleLeadSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleLeadSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Lead fields = ( - 'id', 'title', 'created_at', 'created_by', - 'source_raw', 'author_raw', - 'source', 'author', + "id", + "title", + "created_at", + "created_by", + "source_raw", + "author_raw", + "source", + "author", ) # Legacy Fields - read_only_fields = ('author_raw', 'source_raw',) + read_only_fields = ( + "author_raw", + "source_raw", + ) class LeadEMMTriggerSerializer(serializers.ModelSerializer, RemoveNullFieldsMixin, DynamicFieldsMixin): @@ -132,15 +159,23 @@ class LeadEMMTriggerSerializer(serializers.ModelSerializer, RemoveNullFieldsMixi class Meta: model = LeadEMMTrigger - fields = ('emm_risk_factor', 'emm_keyword', 'count',) + fields = ( + "emm_risk_factor", + "emm_keyword", + "count", + ) class LeadSerializer( - RemoveNullFieldsMixin, DynamicFieldsMixin, WriteOnlyOnCreateSerializerMixin, UserResourceSerializer, + RemoveNullFieldsMixin, + DynamicFieldsMixin, + WriteOnlyOnCreateSerializerMixin, + UserResourceSerializer, ): """ Lead Model Serializer """ + # annotated in lead.get_for entries_count = serializers.IntegerField(read_only=True) controlled_entries_count = serializers.IntegerField(read_only=True) @@ -149,74 +184,74 @@ class LeadSerializer( assessment_id = serializers.IntegerField(read_only=True) - priority_display = serializers.CharField(source='get_priority_display', read_only=True) + priority_display = serializers.CharField(source="get_priority_display", read_only=True) attachment = SimpleFileSerializer(required=False) thumbnail = URLCachedFileField( - source='leadpreview.thumbnail', + source="leadpreview.thumbnail", read_only=True, ) thumbnail_height = serializers.IntegerField( - source='leadpreview.thumbnail_height', + source="leadpreview.thumbnail_height", read_only=True, ) thumbnail_width = serializers.IntegerField( - source='leadpreview.thumbnail_width', + source="leadpreview.thumbnail_width", read_only=True, ) word_count = serializers.IntegerField( - source='leadpreview.word_count', + source="leadpreview.word_count", read_only=True, ) page_count = serializers.IntegerField( - source='leadpreview.page_count', + source="leadpreview.page_count", read_only=True, ) classified_doc_id = serializers.IntegerField( - source='leadpreview.classified_doc_id', + source="leadpreview.classified_doc_id", read_only=True, ) # TODO: Remove (Legacy) - author_detail = SimpleOrganizationSerializer(source='author', read_only=True) + author_detail = SimpleOrganizationSerializer(source="author", read_only=True) - authors_detail = SimpleOrganizationSerializer(source='authors', many=True, read_only=True) - source_detail = SimpleOrganizationSerializer(source='source', read_only=True) + authors_detail = SimpleOrganizationSerializer(source="authors", many=True, read_only=True) + source_detail = SimpleOrganizationSerializer(source="source", read_only=True) assignee_details = SimpleUserSerializer( - source='get_assignee', + source="get_assignee", # many=True, read_only=True, ) assignee = SingleValueThayMayBeListField( - source='get_assignee.id', + source="get_assignee.id", required=False, ) tabular_book = serializers.SerializerMethodField() emm_triggers = LeadEMMTriggerSerializer(many=True, required=False) emm_entities = EMMEntitySerializer(many=True, required=False) # extra fields added from entryleadserializer - confidentiality_display = serializers.CharField(source='get_confidentiality_display', read_only=True) + confidentiality_display = serializers.CharField(source="get_confidentiality_display", read_only=True) class Meta: model = Lead - fields = ('__all__') + fields = "__all__" # Legacy Fields - read_only_fields = ('author_raw', 'source_raw') - write_only_on_create_fields = ['emm_triggers', 'emm_entities'] + read_only_fields = ("author_raw", "source_raw") + write_only_on_create_fields = ["emm_triggers", "emm_entities"] def get_tabular_book(self, obj): file = obj.attachment - if file and hasattr(file, 'book'): + if file and hasattr(file, "book"): return file.book.id return None @staticmethod def add_update__validate(data, instance, attachment=None): - project = data.get('project', instance and instance.project) - source_type = data.get('source_type', instance and instance.source_type) - text = data.get('text', instance and instance.text) - url = data.get('url', instance and instance.url) + project = data.get("project", instance and instance.project) + source_type = data.get("source_type", instance and instance.source_type) + text = data.get("text", instance and instance.text) + url = data.get("url", instance and instance.url) return raise_or_return_existing_lead( project, @@ -230,30 +265,25 @@ def add_update__validate(data, instance, attachment=None): def validate_is_assessment_lead(self, value): # Allow setting True # For False make sure there are no assessment attached. - if value is False and hasattr(self.instance, 'assessment'): - raise serializers.ValidationError('Lead already has an assessment.') + if value is False and hasattr(self.instance, "assessment"): + raise serializers.ValidationError("Lead already has an assessment.") return value def validate(self, data): - attachment_id = self.get_initial().get('attachment', {}).get('id') - LeadSerializer.add_update__validate( - data, self.instance, - File.objects.filter(pk=attachment_id).first() - ) + attachment_id = self.get_initial().get("attachment", {}).get("id") + LeadSerializer.add_update__validate(data, self.instance, File.objects.filter(pk=attachment_id).first()) return data # TODO: Probably also validate assignee to valid list of users def create(self, validated_data): - assignee_field = validated_data.pop('get_assignee', None) - assignee_id = assignee_field and assignee_field.get('id', None) + assignee_field = validated_data.pop("get_assignee", None) + assignee_id = assignee_field and assignee_field.get("id", None) assignee = assignee_id and get_object_or_404(User, id=assignee_id) - emm_triggers = validated_data.pop('emm_triggers', []) + emm_triggers = validated_data.pop("emm_triggers", []) emm_entities_names = [ - entity['name'] - for entity in validated_data.pop('emm_entities', []) - if isinstance(entity, dict) and 'name' in entity + entity["name"] for entity in validated_data.pop("emm_entities", []) if isinstance(entity, dict) and "name" in entity ] lead = super().create(validated_data) @@ -271,13 +301,13 @@ def create(self, validated_data): return lead def update(self, instance, validated_data): - assignee_field = validated_data.pop('get_assignee', None) - assignee_id = assignee_field and assignee_field.get('id', None) + assignee_field = validated_data.pop("get_assignee", None) + assignee_id = assignee_field and assignee_field.get("id", None) assignee = assignee_id and get_object_or_404(User, id=assignee_id) # We do not update triggers and entities - validated_data.pop('emm_entities', None) - validated_data.pop('emm_triggers', None) + validated_data.pop("emm_entities", None) + validated_data.pop("emm_triggers", None) lead = super().update(instance, validated_data) @@ -288,9 +318,7 @@ def update(self, instance, validated_data): return lead -class LeadPreviewImageSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): +class LeadPreviewImageSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Serializer for lead preview image """ @@ -299,48 +327,49 @@ class LeadPreviewImageSerializer(RemoveNullFieldsMixin, class Meta: model = LeadPreviewImage - fields = ('id', 'file',) + fields = ( + "id", + "file", + ) -class LeadPreviewSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class LeadPreviewSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): """ Serializer for lead preview """ - text = serializers.CharField(source='leadpreview.text_extract', - read_only=True) + text = serializers.CharField(source="leadpreview.text_extract", read_only=True) images = LeadPreviewImageSerializer(many=True, read_only=True) classified_doc_id = serializers.IntegerField( - source='leadpreview.classified_doc_id', + source="leadpreview.classified_doc_id", read_only=True, ) preview_id = serializers.IntegerField( - source='leadpreview.pk', + source="leadpreview.pk", read_only=True, ) class Meta: model = Lead - fields = ('id', 'preview_id', 'text', 'images', 'classified_doc_id') + fields = ("id", "preview_id", "text", "images", "classified_doc_id") -class LeadGroupSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, UserResourceSerializer): - leads = LeadSerializer(source='lead_set', - many=True, - read_only=True) +class LeadGroupSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): + leads = LeadSerializer(source="lead_set", many=True, read_only=True) no_of_leads = serializers.IntegerField(read_only=True) class Meta: model = LeadGroup - fields = ('__all__') + fields = "__all__" class SimpleLeadGroupSerializer(UserResourceSerializer): class Meta: model = LeadGroup - fields = ('id', 'title',) + fields = ( + "id", + "title", + ) class KeyValueSerializer(serializers.Serializer): @@ -410,6 +439,7 @@ class LeadGqSerializer(ProjectPropertySerializerMixin, TempClientIdMixin, UserRe """ Lead Model Serializer for Graphql (NOTE: Don't use this on DRF Views) """ + id = IntegerIDField(required=False) # TODO: Make assigne Foreign key from M2M Field assignee = SingleValueThayMayBeListField(required=False) @@ -421,26 +451,26 @@ class LeadGqSerializer(ProjectPropertySerializerMixin, TempClientIdMixin, UserRe class Meta: model = Lead fields = ( - 'id', - 'title', - 'attachment', - 'status', - 'assignee', - 'confidentiality', - 'source_type', - 'priority', - 'published_on', - 'text', - 'is_assessment_lead', - 'lead_group', - 'url', - 'website', # XXX: Remove this after chrome extension is updated - 'source', - 'authors', - 'emm_triggers', - 'emm_entities', - 'connector_lead', - 'client_id', # From TempClientIdMixin + "id", + "title", + "attachment", + "status", + "assignee", + "confidentiality", + "source_type", + "priority", + "published_on", + "text", + "is_assessment_lead", + "lead_group", + "url", + "website", # XXX: Remove this after chrome extension is updated + "source", + "authors", + "emm_triggers", + "emm_entities", + "connector_lead", + "client_id", # From TempClientIdMixin ) def validate_attachment(self, attachment): @@ -448,26 +478,26 @@ def validate_attachment(self, attachment): if attachment is None or (self.instance and self.instance.attachment_id == attachment.id): return attachment # For new attachment make sure user have permission to attach. - if attachment.created_by != self.context['request'].user: + if attachment.created_by != self.context["request"].user: raise serializers.ValidationError("Attachment not found or you don't have the permission!") return attachment def validate_assignee(self, assignee_id): assignee = self.project.get_all_members().filter(id=assignee_id).first() if assignee is None: - raise serializers.ValidationError('Only project members can be assigneed') + raise serializers.ValidationError("Only project members can be assigneed") return assignee def validate_lead_group(self, lead_group): if lead_group and lead_group.project_id != self.project.id: - raise serializers.ValidationError('LeadGroup should have same project as lead project') + raise serializers.ValidationError("LeadGroup should have same project as lead project") return lead_group def validate_is_assessment_lead(self, value): # Allow setting True # For False make sure there are no assessment attached. - if value is False and hasattr(self.instance, 'assessment'): - raise serializers.ValidationError('Lead already has an assessment.') + if value is False and hasattr(self.instance, "assessment"): + raise serializers.ValidationError("Lead already has an assessment.") return value def validate(self, data): @@ -475,28 +505,33 @@ def validate(self, data): This validator makes sure there is no duplicate leads in a project """ # Using active project here. - data.pop('website', None) # XXX: Remove this after chrome extension is updated - data['project'] = self.project - attachment = data.get('attachment', self.instance and self.instance.attachment) - source_type = data.get('source_type', self.instance and self.instance.source_type) - text = data.get('text', self.instance and self.instance.text) - url = data.get('url', self.instance and self.instance.url) + data.pop("website", None) # XXX: Remove this after chrome extension is updated + data["project"] = self.project + attachment = data.get("attachment", self.instance and self.instance.attachment) + source_type = data.get("source_type", self.instance and self.instance.source_type) + text = data.get("text", self.instance and self.instance.text) + url = data.get("url", self.instance and self.instance.url) raise_or_return_existing_lead( - data['project'], self.instance, source_type, url, text, attachment, + data["project"], + self.instance, + source_type, + url, + text, + attachment, return_lead=False, # Raise exception ) return data def create(self, validated_data): - assignee = validated_data.pop('assignee', None) + assignee = validated_data.pop("assignee", None) # Pop out emm values from validated_data - emm_triggers = validated_data.pop('emm_triggers', []) - emm_entities = validated_data.pop('emm_entities', []) + emm_triggers = validated_data.pop("emm_triggers", []) + emm_entities = validated_data.pop("emm_entities", []) # Create new lead lead = super().create(validated_data) # Save emm entities for entity in emm_entities: - entity = EMMEntity.objects.filter(name=entity['name']).first() + entity = EMMEntity.objects.filter(name=entity["name"]).first() if entity is None: continue lead.emm_entities.add(entity) @@ -526,11 +561,11 @@ def create(self, validated_data): return lead def update(self, instance, validated_data): - has_assignee = 'assignee' in validated_data # For parital updates - assignee = validated_data.pop('assignee', None) + has_assignee = "assignee" in validated_data # For parital updates + assignee = validated_data.pop("assignee", None) # Pop out emm values from validated_data (Only allowed in creation) - validated_data.pop('emm_triggers', []) - validated_data.pop('emm_entities', []) + validated_data.pop("emm_triggers", []) + validated_data.pop("emm_entities", []) # Save lead lead = super().update(instance, validated_data) if has_assignee: @@ -544,10 +579,7 @@ class LeadCopyGqSerializer(ProjectPropertySerializerMixin, serializers.Serialize MAX_PROJECTS_ALLOWED = 10 MAX_LEADS_ALLOWED = 100 - projects = serializers.ListField( - child=IntegerIDField(), - required=True - ) + projects = serializers.ListField(child=IntegerIDField(), required=True) leads = serializers.ListField( child=IntegerIDField(), required=True, @@ -556,18 +588,20 @@ class LeadCopyGqSerializer(ProjectPropertySerializerMixin, serializers.Serialize def validate_projects(self, projects_id): projects_id = list( ProjectMembership.objects.filter( - member=self.context['request'].user, + member=self.context["request"].user, role__type__in=PP.REVERSE_PERMISSION_MAP[PP.Permission.CREATE_LEAD], project__in=projects_id, - ).values_list('project', flat=True).distinct() + ) + .values_list("project", flat=True) + .distinct() ) count = len(projects_id) if count > self.MAX_PROJECTS_ALLOWED: - raise serializers.ValidationError(f'Only {self.MAX_PROJECTS_ALLOWED} are allowed. Provided: {count}') + raise serializers.ValidationError(f"Only {self.MAX_PROJECTS_ALLOWED} are allowed. Provided: {count}") return projects_id def validate_leads(self, leads_id): - allowed_permission = self.context['request'].project_permissions + allowed_permission = self.context["request"].project_permissions lead_qs = Lead.objects.filter( id__in=leads_id, project=self.project, @@ -580,7 +614,7 @@ def validate_leads(self, leads_id): raise serializers.ValidationError("You don't have lead read access") count = lead_qs.count() if count > self.MAX_LEADS_ALLOWED: - raise serializers.ValidationError(f'Only {self.MAX_LEADS_ALLOWED} are allowed. Provided: {count}') + raise serializers.ValidationError(f"Only {self.MAX_LEADS_ALLOWED} are allowed. Provided: {count}") return lead_qs def clone_lead(self, original_lead, project_id, user): @@ -594,18 +628,12 @@ def _get_clone_ready(obj, lead): new_lead.pk = None new_lead.uuid = python_uuid.uuid4() existing_lead = raise_or_return_existing_lead( - project_id, - new_lead, - new_lead.source_type, - new_lead.url, - new_lead.text, - new_lead.attachment, - return_lead=True + project_id, new_lead, new_lead.source_type, new_lead.url, new_lead.text, new_lead.attachment, return_lead=True ) if existing_lead: return - preview = original_lead.leadpreview if hasattr(original_lead, 'leadpreview') else None + preview = original_lead.leadpreview if hasattr(original_lead, "leadpreview") else None preview_images = original_lead.images.all() emm_triggers = original_lead.emm_triggers.all() emm_entities = original_lead.emm_entities.all() @@ -618,7 +646,7 @@ def _get_clone_ready(obj, lead): # update the fields for copied lead new_lead.created_at = timezone.now() - new_lead.created_by = new_lead.modified_by = self.context['request'].user + new_lead.created_by = new_lead.modified_by = self.context["request"].user new_lead.status = Lead.Status.NOT_TAGGED new_lead.save() @@ -634,19 +662,15 @@ def _get_clone_ready(obj, lead): new_lead.authors.set(authors) # Clone Many to one Fields - LeadPreviewImage.objects.bulk_create([ - _get_clone_ready(image, new_lead) for image in preview_images - ]) - LeadEMMTrigger.objects.bulk_create([ - _get_clone_ready(emm_trigger, new_lead) for emm_trigger in emm_triggers - ]) + LeadPreviewImage.objects.bulk_create([_get_clone_ready(image, new_lead) for image in preview_images]) + LeadEMMTrigger.objects.bulk_create([_get_clone_ready(emm_trigger, new_lead) for emm_trigger in emm_triggers]) return new_lead def create(self, validated_data): - projects_id = validated_data.get('projects', []) - leads = validated_data.get('leads', []) - user = self.context['request'].user + projects_id = validated_data.get("projects", []) + leads = validated_data.get("leads", []) + user = self.context["request"].user new_leads = [] for project_id in projects_id: for lead in leads: @@ -661,16 +685,16 @@ class UserSavedLeadFilterSerializer(ProjectPropertySerializerMixin, serializers. class Meta: model = UserSavedLeadFilter fields = ( - 'title', - 'filters', + "title", + "filters", ) def validate(self, data): - existing_qs = UserSavedLeadFilter.objects\ - .filter(user=self.current_user, project=self.project)\ - .exclude(id=self.instance and self.instance.pk) + existing_qs = UserSavedLeadFilter.objects.filter(user=self.current_user, project=self.project).exclude( + id=self.instance and self.instance.pk + ) if existing_qs.exists(): - raise serializers.ValidationError('Only one filter save is allowed for now.') - data['project'] = self.project - data['user'] = self.current_user + raise serializers.ValidationError("Only one filter save is allowed for now.") + data["project"] = self.project + data["user"] = self.current_user return data diff --git a/apps/lead/tasks.py b/apps/lead/tasks.py index ffbc3ed032..6574bc7694 100644 --- a/apps/lead/tasks.py +++ b/apps/lead/tasks.py @@ -2,11 +2,11 @@ from datetime import timedelta from celery import shared_task +from deepl_integration.handlers import LeadExtractionHandler from django.db.models import Q from django.utils import timezone from utils.common import redis_lock -from deepl_integration.handlers import LeadExtractionHandler from .models import Lead @@ -14,7 +14,7 @@ @shared_task(bind=True, max_retries=10) -@redis_lock('lead_extraction_{0}', 60 * 60 * 0.5) +@redis_lock("lead_extraction_{0}", 60 * 60 * 0.5) def extract_from_lead(self, lead_id): """ A task to auto extract text and images from a lead. @@ -28,23 +28,22 @@ def extract_from_lead(self, lead_id): return return LeadExtractionHandler.trigger_lead_extract(lead, task_instance=self) except Exception: - logger.error('Lead Core Extraction Failed!!', exc_info=True) + logger.error("Lead Core Extraction Failed!!", exc_info=True) @shared_task def generate_previews(lead_ids=None): """Generate previews of leads which do not have preview""" lead_ids = lead_ids or Lead.objects.filter( - Q(leadpreview__isnull=True) | - Q(leadpreview__text_extract=''), - ).values_list('id', flat=True) + Q(leadpreview__isnull=True) | Q(leadpreview__text_extract=""), + ).values_list("id", flat=True) for lead_id in lead_ids: extract_from_lead.apply_async((lead_id,), countdown=1) @shared_task -@redis_lock('remaining_lead_extract', 60 * 60 * 0.5) +@redis_lock("remaining_lead_extract", 60 * 60 * 0.5) def remaining_lead_extract(): """ This task looks for pending, failed, retrying leads which are dangling. @@ -66,6 +65,6 @@ def remaining_lead_extract(): count = queryset.count() if count == 0: continue - logger.info(f'[Lead Extraction] {status.label}: {count}') - for lead_id in queryset.values_list('id', flat=True)[:PROCCESS_LEADS_PER_STATUS]: + logger.info(f"[Lead Extraction] {status.label}: {count}") + for lead_id in queryset.values_list("id", flat=True)[:PROCCESS_LEADS_PER_STATUS]: extract_from_lead(lead_id) diff --git a/apps/lead/tests/test_apis.py b/apps/lead/tests/test_apis.py index 681e717ca0..dfb0de9259 100644 --- a/apps/lead/tests/test_apis.py +++ b/apps/lead/tests/test_apis.py @@ -1,73 +1,61 @@ import logging -from datetime import date import uuid +from datetime import date +from unittest import mock -from django.db.models import Q -from django.core.files.uploadedfile import SimpleUploadedFile - -from rest_framework.exceptions import ErrorDetail - -from utils.common import UidBase64Helper -from deep.tests import TestCase - -from user.models import User -from user.serializers import SimpleUserSerializer -from project.models import ( - Project, ProjectMembership, - ProjectUserGroupMembership, +from ary.models import Assessment +from deepl_integration.handlers import ( + AutoAssistedTaggingDraftEntryHandler, + LeadExtractionHandler, ) -from project.serializers import SimpleProjectSerializer +from deepl_integration.serializers import DeeplServerBaseCallbackSerializer +from django.core.files.uploadedfile import SimpleUploadedFile +from django.db.models import Q +from entry.models import Entry, EntryGroupLabel, LeadEntryGroup, ProjectEntryLabel from geo.models import Region - -from organization.models import ( - Organization, - OrganizationType, -) -from organization.serializers import SimpleOrganizationSerializer +from lead.factories import LeadFactory, LeadPreviewFactory from lead.filter_set import LeadFilterSet -from lead.serializers import SimpleLeadGroupSerializer -from deepl_integration.handlers import AutoAssistedTaggingDraftEntryHandler, LeadExtractionHandler -from deepl_integration.serializers import DeeplServerBaseCallbackSerializer -from entry.models import ( - Entry, - ProjectEntryLabel, - LeadEntryGroup, - EntryGroupLabel, -) from lead.models import ( - Lead, - LeadPreview, - LeadPreviewImage, EMMEntity, + Lead, LeadEMMTrigger, LeadGroup, + LeadPreview, + LeadPreviewImage, ) -from user_group.models import UserGroup, GroupMembership -from ary.models import Assessment -from lead.factories import LeadFactory, LeadPreviewFactory -from unittest import mock +from lead.serializers import SimpleLeadGroupSerializer +from organization.models import Organization, OrganizationType +from organization.serializers import SimpleOrganizationSerializer +from project.models import Project, ProjectMembership, ProjectUserGroupMembership +from project.serializers import SimpleProjectSerializer +from rest_framework.exceptions import ErrorDetail +from user.models import User +from user.serializers import SimpleUserSerializer +from user_group.models import GroupMembership, UserGroup +from deep.tests import TestCase +from utils.common import UidBase64Helper logger = logging.getLogger(__name__) # Organization data RELIEFWEB_DATA = { - 'title': 'Reliefweb', - 'short_name': 'reliefweb', - 'long_name': 'reliefweb.int', - 'url': 'https://reliefweb.int', + "title": "Reliefweb", + "short_name": "reliefweb", + "long_name": "reliefweb.int", + "url": "https://reliefweb.int", } REDHUM_DATA = { - 'title': 'Redhum', - 'short_name': 'redhum', - 'long_name': 'redhum.org', - 'url': 'https://redhum.org', + "title": "Redhum", + "short_name": "redhum", + "long_name": "redhum.org", + "url": "https://redhum.org", } UNHCR_DATA = { - 'title': 'UNHCR', - 'short_name': 'unhcr', - 'long_name': 'United Nations High Commissioner for Refugees', - 'url': 'https://www.unhcr.org', + "title": "UNHCR", + "short_name": "unhcr", + "long_name": "United Nations High Commissioner for Refugees", + "url": "https://www.unhcr.org", } @@ -83,16 +71,16 @@ def test_create_lead(self, assignee=None): lead_count = Lead.objects.count() project = self.create(Project, role=self.admin_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': assignee or self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": assignee or self.user.id, } self.authenticate() @@ -101,26 +89,26 @@ def test_create_lead(self, assignee=None): self.assertEqual(Lead.objects.count(), lead_count + 1) r_data = response.json() - self.assertEqual(r_data['title'], data['title']) - self.assertEqual(r_data['assignee'], self.user.id) + self.assertEqual(r_data["title"], data["title"]) + self.assertEqual(r_data["assignee"], self.user.id) # low is default priority - self.assertEqual(r_data['priority'], Lead.Priority.LOW) + self.assertEqual(r_data["priority"], Lead.Priority.LOW) def test_lead_create_with_status_validated(self, assignee=None): lead_begining = Lead.objects.count() project = self.create(Project, role=self.admin_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': assignee or self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": assignee or self.user.id, } self.authenticate() @@ -129,7 +117,7 @@ def test_lead_create_with_status_validated(self, assignee=None): self.assertEqual(Lead.objects.count(), lead_begining + 1) r_data = response.data - self.assertEqual(r_data['status'], data['status']) + self.assertEqual(r_data["status"], data["status"]) def test_pre_bulk_delete_leads(self): project = self.create(Project) @@ -139,12 +127,12 @@ def test_pre_bulk_delete_leads(self): lead_ids = [lead1.id, lead3.id] admin_user = self.create(User) project.add_member(admin_user, role=self.admin_role) - url = '/api/v1/project/{}/leads/dry-bulk-delete/'.format(project.id) + url = "/api/v1/project/{}/leads/dry-bulk-delete/".format(project.id) self.authenticate(admin_user) - response = self.client.post(url, {'leads': lead_ids}) + response = self.client.post(url, {"leads": lead_ids}) self.assert_200(response) r_data = response.data - self.assertIn('entries', r_data) + self.assertIn("entries", r_data) def test_bulk_delete_leads(self): project = self.create(Project) @@ -154,19 +142,19 @@ def test_bulk_delete_leads(self): lead_count = Lead.objects.count() lead_ids = [lead1.id, lead3.id] - url = '/api/v1/project/{}/leads/bulk-delete/'.format(project.id) + url = "/api/v1/project/{}/leads/bulk-delete/".format(project.id) # calling without delete permissions view_user = self.create(User) project.add_member(view_user, role=self.view_only_role) self.authenticate(view_user) - response = self.client.post(url, {'leads': lead_ids}) + response = self.client.post(url, {"leads": lead_ids}) self.assert_403(response) admin_user = self.create(User) project.add_member(admin_user, role=self.admin_role) self.authenticate(admin_user) - response = self.client.post(url, {'leads': lead_ids}) + response = self.client.post(url, {"leads": lead_ids}) self.assert_204(response) self.assertLess(Lead.objects.count(), lead_count) @@ -175,17 +163,17 @@ def test_create_high_priority_lead(self, assignee=None): lead_count = Lead.objects.count() project = self.create(Project, role=self.admin_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': assignee or self.user.id, - 'priority': Lead.Priority.HIGH, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": assignee or self.user.id, + "priority": Lead.Priority.HIGH, } self.authenticate() @@ -194,94 +182,94 @@ def test_create_high_priority_lead(self, assignee=None): self.assertEqual(Lead.objects.count(), lead_count + 1) r_data = response.json() - self.assertEqual(r_data['title'], data['title']) - self.assertEqual(r_data['assignee'], self.user.id) + self.assertEqual(r_data["title"], data["title"]) + self.assertEqual(r_data["assignee"], self.user.id) # low is default priority - self.assertEqual(r_data['priority'], Lead.Priority.HIGH) + self.assertEqual(r_data["priority"], Lead.Priority.HIGH) def test_create_lead_with_emm(self): - entity1 = self.create(EMMEntity, name='entity1') - entity2 = self.create(EMMEntity, name='entity2') - entity3 = self.create(EMMEntity, name='entity3') + entity1 = self.create(EMMEntity, name="entity1") + entity2 = self.create(EMMEntity, name="entity2") + entity3 = self.create(EMMEntity, name="entity3") lead_count = Lead.objects.count() project = self.create(Project, role=self.admin_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': self.user.id, - 'emm_entities': [ - {'name': entity1.name}, - {'name': entity2.name}, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": self.user.id, + "emm_entities": [ + {"name": entity1.name}, + {"name": entity2.name}, + ], + "emm_triggers": [ + {"emm_keyword": "kw", "emm_risk_factor": "rf", "count": 3}, + {"emm_keyword": "kw1", "emm_risk_factor": "rf1", "count": 6}, ], - 'emm_triggers': [ - {'emm_keyword': 'kw', 'emm_risk_factor': 'rf', 'count': 3}, - {'emm_keyword': 'kw1', 'emm_risk_factor': 'rf1', 'count': 6}, - ] } self.authenticate() - response = self.client.post(url, data, format='json') + response = self.client.post(url, data, format="json") self.assert_201(response) self.assertEqual(Lead.objects.count(), lead_count + 1) r_data = response.data - self.assertEqual(r_data['title'], data['title']) - self.assertEqual(r_data['assignee'], self.user.id) + self.assertEqual(r_data["title"], data["title"]) + self.assertEqual(r_data["assignee"], self.user.id) - assert 'emm_entities' in r_data - assert 'emm_triggers' in r_data - assert len(r_data['emm_entities']) == 2 - assert len(r_data['emm_triggers']) == 2 + assert "emm_entities" in r_data + assert "emm_triggers" in r_data + assert len(r_data["emm_entities"]) == 2 + assert len(r_data["emm_triggers"]) == 2 - lead_id = r_data['id'] + lead_id = r_data["id"] # Check emm triggers created assert LeadEMMTrigger.objects.filter(lead_id=lead_id).count() == 2 # This should not change anything in the database - data['emm_triggers'] = None - data['emm_entities'] = [{'name': entity3.name}] - response = self.client.put(f"{url}{r_data['id']}/", data, format='json') + data["emm_triggers"] = None + data["emm_entities"] = [{"name": entity3.name}] + response = self.client.put(f"{url}{r_data['id']}/", data, format="json") self.assert_200(response) - assert 'emm_entities' in r_data - assert 'emm_triggers' in r_data - assert len(r_data['emm_entities']) == 2 - assert len(r_data['emm_triggers']) == 2 + assert "emm_entities" in r_data + assert "emm_triggers" in r_data + assert len(r_data["emm_entities"]) == 2 + assert len(r_data["emm_triggers"]) == 2 def test_get_lead_check_no_of_entries(self, assignee=None): project = self.create(Project, role=self.admin_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': assignee or self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": assignee or self.user.id, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - url = '/api/v1/leads/' + url = "/api/v1/leads/" response = self.client.get(url) r_data = response.json() - assert 'entriesCount' in r_data['results'][0] + assert "entriesCount" in r_data["results"][0] def test_create_lead_no_create_role(self, assignee=None): lead_count = Lead.objects.count() @@ -290,16 +278,16 @@ def test_create_lead_no_create_role(self, assignee=None): test_user = self.create(User) project.add_member(test_user, role=self.view_only_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': assignee or self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": assignee or self.user.id, } self.authenticate(test_user) @@ -311,7 +299,7 @@ def test_create_lead_no_create_role(self, assignee=None): def test_delete_lead(self): project = self.create(Project, role=self.admin_role) lead = self.create(Lead, project=project) - url = '/api/v1/leads/{}/'.format(lead.id) + url = "/api/v1/leads/{}/".format(lead.id) self.authenticate() response = self.client.delete(url) @@ -324,7 +312,7 @@ def test_delete_lead_no_perm(self): project.add_member(user, self.view_only_role) - url = '/api/v1/leads/{}/'.format(lead.id) + url = "/api/v1/leads/{}/".format(lead.id) self.authenticate(user) response = self.client.delete(url) @@ -343,9 +331,9 @@ def test_update_assignee(self): self.create(ProjectMembership, project=project, member=user) lead = self.create(Lead, project=project) - url = '/api/v1/leads/{}/'.format(lead.id) + url = "/api/v1/leads/{}/".format(lead.id) data = { - 'assignee': user.id, + "assignee": user.id, } self.authenticate() @@ -353,20 +341,19 @@ def test_update_assignee(self): self.assert_200(response) r_data = response.json() - self.assertEqual(r_data['assignee'], user.id) + self.assertEqual(r_data["assignee"], user.id) lead = Lead.objects.get(id=lead.id) self.assertEqual(lead.get_assignee().id, user.id) def test_options(self): def check_default_options(rdata): for option, value in DEFAULT_OPTIONS.items(): - assert rdata[option] == value, f'value should be same for <{option}> DEFAULT_OPTIONS' + assert rdata[option] == value, f"value should be same for <{option}> DEFAULT_OPTIONS" def assert_id(returned_obj, excepted_obj): - assert set([obj['id'] for obj in returned_obj]) ==\ - set([obj['id'] for obj in excepted_obj]) + assert set([obj["id"] for obj in returned_obj]) == set([obj["id"] for obj in excepted_obj]) - url = '/api/v1/lead-options/' + url = "/api/v1/lead-options/" # Project project = self.create(Project) @@ -390,32 +377,26 @@ def assert_id(returned_obj, excepted_obj): # Default options DEFAULT_OPTIONS = { - 'confidentiality': [ - {'key': c[0], 'value': c[1]} for c in Lead.Confidentiality.choices - ], - 'status': [ - {'key': s[0], 'value': s[1]} for s in Lead.Status.choices - ], - 'priority': [ - {'key': s[0], 'value': s[1]} for s in Lead.Priority.choices - ], + "confidentiality": [{"key": c[0], "value": c[1]} for c in Lead.Confidentiality.choices], + "status": [{"key": s[0], "value": s[1]} for s in Lead.Status.choices], + "priority": [{"key": s[0], "value": s[1]} for s in Lead.Priority.choices], } self.authenticate(user) # 404 if user is not member of any one of the projects data = { - 'projects': [project.pk + 1], + "projects": [project.pk + 1], } response = self.client.post(url, data) self.assert_404(response) # 200 if user is member of one of the project [Also other data are filtered by those projects] data = { - 'projects': [project.pk + 1, project.pk], - 'leadGroups': [lead_group1.pk, lead_group2.pk], - 'members': [user1.pk, user2.pk], - 'organizations': [reliefweb.pk, unhcr.pk] + "projects": [project.pk + 1, project.pk], + "leadGroups": [lead_group1.pk, lead_group2.pk], + "members": [user1.pk, user2.pk], + "organizations": [reliefweb.pk, unhcr.pk], } response = self.client.post(url, data) @@ -424,47 +405,47 @@ def assert_id(returned_obj, excepted_obj): self.maxDiff = None # Only members(all) are returned when requested is None data = { - 'projects': [project.pk], + "projects": [project.pk], } response = self.client.post(url, data) rdata = response.data - assert 'has_emm_leads' in rdata - assert not rdata['has_emm_leads'], 'There should be no emm leads in the project' - assert_id(rdata['members'], SimpleUserSerializer([user1, user2, user], many=True).data) - assert rdata['projects'] == SimpleProjectSerializer([project], many=True).data - assert rdata['lead_groups'] == [] - assert rdata['organizations'] == [] + assert "has_emm_leads" in rdata + assert not rdata["has_emm_leads"], "There should be no emm leads in the project" + assert_id(rdata["members"], SimpleUserSerializer([user1, user2, user], many=True).data) + assert rdata["projects"] == SimpleProjectSerializer([project], many=True).data + assert rdata["lead_groups"] == [] + assert rdata["organizations"] == [] check_default_options(rdata) # If value are provided respective data are provided (filtered by permission) data = { - 'projects': [project.pk], - 'leadGroups': [lead_group2.pk], - 'members': [user1.pk, user2.pk, out_user.pk], - 'organizations': [unhcr.pk] + "projects": [project.pk], + "leadGroups": [lead_group2.pk], + "members": [user1.pk, user2.pk, out_user.pk], + "organizations": [unhcr.pk], } response = self.client.post(url, data) rdata = response.data - assert_id(rdata['members'], SimpleUserSerializer([user1, user2], many=True).data) - assert rdata['projects'] == SimpleProjectSerializer([project], many=True).data - assert rdata['lead_groups'] == SimpleLeadGroupSerializer([lead_group2], many=True).data - assert rdata['organizations'] == SimpleOrganizationSerializer([unhcr], many=True).data + assert_id(rdata["members"], SimpleUserSerializer([user1, user2], many=True).data) + assert rdata["projects"] == SimpleProjectSerializer([project], many=True).data + assert rdata["lead_groups"] == SimpleLeadGroupSerializer([lead_group2], many=True).data + assert rdata["organizations"] == SimpleOrganizationSerializer([unhcr], many=True).data check_default_options(rdata) - assert 'emm_entities' in rdata - assert 'emm_keywords' in rdata - assert 'emm_risk_factors' in rdata + assert "emm_entities" in rdata + assert "emm_keywords" in rdata + assert "emm_risk_factors" in rdata def test_emm_options_post(self): - url = '/api/v1/lead-options/' + url = "/api/v1/lead-options/" project = self.create_project() # Create Entities - entity1 = self.create(EMMEntity, name='entity1') - entity2 = self.create(EMMEntity, name='entity2') - entity3 = self.create(EMMEntity, name='enitty3') - entity4 = self.create(EMMEntity, name='entity4') # noqa:F841 + entity1 = self.create(EMMEntity, name="entity1") + entity2 = self.create(EMMEntity, name="entity2") + entity3 = self.create(EMMEntity, name="enitty3") + entity4 = self.create(EMMEntity, name="entity4") # noqa:F841 lead1 = self.create_lead(project=project, emm_entities=[entity1]) lead2 = self.create_lead(project=project, emm_entities=[entity2, entity3]) @@ -473,46 +454,58 @@ def test_emm_options_post(self): # Create LeadEMMTrigger objects with self.create( - LeadEMMTrigger, lead=lead1, count=5, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead1, + count=5, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead2, count=3, - emm_keyword='keyword1', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead2, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead3, count=3, - emm_keyword='keyword2', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead3, + count=3, + emm_keyword="keyword2", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead4, count=3, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead4, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) data = { - 'projects': [project.id], + "projects": [project.id], } self.authenticate() - response = self.client.post(url, data, format='json') + response = self.client.post(url, data, format="json") self.assert_200(response) data = response.data # No data should be present when not specified in the query - assert 'emm_entities' in data - assert data['emm_entities'] == [] + assert "emm_entities" in data + assert data["emm_entities"] == [] - assert 'emm_keywords' in data - assert data['emm_keywords'] == [] + assert "emm_keywords" in data + assert data["emm_keywords"] == [] - assert 'emm_risk_factors' in data - assert data['emm_risk_factors'] == [] + assert "emm_risk_factors" in data + assert data["emm_risk_factors"] == [] - assert 'has_emm_leads' in data - assert data['has_emm_leads'], 'There are emm leads' + assert "has_emm_leads" in data + assert data["has_emm_leads"], "There are emm leads" data = { - 'projects': [project.id], - 'emm_risk_factors': ['rf1'], # Only risk factors present + "projects": [project.id], + "emm_risk_factors": ["rf1"], # Only risk factors present } self.authenticate() response = self.client.post(url, data) @@ -520,29 +513,29 @@ def test_emm_options_post(self): data = response.data # Check emm_entities - assert 'emm_entities' in data - assert not data['emm_entities'], 'Entities not specified.' + assert "emm_entities" in data + assert not data["emm_entities"], "Entities not specified." # Check emm_risk_factors - assert 'emm_risk_factors' in data - expected_risk_factors_count_set = {('rf1', 'rf1', 8)} - result_risk_factors_count_set = {(x['key'], x['label'], x['total_count']) for x in data['emm_risk_factors']} + assert "emm_risk_factors" in data + expected_risk_factors_count_set = {("rf1", "rf1", 8)} + result_risk_factors_count_set = {(x["key"], x["label"], x["total_count"]) for x in data["emm_risk_factors"]} assert expected_risk_factors_count_set == result_risk_factors_count_set # Check emm_keywords - assert 'emm_keywords' in data - assert not data['emm_entities'], 'Keywords not specified.' + assert "emm_keywords" in data + assert not data["emm_entities"], "Keywords not specified." - assert 'has_emm_leads' in data - assert data['has_emm_leads'], 'There are emm leads' + assert "has_emm_leads" in data + assert data["has_emm_leads"], "There are emm leads" def test_options_assignees_get(self): - url = '/api/v1/lead-options/?projects={}' + url = "/api/v1/lead-options/?projects={}" user = self.create(User) - project = self.create(Project, title='p1') # self.user is member + project = self.create(Project, title="p1") # self.user is member project.add_member(user) # Add user to project project.add_member(self.user) - project1 = self.create(Project, title='p2') + project1 = self.create(Project, title="p2") project1.add_member(self.user) # Add usergroup as well @@ -553,23 +546,23 @@ def test_options_assignees_get(self): ProjectUserGroupMembership.objects.create(project=project, usergroup=usergroup) - projects = f'{project.id}' + projects = f"{project.id}" self.authenticate() resp = self.client.get(url.format(projects)) self.assert_200(resp) data = resp.data - assignee_ids = [int(x['key']) for x in data['assignee']] + assignee_ids = [int(x["key"]) for x in data["assignee"]] - assert 'assignee' in data + assert "assignee" in data # BOTH users should be in assignee since only one project is requested assert self.user.id in assignee_ids assert user.id in assignee_ids assert ugmember.id in assignee_ids assert non_member.id not in assignee_ids - projects = f'{project.id},{project1.id}' + projects = f"{project.id},{project1.id}" self.authenticate() resp = self.client.get(url.format(projects)) @@ -578,8 +571,8 @@ def test_options_assignees_get(self): data = resp.data - assert 'assignee' in data - assignee_ids = [int(x['key']) for x in data['assignee']] + assert "assignee" in data + assignee_ids = [int(x["key"]) for x in data["assignee"]] assert self.user.id in assignee_ids assert user.id not in assignee_ids assert non_member.id not in assignee_ids @@ -589,10 +582,10 @@ def test_emm_options_get(self): project1 = self.create_project() # Create Entities - entity1 = self.create(EMMEntity, name='entity1') - entity2 = self.create(EMMEntity, name='entity2') - entity3 = self.create(EMMEntity, name='entity3') - entity4 = self.create(EMMEntity, name='entity4') # noqa:F841 + entity1 = self.create(EMMEntity, name="entity1") + entity2 = self.create(EMMEntity, name="entity2") + entity3 = self.create(EMMEntity, name="entity3") + entity4 = self.create(EMMEntity, name="entity4") # noqa:F841 lead1 = self.create_lead(project=project, emm_entities=[entity1]) lead2 = self.create_lead(project=project, emm_entities=[entity2, entity3]) @@ -605,74 +598,83 @@ def test_emm_options_get(self): # Create LeadEMMTrigger objects with self.create( - LeadEMMTrigger, lead=lead1, count=5, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead1, + count=5, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead2, count=3, - emm_keyword='keyword1', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead2, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead3, count=3, - emm_keyword='keyword2', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead3, + count=3, + emm_keyword="keyword2", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead4, count=3, - emm_keyword='keyword1', emm_risk_factor='', # This should not be present as risk factor + LeadEMMTrigger, + lead=lead4, + count=3, + emm_keyword="keyword1", + emm_risk_factor="", # This should not be present as risk factor ) # NOTE: 3 leads with keyword keyword1, one with keyword2 # 2 leads with factor rf1, 2 with factor rf2 - url = f'/api/v1/lead-options/?projects={project.id}' + url = f"/api/v1/lead-options/?projects={project.id}" self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.data # Check emm_entities - assert 'emm_entities' in data - expected_entity_count_set = { - (entity1.id, entity1.name, 1), - (entity2.id, entity2.name, 2), - (entity3.id, entity3.name, 1)} - result_entity_count_set = {(x['key'], x['label'], x['total_count']) for x in data['emm_entities']} + assert "emm_entities" in data + expected_entity_count_set = {(entity1.id, entity1.name, 1), (entity2.id, entity2.name, 2), (entity3.id, entity3.name, 1)} + result_entity_count_set = {(x["key"], x["label"], x["total_count"]) for x in data["emm_entities"]} assert expected_entity_count_set == result_entity_count_set # Check emm_risk_factors - assert 'emm_risk_factors' in data - expected_risk_factors_count_set = {('rf1', 'rf1', 5), ('rf2', 'rf2', 6)} - result_risk_factors_count_set = {(x['key'], x['label'], x['total_count']) for x in data['emm_risk_factors']} + assert "emm_risk_factors" in data + expected_risk_factors_count_set = {("rf1", "rf1", 5), ("rf2", "rf2", 6)} + result_risk_factors_count_set = {(x["key"], x["label"], x["total_count"]) for x in data["emm_risk_factors"]} assert expected_risk_factors_count_set == result_risk_factors_count_set # Check emm_keywords - assert 'emm_keywords' in data - expected_keywords_count_set = {('keyword1', 'keyword1', 11), ('keyword2', 'keyword2', 3)} - result_keywords_count_set = {(x['key'], x['label'], x['total_count']) for x in data['emm_keywords']} + assert "emm_keywords" in data + expected_keywords_count_set = {("keyword1", "keyword1", 11), ("keyword2", "keyword2", 3)} + result_keywords_count_set = {(x["key"], x["label"], x["total_count"]) for x in data["emm_keywords"]} assert expected_keywords_count_set == result_keywords_count_set - assert 'has_emm_leads' in data - assert data['has_emm_leads'], 'There are emm leads' + assert "has_emm_leads" in data + assert data["has_emm_leads"], "There are emm leads" # Now check options for project1, there should be no emm related data - url = f'/api/v1/lead-options/?projects={project1.id}' + url = f"/api/v1/lead-options/?projects={project1.id}" self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.data - assert 'has_emm_leads' in data - assert not data['has_emm_leads'], 'this Project should not have emm' - assert 'emm_risk_factors' in data - assert not data['emm_risk_factors'] - assert 'emm_keywords' in data - assert not data['emm_keywords'] - assert 'emm_entities' in data - assert not data['emm_entities'] + assert "has_emm_leads" in data + assert not data["has_emm_leads"], "this Project should not have emm" + assert "emm_risk_factors" in data + assert not data["emm_risk_factors"] + assert "emm_keywords" in data + assert not data["emm_keywords"] + assert "emm_entities" in data + assert not data["emm_entities"] def test_trigger_api(self): project = self.create(Project, role=self.admin_role) lead = self.create(Lead, project=project) - url = '/api/v1/lead-extraction-trigger/{}/'.format(lead.id) + url = "/api/v1/lead-extraction-trigger/{}/".format(lead.id) self.authenticate() response = self.client.get(url) @@ -684,16 +686,16 @@ def test_multiple_project(self): lead_count = Lead.objects.count() - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'test title', - 'project': [project1.id, project2.id], - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'this is some random text', - 'assignee': self.user.id, + "title": "test title", + "project": [project1.id, project2.id], + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "this is some random text", + "assignee": self.user.id, } self.authenticate() @@ -704,36 +706,32 @@ def test_multiple_project(self): self.assertEqual(Lead.objects.count(), lead_count + 2) self.assertEqual(len(r_data), 2) - self.assertEqual(r_data[0].get('project'), project1.id) - self.assertEqual(r_data[1].get('project'), project2.id) + self.assertEqual(r_data[0].get("project"), project1.id) + self.assertEqual(r_data[1].get("project"), project2.id) def test_url_exists(self): project = self.create(Project, role=self.admin_role) - common_url = 'https://same.com/' - lead1 = self.create(Lead, source_type='website', - project=project, - url=common_url) - lead2 = self.create(Lead, source_type='website', - project=project, - url='https://different.com/') - - url = '/api/v1/leads/' + common_url = "https://same.com/" + lead1 = self.create(Lead, source_type="website", project=project, url=common_url) + lead2 = self.create(Lead, source_type="website", project=project, url="https://different.com/") + + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'source_type': 'website', - 'url': common_url, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "source_type": "website", + "url": common_url, } self.authenticate() response = self.client.post(url, data) self.assert_400(response) - url = '/api/v1/leads/{}/'.format(lead2.id) + url = "/api/v1/leads/{}/".format(lead2.id) data = { - 'url': common_url, + "url": common_url, } response = self.client.patch(url, data) @@ -741,15 +739,13 @@ def test_url_exists(self): # This should not be raised while editing same lead - url = '/api/v1/leads/{}/'.format(lead1.id) - data = { - 'title': 'Spaceship allegedly spotted in sky' - } + url = "/api/v1/leads/{}/".format(lead1.id) + data = {"title": "Spaceship allegedly spotted in sky"} response = self.client.patch(url, data) self.assert_200(response) def test_lead_copy_from_project_with_only_view(self): - url = '/api/v1/lead-copy/' + url = "/api/v1/lead-copy/" source_project = self.create(Project, role=self.view_only_role) dest_project = self.create(Project, role=self.admin_role) @@ -758,46 +754,46 @@ def test_lead_copy_from_project_with_only_view(self): leads_count = Lead.objects.all().count() data = { - 'projects': [dest_project.pk], - 'leads': [lead.pk], + "projects": [dest_project.pk], + "leads": [lead.pk], } self.authenticate() response = self.client.post(url, data) self.assert_403(response) - assert leads_count == Lead.objects.all().count(), 'No new lead should have been created' + assert leads_count == Lead.objects.all().count(), "No new lead should have been created" def test_lead_copy(self): - url = '/api/v1/lead-copy/' + url = "/api/v1/lead-copy/" # Projects [Source] # NOTE: make sure the source projects have create/edit permissions - project1s = self.create(Project, title='project1s', role=self.admin_role) - project2s = self.create(Project, title='project2s', role=self.admin_role) - project3s = self.create(Project, title='project3s') - project4s = self.create(Project, title='project4s', role=self.normal_role) + project1s = self.create(Project, title="project1s", role=self.admin_role) + project2s = self.create(Project, title="project2s", role=self.admin_role) + project3s = self.create(Project, title="project3s") + project4s = self.create(Project, title="project4s", role=self.normal_role) # Projects [Destination] - project1d = self.create(Project, title='project1d') - project2d = self.create(Project, title='project2d', role=self.admin_role) - project3d = self.create(Project, title='project3d', role=self.admin_role) - project4d = self.create(Project, title='project4d', role=self.view_only_role) + project1d = self.create(Project, title="project1d") + project2d = self.create(Project, title="project2d", role=self.admin_role) + project3d = self.create(Project, title="project3d", role=self.admin_role) + project4d = self.create(Project, title="project4d", role=self.view_only_role) # Lead1 Info (Will be used later for testing) - lead1_title = 'Lead 1 2019--222-' - lead1_text_extract = 'This is a test text extract' - lead1_preview_file = 'invalid_test_file' - author = self.create(Organization, title='blablaone') - author2 = self.create(Organization, title='blablatwo') - emm_keyword = 'emm1' - emm_risk_factor = 'risk1' + lead1_title = "Lead 1 2019--222-" + lead1_text_extract = "This is a test text extract" + lead1_preview_file = "invalid_test_file" + author = self.create(Organization, title="blablaone") + author2 = self.create(Organization, title="blablatwo") + emm_keyword = "emm1" + emm_risk_factor = "risk1" emm_count = 22 - emm_entity_name = 'emm_entity_11' + emm_entity_name = "emm_entity_11" # Generate Leads lead1 = self.create( - Lead, title=lead1_title, project=project1s, source_type=Lead.SourceType.WEBSITE, url='http://example.com' + Lead, title=lead1_title, project=project1s, source_type=Lead.SourceType.WEBSITE, url="http://example.com" ) lead1.authors.set([author, author2]) lead2 = self.create(Lead, project=project2s) @@ -805,26 +801,25 @@ def test_lead_copy(self): lead4 = self.create(Lead, project=project4s) # For duplicate url validation check - self.create( - Lead, title=lead1_title, project=project2d, source_type=Lead.SourceType.WEBSITE, url='http://example.com' - ) + self.create(Lead, title=lead1_title, project=project2d, source_type=Lead.SourceType.WEBSITE, url="http://example.com") # Generating Foreign elements for lead1 self.create(LeadPreview, lead=lead1, text_extract=lead1_text_extract) self.create(LeadPreviewImage, lead=lead1, file=lead1_preview_file) emm_trigger = self.create( - LeadEMMTrigger, lead=lead1, emm_keyword=emm_keyword, emm_risk_factor=emm_risk_factor, count=emm_count) + LeadEMMTrigger, lead=lead1, emm_keyword=emm_keyword, emm_risk_factor=emm_risk_factor, count=emm_count + ) lead1.emm_entities.set([self.create(EMMEntity, name=emm_entity_name)]) # Request body data [also contains unauthorized projects and leads] data = { - 'projects': sorted([project4d.pk, project3d.pk, project2d.pk, project1d.pk, project1s.pk]), - 'leads': sorted([lead3.pk, lead2.pk, lead1.pk, lead4.pk]), + "projects": sorted([project4d.pk, project3d.pk, project2d.pk, project1d.pk, project1s.pk]), + "leads": sorted([lead3.pk, lead2.pk, lead1.pk, lead4.pk]), } # data [only contains authorized projects and leads] validate_data = { - 'projects': sorted([project3d.pk, project2d.pk, project1s.pk]), - 'leads': sorted([lead4.pk, lead2.pk, lead1.pk]), + "projects": sorted([project3d.pk, project2d.pk, project1s.pk]), + "leads": sorted([lead4.pk, lead2.pk, lead1.pk]), } lead_stats = [ @@ -833,7 +828,6 @@ def test_lead_copy(self): (project2s, 1, 1), (project3s, 1, 1), (project4s, 1, 1), - (project1d, 0, 0), (project2d, 0, 3), (project3d, 0, 3), @@ -847,8 +841,8 @@ def test_lead_copy(self): rdata = response.json() # Sort the data since we are comparing lists sorted_rdata = { - 'projects': sorted(rdata['projects']), - 'leads': sorted(rdata['leads']), + "projects": sorted(rdata["projects"]), + "leads": sorted(rdata["leads"]), } self.assert_201(response) self.assertNotEqual(sorted_rdata, data) @@ -856,24 +850,18 @@ def test_lead_copy(self): for project, old_lead_count, new_lead_count in lead_stats: current_lead_count = Lead.objects.filter(project_id=project.pk).count() - assert new_lead_count == current_lead_count, f'Project: {project.title} lead count is different' + assert new_lead_count == current_lead_count, f"Project: {project.title} lead count is different" # Test Foreign Fields self.assertEqual( - Lead.objects.filter(title=lead1_title).count(), - 3, - 'Should have been 3: Original + Custom created + Copy(of original)' + Lead.objects.filter(title=lead1_title).count(), 3, "Should have been 3: Original + Custom created + Copy(of original)" ) self.assertEqual( - Lead.objects.filter(title=lead1_title).exclude( - Q(pk=lead1.pk) | Q(project=project2d) - ).count(), + Lead.objects.filter(title=lead1_title).exclude(Q(pk=lead1.pk) | Q(project=project2d)).count(), 1, - 'Should have been 1: Copy(of original)' + "Should have been 1: Copy(of original)", ) - lead1_copy = Lead.objects.filter(title=lead1_title).exclude( - Q(pk=lead1.pk) | Q(project=project2d) - ).get() + lead1_copy = Lead.objects.filter(title=lead1_title).exclude(Q(pk=lead1.pk) | Q(project=project2d)).get() lead1_copy.refresh_from_db() self.assertEqual( lead1_copy.images.count(), @@ -885,28 +873,28 @@ def test_lead_copy(self): ) emm_trigger = lead1_copy.emm_triggers.filter(emm_risk_factor=emm_risk_factor, emm_keyword=emm_keyword)[0] assert lead1_copy.authors.count() == 2 - assert sorted(lead1_copy.authors.values_list('id', flat=True)) == [author.id, author2.id] + assert sorted(lead1_copy.authors.values_list("id", flat=True)) == [author.id, author2.id] assert lead1_copy.leadpreview.text_extract == lead1_text_extract assert lead1_copy.images.all()[0].file == lead1_preview_file assert emm_trigger.count == emm_count assert lead1_copy.emm_entities.all()[0].name == emm_entity_name def test_lead_duplicate_validation(self): - url = '/api/v1/leads/' + url = "/api/v1/leads/" project = self.create_project() file = self.create_gallery_file() # Test using FILE (HASH) data = { - 'title': 'test title', - 'project': project.pk, - 'source': self.source.pk, - 'author': self.author.pk, - 'source_type': Lead.SourceType.DISK, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'attachment': {'id': file.pk}, - 'assignee': self.user.id, + "title": "test title", + "project": project.pk, + "source": self.source.pk, + "author": self.author.pk, + "source_type": Lead.SourceType.DISK, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "attachment": {"id": file.pk}, + "assignee": self.user.id, } self.authenticate() @@ -918,15 +906,15 @@ def test_lead_duplicate_validation(self): # Test using TEXT data = { - 'title': 'test title', - 'project': project.pk, - 'source': self.source.pk, - 'author': self.author.pk, - 'source_type': Lead.SourceType.TEXT, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'duplication test 101', - 'assignee': self.user.id, + "title": "test title", + "project": project.pk, + "source": self.source.pk, + "author": self.author.pk, + "source_type": Lead.SourceType.TEXT, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "duplication test 101", + "assignee": self.user.id, } self.authenticate() @@ -945,25 +933,25 @@ def test_lead_order_by_priority(self): self.create_lead(project=project, priority=Lead.Priority.HIGH) self.create_lead(project=project, priority=Lead.Priority.LOW) - url = '/api/v1/leads/?ordering=priority' + url = "/api/v1/leads/?ordering=priority" self.authenticate() response = self.client.get(url) self.assert_200(response) - leads = response.data['results'] - assert leads[0]['priority'] == Lead.Priority.LOW - assert leads[1]['priority'] == Lead.Priority.MEDIUM - assert leads[2]['priority'] == Lead.Priority.HIGH - assert leads[3]['priority'] == Lead.Priority.HIGH + leads = response.data["results"] + assert leads[0]["priority"] == Lead.Priority.LOW + assert leads[1]["priority"] == Lead.Priority.MEDIUM + assert leads[2]["priority"] == Lead.Priority.HIGH + assert leads[3]["priority"] == Lead.Priority.HIGH - url = '/api/v1/leads/?ordering=-priority' + url = "/api/v1/leads/?ordering=-priority" self.authenticate() response = self.client.get(url) self.assert_200(response) - leads = response.data['results'] - assert leads[0]['priority'] == Lead.Priority.HIGH - assert leads[1]['priority'] == Lead.Priority.HIGH - assert leads[2]['priority'] == Lead.Priority.MEDIUM - assert leads[3]['priority'] == Lead.Priority.LOW + leads = response.data["results"] + assert leads[0]["priority"] == Lead.Priority.HIGH + assert leads[1]["priority"] == Lead.Priority.HIGH + assert leads[2]["priority"] == Lead.Priority.MEDIUM + assert leads[3]["priority"] == Lead.Priority.LOW def test_lead_order_by_page_count(self): # Create lead and lead_previews @@ -980,26 +968,26 @@ def test_lead_order_by_page_count(self): self.create(LeadPreview, lead=lead3, page_count=None) # Ascending ordering - url = '/api/v1/leads/?ordering=,page_count,,' # this also tests leading/trailing/multiple commas + url = "/api/v1/leads/?ordering=,page_count,," # this also tests leading/trailing/multiple commas self.authenticate() response = self.client.get(url) self.assert_200(response) - assert len(response.data['results']) == 3, 'Three leads created' - leads = response.data['results'] - assert leads[0]['id'] == lead3.id, 'Preview3 has no pages' - assert leads[1]['id'] == lead2.id, 'Preview2 has less pages' - assert leads[2]['id'] == lead1.id, 'Preview1 has more pages' + assert len(response.data["results"]) == 3, "Three leads created" + leads = response.data["results"] + assert leads[0]["id"] == lead3.id, "Preview3 has no pages" + assert leads[1]["id"] == lead2.id, "Preview2 has less pages" + assert leads[2]["id"] == lead1.id, "Preview1 has more pages" # Descending ordering - url = '/api/v1/leads/?ordering=,-page_count,,' # this also tests leading/trailing/multiple commas + url = "/api/v1/leads/?ordering=,-page_count,," # this also tests leading/trailing/multiple commas self.authenticate() response = self.client.get(url) self.assert_200(response) - assert len(response.data['results']) == 3, 'Three leads created' - leads = response.data['results'] - assert leads[0]['id'] == lead1.id, 'Preview1 has more pages' - assert leads[1]['id'] == lead2.id, 'Preview2 has less pages' - assert leads[2]['id'] == lead3.id, 'Preview3 has no pages' + assert len(response.data["results"]) == 3, "Three leads created" + leads = response.data["results"] + assert leads[0]["id"] == lead1.id, "Preview1 has more pages" + assert leads[1]["id"] == lead2.id, "Preview2 has less pages" + assert leads[2]["id"] == lead3.id, "Preview3 has no pages" def test_lead_filter(self): project = self.create_project(create_assessment_template=True) @@ -1012,34 +1000,34 @@ def test_lead_filter(self): self.authenticate() - response = self.client.get(f'/api/v1/leads/?project={project2.pk}&priority={Lead.Priority.HIGH}') - assert response.json()['results'][0]['id'] == lead4.pk + response = self.client.get(f"/api/v1/leads/?project={project2.pk}&priority={Lead.Priority.HIGH}") + assert response.json()["results"][0]["id"] == lead4.pk - url = f'/api/v1/leads/?project={project.pk}' + url = f"/api/v1/leads/?project={project.pk}" # Project filter test response = self.client.get(url) - assert response.json()['count'] == 3, 'Lead count should be 3' + assert response.json()["count"] == 3, "Lead count should be 3" # Entries exists filter test self.create_entry(lead=lead1) self.create_entry(lead=lead2) - response = self.client.get(f'{url}&exists={LeadFilterSet.Exists.ENTRIES_EXISTS}') - assert response.json()['count'] == 2, 'Lead count should be 2 for lead with entries' + response = self.client.get(f"{url}&exists={LeadFilterSet.Exists.ENTRIES_EXISTS}") + assert response.json()["count"] == 2, "Lead count should be 2 for lead with entries" # Entries do not exist filter test - response = self.client.get(f'{url}&exists={LeadFilterSet.Exists.ENTRIES_DO_NOT_EXIST}') - assert response.json()['count'] == 1, 'Lead count should be 1 for lead without entries' + response = self.client.get(f"{url}&exists={LeadFilterSet.Exists.ENTRIES_DO_NOT_EXIST}") + assert response.json()["count"] == 1, "Lead count should be 1 for lead without entries" # Assessment exists filter test self.create_assessment(lead=lead1) self.create_assessment(lead=lead3) - response = self.client.get(f'{url}&exists={LeadFilterSet.Exists.ASSESSMENT_EXISTS}') - assert response.json()['count'] == 2, 'Lead count should be 2 for lead with assessment' + response = self.client.get(f"{url}&exists={LeadFilterSet.Exists.ASSESSMENT_EXISTS}") + assert response.json()["count"] == 2, "Lead count should be 2 for lead with assessment" # Assessment does not exist filter test - response = self.client.get(f'{url}&exists={LeadFilterSet.Exists.ASSESSMENT_DOES_NOT_EXIST}') - assert response.json()['count'] == 1, 'Lead count should be 1 for lead without assessment' + response = self.client.get(f"{url}&exists={LeadFilterSet.Exists.ASSESSMENT_DOES_NOT_EXIST}") + assert response.json()["count"] == 1, "Lead count should be 1 for lead without assessment" def test_lead_assignee_filter(self): user1 = self.create_user() @@ -1049,27 +1037,27 @@ def test_lead_assignee_filter(self): self.create_lead(project=project, assignee=[user1, user2]) self.create_lead(project=project, assignee=[user1]) self.create_lead(project=project, assignee=[user2]) - url = f'/api/v1/leads/?assignee={user1.id}' + url = f"/api/v1/leads/?assignee={user1.id}" # authenticate user self.authenticate() # filter by user who is assignee in some leads response = self.client.get(url) - assert len(response.data['results']) == 2 + assert len(response.data["results"]) == 2 # filter by user who is not assignee in any of the lead - url = f'/api/v1/leads/?assignee={user3.id}' + url = f"/api/v1/leads/?assignee={user3.id}" response = self.client.get(url) - assert len(response.data['results']) == 0 + assert len(response.data["results"]) == 0 def test_lead_authoring_organization_type_filter(self): - url = '/api/v1/leads/?authoring_organization_types={}' + url = "/api/v1/leads/?authoring_organization_types={}" - project = self.create_project(title='lead_test_project') - organization_type1 = self.create(OrganizationType, title='National') - organization_type2 = self.create(OrganizationType, title='International') - organization_type3 = self.create(OrganizationType, title='Government') + project = self.create_project(title="lead_test_project") + organization_type1 = self.create(OrganizationType, title="National") + organization_type2 = self.create(OrganizationType, title="International") + organization_type3 = self.create(OrganizationType, title="Government") organization1 = self.create(Organization, organization_type=organization_type1) organization2 = self.create(Organization, organization_type=organization_type2) @@ -1086,125 +1074,121 @@ def test_lead_authoring_organization_type_filter(self): # Authoring organization_type filter test response = self.client.get(url.format(organization_type1.id)) self.assert_200(response) - assert len(response.data['results']) == 2, 'There should be 2 lead' + assert len(response.data["results"]) == 2, "There should be 2 lead" # get multiple leads - organization_type_query = ','.join([ - str(id) for id in [organization_type1.id, organization_type3.id] - ]) + organization_type_query = ",".join([str(id) for id in [organization_type1.id, organization_type3.id]]) response = self.client.get(url.format(organization_type_query)) - assert len(response.data['results']) == 3, 'There should be 3 lead' + assert len(response.data["results"]) == 3, "There should be 3 lead" # test authoring_organization post filter - url = '/api/v1/leads/filter/' - filter_data = {'authoring_organization_types': [organization_type1.id]} + url = "/api/v1/leads/filter/" + filter_data = {"authoring_organization_types": [organization_type1.id]} self.authenticate() - response = self.client.post(url, data=filter_data, format='json') - assert len(response.data['results']) == 2, 'There should be 2 lead' + response = self.client.post(url, data=filter_data, format="json") + assert len(response.data["results"]) == 2, "There should be 2 lead" # test multiple post - filter_data = {'authoring_organization_types': [organization_type1.id, organization_type3.id]} + filter_data = {"authoring_organization_types": [organization_type1.id, organization_type3.id]} self.authenticate() - response = self.client.post(url, filter_data, format='json') - assert len(response.data['results']) == 3, 'There should be 3 lead' + response = self.client.post(url, filter_data, format="json") + assert len(response.data["results"]) == 3, "There should be 3 lead" def test_lead_filter_search(self): - url = '/api/v1/leads/?emm_entities={}' + url = "/api/v1/leads/?emm_entities={}" project = self.create_project() - lead1 = self.create(Lead, project=project, title='mytext') - lead2 = self.create(Lead, project=project, source_raw='thisis_mytext') + lead1 = self.create(Lead, project=project, title="mytext") + lead2 = self.create(Lead, project=project, source_raw="thisis_mytext") self.create(Lead, project=project) - self.create(Lead, project=project, title='nothing_here') + self.create(Lead, project=project, title="nothing_here") - url = '/api/v1/leads/?search={}' + url = "/api/v1/leads/?search={}" self.authenticate() - resp = self.client.get(url.format('mytext')) + resp = self.client.get(url.format("mytext")) self.assert_200(resp) expected_ids = {lead1.id, lead2.id} - obtained_ids = {x['id'] for x in resp.data['results']} + obtained_ids = {x["id"] for x in resp.data["results"]} assert expected_ids == obtained_ids - url = '/api/v1/leads/filter/' - post_data = {'search': 'mytext'} + url = "/api/v1/leads/filter/" + post_data = {"search": "mytext"} self.authenticate() resp = self.client.post(url, post_data) self.assert_200(resp) - obtained_ids = {x['id'] for x in resp.data['results']} + obtained_ids = {x["id"] for x in resp.data["results"]} assert expected_ids == obtained_ids def test_lead_filter_with_entries_filter(self): project = self.create_project() - lead1 = self.create(Lead, project=project, title='mytext') - lead2 = self.create(Lead, project=project, source_raw='thisis_mytext') + lead1 = self.create(Lead, project=project, title="mytext") + lead2 = self.create(Lead, project=project, source_raw="thisis_mytext") lead3 = self.create(Lead, project=project) - url = '/api/v1/leads/filter/' + url = "/api/v1/leads/filter/" post_data = {} self.authenticate() response = self.client.post(url, post_data) - assert response.json()['count'] == 3 + assert response.json()["count"] == 3 - post_data = {'custom_filters': LeadFilterSet.CustomFilter.EXCLUDE_EMPTY_FILTERED_ENTRIES} + post_data = {"custom_filters": LeadFilterSet.CustomFilter.EXCLUDE_EMPTY_FILTERED_ENTRIES} response = self.client.post(url, post_data) - assert response.json()['count'] == 0, 'There are not supposed to be leads with entries' + assert response.json()["count"] == 0, "There are not supposed to be leads with entries" entry1 = self.create(Entry, project=project, lead=lead1, controlled=True, entry_type=Entry.TagType.EXCERPT) self.create(Entry, project=project, lead=lead1, controlled=True, entry_type=Entry.TagType.EXCERPT) - post_data = {'entries_filter': [('controlled', True)]} + post_data = {"entries_filter": [("controlled", True)]} response = self.client.post(url, post_data) - assert response.json()['count'] == 3 - assert set([each['filteredEntriesCount'] for each in response.json()['results']]) \ - == set([0, 0, 2]), \ - response.json() + assert response.json()["count"] == 3 + assert set([each["filteredEntriesCount"] for each in response.json()["results"]]) == set([0, 0, 2]), response.json() entry2 = self.create(Entry, project=project, lead=lead2, controlled=False, entry_type=Entry.TagType.IMAGE) self.create(Entry, project=project, lead=lead3, controlled=False, entry_type=Entry.TagType.DATA_SERIES) post_data = { - 'custom_filters': LeadFilterSet.CustomFilter.EXCLUDE_EMPTY_FILTERED_ENTRIES, - 'entries_filter': [('controlled', True)] + "custom_filters": LeadFilterSet.CustomFilter.EXCLUDE_EMPTY_FILTERED_ENTRIES, + "entries_filter": [("controlled", True)], } response = self.client.post(url, post_data) - assert response.json()['count'] == 1 - assert response.data['results'][0]['id'] == lead1.id, response.data - assert response.json()['results'][0]['filteredEntriesCount'] == 2, response.json() + assert response.json()["count"] == 1 + assert response.data["results"][0]["id"] == lead1.id, response.data + assert response.json()["results"][0]["filteredEntriesCount"] == 2, response.json() - post_data['entries_filter'] = [] - post_data['entries_filter'].append(('entry_type', [Entry.TagType.EXCERPT, Entry.TagType.IMAGE])) + post_data["entries_filter"] = [] + post_data["entries_filter"].append(("entry_type", [Entry.TagType.EXCERPT, Entry.TagType.IMAGE])) response = self.client.post(url, post_data) - self.assertEqual(response.json()['count'], 2, response.json()) + self.assertEqual(response.json()["count"], 2, response.json()) # there should be 1 image entry and 2 excerpt entries - assert set([1, 2]) == set([item['filteredEntriesCount'] for item in response.json()['results']]), response.json() + assert set([1, 2]) == set([item["filteredEntriesCount"] for item in response.json()["results"]]), response.json() # filter by project_entry_labels # Labels - label1 = self.create(ProjectEntryLabel, project=project, title='Label 1', order=1, color='#23f23a') - label2 = self.create(ProjectEntryLabel, project=project, title='Label 2', order=2, color='#23f23a') - self.create(ProjectEntryLabel, project=project, title='Label 3', order=3, color='#23f23a') + label1 = self.create(ProjectEntryLabel, project=project, title="Label 1", order=1, color="#23f23a") + label2 = self.create(ProjectEntryLabel, project=project, title="Label 2", order=2, color="#23f23a") + self.create(ProjectEntryLabel, project=project, title="Label 3", order=3, color="#23f23a") # Groups - group11 = self.create(LeadEntryGroup, lead=lead1, title='Group 1', order=1) - group12 = self.create(LeadEntryGroup, lead=lead1, title='Group 2', order=2) - group21 = self.create(LeadEntryGroup, lead=lead2, title='Group 2', order=2) + group11 = self.create(LeadEntryGroup, lead=lead1, title="Group 1", order=1) + group12 = self.create(LeadEntryGroup, lead=lead1, title="Group 2", order=2) + group21 = self.create(LeadEntryGroup, lead=lead2, title="Group 2", order=2) self.create(EntryGroupLabel, group=group11, label=label1, entry=entry1) self.create(EntryGroupLabel, group=group12, label=label2, entry=entry1) self.create(EntryGroupLabel, group=group21, label=label2, entry=entry2) - post_data['entries_filter'] = [] - post_data['entries_filter'].append(('project_entry_labels', [label1.id])) + post_data["entries_filter"] = [] + post_data["entries_filter"].append(("project_entry_labels", [label1.id])) response = self.client.post(url, post_data) - self.assertEqual(response.json()['count'], 1, response.json()) - assert response.json()['results'][0]['filteredEntriesCount'] == 1, response.json() + self.assertEqual(response.json()["count"], 1, response.json()) + assert response.json()["results"][0]["filteredEntriesCount"] == 1, response.json() - post_data['entries_filter'] = [] - post_data['entries_filter'].append(('project_entry_labels', [label1.id, label2.id])) + post_data["entries_filter"] = [] + post_data["entries_filter"].append(("project_entry_labels", [label1.id, label2.id])) response = self.client.post(url, post_data) - self.assertEqual(response.json()['count'], 2, response.json()) + self.assertEqual(response.json()["count"], 2, response.json()) # lead1 has 1 label1+label2 entries # lead2 has 1 label2 entries - assert [1, 1] == [item['filteredEntriesCount'] for item in response.json()['results']], response.json() + assert [1, 1] == [item["filteredEntriesCount"] for item in response.json()["results"]], response.json() def test_filtered_lead_list_with_controlled_entries_count(self): project = self.create_project() @@ -1215,57 +1199,57 @@ def test_filtered_lead_list_with_controlled_entries_count(self): self.create_entry(lead=lead, project=project, controlled=False) self.create_entry(lead=lead2, project=project, controlled=True) - url = '/api/v1/leads/filter/' + url = "/api/v1/leads/filter/" self.authenticate() resp = self.client.post(url, dict()) self.assert_200(resp) - counts = [x['controlled_entries_count'] for x in resp.data['results']] + counts = [x["controlled_entries_count"] for x in resp.data["results"]] self.assertEqual(counts, [1, 1]) def test_lead_filter_search_by_author(self): project = self.create_project() - author = self.create(Organization, title='blablaone') - author2 = self.create(Organization, title='blablatwo') + author = self.create(Organization, title="blablaone") + author2 = self.create(Organization, title="blablatwo") - lead1 = self.create(Lead, project=project, author=author, author_raw='wood') + lead1 = self.create(Lead, project=project, author=author, author_raw="wood") lead2 = self.create(Lead, project=project, author=author2) self.create(Lead, project=project, author=None) lead = self.create(Lead, project=project) lead.authors.set([author, author2]) - url = '/api/v1/leads/filter/' - post_data = {'search': 'blablaone'} + url = "/api/v1/leads/filter/" + post_data = {"search": "blablaone"} expected_ids = {lead1.id, lead.id} self.authenticate() resp = self.client.post(url, post_data) self.assert_200(resp) - obtained_ids = {x['id'] for x in resp.data['results']} + obtained_ids = {x["id"] for x in resp.data["results"]} assert expected_ids == obtained_ids - post_data = {'search': 'blablatwo'} + post_data = {"search": "blablatwo"} expected_ids = {lead2.id, lead.id} resp = self.client.post(url, post_data) self.assert_200(resp) - obtained_ids = {x['id'] for x in resp.data['results']} + obtained_ids = {x["id"] for x in resp.data["results"]} assert expected_ids == obtained_ids - post_data = {'search': 'wood'} + post_data = {"search": "wood"} expected_ids = {lead1.id} resp = self.client.post(url, post_data) self.assert_200(resp) - obtained_ids = {x['id'] for x in resp.data['results']} + obtained_ids = {x["id"] for x in resp.data["results"]} assert expected_ids == obtained_ids def test_lead_filter_emm_entities(self): - url = '/api/v1/leads/?emm_entities={}' + url = "/api/v1/leads/?emm_entities={}" project = self.create_project() # Create Entities - entity1 = self.create(EMMEntity, name='entity1') - entity2 = self.create(EMMEntity, name='entity2') - entity3 = self.create(EMMEntity, name='entity3') - entity4 = self.create(EMMEntity, name='entity4') # noqa:F841 + entity1 = self.create(EMMEntity, name="entity1") + entity2 = self.create(EMMEntity, name="entity2") + entity3 = self.create(EMMEntity, name="entity3") + entity4 = self.create(EMMEntity, name="entity4") # noqa:F841 lead1 = self.create_lead(project=project, emm_entities=[entity1]) lead2 = self.create_lead(project=project, emm_entities=[entity2, entity3]) @@ -1274,8 +1258,8 @@ def test_lead_filter_emm_entities(self): def _test_response(resp): self.assert_200(resp) - assert len(resp.data['results']) == 3, 'There should be three leads' - ids_list = [x['id']for x in resp.data['results']] + assert len(resp.data["results"]) == 3, "There should be three leads" + ids_list = [x["id"] for x in resp.data["results"]] assert lead1.id in ids_list assert lead2.id in ids_list assert lead3.id in ids_list @@ -1285,30 +1269,30 @@ def _test_response(resp): # Get a single lead resp = self.client.get(url.format(entity1.id)) self.assert_200(resp) - assert len(resp.data['results']) == 1, 'There should be one lead' - assert resp.data['results'][0]['id'] == lead1.id + assert len(resp.data["results"]) == 1, "There should be one lead" + assert resp.data["results"][0]["id"] == lead1.id # Get a multiple leads - entities_query = ','.join([str(entity1.id), str(entity2.id)]) + entities_query = ",".join([str(entity1.id), str(entity2.id)]) resp = self.client.get(url.format(entities_query)) _test_response(resp) # test post filter - url = '/api/v1/leads/filter/' - filter_data = {'emm_entities': [entity1.id]} + url = "/api/v1/leads/filter/" + filter_data = {"emm_entities": [entity1.id]} self.authenticate() - resp = self.client.post(url, filter_data, format='json') - assert len(resp.data['results']) == 1, 'There should be one lead' - assert resp.data['results'][0]['id'] == lead1.id + resp = self.client.post(url, filter_data, format="json") + assert len(resp.data["results"]) == 1, "There should be one lead" + assert resp.data["results"][0]["id"] == lead1.id - filter_data = {'emm_entities': [entity1.id, entity2.id]} + filter_data = {"emm_entities": [entity1.id, entity2.id]} self.authenticate() - resp = self.client.post(url, filter_data, format='json') + resp = self.client.post(url, filter_data, format="json") _test_response(resp) def test_lead_filter_emm_keywords(self): - url = '/api/v1/leads/?emm_keywords={}' + url = "/api/v1/leads/?emm_keywords={}" project = self.create_project() lead1 = self.create_lead(project=project) @@ -1318,44 +1302,56 @@ def test_lead_filter_emm_keywords(self): # Create LeadEMMTrigger objects with self.create( - LeadEMMTrigger, lead=lead1, count=5, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead1, + count=5, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead2, count=3, - emm_keyword='keyword1', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead2, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead3, count=3, - emm_keyword='keyword3', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead3, + count=3, + emm_keyword="keyword3", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead4, count=3, - emm_keyword='keyword2', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead4, + count=3, + emm_keyword="keyword2", + emm_risk_factor="rf1", ) self.authenticate() # Get a single lead - resp = self.client.get(url.format('keyword1')) + resp = self.client.get(url.format("keyword1")) self.assert_200(resp) - assert len(resp.data['results']) == 2, 'There should be 2 leads' - ids_list = [x['id']for x in resp.data['results']] + assert len(resp.data["results"]) == 2, "There should be 2 leads" + ids_list = [x["id"] for x in resp.data["results"]] assert lead1.id in ids_list assert lead2.id in ids_list # Get multiple leads - entities_query = ','.join(['keyword1', 'keyword2']) + entities_query = ",".join(["keyword1", "keyword2"]) resp = self.client.get(url.format(entities_query)) self.assert_200(resp) - assert len(resp.data['results']) == 3, 'There should be three leads' - ids_list = [x['id']for x in resp.data['results']] + assert len(resp.data["results"]) == 3, "There should be three leads" + ids_list = [x["id"] for x in resp.data["results"]] assert lead1.id in ids_list assert lead2.id in ids_list assert lead4.id in ids_list def test_lead_filter_emm_risk_factors(self): - url = '/api/v1/leads/?emm_risk_factors={}' + url = "/api/v1/leads/?emm_risk_factors={}" project = self.create_project() lead1 = self.create_lead(project=project) @@ -1365,52 +1361,64 @@ def test_lead_filter_emm_risk_factors(self): # Create LeadEMMTrigger objects with self.create( - LeadEMMTrigger, lead=lead1, count=5, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead1, + count=5, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead2, count=3, - emm_keyword='keyword1', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead2, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead3, count=3, - emm_keyword='keyword3', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead3, + count=3, + emm_keyword="keyword3", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead4, count=3, - emm_keyword='keyword2', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead4, + count=3, + emm_keyword="keyword2", + emm_risk_factor="rf1", ) self.authenticate() # Get a single lead - resp = self.client.get(url.format('rf1')) + resp = self.client.get(url.format("rf1")) self.assert_200(resp) - assert len(resp.data['results']) == 2, 'There should be 2 leads' - ids_list = [x['id']for x in resp.data['results']] + assert len(resp.data["results"]) == 2, "There should be 2 leads" + ids_list = [x["id"] for x in resp.data["results"]] assert lead1.id in ids_list assert lead4.id in ids_list # Get multiple leads - entities_query = ','.join(['rf1', 'rf2']) + entities_query = ",".join(["rf1", "rf2"]) resp = self.client.get(url.format(entities_query)) self.assert_200(resp) - assert len(resp.data['results']) == 4, 'There should be four leads' - ids_list = [x['id'] for x in resp.data['results']] + assert len(resp.data["results"]) == 4, "There should be four leads" + ids_list = [x["id"] for x in resp.data["results"]] assert lead1.id in ids_list assert lead2.id in ids_list assert lead3.id in ids_list assert lead4.id in ids_list def test_get_emm_extra_with_emm_entities_filter(self): - url = '/api/v1/leads/emm-summary/?emm_entities={}' + url = "/api/v1/leads/emm-summary/?emm_entities={}" project = self.create_project() # Create Entities - entity1 = self.create(EMMEntity, name='entity1') - entity2 = self.create(EMMEntity, name='entity2') - entity3 = self.create(EMMEntity, name='entity3') - entity4 = self.create(EMMEntity, name='entity4') # noqa:F841 + entity1 = self.create(EMMEntity, name="entity1") + entity2 = self.create(EMMEntity, name="entity2") + entity3 = self.create(EMMEntity, name="entity3") + entity4 = self.create(EMMEntity, name="entity4") # noqa:F841 self.create_lead(project=project, emm_entities=[entity1]) self.create_lead(project=project, emm_entities=[entity2, entity3]) @@ -1420,32 +1428,32 @@ def test_get_emm_extra_with_emm_entities_filter(self): # Test get filter self.authenticate() # Get a single lead - entities_query = ','.join([str(entity1.id), str(entity2.id)]) + entities_query = ",".join([str(entity1.id), str(entity2.id)]) resp = self.client.get(url.format(entities_query)) self.assert_200(resp) extra = resp.data - assert 'emm_entities' in extra - assert 'emm_triggers' in extra + assert "emm_entities" in extra + assert "emm_triggers" in extra - expected_entities_counts = {('entity1', 1), ('entity2', 2), ('entity3', 1)} - result_entities_counts = {(x['name'], x['total_count']) for x in extra['emm_entities']} + expected_entities_counts = {("entity1", 1), ("entity2", 2), ("entity3", 1)} + result_entities_counts = {(x["name"], x["total_count"]) for x in extra["emm_entities"]} assert expected_entities_counts == result_entities_counts # TODO: test post - filter_data = {'emm_entities': [entity1.id, entity2.id]} - url = '/api/v1/leads/emm-summary/' + filter_data = {"emm_entities": [entity1.id, entity2.id]} + url = "/api/v1/leads/emm-summary/" self.authenticate() - self.client.post(url, data=filter_data, format='json') + self.client.post(url, data=filter_data, format="json") self.assert_200(resp) extra = resp.data - assert 'emm_entities' in extra - assert 'emm_triggers' in extra + assert "emm_entities" in extra + assert "emm_triggers" in extra - expected_entities_counts = {('entity1', 1), ('entity2', 2), ('entity3', 1)} - result_entities_counts = {(x['name'], x['total_count']) for x in extra['emm_entities']} + expected_entities_counts = {("entity1", 1), ("entity2", 2), ("entity3", 1)} + result_entities_counts = {(x["name"], x["total_count"]) for x in extra["emm_entities"]} assert expected_entities_counts == result_entities_counts def test_get_emm_extra_with_emm_keywords_filter(self): @@ -1458,37 +1466,46 @@ def test_get_emm_extra_with_emm_keywords_filter(self): # Create LeadEMMTrigger objects with self.create( - LeadEMMTrigger, lead=lead1, count=5, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead1, + count=5, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead2, count=3, - emm_keyword='keyword1', emm_risk_factor='rf1', + LeadEMMTrigger, + lead=lead2, + count=3, + emm_keyword="keyword1", + emm_risk_factor="rf1", ) self.create( - LeadEMMTrigger, lead=lead3, count=3, - emm_keyword='keyword3', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead3, + count=3, + emm_keyword="keyword3", + emm_risk_factor="rf2", ) self.create( - LeadEMMTrigger, lead=lead4, count=3, - emm_keyword='keyword2', emm_risk_factor='rf2', + LeadEMMTrigger, + lead=lead4, + count=3, + emm_keyword="keyword2", + emm_risk_factor="rf2", ) # Test GET - url = '/api/v1/leads/emm-summary/?emm_keywords=keyword1,keyword2' + url = "/api/v1/leads/emm-summary/?emm_keywords=keyword1,keyword2" self.authenticate() resp = self.client.get(url) self.assert_200(resp) data = resp.data - assert 'emm_entities' in data - assert 'emm_triggers' in data + assert "emm_entities" in data + assert "emm_triggers" in data - expected_triggers = {('keyword1', 'rf1', 8), ('keyword2', 'rf2', 3)} - result_triggers = { - (x['emm_keyword'], x['emm_risk_factor'], x['total_count']) - for x in data['emm_triggers'] - } + expected_triggers = {("keyword1", "rf1", 8), ("keyword2", "rf2", 3)} + result_triggers = {(x["emm_keyword"], x["emm_risk_factor"], x["total_count"]) for x in data["emm_triggers"]} assert expected_triggers == result_triggers def test_lead_summary_get(self): @@ -1498,17 +1515,17 @@ def test_lead_summary_get(self): self.create_entry(lead=lead1, controlled=True) self.create_entry(lead=lead2) - url = '/api/v1/leads/summary/' + url = "/api/v1/leads/summary/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - self.assertEqual(resp.data['total'], 2) - self.assertEqual(resp.data['total_entries'], 3) - self.assertEqual(resp.data['total_controlled_entries'], 1) - self.assertEqual(resp.data['total_uncontrolled_entries'], 2) - assert 'emm_entities' in resp.data - assert 'emm_triggers' in resp.data + self.assertEqual(resp.data["total"], 2) + self.assertEqual(resp.data["total_entries"], 3) + self.assertEqual(resp.data["total_controlled_entries"], 1) + self.assertEqual(resp.data["total_uncontrolled_entries"], 2) + assert "emm_entities" in resp.data + assert "emm_triggers" in resp.data def test_lead_summary_post(self): lead1 = self.create_lead() @@ -1517,69 +1534,63 @@ def test_lead_summary_post(self): self.create_entry(lead=lead1, controlled=True) self.create_entry(lead=lead2) - url = '/api/v1/leads/summary/' + url = "/api/v1/leads/summary/" self.authenticate() - resp = self.client.post(url, data={}, format='json') + resp = self.client.post(url, data={}, format="json") self.assert_200(resp) - self.assertEqual(resp.data['total'], 2) - self.assertEqual(resp.data['total_entries'], 3) - self.assertEqual(resp.data['total_controlled_entries'], 1) - self.assertEqual(resp.data['total_uncontrolled_entries'], 2) - assert 'emm_entities' in resp.data - assert 'emm_triggers' in resp.data + self.assertEqual(resp.data["total"], 2) + self.assertEqual(resp.data["total_entries"], 3) + self.assertEqual(resp.data["total_controlled_entries"], 1) + self.assertEqual(resp.data["total_uncontrolled_entries"], 2) + assert "emm_entities" in resp.data + assert "emm_triggers" in resp.data def test_lead_group_post(self): project = self.create_project() - data = { - 'project': project.id, - 'title': 'Test Lead Group Title' - } - url = '/api/v1/lead-groups/' + data = {"project": project.id, "title": "Test Lead Group Title"} + url = "/api/v1/lead-groups/" self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['title'], data['title']) - self.assertEqual(response.data['project'], data['project']) + self.assertEqual(response.data["title"], data["title"]) + self.assertEqual(response.data["project"], data["project"]) def test_lead_group_get(self): project_1 = self.create_project() project_2 = self.create_project() - leadg_1 = self.create(LeadGroup, project=project_1, title='test1') - leadg_2 = self.create(LeadGroup, project=project_1, title='test2') - self.create(LeadGroup, project=project_2, title='test3') - self.create(LeadGroup, project=project_2, title='test4') + leadg_1 = self.create(LeadGroup, project=project_1, title="test1") + leadg_2 = self.create(LeadGroup, project=project_1, title="test2") + self.create(LeadGroup, project=project_2, title="test3") + self.create(LeadGroup, project=project_2, title="test4") - url = '/api/v1/lead-groups/' + url = "/api/v1/lead-groups/" self.authenticate() response = self.client.get(url) self.assert_200(response) self.assertEqual(len(response.data), 4) # test for project filter - url = f'/api/v1/lead-groups/?project={project_1.id}' + url = f"/api/v1/lead-groups/?project={project_1.id}" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) - self.assertEqual( - set(lg['id'] for lg in response.data['results']), - set([leadg_1.id, leadg_2.id]) - ) + self.assertEqual(len(response.data["results"]), 2) + self.assertEqual(set(lg["id"] for lg in response.data["results"]), set([leadg_1.id, leadg_2.id])) # test for the search field `title` - url = '/api/v1/lead-groups/?search=test1' + url = "/api/v1/lead-groups/?search=test1" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 1) - self.assertEqual(response.data['results'][0]['id'], leadg_1.id) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], leadg_1.id) def test_authors_and_authoring_organization_type(self): - project = self.create_project(title='lead_test_project') - organization_type1 = self.create(OrganizationType, title='National') - organization_type2 = self.create(OrganizationType, title='International') - organization_type3 = self.create(OrganizationType, title='Government') + project = self.create_project(title="lead_test_project") + organization_type1 = self.create(OrganizationType, title="National") + organization_type2 = self.create(OrganizationType, title="International") + organization_type3 = self.create(OrganizationType, title="Government") organization1 = self.create(Organization, organization_type=organization_type1) organization2 = self.create(Organization, organization_type=organization_type2) @@ -1598,90 +1609,66 @@ def test_authors_and_authoring_organization_type(self): lead4 = self.create_lead(project=project, authors=[organization5, organization1]) lead5 = self.create_lead(project=project, authors=[organization6]) # test for authors - self.assertEqual( - set(lead1.get_authors_display()), - set(','.join([organization2.title, organization3.title])) - ) - self.assertEqual( - lead2.get_authors_display(), - organization1.title # organization1 since lead2 authors have parent - ) - self.assertEqual( - lead3.get_authors_display(), - organization1.title - ) - self.assertEqual( - set(lead4.get_authors_display()), - set(','.join([organization1.title, organization5.title])) - ) + self.assertEqual(set(lead1.get_authors_display()), set(",".join([organization2.title, organization3.title]))) + self.assertEqual(lead2.get_authors_display(), organization1.title) # organization1 since lead2 authors have parent + self.assertEqual(lead3.get_authors_display(), organization1.title) + self.assertEqual(set(lead4.get_authors_display()), set(",".join([organization1.title, organization5.title]))) # test for authoring_oragnizations_type self.assertEqual( set(lead1.get_authoring_organizations_type_display()), - set(','.join([organization_type2.title, organization_type3.title])) + set(",".join([organization_type2.title, organization_type3.title])), ) self.assertEqual( lead2.get_authoring_organizations_type_display(), - organization_type1.title # organization_type1 since lead2 authors have parent - ) - self.assertEqual( - lead3.get_authoring_organizations_type_display(), - organization_type1.title + organization_type1.title, # organization_type1 since lead2 authors have parent ) + self.assertEqual(lead3.get_authoring_organizations_type_display(), organization_type1.title) self.assertEqual( lead4.get_authoring_organizations_type_display(), - organization_type1.title # organization_type1 since both authors have same organization_type - ) - self.assertEqual( - lead5.get_authoring_organizations_type_display(), - '' + organization_type1.title, # organization_type1 since both authors have same organization_type ) + self.assertEqual(lead5.get_authoring_organizations_type_display(), "") def test_is_assessment_lead(self): project = self.create_project() - lead = self.create( - Lead, - project=project, - is_assessment_lead=True - ) + lead = self.create(Lead, project=project, is_assessment_lead=True) - url = f'/api/v1/leads/{lead.id}/' + url = f"/api/v1/leads/{lead.id}/" self.authenticate() response = self.client.get(url) self.assert_200(response) # now create Assessment for the lead self.create(Assessment, lead=lead) - data = { - 'is_assessment_lead': False - } - url = f'/api/v1/leads/{lead.id}/' + data = {"is_assessment_lead": False} + url = f"/api/v1/leads/{lead.id}/" self.authenticate() response = self.client.patch(url, data) self.assert_400(response) self.assertEqual( - response.data['errors']['is_assessment_lead'], - [ErrorDetail(string='Lead already has an assessment.', code='invalid')] + response.data["errors"]["is_assessment_lead"], [ErrorDetail(string="Lead already has an assessment.", code="invalid")] ) # here delete the assessment that has lead Assessment.objects.filter(lead=lead).delete() - data = { - 'is_assessment_lead': False - } - url = f'/api/v1/leads/{lead.id}/' + data = {"is_assessment_lead": False} + url = f"/api/v1/leads/{lead.id}/" self.authenticate() response = self.client.patch(url, data) self.assert_200(response) + # Data to use for testing web info extractor # Including, url of the page and its attributes: # source, country, date, website -SAMPLE_WEB_INFO_URL = 'https://reliefweb.int/report/yemen/yemen-emergency-food-security-and-nutrition-assessment-efsna-2016-preliminary-results' # noqa -SAMPLE_WEB_INFO_SOURCE = 'World Food Programme, UN Children\'s Fund, Food and Agriculture Organization of the United Nations' # noqa -SAMPLE_WEB_INFO_COUNTRY = 'Yemen' +SAMPLE_WEB_INFO_URL = "https://reliefweb.int/report/yemen/yemen-emergency-food-security-and-nutrition-assessment-efsna-2016-preliminary-results" # noqa +SAMPLE_WEB_INFO_SOURCE = ( + "World Food Programme, UN Children's Fund, Food and Agriculture Organization of the United Nations" # noqa +) +SAMPLE_WEB_INFO_COUNTRY = "Yemen" SAMPLE_WEB_INFO_DATE = date(2017, 1, 26) -SAMPLE_WEB_INFO_TITLE = 'Yemen Emergency Food Security and Nutrition Assessment (EFSNA) 2016 - Preliminary Results' # noqa +SAMPLE_WEB_INFO_TITLE = "Yemen Emergency Food Security and Nutrition Assessment (EFSNA) 2016 - Preliminary Results" # noqa class WebInfoExtractionTests(TestCase): @@ -1693,41 +1680,39 @@ def setUp(self): self.unhcr = self.create(Organization, **UNHCR_DATA) def test_redhum(self): - url = '/api/v1/web-info-extract/' + url = "/api/v1/web-info-extract/" data = { - 'url': 'https://redhum.org/documento/3227553', + "url": "https://redhum.org/documento/3227553", } try: self.authenticate() response = self.client.post(url, data) rdata = self.client.post(url, data).data self.assert_200(response) - self.assertEqual(rdata['title'], 'Pregnant women flee lack of maternal health care in Venezuela') - self.assertEqual(rdata['date'], '2019-07-23') - self.assertEqual(rdata['country'], 'Colombia') - self.assertEqual(rdata['url'], data['url']) - self.assertEqual(rdata['source_raw'], 'redhum') - self.assertEqual(rdata['author_raw'], 'United Nations High Commissioner for Refugees') - self.assertEqual(rdata['source'], SimpleOrganizationSerializer(self.redhum).data) - self.assertEqual(rdata['author'], SimpleOrganizationSerializer(self.unhcr).data) + self.assertEqual(rdata["title"], "Pregnant women flee lack of maternal health care in Venezuela") + self.assertEqual(rdata["date"], "2019-07-23") + self.assertEqual(rdata["country"], "Colombia") + self.assertEqual(rdata["url"], data["url"]) + self.assertEqual(rdata["source_raw"], "redhum") + self.assertEqual(rdata["author_raw"], "United Nations High Commissioner for Refugees") + self.assertEqual(rdata["source"], SimpleOrganizationSerializer(self.redhum).data) + self.assertEqual(rdata["author"], SimpleOrganizationSerializer(self.unhcr).data) except Exception: import traceback - logger.warning('\n' + ('*' * 30)) - logger.warning('LEAD WEB INFO EXTRACTION ERROR:') + + logger.warning("\n" + ("*" * 30)) + logger.warning("LEAD WEB INFO EXTRACTION ERROR:") logger.warning(traceback.format_exc()) return def test_extract_web_info(self): # Create a sample project containing the sample country - sample_region = self.create(Region, title=SAMPLE_WEB_INFO_COUNTRY, - public=True) + sample_region = self.create(Region, title=SAMPLE_WEB_INFO_COUNTRY, public=True) sample_project = self.create(Project, role=self.admin_role) sample_project.regions.add(sample_region) - url = '/api/v1/web-info-extract/' - data = { - 'url': SAMPLE_WEB_INFO_URL - } + url = "/api/v1/web-info-extract/" + data = {"url": SAMPLE_WEB_INFO_URL} try: self.authenticate() @@ -1735,37 +1720,34 @@ def test_extract_web_info(self): self.assert_200(response) except Exception: import traceback - logger.warning('\n' + ('*' * 30)) - logger.warning('LEAD WEB INFO EXTRACTION ERROR:') + + logger.warning("\n" + ("*" * 30)) + logger.warning("LEAD WEB INFO EXTRACTION ERROR:") logger.warning(traceback.format_exc()) return expected = { - 'project': sample_project.id, - 'date': SAMPLE_WEB_INFO_DATE, - 'country': SAMPLE_WEB_INFO_COUNTRY, - 'title': SAMPLE_WEB_INFO_TITLE, - 'url': SAMPLE_WEB_INFO_URL, - 'source': SimpleOrganizationSerializer(self.reliefweb).data, - 'source_raw': 'reliefweb', - 'author': None, - 'author_raw': SAMPLE_WEB_INFO_SOURCE, - 'existing': False, + "project": sample_project.id, + "date": SAMPLE_WEB_INFO_DATE, + "country": SAMPLE_WEB_INFO_COUNTRY, + "title": SAMPLE_WEB_INFO_TITLE, + "url": SAMPLE_WEB_INFO_URL, + "source": SimpleOrganizationSerializer(self.reliefweb).data, + "source_raw": "reliefweb", + "author": None, + "author_raw": SAMPLE_WEB_INFO_SOURCE, + "existing": False, } self.assertEqualWithWarning(expected, response.data) class WebInfoDataTestCase(TestCase): def test_relief_web(self): - self.create(Organization, title='Organization 1') - self.create(Organization, title='Organization 2') + self.create(Organization, title="Organization 1") + self.create(Organization, title="Organization 2") - url = '/api/v1/web-info-data/' - data = { - 'url': SAMPLE_WEB_INFO_URL, - 'authors_raw': ['Organization1', 'Organization2'], - 'source_raw': 'Organization1' - } + url = "/api/v1/web-info-data/" + data = {"url": SAMPLE_WEB_INFO_URL, "authors_raw": ["Organization1", "Organization2"], "source_raw": "Organization1"} self.authenticate() response = self.client.post(url, data) self.assert_200(response) @@ -1776,18 +1758,16 @@ def setUp(self): super().setUp() self.lead = LeadFactory.create() - @mock.patch('lead.serializers.index_lead_and_calculate_duplicates.delay') - @mock.patch('deepl_integration.handlers.RequestHelper.get_text') - @mock.patch('deepl_integration.handlers.RequestHelper.get_file') + @mock.patch("lead.serializers.index_lead_and_calculate_duplicates.delay") + @mock.patch("deepl_integration.handlers.RequestHelper.get_text") + @mock.patch("deepl_integration.handlers.RequestHelper.get_file") def test_extractor_callback_url(self, get_file_mock, get_text_mock, index_lead_func): - url = '/api/v1/callback/lead-extract/' + url = "/api/v1/callback/lead-extract/" self.authenticate() - image = SimpleUploadedFile( - name='test_image.jpg', content=b'', content_type='image/jpeg' - ) + image = SimpleUploadedFile(name="test_image.jpg", content=b"", content_type="image/jpeg") get_file_mock.return_value = image - get_text_mock.return_value = 'Extracted text' + get_text_mock.return_value = "Extracted text" # Before callback lead_preview = LeadPreview.objects.filter(lead=self.lead).last() @@ -1796,14 +1776,14 @@ def test_extractor_callback_url(self, get_file_mock, get_text_mock, index_lead_f self.assertEqual(images_count, 0) data = { - 'client_id': LeadExtractionHandler.get_client_id(self.lead), - 'images_path': ['http://random.com/image1.jpeg', 'http://random.com/image1.jpeg'], - 'text_path': 'http://random.com/extracted_file.txt', - 'url': 'http://random.com/pdf_file.pdf', - 'total_words_count': 300, - 'total_pages': 4, - 'status': DeeplServerBaseCallbackSerializer.Status.FAILED.value, - 'text_extraction_id': '00431349-5879-4d59-9827-0b12491c4baa' + "client_id": LeadExtractionHandler.get_client_id(self.lead), + "images_path": ["http://random.com/image1.jpeg", "http://random.com/image1.jpeg"], + "text_path": "http://random.com/extracted_file.txt", + "url": "http://random.com/pdf_file.pdf", + "total_words_count": 300, + "total_pages": 4, + "status": DeeplServerBaseCallbackSerializer.Status.FAILED.value, + "text_extraction_id": "00431349-5879-4d59-9827-0b12491c4baa", } # After callback [Failure] @@ -1814,7 +1794,7 @@ def test_extractor_callback_url(self, get_file_mock, get_text_mock, index_lead_f self.assertEqual(LeadPreview.objects.filter(lead=self.lead).count(), 0) self.assertEqual(LeadPreviewImage.objects.filter(lead=self.lead).count(), 0) - data['status'] = DeeplServerBaseCallbackSerializer.Status.SUCCESS.value + data["status"] = DeeplServerBaseCallbackSerializer.Status.SUCCESS.value # After callback [Success] with self.captureOnCommitCallbacks(execute=True): response = self.client.post(url, data) @@ -1823,7 +1803,7 @@ def test_extractor_callback_url(self, get_file_mock, get_text_mock, index_lead_f self.assertEqual(self.lead.extraction_status, Lead.ExtractionStatus.SUCCESS) self.assertEqual(LeadPreview.objects.filter(lead=self.lead).count(), 1) lead_preview = LeadPreview.objects.filter(lead=self.lead).last() - self.assertEqual(lead_preview.text_extract, 'Extracted text') + self.assertEqual(lead_preview.text_extract, "Extracted text") self.assertEqual(lead_preview.word_count, 300) self.assertEqual(lead_preview.page_count, 4) self.assertEqual(LeadPreviewImage.objects.filter(lead=self.lead).count(), 2) @@ -1841,13 +1821,13 @@ def test_client_id_generator(self): (lead1, lead1_client_id, None), ( lead1, - f'{UidBase64Helper.encode(lead1.pk)}-some-random-id', + f"{UidBase64Helper.encode(lead1.pk)}-some-random-id", LeadExtractionHandler.Exception.InvalidOrExpiredToken, ), - (lead1, '11-some-random-id', LeadExtractionHandler.Exception.InvalidTokenValue), - (lead1, 'some-random-id', LeadExtractionHandler.Exception.InvalidTokenValue), + (lead1, "11-some-random-id", LeadExtractionHandler.Exception.InvalidTokenValue), + (lead1, "some-random-id", LeadExtractionHandler.Exception.InvalidTokenValue), (lead2, lead2_client_id, LeadExtractionHandler.Exception.ObjectNotFound), - (lead2, 'somerandomid', LeadExtractionHandler.Exception.InvalidTokenValue), + (lead2, "somerandomid", LeadExtractionHandler.Exception.InvalidTokenValue), ]: if excepted_exception: with self.assertRaises(excepted_exception): @@ -1862,9 +1842,9 @@ def setUp(self): self.lead = LeadFactory.create() self.lead_preview = LeadPreviewFactory.create(lead=self.lead, text_extraction_id=str(uuid.uuid1())) - @mock.patch('deepl_integration.handlers.RequestHelper.json') + @mock.patch("deepl_integration.handlers.RequestHelper.json") def test_entry_extraction_callback_url(self, get_json_mock): - url = '/api/v1/callback/auto-assisted-tagging-draft-entry-prediction/' + url = "/api/v1/callback/auto-assisted-tagging-draft-entry-prediction/" self.authenticate() SAMPLE_AUTO_ASSISTED_TAGGING = { "metadata": {"total_pages": 10, "total_words_count": 5876}, @@ -1897,9 +1877,7 @@ def test_entry_extraction_callback_url(self, get_json_mock): "meta": {"offset_start": 183, "offset_end": 191, "latitude": None, "longitude": None}, } ], - "classification": { - "1": {"101": {"prediction": 2.0000270270529, "threshold": 0.14, "is_selected": True}} - }, + "classification": {"1": {"101": {"prediction": 2.0000270270529, "threshold": 0.14, "is_selected": True}}}, }, { "type": "text", @@ -1921,18 +1899,18 @@ def test_entry_extraction_callback_url(self, get_json_mock): # Invalid clientId data = { - 'client_id': 'invalid-client-id', - 'entry_extraction_classification_path': 'https://random-domain.com/random-url.json', - 'text_extraction_id': str(self.lead_preview.text_extraction_id), - 'status': 1 + "client_id": "invalid-client-id", + "entry_extraction_classification_path": "https://random-domain.com/random-url.json", + "text_extraction_id": str(self.lead_preview.text_extraction_id), + "status": 1, } response = self.client.post(url, data) self.assert_400(response) # valid ClientID - data['client_id'] = AutoAssistedTaggingDraftEntryHandler.get_client_id(self.lead) + data["client_id"] = AutoAssistedTaggingDraftEntryHandler.get_client_id(self.lead) response = self.client.post(url, data) self.assert_200(response) self.lead.refresh_from_db() - self.assertEqual(str(LeadPreview.objects.get(lead=self.lead).text_extraction_id), data['text_extraction_id']) + self.assertEqual(str(LeadPreview.objects.get(lead=self.lead).text_extraction_id), data["text_extraction_id"]) self.assertEqual(self.lead.auto_entry_extraction_status, Lead.AutoExtractionStatus.SUCCESS) diff --git a/apps/lead/tests/test_filters.py b/apps/lead/tests/test_filters.py index f82055d3f7..dcf304c92e 100644 --- a/apps/lead/tests/test_filters.py +++ b/apps/lead/tests/test_filters.py @@ -1,9 +1,9 @@ -from utils.graphene.tests import GraphQLTestCase - -from lead.filter_set import LeadGroupGQFilterSet from lead.factories import LeadGroupFactory +from lead.filter_set import LeadGroupGQFilterSet from project.factories import ProjectFactory +from utils.graphene.tests import GraphQLTestCase + class TestLeadGroupFilter(GraphQLTestCase): def setUp(self) -> None: @@ -12,14 +12,9 @@ def setUp(self) -> None: def test_search_filter(self): project = ProjectFactory.create() - LeadGroupFactory.create(title='one', project=project) - lg2 = LeadGroupFactory.create(title='two', project=project) - lg3 = LeadGroupFactory.create(title='twoo', project=project) - obtained = self.filter_class(data=dict( - search='tw' - )).qs + LeadGroupFactory.create(title="one", project=project) + lg2 = LeadGroupFactory.create(title="two", project=project) + lg3 = LeadGroupFactory.create(title="twoo", project=project) + obtained = self.filter_class(data=dict(search="tw")).qs expected = [lg2, lg3] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) diff --git a/apps/lead/tests/test_migrations.py b/apps/lead/tests/test_migrations.py index 8d93fe5609..28c067d134 100644 --- a/apps/lead/tests/test_migrations.py +++ b/apps/lead/tests/test_migrations.py @@ -1,9 +1,9 @@ import importlib -from deep.tests import TestCase - -from lead.models import Lead from ary.models import Assessment +from lead.models import Lead + +from deep.tests import TestCase class TestCustomMigrationsLogic(TestCase): @@ -13,7 +13,7 @@ class TestCustomMigrationsLogic(TestCase): """ def test_lead_is_assessment_migration(self): - migration_file = importlib.import_module('lead.migrations.0037_auto_20210715_0432') + migration_file = importlib.import_module("lead.migrations.0037_auto_20210715_0432") lead_1 = self.create_lead() lead_2 = self.create_lead() @@ -29,8 +29,6 @@ def test_lead_is_assessment_migration(self): assert Lead.objects.count() == 4 # should set the lead which have assesmment to `is_assessment_lead=True` - assert set( - Lead.objects.filter(is_assessment_lead=True) - ) == set([lead_3, lead_1, lead_2]) + assert set(Lead.objects.filter(is_assessment_lead=True)) == set([lead_3, lead_1, lead_2]) # check for the lead which has no any assessment created for - assert set(Lead.objects.filter(id=lead_4.id).values_list('is_assessment_lead', flat=True)) == set([False]) + assert set(Lead.objects.filter(id=lead_4.id).values_list("is_assessment_lead", flat=True)) == set([False]) diff --git a/apps/lead/tests/test_mutations.py b/apps/lead/tests/test_mutations.py index 0a894205d5..daf2aee398 100644 --- a/apps/lead/tests/test_mutations.py +++ b/apps/lead/tests/test_mutations.py @@ -1,24 +1,24 @@ from unittest import mock -from utils.graphene.tests import GraphQLTestCase, GraphQLSnapShotTestCase -from organization.factories import OrganizationFactory -from user.factories import UserFactory -from project.factories import ProjectFactory - -from lead.models import Lead from gallery.factories import FileFactory from lead.factories import ( - LeadFactory, EmmEntityFactory, - LeadGroupFactory, LeadEMMTriggerFactory, + LeadFactory, + LeadGroupFactory, LeadPreviewFactory, LeadPreviewImageFactory, ) +from lead.models import Lead +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase, GraphQLTestCase class TestLeadMutationSchema(GraphQLTestCase): - CREATE_LEAD_QUERY = ''' + CREATE_LEAD_QUERY = """ mutation MyMutation ($projectId: ID!, $input: LeadInputType!) { project(id: $projectId) { leadCreate1: leadCreate(data: $input) { @@ -55,7 +55,7 @@ class TestLeadMutationSchema(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -67,22 +67,18 @@ def setUp(self): self.project.add_member(self.readonly_member_user, role=self.project_role_reader_non_confidential) self.project.add_member(self.member_user, role=self.project_role_member) - @mock.patch('lead.serializers.index_lead_and_calculate_duplicates.delay') + @mock.patch("lead.serializers.index_lead_and_calculate_duplicates.delay") def test_lead_create(self, index_and_calculate_dups_func): """ This test makes sure only valid users can create lead """ + def _query_check(minput, **kwargs): with self.captureOnCommitCallbacks(execute=True): - return self.query_check( - self.CREATE_LEAD_QUERY, - minput=minput, - variables={'projectId': self.project.id}, - **kwargs - ) + return self.query_check(self.CREATE_LEAD_QUERY, minput=minput, variables={"projectId": self.project.id}, **kwargs) minput = dict( - title='Lead Title 101', + title="Lead Title 101", ) # -- Without login _query_check(minput, assert_for_error=True) @@ -97,8 +93,8 @@ def _query_check(minput, **kwargs): # --- member user self.force_login(self.member_user) - content = _query_check(minput)['data']['project']['leadCreate1']['result'] - self.assertEqual(content['title'], minput['title'], content) + content = _query_check(minput)["data"]["project"]["leadCreate1"]["result"] + self.assertEqual(content["title"], minput["title"], content) index_and_calculate_dups_func.assert_called() @@ -115,23 +111,23 @@ def test_lead_create_validation(self): emm_entity_2 = EmmEntityFactory.create() minput = dict( - title='Lead Title 101', + title="Lead Title 101", confidentiality=self.genum(Lead.Confidentiality.UNPROTECTED), priority=self.genum(Lead.Priority.MEDIUM), status=self.genum(Lead.Status.NOT_TAGGED), - publishedOn='2020-09-25', + publishedOn="2020-09-25", source=org2.pk, authors=[org1.pk, org2.pk], - text='Random Text', - url='', + text="Random Text", + url="", emmEntities=[ dict(name=emm_entity_1.name), dict(name=emm_entity_2.name), ], emmTriggers=[ # Return order is by count so let's keep higher count first - dict(emmKeyword='emm-keyword-1', emmRiskFactor='emm-risk-factor-1', count=20), - dict(emmKeyword='emm-keyword-2', emmRiskFactor='emm-risk-factor-2', count=10), + dict(emmKeyword="emm-keyword-1", emmRiskFactor="emm-risk-factor-1", count=20), + dict(emmKeyword="emm-keyword-2", emmRiskFactor="emm-risk-factor-2", count=10), ], ) @@ -139,8 +135,8 @@ def _query_check(**kwargs): return self.query_check( self.CREATE_LEAD_QUERY, minput=minput, - mnested=['project'], - variables={'projectId': self.project.id}, + mnested=["project"], + variables={"projectId": self.project.id}, **kwargs, ) @@ -148,56 +144,56 @@ def _query_check(**kwargs): self.force_login(self.member_user) # ------ Non member assignee - minput['sourceType'] = self.genum(Lead.SourceType.TEXT) - minput['text'] = 'Text 123' - minput['assignee'] = self.non_member_user.pk - result = _query_check(okay=False)['data']['project']['leadCreate1']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.TEXT) + minput["text"] = "Text 123" + minput["assignee"] = self.non_member_user.pk + result = _query_check(okay=False)["data"]["project"]["leadCreate1"]["result"] self.assertEqual(result, None, result) # ------ Member assignee (TODO: Test partial update as well) + Text Test - minput['assignee'] = self.member_user.pk - minput['text'] = 'Text 123123' # Need to provide different text - result = _query_check(okay=True)['data']['project']['leadCreate1']['result'] - self.assertIdEqual(result['assignee']['id'], minput['assignee'], result) - self.assertCustomDictEqual(result, minput, result, ignore_keys=['id', 'source', 'authors', 'assignee']) - self.assertIdEqual(result['source']['id'], minput['source'], result) - self.assertListIds(result['authors'], minput['authors'], result, get_excepted_list_id=lambda x: str(x)) + minput["assignee"] = self.member_user.pk + minput["text"] = "Text 123123" # Need to provide different text + result = _query_check(okay=True)["data"]["project"]["leadCreate1"]["result"] + self.assertIdEqual(result["assignee"]["id"], minput["assignee"], result) + self.assertCustomDictEqual(result, minput, result, ignore_keys=["id", "source", "authors", "assignee"]) + self.assertIdEqual(result["source"]["id"], minput["source"], result) + self.assertListIds(result["authors"], minput["authors"], result, get_excepted_list_id=lambda x: str(x)) # ------ Disk # File not-owned - minput['sourceType'] = self.genum(Lead.SourceType.DISK) - minput['attachment'] = other_file.pk - result = _query_check(okay=False)['data']['project']['leadCreate1']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.DISK) + minput["attachment"] = other_file.pk + result = _query_check(okay=False)["data"]["project"]["leadCreate1"]["result"] self.assertEqual(result, None, result) # File owned - minput['sourceType'] = self.genum(Lead.SourceType.DISK) - minput['attachment'] = our_file.pk - result = _query_check(okay=True)['data']['project']['leadCreate1']['result'] - self.assertEqual(result['title'], minput['title'], result) + minput["sourceType"] = self.genum(Lead.SourceType.DISK) + minput["attachment"] = our_file.pk + result = _query_check(okay=True)["data"]["project"]["leadCreate1"]["result"] + self.assertEqual(result["title"], minput["title"], result) # -------- Duplicate leads validations # ------------- Text (Using duplicate text) - minput['sourceType'] = self.genum(Lead.SourceType.TEXT) - result = _query_check(okay=False)['data']['project']['leadCreate1']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.TEXT) + result = _query_check(okay=False)["data"]["project"]["leadCreate1"]["result"] self.assertEqual(result, None, result) # ------------- Website - minput['sourceType'] = self.genum(Lead.SourceType.WEBSITE) - minput['url'] = 'http://www.example.com/random-path' - result = _query_check(okay=True)['data']['project']['leadCreate1']['result'] - self.assertCustomDictEqual(result, minput, result, only_keys=['url']) + minput["sourceType"] = self.genum(Lead.SourceType.WEBSITE) + minput["url"] = "http://www.example.com/random-path" + result = _query_check(okay=True)["data"]["project"]["leadCreate1"]["result"] + self.assertCustomDictEqual(result, minput, result, only_keys=["url"]) # Try again will end in error - result = _query_check(okay=False)['data']['project']['leadCreate1']['result'] + result = _query_check(okay=False)["data"]["project"]["leadCreate1"]["result"] self.assertEqual(result, None, result) # ------------- Attachment - minput['sourceType'] = self.genum(Lead.SourceType.DISK) - minput['attachment'] = our_file.pk # Already created this above resulting in error - result = _query_check(okay=False)['data']['project']['leadCreate1']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.DISK) + minput["attachment"] = our_file.pk # Already created this above resulting in error + result = _query_check(okay=False)["data"]["project"]["leadCreate1"]["result"] self.assertEqual(result, None, result) - @mock.patch('lead.receivers.update_index_and_duplicates') + @mock.patch("lead.receivers.update_index_and_duplicates") def test_lead_delete_validation(self, update_indices_func): """ This test checks create lead validations """ - query = ''' + query = """ mutation MyMutation ($projectId: ID! $leadId: ID!) { project(id: $projectId) { leadDelete(id: $leadId) { @@ -211,7 +207,7 @@ def test_lead_delete_validation(self, update_indices_func): } } } - ''' + """ non_access_lead = LeadFactory.create() lead = LeadFactory.create(project=self.project) @@ -220,8 +216,8 @@ def _query_check(lead, will_delete=False, **kwargs): with self.captureOnCommitCallbacks(execute=True): result = self.query_check( query, - mnested=['project'], - variables={'projectId': self.project.id, 'leadId': lead.id}, + mnested=["project"], + variables={"projectId": self.project.id, "leadId": lead.id}, **kwargs, ) if will_delete: @@ -245,12 +241,12 @@ def _query_check(lead, will_delete=False, **kwargs): # ------- login as normal member self.force_login(self.member_user) # Success with normal lead (with project membership) - result = _query_check(lead, will_delete=True, okay=True)['data']['project']['leadDelete']['result'] - self.assertEqual(result['title'], lead.title, result) + result = _query_check(lead, will_delete=True, okay=True)["data"]["project"]["leadDelete"]["result"] + self.assertEqual(result["title"], lead.title, result) update_indices_func.assert_called_once() def test_lead_update_validation(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID! $leadId: ID! $input: LeadInputType!) { project(id: $projectId) { leadUpdate(id: $leadId data: $input) { @@ -280,20 +276,20 @@ def test_lead_update_validation(self): } } } - ''' + """ lead = LeadFactory.create(project=self.project) non_access_lead = LeadFactory.create() user_file = FileFactory.create(created_by=self.member_user) - minput = dict(title='New Lead') + minput = dict(title="New Lead") def _query_check(_lead, **kwargs): return self.query_check( query, minput=minput, - mnested=['project'], - variables={'projectId': self.project.id, 'leadId': _lead.id}, + mnested=["project"], + variables={"projectId": self.project.id, "leadId": _lead.id}, **kwargs, ) @@ -305,46 +301,45 @@ def _query_check(_lead, **kwargs): # ------- Non access lead _query_check(non_access_lead, okay=False) # ------- Access lead - result = _query_check(lead, okay=True)['data']['project']['leadUpdate']['result'] - self.assertEqual(result['title'], minput['title'], result) + result = _query_check(lead, okay=True)["data"]["project"]["leadUpdate"]["result"] + self.assertEqual(result["title"], minput["title"], result) # -------- Duplicate leads validations # ------------ Text (Using duplicate text) new_lead = LeadFactory.create(project=self.project) - minput['sourceType'] = self.genum(Lead.SourceType.TEXT) - minput['text'] = new_lead.text - result = _query_check(lead, okay=False)['data']['project']['leadUpdate']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.TEXT) + minput["text"] = new_lead.text + result = _query_check(lead, okay=False)["data"]["project"]["leadUpdate"]["result"] self.assertEqual(result, None, result) new_lead.delete() # Can save after deleting the conflicting lead. - result = _query_check(lead, okay=True)['data']['project']['leadUpdate']['result'] - self.assertEqual(result['title'], minput['title'], result) + result = _query_check(lead, okay=True)["data"]["project"]["leadUpdate"]["result"] + self.assertEqual(result["title"], minput["title"], result) # ------------ Website (Using duplicate website) new_lead = LeadFactory.create( - project=self.project, source_type=Lead.SourceType.WEBSITE, - url='https://example.com/random-path' + project=self.project, source_type=Lead.SourceType.WEBSITE, url="https://example.com/random-path" ) - minput['sourceType'] = self.genum(Lead.SourceType.WEBSITE) - minput['url'] = new_lead.url - result = _query_check(lead, okay=False)['data']['project']['leadUpdate']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.WEBSITE) + minput["url"] = new_lead.url + result = _query_check(lead, okay=False)["data"]["project"]["leadUpdate"]["result"] self.assertEqual(result, None, result) new_lead.delete() # Can save after deleting the conflicting lead. - result = _query_check(lead, okay=True)['data']['project']['leadUpdate']['result'] - self.assertEqual(result['url'], minput['url'], result) + result = _query_check(lead, okay=True)["data"]["project"]["leadUpdate"]["result"] + self.assertEqual(result["url"], minput["url"], result) # ------------ Attachment (Using duplicate file) new_lead = LeadFactory.create(project=self.project, source_type=Lead.SourceType.DISK, attachment=user_file) - minput['sourceType'] = self.genum(Lead.SourceType.DISK) - minput['attachment'] = new_lead.attachment.pk - result = _query_check(lead, okay=False)['data']['project']['leadUpdate']['result'] + minput["sourceType"] = self.genum(Lead.SourceType.DISK) + minput["attachment"] = new_lead.attachment.pk + result = _query_check(lead, okay=False)["data"]["project"]["leadUpdate"]["result"] self.assertEqual(result, None, result) new_lead.delete() # Can save after deleting the conflicting lead. - result = _query_check(lead, okay=True)['data']['project']['leadUpdate']['result'] - self.assertIdEqual(result['attachment']['id'], minput['attachment'], result) + result = _query_check(lead, okay=True)["data"]["project"]["leadUpdate"]["result"] + self.assertIdEqual(result["attachment"]["id"], minput["attachment"], result) class TestLeadBulkMutationSchema(GraphQLSnapShotTestCase): factories_used = [UserFactory, ProjectFactory, LeadFactory] def test_lead_bulk(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID! $input: [BulkLeadInputType!]) { project(id: $projectId) { leadBulk(items: $input) { @@ -357,42 +352,44 @@ def test_lead_bulk(self): } } } - ''' + """ project = ProjectFactory.create() # User with role user = UserFactory.create() project.add_member(user, role=self.project_role_member) lead1 = LeadFactory.create(project=project) - lead2 = LeadFactory.create(project=project, source_type=Lead.SourceType.WEBSITE, url='https://example.com/path') + lead2 = LeadFactory.create(project=project, source_type=Lead.SourceType.WEBSITE, url="https://example.com/path") lead_count = Lead.objects.count() minput = [ - dict(title='Lead title 1', clientId='new-lead-1'), - dict(title='Lead title 2', clientId='new-lead-2'), + dict(title="Lead title 1", clientId="new-lead-1"), + dict(title="Lead title 2", clientId="new-lead-2"), dict( - title='Lead title 4', sourceType=self.genum(Lead.SourceType.WEBSITE), url='https://example.com/path', - clientId='new-lead-3', + title="Lead title 4", + sourceType=self.genum(Lead.SourceType.WEBSITE), + url="https://example.com/path", + clientId="new-lead-3", ), - dict(id=str(lead1.pk), title='Lead title 3'), - dict(id=str(lead2.pk), title='Lead title 4'), + dict(id=str(lead1.pk), title="Lead title 3"), + dict(id=str(lead2.pk), title="Lead title 4"), ] def _query_check(**kwargs): - return self.query_check(query, minput=minput, variables={'projectId': project.pk}, **kwargs) + return self.query_check(query, minput=minput, variables={"projectId": project.pk}, **kwargs) # --- without login _query_check(assert_for_error=True) # --- with login self.force_login(user) - response = _query_check()['data']['project']['leadBulk'] - self.assertMatchSnapshot(response, 'success') + response = _query_check()["data"]["project"]["leadBulk"] + self.assertMatchSnapshot(response, "success") self.assertEqual(lead_count + 2, Lead.objects.count()) class TestLeadGroupMutation(GraphQLTestCase): def test_lead_group_delete(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID! $leadGroupId: ID!) { project(id: $projectId) { leadGroupDelete(id: $leadGroupId) { @@ -405,7 +402,7 @@ def test_lead_group_delete(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() non_member_user = UserFactory.create() @@ -413,11 +410,7 @@ def test_lead_group_delete(self): lead_group = LeadGroupFactory.create(project=project) def _query_check(**kwargs): - return self.query_check( - query, - variables={'projectId': project.id, 'leadGroupId': lead_group.id}, - **kwargs - ) + return self.query_check(query, variables={"projectId": project.id, "leadGroupId": lead_group.id}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -425,8 +418,8 @@ def _query_check(**kwargs): # --- member user self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['leadGroupDelete']['ok'], True) - self.assertIdEqual(content['data']['project']['leadGroupDelete']['result']['id'], lead_group.id) + self.assertEqual(content["data"]["project"]["leadGroupDelete"]["ok"], True) + self.assertIdEqual(content["data"]["project"]["leadGroupDelete"]["result"]["id"], lead_group.id) # -- non-member user self.force_login(non_member_user) @@ -435,7 +428,7 @@ def _query_check(**kwargs): class TestLeadCopyMutation(GraphQLTestCase): def test_lead_copy_mutation(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID! $input: LeadCopyInputType!) { project(id: $projectId) { leadCopy(data: $input) { @@ -457,18 +450,18 @@ def test_lead_copy_mutation(self): } } } - ''' + """ member_user = UserFactory.create() member_user_only_protected = UserFactory.create() non_member_user = UserFactory.create() created_by_user = UserFactory.create() # Source Projects - wa_source_project = ProjectFactory.create(title='With access Source Project') # With access - woa_source_project = ProjectFactory.create(title='Without access Source Project') # Without access + wa_source_project = ProjectFactory.create(title="With access Source Project") # With access + woa_source_project = ProjectFactory.create(title="Without access Source Project") # Without access # Destination Projects - wa_destination_project = ProjectFactory.create(title='With access Destination Project') # With access - woa_destination_project = ProjectFactory.create(title='Without access Destination Project') # Without access + wa_destination_project = ProjectFactory.create(title="With access Destination Project") # With access + woa_destination_project = ProjectFactory.create(title="Without access Destination Project") # Without access # Assign access wa_source_project.add_member(member_user) wa_source_project.add_member(member_user_only_protected, role=self.project_role_reader_non_confidential) @@ -478,37 +471,37 @@ def test_lead_copy_mutation(self): woa_source_project.add_member(member_user_only_protected, role=self.project_base_access) # With no lead read access # Lead1 Info (Will be used later for testing) - author1 = OrganizationFactory.create(title='author1') - author2 = OrganizationFactory.create(title='author2') - emm_entity = EmmEntityFactory.create(name='emm_entity_11') + author1 = OrganizationFactory.create(title="author1") + author2 = OrganizationFactory.create(title="author2") + emm_entity = EmmEntityFactory.create(name="emm_entity_11") # Generate some leads in source projects. wa_lead_confidential = LeadFactory.create( - title='Confidential Lead (with-access)', + title="Confidential Lead (with-access)", project=wa_source_project, source_type=Lead.SourceType.WEBSITE, - url='http://confidential-lead.example.com', + url="http://confidential-lead.example.com", confidentiality=Lead.Confidentiality.CONFIDENTIAL, ) wa_lead1 = LeadFactory.create( - title='Lead 1 (with-access)', + title="Lead 1 (with-access)", project=wa_source_project, source_type=Lead.SourceType.WEBSITE, - url='http://example.com', + url="http://example.com", created_by=created_by_user, - status=Lead.Status.TAGGED + status=Lead.Status.TAGGED, ) wa_lead2 = LeadFactory.create( - title='Lead 2 (with-access)', + title="Lead 2 (with-access)", project=wa_source_project, source_type=Lead.SourceType.WEBSITE, - url='http://another.example.com' + url="http://another.example.com", ) woa_lead3 = LeadFactory.create( - title='Lead 3 (without-access)', + title="Lead 3 (without-access)", project=woa_source_project, source_type=Lead.SourceType.WEBSITE, - url='http://another-2.example.com' + url="http://another-2.example.com", ) # Assign authors wa_lead1.authors.set([author1, author2]) @@ -516,12 +509,12 @@ def test_lead_copy_mutation(self): woa_lead3.authors.set([author2]) # Generating Foreign elements for wa_lead1 - wa_lead1_preview = LeadPreviewFactory.create(lead=wa_lead1, text_extract='This is a random text extarct') - wa_lead1_image_preview = LeadPreviewImageFactory.create(lead=wa_lead1, file='test-file-123') + wa_lead1_preview = LeadPreviewFactory.create(lead=wa_lead1, text_extract="This is a random text extarct") + wa_lead1_image_preview = LeadPreviewImageFactory.create(lead=wa_lead1, file="test-file-123") LeadEMMTriggerFactory.create( lead=wa_lead1, - emm_keyword='emm1', - emm_risk_factor='risk1', + emm_keyword="emm1", + emm_risk_factor="risk1", count=22, ) wa_lead1.emm_entities.set([emm_entity]) @@ -536,20 +529,15 @@ def test_lead_copy_mutation(self): # test for single lead copy minput = { - 'projects': [ + "projects": [ wa_destination_project.id, # Lead will be added here woa_destination_project.id, # No Lead are added here ], - 'leads': [ - wa_lead_confidential.id, - wa_lead1.id, - wa_lead2.id, - woa_lead3.id - ] + "leads": [wa_lead_confidential.id, wa_lead1.id, wa_lead2.id, woa_lead3.id], } def _query_check(source_project, **kwargs): - return self.query_check(query, minput=minput, variables={'projectId': source_project.pk}, **kwargs) + return self.query_check(query, minput=minput, variables={"projectId": source_project.pk}, **kwargs) # without login _query_check(wa_source_project, assert_for_error=True) @@ -572,7 +560,7 @@ def _query_check(source_project, **kwargs): wa_current_leads_count = wa_destination_project.lead_set.count() woa_current_leads_count = woa_destination_project.lead_set.count() # Call endpoint - new_leads = _query_check(wa_source_project)['data']['project']['leadCopy']['result'] + new_leads = _query_check(wa_source_project)["data"]["project"]["leadCopy"]["result"] # lets make sure lead is copied to the destination project wa_new_count = wa_destination_project.lead_set.count() woa_new_leads_count = woa_destination_project.lead_set.count() @@ -611,10 +599,10 @@ def _query_check(source_project, **kwargs): self.assertEqual(copied_lead1.confidentiality, wa_lead1.confidentiality) # lets check for the foreign key field copy self.assertEqual(copied_lead1.leadpreview.text_extract, wa_lead1_preview.text_extract) - self.assertEqual(list(copied_lead1.images.values_list('file', flat=True)), [wa_lead1_image_preview.file.name]) + self.assertEqual(list(copied_lead1.images.values_list("file", flat=True)), [wa_lead1_image_preview.file.name]) self.assertEqual( - list(copied_lead1.emm_triggers.values('emm_keyword', 'emm_risk_factor', 'count')), - list(wa_lead1.emm_triggers.values('emm_keyword', 'emm_risk_factor', 'count')), + list(copied_lead1.emm_triggers.values("emm_keyword", "emm_risk_factor", "count")), + list(wa_lead1.emm_triggers.values("emm_keyword", "emm_risk_factor", "count")), ) self.assertEqual( list(copied_lead1.emm_entities.all()), diff --git a/apps/lead/tests/test_permissions.py b/apps/lead/tests/test_permissions.py index 57e1c5e0e2..533ce32f65 100644 --- a/apps/lead/tests/test_permissions.py +++ b/apps/lead/tests/test_permissions.py @@ -1,34 +1,30 @@ -from deep.tests import TestCase - from lead.models import Lead from organization.models import Organization -from project.permissions import PROJECT_PERMISSIONS, get_project_permissions_value from project.models import Project, ProjectRole +from project.permissions import PROJECT_PERMISSIONS, get_project_permissions_value + +from deep.tests import TestCase class TestLeadPermissions(TestCase): def setUp(self): super().setUp() common_role_attrs = { - 'entry_permissions': get_project_permissions_value('entry', '__all__'), - 'setup_permissions': get_project_permissions_value('setup', '__all__'), - 'export_permissions': get_project_permissions_value('export', '__all__'), - 'assessment_permissions': get_project_permissions_value('assessment', '__all__'), + "entry_permissions": get_project_permissions_value("entry", "__all__"), + "setup_permissions": get_project_permissions_value("setup", "__all__"), + "export_permissions": get_project_permissions_value("export", "__all__"), + "assessment_permissions": get_project_permissions_value("assessment", "__all__"), } self.no_lead_creation_role = ProjectRole.objects.create( - title='No Lead Creation Role', - lead_permissions=0, - **common_role_attrs + title="No Lead Creation Role", lead_permissions=0, **common_role_attrs ) self.lead_creation_role = ProjectRole.objects.create( - title='Lead Creation Role', - lead_permissions=get_project_permissions_value('lead', ['create']), - **common_role_attrs + title="Lead Creation Role", lead_permissions=get_project_permissions_value("lead", ["create"]), **common_role_attrs ) self.lead_view_clone_role = ProjectRole.objects.create( - title='Lead View Role', - lead_permissions=get_project_permissions_value('lead', ['view', 'create']), - **common_role_attrs + title="Lead View Role", + lead_permissions=get_project_permissions_value("lead", ["view", "create"]), + **common_role_attrs, ) self.author = self.source = self.create_organization() @@ -47,10 +43,10 @@ def _test_lead_copy_no_permission(self): lead = self.create(Lead, project=source_project) data = { - 'projects': [dest_project.pk], - 'leads': [lead.pk], + "projects": [dest_project.pk], + "leads": [lead.pk], } - url = '/api/v1/lead-copy/' + url = "/api/v1/lead-copy/" self.authenticate() response = self.client.post(url, data) @@ -65,18 +61,17 @@ def test_lead_copy_with_permission(self): initial_lead_count = Lead.objects.count() data = { - 'projects': [dest_project.pk], - 'leads': [lead.pk], + "projects": [dest_project.pk], + "leads": [lead.pk], } - url = '/api/v1/lead-copy/' + url = "/api/v1/lead-copy/" self.authenticate() response = self.client.post(url, data) self.assert_201(response) assert Lead.objects.count() == initial_lead_count + 1, "One more lead should be created" - assert Lead.objects.filter(title=lead.title, project=dest_project).exists(), \ - "Exact same lead should be created" + assert Lead.objects.filter(title=lead.title, project=dest_project).exists(), "Exact same lead should be created" def test_cannot_view_confidential_lead_without_permissions(self): view_unprotected_role = ProjectRole.objects.create( @@ -87,24 +82,24 @@ def test_cannot_view_confidential_lead_without_permissions(self): lead1 = self.create_lead(project=project, confidentiality=Lead.Confidentiality.UNPROTECTED) lead_confidential = self.create_lead(project=project, confidentiality=Lead.Confidentiality.CONFIDENTIAL) - url = '/api/v1/leads/' + url = "/api/v1/leads/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) - leads_ids = set([x['id'] for x in resp.data['results']]) + leads_ids = set([x["id"] for x in resp.data["results"]]) assert leads_ids == {lead1.id}, "Only confidential should be present" # Check get particuar non-confidential lead, should return 200 - url = f'/api/v1/leads/{lead1.id}/' + url = f"/api/v1/leads/{lead1.id}/" self.authenticate() resp = self.client.get(url) self.assert_200(resp) # Check get particuar confidential lead, should return 404 - url = f'/api/v1/leads/{lead_confidential.id}/' + url = f"/api/v1/leads/{lead_confidential.id}/" self.authenticate() resp = self.client.get(url) @@ -113,16 +108,16 @@ def test_cannot_view_confidential_lead_without_permissions(self): def test_create_lead_no_permission(self): # Create a project where self.user has no lead creation permission project = self.create(Project, role=self.no_lead_creation_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": self.user.id, } self.authenticate() response = self.client.post(url, data) @@ -131,16 +126,16 @@ def test_create_lead_no_permission(self): def test_create_lead_with_permission(self): # Create a project where self.user has no lead creation permission project = self.create(Project, role=self.lead_creation_role) - url = '/api/v1/leads/' + url = "/api/v1/leads/" data = { - 'title': 'Spaceship spotted in sky', - 'project': project.id, - 'source': self.source.pk, - 'author': self.author.pk, - 'confidentiality': Lead.Confidentiality.UNPROTECTED, - 'status': Lead.Status.NOT_TAGGED, - 'text': 'Alien shapeship has been spotted in the sky', - 'assignee': self.user.id, + "title": "Spaceship spotted in sky", + "project": project.id, + "source": self.source.pk, + "author": self.author.pk, + "confidentiality": Lead.Confidentiality.UNPROTECTED, + "status": Lead.Status.NOT_TAGGED, + "text": "Alien shapeship has been spotted in the sky", + "assignee": self.user.id, } self.authenticate() response = self.client.post(url, data) diff --git a/apps/lead/tests/test_schemas.py b/apps/lead/tests/test_schemas.py index 58c8d46368..cf75159de6 100644 --- a/apps/lead/tests/test_schemas.py +++ b/apps/lead/tests/test_schemas.py @@ -1,40 +1,33 @@ import json -from djangorestframework_camel_case.render import CamelCaseJSONRenderer -from utils.graphene.tests import GraphQLTestCase - -from organization.factories import OrganizationTypeFactory, OrganizationFactory -from user.factories import UserFactory -from project.factories import ProjectFactory - -from lead.models import Lead -from analysis_framework.models import Widget - -from entry.factories import EntryFactory -from ary.factories import AssessmentFactory -from geo.factories import ( - RegionFactory, - GeoAreaFactory, - AdminLevelFactory, -) from analysis_framework.factories import ( + AfFilterFactory, AnalysisFrameworkFactory, WidgetFactory, - AfFilterFactory, ) +from analysis_framework.models import Widget +from ary.factories import AssessmentFactory +from djangorestframework_camel_case.render import CamelCaseJSONRenderer +from entry.factories import EntryFactory +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory +from lead.enums import LeadOrderingEnum from lead.factories import ( + EmmEntityFactory, LeadEMMTriggerFactory, LeadFactory, - EmmEntityFactory, LeadGroupFactory, UserSavedLeadFilterFactory, ) +from lead.models import Lead +from organization.factories import OrganizationFactory, OrganizationTypeFactory +from project.factories import ProjectFactory +from user.factories import UserFactory -from lead.enums import LeadOrderingEnum +from utils.graphene.tests import GraphQLTestCase class TestLeadQuerySchema(GraphQLTestCase): - lead_filter_query = ''' + lead_filter_query = """ query MyQuery ( $projectId: ID! # lead Arguments @@ -98,13 +91,13 @@ class TestLeadQuerySchema(GraphQLTestCase): } } } - ''' + """ def test_lead_query(self): """ Test private + non-private project behaviour """ - query = ''' + query = """ query MyQuery ($projectId: ID! $leadId: ID!) { project(id: $projectId) { lead (id: $leadId) { @@ -114,7 +107,7 @@ def test_lead_query(self): } } } - ''' + """ project = ProjectFactory.create() # User with role @@ -127,7 +120,7 @@ def test_lead_query(self): confidential_lead = LeadFactory.create(project=project, confidentiality=Lead.Confidentiality.CONFIDENTIAL) def _query_check(lead, **kwargs): - return self.query_check(query, variables={'projectId': project.id, 'leadId': lead.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id, "leadId": lead.id}, **kwargs) # -- Without login _query_check(confidential_lead, assert_for_error=True) @@ -138,23 +131,23 @@ def _query_check(lead, **kwargs): # --- non-member user content = _query_check(normal_lead) - self.assertEqual(content['data']['project']['lead'], None, content) + self.assertEqual(content["data"]["project"]["lead"], None, content) content = _query_check(confidential_lead) - self.assertEqual(content['data']['project']['lead'], None, content) + self.assertEqual(content["data"]["project"]["lead"], None, content) # --- member user self.force_login(member_user) content = _query_check(normal_lead) - self.assertNotEqual(content['data']['project']['lead'], None, content) + self.assertNotEqual(content["data"]["project"]["lead"], None, content) content = _query_check(confidential_lead) - self.assertEqual(content['data']['project']['lead'], None, content) + self.assertEqual(content["data"]["project"]["lead"], None, content) # --- confidential member user self.force_login(confidential_member_user) content = _query_check(normal_lead) - self.assertNotEqual(content['data']['project']['lead'], None, content) + self.assertNotEqual(content["data"]["project"]["lead"], None, content) content = _query_check(confidential_lead) - self.assertNotEqual(content['data']['project']['lead'], None, content) + self.assertNotEqual(content["data"]["project"]["lead"], None, content) def test_lead_query_filter(self): af = AnalysisFrameworkFactory.create() @@ -173,7 +166,7 @@ def test_lead_query_filter(self): project.add_member(member2, role=self.project_role_reader) lead1 = LeadFactory.create( project=project, - title='Test 1', + title="Test 1", source_type=Lead.SourceType.TEXT, confidentiality=Lead.Confidentiality.CONFIDENTIAL, source=org1_child, @@ -185,7 +178,7 @@ def test_lead_query_filter(self): lead2 = LeadFactory.create( project=project, source_type=Lead.SourceType.TEXT, - title='Test 2', + title="Test 2", assignee=[member2], authors=[org2, org3], priority=Lead.Priority.HIGH, @@ -193,8 +186,8 @@ def test_lead_query_filter(self): lead3 = LeadFactory.create( project=project, source_type=Lead.SourceType.WEBSITE, - url='https://wwwexample.com/sample-1', - title='Sample 1', + url="https://wwwexample.com/sample-1", + title="Sample 1", confidentiality=Lead.Confidentiality.CONFIDENTIAL, source=org2, authors=[org1, org3], @@ -202,7 +195,7 @@ def test_lead_query_filter(self): ) lead4 = LeadFactory.create( project=project, - title='Sample 2', + title="Sample 2", source=org3, authors=[org1], priority=Lead.Priority.MEDIUM, @@ -210,7 +203,7 @@ def test_lead_query_filter(self): ) lead5 = LeadFactory.create( project=project, - title='Sample 3', + title="Sample 3", status=Lead.Status.TAGGED, assignee=[member2], source=org3, @@ -228,50 +221,47 @@ def test_lead_query_filter(self): # TODO: Add direct test for filter_set as well (is used within export) for filter_data, expected_leads in [ - ({'search': 'test'}, [lead1, lead2]), - ({'confidentiality': self.genum(Lead.Confidentiality.CONFIDENTIAL)}, [lead1, lead3]), - ({'assignees': [member2.pk]}, [lead2, lead5]), - ({'assignees': [member1.pk, member2.pk]}, [lead1, lead2, lead5]), - ({'authoringOrganizationTypes': [org_type2.pk]}, [lead1, lead2, lead3]), - ({'authoringOrganizationTypes': [org_type1.pk, org_type2.pk]}, [lead1, lead2, lead3, lead4]), - ({'authorOrganizations': [org1.pk, org2.pk]}, [lead1, lead2, lead3, lead4]), - ({'authorOrganizations': [org3.pk]}, [lead2, lead3]), - ({'sourceOrganizations': [org1.pk, org2.pk]}, [lead1, lead3]), - ({'sourceOrganizations': [org3.pk]}, [lead4, lead5]), - ({'priorities': [self.genum(Lead.Priority.HIGH)]}, [lead1, lead2]), - ({'priorities': [self.genum(Lead.Priority.LOW), self.genum(Lead.Priority.HIGH)]}, [lead1, lead2, lead3, lead5]), - ({'sourceTypes': [self.genum(Lead.SourceType.WEBSITE)]}, [lead3]), - ( - {'sourceTypes': [self.genum(Lead.SourceType.TEXT), self.genum(Lead.SourceType.WEBSITE)]}, - [lead1, lead2, lead3] - ), - ({'statuses': [self.genum(Lead.Status.NOT_TAGGED)]}, [lead2, lead3]), - ({'statuses': [self.genum(Lead.Status.IN_PROGRESS), self.genum(Lead.Status.TAGGED)]}, [lead1, lead4, lead5]), - ({'hasEntries': True}, [lead4, lead5]), - ({'hasEntries': False}, [lead1, lead2, lead3]), + ({"search": "test"}, [lead1, lead2]), + ({"confidentiality": self.genum(Lead.Confidentiality.CONFIDENTIAL)}, [lead1, lead3]), + ({"assignees": [member2.pk]}, [lead2, lead5]), + ({"assignees": [member1.pk, member2.pk]}, [lead1, lead2, lead5]), + ({"authoringOrganizationTypes": [org_type2.pk]}, [lead1, lead2, lead3]), + ({"authoringOrganizationTypes": [org_type1.pk, org_type2.pk]}, [lead1, lead2, lead3, lead4]), + ({"authorOrganizations": [org1.pk, org2.pk]}, [lead1, lead2, lead3, lead4]), + ({"authorOrganizations": [org3.pk]}, [lead2, lead3]), + ({"sourceOrganizations": [org1.pk, org2.pk]}, [lead1, lead3]), + ({"sourceOrganizations": [org3.pk]}, [lead4, lead5]), + ({"priorities": [self.genum(Lead.Priority.HIGH)]}, [lead1, lead2]), + ({"priorities": [self.genum(Lead.Priority.LOW), self.genum(Lead.Priority.HIGH)]}, [lead1, lead2, lead3, lead5]), + ({"sourceTypes": [self.genum(Lead.SourceType.WEBSITE)]}, [lead3]), + ({"sourceTypes": [self.genum(Lead.SourceType.TEXT), self.genum(Lead.SourceType.WEBSITE)]}, [lead1, lead2, lead3]), + ({"statuses": [self.genum(Lead.Status.NOT_TAGGED)]}, [lead2, lead3]), + ({"statuses": [self.genum(Lead.Status.IN_PROGRESS), self.genum(Lead.Status.TAGGED)]}, [lead1, lead4, lead5]), + ({"hasEntries": True}, [lead4, lead5]), + ({"hasEntries": False}, [lead1, lead2, lead3]), ( { - 'hasEntries': True, - 'ordering': [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], + "hasEntries": True, + "ordering": [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], }, - [lead5, lead4] + [lead5, lead4], ), ( { - 'hasEntries': True, - 'entriesFilterData': {}, - 'ordering': [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], + "hasEntries": True, + "entriesFilterData": {}, + "ordering": [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], }, - [lead5, lead4] + [lead5, lead4], ), ( { - 'entriesFilterData': {'controlled': True}, - 'ordering': [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], + "entriesFilterData": {"controlled": True}, + "ordering": [self.genum(LeadOrderingEnum.DESC_ENTRIES_COUNT), self.genum(LeadOrderingEnum.ASC_ID)], }, - [lead5] + [lead5], ), - ({'isAssessment': True}, [lead4, lead5]), + ({"isAssessment": True}, [lead4, lead5]), # TODO: # ({'emmEntities': []}, []), # ({'emmKeywords': []}, []), @@ -286,17 +276,16 @@ def test_lead_query_filter(self): # ({'createdAtGte': []}, []), # ({'createdAtLte': []}, []), ]: - content = self.query_check(self.lead_filter_query, variables={'projectId': project.id, **filter_data}) + content = self.query_check(self.lead_filter_query, variables={"projectId": project.id, **filter_data}) self.assertListIds( - content['data']['project']['leads']['results'], expected_leads, - {'response': content, 'filter': filter_data} + content["data"]["project"]["leads"]["results"], expected_leads, {"response": content, "filter": filter_data} ) def test_leads_query(self): """ Test private + non-private project behaviour """ - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { leads { @@ -311,7 +300,7 @@ def test_leads_query(self): } } } - ''' + """ project = ProjectFactory.create() # User with role @@ -325,7 +314,7 @@ def test_leads_query(self): confidential_leads = LeadFactory.create_batch(6, project=project, confidentiality=Lead.Confidentiality.CONFIDENTIAL) def _query_check(**kwargs): - return self.query_check(query, variables={'id': project.id}, **kwargs) + return self.query_check(query, variables={"id": project.id}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -336,23 +325,23 @@ def _query_check(**kwargs): # --- non-member user (zero leads) content = _query_check() - self.assertEqual(content['data']['project']['leads']['totalCount'], 0, content) - self.assertEqual(len(content['data']['project']['leads']['results']), 0, content) + self.assertEqual(content["data"]["project"]["leads"]["totalCount"], 0, content) + self.assertEqual(len(content["data"]["project"]["leads"]["results"]), 0, content) # --- member user (only unprotected leads) self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['leads']['totalCount'], 5, content) - self.assertListIds(content['data']['project']['leads']['results'], normal_leads, content) + self.assertEqual(content["data"]["project"]["leads"]["totalCount"], 5, content) + self.assertListIds(content["data"]["project"]["leads"]["results"], normal_leads, content) # --- confidential member user (all leads) self.force_login(confidential_member_user) content = _query_check() - self.assertEqual(content['data']['project']['leads']['totalCount'], 11, content) - self.assertListIds(content['data']['project']['leads']['results'], confidential_leads + normal_leads, content) + self.assertEqual(content["data"]["project"]["leads"]["totalCount"], 11, content) + self.assertListIds(content["data"]["project"]["leads"]["results"], confidential_leads + normal_leads, content) def test_lead_query_with_duplicates_true(self): - query = ''' + query = """ query MyQuery ($projectId: ID!) { project(id: $projectId) { leads (hasDuplicates: true) { @@ -365,7 +354,7 @@ def test_lead_query_with_duplicates_true(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() project.add_member(member_user, role=self.project_role_reader_non_confidential) @@ -384,11 +373,11 @@ def test_lead_query_with_duplicates_true(self): """ def _query_check(lead, **kwargs): - return self.query_check(query, variables={'projectId': project.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id}, **kwargs) self.force_login(member_user) content = _query_check(lead) - leads_resp = content['data']['project']['leads']['results'] + leads_resp = content["data"]["project"]["leads"]["results"] self.assertEqual(len(leads_resp), 6, "There are 6 leads which have/are duplicates.") for lead_resp in leads_resp: @@ -398,7 +387,7 @@ def _query_check(lead, **kwargs): self.assertEqual(lead_resp["duplicateLeadsCount"], 1) def test_lead_query_with_duplicates_false(self): - query = ''' + query = """ query MyQuery ($projectId: ID!) { project(id: $projectId) { leads (hasDuplicates: false) { @@ -411,7 +400,7 @@ def test_lead_query_with_duplicates_false(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() project.add_member(member_user, role=self.project_role_reader_non_confidential) @@ -421,18 +410,18 @@ def test_lead_query_with_duplicates_false(self): another_lead = LeadFactory.create(project=project) # noqa def _query_check(lead, **kwargs): - return self.query_check(query, variables={'projectId': project.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id}, **kwargs) self.force_login(member_user) content = _query_check(lead) - leads_resp = content['data']['project']['leads']['results'] + leads_resp = content["data"]["project"]["leads"]["results"] self.assertEqual(len(leads_resp), 1) lead_resp = leads_resp[0] self.assertEqual(lead_resp["id"], str(another_lead.id)) self.assertEqual(lead_resp["duplicateLeadsCount"], 0) def test_lead_query_with_duplicates(self): - query = ''' + query = """ query MyQuery ($projectId: ID! $duplicatesOf: ID!) { project(id: $projectId) { leads (duplicatesOf: $duplicatesOf) { @@ -443,7 +432,7 @@ def test_lead_query_with_duplicates(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() project.add_member(member_user, role=self.project_role_reader_non_confidential) @@ -452,7 +441,7 @@ def test_lead_query_with_duplicates(self): lead.duplicate_leads.set(duplicate_leads) def _query_check(lead, **kwargs): - return self.query_check(query, variables={'projectId': project.id, 'duplicatesOf': lead.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id, "duplicatesOf": lead.id}, **kwargs) self.force_login(member_user) content = _query_check(lead) @@ -465,7 +454,7 @@ def test_lead_query_with_duplicates_reverse(self): If lead A has duplicate_leads = [B, C, D] then querying duplicate leads of either B, C or D should return A. """ - query = ''' + query = """ query MyQuery ($projectId: ID! $duplicatesOf: ID!) { project(id: $projectId) { leads (duplicatesOf: $duplicatesOf) { @@ -476,7 +465,7 @@ def test_lead_query_with_duplicates_reverse(self): } } } - ''' + """ project = ProjectFactory.create() member_user = UserFactory.create() project.add_member(member_user, role=self.project_role_reader_non_confidential) @@ -485,7 +474,7 @@ def test_lead_query_with_duplicates_reverse(self): lead.duplicate_leads.set(duplicate_leads) def _query_check(lead, **kwargs): - return self.query_check(query, variables={'projectId': project.id, 'duplicatesOf': lead.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id, "duplicatesOf": lead.id}, **kwargs) self.force_login(member_user) for d_lead in duplicate_leads: @@ -499,7 +488,7 @@ def test_leads_fields_query(self): """ Test leads field value """ - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { analysisFramework { @@ -552,7 +541,7 @@ def test_leads_fields_query(self): } } } - ''' + """ af, af_new = AnalysisFrameworkFactory.create_batch(2) project = ProjectFactory.create(analysis_framework=af) @@ -578,49 +567,53 @@ def test_leads_fields_query(self): # --- member user (only unprotected leads) self.force_login(user) - content = self.query_check(query, variables={'id': project.id}) - self.assertIdEqual(content['data']['project']['analysisFramework']['id'], af.pk) - results = content['data']['project']['leads']['results'] + content = self.query_check(query, variables={"id": project.id}) + self.assertIdEqual(content["data"]["project"]["analysisFramework"]["id"], af.pk) + results = content["data"]["project"]["leads"]["results"] # Count check - self.assertEqual(content['data']['project']['leads']['totalCount'], 3, content) + self.assertEqual(content["data"]["project"]["leads"]["totalCount"], 3, content) self.assertListIds(results, [lead1, lead2, lead3], content) - self.assertEqual(len(results[0]['authors']), 0, content) + self.assertEqual(len(results[0]["authors"]), 0, content) # Source check - self.assertIdEqual(results[0]['source']['id'], org1.id, content) - self.assertEqual(results[0]['source']['logo']['file']['name'], str(org1.logo.file.name), content) - self.assertEqual(results[0]['source']['logo']['file']['url'], self.get_media_url(org1.logo.file.name), content) + self.assertIdEqual(results[0]["source"]["id"], org1.id, content) + self.assertEqual(results[0]["source"]["logo"]["file"]["name"], str(org1.logo.file.name), content) + self.assertEqual(results[0]["source"]["logo"]["file"]["url"], self.get_media_url(org1.logo.file.name), content) # Authors check - self.assertListIds(results[1]['authors'], [org1, org3], content) - self.assertIdEqual(results[1]['source']['mergedAs']['id'], org1.id, content) + self.assertListIds(results[1]["authors"], [org1, org3], content) + self.assertIdEqual(results[1]["source"]["mergedAs"]["id"], org1.id, content) # Entries Count check - for index, (total_count, controlled_count) in enumerate([ - [7, 2], - [10, 0], - [0, 0], - ]): - self.assertEqual(results[index]['entriesCount']['total'], total_count, content) - self.assertEqual(results[index]['entriesCount']['controlled'], controlled_count, content) + for index, (total_count, controlled_count) in enumerate( + [ + [7, 2], + [10, 0], + [0, 0], + ] + ): + self.assertEqual(results[index]["entriesCount"]["total"], total_count, content) + self.assertEqual(results[index]["entriesCount"]["controlled"], controlled_count, content) # Change AF, this will now not show old entries - content = self.query_check(query, variables={'id': project.id}) + content = self.query_check(query, variables={"id": project.id}) project.analysis_framework = af_new - project.save(update_fields=('analysis_framework',)) + project.save(update_fields=("analysis_framework",)) EntryFactory.create_batch(2, lead=lead1, controlled=True) EntryFactory.create_batch(1, lead=lead2, controlled=False) - content = self.query_check(query, variables={'id': project.id}) - self.assertIdEqual(content['data']['project']['analysisFramework']['id'], af_new.pk) - results = content['data']['project']['leads']['results'] + content = self.query_check(query, variables={"id": project.id}) + self.assertIdEqual(content["data"]["project"]["analysisFramework"]["id"], af_new.pk) + results = content["data"]["project"]["leads"]["results"] # Entries Count check (After AF change) - for index, (total_count, controlled_count) in enumerate([ - [2, 2], - [1, 0], - [0, 0], - ]): - self.assertEqual(results[index]['entriesCount']['total'], total_count, content) - self.assertEqual(results[index]['entriesCount']['controlled'], controlled_count, content) + for index, (total_count, controlled_count) in enumerate( + [ + [2, 2], + [1, 0], + [0, 0], + ] + ): + self.assertEqual(results[index]["entriesCount"]["total"], total_count, content) + self.assertEqual(results[index]["entriesCount"]["controlled"], controlled_count, content) def test_leads_entries_query(self): - query = ''' + query = """ query MyQuery ($id: ID!, $leadId: ID!) { project(id: $id) { analysisFramework { @@ -638,7 +631,7 @@ def test_leads_entries_query(self): } } } - ''' + """ af, af_new = AnalysisFrameworkFactory.create_batch(2) user = UserFactory.create() project = ProjectFactory.create(analysis_framework=af) @@ -649,38 +642,38 @@ def test_leads_entries_query(self): not_controlled_entries = EntryFactory.create_batch(3, lead=lead, controlled=False) def _query_check(): - return self.query_check(query, variables={'id': project.id, 'leadId': lead.id}) + return self.query_check(query, variables={"id": project.id, "leadId": lead.id}) # -- With login self.force_login(user) response = _query_check() - self.assertIdEqual(response['data']['project']['analysisFramework']['id'], af.pk) - content = response['data']['project']['lead'] - self.assertIdEqual(content['id'], lead.pk, content) - self.assertEqual(content['entriesCount']['total'], 5, content) - self.assertEqual(content['entriesCount']['controlled'], 2, content) - self.assertListIds(content['entries'], [*controlled_entries, *not_controlled_entries], content) + self.assertIdEqual(response["data"]["project"]["analysisFramework"]["id"], af.pk) + content = response["data"]["project"]["lead"] + self.assertIdEqual(content["id"], lead.pk, content) + self.assertEqual(content["entriesCount"]["total"], 5, content) + self.assertEqual(content["entriesCount"]["controlled"], 2, content) + self.assertListIds(content["entries"], [*controlled_entries, *not_controlled_entries], content) # Now change AF project.analysis_framework = af_new - project.save(update_fields=('analysis_framework',)) + project.save(update_fields=("analysis_framework",)) new_controlled_entries = EntryFactory.create_batch(4, lead=lead, controlled=True) new_not_controlled_entries = EntryFactory.create_batch(2, lead=lead, controlled=False) response = _query_check() - self.assertIdEqual(response['data']['project']['analysisFramework']['id'], af_new.pk) - content = response['data']['project']['lead'] - self.assertIdEqual(content['id'], lead.pk, content) - self.assertEqual(content['entriesCount']['total'], 6, content) - self.assertEqual(content['entriesCount']['controlled'], 4, content) - self.assertListIds(content['entries'], [*new_controlled_entries, *new_not_controlled_entries], content) + self.assertIdEqual(response["data"]["project"]["analysisFramework"]["id"], af_new.pk) + content = response["data"]["project"]["lead"] + self.assertIdEqual(content["id"], lead.pk, content) + self.assertEqual(content["entriesCount"]["total"], 6, content) + self.assertEqual(content["entriesCount"]["controlled"], 4, content) + self.assertListIds(content["entries"], [*new_controlled_entries, *new_not_controlled_entries], content) def test_lead_options_query(self): """ Test leads field value """ - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { leadGroups { @@ -708,7 +701,7 @@ def test_lead_options_query(self): } } } - ''' + """ project = ProjectFactory.create() project2 = ProjectFactory.create() member_user = UserFactory.create() @@ -725,8 +718,8 @@ def test_lead_options_query(self): emm_entity_3 = EmmEntityFactory.create() lead1 = LeadFactory.create( - project=project, emm_entities=[emm_entity_1, emm_entity_2], - confidentiality=Lead.Confidentiality.CONFIDENTIAL) + project=project, emm_entities=[emm_entity_1, emm_entity_2], confidentiality=Lead.Confidentiality.CONFIDENTIAL + ) lead2 = LeadFactory.create(project=project, emm_entities=[emm_entity_1]) lead3 = LeadFactory.create(project=project, emm_entities=[emm_entity_3]) lead4 = LeadFactory.create(project=project2, emm_entities=[emm_entity_3]) @@ -738,38 +731,38 @@ def test_lead_options_query(self): self.force_login(member_user) # test for lead group - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['leadGroups']['totalCount'], 2) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["leadGroups"]["totalCount"], 2) self.assertEqual( - set(result['id'] for result in content['data']['project']['leadGroups']['results']), - set([str(lead_group1.id), str(lead_group2.id)]) + set(result["id"] for result in content["data"]["project"]["leadGroups"]["results"]), + set([str(lead_group1.id), str(lead_group2.id)]), ) # with different project - content = self.query_check(query, variables={'id': project2.id}) - self.assertEqual(content['data']['project']['leadGroups']['totalCount'], 0) + content = self.query_check(query, variables={"id": project2.id}) + self.assertEqual(content["data"]["project"]["leadGroups"]["totalCount"], 0) # test for emm_entities # login with member_user - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['emmEntities']['totalCount'], 3) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["emmEntities"]["totalCount"], 3) # login with confidential_member_user self.force_login(confidential_member_user) - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['emmEntities']['totalCount'], 3) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["emmEntities"]["totalCount"], 3) # test for lead_emm_trigger # login with confidential_member_user - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['leadEmmTriggers']['totalCount'], 3) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["leadEmmTriggers"]["totalCount"], 3) # test for project that user is not member - content = self.query_check(query, variables={'id': project2.id}) - self.assertEqual(content['data']['project']['leadEmmTriggers']['totalCount'], 0) + content = self.query_check(query, variables={"id": project2.id}) + self.assertEqual(content["data"]["project"]["leadEmmTriggers"]["totalCount"], 0) def test_leads_status(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { leads(ordering: ASC_ID) { @@ -781,45 +774,46 @@ def test_leads_status(self): } } } - ''' + """ user = UserFactory.create() project = ProjectFactory.create(analysis_framework=AnalysisFrameworkFactory.create()) project.add_member(user) lead1, _ = LeadFactory.create_batch(2, project=project) def _query_check(): - return self.query_check(query, variables={'id': project.id}) + return self.query_check(query, variables={"id": project.id}) self.force_login(user) - content = _query_check()['data']['project']['leads']['results'] + content = _query_check()["data"]["project"]["leads"]["results"] self.assertEqual(len(content), 2, content) self.assertEqual( - set([lead['status'] for lead in content]), {self.genum(Lead.Status.NOT_TAGGED)}, + set([lead["status"] for lead in content]), + {self.genum(Lead.Status.NOT_TAGGED)}, content, ) # Add entry to lead1 entry1 = EntryFactory.create(lead=lead1) - content = _query_check()['data']['project']['leads']['results'] + content = _query_check()["data"]["project"]["leads"]["results"] self.assertEqual(len(content), 2, content) - self.assertEqual(content[0]['status'], self.genum(Lead.Status.IN_PROGRESS), content) - self.assertEqual(content[1]['status'], self.genum(Lead.Status.NOT_TAGGED), content) + self.assertEqual(content[0]["status"], self.genum(Lead.Status.IN_PROGRESS), content) + self.assertEqual(content[1]["status"], self.genum(Lead.Status.NOT_TAGGED), content) # Update lead1 status to TAGGED lead1.status = Lead.Status.TAGGED - lead1.save(update_fields=['status']) - content = _query_check()['data']['project']['leads']['results'] + lead1.save(update_fields=["status"]) + content = _query_check()["data"]["project"]["leads"]["results"] self.assertEqual(len(content), 2, content) - self.assertEqual(content[0]['status'], self.genum(Lead.Status.TAGGED), content) - self.assertEqual(content[1]['status'], self.genum(Lead.Status.NOT_TAGGED), content) + self.assertEqual(content[0]["status"], self.genum(Lead.Status.TAGGED), content) + self.assertEqual(content[1]["status"], self.genum(Lead.Status.NOT_TAGGED), content) # Now update entry1 entry1.save() - content = _query_check()['data']['project']['leads']['results'] + content = _query_check()["data"]["project"]["leads"]["results"] self.assertEqual(len(content), 2, content) # -- We don't change TAGGED -> IN_PROGRESS - self.assertEqual(content[0]['status'], self.genum(Lead.Status.TAGGED), content) - self.assertEqual(content[1]['status'], self.genum(Lead.Status.NOT_TAGGED), content) + self.assertEqual(content[0]["status"], self.genum(Lead.Status.TAGGED), content) + self.assertEqual(content[1]["status"], self.genum(Lead.Status.NOT_TAGGED), content) def test_lead_group_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { leadGroups(ordering: "id") { @@ -835,7 +829,7 @@ def test_lead_group_query(self): } } } - ''' + """ project = ProjectFactory.create() project2 = ProjectFactory.create() member_user = UserFactory.create() @@ -851,28 +845,28 @@ def test_lead_group_query(self): LeadFactory.create_batch(2, project=project, lead_group=lead_group3) self.force_login(member_user) - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['leadGroups']['totalCount'], 2) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["leadGroups"]["totalCount"], 2) self.assertEqual( - set(result['id'] for result in content['data']['project']['leadGroups']['results']), - set([str(lead_group1.id), str(lead_group2.id)]) + set(result["id"] for result in content["data"]["project"]["leadGroups"]["results"]), + set([str(lead_group1.id), str(lead_group2.id)]), ) - self.assertListIds(content['data']['project']['leadGroups']['results'], [lead_group1, lead_group2], content) + self.assertListIds(content["data"]["project"]["leadGroups"]["results"], [lead_group1, lead_group2], content) # login with non_member_user self.force_login(non_member_user) - content = self.query_check(query, variables={'id': project.id}) - self.assertEqual(content['data']['project']['leadGroups']['totalCount'], 0) + content = self.query_check(query, variables={"id": project.id}) + self.assertEqual(content["data"]["project"]["leadGroups"]["totalCount"], 0) # with different project self.force_login(member_user) - content = self.query_check(query, variables={'id': project2.id}) - self.assertEqual(content['data']['project']['leadGroups']['totalCount'], 1) - self.assertEqual(content['data']['project']['leadGroups']['results'][0]['id'], str(lead_group3.id)) - self.assertEqual(content['data']['project']['leadGroups']['results'][0]['leadCounts'], 2) + content = self.query_check(query, variables={"id": project2.id}) + self.assertEqual(content["data"]["project"]["leadGroups"]["totalCount"], 1) + self.assertEqual(content["data"]["project"]["leadGroups"]["results"][0]["id"], str(lead_group3.id)) + self.assertEqual(content["data"]["project"]["leadGroups"]["results"][0]["leadCounts"], 2) def test_public_lead_query(self): - query = ''' + query = """ query MyQuery ($uuid: UUID!) { publicLead(uuid: $uuid) { project { @@ -900,44 +894,47 @@ def test_public_lead_query(self): } } } - ''' + """ project = ProjectFactory.create() # User with role - non_member_user = UserFactory.create(email='non-member@x.y') - member_user = UserFactory.create(email='member@x.y') - confidential_member_user = UserFactory.create(email='confidential-member@x.y') + non_member_user = UserFactory.create(email="non-member@x.y") + member_user = UserFactory.create(email="member@x.y") + confidential_member_user = UserFactory.create(email="confidential-member@x.y") project.add_member(member_user, role=self.project_role_reader_non_confidential) project.add_member(confidential_member_user, role=self.project_role_reader) # Public project unprotected_lead = LeadFactory.create( project=project, confidentiality=Lead.Confidentiality.UNPROTECTED, - title='unprotected_lead', + title="unprotected_lead", ) restricted_lead = LeadFactory.create( project=project, confidentiality=Lead.Confidentiality.RESTRICTED, - title='restricted_lead', + title="restricted_lead", ) confidential_lead = LeadFactory.create( project=project, confidentiality=Lead.Confidentiality.CONFIDENTIAL, - title='confidential_lead', + title="confidential_lead", ) def _query_check(lead): - return self.query_check(query, variables={'uuid': str(lead.uuid)}) + return self.query_check(query, variables={"uuid": str(lead.uuid)}) cases = [ # Public Project # is_private, (public_lead, restricted_lead, confidential_lead) # : [Lead, show_project, show_lead, show_project_title] ( - False, (False, False, False), [ # Project view public leads + False, + (False, False, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, False, None], @@ -945,35 +942,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, True, False, None], [restricted_lead, True, False, None], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - False, (True, False, False), [ # Project view public leads + False, + (True, False, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, True, True], [restricted_lead, False, False, None], [confidential_lead, False, False, None], @@ -981,35 +984,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, False, None], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - False, (False, True, False), [ # Project view public leads + False, + (False, True, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, True, True], [confidential_lead, False, False, None], @@ -1017,35 +1026,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, True, False, None], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - False, (False, False, True), [ # Project view public leads + False, + (False, False, True), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, True, True], @@ -1053,36 +1068,42 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, True, False, None], [restricted_lead, True, False, None], [confidential_lead, True, True, True], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), # Private Project ( - True, (False, False, False), [ # Project view public leads + True, + (False, False, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, False, None], @@ -1090,35 +1111,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - True, (True, False, False), [ # Project view public leads + True, + (True, False, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, True, False], [restricted_lead, False, False, None], [confidential_lead, False, False, None], @@ -1126,35 +1153,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, False, True, False], [restricted_lead, False, False, None], [confidential_lead, False, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - True, (False, True, False), [ # Project view public leads + True, + (False, True, False), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, True, False], [confidential_lead, False, False, None], @@ -1162,35 +1195,41 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, False, False, None], [restricted_lead, False, True, False], [confidential_lead, False, False, None], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, False, None], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ( - True, (False, False, True), [ # Project view public leads + True, + (False, False, True), + [ # Project view public leads ( # Without login - None, [ + None, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, True, False], @@ -1198,29 +1237,32 @@ def _query_check(lead): ), ( # Non member user - non_member_user, [ + non_member_user, + [ [unprotected_lead, False, False, None], [restricted_lead, False, False, None], [confidential_lead, False, True, False], - ] + ], ), ( # Member user with non-confidential access - member_user, [ + member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), ( # Member user with confidential access - confidential_member_user, [ + confidential_member_user, + [ [unprotected_lead, True, True, True], [restricted_lead, True, True, True], [confidential_lead, True, True, True], - ] + ], ), - ] + ], ), ] @@ -1239,10 +1281,10 @@ def _query_check(lead): project.has_publicly_viewable_confidential_leads = project_show_confidential_leads project.save( update_fields=( - 'is_private', - 'has_publicly_viewable_unprotected_leads', - 'has_publicly_viewable_restricted_leads', - 'has_publicly_viewable_confidential_leads', + "is_private", + "has_publicly_viewable_unprotected_leads", + "has_publicly_viewable_restricted_leads", + "has_publicly_viewable_confidential_leads", ) ) for user, conditions in user_and_conditions: @@ -1251,7 +1293,7 @@ def _query_check(lead): else: self.logout() for used_lead, expect_project_membership_data, expect_lead, show_project_title in conditions: - content = _query_check(used_lead)['data']['publicLead'] + content = _query_check(used_lead)["data"]["publicLead"] check_meta = dict( project_private=is_private, project_show=dict( @@ -1264,25 +1306,25 @@ def _query_check(lead): ) # Excepted Lead if expect_lead: - self.assertIsNotNone(content['lead'], check_meta) - self.assertEqual(content['lead']['uuid'], str(used_lead.uuid)) + self.assertIsNotNone(content["lead"], check_meta) + self.assertEqual(content["lead"]["uuid"], str(used_lead.uuid)) # Show project title in Lead data. if show_project_title: - self.assertIsNotNone(content['lead']['projectTitle'], check_meta) + self.assertIsNotNone(content["lead"]["projectTitle"], check_meta) else: - self.assertIsNone(content['lead']['projectTitle'], check_meta) + self.assertIsNone(content["lead"]["projectTitle"], check_meta) else: - self.assertIsNone(content['lead'], check_meta) + self.assertIsNone(content["lead"], check_meta) # Show project with membership data if expect_project_membership_data: - self.assertIsNotNone(content['project'], check_meta) - self.assertEqual(content['project']['id'], str(used_lead.project_id)) + self.assertIsNotNone(content["project"], check_meta) + self.assertEqual(content["project"]["id"], str(used_lead.project_id)) else: - self.assertIsNone(content['project'], check_meta) + self.assertIsNone(content["project"], check_meta) class TestUserSavedLeadFilters(GraphQLTestCase): - LEAD_FILTER_SAVE_QUERY = ''' + LEAD_FILTER_SAVE_QUERY = """ query MyQuery ($id: ID!) { project(id: $id) { userSavedLeadFilter { @@ -1324,9 +1366,9 @@ class TestUserSavedLeadFilters(GraphQLTestCase): } } } - ''' + """ - LEAD_FILTER_SAVE_MUTATION = ''' + LEAD_FILTER_SAVE_MUTATION = """ mutation MyMutation ($id: ID!, $input: UserSavedLeadFilterInputType!) { project(id: $id) { leadFilterSave(data: $input) { @@ -1358,7 +1400,7 @@ class TestUserSavedLeadFilters(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -1366,8 +1408,8 @@ def setUp(self): self.region1, self.region2 = RegionFactory.create_batch(2) self.project = ProjectFactory.create(analysis_framework=self.af) self.project.regions.add(self.region1) - self.non_member_user = UserFactory.create(email='non-member@x.y') - self.member_user = UserFactory.create(email='member@x.y') + self.non_member_user = UserFactory.create(email="non-member@x.y") + self.member_user = UserFactory.create(email="member@x.y") self.project.add_member(self.member_user) # Create entities used for entryFilterdData @@ -1393,9 +1435,7 @@ def _str_list(int_list): return [str(_int) for _int in int_list] def get_id_obj(objs): - return [ - dict(id=str(obj.pk)) for obj in objs - ] + return [dict(id=str(obj.pk)) for obj in objs] self.custom_filters = dict( assignees=_str_list([m_user1.pk, m_user2.pk]), @@ -1412,11 +1452,13 @@ def get_id_obj(objs): lead_created_by=_str_list([m_user2.pk, m_user3.pk]), lead_source_organizations=_str_list([org2.pk, org3.pk]), modified_by=_str_list([m_user1.pk, m_user3.pk]), - filterable_data=[dict( - filter_key=geo_widget_filter.key, - value_list=_str_list([geoarea1_1_1.pk, geoarea1_2_1.pk, geoarea2_2.pk]), - )], - ) + filterable_data=[ + dict( + filter_key=geo_widget_filter.key, + value_list=_str_list([geoarea1_1_1.pk, geoarea1_2_1.pk, geoarea2_2.pk]), + ) + ], + ), ) self.custom_filters_camel_case = json.loads(CamelCaseJSONRenderer().render(self.custom_filters)) @@ -1439,7 +1481,7 @@ def get_id_obj(objs): def test_user_saved_lead_filter_query(self): def _query_check(**kwargs): - return self.query_check(self.LEAD_FILTER_SAVE_QUERY, variables={'id': str(self.project.id)}, **kwargs) + return self.query_check(self.LEAD_FILTER_SAVE_QUERY, variables={"id": str(self.project.id)}, **kwargs) # Without login _query_check(assert_for_error=True) @@ -1447,17 +1489,17 @@ def _query_check(**kwargs): # login with non-member-user self.force_login(self.non_member_user) content = _query_check() - self.assertIsNone(content['data']['project']['userSavedLeadFilter'], content) + self.assertIsNone(content["data"]["project"]["userSavedLeadFilter"], content) UserSavedLeadFilterFactory.create(user=self.non_member_user, project=self.project) # Same here as well even with saved filter in backend content = _query_check() - self.assertIsNone(content['data']['project']['userSavedLeadFilter'], content) + self.assertIsNone(content["data"]["project"]["userSavedLeadFilter"], content) # login with member-user self.force_login(self.member_user) content = _query_check() - self.assertIsNone(content['data']['project']['userSavedLeadFilter'], content) + self.assertIsNone(content["data"]["project"]["userSavedLeadFilter"], content) UserSavedLeadFilterFactory.create( user=self.member_user, @@ -1465,26 +1507,26 @@ def _query_check(**kwargs): filters=self.custom_filters, ) content = _query_check() - self.assertIsNotNone(content['data']['project']['userSavedLeadFilter'], content) + self.assertIsNotNone(content["data"]["project"]["userSavedLeadFilter"], content) self.maxDiff = None self.assertEqual( - content['data']['project']['userSavedLeadFilter']['filtersData'], + content["data"]["project"]["userSavedLeadFilter"]["filtersData"], self.expected_filter_data_options, content, ) def test_user_saved_lead_filter_mutation(self): minput = { - 'title': 'First Filter', - 'filters': {}, + "title": "First Filter", + "filters": {}, } def _query_check(**kwargs): return self.query_check( self.LEAD_FILTER_SAVE_MUTATION, minput=minput, - mnested=['project'], - variables={'id': str(self.project.id)}, + mnested=["project"], + variables={"id": str(self.project.id)}, **kwargs, ) @@ -1498,13 +1540,13 @@ def _query_check(**kwargs): # login with member-user self.force_login(self.member_user) content = _query_check(okay=True) - self.assertIsNotNone(content['data']['project']['leadFilterSave']['result'], content) + self.assertIsNotNone(content["data"]["project"]["leadFilterSave"]["result"], content) # With some valid data - minput['filters'] = self.custom_filters_camel_case + minput["filters"] = self.custom_filters_camel_case content = _query_check(okay=True) self.assertEqual( - content['data']['project']['leadFilterSave']['result']['filtersData'], + content["data"]["project"]["leadFilterSave"]["result"]["filtersData"], self.expected_filter_data_options, content, ) diff --git a/apps/lead/tests/test_tasks.py b/apps/lead/tests/test_tasks.py index 0bd8d73738..1f791487eb 100644 --- a/apps/lead/tests/test_tasks.py +++ b/apps/lead/tests/test_tasks.py @@ -1,15 +1,15 @@ -import os import logging -from parameterized import parameterized +import os + from django.conf import settings +from lead.models import Lead +from lead.tasks import extract_from_lead +from parameterized import parameterized from deep.tests import TestCase from utils.common import get_or_write_file, makedirs, sanitize_text from utils.extractor.tests.test_web_document import HTML_URL, REDHUM_URL -from lead.tasks import extract_from_lead -from lead.models import Lead - logger = logging.getLogger(__name__) @@ -20,7 +20,7 @@ def create_lead_with_url(self, url): lead.project.is_private = False lead.project.save() - lead.text = '' + lead.text = "" lead.url = url lead.save() return lead @@ -29,16 +29,18 @@ def setUp(self): super().setUp() # This is similar to test_web_document - self.path = os.path.join(settings.TEST_DIR, 'documents_urls') + self.path = os.path.join(settings.TEST_DIR, "documents_urls") makedirs(self.path) # Create the sample lead self.lead = self.create_lead_with_url(HTML_URL) - @parameterized.expand([ - ['relief_url', HTML_URL], # Server Render Page - ['redhum_url', REDHUM_URL], # SPA - ]) + @parameterized.expand( + [ + ["relief_url", HTML_URL], # Server Render Page + ["redhum_url", REDHUM_URL], # SPA + ] + ) def test_extraction_(self, _, url): # Create the sample lead lead = self.create_lead_with_url(url) @@ -54,13 +56,13 @@ def test_extraction_(self, _, url): # This is similar to test_web_document path = os.path.join( self.path, - '.'.join(url.split('/')[-1:]), + ".".join(url.split("/")[-1:]), ) - extracted = get_or_write_file(path + '.txt', lead_preview.text_extract) + extracted = get_or_write_file(path + ".txt", lead_preview.text_extract) self.assertEqual( - ' '.join(lead_preview.text_extract.split()), - sanitize_text(' '.join(extracted.read().split())), + " ".join(lead_preview.text_extract.split()), + sanitize_text(" ".join(extracted.read().split())), ) except Exception: - logger.warning('LEAD EXTRACTION ERROR:', exc_info=True) + logger.warning("LEAD EXTRACTION ERROR:", exc_info=True) return diff --git a/apps/lead/tests/test_unit.py b/apps/lead/tests/test_unit.py index f6f9a53208..a52567d6df 100644 --- a/apps/lead/tests/test_unit.py +++ b/apps/lead/tests/test_unit.py @@ -3,6 +3,6 @@ def test_lead_text_extract_transform_tab_and_nbsp(): """Test fitler and tansform extracted word""" - extracted_text = 'Hello, this is extracted\t text that contains  .' - expected_text = 'Hello, this is extracted text that contains .' + extracted_text = "Hello, this is extracted\t text that contains  ." + expected_text = "Hello, this is extracted text that contains ." assert sanitize_text(extracted_text) == expected_text diff --git a/apps/lead/views.py b/apps/lead/views.py index b922e32770..3fea11972b 100644 --- a/apps/lead/views.py +++ b/apps/lead/views.py @@ -1,82 +1,72 @@ import copy import re -import requests import uuid as python_uuid -from django.utils import timezone +import django_filters +import requests from django.conf import settings from django.contrib.auth.models import User from django.contrib.postgres.search import TrigramSimilarity from django.db import models, transaction - -from deep import compiler # noqa: F401 -from rest_framework.decorators import action +from django.utils import timezone +from drf_yasg.utils import swagger_auto_schema +from entry.models import Entry +from lead.filter_set import LeadFilterSet, LeadGroupFilterSet +from organization.models import Organization, OrganizationType +from organization.serializers import SimpleOrganizationSerializer +from project.models import Project, ProjectMembership +from project.permissions import PROJECT_PERMISSIONS as PROJ_PERMS from rest_framework import ( - serializers, exceptions, filters, permissions, response, + serializers, status, views, viewsets, ) -from drf_yasg.utils import swagger_auto_schema - -import django_filters +from rest_framework.decorators import action +from unified_connector.sources.base import OrganizationSearch -from deep.permissions import ModifyPermission, CreateLeadPermission, DeleteLeadPermission -from deep.paginations import AutocompleteSetPagination +from deep import compiler # noqa: F401 from deep.authentication import CSRFExemptSessionAuthentication - -from lead.filter_set import ( - LeadGroupFilterSet, - LeadFilterSet, +from deep.paginations import AutocompleteSetPagination +from deep.permissions import ( + CreateLeadPermission, + DeleteLeadPermission, + ModifyPermission, ) -from project.models import Project, ProjectMembership -from project.permissions import PROJECT_PERMISSIONS as PROJ_PERMS -from organization.models import Organization, OrganizationType -from organization.serializers import SimpleOrganizationSerializer -from utils.web_info_extractor import get_web_info_extractor from utils.common import DEFAULT_HEADERS -from unified_connector.sources.base import OrganizationSearch -from entry.models import Entry +from utils.web_info_extractor import get_web_info_extractor -from .tasks import extract_from_lead -from .models import ( - LeadGroup, - Lead, - EMMEntity, - LeadEMMTrigger, - LeadPreviewImage, -) +from .models import EMMEntity, Lead, LeadEMMTrigger, LeadGroup, LeadPreviewImage from .serializers import ( - raise_or_return_existing_lead, LeadGroupSerializer, - SimpleLeadGroupSerializer, - LeadSerializer, - LeadPreviewSerializer, - check_if_url_exists, - LeadOptionsSerializer, LeadOptionsBodySerializer, + LeadOptionsSerializer, + LeadPreviewSerializer, + LeadSerializer, LegacyLeadOptionsSerializer, + SimpleLeadGroupSerializer, + check_if_url_exists, + raise_or_return_existing_lead, ) - +from .tasks import extract_from_lead valid_lead_url_regex = re.compile( # http:// or https:// - r'^(?:http)s?://' + r"^(?:http)s?://" # domain... - r'(?:(?:[A-Z0-9]' - r'(?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+' - r'(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' + r"(?:(?:[A-Z0-9]" r"(?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+" r"(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # localhost... - r'localhost|' + r"localhost|" # ...or ip - r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # optional port - r'(?::\d+)?' - r'(?:/?|[/?]\S+)$', re.IGNORECASE) + r"(?::\d+)?" r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) def _filter_users_by_projects_memberships(user_qs, projects): @@ -88,13 +78,11 @@ def _filter_users_by_projects_memberships(user_qs, projects): class LeadGroupViewSet(viewsets.ModelViewSet): serializer_class = LeadGroupSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = LeadGroupFilterSet authentication_classes = [CSRFExemptSessionAuthentication] - search_fields = ('title',) + search_fields = ("title",) def get_queryset(self): return LeadGroup.get_for(self.request.user) @@ -104,12 +92,13 @@ class ProjectLeadGroupViewSet(LeadGroupViewSet): """ NOTE: Only to be used by Project's action route [DONOT USE DIRECTLY] """ + pagination_class = AutocompleteSetPagination serializer_class = SimpleLeadGroupSerializer filter_backends = (filters.SearchFilter,) def get_queryset(self): - project = Project.objects.get(pk=self.request.query_params['project']) + project = Project.objects.get(pk=self.request.query_params["project"]) return LeadGroup.get_for_project(project) @@ -117,28 +106,30 @@ class LeadViewSet(viewsets.ModelViewSet): """ Lead View """ + serializer_class = LeadSerializer - permission_classes = [permissions.IsAuthenticated, CreateLeadPermission, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, CreateLeadPermission, ModifyPermission] filter_backends = (django_filters.rest_framework.DjangoFilterBackend,) # NOTE: Using LeadFilterSet for both search and ordering filterset_class = LeadFilterSet def get_serializer(self, *args, **kwargs): - data = kwargs.get('data') - project_list = data and data.get('project') + data = kwargs.get("data") + project_list = data and data.get("project") if project_list and isinstance(project_list, list): - kwargs.pop('data') - kwargs.pop('many', None) - data.pop('project') + kwargs.pop("data") + kwargs.pop("many", None) + data.pop("project") data_list = [] for project in project_list: - data_list.append({ - **data, - 'project': project, - }) + data_list.append( + { + **data, + "project": project, + } + ) return super().get_serializer( data=data_list, @@ -154,71 +145,64 @@ def get_serializer(self, *args, **kwargs): def get_queryset(self): filters = dict() - filters['entries_filter_data'] = { - f[0]: f[1] for f in self.request.data.pop('entries_filter', []) - } - if self.request.data.get('project'): - project_id = self.request.data['project'] + filters["entries_filter_data"] = {f[0]: f[1] for f in self.request.data.pop("entries_filter", [])} + if self.request.data.get("project"): + project_id = self.request.data["project"] if isinstance(project_id, list) and len(project_id) > 0: - filters['entries_filter_data']['project'] = project_id[0] + filters["entries_filter_data"]["project"] = project_id[0] else: - filters['entries_filter_data']['project'] = project_id + filters["entries_filter_data"]["project"] = project_id leads = Lead.get_for(self.request.user, filters) - lead_id = self.request.GET.get('similar') + lead_id = self.request.GET.get("similar") if lead_id: similar_lead = Lead.objects.get(id=lead_id) - leads = leads.filter(project=similar_lead.project).annotate( - similarity=TrigramSimilarity('title', similar_lead.title) - ).filter(similarity__gt=0.3).order_by('-similarity') + leads = ( + leads.filter(project=similar_lead.project) + .annotate(similarity=TrigramSimilarity("title", similar_lead.title)) + .filter(similarity__gt=0.3) + .order_by("-similarity") + ) return leads # TODO: Remove this API endpoint after client is using summary - @action( - detail=False, - permission_classes=[permissions.IsAuthenticated], - methods=['get', 'post'], - url_path='emm-summary' - ) + @action(detail=False, permission_classes=[permissions.IsAuthenticated], methods=["get", "post"], url_path="emm-summary") def emm_summary(self, request, version=None): return self.summary(request, version=version, emm_info_only=False) - @action( - detail=False, - permission_classes=[permissions.IsAuthenticated], - methods=['get', 'post'], - url_path='summary' - ) + @action(detail=False, permission_classes=[permissions.IsAuthenticated], methods=["get", "post"], url_path="summary") def summary(self, request, version=None, emm_info_only=False): - if request.method == 'GET': + if request.method == "GET": qs = self.filter_queryset(self.get_queryset()) - elif request.method == 'POST': + elif request.method == "POST": raw_filter_data = request.data filter_data = LeadFilterSet.get_processed_filter_data(raw_filter_data) qs = LeadFilterSet(data=filter_data, queryset=self.get_queryset()).qs emm_info = Lead.get_emm_summary(qs) if emm_info_only: return response.Response(emm_info) - return response.Response({ - 'total': qs.count(), - 'total_entries': Entry.objects.filter(lead__in=qs).count(), - 'total_controlled_entries': Entry.objects.filter(lead__in=qs, controlled=True).count(), - 'total_uncontrolled_entries': Entry.objects.filter(lead__in=qs, controlled=False).count(), - **emm_info, - }) + return response.Response( + { + "total": qs.count(), + "total_entries": Entry.objects.filter(lead__in=qs).count(), + "total_controlled_entries": Entry.objects.filter(lead__in=qs, controlled=True).count(), + "total_uncontrolled_entries": Entry.objects.filter(lead__in=qs, controlled=False).count(), + **emm_info, + } + ) def get_serializer_context(self): context = super().get_serializer_context() - if self.action == 'leads_filter': - context['post_is_used_for_filter'] = True + if self.action == "leads_filter": + context["post_is_used_for_filter"] = True return context @action( detail=False, permission_classes=[permissions.IsAuthenticated], - methods=['post'], + methods=["post"], serializer_class=LeadSerializer, - url_path='filter', + url_path="filter", ) def leads_filter(self, request, version=None): raw_filter_data = request.data @@ -246,29 +230,28 @@ def get_serializer_class(self): @action( detail=False, - methods=['post'], - url_path='dry-bulk-delete', + methods=["post"], + url_path="dry-bulk-delete", ) def dry_bulk_delete(self, request, project_id, version=None): - lead_ids = request.data.get('leads', []) + lead_ids = request.data.get("leads", []) tbd_entities = Lead.get_associated_entities(project_id, lead_ids) return response.Response(tbd_entities, status=status.HTTP_200_OK) @action( detail=False, - methods=['post'], - url_path='bulk-delete', + methods=["post"], + url_path="bulk-delete", ) def bulk_delete(self, request, project_id, version=None): - lead_ids = request.data.get('leads', []) + lead_ids = request.data.get("leads", []) Lead.objects.filter(project_id=project_id, id__in=lead_ids).delete() return response.Response(status=status.HTTP_204_NO_CONTENT) class LeadPreviewViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = LeadPreviewSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return Lead.get_for(self.request.user) @@ -278,198 +261,227 @@ class LeadOptionsView(views.APIView): """ Options for various attributes related to lead """ + permission_classes = [permissions.IsAuthenticated] # LEGACY SUPPORT @swagger_auto_schema(responses={200: LegacyLeadOptionsSerializer()}) def get(self, request, version=None): - project_query = request.GET.get('projects') - fields_query = request.GET.get('fields') + project_query = request.GET.get("projects") + fields_query = request.GET.get("fields") projects = Project.get_for_member(request.user) if project_query: - projects = projects.filter(id__in=project_query.split(',')) + projects = projects.filter(id__in=project_query.split(",")) fields = None if fields_query: - fields = fields_query.split(',') + fields = fields_query.split(",") options = {} project_filter = models.Q(project__in=projects) - options['lead_group'] = [ - { - 'key': group.id, - 'value': group.title, - } for group in LeadGroup.objects.filter(project_filter).distinct() - ] if (fields is None or 'lead_group' in fields) else [] + options["lead_group"] = ( + [ + { + "key": group.id, + "value": group.title, + } + for group in LeadGroup.objects.filter(project_filter).distinct() + ] + if (fields is None or "lead_group" in fields) + else [] + ) - options['assignee'] = [ - { - 'key': user.id, - 'value': user.profile.get_display_name(), - } for user in _filter_users_by_projects_memberships( - User.objects.all(), projects, - ).prefetch_related('profile').distinct() - ] if (fields is None or 'assignee' in fields) else [] - - options['confidentiality'] = [ + options["assignee"] = ( + [ + { + "key": user.id, + "value": user.profile.get_display_name(), + } + for user in _filter_users_by_projects_memberships( + User.objects.all(), + projects, + ) + .prefetch_related("profile") + .distinct() + ] + if (fields is None or "assignee" in fields) + else [] + ) + + options["confidentiality"] = [ { - 'key': c[0], - 'value': c[1], - } for c in Lead.Confidentiality.choices + "key": c[0], + "value": c[1], + } + for c in Lead.Confidentiality.choices ] - options['status'] = [ + options["status"] = [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Status.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Status.choices ] - options['priority'] = [ + options["priority"] = [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Priority.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Priority.choices ] - options['project'] = [ - { - 'key': project.id, - 'value': project.title, - } for project in projects.distinct() - ] if (fields is None or 'projects' in fields) else [] + options["project"] = ( + [ + { + "key": project.id, + "value": project.title, + } + for project in projects.distinct() + ] + if (fields is None or "projects" in fields) + else [] + ) # Create Emm specific options - options['emm_entities'] = EMMEntity.objects.filter( - lead__project__in=projects - ).distinct().values('name').annotate( - total_count=models.Count('lead'), - label=models.F('name'), - key=models.F('id'), - ).values('key', 'label', 'total_count').order_by('name') - - options['emm_keywords'] = LeadEMMTrigger.objects.filter( - lead__project__in=projects - ).values('emm_keyword').annotate( - total_count=models.Sum('count'), - key=models.F('emm_keyword'), - label=models.F('emm_keyword') - ).order_by('emm_keyword') - - options['emm_risk_factors'] = LeadEMMTrigger.objects.filter( - ~models.Q(emm_risk_factor=''), - ~models.Q(emm_risk_factor=None), - lead__project__in=projects, - ).values('emm_risk_factor').annotate( - total_count=models.Sum('count'), - key=models.F('emm_risk_factor'), - label=models.F('emm_risk_factor'), - ).order_by('emm_risk_factor') + options["emm_entities"] = ( + EMMEntity.objects.filter(lead__project__in=projects) + .distinct() + .values("name") + .annotate( + total_count=models.Count("lead"), + label=models.F("name"), + key=models.F("id"), + ) + .values("key", "label", "total_count") + .order_by("name") + ) + + options["emm_keywords"] = ( + LeadEMMTrigger.objects.filter(lead__project__in=projects) + .values("emm_keyword") + .annotate(total_count=models.Sum("count"), key=models.F("emm_keyword"), label=models.F("emm_keyword")) + .order_by("emm_keyword") + ) + + options["emm_risk_factors"] = ( + LeadEMMTrigger.objects.filter( + ~models.Q(emm_risk_factor=""), + ~models.Q(emm_risk_factor=None), + lead__project__in=projects, + ) + .values("emm_risk_factor") + .annotate( + total_count=models.Sum("count"), + key=models.F("emm_risk_factor"), + label=models.F("emm_risk_factor"), + ) + .order_by("emm_risk_factor") + ) # Add info about if the project has emm leads, just check if entities or keywords present - options['has_emm_leads'] = (not not options['emm_entities']) or (not not options['emm_keywords']) + options["has_emm_leads"] = (not not options["emm_entities"]) or (not not options["emm_keywords"]) - options['organization_types'] = [ + options["organization_types"] = [ { - 'key': organization_type.id, - 'value': organization_type.title, - } for organization_type in OrganizationType.objects.all() + "key": organization_type.id, + "value": organization_type.title, + } + for organization_type in OrganizationType.objects.all() ] return response.Response(LegacyLeadOptionsSerializer(options).data) - @swagger_auto_schema( - request_body=LeadOptionsBodySerializer(), - responses={200: LeadOptionsSerializer()} - ) + @swagger_auto_schema(request_body=LeadOptionsBodySerializer(), responses={200: LeadOptionsSerializer()}) def post(self, request, version=None): serializer = LeadOptionsBodySerializer(data=request.data) serializer.is_valid(raise_exception=True) fields = serializer.data - projects_id = fields['projects'] - lead_groups_id = fields['lead_groups'] - organizations_id = fields['organizations'] - members_id = fields['members'] - emm_entities = fields['emm_entities'] - emm_keywords = fields['emm_keywords'] - emm_risk_factors = fields['emm_risk_factors'] - organization_type_ids = fields['organization_types'] + projects_id = fields["projects"] + lead_groups_id = fields["lead_groups"] + organizations_id = fields["organizations"] + members_id = fields["members"] + emm_entities = fields["emm_entities"] + emm_keywords = fields["emm_keywords"] + emm_risk_factors = fields["emm_risk_factors"] + organization_type_ids = fields["organization_types"] projects = Project.get_for_member(request.user).filter( id__in=projects_id, ) if not projects.exists(): - raise exceptions.NotFound('Provided projects not found') + raise exceptions.NotFound("Provided projects not found") project_filter = models.Q(project__in=projects) members_qs = User.objects.filter(id__in=members_id) if len(members_id) else User.objects options = { - 'projects': projects, - + "projects": projects, # Static Options - 'confidentiality': [ + "confidentiality": [ { - 'key': c[0], - 'value': c[1], - } for c in Lead.Confidentiality.choices + "key": c[0], + "value": c[1], + } + for c in Lead.Confidentiality.choices ], - 'status': [ + "status": [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Status.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Status.choices ], - 'priority': [ + "priority": [ { - 'key': s[0], - 'value': s[1], - } for s in Lead.Priority.choices + "key": s[0], + "value": s[1], + } + for s in Lead.Priority.choices ], - # Dynamic Options - - 'lead_groups': LeadGroup.objects.filter(project_filter, id__in=lead_groups_id).distinct(), - 'members': _filter_users_by_projects_memberships(members_qs, projects)\ - .prefetch_related('profile').distinct(), - 'organizations': Organization.objects.filter(id__in=organizations_id).distinct(), - + "lead_groups": LeadGroup.objects.filter(project_filter, id__in=lead_groups_id).distinct(), + "members": _filter_users_by_projects_memberships(members_qs, projects).prefetch_related("profile").distinct(), + "organizations": Organization.objects.filter(id__in=organizations_id).distinct(), # EMM specific options - 'emm_entities': EMMEntity.objects.filter( + "emm_entities": EMMEntity.objects.filter( lead__project__in=projects, name__in=emm_entities, - ).distinct().values('name').annotate( - total_count=models.Count('lead'), - label=models.F('name'), - key=models.F('id'), - ).values('key', 'label', 'total_count').order_by('name'), - - 'emm_keywords': LeadEMMTrigger.objects.filter( - emm_keyword__in=emm_keywords, - lead__project__in=projects - ).values('emm_keyword').annotate( - total_count=models.Sum('count'), - key=models.F('emm_keyword'), - label=models.F('emm_keyword') - ).values('key', 'label', 'total_count').order_by('emm_keyword'), - - 'emm_risk_factors': LeadEMMTrigger.objects.filter( + ) + .distinct() + .values("name") + .annotate( + total_count=models.Count("lead"), + label=models.F("name"), + key=models.F("id"), + ) + .values("key", "label", "total_count") + .order_by("name"), + "emm_keywords": LeadEMMTrigger.objects.filter(emm_keyword__in=emm_keywords, lead__project__in=projects) + .values("emm_keyword") + .annotate(total_count=models.Sum("count"), key=models.F("emm_keyword"), label=models.F("emm_keyword")) + .values("key", "label", "total_count") + .order_by("emm_keyword"), + "emm_risk_factors": LeadEMMTrigger.objects.filter( emm_risk_factor__in=emm_risk_factors, lead__project__in=projects, - ).values('emm_risk_factor').annotate( - total_count=models.Sum('count'), - key=models.F('emm_risk_factor'), - label=models.F('emm_risk_factor'), - ).order_by('emm_risk_factor'), - - 'has_emm_leads': ( - EMMEntity.objects.filter(lead__project__in=projects).exists() or - LeadEMMTrigger.objects.filter(lead__project__in=projects).exists() + ) + .values("emm_risk_factor") + .annotate( + total_count=models.Sum("count"), + key=models.F("emm_risk_factor"), + label=models.F("emm_risk_factor"), + ) + .order_by("emm_risk_factor"), + "has_emm_leads": ( + EMMEntity.objects.filter(lead__project__in=projects).exists() + or LeadEMMTrigger.objects.filter(lead__project__in=projects).exists() ), - 'organization_types': OrganizationType.objects.filter(id__in=organization_type_ids).distinct(), + "organization_types": OrganizationType.objects.filter(id__in=organization_type_ids).distinct(), } return response.Response(LeadOptionsSerializer(options).data) @@ -478,6 +490,7 @@ class LeadExtractionTriggerView(views.APIView): """ A trigger for extracting lead to generate previews """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, lead_id, version=None): @@ -490,25 +503,28 @@ def get(self, request, lead_id, version=None): if not settings.TESTING: transaction.on_commit(lambda: extract_from_lead.delay(lead_id)) - return response.Response({ - 'extraction_triggered': lead_id, - }) + return response.Response( + { + "extraction_triggered": lead_id, + } + ) class LeadWebsiteFetch(views.APIView): """ Get Information about the website """ + permission_classes = [permissions.IsAuthenticated] def post(self, request, *args, **kwargs): - url = request.data.get('url') + url = request.data.get("url") return self.website_fetch(url) def get(self, request, *args, **kwargs): - url = request.query_params.get('url') + url = request.query_params.get("url") response = self.website_fetch(url) - response['Cache-Control'] = 'max-age={}'.format(60 * 60) + response["Cache-Control"] = "max-age={}".format(60 * 60) return response def website_fetch(self, url): @@ -516,45 +532,39 @@ def website_fetch(self, url): http_url = url if not valid_lead_url_regex.match(url): - return response.Response({ - 'error': 'Url is not valid' - }, status=status.HTTP_400_BAD_REQUEST) + return response.Response({"error": "Url is not valid"}, status=status.HTTP_400_BAD_REQUEST) - if url.find('http://') >= 0: - https_url = url.replace('http://', 'https://', 1) + if url.find("http://") >= 0: + https_url = url.replace("http://", "https://", 1) else: - http_url = url.replace('https://', 'http://', 1) + http_url = url.replace("https://", "http://", 1) try: # Try with https - r = requests.head( - https_url, headers=DEFAULT_HEADERS, - timeout=settings.LEAD_WEBSITE_FETCH_TIMEOUT - ) + r = requests.head(https_url, headers=DEFAULT_HEADERS, timeout=settings.LEAD_WEBSITE_FETCH_TIMEOUT) except requests.exceptions.RequestException: https_url = None # Try with http try: - r = requests.head( - http_url, headers=DEFAULT_HEADERS, - timeout=settings.LEAD_WEBSITE_FETCH_TIMEOUT - ) + r = requests.head(http_url, headers=DEFAULT_HEADERS, timeout=settings.LEAD_WEBSITE_FETCH_TIMEOUT) except requests.exceptions.RequestException: # doesn't work return response.Response( - {'error': 'can\'t fetch url'}, + {"error": "can't fetch url"}, status=status.HTTP_400_BAD_REQUEST, ) - return response.Response({ - 'headers': dict(r.headers), - 'httpsUrl': https_url, - 'httpUrl': http_url, - 'timestamp': timezone.now().timestamp(), - }) + return response.Response( + { + "headers": dict(r.headers), + "httpsUrl": https_url, + "httpUrl": http_url, + "timestamp": timezone.now().timestamp(), + } + ) -class WebInfoViewMixin(): +class WebInfoViewMixin: permission_classes = [permissions.IsAuthenticated] # FIXME: This is also used by chrome-extension, use csrf properly authentication_classes = [CSRFExemptSessionAuthentication] @@ -575,9 +585,7 @@ def _process_data( ): project = None if country: - project = Project.get_for_member(request.user).filter( - regions__title__icontains=country - ).first() + project = Project.get_for_member(request.user).filter(regions__title__icontains=country).first() project = project or request.user.profile.last_active_project organization_search = OrganizationSearch( [source_raw, *authors_raw], @@ -585,18 +593,15 @@ def _process_data( request.user, ) organization_context = { - 'source': self.get_organization(source_raw, organization_search), - 'authors': [ - self.get_organization(author, organization_search) - for author in authors_raw - ], - 'source_raw': source_raw, - 'authors_raw': authors_raw, + "source": self.get_organization(source_raw, organization_search), + "authors": [self.get_organization(author, organization_search) for author in authors_raw], + "source_raw": source_raw, + "authors_raw": authors_raw, } context = { **organization_context, - 'project': project and project.id, - 'existing': check_if_url_exists(url, request.user, project), + "project": project and project.id, + "existing": check_if_url_exists(url, request.user, project), } return context @@ -606,8 +611,9 @@ class WebInfoExtractView(WebInfoViewMixin, views.APIView): """ Extract information from a website for new lead """ + def post(self, request, version=None): - url = request.data.get('url') + url = request.data.get("url") extractor = get_web_info_extractor(url) date = extractor.get_date() @@ -625,13 +631,15 @@ def post(self, request, version=None): country=country, ) - return response.Response({ - **context, - 'title': title, - 'date': date, - 'country': country, - 'url': url, - }) + return response.Response( + { + **context, + "title": title, + "date": date, + "country": country, + "url": url, + } + ) class WebInfoDataView(WebInfoViewMixin, views.APIView): @@ -641,10 +649,10 @@ class WebInfoDataView(WebInfoViewMixin, views.APIView): """ def post(self, request, version=None): - source_raw = request.data.get('source_raw') - authors_raw = request.data.get('authors_raw') - url = request.data.get('url') - country = request.data.get('country') + source_raw = request.data.get("source_raw") + authors_raw = request.data.get("authors_raw") + url = request.data.get("url") + country = request.data.get("country") context = self._process_data( request, @@ -667,6 +675,7 @@ class BaseCopyView(views.APIView): - CLONE_ENTITY_NAME - CLONE_ENTITY """ + permission_classes = [permissions.IsAuthenticated] def __init__(self): @@ -676,7 +685,7 @@ def __init__(self): self.CLONE_ENTITY_NAME self.CLONE_ENTITY except AttributeError as e: - raise Exception(f'{self.__class__.__name__} attributes are not defined properly', str(e)) + raise Exception(f"{self.__class__.__name__} attributes are not defined properly", str(e)) def get_clone_context(self, request): return {} @@ -684,19 +693,24 @@ def get_clone_context(self, request): # Clone Lead @classmethod def clone_entity(cls, original_lead, project_id, user, context): - raise Exception('This method should be defined') + raise Exception("This method should be defined") @classmethod def get_project_ids_with_create_access(cls, request): """ Project ids with create access for given entity """ - project_ids = ProjectMembership.objects.filter( - project_id__in=request.data.get('projects', []), - member=request.user, - ).annotate( - add_permission=models.F(cls.CLONE_ROLE).bitand(cls.CLONE_PERMISSION.create), - ).filter(add_permission=cls.CLONE_PERMISSION.create).values_list('project_id', flat=True) + project_ids = ( + ProjectMembership.objects.filter( + project_id__in=request.data.get("projects", []), + member=request.user, + ) + .annotate( + add_permission=models.F(cls.CLONE_ROLE).bitand(cls.CLONE_PERMISSION.create), + ) + .filter(add_permission=cls.CLONE_PERMISSION.create) + .values_list("project_id", flat=True) + ) return project_ids @transaction.atomic @@ -704,7 +718,7 @@ def post(self, request, *args, **kwargs): context = self.get_clone_context(request) project_ids = self.get_project_ids_with_create_access(request) - entities = self.CLONE_ENTITY.get_for(request.user).filter(pk__in=request.data.get(f'{self.CLONE_ENTITY_NAME}s', [])) + entities = self.CLONE_ENTITY.get_for(request.user).filter(pk__in=request.data.get(f"{self.CLONE_ENTITY_NAME}s", [])) processed_entity = [] processed_entity_by_project = {} @@ -713,17 +727,20 @@ def post(self, request, *args, **kwargs): processed_entity.append(entity.pk) edit_or_create_permission = self.CLONE_PERMISSION.create | self.CLONE_PERMISSION.modify - edit_or_create_membership = ProjectMembership.objects.filter( - member=request.user, - project=entity.project, - ).annotate( - clone_perm=models.F(self.CLONE_ROLE).bitand(edit_or_create_permission) - ).filter(clone_perm__gt=0).first() + edit_or_create_membership = ( + ProjectMembership.objects.filter( + member=request.user, + project=entity.project, + ) + .annotate(clone_perm=models.F(self.CLONE_ROLE).bitand(edit_or_create_permission)) + .filter(clone_perm__gt=0) + .first() + ) if not edit_or_create_membership: raise exceptions.PermissionDenied( - 'You do not have enough permissions to' - f'clone {self.CLONE_ENTITY_NAME} from the project {entity.project.title}' + "You do not have enough permissions to" + f"clone {self.CLONE_ENTITY_NAME} from the project {entity.project.title}" ) for project_id in project_ids: @@ -733,31 +750,39 @@ def post(self, request, *args, **kwargs): # NOTE: To clone entity to another project p_entity = self.clone_entity(entity, project_id, request.user, context) if p_entity: - processed_entity_by_project[project_id] = ( - processed_entity_by_project.get(project_id) or [] - ) + [p_entity.pk] + processed_entity_by_project[project_id] = (processed_entity_by_project.get(project_id) or []) + [p_entity.pk] - return response.Response({ - 'projects': project_ids, - f'{self.CLONE_ENTITY_NAME}s': processed_entity, - f'{self.CLONE_ENTITY_NAME}s_by_projects': processed_entity_by_project, - }, status=201) + return response.Response( + { + "projects": project_ids, + f"{self.CLONE_ENTITY_NAME}s": processed_entity, + f"{self.CLONE_ENTITY_NAME}s_by_projects": processed_entity_by_project, + }, + status=201, + ) class LeadCopyView(BaseCopyView): """ Copy lead to another project """ + CLONE_PERMISSION = PROJ_PERMS.lead - CLONE_ROLE = 'role__lead_permissions' - CLONE_ENTITY_NAME = 'lead' + CLONE_ROLE = "role__lead_permissions" + CLONE_ENTITY_NAME = "lead" CLONE_ENTITY = Lead @classmethod def clone_or_get_lead(cls, lead, project_id, user, context, create_access_project_ids): """Clone or return existing cloned Lead""" existing_lead = raise_or_return_existing_lead( - project_id, lead, lead.source_type, lead.url, lead.text, lead.attachment, return_lead=True, + project_id, + lead, + lead.source_type, + lead.url, + lead.text, + lead.attachment, + return_lead=True, ) if existing_lead: return existing_lead, False @@ -779,7 +804,7 @@ def _get_clone_ready(obj, lead): return obj # LeadGroup? - preview = original_lead.leadpreview if hasattr(lead, 'leadpreview') else None + preview = original_lead.leadpreview if hasattr(lead, "leadpreview") else None preview_images = original_lead.images.all() emm_triggers = original_lead.emm_triggers.all() emm_entities = original_lead.emm_entities.all() @@ -791,7 +816,12 @@ def _get_clone_ready(obj, lead): # By default it raises error if not skip_existing_check: raise_or_return_existing_lead( - project_id, lead, lead.source_type, lead.url, lead.text, lead.attachment, + project_id, + lead, + lead.source_type, + lead.url, + lead.text, + lead.attachment, ) # return existing lead except serializers.ValidationError: @@ -812,11 +842,7 @@ def _get_clone_ready(obj, lead): lead.authors.set(authors) # Clone Many to one Fields - LeadPreviewImage.objects.bulk_create([ - _get_clone_ready(image, lead) for image in preview_images - ]) - LeadEMMTrigger.objects.bulk_create([ - _get_clone_ready(emm_trigger, lead) for emm_trigger in emm_triggers - ]) + LeadPreviewImage.objects.bulk_create([_get_clone_ready(image, lead) for image in preview_images]) + LeadEMMTrigger.objects.bulk_create([_get_clone_ready(emm_trigger, lead) for emm_trigger in emm_triggers]) return lead diff --git a/apps/notification/__init__.py b/apps/notification/__init__.py index 332ad7f2b4..6bccbaabd5 100644 --- a/apps/notification/__init__.py +++ b/apps/notification/__init__.py @@ -1 +1 @@ -default_app_config = 'notification.apps.NotificationConfig' +default_app_config = "notification.apps.NotificationConfig" diff --git a/apps/notification/admin.py b/apps/notification/admin.py index b1c7aad866..c76394cccd 100644 --- a/apps/notification/admin.py +++ b/apps/notification/admin.py @@ -1,14 +1,12 @@ from django.contrib import admin -from .models import ( - Notification, - Assignment -) + +from .models import Assignment, Notification @admin.register(Notification) class Notification(admin.ModelAdmin): - list_display = ('receiver', 'project', 'notification_type', 'timestamp', 'status') - list_filter = ('notification_type', 'status') + list_display = ("receiver", "project", "notification_type", "timestamp", "status") + list_filter = ("notification_type", "status") @admin.register(Assignment) diff --git a/apps/notification/apps.py b/apps/notification/apps.py index aa6ca54082..b99ae2da1a 100644 --- a/apps/notification/apps.py +++ b/apps/notification/apps.py @@ -2,7 +2,7 @@ class NotificationConfig(AppConfig): - name = 'notification' + name = "notification" def ready(self): - from . import receivers # noqa + from . import receivers # noqa diff --git a/apps/notification/dataloaders.py b/apps/notification/dataloaders.py index 464956c3ca..abb4a6e870 100644 --- a/apps/notification/dataloaders.py +++ b/apps/notification/dataloaders.py @@ -1,10 +1,8 @@ -from django.utils.functional import cached_property from django.db import models - -from promise import Promise - -from notification.models import Assignment +from django.utils.functional import cached_property from lead.models import Lead +from notification.models import Assignment +from promise import Promise from quality_assurance.models import EntryReviewComment from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin @@ -16,11 +14,7 @@ def get_model_name(model: models.Model) -> str: class AssignmentLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - assignment_qs = list( - Assignment.objects - .filter(id__in=keys) - .values_list('id', 'content_type__model', 'object_id') - ) + assignment_qs = list(Assignment.objects.filter(id__in=keys).values_list("id", "content_type__model", "object_id")) leads_id = [] entry_review_comment_id = [] @@ -33,36 +27,22 @@ def batch_load_fn(self, keys): _lead_id_map = {} - for _id, title in Lead.objects.filter(id__in=leads_id).values_list('id', 'title'): - _lead_id_map[_id] = dict( - id=_id, - title=title - ) + for _id, title in Lead.objects.filter(id__in=leads_id).values_list("id", "title"): + _lead_id_map[_id] = dict(id=_id, title=title) _entry_review_comment_id_map = {} - for _id, entry_id, lead_id in EntryReviewComment.objects.filter( - id__in=entry_review_comment_id).values_list( - 'id', - 'entry__id', - 'entry__lead_id' + for _id, entry_id, lead_id in EntryReviewComment.objects.filter(id__in=entry_review_comment_id).values_list( + "id", "entry__id", "entry__lead_id" ): - _entry_review_comment_id_map[_id] = dict( - id=_id, - entry_id=entry_id, - lead_id=lead_id - ) + _entry_review_comment_id_map[_id] = dict(id=_id, entry_id=entry_id, lead_id=lead_id) _result = { _id: { - 'content_type': content_type, - 'lead': ( - _lead_id_map.get(object_id) - if content_type == get_model_name(Lead) else None - ), - 'entry_review_comment': ( - _entry_review_comment_id_map.get(object_id) - if content_type == get_model_name(EntryReviewComment) else None + "content_type": content_type, + "lead": (_lead_id_map.get(object_id) if content_type == get_model_name(Lead) else None), + "entry_review_comment": ( + _entry_review_comment_id_map.get(object_id) if content_type == get_model_name(EntryReviewComment) else None ), } for _id, content_type, object_id in assignment_qs diff --git a/apps/notification/enums.py b/apps/notification/enums.py index 83dcf2b9ec..ebe05b6ade 100644 --- a/apps/notification/enums.py +++ b/apps/notification/enums.py @@ -1,15 +1,16 @@ import graphene +from lead.models import Lead +from quality_assurance.models import EntryReviewComment from utils.graphene.enums import ( convert_enum_to_graphene_enum, get_enum_name_from_django_field, ) -from lead.models import Lead -from quality_assurance.models import EntryReviewComment + from .models import Notification -NotificationTypeEnum = convert_enum_to_graphene_enum(Notification.Type, name='NotificationTypeEnum') -NotificationStatusEnum = convert_enum_to_graphene_enum(Notification.Status, name='NotificationStatusEnum') +NotificationTypeEnum = convert_enum_to_graphene_enum(Notification.Type, name="NotificationTypeEnum") +NotificationStatusEnum = convert_enum_to_graphene_enum(Notification.Status, name="NotificationStatusEnum") enum_map = { get_enum_name_from_django_field(field): enum diff --git a/apps/notification/factories.py b/apps/notification/factories.py index 93f884ac23..d4d7eb00c1 100644 --- a/apps/notification/factories.py +++ b/apps/notification/factories.py @@ -1,9 +1,5 @@ from factory.django import DjangoModelFactory - -from notification.models import ( - Assignment, - Notification, -) +from notification.models import Assignment, Notification class NotificationFactory(DjangoModelFactory): diff --git a/apps/notification/filter_set.py b/apps/notification/filter_set.py index d16c08521d..825b4af8ac 100644 --- a/apps/notification/filter_set.py +++ b/apps/notification/filter_set.py @@ -2,8 +2,9 @@ from django.db.models import Q from utils.graphene.filters import SimpleInputFilter -from .models import Notification, Assignment + from .enums import NotificationStatusEnum, NotificationTypeEnum +from .models import Assignment, Notification class NotificationFilterSet(django_filters.FilterSet): @@ -11,77 +12,77 @@ class NotificationFilterSet(django_filters.FilterSet): Notification filter set """ - TRUE = 'true' - FALSE = 'false' + TRUE = "true" + FALSE = "false" BOOLEAN_CHOICES = ( - (TRUE, 'True'), - (FALSE, 'False'), + (TRUE, "True"), + (FALSE, "False"), ) timestamp__lt = django_filters.DateFilter( - field_name='timestamp', lookup_expr='lt', - input_formats=['%Y-%m-%d%z'], + field_name="timestamp", + lookup_expr="lt", + input_formats=["%Y-%m-%d%z"], ) timestamp__gt = django_filters.DateFilter( - field_name='timestamp', lookup_expr='gt', - input_formats=['%Y-%m-%d%z'], + field_name="timestamp", + lookup_expr="gt", + input_formats=["%Y-%m-%d%z"], ) timestamp__lte = django_filters.DateFilter( - field_name='timestamp', lookup_expr='lte', - input_formats=['%Y-%m-%d%z'], + field_name="timestamp", + lookup_expr="lte", + input_formats=["%Y-%m-%d%z"], ) timestamp__gte = django_filters.DateFilter( - field_name='timestamp', lookup_expr='gte', - input_formats=['%Y-%m-%d%z'], - ) - is_pending = django_filters.ChoiceFilter( - label='Action Status', - method='is_pending_filter', - choices=BOOLEAN_CHOICES + field_name="timestamp", + lookup_expr="gte", + input_formats=["%Y-%m-%d%z"], ) + is_pending = django_filters.ChoiceFilter(label="Action Status", method="is_pending_filter", choices=BOOLEAN_CHOICES) class Meta: model = Notification fields = { - 'timestamp': ['exact', 'lt', 'gt', 'lte', 'gte'], - 'status': ['exact'], - 'notification_type': ['exact'], + "timestamp": ["exact", "lt", "gt", "lte", "gte"], + "status": ["exact"], + "notification_type": ["exact"], } def is_pending_filter(self, queryset, name, value): if value == self.TRUE: return queryset.filter( - data__status='pending', + data__status="pending", ).distinct() elif value == self.FALSE: - return queryset.filter( - ~Q(data__status='pending') | Q(data__status__isnull=True) - ).distinct() + return queryset.filter(~Q(data__status="pending") | Q(data__status__isnull=True)).distinct() return queryset class AssignmentFilterSet(django_filters.FilterSet): class Meta: model = Assignment - fields = ('project', 'is_done') + fields = ("project", "is_done") # -------------------- Graphql Filters ----------------------------------- class NotificationGqlFilterSet(django_filters.FilterSet): timestamp = django_filters.DateTimeFilter( - field_name='timestamp', + field_name="timestamp", input_formats=[django_filters.fields.IsoDateTimeField.ISO_8601], ) timestamp_lte = django_filters.DateTimeFilter( - field_name='timestamp', lookup_expr='lte', + field_name="timestamp", + lookup_expr="lte", input_formats=[django_filters.fields.IsoDateTimeField.ISO_8601], ) timestamp_gte = django_filters.DateTimeFilter( - field_name='timestamp', lookup_expr='gte', + field_name="timestamp", + lookup_expr="gte", input_formats=[django_filters.fields.IsoDateTimeField.ISO_8601], ) - is_pending = django_filters.BooleanFilter(label='Action Status', method='is_pending_filter') + is_pending = django_filters.BooleanFilter(label="Action Status", method="is_pending_filter") notification_type = SimpleInputFilter(NotificationTypeEnum) status = SimpleInputFilter(NotificationStatusEnum) @@ -92,11 +93,9 @@ class Meta: def is_pending_filter(self, queryset, _, value): if value is True: return queryset.filter( - data__status='pending', + data__status="pending", ).distinct() elif value is False: - return queryset.filter( - ~Q(data__status='pending') | Q(data__status__isnull=True) - ).distinct() + return queryset.filter(~Q(data__status="pending") | Q(data__status__isnull=True)).distinct() # If none return queryset diff --git a/apps/notification/models.py b/apps/notification/models.py index a34d3d5917..b0b30d38e8 100644 --- a/apps/notification/models.py +++ b/apps/notification/models.py @@ -2,40 +2,33 @@ from django.contrib.contenttypes.models import ContentType from django.db import models from django.utils import timezone - -from user.models import User from project.models import Project +from user.models import User class Notification(models.Model): class Type(models.TextChoices): # Project Joins Notification Types - PROJECT_JOIN_REQUEST = 'project_join_request', 'Join project request' - PROJECT_JOIN_REQUEST_ABORT = 'project_join_request_abort', 'Join project request abort' - PROJECT_JOIN_RESPONSE = 'project_join_response', 'Join project response' + PROJECT_JOIN_REQUEST = "project_join_request", "Join project request" + PROJECT_JOIN_REQUEST_ABORT = "project_join_request_abort", "Join project request abort" + PROJECT_JOIN_RESPONSE = "project_join_response", "Join project response" # Entry Comment Notifications Types - ENTRY_COMMENT_ADD = 'entry_comment_add', 'Entry Comment Add' - ENTRY_COMMENT_MODIFY = 'entry_comment_modify', 'Entry Comment Modify' - ENTRY_COMMENT_ASSIGNEE_CHANGE = 'entry_comment_assignee_change', 'Entry Comment Assignee Change' - ENTRY_COMMENT_REPLY_ADD = 'entry_comment_reply_add', 'Entry Comment Reply Add' - ENTRY_COMMENT_REPLY_MODIFY = 'entry_comment_reply_modify', 'Entry Comment Reply Modify' - ENTRY_COMMENT_RESOLVED = 'entry_comment_resolved', 'Entry Comment Resolved' + ENTRY_COMMENT_ADD = "entry_comment_add", "Entry Comment Add" + ENTRY_COMMENT_MODIFY = "entry_comment_modify", "Entry Comment Modify" + ENTRY_COMMENT_ASSIGNEE_CHANGE = "entry_comment_assignee_change", "Entry Comment Assignee Change" + ENTRY_COMMENT_REPLY_ADD = "entry_comment_reply_add", "Entry Comment Reply Add" + ENTRY_COMMENT_REPLY_MODIFY = "entry_comment_reply_modify", "Entry Comment Reply Modify" + ENTRY_COMMENT_RESOLVED = "entry_comment_resolved", "Entry Comment Resolved" # Entry Comment Review Notifications Types - ENTRY_REVIEW_COMMENT_ADD = 'entry_review_comment_add', 'Entry Review Comment Add' - ENTRY_REVIEW_COMMENT_MODIFY = 'entry_review_comment_modify', 'Entry Review Comment Modify' + ENTRY_REVIEW_COMMENT_ADD = "entry_review_comment_add", "Entry Review Comment Add" + ENTRY_REVIEW_COMMENT_MODIFY = "entry_review_comment_modify", "Entry Review Comment Modify" class Status(models.TextChoices): - SEEN = 'seen', 'Seen' - UNSEEN = 'unseen', 'Unseen' + SEEN = "seen", "Seen" + UNSEEN = "unseen", "Unseen" receiver = models.ForeignKey(User, on_delete=models.CASCADE) - project = models.ForeignKey( - Project, - on_delete=models.CASCADE, - blank=True, - null=True, - default=None - ) + project = models.ForeignKey(Project, on_delete=models.CASCADE, blank=True, null=True, default=None) data = models.JSONField(default=None, blank=True, null=True) notification_type = models.CharField(max_length=48, choices=Type.choices) status = models.CharField( @@ -48,10 +41,10 @@ class Status(models.TextChoices): ) def __str__(self): - return f'{self.notification_type}:: <{self.receiver}> ({self.status})' + return f"{self.notification_type}:: <{self.receiver}> ({self.status})" class Meta: - ordering = ['-timestamp'] + ordering = ["-timestamp"] @staticmethod def get_for(user): @@ -62,17 +55,19 @@ class Assignment(models.Model): """ Assignment Model """ + created_at = models.DateTimeField(auto_now=True) created_by = models.ForeignKey( User, - blank=True, null=True, + blank=True, + null=True, on_delete=models.SET_NULL, - related_name='created_by', + related_name="created_by", ) created_for = models.ForeignKey( User, on_delete=models.CASCADE, - related_name='created_for', + related_name="created_for", ) project = models.ForeignKey( Project, @@ -81,14 +76,15 @@ class Assignment(models.Model): is_done = models.BooleanField(default=False) content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') + content_object = GenericForeignKey("content_type", "object_id") class Meta: - ordering = ['-created_at'] + ordering = ["-created_at"] @staticmethod def get_for(user): from entry.models import EntryComment + return Assignment.objects.filter( created_for=user, ).exclude( diff --git a/apps/notification/mutation.py b/apps/notification/mutation.py index cc10729b54..45a8178618 100644 --- a/apps/notification/mutation.py +++ b/apps/notification/mutation.py @@ -1,24 +1,18 @@ -from django.utils.translation import gettext - import graphene +from django.utils.translation import gettext +from utils.graphene.error_types import CustomErrorType, mutation_is_not_valid from utils.graphene.mutation import GrapheneMutation, generate_input_type_for_serializer -from utils.graphene.error_types import mutation_is_not_valid, CustomErrorType -from .serializers import AssignmentSerializer, NotificationGqSerializer -from .schema import AssignmentType, NotificationType from .models import Assignment, Notification +from .schema import AssignmentType, NotificationType +from .serializers import AssignmentSerializer, NotificationGqSerializer NotificationStatusInputType = generate_input_type_for_serializer( - 'NotificationStatusInputType', - serializer_class=NotificationGqSerializer - + "NotificationStatusInputType", serializer_class=NotificationGqSerializer ) -AssignmentInputType = generate_input_type_for_serializer( - 'AssignmentInputType', - serializer_class=AssignmentSerializer -) +AssignmentInputType = generate_input_type_for_serializer("AssignmentInputType", serializer_class=AssignmentSerializer) class NotificationStatusUpdate(graphene.Mutation): @@ -32,17 +26,15 @@ class Arguments: @staticmethod def mutate(root, info, data): try: - instance = Notification.objects.get(id=data['id'], receiver=info.context.request.user) + instance = Notification.objects.get(id=data["id"], receiver=info.context.request.user) except Notification.DoesNotExist: - return NotificationStatusUpdate(errors=[ - dict( - field='nonFieldErrors', - messages=gettext('Notification doesnot exist') - ) - ], ok=False) - serializer = NotificationGqSerializer(instance=instance, data=data, - context={'request': info.context.request}, partial=True) + return NotificationStatusUpdate( + errors=[dict(field="nonFieldErrors", messages=gettext("Notification doesnot exist"))], ok=False + ) + serializer = NotificationGqSerializer( + instance=instance, data=data, context={"request": info.context.request}, partial=True + ) if errors := mutation_is_not_valid(serializer): return NotificationStatusUpdate(errors=errors, ok=False) instance = serializer.save() @@ -53,6 +45,7 @@ class AssignmentUpdate(GrapheneMutation): class Arguments: id = graphene.ID(required=True) data = AssignmentInputType(required=True) + model = Assignment result = graphene.Field(AssignmentType) serializer_class = AssignmentSerializer diff --git a/apps/notification/receivers/__init__.py b/apps/notification/receivers/__init__.py index ce5cd5244a..ed7b346eb2 100644 --- a/apps/notification/receivers/__init__.py +++ b/apps/notification/receivers/__init__.py @@ -1,3 +1,3 @@ -from . import project_membership # noqa: F401 -from . import entry_comment # noqa: F401 from . import assignment # noqa: F401 +from . import entry_comment # noqa: F401 +from . import project_membership # noqa: F401 diff --git a/apps/notification/receivers/assignment.py b/apps/notification/receivers/assignment.py index a666261981..5c9db3db3a 100644 --- a/apps/notification/receivers/assignment.py +++ b/apps/notification/receivers/assignment.py @@ -1,21 +1,18 @@ +from django.db.models.signals import m2m_changed, post_delete from django.dispatch import receiver -from django.db.models.signals import ( - m2m_changed, - post_delete, -) - -from deep.middleware import get_current_user from lead.models import Lead from notification.models import Assignment from quality_assurance.models import EntryReviewComment +from deep.middleware import get_current_user + @receiver(m2m_changed, sender=Lead.assignee.through) def lead_assignment_signal(sender, instance, action, **kwargs): - pk_set = kwargs.get('pk_set', []) + pk_set = kwargs.get("pk_set", []) # Gets the username from the request with a middleware helper user = get_current_user() - if action == 'post_add' and pk_set and user: + if action == "post_add" and pk_set and user: for receiver_user in pk_set: if Assignment.objects.filter( lead__id=instance.id, @@ -30,7 +27,7 @@ def lead_assignment_signal(sender, instance, action, **kwargs): created_by=user, ) - elif action == 'post_remove' and pk_set and user: + elif action == "post_remove" and pk_set and user: for receiver_user in pk_set: Assignment.objects.filter( lead__id=instance.id, @@ -39,16 +36,16 @@ def lead_assignment_signal(sender, instance, action, **kwargs): # handling `post_clear` since single assignee is passed # though the api - elif action == 'post_clear': + elif action == "post_clear": Assignment.objects.filter(lead__id=instance.id).delete() @receiver(m2m_changed, sender=EntryReviewComment.mentioned_users.through) def entrycomment_assignment_signal(sender, instance, action, **kwargs): - pk_set = kwargs.get('pk_set', []) + pk_set = kwargs.get("pk_set", []) # Gets the username from the request with a middleware helper user = get_current_user() - if action == 'post_add' and pk_set and user: + if action == "post_add" and pk_set and user: for receiver_user in pk_set: if Assignment.objects.filter( entry_review_comment__id=instance.id, @@ -63,7 +60,7 @@ def entrycomment_assignment_signal(sender, instance, action, **kwargs): created_by=user, ) - elif action == 'post_remove' and pk_set and user: + elif action == "post_remove" and pk_set and user: for receiver_user in pk_set: Assignment.objects.filter( entry_review_comment__id=instance.id, diff --git a/apps/notification/receivers/entry_comment.py b/apps/notification/receivers/entry_comment.py index 8a396c284e..a0c77565c1 100644 --- a/apps/notification/receivers/entry_comment.py +++ b/apps/notification/receivers/entry_comment.py @@ -1,14 +1,9 @@ import logging -from django.dispatch import receiver -from django.db.models.signals import ( - post_save, - pre_save, - m2m_changed, -) -from django.db import transaction from django.conf import settings - +from django.db import transaction +from django.db.models.signals import m2m_changed, post_save, pre_save +from django.dispatch import receiver from entry.models import EntryComment, EntryCommentText from entry.serializers import EntryCommentSerializer from notification.models import Notification @@ -25,8 +20,8 @@ def send_notifications_for_comment(comment_pk, notification_meta): notification_meta = { **notification_meta, - 'project': comment.entry.project, - 'data': EntryCommentSerializer(comment).data, + "project": comment.entry.project, + "data": EntryCommentSerializer(comment).data, } related_users = comment.get_related_users() @@ -41,9 +36,7 @@ def send_notifications_for_comment(comment_pk, notification_meta): if settings.TESTING: send_entry_comment_email(user.pk, comment.pk) else: - transaction.on_commit( - lambda: send_entry_comment_email.delay(user.pk, comment.pk) - ) + transaction.on_commit(lambda: send_entry_comment_email.delay(user.pk, comment.pk)) @receiver(pre_save, sender=EntryComment) @@ -58,24 +51,24 @@ def create_entry_commit_notification(sender, instance, **kwargs): old_comment = EntryComment.objects.get(pk=instance.pk) if instance.is_resolved and old_comment.is_resolved != instance.is_resolved: # Comment is Resolved - meta['notification_type'] = Notification.Type.ENTRY_COMMENT_RESOLVED + meta["notification_type"] = Notification.Type.ENTRY_COMMENT_RESOLVED transaction.on_commit(lambda: send_notifications_for_comment(instance.pk, meta)) instance.receiver_notification_already_send = True @receiver(m2m_changed, sender=EntryComment.assignees.through) def create_entry_commit_notification_post(sender, instance, action, **kwargs): - receiver_notification_already_send = getattr(instance, 'receiver_notification_already_send', False) + receiver_notification_already_send = getattr(instance, "receiver_notification_already_send", False) # Default:False Because when it's patch request with only m2m change, create_entry_commit_notification is not triggered - created = getattr(instance, 'receiver_created', False) + created = getattr(instance, "receiver_created", False) if ( - created or action not in ['post_add', 'post_remove'] or instance.parent or receiver_notification_already_send + created or action not in ["post_add", "post_remove"] or instance.parent or receiver_notification_already_send ): # Notification is handled from commit text creation return meta = {} - meta['notification_type'] = Notification.Type.ENTRY_COMMENT_ASSIGNEE_CHANGE + meta["notification_type"] = Notification.Type.ENTRY_COMMENT_ASSIGNEE_CHANGE instance.receiver_notification_already_send = True transaction.on_commit(lambda: send_notifications_for_comment(instance.pk, meta)) @@ -87,11 +80,11 @@ def create_entry_commit_text_notification(sender, instance, created, **kwargs): comment = instance.comment meta = {} - meta['notification_type'] = ( + meta["notification_type"] = ( Notification.Type.ENTRY_COMMENT_REPLY_ADD if comment.parent else Notification.Type.ENTRY_COMMENT_ADD ) if EntryCommentText.objects.filter(comment=comment).count() > 1: - meta['notification_type'] = ( + meta["notification_type"] = ( Notification.Type.ENTRY_COMMENT_REPLY_MODIFY if comment.parent else Notification.Type.ENTRY_COMMENT_MODIFY ) diff --git a/apps/notification/receivers/project_membership.py b/apps/notification/receivers/project_membership.py index 2b80794e12..43a94f292b 100644 --- a/apps/notification/receivers/project_membership.py +++ b/apps/notification/receivers/project_membership.py @@ -1,12 +1,7 @@ +from django.db.models.signals import post_delete, post_save, pre_save from django.dispatch import receiver -from django.db.models.signals import post_save, post_delete, pre_save - from notification.models import Notification -from project.models import ( - ProjectJoinRequest, - ProjectMembership, - ProjectRole, -) +from project.models import ProjectJoinRequest, ProjectMembership, ProjectRole from project.serializers import ProjectJoinRequestSerializer @@ -14,7 +9,7 @@ def create_notification(sender, instance, created, **kwargs): admins = instance.project.get_admins() data = ProjectJoinRequestSerializer(instance).data - if (created): + if created: for admin in admins: Notification.objects.create( receiver=admin, @@ -25,7 +20,7 @@ def create_notification(sender, instance, created, **kwargs): return # notify the requester as well - if instance.status in ['accepted', 'rejected']: + if instance.status in ["accepted", "rejected"]: Notification.objects.create( receiver=instance.requested_by, notification_type=Notification.Type.PROJECT_JOIN_RESPONSE, @@ -33,9 +28,7 @@ def create_notification(sender, instance, created, **kwargs): data=data, ) - old_notifications = Notification.objects.filter( - data__id=instance.id - ) + old_notifications = Notification.objects.filter(data__id=instance.id) for notification in old_notifications: notification.data = data @@ -52,48 +45,37 @@ def create_notification(sender, instance, created, **kwargs): @receiver(post_delete, sender=ProjectJoinRequest) def update_notification_for_join_request(sender, instance, **kwargs): - old_notifications = Notification.objects.filter( - data__id=instance.id - ) + old_notifications = Notification.objects.filter(data__id=instance.id) for notification in old_notifications: - notification.data['status'] = 'aborted' + notification.data["status"] = "aborted" notification.save() admins = instance.project.get_admins() data = ProjectJoinRequestSerializer(instance).data - data['status'] = 'aborted' + data["status"] = "aborted" for admin in admins: Notification.objects.create( - receiver=admin, - notification_type=Notification.Type.PROJECT_JOIN_REQUEST_ABORT, - project=instance.project, - data=data + receiver=admin, notification_type=Notification.Type.PROJECT_JOIN_REQUEST_ABORT, project=instance.project, data=data ) @receiver(pre_save, sender=ProjectMembership) -def remove_notifications_for_former_project_admin( - sender, instance, **kwargs): +def remove_notifications_for_former_project_admin(sender, instance, **kwargs): admin_roles = ProjectRole.get_admin_roles() try: old_membership = ProjectMembership.objects.get(id=instance.id) - if old_membership.role in admin_roles\ - and instance.role not in admin_roles: - old_notifications = Notification.objects.filter( - receiver=instance.member, - project=instance.project - ) + if old_membership.role in admin_roles and instance.role not in admin_roles: + old_notifications = Notification.objects.filter(receiver=instance.member, project=instance.project) old_notifications.delete() - if old_membership.role not in admin_roles\ - and instance.role in admin_roles: + if old_membership.role not in admin_roles and instance.role in admin_roles: old_project_join_requests = ProjectJoinRequest.objects.filter( project=instance.project, - status='pending', + status="pending", ) for old_project_join_request in old_project_join_requests: @@ -101,21 +83,19 @@ def remove_notifications_for_former_project_admin( receiver=instance.member, notification_type=Notification.Type.PROJECT_JOIN_REQUEST, project=instance.project, - data=ProjectJoinRequestSerializer( - old_project_join_request).data, + data=ProjectJoinRequestSerializer(old_project_join_request).data, ) except ProjectMembership.DoesNotExist: pass @receiver(post_save, sender=ProjectMembership) -def create_notifications_for_new_project_admin( - sender, instance, created, **kwargs): +def create_notifications_for_new_project_admin(sender, instance, created, **kwargs): if created is True: if instance.role in ProjectRole.get_admin_roles(): old_project_join_requests = ProjectJoinRequest.objects.filter( project=instance.project, - status='pending', + status="pending", ) for old_project_join_request in old_project_join_requests: @@ -123,6 +103,5 @@ def create_notifications_for_new_project_admin( receiver=instance.member, notification_type=Notification.Type.PROJECT_JOIN_REQUEST, project=instance.project, - data=ProjectJoinRequestSerializer( - old_project_join_request).data, + data=ProjectJoinRequestSerializer(old_project_join_request).data, ) diff --git a/apps/notification/schema.py b/apps/notification/schema.py index 2302bb9e14..c23d21b5b7 100644 --- a/apps/notification/schema.py +++ b/apps/notification/schema.py @@ -1,21 +1,20 @@ import graphene - from django.db.models import QuerySet from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from deep.trackers import track_user from utils.graphene.enums import EnumDescription -from utils.graphene.types import CustomDjangoListObjectType from utils.graphene.fields import DjangoPaginatedListObjectField -from deep.trackers import track_user +from utils.graphene.types import CustomDjangoListObjectType -from .models import Assignment, Notification -from .filter_set import NotificationGqlFilterSet, AssignmentFilterSet from .enums import ( - NotificationTypeEnum, + AssignmentContentTypeEnum, NotificationStatusEnum, - AssignmentContentTypeEnum + NotificationTypeEnum, ) +from .filter_set import AssignmentFilterSet, NotificationGqlFilterSet +from .models import Assignment, Notification def get_user_notification_qs(info): @@ -34,13 +33,16 @@ class NotificationType(DjangoObjectType): class Meta: model = Notification only_fields = ( - 'id', 'project', 'data', 'timestamp', + "id", + "project", + "data", + "timestamp", ) notification_type = graphene.Field(graphene.NonNull(NotificationTypeEnum)) - notification_type_display = EnumDescription(source='get_notification_type_display', required=True) + notification_type_display = EnumDescription(source="get_notification_type_display", required=True) status = graphene.Field(graphene.NonNull(NotificationStatusEnum)) - status_display = EnumDescription(source='get_status_display', required=True) + status_display = EnumDescription(source="get_status_display", required=True) @staticmethod def get_custom_queryset(queryset, info, **kwargs): @@ -72,6 +74,7 @@ class AssignmentContentDataType(graphene.ObjectType): class AssignmentType(DjangoObjectType): class Meta: model = Assignment + id = graphene.ID(required=True) project = graphene.Field(AssignmentProjectDetailType) content_data = graphene.Field(AssignmentContentDataType) @@ -96,16 +99,10 @@ class Meta: class Query: notification = DjangoObjectField(NotificationType) notifications = DjangoPaginatedListObjectField( - NotificationListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + NotificationListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) assignments = DjangoPaginatedListObjectField( - AssignmentListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + AssignmentListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod diff --git a/apps/notification/serializers.py b/apps/notification/serializers.py index 8b0522b58b..1b9b7048aa 100644 --- a/apps/notification/serializers.py +++ b/apps/notification/serializers.py @@ -1,19 +1,14 @@ -from rest_framework import serializers - +from entry.models import EntryComment from generic_relations.relations import GenericRelatedField - -from deep.serializers import RemoveNullFieldsMixin -from user.serializers import SimpleUserSerializer -from project.serializers import SimpleProjectSerializer -from deep.serializers import IntegerIDField - from lead.models import Lead +from project.serializers import SimpleProjectSerializer from quality_assurance.models import EntryReviewComment -from entry.models import EntryComment -from .models import ( - Notification, - Assignment -) +from rest_framework import serializers +from user.serializers import SimpleUserSerializer + +from deep.serializers import IntegerIDField, RemoveNullFieldsMixin + +from .models import Assignment, Notification class NotificationSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): @@ -22,20 +17,16 @@ class NotificationSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer) class Meta: model = Notification - fields = ('__all__') - read_only_fields = ( - 'data', 'receiver', 'project', 'notification_type' - ) + fields = "__all__" + read_only_fields = ("data", "receiver", "project", "notification_type") def create(self, validated_data): - id = validated_data.get('id') + id = validated_data.get("id") if id: try: notification = Notification.objects.get(id=id) except Notification.DoesNotExist: - raise serializers.ValidationError({ - 'id': 'Invalid notification id: {}'.format(id) - }) + raise serializers.ValidationError({"id": "Invalid notification id: {}".format(id)}) return self.update(notification, validated_data) return super().create(validated_data) @@ -46,50 +37,53 @@ def get_data(self, notification): class AssignmentEntryCommentSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): - entry_excerpt = serializers.CharField(source='entry.excerpt', read_only=True) - lead = serializers.CharField(source='entry.lead_id', read_only=True) + entry_excerpt = serializers.CharField(source="entry.excerpt", read_only=True) + lead = serializers.CharField(source="entry.lead_id", read_only=True) class Meta: model = EntryComment - fields = ('id', 'text', 'entry', 'entry_excerpt', 'lead') + fields = ("id", "text", "entry", "entry_excerpt", "lead") class AssignmentEntryReviewCommentSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): - entry_excerpt = serializers.CharField(source='entry.excerpt', read_only=True) - lead = serializers.CharField(source='entry.lead_id', read_only=True) + entry_excerpt = serializers.CharField(source="entry.excerpt", read_only=True) + lead = serializers.CharField(source="entry.lead_id", read_only=True) class Meta: model = EntryReviewComment - fields = ('id', 'text', 'entry', 'entry_excerpt', 'lead') + fields = ("id", "text", "entry", "entry_excerpt", "lead") class AssignmentLeadSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Lead - fields = ('id', 'title',) + fields = ( + "id", + "title", + ) class AssignmentSerializer(serializers.ModelSerializer): - content_object_details = GenericRelatedField({ - Lead: AssignmentLeadSerializer(), - EntryComment: AssignmentEntryCommentSerializer(), - EntryReviewComment: AssignmentEntryReviewCommentSerializer(), - }, read_only=True, source='content_object') - project_details = SimpleProjectSerializer(source='project', read_only=True) - created_by_details = SimpleUserSerializer(source='created_by', read_only=True) + content_object_details = GenericRelatedField( + { + Lead: AssignmentLeadSerializer(), + EntryComment: AssignmentEntryCommentSerializer(), + EntryReviewComment: AssignmentEntryReviewCommentSerializer(), + }, + read_only=True, + source="content_object", + ) + project_details = SimpleProjectSerializer(source="project", read_only=True) + created_by_details = SimpleUserSerializer(source="created_by", read_only=True) class Meta: model = Assignment - read_only_fields = [ - 'id', - 'created_at', - 'project_details', 'created_by_details', 'content_object_details', 'content_type' - ] - fields = read_only_fields + ['is_done'] + read_only_fields = ["id", "created_at", "project_details", "created_by_details", "content_object_details", "content_type"] + fields = read_only_fields + ["is_done"] def to_representation(self, instance): data = super().to_representation(instance) - data['content_object_type'] = instance.content_type.model + data["content_object_type"] = instance.content_type.model return data @@ -99,9 +93,9 @@ class NotificationGqSerializer(serializers.ModelSerializer): class Meta: model = Notification - fields = ('id', 'status') + fields = ("id", "status") def update(self, instance, validated_data): - if instance and instance.receiver != self.context['request'].user: - raise serializers.ValidationError('Only the recepient of this notification can update its status.') + if instance and instance.receiver != self.context["request"].user: + raise serializers.ValidationError("Only the recepient of this notification can update its status.") return super().update(instance, validated_data) diff --git a/apps/notification/tasks.py b/apps/notification/tasks.py index 075066fd05..38c96caccc 100644 --- a/apps/notification/tasks.py +++ b/apps/notification/tasks.py @@ -1,12 +1,11 @@ from celery import shared_task from django.db import transaction - from entry.models import EntryComment -from user.models import User, EmailCondition +from quality_assurance.models import EntryReviewComment +from user.models import EmailCondition, User from user.utils import send_mail_to_user -from deep.permalinks import Permalink -from quality_assurance.models import EntryReviewComment +from deep.permalinks import Permalink from .models import Notification @@ -16,18 +15,17 @@ def send_entry_comment_email(user_id, comment_id): user = User.objects.get(pk=user_id) comment = EntryComment.objects.get(pk=comment_id) send_mail_to_user( - user, EmailCondition.EMAIL_COMMENT, + user, + EmailCondition.EMAIL_COMMENT, context={ - 'notification_type': Notification.Type.ENTRY_COMMENT_ADD, - 'Notification': Notification, - 'comment': comment, - 'assignees_display': ', '.join( - assignee.profile.get_display_name() for assignee in comment.assignees.all() - ), - 'entry_comment_client_url': Permalink.ientry_comments(comment.entry) + "notification_type": Notification.Type.ENTRY_COMMENT_ADD, + "Notification": Notification, + "comment": comment, + "assignees_display": ", ".join(assignee.profile.get_display_name() for assignee in comment.assignees.all()), + "entry_comment_client_url": Permalink.ientry_comments(comment.entry), }, - subject_template_name='entry/comment_notification_email.txt', - email_template_name='entry/comment_notification_email.html', + subject_template_name="entry/comment_notification_email.txt", + email_template_name="entry/comment_notification_email.html", ) @@ -36,16 +34,17 @@ def send_entry_review_comment_email(user_id, comment_id, notification_type): user = User.objects.get(pk=user_id) comment = EntryReviewComment.objects.get(pk=comment_id) send_mail_to_user( - user, EmailCondition.EMAIL_COMMENT, + user, + EmailCondition.EMAIL_COMMENT, context={ - 'Notification': Notification, - 'CommentType': EntryReviewComment.CommentType, - 'notification_type': notification_type, - 'comment': comment, - 'entry_comment_client_url': Permalink.ientry_comment(comment) + "Notification": Notification, + "CommentType": EntryReviewComment.CommentType, + "notification_type": notification_type, + "comment": comment, + "entry_comment_client_url": Permalink.ientry_comment(comment), }, - subject_template_name='entry/review_comment_notification_email.txt', - email_template_name='entry/review_comment_notification_email.html', + subject_template_name="entry/review_comment_notification_email.txt", + email_template_name="entry/review_comment_notification_email.html", ) @@ -55,14 +54,15 @@ def send_notifications_for_comment(comment_pk, meta): """ comment = EntryReviewComment.objects.get(pk=comment_pk) - text_changed = meta.pop('text_changed') - new_mentioned_users = meta.pop('new_mentioned_users', []) + text_changed = meta.pop("text_changed") + new_mentioned_users = meta.pop("new_mentioned_users", []) if text_changed: related_users = comment.get_related_users() else: related_users = new_mentioned_users from quality_assurance.serializers import EntryReviewCommentNotificationSerializer + for user in related_users: # Create DEEP Notification Objects Notification.objects.create( @@ -77,6 +77,6 @@ def send_notifications_for_comment(comment_pk, meta): lambda: send_entry_review_comment_email.delay( user.pk, comment.pk, - meta['notification_type'], + meta["notification_type"], ) ) diff --git a/apps/notification/templatetags/deep_notification_tags.py b/apps/notification/templatetags/deep_notification_tags.py index d0e7da9070..8a4e8dcd9f 100644 --- a/apps/notification/templatetags/deep_notification_tags.py +++ b/apps/notification/templatetags/deep_notification_tags.py @@ -1,11 +1,9 @@ from django import template from django.conf import settings -from django.templatetags.static import static from django.core.files.storage import FileSystemStorage, get_storage_class - +from django.templatetags.static import static from mdmail.api import EmailContent - register = template.Library() StorageClass = get_storage_class() @@ -16,7 +14,7 @@ def markdown_render(value): if value: content = EmailContent(value) return content.html - return '-' + return "-" @register.filter(is_safe=True) diff --git a/apps/notification/tests/test_apis.py b/apps/notification/tests/test_apis.py index 0a5c9ddc2a..9d4cea8a40 100644 --- a/apps/notification/tests/test_apis.py +++ b/apps/notification/tests/test_apis.py @@ -1,15 +1,15 @@ -import pytest from datetime import timedelta -from django.contrib.contenttypes.models import ContentType -from deep.tests import TestCase +import pytest +from django.contrib.contenttypes.models import ContentType from django.utils import timezone - -from user.models import User from lead.models import Lead -from notification.models import Notification, Assignment -from project.models import ProjectJoinRequest, Project +from notification.models import Assignment, Notification +from project.models import Project, ProjectJoinRequest from quality_assurance.models import EntryReviewComment +from user.models import User + +from deep.tests import TestCase class TestNotificationAPIs(TestCase): @@ -18,8 +18,8 @@ def test_get_notifications(self): project = self.create(Project, role=self.admin_role) user = self.create(User) - url = '/api/v1/notifications/' - data = {'project': project.id} + url = "/api/v1/notifications/" + data = {"project": project.id} self.authenticate() @@ -27,7 +27,7 @@ def test_get_notifications(self): self.assert_200(response) rdata = response.data - assert rdata['count'] == 0, "No notifications so far" + assert rdata["count"] == 0, "No notifications so far" # Now, create notifications self.create_join_request(project, user) @@ -35,22 +35,22 @@ def test_get_notifications(self): response = self.client.get(url, data) self.assert_200(response) data = response.json() - assert data['count'] == 1, "A notification created for join request" - result = data['results'][0] - assert 'receiver' in result - assert 'data' in result - assert 'project' in result - assert 'notificationType' in result - assert 'receiver' in result - assert 'status' in result - assert result['status'] == 'unseen' + assert data["count"] == 1, "A notification created for join request" + result = data["results"][0] + assert "receiver" in result + assert "data" in result + assert "project" in result + assert "notificationType" in result + assert "receiver" in result + assert "status" in result + assert result["status"] == "unseen" # TODO: Check inside data def test_update_notification(self): project = self.create(Project, role=self.admin_role) user = self.create(User) - url = '/api/v1/notifications/status/' + url = "/api/v1/notifications/status/" # Create notification self.create_join_request(project, user) @@ -60,9 +60,7 @@ def test_update_notification(self): self.authenticate() - data = [ - {'id': notifs[0].id, 'status': Notification.Status.SEEN} - ] + data = [{"id": notifs[0].id, "status": Notification.Status.SEEN}] response = self.client.put(url, data) self.assert_200(response) @@ -74,7 +72,7 @@ def test_update_notification_invalid_data(self): project = self.create(Project, role=self.admin_role) user = self.create(User) - url = '/api/v1/notifications/status/' + url = "/api/v1/notifications/status/" # Create notification self.create_join_request(project, user) @@ -86,35 +84,25 @@ def test_update_notification_invalid_data(self): # Let's send one valid and other invalid data, this should give 400 data = [ - { - 'id': notifs[0].id + 1, - 'status': Notification.Status.SEEN + 'a' - }, - { - 'id': notifs[0].id, - 'status': Notification.Status.SEEN - }, + {"id": notifs[0].id + 1, "status": Notification.Status.SEEN + "a"}, + {"id": notifs[0].id, "status": Notification.Status.SEEN}, ] response = self.client.put(url, data) self.assert_400(response), "Invalid id and status should give 400" data = response.data - assert 'errors' in data + assert "errors" in data def create_join_request(self, project, user=None): """Create join_request""" user = user or self.create(User) - join_request = ProjectJoinRequest.objects.create( - project=project, - requested_by=user, - role=self.normal_role - ) + join_request = ProjectJoinRequest.objects.create(project=project, requested_by=user, role=self.normal_role) return join_request def test_get_filtered_notifications(self): project = self.create(Project, role=self.admin_role) user = self.create(User) - url = '/api/v1/notifications/' - params = {'project': project.id} + url = "/api/v1/notifications/" + params = {"project": project.id} self.authenticate() # store the time @@ -128,105 +116,104 @@ def test_get_filtered_notifications(self): response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 1, "A notification was created for join request but didnot show" + assert data["count"] == 1, "A notification was created for join request but didnot show" # now applying filters # is_pending filter # by default the notification created is in status = pending - params.update(dict(is_pending='false')) + params.update(dict(is_pending="false")) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 0, "Expected zero non-pending notification" + assert data["count"] == 0, "Expected zero non-pending notification" - params.update(dict(is_pending='true')) + params.update(dict(is_pending="true")) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 1, "Expected one pending notification" + assert data["count"] == 1, "Expected one pending notification" # status filter - params.pop('is_pending', None) - params.update(dict(status='unseen')) + params.pop("is_pending", None) + params.update(dict(status="unseen")) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 1, "One Notification should be with unseen status" + assert data["count"] == 1, "One Notification should be with unseen status" - params.update(dict(status='seen')) + params.update(dict(status="seen")) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 0, "Zero notification should be with seen status" + assert data["count"] == 0, "Zero notification should be with seen status" # timestamp filter - params.pop('status', None) - params.update(dict(timestamp__gt=before.strftime('%Y-%m-%d%z'))) + params.pop("status", None) + params.update(dict(timestamp__gt=before.strftime("%Y-%m-%d%z"))) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 1, "One Notification should be after 'before time' " + assert data["count"] == 1, "One Notification should be after 'before time' " - params.pop('timestamp__gt', None) - params.update(dict(timestamp__lt=before.strftime('%Y-%m-%d%z'))) + params.pop("timestamp__gt", None) + params.update(dict(timestamp__lt=before.strftime("%Y-%m-%d%z"))) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 0, "No notification should be before 'before time'" + assert data["count"] == 0, "No notification should be before 'before time'" - params.pop('timestamp__lt', None) - params.update(dict(timestamp__gt=after.strftime('%Y-%m-%d%z'))) + params.pop("timestamp__lt", None) + params.update(dict(timestamp__gt=after.strftime("%Y-%m-%d%z"))) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 0, "No notification should be after 'after time'" + assert data["count"] == 0, "No notification should be after 'after time'" - params.update(dict(timestamp__gt=before.strftime('%Y-%m-%d%z'), - timestamp__lt=after.strftime('%Y-%m-%d%z'))) + params.update(dict(timestamp__gt=before.strftime("%Y-%m-%d%z"), timestamp__lt=after.strftime("%Y-%m-%d%z"))) response = self.client.get(url, params) self.assert_200(response) data = response.json() - assert data['count'] == 1, "One notification should be after 'before time' and before 'after time'" + assert data["count"] == 1, "One notification should be after 'before time' and before 'after time'" def test_get_notification_count(self): project = self.create_project() user = self.create(User) - url = '/api/v1/notifications/count/' - data = {'project': project.id} + url = "/api/v1/notifications/count/" + data = {"project": project.id} self.authenticate() response = self.client.get(url, data) self.assert_200(response) data = response.data - assert data['total'] == 0 - assert data['unseen_notifications'] == 0 - assert data['unseen_requests'] == 0 + assert data["total"] == 0 + assert data["unseen_notifications"] == 0 + assert data["unseen_requests"] == 0 # Now, create join request join_request = self.create_join_request(project, user) response = self.client.get(url, data) data = response.data self.assert_200(response) - assert data['total'] == 1 - assert data['unseen_notifications'] == 0 - assert data['unseen_requests'] == 1 + assert data["total"] == 1 + assert data["unseen_notifications"] == 0 + assert data["unseen_requests"] == 1 # Change status of project join request - join_request.status = 'accepted' + join_request.status = "accepted" join_request.responded_by = self.user join_request.save() response = self.client.get(url, data) data = response.data self.assert_200(response) - assert data['total'] == 2 + assert data["total"] == 2 # One notification is of join request # Another new notification is created after user sucessfully joins project - assert data['unseen_notifications'] == 2 - assert data['unseen_requests'] == 0 + assert data["unseen_notifications"] == 2 + assert data["unseen_requests"] == 0 # XXX: @@ -234,7 +221,7 @@ def test_get_notification_count(self): # is causing issue, so running this before all. @pytest.mark.run(order=1) class TestAssignmentApi(TestCase): - """ Api test for assignment model""" + """Api test for assignment model""" def test_get_assignments_lead(self): project = self.create_project() @@ -242,129 +229,125 @@ def test_get_assignments_lead(self): user1 = self.create(User) user2 = self.create(User) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" # try creating lead lead = self.create_lead(project=project, assignee=[user1]) self.create(Lead, project=project1, assignee=[user2]) self.authenticate(user1) - params = {'project': project.id} + params = {"project": project.id} response = self.client.get(url, params) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['project_details']['id'], project.id) - self.assertEqual(response.data['results'][0]['content_object_type'], 'lead') - self.assertEqual(response.data['results'][0]['content_object_details']['id'], lead.id) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["project_details"]["id"], project.id) + self.assertEqual(response.data["results"][0]["content_object_type"], "lead") + self.assertEqual(response.data["results"][0]["content_object_details"]["id"], lead.id) def test_create_assignment_on_lead_title_change(self): project = self.create_project() user1 = self.create(User) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" # create lead with title lead = self.create(Lead, title="Uncommitted", project=project, assignee=[user1]) - url = '/api/v1/leads/' + url = "/api/v1/leads/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # try to change the title this should not create another assignment - url = '/api/v1/leads/{}/'.format(lead.id) - data = { - 'title': 'Changed' - } + url = "/api/v1/leads/{}/".format(lead.id) + data = {"title": "Changed"} self.authenticate() response = self.client.patch(url, data) self.assert_200(response) # try to check the assignment - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['content_object_type'], 'lead') - self.assertEqual(response.data['results'][0]['content_object_details']['id'], lead.id) - self.assertEqual(response.data['results'][0]['content_object_details']['title'], 'Changed') # the new title + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["content_object_type"], "lead") + self.assertEqual(response.data["results"][0]["content_object_details"]["id"], lead.id) + self.assertEqual(response.data["results"][0]["content_object_details"]["title"], "Changed") # the new title def test_create_assignment_on_lead_assignee_change(self): project = self.create_project() user1 = self.create(User) user2 = self.create(User) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" # create lead with title lead = self.create(Lead, title="Uncommitted", project=project, assignee=[user1]) - url = '/api/v1/leads/' + url = "/api/v1/leads/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # try to change the title this should not create another assignment - url = '/api/v1/leads/{}/'.format(lead.id) - data = { - 'assignee': user2.id - } + url = "/api/v1/leads/{}/".format(lead.id) + data = {"assignee": user2.id} self.authenticate() response = self.client.patch(url, data) self.assert_200(response) # try to check the assignment - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0 # changing the assignee should remove fromn the previous assignee + assert data["count"] == 0 # changing the assignee should remove fromn the previous assignee # try to aunthenticate the user2 - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user2) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) def test_get_assignments_entrycomment(self): project = self.create_project() @@ -373,27 +356,27 @@ def test_get_assignments_entrycomment(self): user2 = self.create(User) entry = self.create_entry(project=project) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" entry_comment = self.create(EntryReviewComment, entry=entry, project=project, mentioned_users=[user1]) self.create(EntryReviewComment, entry=entry, project=project1, mentioned_users=[user2]) self.authenticate(user1) - params = {'project': project.id} + params = {"project": project.id} response = self.client.get(url, params) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['project_details']['id'], entry.project.id) - self.assertEqual(response.data['results'][0]['content_object_type'], 'entryreviewcomment') - self.assertEqual(response.data['results'][0]['content_object_details']['id'], entry_comment.id) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["project_details"]["id"], entry.project.id) + self.assertEqual(response.data["results"][0]["content_object_type"], "entryreviewcomment") + self.assertEqual(response.data["results"][0]["content_object_details"]["id"], entry_comment.id) def test_create_assignment_on_entry_comment_text_change(self): project = self.create_project() @@ -403,49 +386,49 @@ def test_create_assignment_on_entry_comment_text_change(self): entry = self.create_entry(project=project) entry.project.add_member(user1) - url1 = '/api/v1/assignments/' + url1 = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url1) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" - url = f'/api/v1/entries/{entry.pk}/review-comments/' + url = f"/api/v1/entries/{entry.pk}/review-comments/" data = { - 'mentioned_users': [user1.pk], - 'text': 'This is first comment', - 'parent': None, + "mentioned_users": [user1.pk], + "text": "This is first comment", + "parent": None, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - comment_id = response.json()['id'] + comment_id = response.json()["id"] - url1 = '/api/v1/assignments/' + url1 = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url1) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # Patch new text - new_text = 'this is second comment' + new_text = "this is second comment" self.authenticate() - response = self.client.patch(f'{url}{comment_id}/', {'text': new_text}) + response = self.client.patch(f"{url}{comment_id}/", {"text": new_text}) self.assert_200(response) # try to check the assignment - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 1 - self.assertEqual(response.data['results'][0]['content_object_details']['id'], comment_id) - self.assertEqual(response.data['results'][0]['content_object_details']['text'], new_text) + assert data["count"] == 1 + self.assertEqual(response.data["results"][0]["content_object_details"]["id"], comment_id) + self.assertEqual(response.data["results"][0]["content_object_details"]["text"], new_text) def test_assignment_create_on_entry_comment_assignee_change(self): project = self.create_project() @@ -456,54 +439,54 @@ def test_assignment_create_on_entry_comment_assignee_change(self): for user in [user1, user2]: entry.project.add_member(user, role=self.normal_role) - url1 = '/api/v1/assignments/' + url1 = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url1) self.assert_200(response) data = response.data - assert data['count'] == 0, "No Assignments till now" + assert data["count"] == 0, "No Assignments till now" - url = f'/api/v1/entries/{entry.pk}/review-comments/' + url = f"/api/v1/entries/{entry.pk}/review-comments/" data = { - 'mentioned_users': [user1.pk], - 'text': 'This is first comment', - 'parent': None, + "mentioned_users": [user1.pk], + "text": "This is first comment", + "parent": None, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - comment_id = response.json()['id'] + comment_id = response.json()["id"] - url1 = '/api/v1/assignments/' + url1 = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url1) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # Patch new assignee self.authenticate() - response = self.client.patch(f'{url}{comment_id}/', {'mentioned_users': [user2.pk]}) + response = self.client.patch(f"{url}{comment_id}/", {"mentioned_users": [user2.pk]}) self.assert_200(response) # try to check the assignment - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 0 # no assignment for user1 + assert data["count"] == 0 # no assignment for user1 - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user2) response = self.client.get(url) self.assert_200(response) data = response.data - assert data['count'] == 1 # assignment for user2 + assert data["count"] == 1 # assignment for user2 def test_assignment_is_done(self): # XXX: To avoid using content type cache from pre-tests @@ -514,43 +497,43 @@ def test_assignment_is_done(self): user2 = self.create(User) lead = self.create(Lead, project=project) kwargs = { - 'content_object': lead, - 'project': project, - 'created_for': user1, - 'created_by': user2, + "content_object": lead, + "project": project, + "created_for": user1, + "created_by": user2, } assignment = self.create(Assignment, **kwargs) self.create(Assignment, **kwargs) self.create(Assignment, **kwargs) - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 3) + self.assertEqual(response.data["count"], 3) # try to put is_done for single assignment - url = f'/api/v1/assignments/{assignment.id}/' + url = f"/api/v1/assignments/{assignment.id}/" data = { - 'is_done': 'true', + "is_done": "true", } self.authenticate(user1) response = self.client.put(url, data) self.assert_200(response) - self.assertEqual(response.data['is_done'], True) + self.assertEqual(response.data["is_done"], True) - url = '/api/v1/assignments/bulk-mark-as-done/' + url = "/api/v1/assignments/bulk-mark-as-done/" data = { - 'is_done': 'true', + "is_done": "true", } response = self.client.post(url, data) self.assert_200(response) - self.assertEqual(response.data['assignment_updated'], 2) + self.assertEqual(response.data["assignment_updated"], 2) # test for is_done is true - url = '/api/v1/assignments/' + url = "/api/v1/assignments/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['results'][1]['is_done'], True) - self.assertEqual(response.data['results'][2]['is_done'], True) + self.assertEqual(response.data["results"][1]["is_done"], True) + self.assertEqual(response.data["results"][2]["is_done"], True) diff --git a/apps/notification/tests/test_mutation.py b/apps/notification/tests/test_mutation.py index bca376614f..66ce7db378 100644 --- a/apps/notification/tests/test_mutation.py +++ b/apps/notification/tests/test_mutation.py @@ -1,17 +1,17 @@ -from utils.graphene.tests import GraphQLTestCase from django.contrib.contenttypes.models import ContentType - -from user.factories import UserFactory -from project.factories import ProjectFactory from lead.factories import LeadFactory -from notification.factories import NotificationFactory, AssignmentFactory -from notification.models import Assignment, Notification from lead.models import Lead +from notification.factories import AssignmentFactory, NotificationFactory +from notification.models import Assignment, Notification +from project.factories import ProjectFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class NotificationMutation(GraphQLTestCase): def test_notification_status_update(self): - self.notification_query = ''' + self.notification_query = """ mutation Mutation($input: NotificationStatusInputType!) { notificationStatusUpdate(data: $input) { ok @@ -22,18 +22,15 @@ def test_notification_status_update(self): } } } - ''' + """ user = UserFactory.create() another_user = UserFactory.create() notification = NotificationFactory.create(status=Notification.Status.UNSEEN, receiver=user) def _query_check(minput, **kwargs): - return self.query_check( - self.notification_query, - minput=minput, - **kwargs - ) + return self.query_check(self.notification_query, minput=minput, **kwargs) + minput = dict(id=notification.id, status=self.genum(Notification.Status.SEEN)) # -- Without login _query_check(minput, assert_for_error=True) @@ -41,7 +38,7 @@ def _query_check(minput, **kwargs): # -- With login self.force_login(user) content = _query_check(minput) - self.assertEqual(content['data']['notificationStatusUpdate']['errors'], None, content) + self.assertEqual(content["data"]["notificationStatusUpdate"]["errors"], None, content) # check for the notification status update(db-level) notification = Notification.objects.get(id=notification.id) @@ -49,20 +46,20 @@ def _query_check(minput, **kwargs): # -- with different user self.force_login(another_user) - content = _query_check(minput, okay=False)['data']['notificationStatusUpdate']['result'] + content = _query_check(minput, okay=False)["data"]["notificationStatusUpdate"]["result"] self.assertEqual(content, None, content) class TestAssignmentMutation(GraphQLTestCase): def test_assginment_bulk_status_mark_as_done(self): - self.assignment_query = ''' + self.assignment_query = """ mutation MyMutation { assignmentBulkStatusMarkAsDone { errors ok } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() lead = LeadFactory.create() @@ -72,23 +69,20 @@ def test_assginment_bulk_status_mark_as_done(self): object_id=lead.id, content_type=ContentType.objects.get_for_model(Lead), created_for=user, - is_done=False + is_done=False, ) def _query_check(**kwargs): - return self.query_check( - self.assignment_query, - **kwargs - ) + return self.query_check(self.assignment_query, **kwargs) self.force_login(user) content = _query_check() assignments_qs = Assignment.get_for(user).filter(is_done=False) - self.assertEqual(content['data']['assignmentBulkStatusMarkAsDone']['errors'], None) + self.assertEqual(content["data"]["assignmentBulkStatusMarkAsDone"]["errors"], None) self.assertEqual(len(assignments_qs), 0) def test_individual_assignment_update_status(self): - self.indivdual_assignment_query = ''' + self.indivdual_assignment_query = """ mutation Mutation($isDone: Boolean, $id: ID! ){ assignmentUpdate(id: $id, data: {isDone: $isDone}){ ok @@ -99,7 +93,7 @@ def test_individual_assignment_update_status(self): } } } - ''' + """ user = UserFactory.create() project = ProjectFactory.create() @@ -109,16 +103,11 @@ def test_individual_assignment_update_status(self): object_id=lead.id, content_type=ContentType.objects.get_for_model(Lead), created_for=user, - is_done=False - + is_done=False, ) def _query_check(**kwargs): - return self.query_check( - self.indivdual_assignment_query, - variables={"isDone": True, "id": assignment.id}, - **kwargs - ) + return self.query_check(self.indivdual_assignment_query, variables={"isDone": True, "id": assignment.id}, **kwargs) # without login @@ -129,5 +118,5 @@ def _query_check(**kwargs): self.force_login(user) content = _query_check() assignment_qs = Assignment.get_for(user).filter(id=assignment.id, is_done=False) - self.assertEqual(content['data']['assignmentUpdate']['errors'], None) + self.assertEqual(content["data"]["assignmentUpdate"]["errors"], None) self.assertEqual(len(assignment_qs), 0) diff --git a/apps/notification/tests/test_notification.py b/apps/notification/tests/test_notification.py index 9059bd7317..fa5bde187a 100644 --- a/apps/notification/tests/test_notification.py +++ b/apps/notification/tests/test_notification.py @@ -1,18 +1,18 @@ from unittest.mock import patch -from deep.tests import TestCase -from user.models import User -from notification.models import Notification, Assignment -from project.models import ProjectJoinRequest, Project -from lead.models import Lead -from quality_assurance.models import EntryReviewComment - -from user.factories import UserFactory -from project.factories import ProjectFactory from analysis_framework.factories import AnalysisFrameworkFactory from entry.factories import EntryFactory -from quality_assurance.factories import EntryReviewCommentFactory from lead.factories import LeadFactory +from lead.models import Lead +from notification.models import Assignment, Notification +from project.factories import ProjectFactory +from project.models import Project, ProjectJoinRequest +from quality_assurance.factories import EntryReviewCommentFactory +from quality_assurance.models import EntryReviewComment +from user.factories import UserFactory +from user.models import User + +from deep.tests import TestCase class TestNotification(TestCase): @@ -27,24 +27,19 @@ def test_notification_created_on_project_join_request(self): # Add admin user to project project.add_member(admin_user, role=self.admin_role) ProjectJoinRequest.objects.create( - project=project, - requested_by=normal_user, - role=self.normal_role, - data={'reason': 'bla'} + project=project, requested_by=normal_user, role=self.normal_role, data={"reason": "bla"} ) # Get notifications for admin_users for user in [self.user, admin_user]: notifications = Notification.get_for(user) - assert notifications.count() == 1, \ - "A notification should have been created for admin" + assert notifications.count() == 1, "A notification should have been created for admin" notification = notifications[0] assert notification.status == Notification.Status.UNSEEN - assert notification.notification_type ==\ - Notification.Type.PROJECT_JOIN_REQUEST + assert notification.notification_type == Notification.Type.PROJECT_JOIN_REQUEST assert notification.receiver == user - assert notification.data['status'] == 'pending' - assert notification.data['data']['reason'] is not None + assert notification.data["status"] == "pending" + assert notification.data["data"]["reason"] is not None # Get notifications for requesting user # there should be none @@ -57,34 +52,29 @@ def test_notification_updated_on_request_accepted(self): normal_user = self.create(User) join_request = ProjectJoinRequest.objects.create( - project=project, - requested_by=normal_user, - role=self.normal_role, - data={'reason': 'bla'} + project=project, requested_by=normal_user, role=self.normal_role, data={"reason": "bla"} ) # Get notification for self.user notifications = Notification.get_for(self.user) assert notifications.count() == 1 - assert notifications[0].notification_type ==\ - Notification.Type.PROJECT_JOIN_REQUEST + assert notifications[0].notification_type == Notification.Type.PROJECT_JOIN_REQUEST # Update join_request by adding member project.add_member(join_request.requested_by, role=join_request.role) # Manually updateing join_request because add_member does not trigger # receiver for join_request post_save - join_request.status = 'accepted' + join_request.status = "accepted" join_request.role = join_request.role join_request.save() # Get notifications for admin notifications = Notification.get_for(self.user) assert notifications.count() == 2 - new_notif = Notification.get_for(self.user).order_by('-timestamp')[0] - assert new_notif.notification_type ==\ - Notification.Type.PROJECT_JOIN_RESPONSE - assert new_notif.data['status'] == 'accepted' + new_notif = Notification.get_for(self.user).order_by("-timestamp")[0] + assert new_notif.notification_type == Notification.Type.PROJECT_JOIN_RESPONSE + assert new_notif.data["status"] == "accepted" # Get notifications for requesting user # He/She should get a notification saying request is accepted @@ -92,9 +82,8 @@ def test_notification_updated_on_request_accepted(self): assert notifications.count() == 1 new_notif = notifications[0] - assert new_notif.notification_type ==\ - Notification.Type.PROJECT_JOIN_RESPONSE - assert new_notif.data['status'] == 'accepted' + assert new_notif.notification_type == Notification.Type.PROJECT_JOIN_RESPONSE + assert new_notif.data["status"] == "accepted" def test_notification_updated_on_request_rejected(self): project = self.create(Project, role=self.admin_role) @@ -102,31 +91,26 @@ def test_notification_updated_on_request_rejected(self): normal_user = self.create(User) join_request = ProjectJoinRequest.objects.create( - project=project, - requested_by=normal_user, - role=self.normal_role, - data={'reason': 'bla'} + project=project, requested_by=normal_user, role=self.normal_role, data={"reason": "bla"} ) # Get notification for self.user notifications = Notification.get_for(self.user) assert notifications.count() == 1 - assert notifications[0].notification_type ==\ - Notification.Type.PROJECT_JOIN_REQUEST - assert notifications[0].data['status'] == 'pending' + assert notifications[0].notification_type == Notification.Type.PROJECT_JOIN_REQUEST + assert notifications[0].data["status"] == "pending" # Update join_request without adding member - join_request.status = 'rejected' + join_request.status = "rejected" join_request.role = join_request.role join_request.save() # Get notifications for admin notifications = Notification.get_for(self.user) assert notifications.count() == 2 - new_notif = notifications.order_by('-timestamp')[0] - assert new_notif.notification_type ==\ - Notification.Type.PROJECT_JOIN_RESPONSE - assert new_notif.data['status'] == 'rejected' + new_notif = notifications.order_by("-timestamp")[0] + assert new_notif.notification_type == Notification.Type.PROJECT_JOIN_RESPONSE + assert new_notif.data["status"] == "rejected" # Get notifications for requesting user # He/She should get a notification saying request is rejected @@ -134,27 +118,22 @@ def test_notification_updated_on_request_rejected(self): assert notifications.count() == 1 new_notif = notifications[0] - assert new_notif.notification_type ==\ - Notification.Type.PROJECT_JOIN_RESPONSE - assert new_notif.data['status'] == 'rejected' + assert new_notif.notification_type == Notification.Type.PROJECT_JOIN_RESPONSE + assert new_notif.data["status"] == "rejected" def test_notification_updated_on_request_aborted(self): project = self.create(Project, role=self.admin_role) normal_user = self.create(User) join_request = ProjectJoinRequest.objects.create( - project=project, - requested_by=normal_user, - role=self.normal_role, - data={'reason': 'bla'} + project=project, requested_by=normal_user, role=self.normal_role, data={"reason": "bla"} ) # Get notification for self.user notifications = Notification.get_for(self.user) assert notifications.count() == 1 - assert notifications[0].notification_type ==\ - Notification.Type.PROJECT_JOIN_REQUEST - assert notifications[0].data['status'] == 'pending' + assert notifications[0].notification_type == Notification.Type.PROJECT_JOIN_REQUEST + assert notifications[0].data["status"] == "pending" # Now abort join request by deleting it join_request.delete() @@ -162,10 +141,9 @@ def test_notification_updated_on_request_aborted(self): # Get notifications again notifications = Notification.get_for(self.user) assert notifications.count() == 2 - new_notif = notifications.order_by('-timestamp')[0] - assert new_notif.data['status'] == 'aborted' - assert new_notif.notification_type ==\ - Notification.Type.PROJECT_JOIN_REQUEST_ABORT + new_notif = notifications.order_by("-timestamp")[0] + assert new_notif.data["status"] == "aborted" + assert new_notif.notification_type == Notification.Type.PROJECT_JOIN_REQUEST_ABORT # Get notifications for requesting user # there should be none @@ -174,17 +152,15 @@ def test_notification_updated_on_request_aborted(self): class TestAssignment(TestCase): - """ Unit test for Assignment""" + """Unit test for Assignment""" - @patch('notification.receivers.assignment.get_current_user') + @patch("notification.receivers.assignment.get_current_user") def test_create_assignment_create_on_entry_review_comment(self, get_user_mocked_func): af = AnalysisFrameworkFactory.create() project = ProjectFactory.create(analysis_framework=af) user1, user2 = UserFactory.create_batch(2) get_user_mocked_func.return_value = user2 - entry = EntryFactory.create( - lead=LeadFactory.create(project=project) - ) + entry = EntryFactory.create(lead=LeadFactory.create(project=project)) old_assignment_count = Assignment.objects.count() entry_review_comment = EntryReviewCommentFactory.create(entry=entry, entry_comment=None, created_by=user1) @@ -214,7 +190,7 @@ def test_create_assignment_create_on_entry_review_comment(self, get_user_mocked_ assert assignment.count() == 1 # for only the user assert get_user_mocked_func.called - @patch('notification.receivers.assignment.get_current_user') + @patch("notification.receivers.assignment.get_current_user") def test_assignment_create_on_lead_create(self, get_user_mocked_func): project = self.create(Project) user1 = self.create_user() @@ -252,7 +228,7 @@ def test_assignment_create_on_lead_create(self, get_user_mocked_func): assert assignment.count() == 1 # for only the user assert get_user_mocked_func.called - @patch('notification.receivers.assignment.get_current_user') + @patch("notification.receivers.assignment.get_current_user") def test_assignment_on_lead_and_entry_review_comment_delete(self, get_user_mocked_func): project = self.create_project() user1 = self.create(User) diff --git a/apps/notification/tests/test_schemas.py b/apps/notification/tests/test_schemas.py index eecf3b1bdb..bc249bd6f3 100644 --- a/apps/notification/tests/test_schemas.py +++ b/apps/notification/tests/test_schemas.py @@ -1,18 +1,17 @@ import datetime -import pytz +import pytz +from analysis_framework.factories import AnalysisFrameworkFactory from django.contrib.contenttypes.models import ContentType - -from utils.graphene.tests import GraphQLTestCase +from entry.factories import EntryFactory +from lead.factories import LeadFactory +from notification.factories import AssignmentFactory, NotificationFactory from notification.models import Notification - -from user.factories import UserFactory from project.factories import ProjectFactory -from notification.factories import AssignmentFactory, NotificationFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory from quality_assurance.factories import EntryReviewCommentFactory -from analysis_framework.factories import AnalysisFrameworkFactory +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestNotificationQuerySchema(GraphQLTestCase): @@ -20,7 +19,7 @@ def test_notifications_query(self): """ Test notification for users """ - query = ''' + query = """ query MyQuery { notifications { totalCount @@ -39,7 +38,7 @@ def test_notifications_query(self): } } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() @@ -60,16 +59,16 @@ def _query_check(**kwargs): # --- With login self.force_login(user) content = _query_check() - self.assertEqual(content['data']['notifications']['totalCount'], 10, content) - self.assertEqual(len(content['data']['notifications']['results']), 10, content) + self.assertEqual(content["data"]["notifications"]["totalCount"], 10, content) + self.assertEqual(len(content["data"]["notifications"]["results"]), 10, content) self.force_login(another_user) content = _query_check() - self.assertEqual(content['data']['notifications']['totalCount'], 2, content) - self.assertEqual(len(content['data']['notifications']['results']), 2, content) + self.assertEqual(content["data"]["notifications"]["totalCount"], 2, content) + self.assertEqual(len(content["data"]["notifications"]["results"]), 2, content) def test_notification_query(self): - query = ''' + query = """ query MyQuery ($id: ID!) { notification(id: $id) { id @@ -85,7 +84,7 @@ def test_notification_query(self): data } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() @@ -98,7 +97,7 @@ def test_notification_query(self): other_notification = NotificationFactory.create(project=project, receiver=another_user, **notification_meta) def _query_check(notification, **kwargs): - return self.query_check(query, variables={'id': notification.pk}, **kwargs) + return self.query_check(query, variables={"id": notification.pk}, **kwargs) # -- Without login _query_check(our_notification, assert_for_error=True) @@ -106,13 +105,13 @@ def _query_check(notification, **kwargs): # --- With login self.force_login(user) content = _query_check(our_notification) - self.assertNotEqual(content['data']['notification'], None, content) + self.assertNotEqual(content["data"]["notification"], None, content) content = _query_check(other_notification) - self.assertEqual(content['data']['notification'], None, content) + self.assertEqual(content["data"]["notification"], None, content) def test_notifications_with_filter_query(self): - query = ''' + query = """ query MyQuery ( $timestamp: DateTime, $timestampLte: DateTime, @@ -145,7 +144,7 @@ def test_notifications_with_filter_query(self): } } } - ''' + """ project = ProjectFactory.create() user = UserFactory.create() @@ -169,7 +168,7 @@ def test_notifications_with_filter_query(self): notification_type=Notification.Type.PROJECT_JOIN_REQUEST_ABORT, status=Notification.Status.UNSEEN, timestamp=datetime.datetime(2021, 2, 1, 0, 0, 0, 0, tzinfo=pytz.UTC), - data={'status': 'pending'}, + data={"status": "pending"}, **notification_meta, ) @@ -179,24 +178,24 @@ def _query_check(filters, **kwargs): # --- With login self.force_login(user) for filters, count in [ - ({'status': self.genum(Notification.Status.SEEN)}, 1), - ({'status': self.genum(Notification.Status.UNSEEN)}, 2), - ({'notificationType': self.genum(Notification.Type.PROJECT_JOIN_REQUEST)}, 2), - ({'notificationType': self.genum(Notification.Type.PROJECT_JOIN_REQUEST_ABORT)}, 1), - ({'isPending': True}, 1), - ({'isPending': False}, 2), - ({'timestampGte': '2021-01-01T00:00:00+00:00', 'timestampLte': '2021-01-01T00:00:00+00:00'}, 1), - ({'timestampGte': '2021-01-01T00:00:00+00:00', 'timestampLte': '2021-02-01T00:00:00+00:00'}, 3), - ({'timestamp': '2021-01-01T00:00:00+00:00'}, 1), + ({"status": self.genum(Notification.Status.SEEN)}, 1), + ({"status": self.genum(Notification.Status.UNSEEN)}, 2), + ({"notificationType": self.genum(Notification.Type.PROJECT_JOIN_REQUEST)}, 2), + ({"notificationType": self.genum(Notification.Type.PROJECT_JOIN_REQUEST_ABORT)}, 1), + ({"isPending": True}, 1), + ({"isPending": False}, 2), + ({"timestampGte": "2021-01-01T00:00:00+00:00", "timestampLte": "2021-01-01T00:00:00+00:00"}, 1), + ({"timestampGte": "2021-01-01T00:00:00+00:00", "timestampLte": "2021-02-01T00:00:00+00:00"}, 3), + ({"timestamp": "2021-01-01T00:00:00+00:00"}, 1), ]: content = _query_check(filters) - self.assertEqual(content['data']['notifications']['totalCount'], count, f'\n{filters=} \n{content=}') - self.assertEqual(len(content['data']['notifications']['results']), count, f'\n{filters=} \n{content=}') + self.assertEqual(content["data"]["notifications"]["totalCount"], count, f"\n{filters=} \n{content=}") + self.assertEqual(len(content["data"]["notifications"]["results"]), count, f"\n{filters=} \n{content=}") class TestAssignmentQuerySchema(GraphQLTestCase): def test_assignments_query(self): - query = ''' + query = """ query MyQuery { assignments { results { @@ -230,7 +229,7 @@ def test_assignments_query(self): totalCount } } - ''' + """ # XXX: To avoid using content type cache from pre-tests ContentType.objects.clear_cache() @@ -240,14 +239,8 @@ def test_assignments_query(self): another = UserFactory.create() lead = LeadFactory.create() af = AnalysisFrameworkFactory.create() - entry = EntryFactory.create( - analysis_framework=af, - lead=lead - ) - entry_comment = EntryReviewCommentFactory.create( - entry=entry, - created_by=user - ) + entry = EntryFactory.create(analysis_framework=af, lead=lead) + entry_comment = EntryReviewCommentFactory.create(entry=entry, created_by=user) AssignmentFactory.create_batch( 3, @@ -265,23 +258,18 @@ def _query_check(**kwargs): # -- with login with different user self.force_login(another) content = _query_check() - self.assertEqual(content['data']['assignments']['results'], [], content) + self.assertEqual(content["data"]["assignments"]["results"], [], content) # -- with login normal user self.force_login(user) content = _query_check() - self.assertEqual(content['data']['assignments']['totalCount'], 3) - self.assertEqual(content['data']['assignments']['results'][0]['contentData']['contentType'], 'LEAD') - AssignmentFactory.create_batch( - 3, - project=project, - content_object=entry_comment, - created_for=user - ) + self.assertEqual(content["data"]["assignments"]["totalCount"], 3) + self.assertEqual(content["data"]["assignments"]["results"][0]["contentData"]["contentType"], "LEAD") + AssignmentFactory.create_batch(3, project=project, content_object=entry_comment, created_for=user) content = _query_check() - self.assertEqual(content['data']['assignments']['totalCount'], 6) + self.assertEqual(content["data"]["assignments"]["totalCount"], 6) def test_assignments_with_filter_query(self): - query = ''' + query = """ query MyQuery($isDone: Boolean) { assignments(isDone: $isDone) { totalCount @@ -290,7 +278,7 @@ def test_assignments_with_filter_query(self): } } } - ''' + """ # XXX: To avoid using content type cache from pre-tests ContentType.objects.clear_cache() @@ -299,38 +287,20 @@ def test_assignments_with_filter_query(self): user = UserFactory.create() lead = LeadFactory.create() af = AnalysisFrameworkFactory.create() - entry = EntryFactory.create( - analysis_framework=af, - lead=lead - ) - entry_comment = EntryReviewCommentFactory.create( - entry=entry, - created_by=user - ) + entry = EntryFactory.create(analysis_framework=af, lead=lead) + entry_comment = EntryReviewCommentFactory.create(entry=entry, created_by=user) - AssignmentFactory.create_batch( - 3, - project=project, - content_object=lead, - created_for=user, - is_done=False - ) - AssignmentFactory.create_batch( - 5, - project=project, - content_object=entry_comment, - created_for=user, - is_done=True - ) + AssignmentFactory.create_batch(3, project=project, content_object=lead, created_for=user, is_done=False) + AssignmentFactory.create_batch(5, project=project, content_object=entry_comment, created_for=user, is_done=True) def _query_check(filters, **kwargs): return self.query_check(query, variables=filters, **kwargs) self.force_login(user) for filters, count in [ - ({'isDone': True}, 5), - ({'isDone': False}, 3), + ({"isDone": True}, 5), + ({"isDone": False}, 3), ]: content = _query_check(filters) - self.assertEqual(content['data']['assignments']['totalCount'], count, f'\n{filters=} \n{content=}') - self.assertEqual(len(content['data']['assignments']['results']), count, f'\n{filters=} \n{content=}') + self.assertEqual(content["data"]["assignments"]["totalCount"], count, f"\n{filters=} \n{content=}") + self.assertEqual(len(content["data"]["assignments"]["results"]), count, f"\n{filters=} \n{content=}") diff --git a/apps/notification/views.py b/apps/notification/views.py index 49498b2c76..98e2a107e2 100644 --- a/apps/notification/views.py +++ b/apps/notification/views.py @@ -1,27 +1,19 @@ import django_filters +from notification.filter_set import AssignmentFilterSet, NotificationFilterSet +from rest_framework import exceptions, permissions, response, viewsets from rest_framework.decorators import action -from .serializers import NotificationSerializer, AssignmentSerializer -from .models import Notification, Assignment from deep.paginations import AssignmentPagination -from notification.filter_set import ( - NotificationFilterSet, - AssignmentFilterSet -) -from rest_framework import ( - exceptions, - response, - permissions, - viewsets, -) +from .models import Assignment, Notification +from .serializers import AssignmentSerializer, NotificationSerializer class NotificationViewSet(viewsets.ModelViewSet): serializer_class = NotificationSerializer permission_classes = [permissions.IsAuthenticated] filter_backends = (django_filters.rest_framework.DjangoFilterBackend,) - filterset_fields = ('project',) + filterset_fields = ("project",) filterset_class = NotificationFilterSet def get_queryset(self): @@ -31,7 +23,7 @@ def get_queryset(self): def filter_queryset(self, queryset): qs = super().filter_queryset(queryset) - project = self.request.query_params.get('project') + project = self.request.query_params.get("project") if project is not None: qs.filter(project=project) @@ -41,14 +33,12 @@ def filter_queryset(self, queryset): @action( detail=False, permission_classes=[permissions.IsAuthenticated], - methods=['put'], + methods=["put"], serializer_class=NotificationSerializer, - url_path='status', + url_path="status", ) def status_update(self, request, version=None): - serializer = self.get_serializer( - data=request.data, many=True, partial=True - ) + serializer = self.get_serializer(data=request.data, many=True, partial=True) if not serializer.is_valid(): raise exceptions.ValidationError(serializer.errors) serializer.save() @@ -57,7 +47,7 @@ def status_update(self, request, version=None): @action( detail=False, permission_classes=[permissions.IsAuthenticated], - url_path='count', + url_path="count", ) def get_count(self, request, version=None): request.child_route = True @@ -66,13 +56,13 @@ def get_count(self, request, version=None): unseen_notifications = qs.filter(status=Notification.Status.UNSEEN) - unseen_requests_count = unseen_notifications.filter(data__status='pending').count() + unseen_requests_count = unseen_notifications.filter(data__status="pending").count() unseen_notifications_count = unseen_notifications.count() - unseen_requests_count result = { - 'unseen_notifications': unseen_notifications_count, - 'unseen_requests': unseen_requests_count, - 'total': total, + "unseen_notifications": unseen_notifications_count, + "unseen_requests": unseen_requests_count, + "total": total, } return response.Response(result) @@ -85,19 +75,22 @@ class AssignmentViewSet(viewsets.ModelViewSet): pagination_class = AssignmentPagination def get_queryset(self): - return Assignment.get_for(self.request.user).select_related( - 'project', 'created_by', 'content_type', - ).order_by('-created_at') + return ( + Assignment.get_for(self.request.user) + .select_related( + "project", + "created_by", + "content_type", + ) + .order_by("-created_at") + ) - @action( - detail=False, - methods=['POST'], - permission_classes=[permissions.IsAuthenticated], - url_path='bulk-mark-as-done' - ) + @action(detail=False, methods=["POST"], permission_classes=[permissions.IsAuthenticated], url_path="bulk-mark-as-done") def status(self, request, version=None): queryset = self.filter_queryset(self.get_queryset()).filter(is_done=False) updated_rows_count = queryset.update(is_done=True) - return response.Response({ - 'assignment_updated': updated_rows_count, - }) + return response.Response( + { + "assignment_updated": updated_rows_count, + } + ) diff --git a/apps/organization/actions.py b/apps/organization/actions.py index 4e035e77a3..de7c286b25 100644 --- a/apps/organization/actions.py +++ b/apps/organization/actions.py @@ -1,11 +1,11 @@ -import traceback import logging +import traceback -from django.db import transaction from django.contrib import messages from django.contrib.admin import helpers from django.contrib.admin.utils import model_ngettext from django.core.exceptions import PermissionDenied +from django.db import transaction from django.template.response import TemplateResponse from django.utils.safestring import mark_safe from django.utils.translation import gettext as _ @@ -20,7 +20,7 @@ def _merge_organizations(modeladmin, request, queryset): opts = modeladmin.model._meta mergable_objects, count, perms_needed = modeladmin.get_merged_objects(queryset, request) - if request.POST.get('post'): + if request.POST.get("post"): form = MergeForm(request.POST, organizations=queryset) if perms_needed: raise PermissionDenied @@ -29,11 +29,13 @@ def _merge_organizations(modeladmin, request, queryset): for obj in queryset: obj_display = str(obj) modeladmin.log_merge(request, obj, obj_display) - parent_organization = form.data.get('parent_organization') + parent_organization = form.data.get("parent_organization") modeladmin.merge_queryset(request, parent_organization, queryset) - modeladmin.message_user(request, _("Successfully merged %(count)d %(items)s.") % { - "count": n, "items": model_ngettext(modeladmin.opts, n) - }, messages.SUCCESS) + modeladmin.message_user( + request, + _("Successfully merged %(count)d %(items)s.") % {"count": n, "items": model_ngettext(modeladmin.opts, n)}, + messages.SUCCESS, + ) # Return None to display the change list page again. return None @@ -46,38 +48,37 @@ def _merge_organizations(modeladmin, request, queryset): title = _("Are you sure?") context = { **modeladmin.admin_site.each_context(request), - 'title': title, - 'objects_name': str(objects_name), - 'mergable_objects': mergable_objects, - 'model_count': count, - 'queryset': queryset, - 'perms_lacking': perms_needed, - 'opts': opts, - 'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME, - 'media': modeladmin.media, - 'form': form, - 'adminform': helpers.AdminForm( + "title": title, + "objects_name": str(objects_name), + "mergable_objects": mergable_objects, + "model_count": count, + "queryset": queryset, + "perms_lacking": perms_needed, + "opts": opts, + "action_checkbox_name": helpers.ACTION_CHECKBOX_NAME, + "media": modeladmin.media, + "form": form, + "adminform": helpers.AdminForm( form, - list([(None, {'fields': form.base_fields})]), + list([(None, {"fields": form.base_fields})]), {}, - ) + ), } - return TemplateResponse(request, 'organization/merge_confirmation.html', context) + return TemplateResponse(request, "organization/merge_confirmation.html", context) def merge_organizations(modeladmin, request, queryset): try: return _merge_organizations(modeladmin, request, queryset) except Exception: - logger.error('Error occured while merging organization', exc_info=True) + logger.error("Error occured while merging organization", exc_info=True) messages.add_message( - request, messages.ERROR, - mark_safe( - 'Error occured while merging organization:

' + traceback.format_exc() + '
' - ) + request, + messages.ERROR, + mark_safe("Error occured while merging organization:

" + traceback.format_exc() + "
"), ) -merge_organizations.short_description = 'Merge Organizations' -merge_organizations.allowed_permissions = ('merge',) -merge_organizations.long_description = 'Merge Organizations and reflect changes to other part of the deep' +merge_organizations.short_description = "Merge Organizations" +merge_organizations.allowed_permissions = ("merge",) +merge_organizations.long_description = "Merge Organizations and reflect changes to other part of the deep" diff --git a/apps/organization/admin.py b/apps/organization/admin.py index 5709812d85..bc903ded49 100644 --- a/apps/organization/admin.py +++ b/apps/organization/admin.py @@ -1,42 +1,37 @@ from django import forms -from django.utils.html import format_html -from django.contrib import messages -from django.utils.safestring import mark_safe +from django.contrib import admin, messages +from django.contrib.admin.models import CHANGE, LogEntry from django.db import models from django.http import HttpResponseRedirect from django.shortcuts import redirect from django.urls import path, reverse -from django.contrib.admin.models import LogEntry, CHANGE -from django.contrib import admin +from django.utils.html import format_html +from django.utils.safestring import mark_safe +from gallery.models import File -from deep.admin import document_preview, linkify, ReadOnlyMixin +from deep.admin import ReadOnlyMixin, document_preview, linkify from deep.middleware import get_current_user -from gallery.models import File from .actions import merge_organizations from .filters import IsFromReliefWeb -from .models import ( - OrganizationType, - Organization, -) +from .models import Organization, OrganizationType from .tasks import sync_organization_with_relief_web @admin.register(OrganizationType) class OrganizationTypeAdmin(admin.ModelAdmin): - list_display = ('title', 'get_organization_count', 'get_relief_web_id') - readonly_fields = ('relief_web_id',) - search_fields = ('title',) + list_display = ("title", "get_organization_count", "get_relief_web_id") + readonly_fields = ("relief_web_id",) + search_fields = ("title",) def get_queryset(self, request): - return super().get_queryset(request).annotate( - organization_count=models.Count('organization') - ) + return super().get_queryset(request).annotate(organization_count=models.Count("organization")) def get_organization_count(self, instance): if instance: return instance.organization_count - get_organization_count.short_description = 'Organization Count' + + get_organization_count.short_description = "Organization Count" def get_relief_web_id(self, obj): id = obj.relief_web_id @@ -44,26 +39,27 @@ def get_relief_web_id(self, obj): return format_html( f'{id}' ) - get_relief_web_id.short_description = 'ReliefWeb' - get_relief_web_id.admin_order_field = 'relief_web_id' + + get_relief_web_id.short_description = "ReliefWeb" + get_relief_web_id.admin_order_field = "relief_web_id" class OrganizationInline(ReadOnlyMixin, admin.TabularInline): model = Organization can_delete = False - verbose_name_plural = 'Merged Organizations' + verbose_name_plural = "Merged Organizations" extra = 0 class OrganizationModelForm(forms.ModelForm): - update_logo_direct = forms.ImageField(required=False, help_text='This will replace current logo.') + update_logo_direct = forms.ImageField(required=False, help_text="This will replace current logo.") class Meta: model = Organization - fields = '__all__' + fields = "__all__" def save(self, commit=True): - new_logo_file = self.cleaned_data.pop('update_logo_direct', None) + new_logo_file = self.cleaned_data.pop("update_logo_direct", None) instance = super().save(commit=False) if new_logo_file: mime_type = new_logo_file.content_type @@ -84,80 +80,80 @@ def save(self, commit=True): @admin.register(Organization) class OrganizationAdmin(admin.ModelAdmin): - search_fields = ('title', 'short_name', 'long_name') + search_fields = ("title", "short_name", "long_name") list_display = ( - 'title', - 'short_name', - linkify('organization_type'), - 'source', - 'get_relief_web_id', - 'verified', - 'modified_at', + "title", + "short_name", + linkify("organization_type"), + "source", + "get_relief_web_id", + "verified", + "modified_at", ) - readonly_fields = ( - document_preview('logo', label='Logo Preview', max_height='400px', max_width='300px'), - 'relief_web_id' + readonly_fields = (document_preview("logo", label="Logo Preview", max_height="400px", max_width="300px"), "relief_web_id") + list_filter = ( + "organization_type", + "verified", + IsFromReliefWeb, + "source", ) - list_filter = ('organization_type', 'verified', IsFromReliefWeb, 'source',) actions = (merge_organizations,) - exclude = ('parent',) + exclude = ("parent",) inlines = [OrganizationInline] autocomplete_fields = ( - 'created_by', - 'modified_by', - 'logo', - 'organization_type', - 'regions', - 'parent', + "created_by", + "modified_by", + "logo", + "organization_type", + "regions", + "parent", ) - change_list_template = 'admin/organization_change_list.html' + change_list_template = "admin/organization_change_list.html" form = OrganizationModelForm def get_relief_web_id(self, obj): id = obj.relief_web_id if id: return format_html(f'{id} ') - get_relief_web_id.short_description = 'ReliefWeb' - get_relief_web_id.admin_order_field = 'relief_web_id' + + get_relief_web_id.short_description = "ReliefWeb" + get_relief_web_id.admin_order_field = "relief_web_id" def get_queryset(self, request): - return super().get_queryset(request).prefetch_related('organization_type') + return super().get_queryset(request).prefetch_related("organization_type") - def change_view(self, request, object_id, form_url='', extra_context=None): + def change_view(self, request, object_id, form_url="", extra_context=None): extra_context = extra_context or {} - extra_context['has_merge_permission'] = self.has_merge_permission(request) + extra_context["has_merge_permission"] = self.has_merge_permission(request) return super().change_view(request, object_id, form_url=form_url, extra_context=extra_context) def has_merge_permission(self, request): - return request.user.has_perm('organization.can_merge') + return request.user.has_perm("organization.can_merge") def merge_view(self, request, object_id, extra_context=None): info = self.model._meta.app_label, self.model._meta.model_name org = Organization.objects.get(pk=object_id) org.parent = None - org.save(update_fields=('parent',)) + org.save(update_fields=("parent",)) return HttpResponseRedirect( - reverse('admin:%s_%s_change' % info, kwargs={'object_id': object_id}), + reverse("admin:%s_%s_change" % info, kwargs={"object_id": object_id}), ) def get_urls(self): info = self.model._meta.app_label, self.model._meta.model_name return [ + path("/unmerge/", self.admin_site.admin_view(self.merge_view), name="%s_%s_unmerge" % info), path( - '/unmerge/', - self.admin_site.admin_view(self.merge_view), - name='%s_%s_unmerge' % info - ), - path( - 'trigger-relief-web-sync/', self.admin_site.admin_view(self.trigger_relief_web_sync), - name='organization_relief_web_sync' + "trigger-relief-web-sync/", + self.admin_site.admin_view(self.trigger_relief_web_sync), + name="organization_relief_web_sync", ), ] + super().get_urls() def trigger_relief_web_sync(self, request): sync_organization_with_relief_web.s().delay() - messages.add_message(request, messages.INFO, mark_safe('Successfully triggered organizations re-sync')) - return redirect('admin:organization_organization_changelist') + messages.add_message(request, messages.INFO, mark_safe("Successfully triggered organizations re-sync")) + return redirect("admin:organization_organization_changelist") def get_inline_instances(self, request, obj=None): if obj and obj.related_childs.exists(): @@ -165,7 +161,7 @@ def get_inline_instances(self, request, obj=None): return [] def get_exclude(self, request, obj=None): - if request.GET.get('show_parent', False): + if request.GET.get("show_parent", False): return return self.exclude @@ -186,7 +182,7 @@ def log_merge(self, request, object, object_repr): object_id=object.pk, object_repr=object_repr, action_flag=CHANGE, - change_message='Merged organization', + change_message="Merged organization", ) def get_merged_objects(self, objs, request): @@ -204,17 +200,12 @@ def update_children(related_childs): org_list = [] for child_org in related_childs.all(): if child_org.related_childs.exists(): - org_list.extend( - update_children( - child_org.related_childs - ) - ) + org_list.extend(update_children(child_org.related_childs)) org_list.append(child_org) return org_list + orgs = update_children(queryset) # Make others childern to selected_parent_organization - Organization.objects.filter( - id__in=[org.pk for org in orgs] - ).update(parent=selected_parent_org_id) + Organization.objects.filter(id__in=[org.pk for org in orgs]).update(parent=selected_parent_org_id) # Make selected_parent_organization a root entity Organization.objects.filter(pk=selected_parent_org_id).update(parent=None) diff --git a/apps/organization/apps.py b/apps/organization/apps.py index dad8aa7edc..0395fab3ff 100644 --- a/apps/organization/apps.py +++ b/apps/organization/apps.py @@ -2,4 +2,4 @@ class OrganizationConfig(AppConfig): - name = 'organization' + name = "organization" diff --git a/apps/organization/dataloaders.py b/apps/organization/dataloaders.py index 33dbaba6e9..b55133dc5c 100644 --- a/apps/organization/dataloaders.py +++ b/apps/organization/dataloaders.py @@ -1,41 +1,31 @@ -from promise import Promise -from django.utils.functional import cached_property from django.db.models import F +from django.utils.functional import cached_property +from gallery.models import File +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from gallery.models import File - from .models import Organization class LogoLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - file_qs = File.objects\ - .annotate(organization_id=F('organization'))\ - .filter(organization__in=keys) - _map = { - file.organization_id: file - for file in file_qs - } + file_qs = File.objects.annotate(organization_id=F("organization")).filter(organization__in=keys) + _map = {file.organization_id: file for file in file_qs} return Promise.resolve([_map.get(key) for key in keys]) class OrganizationLoader(DataLoaderWithContext): def batch_load_fn(self, keys): qs = Organization.objects.filter(id__in=keys) - _map = { - org.pk: org for org in qs - } + _map = {org.pk: org for org in qs} return Promise.resolve([_map.get(key) for key in keys]) class ParentOrganizationLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - qs = Organization.objects.filter(id__in=keys).only('id', 'title') - _map = { - org.pk: org for org in qs - } + qs = Organization.objects.filter(id__in=keys).only("id", "title") + _map = {org.pk: org for org in qs} return Promise.resolve([_map.get(key) for key in keys]) diff --git a/apps/organization/enums.py b/apps/organization/enums.py index 3d0d8d708e..c9bd55bfa3 100644 --- a/apps/organization/enums.py +++ b/apps/organization/enums.py @@ -3,21 +3,21 @@ class OrganizationOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_CREATED_AT = 'created_at' - ASC_TITLE = 'title' - ASC_SHORT_NAME = 'short_name' - ASC_LONG_NAME = 'long_name' - ASC_ORGANIZATION_TYPE = 'organization_type__title' - ASC_POPULARITY = 'popularity' + ASC_ID = "id" + ASC_CREATED_AT = "created_at" + ASC_TITLE = "title" + ASC_SHORT_NAME = "short_name" + ASC_LONG_NAME = "long_name" + ASC_ORGANIZATION_TYPE = "organization_type__title" + ASC_POPULARITY = "popularity" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_TITLE = f'-{ASC_TITLE}' - DESC_SHORT_NAME = f'-{ASC_SHORT_NAME}' - DESC_LONG_NAME = f'-{ASC_LONG_NAME}' - DESC_ORGANIZATION_TYPE = f'-{ASC_ORGANIZATION_TYPE}' - DESC_POPULARITY = f'-{ASC_POPULARITY}' + DESC_ID = f"-{ASC_ID}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_TITLE = f"-{ASC_TITLE}" + DESC_SHORT_NAME = f"-{ASC_SHORT_NAME}" + DESC_LONG_NAME = f"-{ASC_LONG_NAME}" + DESC_ORGANIZATION_TYPE = f"-{ASC_ORGANIZATION_TYPE}" + DESC_POPULARITY = f"-{ASC_POPULARITY}" # Custom annotate fields - ASC_TITLE_LENGTH = 'title_length' - DESC_TITLE_LENGTH = f'-{ASC_TITLE_LENGTH}' + ASC_TITLE_LENGTH = "title_length" + DESC_TITLE_LENGTH = f"-{ASC_TITLE_LENGTH}" diff --git a/apps/organization/factories.py b/apps/organization/factories.py index 53049dc97e..60767031e3 100644 --- a/apps/organization/factories.py +++ b/apps/organization/factories.py @@ -1,15 +1,14 @@ import factory from factory import fuzzy from factory.django import DjangoModelFactory - from gallery.factories import FileFactory from .models import Organization, OrganizationType class OrganizationTypeFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Organization-Type-{n}') - short_name = factory.Sequence(lambda n: f'Organization-Type-Short-Name-{n}') + title = factory.Sequence(lambda n: f"Organization-Type-{n}") + short_name = factory.Sequence(lambda n: f"Organization-Type-Short-Name-{n}") description = fuzzy.FuzzyText(length=100) class Meta: @@ -17,11 +16,11 @@ class Meta: class OrganizationFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Organization-{n}') + title = factory.Sequence(lambda n: f"Organization-{n}") organization_type = factory.SubFactory(OrganizationTypeFactory) - short_name = factory.Sequence(lambda n: f'Organization-Short-Name-{n}') - long_name = factory.Sequence(lambda n: f'Organization-Long-Name-{n}') - url = fuzzy.FuzzyText(length=50, prefix='https://example.com/') + short_name = factory.Sequence(lambda n: f"Organization-Short-Name-{n}") + long_name = factory.Sequence(lambda n: f"Organization-Long-Name-{n}") + url = fuzzy.FuzzyText(length=50, prefix="https://example.com/") logo = factory.SubFactory(FileFactory) verified = True diff --git a/apps/organization/filters.py b/apps/organization/filters.py index 2d6b4d0ad3..ab0d7999cd 100644 --- a/apps/organization/filters.py +++ b/apps/organization/filters.py @@ -1,34 +1,33 @@ import django_filters - +from assessment_registry.models import AssessmentRegistry from django.contrib import admin -from django.utils.translation import gettext_lazy as _ -from django.db.models.functions import Length from django.db import models - -from utils.graphene.filters import MultipleInputFilter, IDFilter - -from assessment_registry.models import AssessmentRegistry +from django.db.models.functions import Length +from django.utils.translation import gettext_lazy as _ from lead.models import Lead from project.models import Project -from .models import Organization + +from utils.graphene.filters import IDFilter, MultipleInputFilter + from .enums import OrganizationOrderingEnum +from .models import Organization class IsFromReliefWeb(admin.SimpleListFilter): - YES = 'yes' - NO = 'no' + YES = "yes" + NO = "no" # Human-readable title which will be displayed in the # right admin sidebar just above the filter options. - title = _('Is from Relief Web') + title = _("Is from Relief Web") # Parameter for the filter that will be used in the URL query. - parameter_name = 'is_from_relief_web' + parameter_name = "is_from_relief_web" def lookups(self, request, model_admin): return ( - (self.YES, 'Yes'), - (self.NO, 'No'), + (self.YES, "Yes"), + (self.NO, "No"), ) def queryset(self, request, queryset): @@ -41,33 +40,33 @@ def queryset(self, request, queryset): class OrganizationFilterSet(django_filters.FilterSet): - search = django_filters.CharFilter(method='search_filter') - used_in_project_by_lead = IDFilter(method='filter_used_in_project_by_lead') - used_in_project_by_assesment = IDFilter(method='filter_used_in_project_by_assesment') + search = django_filters.CharFilter(method="search_filter") + used_in_project_by_lead = IDFilter(method="filter_used_in_project_by_lead") + used_in_project_by_assesment = IDFilter(method="filter_used_in_project_by_assesment") ordering = MultipleInputFilter( OrganizationOrderingEnum, - method='ordering_filter', + method="ordering_filter", ) class Meta: model = Organization - fields = ['id', 'verified'] + fields = ["id", "verified"] def search_filter(self, qs, _, value): if value: return qs.filter( - models.Q(title__unaccent__icontains=value) | - models.Q(short_name__unaccent__icontains=value) | - models.Q(long_name__unaccent__icontains=value) | - models.Q(related_childs__title__unaccent__icontains=value) | - models.Q(related_childs__short_name__unaccent__icontains=value) | - models.Q(related_childs__long_name__unaccent__icontains=value) + models.Q(title__unaccent__icontains=value) + | models.Q(short_name__unaccent__icontains=value) + | models.Q(long_name__unaccent__icontains=value) + | models.Q(related_childs__title__unaccent__icontains=value) + | models.Q(related_childs__short_name__unaccent__icontains=value) + | models.Q(related_childs__long_name__unaccent__icontains=value) ).distinct() return qs def filter_used_in_project_by_lead(self, qs, _, value): if value: - user = getattr(self.request, 'user', None) + user = getattr(self.request, "user", None) if user is None: return qs project = Project.get_for_gq(user, only_member=True).filter(id=value).first() @@ -77,44 +76,46 @@ def filter_used_in_project_by_lead(self, qs, _, value): lead_organizations_queryset = Lead.objects.filter(project=project) return qs.filter( # Publishers - models.Q(id__in=lead_organizations_queryset.values('source')) | + models.Q(id__in=lead_organizations_queryset.values("source")) + | # Authors - models.Q(id__in=lead_organizations_queryset.values('authors__id')) | + models.Q(id__in=lead_organizations_queryset.values("authors__id")) + | # Project stakeholders - models.Q(id__in=project.organizations.values('id')) + models.Q(id__in=project.organizations.values("id")) ) return qs def filter_used_in_project_by_assesment(self, qs, _, value): if value: - user = getattr(self.request, 'user', None) + user = getattr(self.request, "user", None) if user is None: return qs project = Project.get_for_gq(user, only_member=True).filter(id=value).first() if project is None: return qs assessment_organizations_queryset = AssessmentRegistry.objects.filter(project=project) - return qs.filter( - models.Q(id__in=assessment_organizations_queryset.values('stakeholders')) - ) + return qs.filter(models.Q(id__in=assessment_organizations_queryset.values("stakeholders"))) return qs def ordering_filter(self, qs, _, value): if value: if ( - OrganizationOrderingEnum.ASC_TITLE_LENGTH.value in value or - OrganizationOrderingEnum.DESC_TITLE_LENGTH.value in value + OrganizationOrderingEnum.ASC_TITLE_LENGTH.value in value + or OrganizationOrderingEnum.DESC_TITLE_LENGTH.value in value ): - qs = qs.annotate(**{ - OrganizationOrderingEnum.ASC_TITLE_LENGTH.value: Length('title'), - }) + qs = qs.annotate( + **{ + OrganizationOrderingEnum.ASC_TITLE_LENGTH.value: Length("title"), + } + ) return qs.order_by(*value) return qs @property def qs(self): qs = super().qs - if 'ordering' not in self.data: + if "ordering" not in self.data: # Default is Title Length qs = self.ordering_filter( qs, diff --git a/apps/organization/forms.py b/apps/organization/forms.py index b2c1324cf3..4f3e71bff1 100644 --- a/apps/organization/forms.py +++ b/apps/organization/forms.py @@ -1,4 +1,5 @@ from django import forms + from .models import Organization @@ -9,6 +10,6 @@ class MergeForm(forms.Form): ) def __init__(self, *args, **kwargs): - qs = kwargs.pop('organizations') + qs = kwargs.pop("organizations") super().__init__(*args, **kwargs) - self.fields['parent_organization'].queryset = qs + self.fields["parent_organization"].queryset = qs diff --git a/apps/organization/management/commands/load_organizations.py b/apps/organization/management/commands/load_organizations.py index 3b7fedddb2..698c22b19a 100644 --- a/apps/organization/management/commands/load_organizations.py +++ b/apps/organization/management/commands/load_organizations.py @@ -1,115 +1,106 @@ -from django.db import transaction -from django.core.management.base import BaseCommand -from django.core import files from io import BytesIO -from organization.models import ( - OrganizationType, - Organization, -) -from geo.models import Region -from gallery.models import File - import requests +from django.core import files +from django.core.management.base import BaseCommand +from django.db import transaction +from gallery.models import File +from geo.models import Region +from organization.models import Organization, OrganizationType class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - '--sync-by-name', - action='store_true', - help='Sync using short name ( Used for existing organization without relief web id)', + "--sync-by-name", + action="store_true", + help="Sync using short name ( Used for existing organization without relief web id)", ) def handle(self, *args, **kwargs): - self.sync_by_name = kwargs['sync_by_name'] + self.sync_by_name = kwargs["sync_by_name"] self.fetch_org_types() self.fetch_organizations() def fetch_org_types(self): - print('Fetching organization types') - URL = 'https://api.reliefweb.int/v1/references/organization-types' + print("Fetching organization types") + URL = "https://api.reliefweb.int/v1/references/organization-types" response = requests.get(URL).json() - print('Loading organization types') - total = len(response['data']) - for i, type_data in enumerate(response['data']): + print("Loading organization types") + total = len(response["data"]) + for i, type_data in enumerate(response["data"]): self.load_org_type(type_data) - print('{} out of {}'.format(i + 1, total)) + print("{} out of {}".format(i + 1, total)) def load_org_type(self, type_data): - fields = type_data['fields'] + fields = type_data["fields"] values = { - 'title': fields['name'], - 'description': fields.get('description', ''), - 'relief_web_id': fields.get('id'), + "title": fields["name"], + "description": fields.get("description", ""), + "relief_web_id": fields.get("id"), } OrganizationType.objects.update_or_create( **( # Use short_name to sync (Should only be used once) --sync-by-name - {'title': values['title']} + {"title": values["title"]} if self.sync_by_name # Using relief_web_id to sync - else {'relief_web_id': values['relief_web_id']} + else {"relief_web_id": values["relief_web_id"]} ), - defaults=values + defaults=values, ) @transaction.atomic def fetch_organizations(self, offset=0, limit=1000): - print('Fetching organizations starting from: {}'.format(offset)) - URL = 'https://api.reliefweb.int/v1/sources?fields[include][]=logo&fields[include][]=country.iso3&fields[include][]=shortname&fields[include][]=longname&fields[include][]=homepage&fields[include][]=type&offset={}&limit={}'.format( # noqa + print("Fetching organizations starting from: {}".format(offset)) + URL = "https://api.reliefweb.int/v1/sources?fields[include][]=logo&fields[include][]=country.iso3&fields[include][]=shortname&fields[include][]=longname&fields[include][]=homepage&fields[include][]=type&offset={}&limit={}".format( # noqa offset, limit, ) response = requests.get(URL).json() - print('Loading organizations') - total = response['totalCount'] - for i, org_data in enumerate(response['data']): - print('{} out of {}'.format(i + offset + 1, total)) + print("Loading organizations") + total = response["totalCount"] + for i, org_data in enumerate(response["data"]): + print("{} out of {}".format(i + offset + 1, total)) self.load_organization(org_data) - if len(response['data']) > 0: + if len(response["data"]) > 0: self.fetch_organizations(offset + limit, limit) def _get_organization_type_by_relief_web_id(self, relief_web_id): - if not hasattr(self, '_organization_types'): - self._organization_types = { - org.relief_web_id or 'n/a': org - for org in OrganizationType.objects.all() - } + if not hasattr(self, "_organization_types"): + self._organization_types = {org.relief_web_id or "n/a": org for org in OrganizationType.objects.all()} return self._organization_types.get(relief_web_id) def load_organization(self, org_data): - fields = org_data['fields'] + fields = org_data["fields"] values = { - 'title': fields['name'], - 'short_name': fields.get('shortname', ''), - 'long_name': fields.get('longname', ''), - 'url': fields.get('homepage', ''), - 'relief_web_id': org_data['id'], - 'verified': True, - 'organization_type': self._get_organization_type_by_relief_web_id( - fields.get('type', {}).get('id') - ), + "title": fields["name"], + "short_name": fields.get("shortname", ""), + "long_name": fields.get("longname", ""), + "url": fields.get("homepage", ""), + "relief_web_id": org_data["id"], + "verified": True, + "organization_type": self._get_organization_type_by_relief_web_id(fields.get("type", {}).get("id")), } organization, created = Organization.objects.update_or_create( **( # Use short_name to sync (Should only be used once) --sync-by-name - {'title': values['title'], 'created_by': None} + {"title": values["title"], "created_by": None} if self.sync_by_name # Using relief_web_id to sync - else {'relief_web_id': values['relief_web_id']} + else {"relief_web_id": values["relief_web_id"]} ), defaults=values, ) - countries = fields.get('country', []) + countries = fields.get("country", []) for country in countries: - code = country.get('iso3') + code = country.get("iso3") if not code: continue @@ -122,21 +113,21 @@ def load_organization(self, org_data): organization.regions.add(region) - if created or not fields.get('logo'): + if created or not fields.get("logo"): return - logo_data = fields['logo'] + logo_data = fields["logo"] - resp = requests.get(logo_data['url']) + resp = requests.get(logo_data["url"]) fp = BytesIO() fp.write(resp.content) logo = File.objects.create( is_public=True, - title=logo_data['filename'], - mime_type=logo_data['mimetype'], + title=logo_data["filename"], + mime_type=logo_data["mimetype"], ) - logo.file.save(logo_data['filename'], files.File(fp)) + logo.file.save(logo_data["filename"], files.File(fp)) organization.logo = logo organization.save() diff --git a/apps/organization/management/commands/update_organization_popularity.py b/apps/organization/management/commands/update_organization_popularity.py index 2a9b46175e..70ab33dc9d 100644 --- a/apps/organization/management/commands/update_organization_popularity.py +++ b/apps/organization/management/commands/update_organization_popularity.py @@ -1,10 +1,9 @@ from collections import defaultdict -from django.db import models -from django.core.management.base import BaseCommand -from organization.models import Organization +from django.core.management.base import BaseCommand +from django.db import models from lead.models import Lead - +from organization.models import Organization COUNT_THRESHOLD = 10 @@ -16,22 +15,30 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): lead_qs = Lead.objects.filter(project__is_test=False) - lead_author_qs = lead_qs.filter(authors__isnull=False).annotate( - organization_id=models.functions.Coalesce( - models.F('authors__parent_id'), - models.F('authors__id'), + lead_author_qs = ( + lead_qs.filter(authors__isnull=False) + .annotate( + organization_id=models.functions.Coalesce( + models.F("authors__parent_id"), + models.F("authors__id"), + ) ) - ).order_by().values('organization_id').annotate( - count=models.Count('id') + .order_by() + .values("organization_id") + .annotate(count=models.Count("id")) ) - lead_source_qs = lead_qs.filter(source__isnull=False).annotate( - organization_id=models.functions.Coalesce( - models.F('source__parent_id'), - models.F('source__id'), + lead_source_qs = ( + lead_qs.filter(source__isnull=False) + .annotate( + organization_id=models.functions.Coalesce( + models.F("source__parent_id"), + models.F("source__id"), + ) ) - ).order_by().values('organization_id').annotate( - count=models.Count('id') + .order_by() + .values("organization_id") + .annotate(count=models.Count("id")) ) organization_popularity_map = defaultdict(int) @@ -39,7 +46,7 @@ def handle(self, *args, **kwargs): lead_author_qs, lead_source_qs, ]: - for org_id, count in qs.filter(count__gt=COUNT_THRESHOLD).values_list('organization_id', 'count'): + for org_id, count in qs.filter(count__gt=COUNT_THRESHOLD).values_list("organization_id", "count"): organization_popularity_map[org_id] += count Organization.objects.bulk_update( diff --git a/apps/organization/models.py b/apps/organization/models.py index 88287503cd..9e8f003c78 100644 --- a/apps/organization/models.py +++ b/apps/organization/models.py @@ -1,7 +1,8 @@ from django.db import models -from deep.middleware import get_current_user from user_resource.models import UserResource +from deep.middleware import get_current_user + class OrganizationType(models.Model): title = models.CharField(max_length=255, blank=True) @@ -15,16 +16,18 @@ def __str__(self): class Organization(UserResource): class SourceType(models.IntegerChoices): - WEB_INFO_EXTRACT_VIEW = 0, 'Web info extract VIEW' - WEB_INFO_DATA_VIEW = 1, 'Web Info Data VIEW' - CONNECTOR = 2, 'Connector' + WEB_INFO_EXTRACT_VIEW = 0, "Web info extract VIEW" + WEB_INFO_DATA_VIEW = 1, "Web Info Data VIEW" + CONNECTOR = 2, "Connector" parent = models.ForeignKey( # TODO: should we do this ? on_delete=models.CASCADE - 'Organization', on_delete=models.CASCADE, - null=True, blank=True, - help_text='Deep will use the parent organization data instead of current', - related_name='related_childs', + "Organization", + on_delete=models.CASCADE, + null=True, + blank=True, + help_text="Deep will use the parent organization data instead of current", + related_name="related_childs", ) source = models.PositiveSmallIntegerField(choices=SourceType.choices, null=True, blank=True) @@ -36,17 +39,21 @@ class SourceType(models.IntegerChoices): relief_web_id = models.IntegerField(unique=True, blank=True, null=True) logo = models.ForeignKey( - 'gallery.File', + "gallery.File", on_delete=models.SET_NULL, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, ) - regions = models.ManyToManyField('geo.Region', blank=True) + regions = models.ManyToManyField("geo.Region", blank=True) organization_type = models.ForeignKey( OrganizationType, on_delete=models.SET_NULL, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, ) verified = models.BooleanField(default=False) @@ -55,21 +62,17 @@ class SourceType(models.IntegerChoices): class Meta: # Admin panel permissions - permissions = ( - ("can_merge", "Can Merge organizations"), - ) + permissions = (("can_merge", "Can Merge organizations"),) def __str__(self): - return f'{self.pk} : ({self.short_name}) {self.title} ' + ( - '(MERGED)' if self.parent else '' - ) + return f"{self.pk} : ({self.short_name}) {self.title} " + ("(MERGED)" if self.parent else "") @property def data(self): """ Get merged organization if merged """ - if hasattr(self, '_data'): + if hasattr(self, "_data"): return self._data if self.parent_id: diff --git a/apps/organization/mutation.py b/apps/organization/mutation.py index e263d0bb98..5eb43194e2 100644 --- a/apps/organization/mutation.py +++ b/apps/organization/mutation.py @@ -2,15 +2,12 @@ from organization.schema import OrganizationType from organization.serializers import OrganizationGqSerializer -from utils.graphene.mutation import ( - generate_input_type_for_serializer, - GrapheneMutation -) +from utils.graphene.mutation import GrapheneMutation, generate_input_type_for_serializer from .models import Organization OrganizationInputType = generate_input_type_for_serializer( - 'OrganizationInputType', + "OrganizationInputType", serializer_class=OrganizationGqSerializer, ) @@ -18,6 +15,7 @@ class OrganizationCreate(GrapheneMutation): class Arguments: data = OrganizationInputType(required=True) + model = Organization result = graphene.Field(OrganizationType) serializer_class = OrganizationGqSerializer @@ -27,5 +25,5 @@ def check_permissions(cls, info, **kwargs): return True # global permission is always True -class Mutation(): +class Mutation: organization_create = OrganizationCreate.Field() diff --git a/apps/organization/public_schema.py b/apps/organization/public_schema.py index c70a63e82f..b704620b99 100644 --- a/apps/organization/public_schema.py +++ b/apps/organization/public_schema.py @@ -2,18 +2,15 @@ from utils.graphene.types import CustomDjangoListObjectType -from .models import Organization from .filters import OrganizationFilterSet +from .models import Organization class PublicOrganization(DjangoObjectType): class Meta: model = Organization skip_registry = True - fields = ( - 'id', - 'title' - ) + fields = ("id", "title") class PublicOrganizationListObjectType(CustomDjangoListObjectType): diff --git a/apps/organization/schema.py b/apps/organization/schema.py index bfa47bf97b..d544dded30 100644 --- a/apps/organization/schema.py +++ b/apps/organization/schema.py @@ -1,16 +1,16 @@ import graphene +from gallery.models import File +from gallery.schema import GalleryFileType from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from utils.graphene.fields import DjangoPaginatedListObjectField from utils.graphene.pagination import NoOrderingPageGraphqlPagination from utils.graphene.types import CustomDjangoListObjectType -from utils.graphene.fields import DjangoPaginatedListObjectField - -from gallery.schema import GalleryFileType -from gallery.models import File -from .models import Organization, OrganizationType as _OrganizationType from .filters import OrganizationFilterSet +from .models import Organization +from .models import OrganizationType as _OrganizationType from .public_schema import PublicOrganizationListObjectType @@ -18,10 +18,10 @@ class OrganizationTypeType(DjangoObjectType): class Meta: model = _OrganizationType only_fields = ( - 'id', - 'title', - 'short_name', - 'description', + "id", + "title", + "short_name", + "description", ) @@ -36,14 +36,15 @@ class Meta: model = Organization skip_registry = True only_fields = ( - 'id', - 'title', - 'short_name', - 'long_name', - 'url', - 'logo', - 'verified', + "id", + "title", + "short_name", + "long_name", + "url", + "logo", + "verified", ) + logo = graphene.Field(GalleryFileType) def resolve_logo(root, info, **kwargs) -> File: @@ -54,18 +55,19 @@ class OrganizationType(DjangoObjectType): class Meta: model = Organization only_fields = ( - 'id', - 'title', - 'short_name', - 'long_name', - 'url', - 'logo', - 'regions', - 'organization_type', - 'verified', + "id", + "title", + "short_name", + "long_name", + "url", + "logo", + "regions", + "organization_type", + "verified", ) + logo = graphene.Field(GalleryFileType) - merged_as = graphene.Field(MergedAsOrganizationType, source='parent') + merged_as = graphene.Field(MergedAsOrganizationType, source="parent") def resolve_logo(root, info, **kwargs) -> File: return info.context.dl.organization.logo.load(root.pk) @@ -83,23 +85,14 @@ class Meta: class Query: organization = DjangoObjectField(OrganizationType) organizations = DjangoPaginatedListObjectField( - OrganizationListType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + OrganizationListType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) public_organizations = DjangoPaginatedListObjectField( - PublicOrganizationListObjectType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + PublicOrganizationListObjectType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) organization_type = DjangoObjectField(OrganizationTypeType) organization_types = DjangoPaginatedListObjectField( - OrganizationTypeListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + OrganizationTypeListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) def resolve_organizations(root, info, **kwargs): diff --git a/apps/organization/serializers.py b/apps/organization/serializers.py index 3f4df84f6e..a45395a3b0 100644 --- a/apps/organization/serializers.py +++ b/apps/organization/serializers.py @@ -1,9 +1,9 @@ +from drf_dynamic_fields import DynamicFieldsMixin +from geo.serializers import SimpleRegionSerializer from rest_framework import serializers from user_resource.serializers import UserResourceSerializer -from drf_dynamic_fields import DynamicFieldsMixin -from geo.serializers import SimpleRegionSerializer -from deep.serializers import URLCachedFileField, RemoveNullFieldsMixin +from deep.serializers import RemoveNullFieldsMixin, URLCachedFileField from .models import Organization, OrganizationType @@ -11,47 +11,55 @@ class OrganizationTypeSerializer(serializers.ModelSerializer): class Meta: model = OrganizationType - fields = ('__all__') + fields = "__all__" class MergedAsOrganizationSerializer(serializers.ModelSerializer): - logo = URLCachedFileField(source='logo.file', read_only=True) + logo = URLCachedFileField(source="logo.file", read_only=True) class Meta: model = Organization - fields = ('id', 'title', 'logo') + fields = ("id", "title", "logo") class SimpleOrganizationSerializer(serializers.ModelSerializer): - logo = URLCachedFileField(source='logo.file', read_only=True) - merged_as = MergedAsOrganizationSerializer(source='parent', read_only=True) + logo = URLCachedFileField(source="logo.file", read_only=True) + merged_as = MergedAsOrganizationSerializer(source="parent", read_only=True) class Meta: model = Organization - fields = ('id', 'title', 'short_name', 'merged_as', 'logo') + fields = ("id", "title", "short_name", "merged_as", "logo") class OrganizationSerializer( - DynamicFieldsMixin, RemoveNullFieldsMixin, UserResourceSerializer, + DynamicFieldsMixin, + RemoveNullFieldsMixin, + UserResourceSerializer, ): organization_type_display = OrganizationTypeSerializer( - source='organization_type', read_only=True, + source="organization_type", + read_only=True, ) regions_display = SimpleRegionSerializer( - source='regions', read_only=True, many=True, + source="regions", + read_only=True, + many=True, ) - logo_url = URLCachedFileField(source='logo.file', allow_null=True, required=False) - merged_as = MergedAsOrganizationSerializer(source='parent', read_only=True) + logo_url = URLCachedFileField(source="logo.file", allow_null=True, required=False) + merged_as = MergedAsOrganizationSerializer(source="parent", read_only=True) client_id = None class Meta: model = Organization - exclude = ('parent',) - read_only_fields = ('verified', 'logo_url',) + exclude = ("parent",) + read_only_fields = ( + "verified", + "logo_url", + ) def create(self, validated_data): organization = super().create(validated_data) - organization.created_by = organization.modified_by = self.context['request'].user + organization.created_by = organization.modified_by = self.context["request"].user return organization @@ -61,17 +69,16 @@ def get_merged_as(self, obj): class ArySourceOrganizationSerializer(DynamicFieldsMixin, UserResourceSerializer): - logo = URLCachedFileField(source='logo.file', allow_null=True) - key = serializers.IntegerField(source='pk') - merged_as = MergedAsOrganizationSerializer(source='parent', read_only=True) + logo = URLCachedFileField(source="logo.file", allow_null=True) + key = serializers.IntegerField(source="pk") + merged_as = MergedAsOrganizationSerializer(source="parent", read_only=True) class Meta: model = Organization - fields = ('key', 'title', 'long_name', - 'short_name', 'logo', 'organization_type', 'merged_as') + fields = ("key", "title", "long_name", "short_name", "logo", "organization_type", "merged_as") class OrganizationGqSerializer(UserResourceSerializer): class Meta: model = Organization - fields = ('title', 'long_name', 'url', 'short_name', 'logo', 'organization_type') + fields = ("title", "long_name", "url", "short_name", "logo", "organization_type") diff --git a/apps/organization/tasks.py b/apps/organization/tasks.py index e399f1515f..0400bc215d 100644 --- a/apps/organization/tasks.py +++ b/apps/organization/tasks.py @@ -2,21 +2,21 @@ from celery import shared_task from django.core.management import call_command -from utils.common import redis_lock +from utils.common import redis_lock logger = logging.getLogger(__name__) @shared_task -@redis_lock('sync_organization_with_relief_web') +@redis_lock("sync_organization_with_relief_web") def sync_organization_with_relief_web(): - call_command('load_organizations') + call_command("load_organizations") return True @shared_task -@redis_lock('update_organization_popularity') +@redis_lock("update_organization_popularity") def update_organization_popularity(): - call_command('update_organization_popularity') + call_command("update_organization_popularity") return True diff --git a/apps/organization/tests/test_mutations.py b/apps/organization/tests/test_mutations.py index ee5e74a951..a5416dd42c 100644 --- a/apps/organization/tests/test_mutations.py +++ b/apps/organization/tests/test_mutations.py @@ -1,10 +1,11 @@ from user.factories import UserFactory + from utils.graphene.tests import GraphQLTestCase class TestOrganizationMutation(GraphQLTestCase): def test_orgainization_query(self): - self.organization_query = ''' + self.organization_query = """ mutation MyMutation ($input : OrganizationInputType!) { organizationCreate(data: $input){ @@ -20,21 +21,14 @@ def test_orgainization_query(self): } } } - ''' + """ user = UserFactory.create() - minput = dict( - title="Test Organization", - shortName="Short Name", - longName="This is long name" - ) + minput = dict(title="Test Organization", shortName="Short Name", longName="This is long name") def _query_check(minput, **kwargs): - return self.query_check( - self.organization_query, - minput=minput, - **kwargs - ) + return self.query_check(self.organization_query, minput=minput, **kwargs) + # without login _query_check(minput, assert_for_error=True) @@ -43,6 +37,6 @@ def _query_check(minput, **kwargs): self.force_login(user) content = _query_check(minput) - self.assertEqual(content['data']['organizationCreate']['errors'], None) - self.assertEqual(content['data']['organizationCreate']['result']['title'], 'Test Organization') - self.assertEqual(content['data']['organizationCreate']['result']['verified'], False) + self.assertEqual(content["data"]["organizationCreate"]["errors"], None) + self.assertEqual(content["data"]["organizationCreate"]["result"]["title"], "Test Organization") + self.assertEqual(content["data"]["organizationCreate"]["result"]["verified"], False) diff --git a/apps/organization/tests/test_schemas.py b/apps/organization/tests/test_schemas.py index 142ea99218..ba3226d0ee 100644 --- a/apps/organization/tests/test_schemas.py +++ b/apps/organization/tests/test_schemas.py @@ -1,18 +1,15 @@ -from utils.graphene.tests import GraphQLTestCase - -from project.factories import ProjectFactory from lead.factories import LeadFactory -from organization.factories import ( - OrganizationTypeFactory, - OrganizationFactory -) +from organization.factories import OrganizationFactory, OrganizationTypeFactory from organization.models import OrganizationType +from project.factories import ProjectFactory from user.factories import UserFactory +from utils.graphene.tests import GraphQLTestCase + class TestOrganizationTypeQuery(GraphQLTestCase): def test_organization_type_query(self): - query = ''' + query = """ query OrganizationType { organizationTypes { results { @@ -24,7 +21,7 @@ def test_organization_type_query(self): totalCount } } - ''' + """ OrganizationType.objects.all().delete() OrganizationTypeFactory.create_batch(3) user = UserFactory.create() @@ -34,11 +31,11 @@ def test_organization_type_query(self): self.force_login(user) content = self.query_check(query) - self.assertEqual(len(content['data']['organizationTypes']['results']), 3, content) - self.assertEqual(content['data']['organizationTypes']['totalCount'], 3, content) + self.assertEqual(len(content["data"]["organizationTypes"]["results"]), 3, content) + self.assertEqual(content["data"]["organizationTypes"]["totalCount"], 3, content) def test_organization_query(self): - query = ''' + query = """ query MyQuery ( $verified: Boolean $search: String @@ -57,19 +54,19 @@ def test_organization_query(self): totalCount } } - ''' - org1 = OrganizationFactory.create(title='org-1', verified=False) - org2 = OrganizationFactory.create(title='org-2', verified=True) - org3 = OrganizationFactory.create(title='org-3', verified=False) + """ + org1 = OrganizationFactory.create(title="org-1", verified=False) + org2 = OrganizationFactory.create(title="org-2", verified=True) + org3 = OrganizationFactory.create(title="org-3", verified=False) org4 = OrganizationFactory.create( - title='org-4', - short_name='org-short-name-4', - long_name='org-long-name-4', + title="org-4", + short_name="org-short-name-4", + long_name="org-long-name-4", verified=True, ) - org5 = OrganizationFactory.create(title='org-5', verified=False) - org6 = OrganizationFactory.create(title='org-5', verified=False) - org7 = OrganizationFactory.create(title='órg-7', verified=False) + org5 = OrganizationFactory.create(title="org-5", verified=False) + org6 = OrganizationFactory.create(title="org-5", verified=False) + org7 = OrganizationFactory.create(title="órg-7", verified=False) all_org = [org7, org6, org5, org4, org3, org2, org1] user, non_member_user = UserFactory.create_batch(2) project = ProjectFactory.create() @@ -86,29 +83,45 @@ def test_organization_query(self): lead2.save() for _user, filters, expected_organizations in [ - (user, {'search': 'Organization-'}, [org7, org6, org5, org3, org2, org1]), - (user, {'verified': True}, [org4, org2]), - (user, {'verified': False}, [org7, org6, org5, org3, org1]), - (user, { - 'search': 'Organization-', - 'verified': True, - }, [org2]), - (user, { - 'search': 'Organization-', - 'verified': False, - }, [org7, org6, org5, org3, org1]), - (user, { - 'usedInProjectByLead': str(project.id), - }, [org6, org5, org3, org2, org1]), - (non_member_user, { - 'usedInProjectByLead': str(project.id), + (user, {"search": "Organization-"}, [org7, org6, org5, org3, org2, org1]), + (user, {"verified": True}, [org4, org2]), + (user, {"verified": False}, [org7, org6, org5, org3, org1]), + ( + user, + { + "search": "Organization-", + "verified": True, + }, + [org2], + ), + ( + user, + { + "search": "Organization-", + "verified": False, + }, + [org7, org6, org5, org3, org1], + ), + ( + user, + { + "usedInProjectByLead": str(project.id), + }, + [org6, org5, org3, org2, org1], + ), + ( + non_member_user, + { + "usedInProjectByLead": str(project.id), # Return all the organizations (Filter not applied) - }, all_org), - # unaccent search - (user, {'search': 'org'}, all_org), - (user, {'search': 'órg'}, all_org), - (user, {'search': 'org-7'}, [org7]), - (user, {'search': 'órg-7'}, [org7]), + }, + all_org, + ), + # unaccent search + (user, {"search": "org"}, all_org), + (user, {"search": "órg"}, all_org), + (user, {"search": "org-7"}, [org7]), + (user, {"search": "órg-7"}, [org7]), ]: # Without authentication ----- self.logout() @@ -118,27 +131,22 @@ def test_organization_query(self): self.force_login(_user) content = self.query_check(query, variables=filters) context = { - 'content': content, - 'user': _user, - 'filters': filters, - 'expected_organizations': expected_organizations, + "content": content, + "user": _user, + "filters": filters, + "expected_organizations": expected_organizations, } - self.assertEqual(len(content['data']['organizations']['results']), len(expected_organizations), context) - self.assertEqual(content['data']['organizations']['totalCount'], len(expected_organizations), context) + self.assertEqual(len(content["data"]["organizations"]["results"]), len(expected_organizations), context) + self.assertEqual(content["data"]["organizations"]["totalCount"], len(expected_organizations), context) self.assertEqual( - [ - item['title'] for item in content['data']['organizations']['results'] - ], - [ - org.title - for org in expected_organizations - ], + [item["title"] for item in content["data"]["organizations"]["results"]], + [org.title for org in expected_organizations], context, ) def test_public_organizations_query(self): - query = ''' + query = """ query PublicOrganizations { publicOrganizations { results { @@ -148,8 +156,8 @@ def test_public_organizations_query(self): totalCount } } - ''' + """ OrganizationFactory.create_batch(4) # should be visible without authentication content = self.query_check(query) - self.assertEqual(content['data']['publicOrganizations']['totalCount'], 4, content) + self.assertEqual(content["data"]["publicOrganizations"]["totalCount"], 4, content) diff --git a/apps/organization/views.py b/apps/organization/views.py index 6b0f34919a..49f3a0c314 100644 --- a/apps/organization/views.py +++ b/apps/organization/views.py @@ -1,46 +1,45 @@ -from rest_framework import viewsets, mixins, permissions, filters - import django_filters +from rest_framework import filters, mixins, permissions, viewsets -from deep.paginations import AutocompleteSetPagination from deep.authentication import CSRFExemptSessionAuthentication +from deep.paginations import AutocompleteSetPagination -from .serializers import ( - OrganizationSerializer, - OrganizationTypeSerializer, -) -from .models import ( - Organization, - OrganizationType, -) +from .models import Organization, OrganizationType +from .serializers import OrganizationSerializer, OrganizationTypeSerializer class OrganizationTypeViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = OrganizationTypeSerializer permission_classes = [permissions.IsAuthenticated] queryset = OrganizationType.objects.all() - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) - search_fields = ('title', 'description',) + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) + search_fields = ( + "title", + "description", + ) class OrganizationViewSet( - mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - viewsets.GenericViewSet, + mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + viewsets.GenericViewSet, ): serializer_class = OrganizationSerializer permission_classes = [permissions.IsAuthenticated] pagination_class = AutocompleteSetPagination authentication_classes = [CSRFExemptSessionAuthentication] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) - search_fields = ('title', 'short_name', 'long_name', 'url',) - filterset_fields = ('verified',) - ordering = ('title',) + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) + search_fields = ( + "title", + "short_name", + "long_name", + "url", + ) + filterset_fields = ("verified",) + ordering = ("title",) def get_queryset(self): - if self.kwargs.get('pk'): - return Organization.objects.prefetch_related('parent') + if self.kwargs.get("pk"): + return Organization.objects.prefetch_related("parent") return Organization.objects.filter(parent=None) diff --git a/apps/profiling/apps.py b/apps/profiling/apps.py index f34a4533ed..51393dd272 100644 --- a/apps/profiling/apps.py +++ b/apps/profiling/apps.py @@ -2,4 +2,4 @@ class ProfilingConfig(AppConfig): - name = 'profiling' + name = "profiling" diff --git a/apps/profiling/management/commands/test_profile.py b/apps/profiling/management/commands/test_profile.py index d355302d87..6257090f8b 100644 --- a/apps/profiling/management/commands/test_profile.py +++ b/apps/profiling/management/commands/test_profile.py @@ -1,50 +1,58 @@ +import autofixture from django.core.management.base import BaseCommand from profiling.profiler import Profiler from project.models import Project, ProjectMembership -import autofixture - class Command(BaseCommand): def handle(self, *args, **kwargs): p = Profiler() - print('Creating users') - users = autofixture.create('auth.User', 20, overwrite_defaults=True) + print("Creating users") + users = autofixture.create("auth.User", 20, overwrite_defaults=True) user = users[0] p.authorise_with(user) - print('Creating regions') - autofixture.create('geo.Region', 5, field_values={ - 'created_by': user, - }) + print("Creating regions") + autofixture.create( + "geo.Region", + 5, + field_values={ + "created_by": user, + }, + ) - print('Creating projects') + print("Creating projects") Project.objects.all().delete() - autofixture.create_one('project.Project', field_values={ - 'created_by': user, - }) + autofixture.create_one( + "project.Project", + field_values={ + "created_by": user, + }, + ) project = Project.objects.first() - if not ProjectMembership.objects.filter(project=project, member=user)\ - .exists(): - ProjectMembership.objects.create(project=project, member=user, - role='admin') + if not ProjectMembership.objects.filter(project=project, member=user).exists(): + ProjectMembership.objects.create(project=project, member=user, role="admin") - print('Creating leads') + print("Creating leads") # create_many_leads(1000, user, project) - autofixture.create('lead.Lead', 100, field_values={ - 'created_by': user, - }) + autofixture.create( + "lead.Lead", + 100, + field_values={ + "created_by": user, + }, + ) - print('Starting profiling') + print("Starting profiling") p.profile_get( - '/api/v1/leads/?' - 'status=pending&' - 'published_on__lt=2016-01-10&' - 'assignee={0}&' - 'search=lorem&' - 'limit=100&' - ''.format(users[2].id) + "/api/v1/leads/?" + "status=pending&" + "published_on__lt=2016-01-10&" + "assignee={0}&" + "search=lorem&" + "limit=100&" + "".format(users[2].id) ) p.__del__() diff --git a/apps/profiling/profiler.py b/apps/profiling/profiler.py index 610fb88811..9d359021cb 100644 --- a/apps/profiling/profiler.py +++ b/apps/profiling/profiler.py @@ -1,13 +1,11 @@ -from django import test -from django.conf import settings - import cProfile -import pstats import os +import pstats +from django import test +from django.conf import settings from jwt_auth.token import AccessToken - TEST_SETUP_VERBOSITY = 1 @@ -31,7 +29,7 @@ def create(self): False, ) self.client = test.Client() - self.client.get('/') + self.client.get("/") self.created = True @@ -54,7 +52,7 @@ def destroy(self): def authorise_with(self, user): self.access = AccessToken.for_user(user).encode() - self.auth = 'Bearer {0}'.format(self.access) + self.auth = "Bearer {0}".format(self.access) def start_profiling(self): self.pr = cProfile.Profile(builtins=False) @@ -66,14 +64,13 @@ def stop_profiling(self): self.pr.disable() self.stats = pstats.Stats(self.pr) - self.stats.sort_stats('cumulative') + self.stats.sort_stats("cumulative") self.pr = None def print_stats(self): - regex = '({})|(/db/models.*(fetch|execute_sql))'\ - .format(os.getcwd()) + regex = "({})|(/db/models.*(fetch|execute_sql))".format(os.getcwd()) - print('Stats') + print("Stats") self.stats.print_stats(regex) # print('Callers') @@ -82,7 +79,7 @@ def print_stats(self): # print('Callees') # self.stats.print_callees(regex) - print('End') + print("End") def profile_get(self, *args, **kwargs): self.start_profiling() diff --git a/apps/project/activity.py b/apps/project/activity.py index 9856bdb675..c19de34eee 100644 --- a/apps/project/activity.py +++ b/apps/project/activity.py @@ -1,20 +1,22 @@ -from reversion.models import Version +import json + from project.models import AnalysisFramework, CategoryEditor +from reversion.models import Version + from utils.common import random_key -import json def get_diff(v1, v2): - p1 = json.loads(v1.serialized_data)[0].get('fields') - p2 = json.loads(v2.serialized_data)[0].get('fields') + p1 = json.loads(v1.serialized_data)[0].get("fields") + p2 = json.loads(v2.serialized_data)[0].get("fields") diff = {} def calc_simple_diff(key): if p1.get(key) != p2.get(key): diff[key] = { - 'new': p1.get(key), - 'old': p2.get(key), + "new": p1.get(key), + "old": p2.get(key), } def calc_model_diff(key, model): @@ -25,26 +27,30 @@ def calc_model_diff(key, model): m1 = id1 and model.objects.filter(id=id1).first() m2 = id2 and model.objects.filter(id=id2).first() diff[key] = { - 'new': m1 and {'id': m1.id, 'title': m1.title}, - 'old': m2 and {'id': m2.id, 'title': m2.title}, + "new": m1 and {"id": m1.id, "title": m1.title}, + "old": m2 and {"id": m2.id, "title": m2.title}, } - calc_simple_diff('title') - calc_simple_diff('description') - calc_simple_diff('start_date') - calc_simple_diff('end_date') - calc_model_diff('analysis_framework', AnalysisFramework) - calc_model_diff('category_editor', CategoryEditor) + calc_simple_diff("title") + calc_simple_diff("description") + calc_simple_diff("start_date") + calc_simple_diff("end_date") + calc_model_diff("analysis_framework", AnalysisFramework) + calc_model_diff("category_editor", CategoryEditor) if len(diff.keys()) > 0: return { - 'key': random_key(), - 'fields': diff, - 'user': { - 'name': v1.revision.user.profile.get_display_name(), - 'id': v1.revision.user.id, - } if v1.revision.user else None, # TODO: this is just a fix - 'timestamp': v1.revision.date_created, + "key": random_key(), + "fields": diff, + "user": ( + { + "name": v1.revision.user.profile.get_display_name(), + "id": v1.revision.user.id, + } + if v1.revision.user + else None + ), # TODO: this is just a fix + "timestamp": v1.revision.date_created, } return None diff --git a/apps/project/admin.py b/apps/project/admin.py index c36346ae2b..f17d9da9a3 100644 --- a/apps/project/admin.py +++ b/apps/project/admin.py @@ -1,50 +1,51 @@ import json -from django.contrib import admin -from django.utils.safestring import mark_safe -from django.contrib import messages -from django.db import models +from admin_auto_filters.filters import AutocompleteFilterFactory +from assessment_registry.models import AssessmentRegistry +from django.contrib import admin, messages from django.contrib.postgres.aggregates import StringAgg +from django.db import models +from django.utils.safestring import mark_safe +from entry.models import Entry +from lead.models import Lead from reversion.admin import VersionAdmin -from admin_auto_filters.filters import AutocompleteFilterFactory from deep.admin import linkify -from lead.models import Lead -from entry.models import Entry -from assessment_registry.models import AssessmentRegistry -from .tasks import generate_viz_stats, generate_project_stats_cache from .forms import ProjectRoleForm from .models import ( Project, - ProjectRole, - ProjectMembership, - ProjectUserGroupMembership, - ProjectStats, + ProjectChangeLog, ProjectJoinRequest, + ProjectMembership, ProjectOrganization, - ProjectChangeLog, - ProjectPinned + ProjectPinned, + ProjectRole, + ProjectStats, + ProjectUserGroupMembership, ) +from .tasks import generate_project_stats_cache, generate_viz_stats TRIGGER_LIMIT = 5 def trigger_project_viz_stat_calc(generator): def action(modeladmin, request, queryset): - for project_id in queryset.values_list('project_id', flat=True).distinct()[:TRIGGER_LIMIT]: + for project_id in queryset.values_list("project_id", flat=True).distinct()[:TRIGGER_LIMIT]: generator.delay(project_id, force=True) messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully triggered Project Stats Calculation for projects:

' + - '
'.join( - '* {0} : {1}'.format(*value) - for value in queryset.values_list('project_id', 'project__title').distinct()[:TRIGGER_LIMIT] + "Successfully triggered Project Stats Calculation for projects:

" + + "
".join( + "* {0} : {1}".format(*value) + for value in queryset.values_list("project_id", "project__title").distinct()[:TRIGGER_LIMIT] ) - ) + ), ) - action.short_description = 'Trigger project stat calculation' + + action.short_description = "Trigger project stat calculation" return action @@ -52,93 +53,114 @@ def trigger_project_stat_cache_calc(): def action(modeladmin, request, queryset): generate_project_stats_cache.delay(force=True) messages.add_message( - request, messages.INFO, - mark_safe( - 'Successfully triggered Project Stats Cache Calculation for projects.' - ) + request, messages.INFO, mark_safe("Successfully triggered Project Stats Cache Calculation for projects.") ) - action.short_description = 'Trigger project stat cache calculation' + + action.short_description = "Trigger project stat cache calculation" return action class ProjectMembershipInline(admin.TabularInline): model = ProjectMembership extra = 0 - autocomplete_fields = ('added_by', 'linked_group', 'member',) + autocomplete_fields = ( + "added_by", + "linked_group", + "member", + ) class ProjectUserGroupMembershipInline(admin.TabularInline): model = ProjectUserGroupMembership extra = 0 - autocomplete_fields = ('added_by', 'usergroup',) + autocomplete_fields = ( + "added_by", + "usergroup", + ) class ProjectOrganizationInline(admin.TabularInline): model = ProjectOrganization - autocomplete_fields = ('organization',) + autocomplete_fields = ("organization",) class ProjectJoinRequestInline(admin.TabularInline): model = ProjectJoinRequest extra = 0 - autocomplete_fields = ('requested_by', 'responded_by',) + autocomplete_fields = ( + "requested_by", + "responded_by", + ) @admin.register(Project) class ProjectAdmin(VersionAdmin): - search_fields = ['title'] + search_fields = ["title"] list_display = [ - 'title', - linkify('category_editor', 'Category Editor'), - linkify('analysis_framework', 'Assessment Framework'), - linkify('assessment_template', 'Assessment Template'), - 'associated_regions', - 'entries_count', - 'assessment_count', - 'members_count', - 'deleted_at', + "title", + linkify("category_editor", "Category Editor"), + linkify("analysis_framework", "Assessment Framework"), + linkify("assessment_template", "Assessment Template"), + "associated_regions", + "entries_count", + "assessment_count", + "members_count", + "deleted_at", ] autocomplete_fields = ( - 'analysis_framework', 'assessment_template', 'category_editor', - 'created_by', 'modified_by', 'regions', + "analysis_framework", + "assessment_template", + "category_editor", + "created_by", + "modified_by", + "regions", ) list_filter = ( - 'assessment_template', - 'is_private', - 'is_deleted', + "assessment_template", + "is_private", + "is_deleted", ) actions = [trigger_project_stat_cache_calc()] - inlines = [ProjectMembershipInline, - ProjectUserGroupMembershipInline, - ProjectJoinRequestInline, - ProjectOrganizationInline] + inlines = [ProjectMembershipInline, ProjectUserGroupMembershipInline, ProjectJoinRequestInline, ProjectOrganizationInline] def get_queryset(self, request): - def _count_subquery(Model, count_field='id'): + def _count_subquery(Model, count_field="id"): return models.functions.Coalesce( models.Subquery( Model.objects.filter( - project=models.OuterRef('pk'), - ).order_by().values('project') - .annotate(c=models.Count('id', distinct=True)).values('c')[:1], + project=models.OuterRef("pk"), + ) + .order_by() + .values("project") + .annotate(c=models.Count("id", distinct=True)) + .values("c")[:1], output_field=models.IntegerField(), - ), 0) - - return super().get_queryset(request).prefetch_related( - 'category_editor', 'analysis_framework', 'assessment_template', - ).annotate( - leads_count=_count_subquery(Lead), - entries_count=_count_subquery(Entry), - assessment_count=_count_subquery(AssessmentRegistry), - members_count=_count_subquery(ProjectMembership, count_field='member'), - associated_regions_count=models.Count('regions', distinct=True), - associated_regions=StringAgg('regions__title', ',', distinct=True), + ), + 0, + ) + + return ( + super() + .get_queryset(request) + .prefetch_related( + "category_editor", + "analysis_framework", + "assessment_template", + ) + .annotate( + leads_count=_count_subquery(Lead), + entries_count=_count_subquery(Entry), + assessment_count=_count_subquery(AssessmentRegistry), + members_count=_count_subquery(ProjectMembership, count_field="member"), + associated_regions_count=models.Count("regions", distinct=True), + associated_regions=StringAgg("regions__title", ",", distinct=True), + ) ) def get_readonly_fields(self, request, obj=None): # editing an existing object if obj: - return self.readonly_fields + ('is_private', ) + return self.readonly_fields + ("is_private",) return self.readonly_fields def entries_count(self, obj): @@ -153,69 +175,88 @@ def assessment_count(self, obj): def members_count(self, obj): return obj.members_count - entries_count.admin_order_field = 'entries_count' - leads_count.admin_order_field = 'leads_count' - assessment_count.admin_order_field = 'assessment_count' - members_count.admin_order_field = 'members_count' + entries_count.admin_order_field = "entries_count" + leads_count.admin_order_field = "leads_count" + assessment_count.admin_order_field = "assessment_count" + members_count.admin_order_field = "members_count" def associated_regions(self, obj): count = obj.associated_regions_count regions = obj.associated_regions if count == 0: - return '' + return "" elif count == 1: return regions - return f'{regions[:10]}.... ({count})' + return f"{regions[:10]}.... ({count})" @admin.register(ProjectRole) class ProjectRoleAdmin(admin.ModelAdmin): - list_display = ('id', 'title', 'level', 'type', 'is_default_role') + list_display = ("id", "title", "level", "type", "is_default_role") form = ProjectRoleForm @admin.register(ProjectStats) class ProjectEntryStatsAdmin(admin.ModelAdmin): - AF = linkify('project.analysis_framework', 'AF') - - search_fields = ('project__title',) - list_filter = ('status',) - list_display = ('project', 'modified_at', AF, 'status', 'file', 'confidential_file',) + AF = linkify("project.analysis_framework", "AF") + + search_fields = ("project__title",) + list_filter = ("status",) + list_display = ( + "project", + "modified_at", + AF, + "status", + "file", + "confidential_file", + ) actions = [trigger_project_viz_stat_calc(generate_viz_stats)] - autocomplete_fields = ('project',) - readonly_fields = (AF, 'token') + autocomplete_fields = ("project",) + readonly_fields = (AF, "token") def get_queryset(self, request): - return super().get_queryset(request).prefetch_related('project', 'project__analysis_framework') + return super().get_queryset(request).prefetch_related("project", "project__analysis_framework") @admin.register(ProjectChangeLog) class ProjectChangeLogAdmin(admin.ModelAdmin): - search_fields = ('project__title',) + search_fields = ("project__title",) list_filter = ( - AutocompleteFilterFactory('Project', 'project'), - AutocompleteFilterFactory('User', 'user'), - 'action', - 'created_at', + AutocompleteFilterFactory("Project", "project"), + AutocompleteFilterFactory("User", "user"), + "action", + "created_at", + ) + list_display = ( + "project", + "created_at", + "action", + "user", ) - list_display = ('project', 'created_at', 'action', 'user',) - autocomplete_fields = ('project', 'user',) - readonly_fields = ('project', 'created_at', 'action', 'user', 'diff', 'diff_pretty') + autocomplete_fields = ( + "project", + "user", + ) + readonly_fields = ("project", "created_at", "action", "user", "diff", "diff_pretty") def get_queryset(self, request): - return super().get_queryset(request).prefetch_related( - 'project', - 'user', + return ( + super() + .get_queryset(request) + .prefetch_related( + "project", + "user", + ) ) def has_add_permission(self, request, obj=None): return False - @admin.display(description='Diff pretty JSON') + @admin.display(description="Diff pretty JSON") def diff_pretty(self, obj): - return mark_safe(f'
{json.dumps(obj.diff, indent=2)}
') + return mark_safe(f"
{json.dumps(obj.diff, indent=2)}
") @admin.register(ProjectPinned) class ProjectPinnedAdmin(admin.ModelAdmin): - list_display = ('id', 'project', 'user', 'order') + list_display = ("id", "project", "user", "order") diff --git a/apps/project/apps.py b/apps/project/apps.py index 857dae0a9c..65c502d5f1 100644 --- a/apps/project/apps.py +++ b/apps/project/apps.py @@ -2,7 +2,7 @@ class ProjectConfig(AppConfig): - name = 'project' + name = "project" def ready(self): import project.receivers # noqa diff --git a/apps/project/change_log.py b/apps/project/change_log.py index cd17abd9da..40981ad4af 100644 --- a/apps/project/change_log.py +++ b/apps/project/change_log.py @@ -4,38 +4,21 @@ from utils.common import remove_empty_keys_from_dict - -from .models import ( - Project, - ProjectChangeLog, - ProjectOrganization, -) +from .models import Project, ProjectChangeLog, ProjectOrganization def get_flat_dict_diff(list1: List[dict], list2: List[dict], fields: List[str]): def _dict_to_tuple_set(items: List[dict]) -> Set[tuple]: - return set( - tuple( - item[field] - for field in fields - ) - for item in items - ) + return set(tuple(item[field] for field in fields) for item in items) def _tuple_to_dict_list(items: Set[tuple]) -> List[dict]: - return [ - { - field: item[index] - for index, field in enumerate(fields) - } - for item in sorted(items) - ] + return [{field: item[index] for index, field in enumerate(fields)} for item in sorted(items)] set_list1 = _dict_to_tuple_set(list1) set_list2 = _dict_to_tuple_set(list2) return { - 'add': _tuple_to_dict_list(set_list2 - set_list1), - 'remove': _tuple_to_dict_list(set_list1 - set_list2), + "add": _tuple_to_dict_list(set_list2 - set_list1), + "remove": _tuple_to_dict_list(set_list1 - set_list2), } @@ -43,20 +26,20 @@ def get_list_diff(list1, list2): set_list1 = set(list1) set_list2 = set(list2) return { - 'add': sorted(list(set_list2 - set_list1)), - 'remove': sorted(list(set_list1 - set_list2)), + "add": sorted(list(set_list2 - set_list1)), + "remove": sorted(list(set_list1 - set_list2)), } class ProjectOrganizationSerializer(serializers.ModelSerializer): class Meta: model = ProjectOrganization - fields = ('organization', 'organization_type') + fields = ("organization", "organization_type") class ProjectDataSerializer(serializers.ModelSerializer): organizations = serializers.SerializerMethodField() - analysis_framework = serializers.IntegerField(source='analysis_framework_id') + analysis_framework = serializers.IntegerField(source="analysis_framework_id") regions = serializers.SerializerMethodField() # Members member_users = serializers.SerializerMethodField() @@ -66,60 +49,60 @@ class ProjectDataSerializer(serializers.ModelSerializer): class Meta: model = Project scalar_fields = [ - 'title', - 'start_date', - 'end_date', - 'description', - 'is_private', - 'is_test', - 'is_deleted', - 'deleted_at', + "title", + "start_date", + "end_date", + "description", + "is_private", + "is_test", + "is_deleted", + "deleted_at", # Document sharing - 'has_publicly_viewable_unprotected_leads', - 'has_publicly_viewable_restricted_leads', - 'has_publicly_viewable_confidential_leads', + "has_publicly_viewable_unprotected_leads", + "has_publicly_viewable_restricted_leads", + "has_publicly_viewable_confidential_leads", ] fields = ( *scalar_fields, # Defined fields - 'organizations', - 'analysis_framework', - 'regions', - 'member_users', - 'member_user_groups', - 'project_viz_config', + "organizations", + "analysis_framework", + "regions", + "member_users", + "member_user_groups", + "project_viz_config", ) def get_project_viz_config(self, obj): stat = obj.project_stats return { - 'public_share': stat.public_share, - 'token': stat.token, + "public_share": stat.public_share, + "token": stat.token, } def get_organizations(self, obj): return ProjectOrganizationSerializer( - obj.projectorganization_set.order_by('organization_id', 'organization_type'), + obj.projectorganization_set.order_by("organization_id", "organization_type"), many=True, ).data def get_regions(self, obj): - return list(obj.regions.order_by('id').values_list('id', flat=True)) + return list(obj.regions.order_by("id").values_list("id", flat=True)) def get_member_users(self, obj): - return list(obj.members.order_by('id').values_list('id', flat=True)) + return list(obj.members.order_by("id").values_list("id", flat=True)) def get_member_user_groups(self, obj): - return list(obj.user_groups.order_by('id').values_list('id', flat=True)) + return list(obj.user_groups.order_by("id").values_list("id", flat=True)) -class ProjectChangeManager(): +class ProjectChangeManager: ACTION_MAP = { - 'details': ProjectChangeLog.Action.PROJECT_DETAILS, - 'organizations': ProjectChangeLog.Action.ORGANIZATION, - 'regions': ProjectChangeLog.Action.REGION, - 'memberships': ProjectChangeLog.Action.MEMBERSHIP, - 'framework': ProjectChangeLog.Action.FRAMEWORK, + "details": ProjectChangeLog.Action.PROJECT_DETAILS, + "organizations": ProjectChangeLog.Action.ORGANIZATION, + "regions": ProjectChangeLog.Action.REGION, + "memberships": ProjectChangeLog.Action.MEMBERSHIP, + "framework": ProjectChangeLog.Action.FRAMEWORK, } def __init__(self, request, project_id): @@ -141,16 +124,14 @@ def __exit__(self, *_): ) def get_active_project_latest_data(self): - return ProjectDataSerializer( - Project.objects.get(pk=self.project_id) - ).data + return ProjectDataSerializer(Project.objects.get(pk=self.project_id)).data @staticmethod def _framework_change_data(new, old, updated): return { - 'new': new, - 'old': old, - 'updated': updated, + "new": new, + "old": old, + "updated": updated, } @staticmethod @@ -166,13 +147,13 @@ def _track_viz_config(viz_config, new_viz_config): } """ changes = {} - if viz_config['public_share'] != new_viz_config['public_share']: - changes['public_share'] = { - 'old': viz_config['public_share'], - 'new': new_viz_config['public_share'], + if viz_config["public_share"] != new_viz_config["public_share"]: + changes["public_share"] = { + "old": viz_config["public_share"], + "new": new_viz_config["public_share"], } - if viz_config['token'] != new_viz_config['token']: - changes['token_changed'] = True + if viz_config["token"] != new_viz_config["token"]: + changes["token_changed"] = True return changes @classmethod @@ -185,42 +166,44 @@ def _track_details(cls, project_data, new_project_data): if old_value == new_value: continue details_change_data[field] = { - 'old': old_value, - 'new': new_value, + "old": old_value, + "new": new_value, } - details_change_data['project_viz_config'] = cls._track_viz_config( - project_data['project_viz_config'], - new_project_data['project_viz_config'], + details_change_data["project_viz_config"] = cls._track_viz_config( + project_data["project_viz_config"], + new_project_data["project_viz_config"], ) return details_change_data @staticmethod def _track_framework(project_data, new_project_data): - framework_id = project_data['analysis_framework'] - new_framework_id = new_project_data['analysis_framework'] + framework_id = project_data["analysis_framework"] + new_framework_id = new_project_data["analysis_framework"] if framework_id != new_framework_id: return { - 'new': new_framework_id, - 'old': framework_id, + "new": new_framework_id, + "old": framework_id, } @classmethod def log_full_changes(cls, project_id, project_data, new_project_data, user): # TODO: 'properties' - diff_data = remove_empty_keys_from_dict({ - 'details': cls._track_details(project_data, new_project_data), - 'organizations': get_flat_dict_diff( - project_data['organizations'], - new_project_data['organizations'], - ['organization', 'organization_type'], - ), - 'regions': get_list_diff(project_data['regions'], new_project_data['regions']), - 'memberships': { - "users": get_list_diff(project_data['member_users'], new_project_data['member_users']), - "user_groups": get_list_diff(project_data['member_user_groups'], new_project_data['member_user_groups']) - }, - 'framework': cls._track_framework(project_data, new_project_data) - }) + diff_data = remove_empty_keys_from_dict( + { + "details": cls._track_details(project_data, new_project_data), + "organizations": get_flat_dict_diff( + project_data["organizations"], + new_project_data["organizations"], + ["organization", "organization_type"], + ), + "regions": get_list_diff(project_data["regions"], new_project_data["regions"]), + "memberships": { + "users": get_list_diff(project_data["member_users"], new_project_data["member_users"]), + "user_groups": get_list_diff(project_data["member_user_groups"], new_project_data["member_user_groups"]), + }, + "framework": cls._track_framework(project_data, new_project_data), + } + ) if diff_data: action = ProjectChangeLog.Action.MULTIPLE @@ -240,17 +223,15 @@ def log_framework_update(cls, af_id, user): """ Flag that an AF is updated to each project. """ - project_ids = Project.objects\ - .filter(analysis_framework=af_id)\ - .values_list('id', flat=True) + project_ids = Project.objects.filter(analysis_framework=af_id).values_list("id", flat=True) change_logs = [ ProjectChangeLog( user=user, project_id=project_id, action=ProjectChangeLog.Action.FRAMEWORK, diff={ - 'framework': { - 'updated': True, + "framework": { + "updated": True, }, }, ) diff --git a/apps/project/dataloaders.py b/apps/project/dataloaders.py index 27b7292211..4abb0a5a48 100644 --- a/apps/project/dataloaders.py +++ b/apps/project/dataloaders.py @@ -1,52 +1,44 @@ from collections import defaultdict -from promise import Promise +from analysis_framework.models import AnalysisFramework +from django.contrib.postgres.aggregates.general import ArrayAgg +from django.contrib.postgres.fields.jsonb import KeyTextTransform from django.core.cache import cache -from django.utils.functional import cached_property from django.db import models -from django.contrib.postgres.fields.jsonb import KeyTextTransform -from django.contrib.postgres.aggregates.general import ArrayAgg from django.db.models.functions import Cast from django.utils import timezone +from django.utils.functional import cached_property +from entry.models import Entry +from export.models import Export +from geo.models import Region +from lead.models import Lead +from promise import Promise +from user.models import User from deep.caches import CacheKey from utils.common import get_number_of_months_between_dates from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from lead.models import Lead -from entry.models import Entry -from export.models import Export -from user.models import User -from geo.models import Region -from analysis_framework.models import AnalysisFramework - -from .models import ( - Project, - ProjectJoinRequest, - ProjectOrganization, -) +from .models import Project, ProjectJoinRequest, ProjectOrganization class ProjectStatLoader(DataLoaderWithContext): def batch_load_fn(self, keys): annotate_data = { - key: Cast(KeyTextTransform(key, 'stats_cache'), models.IntegerField()) + key: Cast(KeyTextTransform(key, "stats_cache"), models.IntegerField()) for key in [ - 'number_of_leads', - 'number_of_leads_not_tagged', - 'number_of_leads_in_progress', - 'number_of_leads_tagged', - 'number_of_entries', - 'number_of_entries_verified', - 'number_of_entries_controlled', - 'number_of_users', + "number_of_leads", + "number_of_leads_not_tagged", + "number_of_leads_in_progress", + "number_of_leads_tagged", + "number_of_entries", + "number_of_entries_verified", + "number_of_entries_controlled", + "number_of_users", ] } stat_qs = Project.objects.filter(id__in=keys).annotate(**annotate_data) - _map = { - project_with_stat.id: project_with_stat - for project_with_stat in stat_qs - } + _map = {project_with_stat.id: project_with_stat for project_with_stat in stat_qs} return Promise.resolve([_map.get(key) for key in keys]) @@ -55,12 +47,9 @@ def batch_load_fn(self, keys): join_status_qs = ProjectJoinRequest.objects.filter( project__in=keys, requested_by=self.context.request.user, - status='pending', - ).values_list('project_id', flat=True) - _map = { - project_id: True - for project_id in join_status_qs - } + status="pending", + ).values_list("project_id", flat=True) + _map = {project_id: True for project_id in join_status_qs} return Promise.resolve([_map.get(key, False) for key in keys]) @@ -70,11 +59,8 @@ def batch_load_fn(self, keys): project__in=keys, requested_by=self.context.request.user, status=ProjectJoinRequest.Status.REJECTED, - ).values_list('project_id', flat=True) - _map = { - project_id: True - for project_id in join_status_qs - } + ).values_list("project_id", flat=True) + _map = {project_id: True for project_id in join_status_qs} return Promise.resolve([_map.get(key, False) for key in keys]) @@ -105,70 +91,84 @@ def get_stats(self): # Projects -- stats_cache__entries_activity are calculated for last 3 months project_count = Project.objects.count() - latest_active_projects_qs = Project.objects\ - .filter(is_private=False)\ - .order_by('-stats_cache__entries_activity', '-created_at') - latest_active_projects = latest_active_projects_qs\ - .values( - 'analysis_framework_id', - project_id=models.F('id'), - project_title=models.F('title'), - analysis_framework_title=models.F('analysis_framework__title'), - )[:5] + latest_active_projects_qs = Project.objects.filter(is_private=False).order_by( + "-stats_cache__entries_activity", "-created_at" + ) + latest_active_projects = latest_active_projects_qs.values( + "analysis_framework_id", + project_id=models.F("id"), + project_title=models.F("title"), + analysis_framework_title=models.F("analysis_framework__title"), + )[:5] # All leads leads_qs = Lead.objects.all() leads_count = leads_qs.count() lead_created_at_range = leads_qs.aggregate( - max_created_at=models.Max('created_at'), - min_created_at=models.Min('created_at'), + max_created_at=models.Max("created_at"), + min_created_at=models.Min("created_at"), ) # Tagged leads tagged_leads_qs = leads_qs.annotate( entries_count=models.Subquery( Entry.objects.filter( - lead=models.OuterRef('pk'), - ).order_by().values('lead').annotate(count=models.Count('id')).values('count')[:1], - output_field=models.IntegerField() + lead=models.OuterRef("pk"), + ) + .order_by() + .values("lead") + .annotate(count=models.Count("id")) + .values("count")[:1], + output_field=models.IntegerField(), ), ).filter(entries_count__gt=0) tagged_leads_count = tagged_leads_qs.count() tagged_lead_created_at_range = tagged_leads_qs.aggregate( - max_created_at=models.Max('created_at'), - min_created_at=models.Min('created_at'), + max_created_at=models.Max("created_at"), + min_created_at=models.Min("created_at"), ) # Exports exports_count = Export.objects.count() exports_created_at_range = Export.objects.aggregate( - max_exported_at=models.Max('exported_at'), - min_exported_at=models.Min('exported_at'), + max_exported_at=models.Max("exported_at"), + min_exported_at=models.Min("exported_at"), ) # Recent frameworks - top_active_frameworks = AnalysisFramework.objects.filter(is_private=False).annotate( - project_count=models.functions.Coalesce( - models.Subquery( - Project.objects.filter( - analysis_framework=models.OuterRef('pk') - ).order_by().values('analysis_framework').annotate( - count=models.Count('id', distinct=True), - ).values('count')[:1], - output_field=models.IntegerField() - ), 0), - source_count=models.functions.Coalesce( - models.Subquery( - Lead.objects.filter( - project__analysis_framework=models.OuterRef('pk') - ).order_by().values('project__analysis_framework').annotate( - count=models.Count('id', distinct=True) - ).values('count')[:1], - output_field=models.IntegerField() - ), 0), - ).order_by('-project_count', '-source_count').values( - analysis_framework_id=models.F('id'), - analysis_framework_title=models.F('title'), - project_count=models.F('project_count'), - source_count=models.F('source_count'), - )[:5] + top_active_frameworks = ( + AnalysisFramework.objects.filter(is_private=False) + .annotate( + project_count=models.functions.Coalesce( + models.Subquery( + Project.objects.filter(analysis_framework=models.OuterRef("pk")) + .order_by() + .values("analysis_framework") + .annotate( + count=models.Count("id", distinct=True), + ) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + source_count=models.functions.Coalesce( + models.Subquery( + Lead.objects.filter(project__analysis_framework=models.OuterRef("pk")) + .order_by() + .values("project__analysis_framework") + .annotate(count=models.Count("id", distinct=True)) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + ) + .order_by("-project_count", "-source_count") + .values( + analysis_framework_id=models.F("id"), + analysis_framework_title=models.F("title"), + project_count=models.F("project_count"), + source_count=models.F("source_count"), + )[:5] + ) return dict( calculated_at=now, @@ -176,32 +176,28 @@ def get_stats(self): total_users=User.objects.filter(is_active=True).count(), total_leads=Lead.objects.count(), total_entries=Entry.objects.count(), - leads_added_weekly=leads_count and ( - leads_count / ( - ( - ( - abs( - lead_created_at_range['max_created_at'] - lead_created_at_range['min_created_at'] - ).days - ) // 7 - ) or 1 - ) + leads_added_weekly=leads_count + and ( + leads_count + / (((abs(lead_created_at_range["max_created_at"] - lead_created_at_range["min_created_at"]).days) // 7) or 1) ), - daily_average_leads_tagged_per_project=tagged_leads_count and ( - tagged_leads_count / ( - ( - abs( - tagged_lead_created_at_range['max_created_at'] - tagged_lead_created_at_range['min_created_at'] - ).days - ) or 1 - ) / (project_count or 1) + daily_average_leads_tagged_per_project=tagged_leads_count + and ( + tagged_leads_count + / ( + (abs(tagged_lead_created_at_range["max_created_at"] - tagged_lead_created_at_range["min_created_at"]).days) + or 1 + ) + / (project_count or 1) ), - generated_exports_monthly=exports_count and ( - exports_count / ( + generated_exports_monthly=exports_count + and ( + exports_count + / ( get_number_of_months_between_dates( - exports_created_at_range['max_exported_at'], - exports_created_at_range['min_exported_at'] - ) or 1 + exports_created_at_range["max_exported_at"], exports_created_at_range["min_exported_at"] + ) + or 1 ) ), top_active_projects=latest_active_projects, @@ -222,9 +218,14 @@ def get_region_queryset(): return Region.objects.all() def batch_load_fn(self, keys): - qs = self.get_region_queryset().filter(project__in=keys).annotate( - projects_id=ArrayAgg('project', filter=models.Q(project__in=keys)), - ).defer('geo_options') + qs = ( + self.get_region_queryset() + .filter(project__in=keys) + .annotate( + projects_id=ArrayAgg("project", filter=models.Q(project__in=keys)), + ) + .defer("geo_options") + ) _map = defaultdict(list) for region in qs.all(): for project_id in region.projects_id: diff --git a/apps/project/enums.py b/apps/project/enums.py index b40345601a..d02ded9ab6 100644 --- a/apps/project/enums.py +++ b/apps/project/enums.py @@ -1,33 +1,32 @@ import graphene from django.db.models.functions import Lower +from deep.permissions import ProjectPermissions as PP from utils.graphene.enums import ( convert_enum_to_graphene_enum, get_enum_name_from_django_field, ) -from deep.permissions import ProjectPermissions as PP from .models import ( Project, - ProjectRole, ProjectJoinRequest, + ProjectMembership, ProjectOrganization, + ProjectRole, ProjectStats, - ProjectMembership, ProjectUserGroupMembership, RecentActivityType, ) ProjectPermissionEnum = graphene.Enum.from_enum(PP.Permission) -ProjectStatusEnum = convert_enum_to_graphene_enum(Project.Status, name='ProjectStatusEnum') -ProjectRoleTypeEnum = convert_enum_to_graphene_enum(ProjectRole.Type, name='ProjectRoleTypeEnum') -ProjectOrganizationTypeEnum = convert_enum_to_graphene_enum(ProjectOrganization.Type, name='ProjectOrganizationTypeEnum') -ProjectJoinRequestStatusEnum = convert_enum_to_graphene_enum(ProjectJoinRequest.Status, name='ProjectJoinRequestStatusEnum') -ProjectStatsStatusEnum = convert_enum_to_graphene_enum(ProjectStats.Status, name='ProjectStatsStatusEnum') -ProjectStatsActionEnum = convert_enum_to_graphene_enum(ProjectStats.Action, name='ProjectStatsActionEnum') -ProjectMembershipBadgeTypeEnum = convert_enum_to_graphene_enum( - ProjectMembership.BadgeType, name='ProjectMembershipBadgeTypeEnum') -RecentActivityTypeEnum = convert_enum_to_graphene_enum(RecentActivityType, name='RecentActivityTypeEnum') +ProjectStatusEnum = convert_enum_to_graphene_enum(Project.Status, name="ProjectStatusEnum") +ProjectRoleTypeEnum = convert_enum_to_graphene_enum(ProjectRole.Type, name="ProjectRoleTypeEnum") +ProjectOrganizationTypeEnum = convert_enum_to_graphene_enum(ProjectOrganization.Type, name="ProjectOrganizationTypeEnum") +ProjectJoinRequestStatusEnum = convert_enum_to_graphene_enum(ProjectJoinRequest.Status, name="ProjectJoinRequestStatusEnum") +ProjectStatsStatusEnum = convert_enum_to_graphene_enum(ProjectStats.Status, name="ProjectStatsStatusEnum") +ProjectStatsActionEnum = convert_enum_to_graphene_enum(ProjectStats.Action, name="ProjectStatsActionEnum") +ProjectMembershipBadgeTypeEnum = convert_enum_to_graphene_enum(ProjectMembership.BadgeType, name="ProjectMembershipBadgeTypeEnum") +RecentActivityTypeEnum = convert_enum_to_graphene_enum(RecentActivityType, name="RecentActivityTypeEnum") enum_map = { get_enum_name_from_django_field(field): enum @@ -43,40 +42,42 @@ } # Additional enums which doesn't have a field in model but are used in serializer -enum_map.update({ - get_enum_name_from_django_field( - None, - field_name='action', # ProjectVizConfigurationSerializer.action - model_name=ProjectStats.__name__, - ): ProjectStatsActionEnum, -}) +enum_map.update( + { + get_enum_name_from_django_field( + None, + field_name="action", # ProjectVizConfigurationSerializer.action + model_name=ProjectStats.__name__, + ): ProjectStatsActionEnum, + } +) class ProjectOrderingEnum(graphene.Enum): # ASC - ASC_TITLE = Lower('title').asc() - ASC_USER_COUNT = 'stats_cache__number_of_users' - ASC_LEAD_COUNT = 'stats_cache__number_of_leads' - ASC_CREATED_AT = 'created_at' - ASC_ANALYSIS_FRAMEWORK = Lower('analysis_framework__title').asc() + ASC_TITLE = Lower("title").asc() + ASC_USER_COUNT = "stats_cache__number_of_users" + ASC_LEAD_COUNT = "stats_cache__number_of_leads" + ASC_CREATED_AT = "created_at" + ASC_ANALYSIS_FRAMEWORK = Lower("analysis_framework__title").asc() # DESC - DESC_TITLE = Lower('title').desc() - DESC_USER_COUNT = f'-{ASC_USER_COUNT}' - DESC_LEAD_COUNT = f'-{ASC_LEAD_COUNT}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_ANALYSIS_FRAMEWORK = Lower('analysis_framework__title').desc() + DESC_TITLE = Lower("title").desc() + DESC_USER_COUNT = f"-{ASC_USER_COUNT}" + DESC_LEAD_COUNT = f"-{ASC_LEAD_COUNT}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_ANALYSIS_FRAMEWORK = Lower("analysis_framework__title").desc() class PublicProjectOrderingEnum(graphene.Enum): # ASC - ASC_TITLE = Lower('title').asc() - ASC_USER_COUNT = 'number_of_users' - ASC_LEAD_COUNT = 'number_of_leads' - ASC_CREATED_AT = 'created_at' - ASC_ANALYSIS_FRAMEWORK = Lower('analysis_framework__title').asc() + ASC_TITLE = Lower("title").asc() + ASC_USER_COUNT = "number_of_users" + ASC_LEAD_COUNT = "number_of_leads" + ASC_CREATED_AT = "created_at" + ASC_ANALYSIS_FRAMEWORK = Lower("analysis_framework__title").asc() # DESC - DESC_TITLE = Lower('title').desc() - DESC_USER_COUNT = f'-{ASC_USER_COUNT}' - DESC_LEAD_COUNT = f'-{ASC_LEAD_COUNT}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_ANALYSIS_FRAMEWORK = Lower('analysis_framework__title').desc() + DESC_TITLE = Lower("title").desc() + DESC_USER_COUNT = f"-{ASC_USER_COUNT}" + DESC_LEAD_COUNT = f"-{ASC_LEAD_COUNT}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_ANALYSIS_FRAMEWORK = Lower("analysis_framework__title").desc() diff --git a/apps/project/factories.py b/apps/project/factories.py index e91cf68abb..528c6f50fa 100644 --- a/apps/project/factories.py +++ b/apps/project/factories.py @@ -1,19 +1,14 @@ import factory from factory.django import DjangoModelFactory -from .models import ( - Project, - ProjectJoinRequest, - ProjectOrganization, - ProjectPinned, -) +from .models import Project, ProjectJoinRequest, ProjectOrganization, ProjectPinned class ProjectFactory(DjangoModelFactory): class Meta: model = Project - title = factory.Sequence(lambda n: f'Project-{n}') + title = factory.Sequence(lambda n: f"Project-{n}") @factory.post_generation def regions(self, create, extracted, **kwargs): diff --git a/apps/project/filter_set.py b/apps/project/filter_set.py index b156c435fd..b5f5e54dfb 100644 --- a/apps/project/filter_set.py +++ b/apps/project/filter_set.py @@ -1,50 +1,40 @@ +import django_filters import graphene - from django.contrib.postgres.aggregates.general import ArrayAgg -from graphene_django.filter.utils import get_filtering_args_from_filterset from django.db import models from django.db.models.functions import Concat, Lower -import django_filters +from geo.models import Region +from graphene_django.filter.utils import get_filtering_args_from_filterset +from user_resource.filters import UserResourceFilterSet, UserResourceGqlFilterSet -from deep.permissions import ProjectPermissions as PP from deep.filter_set import OrderEnumMixin, generate_type_for_filter_set -from utils.graphene.filters import ( - SimpleInputFilter, - IDListFilter, - MultipleInputFilter, -) -from user_resource.filters import UserResourceFilterSet, UserResourceGqlFilterSet +from deep.permissions import ProjectPermissions as PP +from utils.graphene.filters import IDListFilter, MultipleInputFilter, SimpleInputFilter -from geo.models import Region -from .models import ( - Project, - ProjectMembership, - ProjectUserGroupMembership, -) from .enums import ( + ProjectOrderingEnum, ProjectPermissionEnum, ProjectStatusEnum, - ProjectOrderingEnum, PublicProjectOrderingEnum, ) +from .models import Project, ProjectMembership, ProjectUserGroupMembership class ProjectFilterSet(UserResourceFilterSet): class Meta: model = Project - fields = ['id', 'title', 'status', 'user_groups'] + fields = ["id", "title", "status", "user_groups"] filter_overrides = { models.CharField: { - 'filter_class': django_filters.CharFilter, - 'extra': lambda _: { - 'lookup_expr': 'icontains', + "filter_class": django_filters.CharFilter, + "extra": lambda _: { + "lookup_expr": "icontains", }, }, } - is_current_user_member = django_filters.BooleanFilter( - field_name='is_current_user_member', method='filter_with_membership') + is_current_user_member = django_filters.BooleanFilter(field_name="is_current_user_member", method="filter_with_membership") def filter_with_membership(self, queryset, _, value): if value is not None: @@ -60,29 +50,29 @@ def filter_with_membership(self, queryset, _, value): class ProjectMembershipFilterSet(UserResourceFilterSet): class Meta: model = ProjectMembership - fields = ['id', 'project', 'member'] + fields = ["id", "project", "member"] class ProjectUserGroupMembershipFilterSet(UserResourceFilterSet): class Meta: model = ProjectUserGroupMembership - fields = ['id', 'project', 'usergroup'] + fields = ["id", "project", "usergroup"] def get_filtered_projects(user, queries, annotate=False): projects = Project.get_for(user, annotate) - involvement = queries.get('involvement') + involvement = queries.get("involvement") if involvement: - if involvement == 'my_projects': + if involvement == "my_projects": projects = projects.filter(Project.get_query_for_member(user)) - if involvement == 'not_my_projects': + if involvement == "not_my_projects": projects = projects.exclude(Project.get_query_for_member(user)) - regions = queries.get('regions') or '' + regions = queries.get("regions") or "" if regions: - projects = projects.filter(regions__in=regions.split(',')) + projects = projects.filter(regions__in=regions.split(",")) - ordering = queries.get('ordering') + ordering = queries.get("ordering") if ordering: projects = projects.order_by(ordering) @@ -91,18 +81,17 @@ def get_filtered_projects(user, queries, annotate=False): # -------------------- Graphql Filters ----------------------------------- class ProjectGqlFilterSet(OrderEnumMixin, UserResourceGqlFilterSet): - ids = IDListFilter(field_name='id') - exclude_ids = IDListFilter(method='filter_exclude_ids') + ids = IDListFilter(field_name="id") + exclude_ids = IDListFilter(method="filter_exclude_ids") status = SimpleInputFilter(ProjectStatusEnum) organizations = IDListFilter(distinct=True) - analysis_frameworks = IDListFilter(field_name='analysis_framework') + analysis_frameworks = IDListFilter(field_name="analysis_framework") regions = IDListFilter(distinct=True) - search = django_filters.CharFilter(method='filter_title') - is_current_user_member = django_filters.BooleanFilter( - field_name='is_current_user_member', method='filter_with_membership') - has_permission_access = SimpleInputFilter(ProjectPermissionEnum, method='filter_has_permission_access') - ordering = MultipleInputFilter(ProjectOrderingEnum, method='ordering_filter') - is_test = django_filters.BooleanFilter(field_name='is_test', method='filter_is_test') + search = django_filters.CharFilter(method="filter_title") + is_current_user_member = django_filters.BooleanFilter(field_name="is_current_user_member", method="filter_with_membership") + has_permission_access = SimpleInputFilter(ProjectPermissionEnum, method="filter_has_permission_access") + ordering = MultipleInputFilter(ProjectOrderingEnum, method="ordering_filter") + is_test = django_filters.BooleanFilter(field_name="is_test", method="filter_is_test") class Meta: model = Project @@ -137,47 +126,51 @@ def filter_has_permission_access(self, queryset, _, value): id__in=ProjectMembership.objects.filter( member=self.request.user, role__type__in=PP.REVERSE_PERMISSION_MAP[value], - ).values('project') + ).values("project") ) return queryset class PublicProjectGqlFilterSet(ProjectGqlFilterSet): - ordering = MultipleInputFilter(PublicProjectOrderingEnum, method='ordering_filter') + ordering = MultipleInputFilter(PublicProjectOrderingEnum, method="ordering_filter") class ProjectMembershipGqlFilterSet(UserResourceGqlFilterSet): - search = django_filters.CharFilter(method='filter_search') - members = IDListFilter(distinct=True, field_name='member') + search = django_filters.CharFilter(method="filter_search") + members = IDListFilter(distinct=True, field_name="member") class Meta: model = ProjectMembership - fields = ('id',) + fields = ("id",) def filter_search(self, qs, _, value): if value: - return qs.annotate( - full_name=Lower( - Concat( - 'member__first_name', - models.Value(' '), - 'member__last_name', - models.Value(' '), - 'member__email', - output_field=models.CharField(), - ) - ), - ).filter(full_name__icontains=value).distinct() + return ( + qs.annotate( + full_name=Lower( + Concat( + "member__first_name", + models.Value(" "), + "member__last_name", + models.Value(" "), + "member__email", + output_field=models.CharField(), + ) + ), + ) + .filter(full_name__icontains=value) + .distinct() + ) return qs class ProjectUserGroupMembershipGqlFilterSet(UserResourceGqlFilterSet): - search = django_filters.CharFilter(method='filter_search') - usergroups = IDListFilter(distinct=True, field_name='usergroup') + search = django_filters.CharFilter(method="filter_search") + usergroups = IDListFilter(distinct=True, field_name="usergroup") class Meta: model = ProjectUserGroupMembership - fields = ('id',) + fields = ("id",) def filter_search(self, qs, _, value): if value: @@ -187,12 +180,12 @@ def filter_search(self, qs, _, value): class ProjectByRegionGqlFilterSet(django_filters.FilterSet): RegionProjectFilterData = type( - 'RegionProjectFilterData', + "RegionProjectFilterData", (graphene.InputObjectType,), - get_filtering_args_from_filterset(ProjectGqlFilterSet, 'project.schema.ProjectListType') + get_filtering_args_from_filterset(ProjectGqlFilterSet, "project.schema.ProjectListType"), ) - project_filter = SimpleInputFilter(RegionProjectFilterData, method='filter_project_filter') + project_filter = SimpleInputFilter(RegionProjectFilterData, method="filter_project_filter") class Meta: model = Region @@ -209,17 +202,22 @@ def get_project_queryset(self): def qs(self): project_qs = self.get_project_queryset() # Filter project if filter is provided - project_filter = self.data.get('project_filter') + project_filter = self.data.get("project_filter") if project_filter: project_qs = ProjectGqlFilterSet(data=project_filter, queryset=project_qs, request=self.request).qs - return super().qs.annotate( - projects_id=ArrayAgg( - 'project', - distinct=True, - ordering='project', - filter=models.Q(project__in=project_qs), - ), - ).filter(projects_id__isnull=False).only('id', 'centroid') + return ( + super() + .qs.annotate( + projects_id=ArrayAgg( + "project", + distinct=True, + ordering="project", + filter=models.Q(project__in=project_qs), + ), + ) + .filter(projects_id__isnull=False) + .only("id", "centroid") + ) class PublicProjectByRegionGqlFileterSet(ProjectByRegionGqlFilterSet): @@ -233,7 +231,7 @@ def get_project_queryset(self): ProjectsFilterDataType, ProjectsFilterDataInputType = generate_type_for_filter_set( ProjectGqlFilterSet, - 'project.schema.ProjectListType', - 'ProjectsFilterDataType', - 'ProjectsFilterDataInputType', + "project.schema.ProjectListType", + "ProjectsFilterDataType", + "ProjectsFilterDataInputType", ) diff --git a/apps/project/forms.py b/apps/project/forms.py index 6e3b076d59..521423e230 100644 --- a/apps/project/forms.py +++ b/apps/project/forms.py @@ -1,41 +1,41 @@ from django import forms +from .models import ProjectRole from .permissions import PROJECT_PERMISSIONS from .widgets import PermissionsWidget -from .models import ProjectRole class ProjectRoleForm(forms.ModelForm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['lead_permissions'].widget = PermissionsWidget( - 'lead_permissions', # NOTE: this needs to besent to uniquely identify the checkboxes + self.fields["lead_permissions"].widget = PermissionsWidget( + "lead_permissions", # NOTE: this needs to besent to uniquely identify the checkboxes PROJECT_PERMISSIONS.lead, ) - self.fields['entry_permissions'].widget = PermissionsWidget( - 'entry_permissions', + self.fields["entry_permissions"].widget = PermissionsWidget( + "entry_permissions", PROJECT_PERMISSIONS.entry, ) - self.fields['setup_permissions'].widget = PermissionsWidget( - 'setup_permissions', + self.fields["setup_permissions"].widget = PermissionsWidget( + "setup_permissions", PROJECT_PERMISSIONS.setup, ) - self.fields['export_permissions'].widget = PermissionsWidget( - 'export_permissions', + self.fields["export_permissions"].widget = PermissionsWidget( + "export_permissions", PROJECT_PERMISSIONS.export, ) - self.fields['assessment_permissions'].widget = PermissionsWidget( - 'assessment_permissions', + self.fields["assessment_permissions"].widget = PermissionsWidget( + "assessment_permissions", PROJECT_PERMISSIONS.assessment, ) def save(self, commit=True): obj = super().save(commit=False) - obj.lead_permissions = self.cleaned_data['lead_permissions'] - obj.entry_permissions = self.cleaned_data['entry_permissions'] - obj.setup_permissions = self.cleaned_data['setup_permissions'] - obj.export_permissions = self.cleaned_data['export_permissions'] - obj.assessment_permissions = self.cleaned_data['assessment_permissions'] + obj.lead_permissions = self.cleaned_data["lead_permissions"] + obj.entry_permissions = self.cleaned_data["entry_permissions"] + obj.setup_permissions = self.cleaned_data["setup_permissions"] + obj.export_permissions = self.cleaned_data["export_permissions"] + obj.assessment_permissions = self.cleaned_data["assessment_permissions"] obj.save() self.save_m2m() @@ -43,4 +43,4 @@ def save(self, commit=True): class Meta: model = ProjectRole - fields = '__all__' + fields = "__all__" diff --git a/apps/project/management/commands/generate_projects_viz_stats.py b/apps/project/management/commands/generate_projects_viz_stats.py index 6bd86c4d96..7e8a00a546 100644 --- a/apps/project/management/commands/generate_projects_viz_stats.py +++ b/apps/project/management/commands/generate_projects_viz_stats.py @@ -1,19 +1,16 @@ from django.core.management.base import BaseCommand - from project.models import Project from project.tasks import generate_viz_stats class Command(BaseCommand): - help = 'Generate the Project Viz Stats' + help = "Generate the Project Viz Stats" def handle(self, *arg, **options): generate_project_viz_stats() def generate_project_viz_stats(): - project_qs = Project.objects.filter( - is_visualization_enabled=True - ) + project_qs = Project.objects.filter(is_visualization_enabled=True) for project in project_qs: generate_viz_stats(project.id, force=True) diff --git a/apps/project/mixins.py b/apps/project/mixins.py index 3c4be0c3f0..16f02ab977 100644 --- a/apps/project/mixins.py +++ b/apps/project/mixins.py @@ -6,17 +6,18 @@ class ProjectEntityMixin: Mixin with built in permission methods for project entities like lead, entry, assessments, etc. """ + def __getattr__(self, name): - if not name.startswith('can_'): + if not name.startswith("can_"): # super() does not have __getattr__ so call __getattribute__ return super().__getattribute__(name) try: - _, action = name.split('_') # Example: can_modify + _, action = name.split("_") # Example: can_modify except ValueError: return super().__getattribute__(name) selfname = self.__class__.__name__.lower() - roleattr = '{}_{}'.format(name, selfname) # eg: can_modify_entry + roleattr = "{}_{}".format(name, selfname) # eg: can_modify_entry def permission_function(user): role = self.project.get_role(user) @@ -29,4 +30,4 @@ def permission_function(user): @classmethod def get_for(cls, user): - return get_project_entities(cls, user, action='view').distinct() + return get_project_entities(cls, user, action="view").distinct() diff --git a/apps/project/models.py b/apps/project/models.py index 9ee905695f..eaa8bbbb02 100644 --- a/apps/project/models.py +++ b/apps/project/models.py @@ -1,36 +1,34 @@ import uuid +from datetime import timedelta +from analysis_framework.models import AnalysisFramework +from category_editor.models import CategoryEditor from dateutil.relativedelta import relativedelta -from django.urls import reverse -from django.core.exceptions import ValidationError +from django.contrib.auth.models import User from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields.jsonb import KeyTextTransform from django.core.cache import cache -from django.utils.functional import cached_property -from django.db.models.functions import Cast -from django.contrib.auth.models import User +from django.core.exceptions import ValidationError +from django.db import connection as django_db_connection from django.db import models from django.db.models import Q -from django.db.models.functions import JSONObject -from django.db import connection as django_db_connection - -from deep.caches import CacheKey -from user_resource.models import UserResource +from django.db.models.functions import Cast, JSONObject +from django.urls import reverse +from django.utils import timezone +from django.utils.functional import cached_property from geo.models import Region -from user_group.models import UserGroup -from analysis_framework.models import AnalysisFramework -from category_editor.models import CategoryEditor -from project.permissions import PROJECT_PERMISSIONS, PROJECT_PERMISSION_MODEL_MAP from organization.models import Organization +from project.permissions import PROJECT_PERMISSION_MODEL_MAP, PROJECT_PERMISSIONS +from user_group.models import UserGroup +from user_resource.models import UserResource -from django.utils import timezone -from datetime import timedelta +from deep.caches import CacheKey class RecentActivityType(models.TextChoices): - LEAD = 'lead', 'Source' - ENTRY = 'entry', 'Entry' - ENTRY_COMMENT = 'entry-comment', 'Entry Comment' + LEAD = "lead", "Source" + ENTRY = "entry", "Entry" + ENTRY_COMMENT = "entry-comment", "Entry Comment" class Project(UserResource): @@ -40,8 +38,8 @@ class Project(UserResource): # Status Choices class Status(models.TextChoices): - ACTIVE = 'active', 'Active' - INACTIVE = 'inactive', 'Inactive' + ACTIVE = "active", "Active" + INACTIVE = "inactive", "Inactive" PROJECT_INACTIVE_AFTER_MONTHS = 12 @@ -51,29 +49,32 @@ class Status(models.TextChoices): start_date = models.DateField(default=None, null=True, blank=True) end_date = models.DateField(default=None, null=True, blank=True) - members = models.ManyToManyField(User, blank=True, - through_fields=('project', 'member'), - through='ProjectMembership') + members = models.ManyToManyField(User, blank=True, through_fields=("project", "member"), through="ProjectMembership") regions = models.ManyToManyField(Region, blank=True) user_groups = models.ManyToManyField( UserGroup, blank=True, - through='ProjectUserGroupMembership', - through_fields=('project', 'usergroup'), + through="ProjectUserGroupMembership", + through_fields=("project", "usergroup"), ) analysis_framework = models.ForeignKey( - AnalysisFramework, blank=True, - default=None, null=True, + AnalysisFramework, + blank=True, + default=None, + null=True, on_delete=models.SET_NULL, ) category_editor = models.ForeignKey( - CategoryEditor, blank=True, - default=None, null=True, + CategoryEditor, + blank=True, + default=None, + null=True, on_delete=models.SET_NULL, ) assessment_template = models.ForeignKey( - 'ary.AssessmentTemplate', - blank=True, default=None, + "ary.AssessmentTemplate", + blank=True, + default=None, null=True, on_delete=models.SET_NULL, ) @@ -94,8 +95,8 @@ class Status(models.TextChoices): organizations = models.ManyToManyField( Organization, - through='ProjectOrganization', - through_fields=('project', 'organization'), + through="ProjectOrganization", + through_fields=("project", "organization"), blank=True, ) @@ -110,7 +111,7 @@ class Status(models.TextChoices): stats_cache = models.JSONField(default=dict, blank=True) # Stores the geo locations data as cache. geo_cache_hash = models.CharField(max_length=256, null=True, blank=True) - geo_cache_file = models.FileField(upload_to='project-geo-cache/', null=True, blank=True) + geo_cache_file = models.FileField(upload_to="project-geo-cache/", null=True, blank=True) # this is used for project deletion is_deleted = models.BooleanField(default=False) @@ -124,12 +125,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.analysis_framework_id: int self.current_user_membership_data = getattr( - self, 'current_user_membership_data', + self, + "current_user_membership_data", dict( user_id=None, role=None, badges=[], - ) + ), ) def __str__(self): @@ -143,28 +145,24 @@ def project_stats(self): def is_visualization_available(self): af = self.analysis_framework is_viz_enabled = self.is_visualization_enabled - return ( - is_viz_enabled and - af is not None and - af.properties is not None and - af.properties.get('stats_config') is not None - ) + return is_viz_enabled and af is not None and af.properties is not None and af.properties.get("stats_config") is not None def soft_delete(self, deleted_at=None, commit=True): self.is_deleted = True self.deleted_at = deleted_at or timezone.now() if commit: - self.save(update_fields=('is_deleted', 'deleted_at',)) + self.save( + update_fields=( + "is_deleted", + "deleted_at", + ) + ) def get_all_members(self): - return User.objects.filter( - projectmembership__project=self - ) + return User.objects.filter(projectmembership__project=self) def get_direct_members(self): - return self.get_all_members().filter( - projectmembership__linked_group__isnull=True - ) + return self.get_all_members().filter(projectmembership__linked_group__isnull=True) @staticmethod def base_queryset(): @@ -174,16 +172,16 @@ def base_queryset(): def get_annotated(cls): return cls.base_queryset().annotate( **{ - key: Cast(KeyTextTransform(key, 'stats_cache'), models.IntegerField()) + key: Cast(KeyTextTransform(key, "stats_cache"), models.IntegerField()) for key in [ - ('number_of_leads'), - ('number_of_leads_tagged'), - ('number_of_leads_tagged_and_controlled'), - ('number_of_entries'), - ('number_of_users'), + ("number_of_leads"), + ("number_of_leads_tagged"), + ("number_of_leads_tagged_and_controlled"), + ("number_of_entries"), + ("number_of_users"), # NOTE: Used for sorting in discover projects - ('leads_activity'), - ('entries_activity'), + ("leads_activity"), + ("entries_activity"), ] } ) @@ -202,43 +200,48 @@ def get_for_gq(cls, user, only_member=False): """ current_user_role_subquery = models.Subquery( ProjectMembership.objects.filter( - project=models.OuterRef('pk'), + project=models.OuterRef("pk"), member=user, - ).order_by('role__level').values('role__type')[:1], + ) + .order_by("role__level") + .values("role__type")[:1], output_field=models.CharField(), ) current_user_membership_data_subquery = JSONObject( user_id=models.Value(user.id), - role=models.F('current_user_role'), + role=models.F("current_user_role"), badges=models.Subquery( ProjectMembership.objects.filter( - project=models.OuterRef('pk'), + project=models.OuterRef("pk"), member=user, - ).order_by('badges').values('badges')[:1], + ) + .order_by("badges") + .values("badges")[:1], output_field=ArrayField(models.CharField()), ), ) visible_projects = cls.base_queryset() - visible_projects = visible_projects\ - .annotate( + visible_projects = ( + visible_projects.annotate( # For using within query filters current_user_role=current_user_role_subquery, - ).annotate( + ) + .annotate( # NOTE: This is used by permission module current_user_membership_data=current_user_membership_data_subquery, # NOTE: Exclude if project is private + user is not a member - ).exclude( + ) + .exclude( is_private=True, current_user_role__isnull=True, ) + ) if only_member: return visible_projects.filter(current_user_role__isnull=False) return visible_projects def fetch_current_user_membership_data(self, user): - membership = ProjectMembership.objects\ - .select_related('role')\ - .filter(project=self, member=user).first() + membership = ProjectMembership.objects.select_related("role").filter(project=self, member=user).first() current_user_role = None badges = [] if membership: @@ -258,17 +261,17 @@ def get_current_user_attr(self, user, attr): if user is None: return - if self.current_user_membership_data.get('user_id') == user.id: + if self.current_user_membership_data.get("user_id") == user.id: return self.current_user_membership_data.get(attr) self.fetch_current_user_membership_data(user) return self.current_user_membership_data.get(attr) def get_current_user_role(self, user): - return self.get_current_user_attr(user, 'role') + return self.get_current_user_attr(user, "role") def get_current_user_badges(self, user): - return self.get_current_user_attr(user, 'badges') + return self.get_current_user_attr(user, "badges") @classmethod def get_recent_activities(cls, user): @@ -279,32 +282,58 @@ def get_recent_activities(cls, user): project_qs = cls.get_for_member(user) created_by_expression = models.functions.Coalesce( models.Func( - models.Value(' '), models.F('created_by_id__first_name'), models.F('created_by_id__last_name'), - function='CONCAT_WS' - ), models.F('created_by_id__email'), output_field=models.CharField() + models.Value(" "), + models.F("created_by_id__first_name"), + models.F("created_by_id__last_name"), + function="CONCAT_WS", + ), + models.F("created_by_id__email"), + output_field=models.CharField(), ) leads_qs = Lead.objects.filter(project__in=project_qs).values_list( - 'id', 'created_at', 'project_id', 'project__title', - 'created_by_id', 'created_by__profile__display_picture__file', - models.Value('lead', output_field=models.CharField()), - created_by_expression, 'id', 'id', # here id has no use, it is added to resolve error for union + "id", + "created_at", + "project_id", + "project__title", + "created_by_id", + "created_by__profile__display_picture__file", + models.Value("lead", output_field=models.CharField()), + created_by_expression, + "id", + "id", # here id has no use, it is added to resolve error for union ) entry_qs = Entry.objects.filter(project__in=project_qs).values_list( - 'id', 'created_at', 'project_id', 'project__title', - 'created_by_id', 'created_by__profile__display_picture__file', - models.Value('entry', output_field=models.CharField()), - created_by_expression, 'lead__id', 'id', + "id", + "created_at", + "project_id", + "project__title", + "created_by_id", + "created_by__profile__display_picture__file", + models.Value("entry", output_field=models.CharField()), + created_by_expression, + "lead__id", + "id", + ) + entry_comment_qs = ( + EntryReviewComment.objects.filter(entry__project__in=project_qs) + .values_list( + "id", + "created_at", + "entry__project_id", + "entry__project__title", + "created_by_id", + "created_by__profile__display_picture__file", + models.Value("entry-comment", output_field=models.CharField()), + created_by_expression, + "entry__lead__id", + "entry_id", + ) + .distinct("id") ) - entry_comment_qs = EntryReviewComment.objects.filter(entry__project__in=project_qs).values_list( - 'id', 'created_at', 'entry__project_id', 'entry__project__title', - 'created_by_id', 'created_by__profile__display_picture__file', - models.Value('entry-comment', output_field=models.CharField()), - created_by_expression, 'entry__lead__id', 'entry_id', - ).distinct('id') def _get_activities(): - return list(entry_qs.union(entry_comment_qs).union(leads_qs).order_by('-created_at')[:30]) + return list(entry_qs.union(entry_comment_qs).union(leads_qs).order_by("-created_at")[:30]) activities = cache.get_or_set( CacheKey.RECENT_ACTIVITIES_KEY_FORMAT.format(user.pk), @@ -314,18 +343,20 @@ def _get_activities(): return [ { field: item[index] - for index, field in enumerate([ - 'id', - 'created_at', - 'project', - 'project_display_name', - 'created_by', - 'created_by_display_picture', - 'type', - 'created_by_display_name', - 'lead_id', - 'entry_id', - ]) + for index, field in enumerate( + [ + "id", + "created_at", + "project", + "project_display_name", + "created_by", + "created_by_display_picture", + "type", + "created_by_display_name", + "lead_id", + "entry_id", + ] + ) } for item in activities ] @@ -335,25 +366,27 @@ def get_recent_active_projects(user, qs=None, max=3): # NOTE: to avoid circular import from entry.models import Entry from lead.models import Lead + # NOTE: Django ORM union don't allow annotation # TODO: Need to refactor this with django_db_connection.cursor() as cursor: select_sql = [ - f''' + f""" SELECT tb."project_id" AS "project", MAX(tb."{field}_at") AS "date" FROM "{Model._meta.db_table}" AS tb WHERE tb."{field}_by_id" = {user.pk} GROUP BY tb."project_id" - ''' for Model, field in [ - (Lead, 'created'), - (Lead, 'modified'), - (Entry, 'created'), - (Entry, 'modified'), + """ + for Model, field in [ + (Lead, "created"), + (Lead, "modified"), + (Entry, "created"), + (Entry, "modified"), ] ] - union_sql = '(' + ') UNION ('.join(select_sql) + ')' + union_sql = "(" + ") UNION (".join(select_sql) + ")" cursor.execute( f'SELECT DISTINCT(entities."project"), MAX("date") as "date" FROM ({union_sql}) as entities' f' GROUP BY entities."project" ORDER BY "date" DESC' @@ -362,46 +395,38 @@ def get_recent_active_projects(user, qs=None, max=3): if qs is None: qs = Project.get_for_member(user) # only the projects user is member among the recent projects - current_users_project_id = set(qs.filter(pk__in=recent_projects_id).values_list('pk', flat=True)) - recent_projects_id = [ - pk - for pk in recent_projects_id - if pk in current_users_project_id # filter out user project - ][:max] - projects_map = { - project.pk: project - for project in qs.filter(pk__in=recent_projects_id) - } + current_users_project_id = set(qs.filter(pk__in=recent_projects_id).values_list("pk", flat=True)) + recent_projects_id = [pk for pk in recent_projects_id if pk in current_users_project_id][:max] # filter out user project + projects_map = {project.pk: project for project in qs.filter(pk__in=recent_projects_id)} # Maintain the order - recent_projects = [ - projects_map[id] - for id in recent_projects_id if projects_map.get(id) - ] + recent_projects = [projects_map[id] for id in recent_projects_id if projects_map.get(id)] return recent_projects def get_recent_active_users_id_and_date(self, max_users=3): # NOTE: to avoid circular import from entry.models import Entry from lead.models import Lead + # NOTE: Django ORM union don't allow annotation # TODO: Need to refactor this with django_db_connection.cursor() as cursor: select_sql = [ - f''' + f""" SELECT tb."{field}_by_id" AS "user", MAX(tb."{field}_at") AS "date" FROM "{Model._meta.db_table}" AS tb WHERE tb."project_id" = {self.pk} GROUP BY tb."{field}_by_id" - ''' for Model, field in [ - (Lead, 'created'), - (Lead, 'modified'), - (Entry, 'created'), - (Entry, 'modified'), + """ + for Model, field in [ + (Lead, "created"), + (Lead, "modified"), + (Entry, "created"), + (Entry, "modified"), ] ] - union_sql = '(' + ') UNION ('.join(select_sql) + ')' + union_sql = "(" + ") UNION (".join(select_sql) + ")" cursor.execute( f'SELECT DISTINCT(entities."user"), MAX("date") as "date" FROM ({union_sql}) as entities' f' GROUP BY entities."user" ORDER BY "date" DESC Limit {max_users}' @@ -411,9 +436,7 @@ def get_recent_active_users_id_and_date(self, max_users=3): @staticmethod def get_for_public(requestUser, user): - return Project\ - .get_for_member(user)\ - .exclude(models.Q(is_private=True) & ~models.Q(members=requestUser)) + return Project.get_for_member(user).exclude(models.Q(is_private=True) & ~models.Q(members=requestUser)) @staticmethod def get_for_member(user, annotated=False, exclude=False): @@ -432,16 +455,17 @@ def get_query_for_member(user): @staticmethod def get_modifiable_for(user): permission = PROJECT_PERMISSIONS.setup.modify - return Project.get_annotated().filter( - projectmembership__in=ProjectMembership.objects.filter( - member=user, - ).annotate( - new_setup_permission=models.F('role__setup_permissions') - .bitand(permission) - ).filter( - new_setup_permission=permission + return ( + Project.get_annotated() + .filter( + projectmembership__in=ProjectMembership.objects.filter( + member=user, + ) + .annotate(new_setup_permission=models.F("role__setup_permissions").bitand(permission)) + .filter(new_setup_permission=permission) ) - ).distinct() + .distinct() + ) @property def has_assessments(self): @@ -470,7 +494,8 @@ def can_delete(self, user): return role is not None and role.can_delete_setup def add_member( - self, user, + self, + user, role=None, added_by=None, linked_group=None, @@ -489,10 +514,10 @@ def add_member( ) def get_entries_activity(self): - return (self.stats_cache or {}).get('entries_activities') or [] + return (self.stats_cache or {}).get("entries_activities") or [] def get_leads_activity(self): - return (self.stats_cache or {}).get('leads_activities') or [] + return (self.stats_cache or {}).get("leads_activities") or [] def get_admins(self): return User.objects.filter( @@ -507,64 +532,72 @@ def get_default_role_id(): class ProjectOrganization(models.Model): class Type(models.TextChoices): - LEAD_ORGANIZATION = 'lead_organization', 'Lead Organization' # Project Owner - INTERNATIONAL_PARTNER = 'international_partner', 'International Partner' - NATIONAL_PARTNER = 'national_partner', 'National Partner' - DONOR = 'donor', 'Donor' - GOVERNMENT = 'government', 'Government' + LEAD_ORGANIZATION = "lead_organization", "Lead Organization" # Project Owner + INTERNATIONAL_PARTNER = "international_partner", "International Partner" + NATIONAL_PARTNER = "national_partner", "National Partner" + DONOR = "donor", "Donor" + GOVERNMENT = "government", "Government" organization_type = models.CharField(max_length=30, choices=Type.choices) organization = models.ForeignKey(Organization, on_delete=models.CASCADE) project = models.ForeignKey(Project, on_delete=models.CASCADE) class Meta: - unique_together = ('project', 'organization_type', 'organization') + unique_together = ("project", "organization_type", "organization") class ProjectMembership(models.Model): """ Project-Member relationship attributes """ + class BadgeType(models.IntegerChoices): - QA = 0, 'Quality Assurance' + QA = 0, "Quality Assurance" member = models.ForeignKey(User, on_delete=models.CASCADE) project = models.ForeignKey(Project, on_delete=models.CASCADE) role = models.ForeignKey( - 'project.ProjectRole', + "project.ProjectRole", default=get_default_role_id, on_delete=models.CASCADE, ) linked_group = models.ForeignKey( - UserGroup, on_delete=models.CASCADE, - default=None, null=True, blank=True, + UserGroup, + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, ) joined_at = models.DateTimeField(auto_now_add=True) added_by = models.ForeignKey( - User, on_delete=models.CASCADE, - null=True, blank=True, default=None, - related_name='added_project_memberships', + User, + on_delete=models.CASCADE, + null=True, + blank=True, + default=None, + related_name="added_project_memberships", ) # Represents additional permission like QA badges = ArrayField(models.IntegerField(choices=BadgeType.choices), default=list, blank=True) class Meta: - unique_together = ('member', 'project') + unique_together = ("member", "project") def __str__(self): - return '{} @ {}'.format(str(self.member), - self.project.title) + return "{} @ {}".format(str(self.member), self.project.title) def save(self, *args, **kwargs): super().save(*args, **kwargs) - group_membership = self.linked_group and \ - ProjectUserGroupMembership.objects.filter( + group_membership = ( + self.linked_group + and ProjectUserGroupMembership.objects.filter( usergroup=self.linked_group, project=self.project, ).first() + ) if group_membership: role = group_membership.role or ProjectRole.get_default_role() if self.role != role: @@ -589,27 +622,32 @@ class ProjectUserGroupMembership(models.Model): """ Project user group membership model """ + project = models.ForeignKey(Project, on_delete=models.CASCADE) # FIXME: use user_group instead of usergroup for consistency usergroup = models.ForeignKey(UserGroup, on_delete=models.CASCADE) role = models.ForeignKey( - 'project.ProjectRole', on_delete=models.CASCADE, + "project.ProjectRole", + on_delete=models.CASCADE, default=get_default_role_id, ) joined_at = models.DateTimeField(auto_now_add=True) added_by = models.ForeignKey( - User, on_delete=models.CASCADE, - null=True, blank=True, default=None, - related_name='added_project_usergroups', + User, + on_delete=models.CASCADE, + null=True, + blank=True, + default=None, + related_name="added_project_usergroups", ) # Represents additional permission like QA (UserGroup level, we define additionaly in UserMembersip level as well) badges = ArrayField(models.IntegerField(choices=ProjectMembership.BadgeType.choices), default=list, blank=True) class Meta: - unique_together = ('usergroup', 'project') + unique_together = ("usergroup", "project") def __str__(self): - return 'Group {} @ {}'.format(self.usergroup.title, self.project.title) + return "Group {} @ {}".format(self.usergroup.title, self.project.title) @staticmethod def get_for(user): @@ -623,7 +661,7 @@ def can_modify(self, user): def get_default_join_request_data(): - return dict(reason='') + return dict(reason="") class ProjectJoinRequest(models.Model): @@ -632,36 +670,40 @@ class ProjectJoinRequest(models.Model): """ class Status(models.TextChoices): - PENDING = 'pending', 'Pending' - ACCEPTED = 'accepted', 'Accepted' - REJECTED = 'rejected', 'Rejected' + PENDING = "pending", "Pending" + ACCEPTED = "accepted", "Accepted" + REJECTED = "rejected", "Rejected" project = models.ForeignKey(Project, on_delete=models.CASCADE) requested_by = models.ForeignKey( - User, on_delete=models.CASCADE, - related_name='project_join_requests', + User, + on_delete=models.CASCADE, + related_name="project_join_requests", ) requested_at = models.DateTimeField(auto_now_add=True) status = models.CharField(max_length=48, choices=Status.choices, default=Status.PENDING) - role = models.ForeignKey('project.ProjectRole', on_delete=models.CASCADE) + role = models.ForeignKey("project.ProjectRole", on_delete=models.CASCADE) responded_by = models.ForeignKey( - User, on_delete=models.CASCADE, - null=True, blank=True, default=None, - related_name='project_join_responses', + User, + on_delete=models.CASCADE, + null=True, + blank=True, + default=None, + related_name="project_join_responses", ) responded_at = models.DateTimeField(null=True, blank=True, default=None) data = models.JSONField(default=get_default_join_request_data, blank=True, null=True) def __str__(self): - return 'Join request for {} by {} ({})'.format( + return "Join request for {} by {} ({})".format( self.project.title, self.requested_by.profile.get_display_name(), self.status, ) class Meta: - ordering = ('-requested_at',) - unique_together = ('project', 'requested_by') + ordering = ("-requested_at",) + unique_together = ("project", "requested_by") class ProjectRole(models.Model): @@ -670,12 +712,12 @@ class ProjectRole(models.Model): """ class Type(models.TextChoices): - PROJECT_OWNER = 'project_owner', 'Project Owner' - ADMIN = 'admin', 'Admin' - MEMBER = 'member', 'Member' - READER = 'reader', 'Reader' - READER_NON_CONFIDENTIAL = 'reader_non_confidential', 'Reader (Non-confidential)' - UNKNOWN = 'unknown', 'Unknown' + PROJECT_OWNER = "project_owner", "Project Owner" + ADMIN = "admin", "Admin" + MEMBER = "member", "Member" + READER = "reader", "Reader" + READER_NON_CONFIDENTIAL = "reader_non_confidential", "Reader (Non-confidential)" + UNKNOWN = "unknown", "Unknown" title = models.CharField(max_length=255, unique=True) type = models.CharField(choices=Type.choices, default=Type.UNKNOWN, max_length=50) @@ -714,23 +756,21 @@ def __str__(self): return self.title def __getattr__(self, name): - if not name.startswith('can_'): + if not name.startswith("can_"): # super() does not have __getattr__ return super().__getattribute__(name) else: try: - _, action, _item = name.split('_') # Example: can_create_lead + _, action, _item = name.split("_") # Example: can_create_lead # TODO: Better approach item = PROJECT_PERMISSION_MODEL_MAP[_item] except ValueError: return super().__getattribute__(name) try: - item_permissions = self.__getattr__(item + '_permissions') + item_permissions = self.__getattr__(item + "_permissions") except Exception: - raise AttributeError( - 'No permission defined for "{}"'.format(item) - ) + raise AttributeError('No permission defined for "{}"'.format(item)) permission_bit = PROJECT_PERMISSIONS.get(item, {}).get(action) @@ -742,30 +782,28 @@ def __getattr__(self, name): def clean(self): if self.type != self.Type.UNKNOWN and ProjectRole.objects.filter(type=self.type).exclude(pk=self.pk).count() > 0: - raise ValidationError({ - 'type': f'Type: {self.type} is already assigned!!' - }) + raise ValidationError({"type": f"Type: {self.type} is already assigned!!"}) class ProjectStats(models.Model): class Status(models.TextChoices): - PENDING = 'pending', 'Pending' - STARTED = 'started', 'Started' - SUCCESS = 'success', 'Success' - FAILURE = 'failure', 'Failure' + PENDING = "pending", "Pending" + STARTED = "started", "Started" + SUCCESS = "success", "Success" + FAILURE = "failure", "Failure" class Action(models.TextChoices): - NEW = 'new', 'New' - ON = 'on', 'On' - OFF = 'off', 'Off' + NEW = "new", "New" + ON = "on", "On" + OFF = "off", "Off" THRESHOLD_SECONDS = 60 * 20 - project = models.OneToOneField(Project, on_delete=models.CASCADE, related_name='entry_stats') + project = models.OneToOneField(Project, on_delete=models.CASCADE, related_name="entry_stats") modified_at = models.DateTimeField(auto_now=True) status = models.CharField(max_length=30, choices=Status.choices, default=Status.PENDING) - file = models.FileField(upload_to='project-stats/', max_length=255, null=True, blank=True) - confidential_file = models.FileField(upload_to='project-stats/', max_length=255, null=True, blank=True) + file = models.FileField(upload_to="project-stats/", max_length=255, null=True, blank=True) + confidential_file = models.FileField(upload_to="project-stats/", max_length=255, null=True, blank=True) # Token is used to retrive the viz data (non-confidential) public_share = models.BooleanField(default=False) @@ -781,10 +819,7 @@ def get_activity_timeframe(now=None): @classmethod def get_for(cls, user): - return cls.objects.filter( - models.Q(project__members=user) | - models.Q(project__user_groups__members=user) - ).distinct() + return cls.objects.filter(models.Q(project__members=user) | models.Q(project__user_groups__members=user)).distinct() def update_public_share_configuration(self, action: Action, commit=True): if action == self.Action.NEW: @@ -796,39 +831,43 @@ def update_public_share_configuration(self, action: Action, commit=True): elif action == self.Action.OFF: self.public_share = False if commit: - self.save(update_fields=('public_share', 'token',)) + self.save( + update_fields=( + "public_share", + "token", + ) + ) return self def get_public_url(self, request=None): if self.token: - url = reverse('project-stat-viz-public', kwargs={ - 'project_stat_id': self.id, - 'token': self.token, - }) + url = reverse( + "project-stat-viz-public", + kwargs={ + "project_stat_id": self.id, + "token": self.token, + }, + ) if request: url = request.build_absolute_uri(url) return url def is_ready(self): time_threshold = timezone.now() - timedelta(seconds=self.THRESHOLD_SECONDS) - if ( - self.status == ProjectStats.Status.SUCCESS and - self.modified_at > time_threshold and - self.file - ): + if self.status == ProjectStats.Status.SUCCESS and self.modified_at > time_threshold and self.file: return True return False class ProjectChangeLog(models.Model): class Action(models.IntegerChoices): - PROJECT_CREATE = 1, 'Project Create' - PROJECT_DETAILS = 2, 'Project Details' - ORGANIZATION = 3, 'Organization' - REGION = 4, 'Region' - MEMBERSHIP = 5, 'Membership' - FRAMEWORK = 6, 'Framework' - MULTIPLE = 7, 'Multiple fields' + PROJECT_CREATE = 1, "Project Create" + PROJECT_DETAILS = 2, "Project Details" + ORGANIZATION = 3, "Organization" + REGION = 4, "Region" + MEMBERSHIP = 5, "Membership" + FRAMEWORK = 6, "Framework" + MULTIPLE = 7, "Multiple fields" created_at = models.DateTimeField(auto_now_add=True) project = models.ForeignKey(Project, on_delete=models.CASCADE) @@ -845,4 +884,4 @@ class ProjectPinned(models.Model): modified_at = models.DateTimeField(auto_now=True) class Meta: - unique_together = ('project', 'user') + unique_together = ("project", "user") diff --git a/apps/project/mutation.py b/apps/project/mutation.py index 506affa4f2..d7f6b7fd9a 100644 --- a/apps/project/mutation.py +++ b/apps/project/mutation.py @@ -1,54 +1,40 @@ +import graphene +from analysis.mutation import Mutation as AnalysisMutation +from ary.mutation import Mutation as AryMutation +from assessment_registry.mutation import ProjectMutation as AssessmentRegistryMutation +from assisted_tagging.mutation import AssistedTaggingMutationType +from django.core.exceptions import PermissionDenied from django.db import transaction from django.utils.translation import gettext - -import graphene +from entry.mutation import Mutation as EntryMutation +from export.mutation import ProjectMutation as ExportMutation +from geo.models import Region +from geo.schema import RegionDetailType from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField +from lead.mutation import Mutation as LeadMutation +from quality_assurance.mutation import Mutation as QualityAssuranceMutation +from unified_connector.mutation import UnifiedConnectorMutationType -from django.core.exceptions import PermissionDenied - +from deep.permissions import ProjectPermissions as PP +from deep.trackers import TrackerAction, track_project +from utils.graphene.error_types import CustomErrorType, mutation_is_not_valid from utils.graphene.mutation import ( - generate_input_type_for_serializer, + DeleteMutation, GrapheneMutation, - PsGrapheneMutation, PsBulkGrapheneMutation, - DeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from utils.graphene.error_types import mutation_is_not_valid, CustomErrorType - -from deep.permissions import ProjectPermissions as PP -from deep.trackers import TrackerAction, track_project - -from geo.models import Region -from geo.schema import RegionDetailType -from lead.mutation import Mutation as LeadMutation -from entry.mutation import Mutation as EntryMutation -from assessment_registry.mutation import ProjectMutation as AssessmentRegistryMutation -from quality_assurance.mutation import Mutation as QualityAssuranceMutation -from ary.mutation import Mutation as AryMutation -from export.mutation import ProjectMutation as ExportMutation -from analysis.mutation import Mutation as AnalysisMutation -from unified_connector.mutation import UnifiedConnectorMutationType -from assisted_tagging.mutation import AssistedTaggingMutationType from .models import ( Project, - ProjectStats, ProjectJoinRequest, ProjectMembership, - ProjectUserGroupMembership, + ProjectPinned, ProjectRole, - ProjectPinned -) -from .serializers import ( - ProjectGqSerializer, - ProjectJoinGqSerializer, - ProjectAcceptRejectSerializer, - ProjectMembershipGqlSerializer as ProjectMembershipSerializer, - ProjectUserGroupMembershipGqlSerializer as ProjectUserGroupMembershipSerializer, - ProjectVizConfigurationSerializer, - UserPinnedProjectSerializer, - BulkProjectPinnedSerializer + ProjectStats, + ProjectUserGroupMembership, ) from .schema import ( ProjectDetailType, @@ -56,53 +42,62 @@ ProjectMembershipType, ProjectUserGroupMembershipType, ProjectVizDataType, - UserPinnedProjectType + UserPinnedProjectType, ) - +from .serializers import ( + BulkProjectPinnedSerializer, + ProjectAcceptRejectSerializer, + ProjectGqSerializer, + ProjectJoinGqSerializer, +) +from .serializers import ProjectMembershipGqlSerializer as ProjectMembershipSerializer +from .serializers import ( + ProjectUserGroupMembershipGqlSerializer as ProjectUserGroupMembershipSerializer, +) +from .serializers import ProjectVizConfigurationSerializer, UserPinnedProjectSerializer ProjectCreateInputType = generate_input_type_for_serializer( - 'ProjectCreateInputType', + "ProjectCreateInputType", serializer_class=ProjectGqSerializer, ) ProjectUpdateInputType = generate_input_type_for_serializer( - 'ProjectUpdateInputType', + "ProjectUpdateInputType", serializer_class=ProjectGqSerializer, partial=True, ) ProjectJoinRequestInputType = generate_input_type_for_serializer( - 'ProjectJoinRequestInputType', + "ProjectJoinRequestInputType", serializer_class=ProjectJoinGqSerializer, ) ProjectAcceptRejectInputType = generate_input_type_for_serializer( - 'ProjectAcceptRejectInputType', + "ProjectAcceptRejectInputType", serializer_class=ProjectAcceptRejectSerializer, ) ProjectMembershipInputType = generate_input_type_for_serializer( - 'ProjectMembershipInputType', + "ProjectMembershipInputType", serializer_class=ProjectMembershipSerializer, ) ProjectUserGroupMembershipInputType = generate_input_type_for_serializer( - 'ProjectUserGroupMembershipInputType', + "ProjectUserGroupMembershipInputType", serializer_class=ProjectUserGroupMembershipSerializer, ) ProjectVizConfigurationInputType = generate_input_type_for_serializer( - 'ProjectVizConfigurationInputType', + "ProjectVizConfigurationInputType", serializer_class=ProjectVizConfigurationSerializer, ) ProjectPinnedInputType = generate_input_type_for_serializer( - 'ProjectPinnedInputType', - serializer_class=UserPinnedProjectSerializer + "ProjectPinnedInputType", serializer_class=UserPinnedProjectSerializer ) UserPinnedProjectReOrderInputType = generate_input_type_for_serializer( - 'UserPinnedProjectReOrderInputType', + "UserPinnedProjectReOrderInputType", serializer_class=BulkProjectPinnedSerializer, ) @@ -133,7 +128,7 @@ class Arguments: @classmethod def perform_mutate(cls, root, info, **kwargs): - kwargs['id'] = info.context.active_project.id + kwargs["id"] = info.context.active_project.id return super().perform_mutate(root, info, **kwargs) @@ -165,12 +160,15 @@ def mutate(root, info, project_id): project=project_id, ) except ProjectJoinRequest.DoesNotExist: - return ProjectJoinRequestDelete(errors=[ - dict( - field='nonFieldErrors', - messages=gettext('ProjectJoinRequest does not exist for project(id:%s)' % project_id) - ) - ], ok=False) + return ProjectJoinRequestDelete( + errors=[ + dict( + field="nonFieldErrors", + messages=gettext("ProjectJoinRequest does not exist for project(id:%s)" % project_id), + ) + ], + ok=False, + ) instance.delete() instance.id = id return ProjectJoinRequestDelete(result=instance, errors=None, ok=True) @@ -188,15 +186,17 @@ def mutate(root, info, **kwargs): role__type=ProjectRole.Type.PROJECT_OWNER, ) if not membership_qs.exists(): - return ProjectDelete(errors=[ - dict( - field='nonFieldErrors', - messages=gettext( - 'You should be Project Owner to delete this project(id:%s)' - % info.context.active_project.id - ), - ) - ], ok=False) + return ProjectDelete( + errors=[ + dict( + field="nonFieldErrors", + messages=gettext( + "You should be Project Owner to delete this project(id:%s)" % info.context.active_project.id + ), + ) + ], + ok=False, + ) root.soft_delete() return ProjectDelete(result=root, errors=None, ok=True) @@ -211,7 +211,7 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = ProjectJoinGqSerializer(data=data, context={'request': info.context.request}) + serializer = ProjectJoinGqSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return CreateProjectJoin(errors=errors, ok=False) instance = serializer.save() @@ -301,17 +301,17 @@ class Arguments: @classmethod def perform_mutate(cls, _, info, **kwargs): project = info.context.active_project - regions_to_add = kwargs.get('regions_to_add') or [] - regions_to_remove = kwargs.get('regions_to_remove') or [] + regions_to_add = kwargs.get("regions_to_add") or [] + regions_to_remove = kwargs.get("regions_to_remove") or [] existing_regions = project.regions.all() added_regions = [ region - for region in Region.objects.filter(id__in=regions_to_add).exclude( - id__in=existing_regions.values('id') - ).order_by('id') + for region in Region.objects.filter(id__in=regions_to_add) + .exclude(id__in=existing_regions.values("id")) + .order_by("id") if region.public or region.can_modify(info.context.user) ] - deleted_regions = list(existing_regions.filter(id__in=regions_to_remove).order_by('id')) + deleted_regions = list(existing_regions.filter(id__in=regions_to_remove).order_by("id")) assert len(added_regions) <= len(regions_to_add) assert len(deleted_regions) <= len(regions_to_remove) # Remove regions @@ -334,6 +334,7 @@ class Arguments: class CreateUserPinnedProject(PsGrapheneMutation): class Arguments: data = ProjectPinnedInputType(required=True) + model = ProjectPinned result = graphene.Field(UserPinnedProjectType) serializer_class = UserPinnedProjectSerializer @@ -350,15 +351,16 @@ class ProjectMutationType( AnalysisMutation, AssessmentRegistryMutation, # --End Project Scoped Mutation - DjangoObjectType + DjangoObjectType, ): """ This mutation is for other scoped objects """ + class Meta: model = Project skip_registry = True - fields = ('id', 'title') + fields = ("id", "title") project_update = UpdateProject.Field() project_delete = ProjectDelete.Field() @@ -394,6 +396,7 @@ def resolve_assisted_tagging(root, info, **kwargs): class ReorderPinnedProjects(PsGrapheneMutation): class Arguments: items = graphene.List(graphene.NonNull(UserPinnedProjectReOrderInputType)) + model = ProjectPinned result = graphene.List(UserPinnedProjectType) serializer_class = BulkProjectPinnedSerializer @@ -405,11 +408,11 @@ def perform_mutate(cls, root, info, **kwargs): errors_data = [] serializers_data = [] results = [] - for data in kwargs['items']: - instance, errors = cls.get_object(info, id=data['id']) + for data in kwargs["items"]: + instance, errors = cls.get_object(info, id=data["id"]) if errors: errors_data.append(errors) - serializer = cls.serializer_class(data=data, instance=instance, context={'request': info.context.request}) + serializer = cls.serializer_class(data=data, instance=instance, context={"request": info.context.request}) errors_data.append(mutation_is_not_valid(serializer)) # errors_data also add empty list serializers_data.append(serializer) errors_data = [items for items in errors_data if items] # list comprehension removing empty list @@ -423,6 +426,7 @@ def perform_mutate(cls, root, info, **kwargs): class DeleteUserPinnedProject(DeleteMutation): class Arguments: id = graphene.ID(required=True) + model = ProjectPinned result = graphene.Field(UserPinnedProjectType) permissions = [] @@ -430,19 +434,17 @@ class Arguments: @staticmethod def mutate(root, info, id): - project_pinned_qs = ProjectPinned.objects.filter( - id=id, - user=info.context.user - ) + project_pinned_qs = ProjectPinned.objects.filter(id=id, user=info.context.user) if not project_pinned_qs.exists(): - return DeleteUserPinnedProject(errors=[ - dict( - field='nonFieldErrors', - messages=gettext( - 'Not authorize the unpinned project ' - ), - ) - ], ok=False) + return DeleteUserPinnedProject( + errors=[ + dict( + field="nonFieldErrors", + messages=gettext("Not authorize the unpinned project "), + ) + ], + ok=False, + ) project_pinned_qs.delete() return DeleteUserPinnedProject(result=root, errors=None, ok=True) diff --git a/apps/project/permissions.py b/apps/project/permissions.py index 601930b8c4..d8853e75da 100644 --- a/apps/project/permissions.py +++ b/apps/project/permissions.py @@ -1,4 +1,5 @@ from functools import reduce + from django.db import models from rest_framework import permissions @@ -6,13 +7,13 @@ # NOTE: Defined such that two model can share same permission model PROJECT_PERMISSION_MODEL_MAP = { - 'lead': 'lead', - 'entry': 'entry', - 'analysis': 'entry', - 'setup': 'setup', - 'export': 'export', - 'assessment': 'assessment', - 'plannedassessment': 'assessment', + "lead": "lead", + "entry": "entry", + "analysis": "entry", + "setup": "setup", + "export": "export", + "assessment": "assessment", + "plannedassessment": "assessment", } @@ -24,27 +25,28 @@ delete=1 << 3, view_only_unprotected=1 << 4, ), - entry=Dict({ - 'view': 1 << 0, - 'create': 1 << 1, - 'modify': 1 << 2, - 'delete': 1 << 3, - 'view_only_unprotected': 1 << 4, - }), - setup=Dict({ - 'modify': 1 << 0, - 'delete': 1 << 1, - }), - export=Dict({ - 'create': 1 << 0, - 'create_only_unprotected': 1 << 1, - }), - assessment=Dict({ - 'view': 1 << 0, - 'create': 1 << 1, - 'modify': 1 << 2, - 'delete': 1 << 3 - }) + entry=Dict( + { + "view": 1 << 0, + "create": 1 << 1, + "modify": 1 << 2, + "delete": 1 << 3, + "view_only_unprotected": 1 << 4, + } + ), + setup=Dict( + { + "modify": 1 << 0, + "delete": 1 << 1, + } + ), + export=Dict( + { + "create": 1 << 0, + "create_only_unprotected": 1 << 1, + } + ), + assessment=Dict({"view": 1 << 0, "create": 1 << 1, "modify": 1 << 2, "delete": 1 << 3}), ) @@ -55,7 +57,7 @@ def get_project_permissions_value(_item, actions=[]): 1 << 0 | 1 << 3 """ item = PROJECT_PERMISSION_MODEL_MAP[_item] - if actions == '__all__': + if actions == "__all__": # set all bits to 1 return reduce(lambda a, e: a | e, PROJECT_PERMISSIONS[item].values()) permissions = 0 @@ -71,14 +73,14 @@ def get_project_permissions_value(_item, actions=[]): class JoinPermission(permissions.BasePermission): def has_object_permission(self, request, view, obj): from project.models import ProjectJoinRequest + # User should not already be a member # and there should not be existing request by this user # to this project (whether pending, accepted or rejected). return ( - not obj.is_member(request.user) and - not ProjectJoinRequest.objects.filter( - models.Q(status='pending') | - models.Q(status='rejected'), + not obj.is_member(request.user) + and not ProjectJoinRequest.objects.filter( + models.Q(status="pending") | models.Q(status="rejected"), project=obj, requested_by=request.user, ).exists() @@ -88,6 +90,7 @@ def has_object_permission(self, request, view, obj): class AcceptRejectPermission(permissions.BasePermission): def has_object_permission(self, request, view, obj): from project.models import ProjectMembership, ProjectRole + return ProjectMembership.objects.filter( project=obj, member=request.user, @@ -98,6 +101,7 @@ def has_object_permission(self, request, view, obj): class MembershipModifyPermission(permissions.BasePermission): def has_object_permission(self, request, view, obj): from project.models import ProjectMembership + if request.method in permissions.SAFE_METHODS: return True @@ -105,10 +109,7 @@ def has_object_permission(self, request, view, obj): return True project = obj.project - user = ProjectMembership.objects.filter( - project=project, - member=request.user - ).first() + user = ProjectMembership.objects.filter(project=project, member=request.user).first() user_role = user and user.role if not user_role or user_role.level > obj.role.level: return False @@ -135,16 +136,17 @@ def get_project_entities(Entity, user, action=None): # TODO: camelcase to snakecase instead of just lower() item = PROJECT_PERMISSION_MODEL_MAP[Entity.__name__.lower()] - item_permissions = item + '_permissions' + item_permissions = item + "_permissions" permission = PROJECT_PERMISSIONS.get(item, {}).get(action) if permission is None: return Entity.objects.none() - fieldname = 'project__projectmembership__role__{}'.format(item_permissions) - return Entity.objects.filter( - project__projectmembership__member=user, - ).annotate( - new_permission_col=models.F(fieldname).bitand(permission) - ).filter( - new_permission_col=permission - ).distinct() + fieldname = "project__projectmembership__role__{}".format(item_permissions) + return ( + Entity.objects.filter( + project__projectmembership__member=user, + ) + .annotate(new_permission_col=models.F(fieldname).bitand(permission)) + .filter(new_permission_col=permission) + .distinct() + ) diff --git a/apps/project/public_schema.py b/apps/project/public_schema.py index 4f14aa491c..c70a76fc60 100644 --- a/apps/project/public_schema.py +++ b/apps/project/public_schema.py @@ -1,17 +1,15 @@ import graphene - -from graphene_django import DjangoObjectType from django.contrib.postgres.aggregates import StringAgg from django.contrib.postgres.fields.jsonb import KeyTextTransform -from django.db.models.functions import Cast, Coalesce from django.db import models - -from utils.graphene.types import CustomDjangoListObjectType +from django.db.models.functions import Cast, Coalesce +from graphene_django import DjangoObjectType from deep.serializers import URLCachedFileField +from utils.graphene.types import CustomDjangoListObjectType -from .models import Project from .filter_set import PublicProjectGqlFilterSet +from .models import Project class PublicProjectType(DjangoObjectType): @@ -19,13 +17,13 @@ class Meta: model = Project skip_registry = True fields = ( - 'id', - 'title', - 'description', - 'created_at', + "id", + "title", + "description", + "created_at", ) - analysis_framework = graphene.ID(source='analysis_framework_id') + analysis_framework = graphene.ID(source="analysis_framework_id") analysis_framework_title = graphene.String() regions_title = graphene.String() organizations_title = graphene.String() @@ -37,9 +35,7 @@ class Meta: @staticmethod def resolve_analysis_framework_preview_image(root, info, **kwargs): if root.preview_image: - return info.context.request.build_absolute_uri( - URLCachedFileField.name_to_representation(root.preview_image) - ) + return info.context.request.build_absolute_uri(URLCachedFileField.name_to_representation(root.preview_image)) return None @@ -51,60 +47,59 @@ class Meta: @classmethod def queryset(cls): - return Project.objects.filter( - is_deleted=False, - is_private=False, - is_test=False, - ).annotate( - analysis_framework_title=models.Case( - models.When( - analysis_framework__is_private=False, - then=models.F('analysis_framework__title') + return ( + Project.objects.filter( + is_deleted=False, + is_private=False, + is_test=False, + ) + .annotate( + analysis_framework_title=models.Case( + models.When(analysis_framework__is_private=False, then=models.F("analysis_framework__title")), + default=None, ), - default=None, - ), - preview_image=models.Case( - models.When( - analysis_framework__is_private=False, - then=models.F('analysis_framework__preview_image') + preview_image=models.Case( + models.When(analysis_framework__is_private=False, then=models.F("analysis_framework__preview_image")), + default=None, ), - default=None - ), - regions_title=StringAgg( - 'regions__title', - ', ', - filter=models.Q( - ~models.Q(regions__title=''), - regions__public=True, - regions__title__isnull=False, + regions_title=StringAgg( + "regions__title", + ", ", + filter=models.Q( + ~models.Q(regions__title=""), + regions__public=True, + regions__title__isnull=False, + ), + distinct=True, ), - distinct=True, - ), - organizations_title=StringAgg( - models.Case( - models.When( - projectorganization__organization__parent__isnull=False, - then='projectorganization__organization__parent__title' + organizations_title=StringAgg( + models.Case( + models.When( + projectorganization__organization__parent__isnull=False, + then="projectorganization__organization__parent__title", + ), + default="projectorganization__organization__title", ), - default='projectorganization__organization__title', + ", ", + distinct=True, ), - ', ', - distinct=True, - ), - **{ - key: Coalesce( - Cast(KeyTextTransform(key, 'stats_cache'), models.IntegerField()), - 0, - ) - for key in ['number_of_leads', 'number_of_users', 'number_of_entries'] - }, - ).only( - 'id', - 'title', - 'description', - 'analysis_framework_id', - 'created_at', - ).distinct() + **{ + key: Coalesce( + Cast(KeyTextTransform(key, "stats_cache"), models.IntegerField()), + 0, + ) + for key in ["number_of_leads", "number_of_users", "number_of_entries"] + }, + ) + .only( + "id", + "title", + "description", + "analysis_framework_id", + "created_at", + ) + .distinct() + ) class PublicProjectWithMembershipData(graphene.ObjectType): diff --git a/apps/project/receivers.py b/apps/project/receivers.py index e8510fe94c..5bd2ac2861 100644 --- a/apps/project/receivers.py +++ b/apps/project/receivers.py @@ -1,12 +1,11 @@ from django.db import models from django.dispatch import receiver - -from user.models import User from project.models import ( + ProjectJoinRequest, ProjectMembership, ProjectUserGroupMembership, - ProjectJoinRequest, ) +from user.models import User @receiver(models.signals.post_save, sender=ProjectUserGroupMembership) @@ -37,9 +36,7 @@ def refresh_project_memberships_usergroup_removed(sender, instance, **kwargs): ) for membership in remove_memberships: - other_user_groups = membership.get_user_group_options().exclude( - id=user_group.id - ) + other_user_groups = membership.get_user_group_options().exclude(id=user_group.id) if other_user_groups.count() > 0: membership.linked_group = other_user_groups.first() membership.save() @@ -52,13 +49,13 @@ def refresh_project_memberships_usergroup_removed(sender, instance, **kwargs): @receiver(models.signals.post_save, sender=ProjectMembership) def on_membership_saved(sender, **kwargs): # if kwargs.get('created'): - instance = kwargs.get('instance') + instance = kwargs.get("instance") ProjectJoinRequest.objects.filter( project=instance.project, requested_by=instance.member, - status='pending', + status="pending", ).update( - status='accepted', + status="accepted", responded_by=instance.added_by, responded_at=instance.joined_at, ) diff --git a/apps/project/schema.py b/apps/project/schema.py index c3627800b3..62aea6f9d4 100644 --- a/apps/project/schema.py +++ b/apps/project/schema.py @@ -1,203 +1,187 @@ from typing import List import graphene -from django.db import transaction, models +from analysis.schema import Query as AnalysisQuery +from ary.schema import Query as AryQuery +from assessment_registry.dashboard_schema import ( + Query as AssessmentRegistryDashboardQuery, +) +from assessment_registry.schema import ProjectQuery as AssessmentRegistryQuery +from assisted_tagging.schema import AssistedTaggingQueryType +from dateutil.relativedelta import relativedelta +from django.contrib.postgres.fields.jsonb import KeyTextTransform +from django.db import models, transaction from django.db.models import QuerySet from django.db.models.functions import Cast from django.utils import timezone -from django.contrib.postgres.fields.jsonb import KeyTextTransform -from dateutil.relativedelta import relativedelta -from graphene_django import DjangoObjectType, DjangoListField +from entry.models import Entry +from entry.schema import Query as EntryQuery +from export.schema import ProjectQuery as ExportQuery +from geo.models import Region +from geo.schema import ProjectScopeQuery as GeoQuery +from geo.schema import RegionDetailType from graphene.types import generic +from graphene_django import DjangoListField, DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination +from lead.filter_set import LeadsFilterDataInputType +from lead.models import Lead +from lead.schema import Query as LeadQuery +from quality_assurance.schema import Query as QualityAssuranceQuery +from unified_connector.schema import UnifiedConnectorQueryType +from user.models import User +from user.schema import UserType +from user_resource.schema import UserResourceMixin - -from utils.graphene.geo_scalars import PointScalar +from deep.permissions import ProjectPermissions as PP +from deep.serializers import URLCachedFileField +from deep.trackers import TrackerAction, track_project from utils.graphene.enums import EnumDescription +from utils.graphene.fields import DjangoPaginatedListObjectField +from utils.graphene.geo_scalars import PointScalar from utils.graphene.pagination import NoOrderingPageGraphqlPagination from utils.graphene.types import ( - CustomDjangoListObjectType, ClientIdMixin, + CustomDjangoListObjectType, DateCountType, UserEntityCountType, UserEntityDateType, ) -from utils.graphene.fields import ( - DjangoPaginatedListObjectField, -) -from deep.permissions import ProjectPermissions as PP -from deep.serializers import URLCachedFileField -from deep.trackers import TrackerAction, track_project -from user_resource.schema import UserResourceMixin -from user.models import User -from user.schema import UserType -from lead.schema import Query as LeadQuery -from entry.schema import Query as EntryQuery -from export.schema import ProjectQuery as ExportQuery -from geo.schema import RegionDetailType, ProjectScopeQuery as GeoQuery -from quality_assurance.schema import Query as QualityAssuranceQuery -from ary.schema import Query as AryQuery -from analysis.schema import Query as AnalysisQuery -from assessment_registry.schema import ProjectQuery as AssessmentRegistryQuery -from unified_connector.schema import UnifiedConnectorQueryType -from assisted_tagging.schema import AssistedTaggingQueryType -from assessment_registry.dashboard_schema import Query as AssessmentRegistryDashboardQuery -from lead.models import Lead -from entry.models import Entry -from geo.models import Region - -from lead.filter_set import LeadsFilterDataInputType - -from .models import ( - Project, - ProjectRole, - ProjectMembership, - ProjectUserGroupMembership, - ProjectJoinRequest, - ProjectOrganization, - ProjectStats, - RecentActivityType as ActivityTypes, - ProjectPinned -) +from .activity import project_activity_log from .enums import ( - ProjectPermissionEnum, - ProjectStatusEnum, - ProjectRoleTypeEnum, ProjectJoinRequestStatusEnum, - ProjectOrganizationTypeEnum, ProjectMembershipBadgeTypeEnum, + ProjectOrganizationTypeEnum, + ProjectPermissionEnum, + ProjectRoleTypeEnum, + ProjectStatusEnum, RecentActivityTypeEnum, ) - from .filter_set import ( + ProjectByRegionGqlFilterSet, ProjectGqlFilterSet, ProjectMembershipGqlFilterSet, ProjectUserGroupMembershipGqlFilterSet, - ProjectByRegionGqlFilterSet, PublicProjectByRegionGqlFileterSet, ) -from .activity import project_activity_log -from .tasks import generate_viz_stats, get_project_stats +from .models import ( + Project, + ProjectJoinRequest, + ProjectMembership, + ProjectOrganization, + ProjectPinned, + ProjectRole, + ProjectStats, + ProjectUserGroupMembership, +) +from .models import RecentActivityType as ActivityTypes from .public_schema import PublicProjectListType +from .tasks import generate_viz_stats, get_project_stats def get_recent_active_users(project, max_users=3): # id, date users_activity = project.get_recent_active_users_id_and_date(max_users=max_users) - recent_active_users_map = { - user.pk: user - for user in User.objects.filter(pk__in=[id for id, _ in users_activity]) - } - recent_active_users = [ - (recent_active_users_map[id], date) - for id, date in users_activity - if id in recent_active_users_map - ] + recent_active_users_map = {user.pk: user for user in User.objects.filter(pk__in=[id for id, _ in users_activity])} + recent_active_users = [(recent_active_users_map[id], date) for id, date in users_activity if id in recent_active_users_map] return [ { - 'user_id': user.id, - 'name': user.get_display_name(), - 'date': date, - } for user, date in recent_active_users + "user_id": user.id, + "name": user.get_display_name(), + "date": date, + } + for user, date in recent_active_users ] def get_top_entity_contributor(project, Entity): - contributors = ProjectMembership.objects.filter( - project=project, - ).annotate( - entity_count=models.functions.Coalesce(models.Subquery( - Entity.objects.filter( - project=project, - created_by=models.OuterRef('member'), - ).order_by().values('project') - .annotate(cnt=models.Count('*')).values('cnt')[:1], - output_field=models.IntegerField(), - ), 0), - ).order_by('-entity_count').select_related('member')[:5] + contributors = ( + ProjectMembership.objects.filter( + project=project, + ) + .annotate( + entity_count=models.functions.Coalesce( + models.Subquery( + Entity.objects.filter( + project=project, + created_by=models.OuterRef("member"), + ) + .order_by() + .values("project") + .annotate(cnt=models.Count("*")) + .values("cnt")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + ) + .order_by("-entity_count") + .select_related("member")[:5] + ) return [ { - 'name': contributor.member.get_display_name(), - 'user_id': contributor.member.id, - 'count': contributor.entity_count, - } for contributor in contributors + "name": contributor.member.get_display_name(), + "user_id": contributor.member.id, + "count": contributor.entity_count, + } + for contributor in contributors ] def get_project_stats_summary(self): - projects = Project.get_for_member(self.context.request.user).only('id') + projects = Project.get_for_member(self.context.request.user).only("id") # Lead stats leads = Lead.objects.filter(project__in=projects) - total_leads_tagged_count = ( - leads - .annotate(entries_count=models.Count('entry')) - .filter(entries_count__gt=0).count() - ) + total_leads_tagged_count = leads.annotate(entries_count=models.Count("entry")).filter(entries_count__gt=0).count() total_leads_tagged_and_controlled_count = ( leads.annotate( - entries_count=models.Count('entry'), + entries_count=models.Count("entry"), controlled_entries_count=models.Count( - 'entry', + "entry", filter=models.Q(entry__controlled=True), - ) - ).filter( + ), + ) + .filter( entries_count__gt=0, - entries_count=models.F('controlled_entries_count'), - ).count() + entries_count=models.F("controlled_entries_count"), + ) + .count() ) # Entries activity recent_projects_id = list( - projects.annotate( - entries_count=Cast( - KeyTextTransform('entries_activity', 'stats_cache'), - models.IntegerField() - ) - ) + projects.annotate(entries_count=Cast(KeyTextTransform("entries_activity", "stats_cache"), models.IntegerField())) .filter(entries_count__gt=0) - .order_by('-entries_count') - .values_list('id', flat=True)[:3] + .order_by("-entries_count") + .values_list("id", flat=True)[:3] ) recent_entries = Entry.objects.filter( - project__in=recent_projects_id, - created_at__gte=(timezone.now() + relativedelta(months=-3)) + project__in=recent_projects_id, created_at__gte=(timezone.now() + relativedelta(months=-3)) ) recent_entries_activity = ( - recent_entries - .order_by('created_at__date') - .values('created_at__date').annotate( - count=models.Count('*') - ) - .values( - 'project_id', - 'count', - date=models.Func(models.F('created_at__date'), function='DATE') - ) + recent_entries.order_by("created_at__date") + .values("created_at__date") + .annotate(count=models.Count("*")) + .values("project_id", "count", date=models.Func(models.F("created_at__date"), function="DATE")) ) recent_entries_project_details = ( - recent_entries - .order_by() - .values('project') - .annotate(count=models.Count('*')) + recent_entries.order_by() + .values("project") + .annotate(count=models.Count("*")) .filter(count__gt=0) - .values( - 'count', - id=models.F('project'), - title=models.F('project__title') - ) + .values("count", id=models.F("project"), title=models.F("project__title")) ) return { - 'projects_count': projects.count(), - 'total_leads_count': leads.count(), - 'total_leads_tagged_count': total_leads_tagged_count, - 'total_leads_tagged_and_controlled_count': total_leads_tagged_and_controlled_count, - 'recent_entries_project_details': recent_entries_project_details, - 'recent_entries_activities': recent_entries_activity + "projects_count": projects.count(), + "total_leads_count": leads.count(), + "total_leads_tagged_count": total_leads_tagged_count, + "total_leads_tagged_and_controlled_count": total_leads_tagged_and_controlled_count, + "recent_entries_project_details": recent_entries_project_details, + "recent_entries_activities": recent_entries_activity, } @@ -212,22 +196,30 @@ class ProjectExploreStatType(graphene.ObjectType): generated_exports_monthly = graphene.Int() top_active_projects = graphene.List( graphene.NonNull( - type('ExploreProjectStatTopActiveProjectsType', (graphene.ObjectType,), { - 'project_id': graphene.Field(graphene.NonNull(graphene.ID)), - 'project_title': graphene.String(), - 'analysis_framework_id': graphene.ID(), - 'analysis_framework_title': graphene.String(), - }) + type( + "ExploreProjectStatTopActiveProjectsType", + (graphene.ObjectType,), + { + "project_id": graphene.Field(graphene.NonNull(graphene.ID)), + "project_title": graphene.String(), + "analysis_framework_id": graphene.ID(), + "analysis_framework_title": graphene.String(), + }, + ) ) ) top_active_frameworks = graphene.List( graphene.NonNull( - type('ExploreProjectStatTopActiveFrameworksType', (graphene.ObjectType,), { - 'analysis_framework_id': graphene.Field(graphene.NonNull(graphene.ID)), - 'analysis_framework_title': graphene.String(), - 'project_count': graphene.NonNull(graphene.Int), - 'source_count': graphene.NonNull(graphene.Int) - }) + type( + "ExploreProjectStatTopActiveFrameworksType", + (graphene.ObjectType,), + { + "analysis_framework_id": graphene.Field(graphene.NonNull(graphene.ID)), + "analysis_framework_title": graphene.String(), + "project_count": graphene.NonNull(graphene.Int), + "source_count": graphene.NonNull(graphene.Int), + }, + ) ) ) @@ -255,20 +247,24 @@ class ProjectStatType(graphene.ObjectType): @staticmethod def resolve_leads_activity(root, info, **kwargs): - return (root.stats_cache or {}).get('leads_activities') or [] + return (root.stats_cache or {}).get("leads_activities") or [] @staticmethod def resolve_entries_activity(root, info, **kwargs): - return (root.stats_cache or {}).get('entries_activities') or [] + return (root.stats_cache or {}).get("entries_activities") or [] class ProjectOrganizationType(DjangoObjectType, UserResourceMixin, ClientIdMixin): class Meta: model = ProjectOrganization - only_fields = ('id', 'client_id', 'organization',) + only_fields = ( + "id", + "client_id", + "organization", + ) organization_type = graphene.Field(ProjectOrganizationTypeEnum, required=True) - organization_type_display = EnumDescription(source='get_organization_type_display', required=True) + organization_type_display = EnumDescription(source="get_organization_type_display", required=True) @staticmethod def resolve_organization(root, info): @@ -278,7 +274,7 @@ def resolve_organization(root, info): class ProjectRoleType(DjangoObjectType): class Meta: model = ProjectRole - only_fields = ('id', 'title', 'level') + only_fields = ("id", "title", "level") type = graphene.Field(ProjectRoleTypeEnum, required=True) @@ -287,8 +283,12 @@ class ProjectMembershipType(ClientIdMixin, DjangoObjectType): class Meta: model = ProjectMembership only_fields = ( - 'id', 'member', 'linked_group', - 'role', 'joined_at', 'added_by', + "id", + "member", + "linked_group", + "role", + "joined_at", + "added_by", ) badges = graphene.List(graphene.NonNull(ProjectMembershipBadgeTypeEnum)) @@ -298,8 +298,11 @@ class ProjectUserGroupMembershipType(ClientIdMixin, DjangoObjectType): class Meta: model = ProjectUserGroupMembership only_fields = ( - 'id', 'usergroup', - 'role', 'joined_at', 'added_by', + "id", + "usergroup", + "role", + "joined_at", + "added_by", ) badges = graphene.List(graphene.NonNull(ProjectMembershipBadgeTypeEnum)) @@ -309,26 +312,37 @@ class ProjectType(UserResourceMixin, DjangoObjectType): class Meta: model = Project only_fields = ( - 'id', 'title', 'description', 'start_date', 'end_date', - 'analysis_framework', 'assessment_template', - 'is_default', 'is_private', 'is_test', 'is_visualization_enabled', - 'is_assessment_enabled', - 'created_at', 'created_by', - 'modified_at', 'modified_by', + "id", + "title", + "description", + "start_date", + "end_date", + "analysis_framework", + "assessment_template", + "is_default", + "is_private", + "is_test", + "is_visualization_enabled", + "is_assessment_enabled", + "created_at", + "created_by", + "modified_at", + "modified_by", ) current_user_role = graphene.Field(ProjectRoleTypeEnum) allowed_permissions = graphene.List( graphene.NonNull( ProjectPermissionEnum, - ), required=True + ), + required=True, ) stats = graphene.Field(ProjectStatType) membership_pending = graphene.Boolean(required=True) is_rejected = graphene.Boolean(required=True) regions = DjangoListField(RegionDetailType) status = graphene.Field(ProjectStatusEnum, required=True) - status_display = EnumDescription(source='get_status_display', required=True) + status_display = EnumDescription(source="get_status_display", required=True) organizations = graphene.List(graphene.NonNull(ProjectOrganizationType)) has_analysis_framework = graphene.Boolean(required=True) has_assessment_template = graphene.Boolean(required=True) @@ -381,10 +395,7 @@ def resolve_regions(root, info, **kwargs): return info.context.dl.project.public_geo_region.load(root.pk) def resolve_is_project_pinned(root, info, **kwargs): - return ProjectPinned.objects.filter( - project=root, - user=info.context.request.user - ).exists() + return ProjectPinned.objects.filter(project=root, user=info.context.request.user).exists() class RecentActivityType(graphene.ObjectType): @@ -398,31 +409,27 @@ class RecentActivityType(graphene.ObjectType): entry_id = graphene.ID() def resolve_created_by(root, info, **kwargs): - id = int(root['created_by']) + id = int(root["created_by"]) return info.context.dl.project.users.load(id) def resolve_project(root, info, **kwargs): - id = int(root['project']) + id = int(root["project"]) return info.context.dl.project.projects.load(id) def resolve_type_display(root, info, **kwargs): - return ActivityTypes(root['type']).label + return ActivityTypes(root["type"]).label def resolve_entry_id(root, info, **kwargs): - if root['type'] == ActivityTypes.LEAD: + if root["type"] == ActivityTypes.LEAD: return - return root['entry_id'] + return root["entry_id"] class AnalysisFrameworkVisibleProjectType(DjangoObjectType): class Meta: model = Project skip_registry = True - only_fields = ( - 'id', - 'title', - 'is_private' - ) + only_fields = ("id", "title", "is_private") class ProjectMembershipListType(CustomDjangoListObjectType): @@ -441,9 +448,9 @@ class ProjectVizDataType(DjangoObjectType): class Meta: model = ProjectStats only_fields = ( - 'modified_at', - 'status', - 'public_share', + "modified_at", + "status", + "public_share", ) data_url = graphene.String() @@ -457,7 +464,7 @@ def resolve_status(root, info, **_): # NOTE: Not changing modified_at if already pending if root.status != ProjectStats.Status.PENDING: root.status = ProjectStats.Status.PENDING - root.save(update_fields=('status',)) + root.save(update_fields=("status",)) return root.status @staticmethod @@ -493,16 +500,28 @@ class Meta: model = Project skip_registry = True only_fields = ( - 'id', 'title', 'description', 'start_date', 'end_date', 'analysis_framework', - 'category_editor', 'assessment_template', 'data', - 'created_at', 'created_by', - 'modified_at', 'modified_by', - 'is_default', 'is_private', 'is_test', 'is_visualization_enabled', - 'is_assessment_enabled', - 'has_publicly_viewable_unprotected_leads', - 'has_publicly_viewable_restricted_leads', - 'has_publicly_viewable_confidential_leads', - 'enable_publicly_viewable_analysis_report_snapshot', + "id", + "title", + "description", + "start_date", + "end_date", + "analysis_framework", + "category_editor", + "assessment_template", + "data", + "created_at", + "created_by", + "modified_at", + "modified_by", + "is_default", + "is_private", + "is_test", + "is_visualization_enabled", + "is_assessment_enabled", + "has_publicly_viewable_unprotected_leads", + "has_publicly_viewable_restricted_leads", + "has_publicly_viewable_confidential_leads", + "enable_publicly_viewable_analysis_report_snapshot", ) analysis_framework = graphene.Field(AnalysisFrameworkDetailType) @@ -512,20 +531,14 @@ class Meta: top_taggers = graphene.List(graphene.NonNull(UserEntityCountType)) user_members = DjangoPaginatedListObjectField( - ProjectMembershipListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + ProjectMembershipListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) user_group_members = DjangoPaginatedListObjectField( - ProjectUserGroupMembershipListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + ProjectUserGroupMembershipListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) is_visualization_available = graphene.Boolean( required=True, - description='Checks if visualization is enabled and analysis framework is configured.', + description="Checks if visualization is enabled and analysis framework is configured.", ) stats = graphene.Field( ProjectStatType, @@ -535,10 +548,7 @@ class Meta: # Other scoped queries unified_connector = graphene.Field(UnifiedConnectorQueryType) assisted_tagging = graphene.Field(AssistedTaggingQueryType) - is_project_pinned = graphene.Boolean( - required=True, - description='Check if user have pinned the project' - ) + is_project_pinned = graphene.Boolean(required=True, description="Check if user have pinned the project") @staticmethod def resolve_user_members(root, info, **kwargs): @@ -590,27 +600,25 @@ def resolve_assisted_tagging(root, info, **kwargs): @staticmethod def resolve_is_project_pinned(root, info, **kwargs): - return ProjectPinned.objects.filter( - project=root, - user=info.context.request.user - ).exists() + return ProjectPinned.objects.filter(project=root, user=info.context.request.user).exists() class UserPinnedProjectType(ClientIdMixin, DjangoObjectType): class Meta: model = ProjectPinned only_fields = ( - 'id', + "id", "project", "user", "order", "client_id", ) + project = graphene.Field(graphene.NonNull(ProjectDetailType)) class ProjectByRegion(graphene.ObjectType): - id = graphene.ID(required=True, description='Region\'s ID') + id = graphene.ID(required=True, description="Region's ID") # NOTE: Annotated by ProjectByRegionGqlFilterSet/PublicProjectByRegionGqlFileterSet projects_id = graphene.List(graphene.NonNull(graphene.ID)) centroid = PointScalar() @@ -620,11 +628,11 @@ class ProjectJoinRequestType(DjangoObjectType): class Meta: model = ProjectJoinRequest only_fields = ( - 'id', - 'data', - 'requested_by', - 'responded_by', - 'project', + "id", + "data", + "requested_by", + "responded_by", + "project", ) status = graphene.Field(ProjectJoinRequestStatusEnum, required=True) @@ -635,8 +643,10 @@ class Meta: model = Region skip_registry = True only_fields = ( - 'id', 'centroid', + "id", + "centroid", ) + # NOTE: Annotated by ProjectByRegionGqlFilterSet/PublicProjectByRegionGqlFileterSet projects_id = graphene.List(graphene.NonNull(graphene.ID)) @@ -685,10 +695,7 @@ class UserProjectSummaryStatType(graphene.ObjectType): class Query: project = DjangoObjectField(ProjectDetailType) projects = DjangoPaginatedListObjectField( - ProjectListType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + ProjectListType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) recent_projects = graphene.List(graphene.NonNull(ProjectDetailType)) recent_activities = graphene.List(graphene.NonNull(RecentActivityType)) @@ -700,16 +707,10 @@ class Query: # PUBLIC NODES public_projects = DjangoPaginatedListObjectField( - PublicProjectListType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + PublicProjectListType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) public_projects_by_region = DjangoPaginatedListObjectField( - PublicProjectByRegionListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + PublicProjectByRegionListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) user_pinned_projects = DjangoListField(UserPinnedProjectType, required=True) user_project_stat_summary = graphene.Field(UserProjectSummaryStatType, required=True) @@ -733,9 +734,7 @@ def resolve_recent_projects(root, info, **kwargs) -> QuerySet: @staticmethod def resolve_projects_by_region(root, info, **kwargs): - return Region.objects\ - .filter(centroid__isnull=False)\ - .order_by('centroid') + return Region.objects.filter(centroid__isnull=False).order_by("centroid") @staticmethod def resolve_project_explore_stats(root, info, **kwargs): diff --git a/apps/project/serializers.py b/apps/project/serializers.py index 4e50d4a555..30c08191c0 100644 --- a/apps/project/serializers.py +++ b/apps/project/serializers.py @@ -1,74 +1,68 @@ -from django.db import models +from analysis_framework.models import AnalysisFrameworkMembership +from ary.models import AssessmentTemplate +from django.db import models, transaction +from django.shortcuts import get_object_or_404 from django.utils import timezone from django.utils.functional import cached_property -from django.shortcuts import get_object_or_404 -from django.db import transaction from django.utils.translation import gettext - from drf_dynamic_fields import DynamicFieldsMixin -from rest_framework import serializers -from rest_framework.exceptions import PermissionDenied - -from deep.permissions import AnalysisFrameworkPermissions as AfP -from deep.serializers import ( - RemoveNullFieldsMixin, - URLCachedFileField, - IntegerIDField, - TempClientIdMixin, - ProjectPropertySerializerMixin -) +from entry.models import Entry, Lead from geo.models import Region from geo.serializers import SimpleRegionSerializer -from entry.models import Lead, Entry -from analysis_framework.models import AnalysisFrameworkMembership +from organization.serializers import SimpleOrganizationSerializer +from rest_framework import serializers +from rest_framework.exceptions import PermissionDenied from user.models import Feature from user.serializers import SimpleUserSerializer -from user_group.models import UserGroup from user.utils import ( - send_project_join_request_emails, send_project_accept_email, - send_project_reject_email + send_project_join_request_emails, + send_project_reject_email, ) +from user_group.models import UserGroup from user_group.serializers import SimpleUserGroupSerializer -from user_resource.serializers import UserResourceSerializer, DeprecatedUserResourceSerializer -from ary.models import AssessmentTemplate +from user_resource.serializers import ( + DeprecatedUserResourceSerializer, + UserResourceSerializer, +) +from deep.permissions import AnalysisFrameworkPermissions as AfP +from deep.serializers import ( + IntegerIDField, + ProjectPropertySerializerMixin, + RemoveNullFieldsMixin, + TempClientIdMixin, + URLCachedFileField, +) + +from .activity import project_activity_log from .change_log import ProjectChangeManager from .models import ( Project, - ProjectMembership, ProjectJoinRequest, + ProjectMembership, + ProjectOrganization, ProjectPinned, ProjectRole, - ProjectUserGroupMembership, - ProjectOrganization, ProjectStats, + ProjectUserGroupMembership, ) - -from organization.serializers import ( - SimpleOrganizationSerializer -) - from .permissions import PROJECT_PERMISSIONS -from .activity import project_activity_log -class SimpleProjectSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class SimpleProjectSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Project - fields = ('id', 'title', 'is_private') + fields = ("id", "title", "is_private") class ProjectNotificationSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Project - fields = ('id', 'title') + fields = ("id", "title") -class ProjectRoleSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): +class ProjectRoleSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): lead_permissions = serializers.SerializerMethodField() entry_permissions = serializers.SerializerMethodField() setup_permissions = serializers.SerializerMethodField() @@ -77,156 +71,127 @@ class ProjectRoleSerializer(RemoveNullFieldsMixin, class Meta: model = ProjectRole - fields = '__all__' + fields = "__all__" def get_lead_permissions(self, roleobj): - return [ - k - for k, v in PROJECT_PERMISSIONS['lead'].items() - if roleobj.lead_permissions & v != 0 - ] + return [k for k, v in PROJECT_PERMISSIONS["lead"].items() if roleobj.lead_permissions & v != 0] def get_entry_permissions(self, roleobj): - return [ - k - for k, v in PROJECT_PERMISSIONS['entry'].items() - if roleobj.entry_permissions & v != 0 - ] + return [k for k, v in PROJECT_PERMISSIONS["entry"].items() if roleobj.entry_permissions & v != 0] def get_setup_permissions(self, roleobj): - return [ - k - for k, v in PROJECT_PERMISSIONS['setup'].items() - if roleobj.setup_permissions & v != 0 - ] + return [k for k, v in PROJECT_PERMISSIONS["setup"].items() if roleobj.setup_permissions & v != 0] def get_export_permissions(self, roleobj): - return [ - k - for k, v in PROJECT_PERMISSIONS['export'].items() - if roleobj.export_permissions & v != 0 - ] + return [k for k, v in PROJECT_PERMISSIONS["export"].items() if roleobj.export_permissions & v != 0] def get_assessment_permissions(self, roleobj): - return [ - k - for k, v in PROJECT_PERMISSIONS['assessment'].items() - if roleobj.assessment_permissions & v != 0 - ] + return [k for k, v in PROJECT_PERMISSIONS["assessment"].items() if roleobj.assessment_permissions & v != 0] -class SimpleProjectRoleSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): +class SimpleProjectRoleSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = ProjectRole - fields = ('id', 'title', 'level') + fields = ("id", "title", "level") -class ProjectOrganizationSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer, - serializers.ModelSerializer): - organization_type_display = serializers.CharField(source='get_organization_type_display', read_only=True) - organization_details = SimpleOrganizationSerializer(source='organization', read_only=True) +class ProjectOrganizationSerializer( + RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer, serializers.ModelSerializer +): + organization_type_display = serializers.CharField(source="get_organization_type_display", read_only=True) + organization_details = SimpleOrganizationSerializer(source="organization", read_only=True) class Meta: model = ProjectOrganization - fields = ('id', 'organization', 'organization_details', 'organization_type', 'organization_type_display') + fields = ("id", "organization", "organization_details", "organization_type", "organization_type_display") -class ProjectMembershipSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): - member_email = serializers.CharField(source='member.email', read_only=True) - member_name = serializers.CharField( - source='member.profile.get_display_name', read_only=True) +class ProjectMembershipSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): + member_email = serializers.CharField(source="member.email", read_only=True) + member_name = serializers.CharField(source="member.profile.get_display_name", read_only=True) added_by_name = serializers.CharField( - source='added_by.profile.get_display_name', + source="added_by.profile.get_display_name", read_only=True, ) member_status = serializers.SerializerMethodField() member_organization = serializers.CharField( - source='member.profile.organization', + source="member.profile.organization", read_only=True, ) user_group_options = SimpleUserGroupSerializer( - source='get_user_group_options', + source="get_user_group_options", read_only=True, many=True, ) - role_details = SimpleProjectRoleSerializer(source='role', read_only=True) + role_details = SimpleProjectRoleSerializer(source="role", read_only=True) class Meta: model = ProjectMembership - fields = '__all__' - read_only_fields = ('project',) + fields = "__all__" + read_only_fields = ("project",) def get_member_status(self, membership): - if ProjectRole.get_admin_roles().filter( - id=membership.role.id - ).exists(): - return 'admin' - return 'member' + if ProjectRole.get_admin_roles().filter(id=membership.role.id).exists(): + return "admin" + return "member" # Validations def validate_project(self, project): - if not project.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if not project.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project def project_member_validation(self, project, member): - if ProjectMembership.objects.filter( - project=project, - member=member - ).exists(): - raise serializers.ValidationError({'member': 'Member already exist'}) + if ProjectMembership.objects.filter(project=project, member=member).exists(): + raise serializers.ValidationError({"member": "Member already exist"}) def validate(self, data): - data['project_id'] = int(self.context['view'].kwargs['project_id']) - member = data.get('member') + data["project_id"] = int(self.context["view"].kwargs["project_id"]) + member = data.get("member") if not self.instance: - self.project_member_validation(data['project_id'], member) - role = data.get('role') + self.project_member_validation(data["project_id"], member) + role = data.get("role") if not role: return data - user = self.context['request'].user - user_role = ProjectMembership.objects.filter( - project=data['project_id'], - member=user, - ).first().role + user = self.context["request"].user + user_role = ( + ProjectMembership.objects.filter( + project=data["project_id"], + member=user, + ) + .first() + .role + ) if role.level < user_role.level: - raise serializers.ValidationError('Invalid role') + raise serializers.ValidationError("Invalid role") return data def create(self, validated_data): resource = super().create(validated_data) - resource.added_by = self.context['request'].user + resource.added_by = self.context["request"].user resource.save() return resource -class ProjectUsergroupMembershipSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): - group_title = serializers.CharField(source='usergroup.title') +class ProjectUsergroupMembershipSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): + group_title = serializers.CharField(source="usergroup.title") class Meta: model = ProjectUserGroupMembership - fields = '__all__' + fields = "__all__" def get_unique_together_validators(self): return [] # Validations def validate_project(self, project): - if not project.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if not project.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project def create(self, validated_data): resource = super().create(validated_data) - resource.added_by = self.context['request'].user + resource.added_by = self.context["request"].user resource.save() return resource @@ -234,7 +199,7 @@ def create(self, validated_data): class ProjectSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, DeprecatedUserResourceSerializer): organizations = ProjectOrganizationSerializer( - source='projectorganization_set', + source="projectorganization_set", many=True, ) @@ -244,15 +209,15 @@ class ProjectSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, DeprecatedUse member_status = serializers.SerializerMethodField() analysis_framework_title = serializers.CharField( - source='analysis_framework.title', + source="analysis_framework.title", read_only=True, ) assessment_template_title = serializers.CharField( - source='assessment_template.title', + source="assessment_template.title", read_only=True, ) category_editor_title = serializers.CharField( - source='category_editor.title', + source="category_editor.title", read_only=True, ) @@ -261,24 +226,20 @@ class ProjectSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, DeprecatedUse number_of_users = serializers.IntegerField(read_only=True) is_visualization_enabled = serializers.SerializerMethodField(read_only=True) has_assessments = serializers.BooleanField(required=False) - status_display = serializers.CharField(source='get_status_display', read_only=True) + status_display = serializers.CharField(source="get_status_display", read_only=True) class Meta: model = Project - exclude = ('members', 'stats_cache') + exclude = ("members", "stats_cache") def create(self, validated_data): - member = self.context['request'].user - is_private = validated_data.get('is_private', False) + member = self.context["request"].user + is_private = validated_data.get("is_private", False) - private_access = member.profile.get_accessible_features().filter( - key=Feature.FeatureKey.PRIVATE_PROJECT - ).exists() + private_access = member.profile.get_accessible_features().filter(key=Feature.FeatureKey.PRIVATE_PROJECT).exists() if is_private and not private_access: - raise PermissionDenied( - {'message': "You don't have permission to create private project"} - ) + raise PermissionDenied({"message": "You don't have permission to create private project"}) project = super().create(validated_data) ProjectMembership.objects.create( @@ -292,18 +253,17 @@ def create(self, validated_data): def update(self, instance, validated_data): # TODO; might need to check for private project feature access, # But that might be redundant, since checked in creation, I don't know - framework = validated_data.get('analysis_framework') - user = self.context['request'].user + framework = validated_data.get("analysis_framework") + user = self.context["request"].user - if 'is_private' in validated_data and\ - validated_data['is_private'] != instance.is_private: - raise PermissionDenied('Cannot change privacy of project') + if "is_private" in validated_data and validated_data["is_private"] != instance.is_private: + raise PermissionDenied("Cannot change privacy of project") if framework is None or not framework.is_private: return super().update(instance, validated_data) if not instance.is_private and framework.is_private: - raise PermissionDenied('Cannot use private framework in public project') + raise PermissionDenied("Cannot use private framework in public project") memberships = AnalysisFrameworkMembership.objects.filter( framework=framework, @@ -311,104 +271,88 @@ def update(self, instance, validated_data): ) if not memberships.exists(): # Send a bad request, use should not know if the framework exists - raise serializers.ValidationError('Invalid Analysis Framework') + raise serializers.ValidationError("Invalid Analysis Framework") if memberships.filter(role__can_use_in_other_projects=True).exists(): return super().update(instance, validated_data) - raise PermissionDenied( - {'message': "You don't have permissions to use the analysis framework in the project"} - ) + raise PermissionDenied({"message": "You don't have permissions to use the analysis framework in the project"}) def validate(self, data): - has_assessments = data.pop('has_assessments', None) + has_assessments = data.pop("has_assessments", None) if has_assessments is not None: - data['assessment_template'] = AssessmentTemplate.objects.first() if has_assessments else None + data["assessment_template"] = AssessmentTemplate.objects.first() if has_assessments else None return data def get_is_visualization_enabled(self, project): af = project.analysis_framework is_viz_enabled = project.is_visualization_enabled - entry_viz_enabled = ( - is_viz_enabled and - af.properties is not None and - af.properties.get('stats_config') is not None - ) + entry_viz_enabled = is_viz_enabled and af.properties is not None and af.properties.get("stats_config") is not None # Entry viz data is required by ARY VIZ ary_viz_enabled = entry_viz_enabled return { - 'entry': entry_viz_enabled, - 'assessment': ary_viz_enabled, + "entry": entry_viz_enabled, + "assessment": ary_viz_enabled, } def get_member_status(self, project): - request = self.context['request'] - user = request.GET.get('user', request.user) + request = self.context["request"] + user = request.GET.get("user", request.user) role = project.get_role(user) if role: if ProjectRole.get_admin_roles().filter(id=role.id).exists(): - return 'admin' - return 'member' + return "admin" + return "member" join_request = ProjectJoinRequest.objects.filter( project=project, requested_by=user, ).first() - if join_request and ( - join_request.status == 'pending' or - join_request.status == 'rejected' - ): + if join_request and (join_request.status == "pending" or join_request.status == "rejected"): return join_request.status - return 'none' + return "none" def get_role(self, project): - request = self.context['request'] - user = request.GET.get('user', request.user) + request = self.context["request"] + user = request.GET.get("user", request.user) - membership = ProjectMembership.objects.filter( - project=project, - member=user - ).first() + membership = ProjectMembership.objects.filter(project=project, member=user).first() if membership: return membership.role.id return None # Validations def validate_user_groups(self, user_groups): - for user_group_obj in self.initial_data['user_groups']: - user_group = UserGroup.objects.get(id=user_group_obj['id']) + for user_group_obj in self.initial_data["user_groups"]: + user_group = UserGroup.objects.get(id=user_group_obj["id"]) if self.instance and user_group in self.instance.user_groups.all(): continue - if not user_group.can_modify(self.context['request'].user): - raise serializers.ValidationError( - 'Invalid user group: {}'.format(user_group.id)) + if not user_group.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid user group: {}".format(user_group.id)) return user_groups def validate_regions(self, data): - for region_obj in self.initial_data['regions']: - region = Region.objects.get(id=region_obj.get('id')) + for region_obj in self.initial_data["regions"]: + region = Region.objects.get(id=region_obj.get("id")) if self.instance and region in self.instance.regions.all(): continue - if not region.public and \ - not region.can_modify(self.context['request'].user): - raise serializers.ValidationError( - 'Invalid region: {}'.format(region.id)) + if not region.public and not region.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid region: {}".format(region.id)) return data def validate_analysis_framework(self, analysis_framework): - if not analysis_framework.can_get(self.context['request'].user): - raise serializers.ValidationError( - 'Invalid analysis framework: {}'.format(analysis_framework.id)) + if not analysis_framework.can_get(self.context["request"].user): + raise serializers.ValidationError("Invalid analysis framework: {}".format(analysis_framework.id)) return analysis_framework class ProjectMemberViewSerializer(ProjectSerializer): memberships = ProjectMembershipSerializer( - source='projectmembership_set', + source="projectmembership_set", many=True, read_only=True, ) @@ -420,34 +364,46 @@ class ProjectStatSerializer(ProjectSerializer): number_of_leads_tagged_and_controlled = serializers.IntegerField(read_only=True) number_of_entries = serializers.IntegerField(read_only=True) - leads_activity = serializers.ReadOnlyField(source='get_leads_activity') - entries_activity = serializers.ReadOnlyField(source='get_entries_activity') + leads_activity = serializers.ReadOnlyField(source="get_leads_activity") + entries_activity = serializers.ReadOnlyField(source="get_entries_activity") top_sourcers = serializers.SerializerMethodField() top_taggers = serializers.SerializerMethodField() activity_log = serializers.SerializerMethodField() def _get_top_entity_contributer(self, project, Entity): - contributers = ProjectMembership.objects.filter( - project=project, - ).annotate( - entity_count=models.functions.Coalesce(models.Subquery( - Entity.objects.filter( - project=project, - created_by=models.OuterRef('member'), - ).order_by().values('project') - .annotate(cnt=models.Count('*')).values('cnt')[:1], - output_field=models.IntegerField(), - ), 0), - ).order_by('-entity_count').select_related('member', 'member__profile')[:5] + contributers = ( + ProjectMembership.objects.filter( + project=project, + ) + .annotate( + entity_count=models.functions.Coalesce( + models.Subquery( + Entity.objects.filter( + project=project, + created_by=models.OuterRef("member"), + ) + .order_by() + .values("project") + .annotate(cnt=models.Count("*")) + .values("cnt")[:1], + output_field=models.IntegerField(), + ), + 0, + ), + ) + .order_by("-entity_count") + .select_related("member", "member__profile")[:5] + ) return [ { - 'id': contributer.id, - 'name': contributer.member.profile.get_display_name(), - 'user_id': contributer.member.id, - 'count': contributer.entity_count, - } for contributer in contributers + "id": contributer.id, + "name": contributer.member.profile.get_display_name(), + "user_id": contributer.member.id, + "count": contributer.entity_count, + } + for contributer in contributers ] def get_top_sourcers(self, project): @@ -460,48 +416,45 @@ def get_activity_log(self, project): return list(project_activity_log(project)) -class ProjectJoinRequestSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer): +class ProjectJoinRequestSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): project = SimpleProjectSerializer(read_only=True) requested_by = SimpleUserSerializer(read_only=True) responded_by = SimpleUserSerializer(read_only=True) # `reason` will be stored into json field - reason = serializers.CharField(source='data.reason', required=True) + reason = serializers.CharField(source="data.reason", required=True) class Meta: model = ProjectJoinRequest - fields = '__all__' + fields = "__all__" def create(self, validated_data): - validated_data['project'] = self.context['project'] - validated_data['requested_by'] = self.context['request'].user - validated_data['status'] = 'pending' + validated_data["project"] = self.context["project"] + validated_data["requested_by"] = self.context["request"].user + validated_data["status"] = "pending" return super(ProjectJoinRequestSerializer, self).create(validated_data) class ProjectUserGroupSerializer(serializers.ModelSerializer): - title = serializers.CharField(source='usergroup.title', read_only=True) - role_details = SimpleProjectRoleSerializer(source='role', read_only=True) - added_by_name = serializers.CharField(source='added_by.profile.get_display_name', read_only=True) + title = serializers.CharField(source="usergroup.title", read_only=True) + role_details = SimpleProjectRoleSerializer(source="role", read_only=True) + added_by_name = serializers.CharField(source="added_by.profile.get_display_name", read_only=True) class Meta: model = ProjectUserGroupMembership - fields = '__all__' - read_only_fields = ('project',) + fields = "__all__" + read_only_fields = ("project",) def validate(self, data): - data['project_id'] = int(self.context['view'].kwargs['project_id']) - usergroup = data.get('usergroup') - if usergroup and ProjectUserGroupMembership.objects.filter(project=data['project_id'], - usergroup=usergroup).exists(): - raise serializers.ValidationError({'usergroup': 'Usergroup already exist in the project'}) + data["project_id"] = int(self.context["view"].kwargs["project_id"]) + usergroup = data.get("usergroup") + if usergroup and ProjectUserGroupMembership.objects.filter(project=data["project_id"], usergroup=usergroup).exists(): + raise serializers.ValidationError({"usergroup": "Usergroup already exist in the project"}) return data def create(self, validated_data): project_user_group_membership = super().create(validated_data) - project_user_group_membership.added_by = self.context['request'].user - project_user_group_membership.save(update_fields=['added_by']) + project_user_group_membership.added_by = self.context["request"].user + project_user_group_membership.save(update_fields=["added_by"]) return project_user_group_membership @@ -516,8 +469,8 @@ class ProjectRecentActivitySerializer(serializers.Serializer): created_by_display_name = serializers.CharField() def get_created_by_display_picture(self, instance): - name = instance['created_by_display_picture'] - return name and self.context['request'].build_absolute_uri(URLCachedFileField.name_to_representation(name)) + name = instance["created_by_display_picture"] + return name and self.context["request"].build_absolute_uri(URLCachedFileField.name_to_representation(name)) # -------Graphql Serializer @@ -526,7 +479,7 @@ class ProjectJoinGqSerializer(serializers.ModelSerializer): DESCRIPTION_MAX_LENGTH = 500 project = serializers.CharField(required=True) - reason = serializers.CharField(source='data.reason', required=True) + reason = serializers.CharField(source="data.reason", required=True) role = serializers.CharField(required=False) requested_by = serializers.CharField(read_only=True) responded_by = serializers.CharField(read_only=True) @@ -534,42 +487,33 @@ class ProjectJoinGqSerializer(serializers.ModelSerializer): class Meta: model = ProjectJoinRequest - fields = ( - 'id', - 'reason', - 'role', - 'requested_by', - 'responded_by', - 'project', - 'status', - 'data' - ) + fields = ("id", "reason", "role", "requested_by", "responded_by", "project", "status", "data") def create(self, validated_data): - validated_data['requested_by'] = self.context['request'].user - validated_data['status'] = ProjectJoinRequest.Status.PENDING - validated_data['role_id'] = ProjectRole.get_default_role().id + validated_data["requested_by"] = self.context["request"].user + validated_data["status"] = ProjectJoinRequest.Status.PENDING + validated_data["role_id"] = ProjectRole.get_default_role().id instance = super().create(validated_data) - transaction.on_commit( - lambda: send_project_join_request_emails.delay(instance.id) - ) + transaction.on_commit(lambda: send_project_join_request_emails.delay(instance.id)) return instance def validate_project(self, project): project = get_object_or_404(Project, id=project) if project.is_private: raise serializers.ValidationError("Cannot join private project") - if ProjectMembership.objects.filter(project=project, member=self.context['request'].user).exists(): + if ProjectMembership.objects.filter(project=project, member=self.context["request"].user).exists(): raise serializers.ValidationError("Already a member") - if ProjectJoinRequest.objects.filter(project=project, requested_by=self.context['request'].user).exists(): + if ProjectJoinRequest.objects.filter(project=project, requested_by=self.context["request"].user).exists(): raise serializers.ValidationError("Already sent project join request for project %s" % project.title) return project def validate_reason(self, reason): if not (self.DESCRIPTION_MIN_LENGTH <= len(reason) <= self.DESCRIPTION_MAX_LENGTH): raise serializers.ValidationError( - gettext("Must be at least %s characters and at most %s characters") % ( - self.DESCRIPTION_MIN_LENGTH, self.DESCRIPTION_MAX_LENGTH, + gettext("Must be at least %s characters and at most %s characters") + % ( + self.DESCRIPTION_MIN_LENGTH, + self.DESCRIPTION_MAX_LENGTH, ) ) return reason @@ -580,25 +524,21 @@ class ProjectAcceptRejectSerializer(serializers.ModelSerializer): class Meta: model = ProjectJoinRequest - fields = ( - 'id', - 'status', - 'role' - ) + fields = ("id", "status", "role") @staticmethod def _accept_request(responded_by, join_request, role): - if not role or role == 'normal': + if not role or role == "normal": role = ProjectRole.get_default_role() - elif role == 'admin': + elif role == "admin": role = ProjectRole.get_admin_role() else: role_qs = ProjectRole.objects.filter(id=role) if not role_qs.exists(): - raise serializers.ValidationError('Role doesnot exist') + raise serializers.ValidationError("Role doesnot exist") role = role_qs.first() - join_request.status = 'accepted' + join_request.status = "accepted" join_request.responded_by = responded_by join_request.responded_at = timezone.now() join_request.role = role @@ -608,35 +548,29 @@ def _accept_request(responded_by, join_request, role): project=join_request.project, member=join_request.requested_by, defaults={ - 'role': role, - 'added_by': responded_by, + "role": role, + "added_by": responded_by, }, ) - transaction.on_commit( - lambda: send_project_accept_email.delay(join_request.id) - ) + transaction.on_commit(lambda: send_project_accept_email.delay(join_request.id)) @staticmethod def _reject_request(responded_by, join_request): - join_request.status = 'rejected' + join_request.status = "rejected" join_request.responded_by = responded_by join_request.responded_at = timezone.now() join_request.save() - transaction.on_commit( - lambda: send_project_reject_email.delay(join_request.id) - ) + transaction.on_commit(lambda: send_project_reject_email.delay(join_request.id)) def update(self, instance, validated_data): - validated_data['project'] = self.context['request'].active_project - role = validated_data.pop('role', None) - if instance.status in ['accepted', 'rejected']: - raise serializers.ValidationError( - 'This request has already been {}'.format(instance.status) - ) - if validated_data['status'] == 'accepted': - ProjectAcceptRejectSerializer._accept_request(self.context['request'].user, instance, role) - elif validated_data['status'] == 'rejected': - ProjectAcceptRejectSerializer._reject_request(self.context['request'].user, instance) + validated_data["project"] = self.context["request"].active_project + role = validated_data.pop("role", None) + if instance.status in ["accepted", "rejected"]: + raise serializers.ValidationError("This request has already been {}".format(instance.status)) + if validated_data["status"] == "accepted": + ProjectAcceptRejectSerializer._accept_request(self.context["request"].user, instance, role) + elif validated_data["status"] == "rejected": + ProjectAcceptRejectSerializer._reject_request(self.context["request"].user, instance) return instance @@ -647,63 +581,66 @@ class ProjectMembershipGqlSerializer(TempClientIdMixin, serializers.ModelSeriali class Meta: model = ProjectMembership fields = ( - 'id', 'member', 'role', 'badges', - 'client_id', + "id", + "member", + "role", + "badges", + "client_id", ) @cached_property def project(self): - project = self.context['request'].active_project + project = self.context["request"].active_project # This is a rare case, just to make sure this is validated if self.instance and self.instance.project != project: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return project @cached_property def current_user_role(self): return ProjectMembership.objects.get( project=self.project, - member=self.context['request'].user, + member=self.context["request"].user, ).role def validate_member(self, member): if self.instance: # Update if self.instance.member != member: # Changing member not allowed - raise serializers.ValidationError('Changing member is not allowed!') + raise serializers.ValidationError("Changing member is not allowed!") return member # Create current_members = ProjectMembership.objects.filter(project=self.project, member=member) if current_members.exclude(pk=self.instance and self.instance.pk).exists(): - raise serializers.ValidationError('User is already a member!') + raise serializers.ValidationError("User is already a member!") return member def validate_role(self, new_role): # Make sure higher role are never allowed if new_role.level < self.current_user_role.level: - raise serializers.ValidationError('Access is denied for higher role assignment.') + raise serializers.ValidationError("Access is denied for higher role assignment.") if ( - self.instance and # For Update - self.instance.role != new_role and # For changed role - ( - self.instance.role.level == self.current_user_role.level and # Requesting user role == current member role - self.instance.role.level < new_role.level # New role is lower then current role + self.instance # For Update + and self.instance.role != new_role # For changed role + and ( + self.instance.role.level == self.current_user_role.level # Requesting user role == current member role + and self.instance.role.level < new_role.level # New role is lower then current role ) ): - raise serializers.ValidationError('Changing same level role is not allowed!') + raise serializers.ValidationError("Changing same level role is not allowed!") return new_role def validate(self, data): - linked_group = (self.instance and self.instance.linked_group) + linked_group = self.instance and self.instance.linked_group if linked_group: raise serializers.ValidationError( - f'This user is added through usergroup: {linked_group}. Please update the respective usergroup.' + f"This user is added through usergroup: {linked_group}. Please update the respective usergroup." ) return data def create(self, validated_data): - validated_data['added_by'] = self.context['request'].user - validated_data['project'] = self.project + validated_data["added_by"] = self.context["request"].user + validated_data["project"] = self.project return super().create(validated_data) @@ -711,76 +648,79 @@ class ProjectUserGroupMembershipGqlSerializer(TempClientIdMixin, serializers.Mod class Meta: model = ProjectUserGroupMembership fields = ( - 'id', 'usergroup', 'role', 'badges', - 'client_id', + "id", + "usergroup", + "role", + "badges", + "client_id", ) @cached_property def project(self): - project = self.context['request'].active_project + project = self.context["request"].active_project # This is a rare case, just to make sure this is validated if self.instance and self.instance.project != project: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return project @cached_property def current_user_role(self): return ProjectMembership.objects.get( project=self.project, - member=self.context['request'].user, + member=self.context["request"].user, ).role def validate_usergroup(self, usergroup): if self.instance: # Update if self.instance.usergroup != usergroup: # Changing usergroup not allowed - raise serializers.ValidationError('Changing usergroup is not allowed!') + raise serializers.ValidationError("Changing usergroup is not allowed!") return usergroup # Create current_usergroup_members = ProjectUserGroupMembership.objects.filter(project=self.project, usergroup=usergroup) if current_usergroup_members.exclude(pk=self.instance and self.instance.pk).exists(): - raise serializers.ValidationError('UserGroup already a member!') + raise serializers.ValidationError("UserGroup already a member!") return usergroup def validate_role(self, new_role): if new_role.level < self.current_user_role.level: - raise serializers.ValidationError('Access is denied for higher role assignment.') + raise serializers.ValidationError("Access is denied for higher role assignment.") if ( - self.instance and # Update - self.instance.role != new_role and # Role is changed - ( - self.instance.role.level == self.current_user_role.level and # Requesting user role == current member role - self.instance.role.level < new_role.level # New role is lower then current role + self.instance # Update + and self.instance.role != new_role # Role is changed + and ( + self.instance.role.level == self.current_user_role.level # Requesting user role == current member role + and self.instance.role.level < new_role.level # New role is lower then current role ) ): - raise serializers.ValidationError('Changing same level role is not allowed!') + raise serializers.ValidationError("Changing same level role is not allowed!") return new_role def create(self, validated_data): - validated_data['project'] = self.project - validated_data['added_by'] = self.context['request'].user + validated_data["project"] = self.project + validated_data["added_by"] = self.context["request"].user return super().create(validated_data) class ProjectVizConfigurationSerializer(ProjectPropertySerializerMixin, serializers.ModelSerializer): class Action(models.TextChoices): - NEW = 'new', 'New' - ON = 'on', 'On' - OFF = 'off', 'Off' + NEW = "new", "New" + ON = "on", "On" + OFF = "off", "Off" class Meta: model = ProjectStats - fields = ('action',) + fields = ("action",) action = serializers.ChoiceField(choices=Action.choices) def validate(self, data): if not self.project.is_visualization_available: - raise serializers.ValidationError('Visualization is not available for this project') + raise serializers.ValidationError("Visualization is not available for this project") return data def save(self): - action = self.validated_data and self.validated_data['action'] + action = self.validated_data and self.validated_data["action"] return self.project.project_stats.update_public_share_configuration(action) @@ -790,32 +730,34 @@ class ProjectOrganizationGqSerializer(TempClientIdMixin, serializers.ModelSerial class Meta: model = ProjectOrganization fields = ( - 'id', 'organization', 'organization_type', - 'client_id', + "id", + "organization", + "organization_type", + "client_id", ) class ProjectGqSerializer(DeprecatedUserResourceSerializer): - organizations = ProjectOrganizationGqSerializer(source='projectorganization_set', many=True, required=False) + organizations = ProjectOrganizationGqSerializer(source="projectorganization_set", many=True, required=False) class Meta: model = Project fields = ( - 'title', - 'description', - 'start_date', - 'end_date', - 'status', - 'is_private', - 'is_test', - 'is_assessment_enabled', - 'analysis_framework', - 'is_visualization_enabled', - 'has_publicly_viewable_unprotected_leads', - 'has_publicly_viewable_restricted_leads', - 'has_publicly_viewable_confidential_leads', - 'enable_publicly_viewable_analysis_report_snapshot', - 'organizations', + "title", + "description", + "start_date", + "end_date", + "status", + "is_private", + "is_test", + "is_assessment_enabled", + "analysis_framework", + "is_visualization_enabled", + "has_publicly_viewable_unprotected_leads", + "has_publicly_viewable_restricted_leads", + "has_publicly_viewable_confidential_leads", + "enable_publicly_viewable_analysis_report_snapshot", + "organizations", ) # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -827,17 +769,16 @@ def _get_prefetch_related_instances_qs(self, qs): @cached_property def current_user(self): - return self.context['request'].user + return self.context["request"].user def validate_is_private(self, is_private): if self.instance: # For update, don't allow changing privacy. if self.instance.is_private != is_private: - raise serializers.ValidationError('Cannot change privacy of project.') + raise serializers.ValidationError("Cannot change privacy of project.") # For create, make sure user can feature permission to create private project. else: - private_access = self.current_user.profile.\ - get_accessible_features().filter(key=Feature.FeatureKey.PRIVATE_PROJECT) + private_access = self.current_user.profile.get_accessible_features().filter(key=Feature.FeatureKey.PRIVATE_PROJECT) if is_private and not private_access.exists(): raise serializers.ValidationError("You don't have permission to create private project") return is_private @@ -846,9 +787,7 @@ def validate_analysis_framework(self, framework): if (self.instance and self.instance.analysis_framework) == framework: return framework if not framework.can_get(self.current_user): - raise serializers.ValidationError( - "Given framework either doesn't exists or you don't have access to it" - ) + raise serializers.ValidationError("Given framework either doesn't exists or you don't have access to it") if not framework.is_private: return framework # Check membership+permissions if private @@ -860,16 +799,18 @@ def validate_analysis_framework(self, framework): return framework def validate(self, data): - is_private = data.get('is_private', self.instance and self.instance.is_private) - framework = data.get('analysis_framework', self.instance and self.instance.analysis_framework) + is_private = data.get("is_private", self.instance and self.instance.is_private) + framework = data.get("analysis_framework", self.instance and self.instance.analysis_framework) # Analysis Frameowrk check if (self.instance and self.instance.analysis_framework) != framework: # Check private if not is_private and framework.is_private: - raise serializers.ValidationError({ - 'analysis_framework': 'Cannot use private framework in public project', - }) + raise serializers.ValidationError( + { + "analysis_framework": "Cannot use private framework in public project", + } + ) return data def validate_title(self, title): @@ -897,13 +838,11 @@ class UserPinnedProjectSerializer(serializers.ModelSerializer): class Meta: model = ProjectPinned - fields = ( - 'project', - ) + fields = ("project",) @cached_property def current_user(self): - return self.context['request'].user + return self.context["request"].user @cached_property def get_queryset(self): @@ -911,18 +850,18 @@ def get_queryset(self): return pinned_project def validate(self, data): - if (self.get_queryset.count() >= 5): + if self.get_queryset.count() >= 5: raise serializers.ValidationError("User can pinned 5 project only!!!") return data def create(self, validated_data): - if self.get_queryset.filter(project=validated_data['project']).exists(): + if self.get_queryset.filter(project=validated_data["project"]).exists(): raise serializers.ValidationError("Project already pinned!!") - validated_data['user'] = self.current_user + validated_data["user"] = self.current_user if self.get_queryset: - validated_data['order'] = self.get_queryset.latest('order').order + 1 + validated_data["order"] = self.get_queryset.latest("order").order + 1 return super().create(validated_data) - validated_data['order'] = 1 + validated_data["order"] = 1 return super().create(validated_data) def update(self): @@ -934,15 +873,11 @@ class BulkProjectPinnedSerializer(TempClientIdMixin, UserResourceSerializer): class Meta: model = ProjectPinned - fields = ( - 'order', - 'client_id', - 'id' - ) + fields = ("order", "client_id", "id") @cached_property def current_user(self): - return self.context['request'].user + return self.context["request"].user @cached_property def get_queryset(self): @@ -950,7 +885,7 @@ def get_queryset(self): return pinned_project def validate(self, data): - if (self.get_queryset.count() >= 5): + if self.get_queryset.count() >= 5: raise serializers.ValidationError("User can pinned 5 project only!!!") return data diff --git a/apps/project/tasks.py b/apps/project/tasks.py index f1b2fbb665..9682bf26da 100644 --- a/apps/project/tasks.py +++ b/apps/project/tasks.py @@ -2,32 +2,27 @@ from collections import defaultdict from datetime import timedelta +from ary.stats import get_project_ary_entry_stats from celery import shared_task +from django.conf import settings from django.core.files.base import ContentFile -from django.utils import timezone from django.db import models -from redis_store import redis +from django.utils import timezone from djangorestframework_camel_case.render import CamelCaseJSONRenderer -from django.conf import settings - -from utils.files import generate_json_file_for_upload -from ary.stats import get_project_ary_entry_stats -from lead.models import Lead +from entry.filter_set import EntryGQFilterSet from entry.models import Entry - from lead.filter_set import LeadGQFilterSet -from entry.filter_set import EntryGQFilterSet +from lead.models import Lead +from redis_store import redis + +from utils.files import generate_json_file_for_upload -from .models import ( - Project, - ProjectStats, - ProjectMembership, -) +from .models import Project, ProjectMembership, ProjectStats logger = logging.getLogger(__name__) -VIZ_STATS_WAIT_LOCK_KEY = 'generate_project_viz_stats__wait_lock__{0}' -STATS_WAIT_LOCK_KEY = 'generate_project_stats__wait_lock' +VIZ_STATS_WAIT_LOCK_KEY = "generate_project_viz_stats__wait_lock__{0}" +STATS_WAIT_LOCK_KEY = "generate_project_stats__wait_lock" STATS_WAIT_TIMEOUT = ProjectStats.THRESHOLD_SECONDS @@ -46,25 +41,22 @@ def _generate_project_viz_stats(project_id): project_stats.file.delete() project_stats.confidential_file.delete() # Save new file - project_stats.file.save(f'project-stats-{project_id}.json', stats_content) - project_stats.confidential_file.save(f'project-stats-confidential-{project_id}.json', confidential_stats_content) + project_stats.file.save(f"project-stats-{project_id}.json", stats_content) + project_stats.confidential_file.save(f"project-stats-confidential-{project_id}.json", confidential_stats_content) project_stats.save() except Exception: - logger.warning(f'Ary Stats Generation Failed ({project_id})!!', exc_info=True) + logger.warning(f"Ary Stats Generation Failed ({project_id})!!", exc_info=True) project_stats.status = ProjectStats.Status.FAILURE project_stats.save() def get_project_stats(project, info, filters): # XXX: Circular dependency - from lead.schema import get_lead_qs from entry.schema import get_entry_qs + from lead.schema import get_lead_qs def _count_by_project(qs): - return qs\ - .filter(project=project)\ - .order_by().values('project')\ - .aggregate(count=models.Count('id', distinct=True))['count'] + return qs.filter(project=project).order_by().values("project").aggregate(count=models.Count("id", distinct=True))["count"] if info.context.active_project: lead_qs = get_lead_qs(info) @@ -74,7 +66,7 @@ def _count_by_project(qs): entry_qs = Entry.objects.filter(project=project, analysis_framework=project.analysis_framework_id) filters_counts = {} if filters: - entry_filter_data = filters.get('entries_filter_data') or {} + entry_filter_data = filters.get("entries_filter_data") or {} filtered_lead_qs = LeadGQFilterSet(request=info.context.request, queryset=lead_qs, data=filters).qs filtered_entry_qs = EntryGQFilterSet( request=info.context.request, @@ -110,52 +102,64 @@ def _generate_project_stats_cache(): def _count_by_project_qs(qs): return { project: count - for project, count in qs.order_by().values('project').annotate( - count=models.Count('id', distinct=True) - ).values_list('project', 'count') + for project, count in qs.order_by() + .values("project") + .annotate(count=models.Count("id", distinct=True)) + .values_list("project", "count") } def _count_by_project_date_qs(qs): data = defaultdict(list) for project, count, date in ( - qs - .order_by('project', 'created_at__date') - .values('project', 'created_at__date') - .annotate(count=models.Count('id', distinct=True)) - .values_list('project', 'count', models.Func(models.F('created_at__date'), function='DATE')) + qs.order_by("project", "created_at__date") + .values("project", "created_at__date") + .annotate(count=models.Count("id", distinct=True)) + .values_list("project", "count", models.Func(models.F("created_at__date"), function="DATE")) ): - data[project].append({ - 'date': date and date.strftime('%Y-%m-%d'), - 'count': count, - }) + data[project].append( + { + "date": date and date.strftime("%Y-%m-%d"), + "count": count, + } + ) return data current_time = timezone.now() threshold = ProjectStats.get_activity_timeframe(current_time) # Make sure to only look for entries which have same AF as Project's AF - all_entries_qs = Entry.objects.filter(analysis_framework=models.F('project__analysis_framework')) + all_entries_qs = Entry.objects.filter(analysis_framework=models.F("project__analysis_framework")) recent_leads = Lead.objects.filter(created_at__gte=threshold) recent_entries = all_entries_qs.filter(created_at__gte=threshold) # Calculate leads_count_map = _count_by_project_qs(Lead.objects.all()) leads_tagged_and_controlled_count_map = _count_by_project_qs( - Lead.objects.filter(status=Lead.Status.TAGGED).annotate( + Lead.objects.filter(status=Lead.Status.TAGGED) + .annotate( entries_count=models.Subquery( all_entries_qs.filter( - lead=models.OuterRef('pk'), - ).order_by().values('lead').annotate(count=models.Count('id')).values('count')[:1], - output_field=models.IntegerField() + lead=models.OuterRef("pk"), + ) + .order_by() + .values("lead") + .annotate(count=models.Count("id")) + .values("count")[:1], + output_field=models.IntegerField(), ), entries_controlled_count=models.Subquery( all_entries_qs.filter( - lead=models.OuterRef('pk'), + lead=models.OuterRef("pk"), controlled=True, - ).order_by().values('lead').annotate(count=models.Count('id')).values('count')[:1], - output_field=models.IntegerField() + ) + .order_by() + .values("lead") + .annotate(count=models.Count("id")) + .values("count")[:1], + output_field=models.IntegerField(), ), - ).filter(entries_count__gt=0, entries_count=models.F('entries_controlled_count')) + ) + .filter(entries_count__gt=0, entries_count=models.F("entries_controlled_count")) ) leads_not_tagged_count_map = _count_by_project_qs(Lead.objects.filter(status=Lead.Status.NOT_TAGGED)) leads_in_progress_count_map = _count_by_project_qs(Lead.objects.filter(status=Lead.Status.IN_PROGRESS)) @@ -193,7 +197,7 @@ def _count_by_project_date_qs(qs): leads_activities=leads_activity_map.get(pk, []), entries_activities=entries_activity_map.get(pk, []), ) - project.save(update_fields=['stats_cache']) + project.save(update_fields=["stats_cache"]) @shared_task @@ -205,10 +209,10 @@ def generate_viz_stats(project_id, force=False): lock = redis.get_lock(key, STATS_WAIT_TIMEOUT) have_lock = lock.acquire(blocking=False) if not have_lock and not force: - logger.warning(f'GENERATE_PROJECT_VIZ_STATS:: Waiting for timeout {key}') + logger.warning(f"GENERATE_PROJECT_VIZ_STATS:: Waiting for timeout {key}") return False - logger.info(f'GENERATE_PROJECT_STATS:: Processing for {key}') + logger.info(f"GENERATE_PROJECT_STATS:: Processing for {key}") _generate_project_viz_stats(project_id) # NOTE: lock.release() is not called so that another process waits for timeout return True @@ -223,17 +227,17 @@ def generate_project_stats_cache(force=False): lock = redis.get_lock(key, STATS_WAIT_TIMEOUT) have_lock = lock.acquire(blocking=False) if not have_lock and not force: - logger.warning(f'GENERATE_PROJECT_STATS:: Waiting for timeout {key}') + logger.warning(f"GENERATE_PROJECT_STATS:: Waiting for timeout {key}") return False - logger.info(f'GENERATE_PROJECT_STATS:: Processing for {key}') + logger.info(f"GENERATE_PROJECT_STATS:: Processing for {key}") _generate_project_stats_cache() lock.release() return True def generate_project_geo_region_cache(project): - region_qs = project.regions.defer('geo_options', 'centroid') + region_qs = project.regions.defer("geo_options", "centroid") geo_options = {} for region in region_qs: @@ -242,30 +246,28 @@ def generate_project_geo_region_cache(project): geo_options[region.pk] = region.geo_options project.geo_cache_file.save( - f'project-geo-cache-{project.pk}.json', + f"project-geo-cache-{project.pk}.json", ContentFile(CamelCaseJSONRenderer().render(geo_options)), save=False, ) - project.geo_cache_hash = hash(tuple(region_qs.order_by('id').values_list('cache_index', flat=True))) - project.save(update_fields=('geo_cache_hash', 'geo_cache_file')) + project.geo_cache_hash = hash(tuple(region_qs.order_by("id").values_list("cache_index", flat=True))) + project.save(update_fields=("geo_cache_hash", "geo_cache_file")) @shared_task def permanently_delete_projects(): # check every project if there `is_deleted` is set True # if greater than settings.USER_AND_PROJECT_DELETE_IN_DAYS days delete those projects - logger.info('[Project Delete] Checking project to delete.') - threshold = ( - timezone.now() - timedelta(days=settings.USER_AND_PROJECT_DELETE_IN_DAYS) - ) + logger.info("[Project Delete] Checking project to delete.") + threshold = timezone.now() - timedelta(days=settings.USER_AND_PROJECT_DELETE_IN_DAYS) project_qs = Project.objects.filter( is_deleted=True, deleted_at__isnull=False, deleted_at__lt=threshold, ) - logger.info(f'[Project Delete] Found {project_qs.count()} projects to delete.') + logger.info(f"[Project Delete] Found {project_qs.count()} projects to delete.") for project in project_qs: - _meta = f'{project.id}::{project.title}' - logger.info(f'[Project Delete] Deleting {_meta}') + _meta = f"{project.id}::{project.title}" + logger.info(f"[Project Delete] Deleting {_meta}") project_delete_response = project.delete() - logger.info(f'[Project Delete] Deleted {_meta}:: {project_delete_response}') + logger.info(f"[Project Delete] Deleted {_meta}:: {project_delete_response}") diff --git a/apps/project/tests/entry_stats_data.py b/apps/project/tests/entry_stats_data.py index da38250436..39f3bff74e 100644 --- a/apps/project/tests/entry_stats_data.py +++ b/apps/project/tests/entry_stats_data.py @@ -1,187 +1,183 @@ -WIDGET_DATA = { -} +WIDGET_DATA = {} # NOTE: This structure and value are set through https://github.com/the-deep/client WIDGET_DATA = { - 'multiselectWidget': { - 'options': [ - {'key': 'option-1', 'label': 'Option 1'}, - {'key': 'option-2', 'label': 'Option 2'}, - {'key': 'option-3', 'label': 'Option 3'} + "multiselectWidget": { + "options": [ + {"key": "option-1", "label": "Option 1"}, + {"key": "option-2", "label": "Option 2"}, + {"key": "option-3", "label": "Option 3"}, ] }, - 'scaleWidget': { - 'options': [ - {'key': 'scale-1', 'color': '#470000', 'label': 'Scale 1'}, - {'key': 'scale-2', 'color': '#a40000', 'label': 'Scale 2'}, - {'key': 'scale-3', 'color': '#d40000', 'label': 'Scale 3'} + "scaleWidget": { + "options": [ + {"key": "scale-1", "color": "#470000", "label": "Scale 1"}, + {"key": "scale-2", "color": "#a40000", "label": "Scale 2"}, + {"key": "scale-3", "color": "#d40000", "label": "Scale 3"}, ] }, - - 'matrix1dWidget': { - 'rows': [ + "matrix1dWidget": { + "rows": [ { - 'key': 'pillar-1', - 'cells': [ - {'key': 'subpillar-1', 'value': 'Politics'}, - {'key': 'subpillar-2', 'value': 'Security'}, - {'key': 'subpillar-3', 'value': 'Legal & Policy'}, - {'key': 'subpillar-4', 'value': 'Demography'}, - {'key': 'subpillar-5', 'value': 'Economy'}, - {'key': 'subpillar-5', 'value': 'Socio Cultural'}, - {'key': 'subpillar-7', 'value': 'Environment'}, + "key": "pillar-1", + "cells": [ + {"key": "subpillar-1", "value": "Politics"}, + {"key": "subpillar-2", "value": "Security"}, + {"key": "subpillar-3", "value": "Legal & Policy"}, + {"key": "subpillar-4", "value": "Demography"}, + {"key": "subpillar-5", "value": "Economy"}, + {"key": "subpillar-5", "value": "Socio Cultural"}, + {"key": "subpillar-7", "value": "Environment"}, ], - 'color': '#c26b27', - 'label': 'Context', - 'tooltip': 'Information about the environment in which humanitarian actors operates and the crisis happen', # noqa E501 - }, { - 'key': 'pillar-2', - 'cells': [ - {'key': 'subpillar-8', 'value': 'Affected Groups'}, - {'key': 'subpillar-9', 'value': 'Population Movement'}, - {'key': 'subpillar-10', 'value': 'Push/Pull Factors'}, - {'key': 'subpillar-11', 'value': 'Casualties'}, + "color": "#c26b27", + "label": "Context", + "tooltip": "Information about the environment in which humanitarian actors operates and the crisis happen", # noqa E501 + }, + { + "key": "pillar-2", + "cells": [ + {"key": "subpillar-8", "value": "Affected Groups"}, + {"key": "subpillar-9", "value": "Population Movement"}, + {"key": "subpillar-10", "value": "Push/Pull Factors"}, + {"key": "subpillar-11", "value": "Casualties"}, ], - 'color': '#efaf78', - 'label': 'Humanitarian Profile', - 'tooltip': 'Information related to the population affected, including affected residents and displaced people', # noqa E501 - }, { - 'key': 'pillar-3', - 'cells': [ - {'key': 'subpillar-12', 'value': 'Relief to Beneficiaries'}, - {'key': 'subpillar-13', 'value': 'Beneficiaries to Relief'}, - {'key': 'subpillar-14', 'value': 'Physical Constraints'}, - {'key': 'subpillar-15', 'value': 'Humanitarian Access Gaps'}, + "color": "#efaf78", + "label": "Humanitarian Profile", + "tooltip": "Information related to the population affected, including affected residents and displaced people", # noqa E501 + }, + { + "key": "pillar-3", + "cells": [ + {"key": "subpillar-12", "value": "Relief to Beneficiaries"}, + {"key": "subpillar-13", "value": "Beneficiaries to Relief"}, + {"key": "subpillar-14", "value": "Physical Constraints"}, + {"key": "subpillar-15", "value": "Humanitarian Access Gaps"}, ], - 'color': '#b9b2a5', - 'label': 'Humanitarian Access', - 'tooltip': 'Information related to restrictions and constraints in accessing or being accessed by people in need', # noqa E501 - }, { - 'key': 'pillar-4', - 'cells': [ - {'key': 'subpillar-16', 'value': 'Communication Means & Channels'}, - {'key': 'subpillar-17', 'value': 'Information Challenges'}, - {'key': 'subpillar-18', 'value': 'Information Needs & Gaps'}, + "color": "#b9b2a5", + "label": "Humanitarian Access", + "tooltip": "Information related to restrictions and constraints in accessing or being accessed by people in need", # noqa E501 + }, + { + "key": "pillar-4", + "cells": [ + {"key": "subpillar-16", "value": "Communication Means & Channels"}, + {"key": "subpillar-17", "value": "Information Challenges"}, + {"key": "subpillar-18", "value": "Information Needs & Gaps"}, ], - 'color': '#9bd65b', - 'label': 'Information', - 'tooltip': 'Information about information, including communication means, information challenges and information needs', # noqa E501 - }] + "color": "#9bd65b", + "label": "Information", + "tooltip": "Information about information, including communication means, information challenges and information needs", # noqa E501 + }, + ] }, - - 'matrix2dWidget': { - 'columns': [ - {'key': 'sector-9', 'label': 'Cross', 'tooltip': 'Cross sectoral information', 'subColumns': []}, - {'key': 'sector-0', 'label': 'Food', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-1', 'label': 'Livelihoods', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-2', 'label': 'Health', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-3', 'label': 'Nutrition', 'tooltip': '...', 'subColumns': []}, + "matrix2dWidget": { + "columns": [ + {"key": "sector-9", "label": "Cross", "tooltip": "Cross sectoral information", "subColumns": []}, + {"key": "sector-0", "label": "Food", "tooltip": "...", "subColumns": []}, + {"key": "sector-1", "label": "Livelihoods", "tooltip": "...", "subColumns": []}, + {"key": "sector-2", "label": "Health", "tooltip": "...", "subColumns": []}, + {"key": "sector-3", "label": "Nutrition", "tooltip": "...", "subColumns": []}, { - 'key': 'sector-4', - 'label': 'WASH', - 'tooltip': '...', - 'subColumns': [ - {'key': 'subsector-1', 'label': 'Water'}, - {'key': 'subsector-2', 'label': 'Sanitation'}, - {'key': 'subsector-3', 'label': 'Hygiene'}, - {'key': 'subsector-4', 'label': 'Waste management', 'tooltip': ''}, - {'key': 'subsector-5', 'label': 'Vector control', 'tooltip': ''} - ] + "key": "sector-4", + "label": "WASH", + "tooltip": "...", + "subColumns": [ + {"key": "subsector-1", "label": "Water"}, + {"key": "subsector-2", "label": "Sanitation"}, + {"key": "subsector-3", "label": "Hygiene"}, + {"key": "subsector-4", "label": "Waste management", "tooltip": ""}, + {"key": "subsector-5", "label": "Vector control", "tooltip": ""}, + ], }, - {'key': 'sector-5', 'label': 'Shelter', 'tooltip': '...', 'subColumns': []}, + {"key": "sector-5", "label": "Shelter", "tooltip": "...", "subColumns": []}, { - 'key': 'sector-7', - 'label': 'Education', - 'tooltip': '.....', - 'subColumns': [ - {'key': 'subsector-6', 'label': 'Learning Environment', 'tooltip': ''}, - {'key': 'subsector-7', 'label': 'Teaching and Learning', 'tooltip': ''}, - {'key': 'subsector-8', 'label': 'Teachers and Education Personnel', 'tooltip': ''}, - ] + "key": "sector-7", + "label": "Education", + "tooltip": ".....", + "subColumns": [ + {"key": "subsector-6", "label": "Learning Environment", "tooltip": ""}, + {"key": "subsector-7", "label": "Teaching and Learning", "tooltip": ""}, + {"key": "subsector-8", "label": "Teachers and Education Personnel", "tooltip": ""}, + ], }, - {'key': 'sector-8', 'label': 'Protection', 'tooltip': '', 'subColumns': []}, - {'key': 'sector-10', 'label': 'Agriculture', 'tooltip': '...', 'subColumns': []}, - {'key': 'sector-11', 'label': 'Logistics', 'tooltip': '...', 'subColumns': []} + {"key": "sector-8", "label": "Protection", "tooltip": "", "subColumns": []}, + {"key": "sector-10", "label": "Agriculture", "tooltip": "...", "subColumns": []}, + {"key": "sector-11", "label": "Logistics", "tooltip": "...", "subColumns": []}, ], - 'rows': [ + "rows": [ { - 'key': 'dimension-0', - 'color': '#eae285', - 'label': 'Scope & Scale', - 'tooltip': 'Information about the direct and indirect impact of the disaster or crisis', - 'subRows': [ - {'key': 'subdimension-0', 'label': 'Drivers/Aggravating Factors', 'tooltip': '...'}, - {'key': 'subdimension-3', 'label': 'System Disruption', 'tooltip': '...'}, - {'key': 'subdimension-4', 'label': 'Damages & Losses', 'tooltip': '...'}, - {'key': 'subdimension-6', 'label': 'Lessons Learnt', 'tooltip': '...'} - ] + "key": "dimension-0", + "color": "#eae285", + "label": "Scope & Scale", + "tooltip": "Information about the direct and indirect impact of the disaster or crisis", + "subRows": [ + {"key": "subdimension-0", "label": "Drivers/Aggravating Factors", "tooltip": "..."}, + {"key": "subdimension-3", "label": "System Disruption", "tooltip": "..."}, + {"key": "subdimension-4", "label": "Damages & Losses", "tooltip": "..."}, + {"key": "subdimension-6", "label": "Lessons Learnt", "tooltip": "..."}, + ], }, { - 'key': 'dimension-1', - 'color': '#fba855', - 'label': 'Humanitarian Conditions', - 'tooltip': '...', - 'subRows': [ - {'key': 'subdimension-1', 'label': 'Living Standards', 'tooltip': '...'}, - {'key': 'us9kizxxwha7cpgb', 'label': 'Coping Mechanisms', 'tooltip': ''}, - {'key': 'subdimension-7', 'label': 'Physical & mental wellbeing', 'tooltip': '..'}, - {'key': 'subdimension-8', 'label': 'Risks & Vulnerabilities', 'tooltip': '...'}, - {'key': 'ejve4vklgge9ysxm', 'label': 'People with Specific Needs', 'tooltip': ''}, - {'key': 'subdimension-10', 'label': 'Unmet Needs', 'tooltip': '...'}, - {'key': 'subdimension-16', 'label': 'Lessons Learnt', 'tooltip': '...'}, - ] + "key": "dimension-1", + "color": "#fba855", + "label": "Humanitarian Conditions", + "tooltip": "...", + "subRows": [ + {"key": "subdimension-1", "label": "Living Standards", "tooltip": "..."}, + {"key": "us9kizxxwha7cpgb", "label": "Coping Mechanisms", "tooltip": ""}, + {"key": "subdimension-7", "label": "Physical & mental wellbeing", "tooltip": ".."}, + {"key": "subdimension-8", "label": "Risks & Vulnerabilities", "tooltip": "..."}, + {"key": "ejve4vklgge9ysxm", "label": "People with Specific Needs", "tooltip": ""}, + {"key": "subdimension-10", "label": "Unmet Needs", "tooltip": "..."}, + {"key": "subdimension-16", "label": "Lessons Learnt", "tooltip": "..."}, + ], }, { - 'key': 'dimension-2', - 'color': '#92c5f6', - 'label': 'Capacities & Response', - 'tooltip': '...', - 'subRows': [ - {'key': '7iiastsikxackbrt', 'label': 'System Functionality', 'tooltip': '...'}, - {'key': 'subdimension-11', 'label': 'Government', 'tooltip': '...'}, - {'key': 'drk4j92jwvmck7dc', 'label': 'LNGO', 'tooltip': '...'}, - {'key': 'subdimension-12', 'label': 'International', 'tooltip': '...'}, - {'key': 'subdimension-14', 'label': 'Response Gaps', 'tooltip': '...'}, - {'key': 'subdimension-15', 'label': 'Lessons Learnt', 'tooltip': '...'}, - ] - } - ] + "key": "dimension-2", + "color": "#92c5f6", + "label": "Capacities & Response", + "tooltip": "...", + "subRows": [ + {"key": "7iiastsikxackbrt", "label": "System Functionality", "tooltip": "..."}, + {"key": "subdimension-11", "label": "Government", "tooltip": "..."}, + {"key": "drk4j92jwvmck7dc", "label": "LNGO", "tooltip": "..."}, + {"key": "subdimension-12", "label": "International", "tooltip": "..."}, + {"key": "subdimension-14", "label": "Response Gaps", "tooltip": "..."}, + {"key": "subdimension-15", "label": "Lessons Learnt", "tooltip": "..."}, + ], + }, + ], }, - - 'geoWidget': {}, + "geoWidget": {}, } # NOTE: This structure and value are set through https://github.com/the-deep/client ATTRIBUTE_DATA = { - 'geoWidget': {}, - - 'multiselectWidget': { - 'data': {'value': ['option-3', 'option-1']}, + "geoWidget": {}, + "multiselectWidget": { + "data": {"value": ["option-3", "option-1"]}, }, - - 'scaleWidget': { - 'data': {'value': 'scale-1'}, + "scaleWidget": { + "data": {"value": "scale-1"}, }, - - 'matrix1dWidget': { - 'data': { - 'value': { - 'pillar-2': {'subpillar-8': True}, - 'pillar-1': {'subpillar-7': False}, - 'pillar-4': {'subpillar-18': True}, + "matrix1dWidget": { + "data": { + "value": { + "pillar-2": {"subpillar-8": True}, + "pillar-1": {"subpillar-7": False}, + "pillar-4": {"subpillar-18": True}, }, }, }, - - 'matrix2dWidget': { - 'data': { - 'value': { - 'dimension-0': { - 'subdimension-4': { - 'sector-1': [], - 'sector-4': ['subsector-2', 'subsector-4'], - 'sector-7': ['subsector-8', 'subsector-6'] + "matrix2dWidget": { + "data": { + "value": { + "dimension-0": { + "subdimension-4": { + "sector-1": [], + "sector-4": ["subsector-2", "subsector-4"], + "sector-7": ["subsector-8", "subsector-6"], } } }, diff --git a/apps/project/tests/test_apis.py b/apps/project/tests/test_apis.py index bd74ee47a9..ebe1e88208 100644 --- a/apps/project/tests/test_apis.py +++ b/apps/project/tests/test_apis.py @@ -1,56 +1,37 @@ import uuid +from analysis_framework.models import AnalysisFramework, AnalysisFrameworkRole, Widget +from ary.models import AssessmentTemplate from dateutil.relativedelta import relativedelta from django.utils import timezone from django.utils.hashable import make_hashable - -from user.models import ( - User, - Feature, -) -from deep.tests import TestCase -from entry.models import ( - Lead, - Entry, - Attribute, -) -from quality_assurance.models import EntryReviewComment -from analysis_framework.models import ( - AnalysisFramework, - AnalysisFrameworkRole, - Widget, -) -from lead.models import LeadGroup +from entry.models import Attribute, Entry, Lead from geo.models import Region -from project.tasks import ( - _generate_project_viz_stats, - _generate_project_stats_cache, -) -from ary.models import AssessmentTemplate +from lead.models import LeadGroup +from organization.models import Organization from project.models import ( Project, - ProjectRole, - ProjectMembership, ProjectJoinRequest, - ProjectUserGroupMembership, + ProjectMembership, ProjectOrganization, + ProjectRole, ProjectStats, + ProjectUserGroupMembership, ) - -from organization.models import ( - Organization -) - +from project.tasks import _generate_project_stats_cache, _generate_project_viz_stats +from quality_assurance.models import EntryReviewComment +from user.models import Feature, User from user_group.models import UserGroup -from . import entry_stats_data +from deep.tests import TestCase +from . import entry_stats_data # TODO Document properly some of the following complex tests class ProjectApiTest(TestCase): - fixtures = ['ary_template_data.json'] + fixtures = ["ary_template_data.json"] def setUp(self): super().setUp() @@ -59,25 +40,25 @@ def setUp(self): self.user2 = self.create(User) self.user3 = self.create(User) # and some user groups - self.ug1 = self.create(UserGroup, role='admin') + self.ug1 = self.create(UserGroup, role="admin") self.ug1.add_member(self.user1) self.ug1.add_member(self.user2) - self.ug2 = self.create(UserGroup, role='admin') + self.ug2 = self.create(UserGroup, role="admin") self.ug2.add_member(self.user2) self.ug2.add_member(self.user3) - self.org1 = self.create(Organization, title='Test Organization') - self.region1 = self.create(Region, title='ACU') - self.region2 = self.create(Region, title='NSW') + self.org1 = self.create(Organization, title="Test Organization") + self.region1 = self.create(Region, title="ACU") + self.region2 = self.create(Region, title="NSW") def test_create_project(self): project_count = Project.objects.count() - url = '/api/v1/projects/' + url = "/api/v1/projects/" data = { - 'title': 'Test project', - 'data': {'testKey': 'testValue'}, - 'organizations': [ - {'organization': self.org1.id, 'organization_type': ProjectOrganization.Type.DONOR}, + "title": "Test project", + "data": {"testKey": "testValue"}, + "organizations": [ + {"organization": self.org1.id, "organization_type": ProjectOrganization.Type.DONOR}, ], } @@ -86,19 +67,19 @@ def test_create_project(self): self.assert_201(response) self.assertEqual(Project.objects.count(), project_count + 1) - self.assertEqual(response.data['title'], data['title']) + self.assertEqual(response.data["title"], data["title"]) def test_check_assessment_template_in_project_create(self): project_count = Project.objects.count() assessment = self.create(AssessmentTemplate) - url = '/api/v1/projects/' + url = "/api/v1/projects/" data = { - 'title': 'Test project', - 'data': {'testKey': 'testValue'}, - 'organizations': [ - {'organization': self.org1.id, 'organization_type': ProjectOrganization.Type.DONOR}, + "title": "Test project", + "data": {"testKey": "testValue"}, + "organizations": [ + {"organization": self.org1.id, "organization_type": ProjectOrganization.Type.DONOR}, ], - 'has_assessments': True + "has_assessments": True, } self.authenticate() @@ -106,27 +87,27 @@ def test_check_assessment_template_in_project_create(self): self.assert_201(response) self.assertEqual(Project.objects.count(), project_count + 1) - self.assertEqual(response.data['assessment_template'], assessment.id) + self.assertEqual(response.data["assessment_template"], assessment.id) # providing `has_assessments=False` - data['has_assessments'] = False + data["has_assessments"] = False self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertNotIn('assessment_template', response.data) + self.assertNotIn("assessment_template", response.data) # providing `has_assessments=None` - data['has_assessments'] = None + data["has_assessments"] = None self.authenticate() response = self.client.post(url, data) self.assert_400(response) def create_project_api(self, **kwargs): - url = '/api/v1/projects/' + url = "/api/v1/projects/" data = { - 'title': kwargs.get('title'), - 'is_private': kwargs.get('is_private'), - 'organizations': kwargs.get('organizations', []) + "title": kwargs.get("title"), + "is_private": kwargs.get("is_private"), + "organizations": kwargs.get("organizations", []), } response = self.client.post(url, data) @@ -134,19 +115,24 @@ def create_project_api(self, **kwargs): def test_get_projects(self): user_fhx = self.create(User) - self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - users=[user_fhx], email_domains=[]) + self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + users=[user_fhx], + email_domains=[], + ) self.authenticate(user_fhx) - self.create_project_api(title='Project 1', is_private=False) - self.create_project_api(title='Project 2', is_private=False) - self.create_project_api(title='Project 3', is_private=False) - self.create_project_api(title='Project 4', is_private=False) - self.create_project_api(title='Private Project 1', is_private=True) + self.create_project_api(title="Project 1", is_private=False) + self.create_project_api(title="Project 2", is_private=False) + self.create_project_api(title="Project 3", is_private=False) + self.create_project_api(title="Project 4", is_private=False) + self.create_project_api(title="Private Project 1", is_private=True) - response = self.client.get('/api/v1/projects/') - self.assertEqual(len(response.data['results']), 5) + response = self.client.get("/api/v1/projects/") + self.assertEqual(len(response.data["results"]), 5) other_user = self.create(User) self.authenticate(other_user) @@ -154,8 +140,8 @@ def test_get_projects(self): # self.create_project_api(title='Project 5', is_private=False) # self.create_project_api(title='Private Project 3', is_private=True) - response = self.client.get('/api/v1/projects/') - self.assertEqual(len(response.data['results']), 4) + response = self.client.get("/api/v1/projects/") + self.assertEqual(len(response.data["results"]), 4) def test_get_project_members(self): user1 = self.create(User) @@ -172,17 +158,15 @@ def test_get_project_members(self): project = self.create(Project) project.add_member(user1) - ProjectUserGroupMembership.objects.create( - project=project, usergroup=usergroup, badges=[ProjectMembership.BadgeType.QA] - ) + ProjectUserGroupMembership.objects.create(project=project, usergroup=usergroup, badges=[ProjectMembership.BadgeType.QA]) - url = f'/api/v1/projects/{project.id}/members/' + url = f"/api/v1/projects/{project.id}/members/" self.authenticate(user1) # autheniticate with the members only resp = self.client.get(url) self.assert_200(resp) - userids = [x['id'] for x in resp.data['results']] + userids = [x["id"] for x in resp.data["results"]] assert user1.id in userids assert user2.id not in userids assert userg1.id in userids @@ -191,23 +175,28 @@ def test_get_project_members(self): def test_create_private_project(self): # project_count = Project.objects.count() - url = '/api/v1/projects/' + url = "/api/v1/projects/" data = { - 'title': 'Test private project', - 'is_private': 'true', - 'organizations': [], + "title": "Test private project", + "is_private": "true", + "organizations": [], } - user_fhx = self.create(User, email='fhx@togglecorp.com') - self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - users=[user_fhx], email_domains=[]) + user_fhx = self.create(User, email="fhx@togglecorp.com") + self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + users=[user_fhx], + email_domains=[], + ) self.authenticate(user_fhx) response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['is_private'], True) + self.assertEqual(response.data["is_private"], True) self.assertEqual(Project.objects.last().is_private, True) def test_change_private_project_to_public(self): @@ -222,37 +211,47 @@ def test_change_private_project_to_public(self): self._change_project_privacy_test(public_project, 403, self.user) def test_create_private_project_unauthorized(self): - user_fhx = self.create(User, email='fhx@togglecorp.com') - user_dummy = self.create(User, email='dummy@test.com') - - self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - users=[user_dummy], email_domains=[]) + user_fhx = self.create(User, email="fhx@togglecorp.com") + user_dummy = self.create(User, email="dummy@test.com") + + self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + users=[user_dummy], + email_domains=[], + ) self.authenticate(user_fhx) - self.assert_403(self.create_project_api(title='Private test', is_private=True)) + self.assert_403(self.create_project_api(title="Private test", is_private=True)) self.authenticate(user_dummy) - self.assert_201(self.create_project_api(title='Private test', is_private=True)) + self.assert_201(self.create_project_api(title="Private test", is_private=True)) def test_get_private_project_detail_unauthorized(self): - user_fhx = self.create(User, email='fhx@togglecorp.com') - self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - users=[user_fhx], email_domains=[]) + user_fhx = self.create(User, email="fhx@togglecorp.com") + self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + users=[user_fhx], + email_domains=[], + ) self.authenticate(user_fhx) - response = self.create_project_api(title='Test private project', is_private=True) + response = self.create_project_api(title="Test private project", is_private=True) self.assert_201(response) - self.assertEqual(response.data['is_private'], True) + self.assertEqual(response.data["is_private"], True) self.assertEqual(Project.objects.last().is_private, True) other_user = self.create(User) self.authenticate(other_user) - new_private_project_id = response.data['id'] - response = self.client.get(f'/api/v1/projects/{new_private_project_id}/') + new_private_project_id = response.data["id"] + response = self.client.get(f"/api/v1/projects/{new_private_project_id}/") self.assert_404(response) @@ -267,11 +266,11 @@ def test_private_project_use_public_framework(self): ProjectRole.get_owner_role(), ) - url = f'/api/v1/projects/{private_project.id}/' + url = f"/api/v1/projects/{private_project.id}/" data = { - 'title': private_project.title, - 'analysis_framework': public_framework.id, - 'organizations': [], + "title": private_project.title, + "analysis_framework": public_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -283,10 +282,7 @@ def test_private_project_use_private_framework_if_framework_member(self): private_project = self.create(Project, is_private=True, organizations=[]) private_framework = self.create(AnalysisFramework, is_private=False) - private_framework.add_member( - self.user, - private_framework.get_or_create_default_role() - ) + private_framework.add_member(self.user, private_framework.get_or_create_default_role()) private_project.add_member( self.user, @@ -294,11 +290,11 @@ def test_private_project_use_private_framework_if_framework_member(self): ProjectRole.get_owner_role(), ) - url = f'/api/v1/projects/{private_project.id}/' + url = f"/api/v1/projects/{private_project.id}/" data = { - 'title': private_project.title, - 'analysis_framework': private_framework.id, - 'organizations': [], + "title": private_project.title, + "analysis_framework": private_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -316,11 +312,11 @@ def test_private_project_use_private_framework_if_not_framework_member(self): ProjectRole.get_owner_role(), ) - url = f'/api/v1/projects/{private_project.id}/' + url = f"/api/v1/projects/{private_project.id}/" data = { - 'title': private_project.title, - 'analysis_framework': private_framework.id, - 'organizations': [], + "title": private_project.title, + "analysis_framework": private_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -335,10 +331,7 @@ def test_private_project_use_private_framework_if_framework_member_no_can_use(se private_framework = self.create(AnalysisFramework, is_private=True) framework_role_no_permissions = AnalysisFrameworkRole.objects.create() - private_framework.add_member( - self.user, - framework_role_no_permissions - ) + private_framework.add_member(self.user, framework_role_no_permissions) private_project.add_member( self.user, @@ -346,11 +339,11 @@ def test_private_project_use_private_framework_if_framework_member_no_can_use(se ProjectRole.get_owner_role(), ) - url = f'/api/v1/projects/{private_project.id}/' + url = f"/api/v1/projects/{private_project.id}/" data = { - 'title': private_project.title, - 'analysis_framework': private_framework.id, - 'organizations': [], + "title": private_project.title, + "analysis_framework": private_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -370,11 +363,11 @@ def test_public_project_use_public_framework(self): ProjectRole.get_owner_role(), ) - url = f'/api/v1/projects/{public_project.id}/' + url = f"/api/v1/projects/{public_project.id}/" data = { - 'title': public_project.title, - 'analysis_framework': public_framework.id, - 'organizations': [], + "title": public_project.title, + "analysis_framework": public_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -397,11 +390,11 @@ def test_public_project_use_private_framework(self): # has can_use_in_other_projects True ) - url = f'/api/v1/projects/{public_project.id}/' + url = f"/api/v1/projects/{public_project.id}/" data = { - 'title': public_project.title, - 'analysis_framework': private_framework.id, - 'organizations': [], + "title": public_project.title, + "analysis_framework": private_framework.id, + "organizations": [], # ... don't care other fields } self.authenticate() @@ -410,47 +403,38 @@ def test_public_project_use_private_framework(self): def test_project_get_with_user_group_field(self): # TODO: can make this more generic for other fields as well - project = self.create( - Project, - user_groups=[], - title='TestProject', - role=self.admin_role, - organizations=[] - ) + project = self.create(Project, user_groups=[], title="TestProject", role=self.admin_role, organizations=[]) # Add usergroup - ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) + ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) # Now get project and validate fields - url = '/api/v1/projects/{}/'.format(project.pk) + url = "/api/v1/projects/{}/".format(project.pk) self.authenticate() response = self.client.get(url) self.assert_200(response) project = response.json() - assert 'id' in project - assert 'userGroups' in project - assert len(project['userGroups']) > 0 - for ug in project['userGroups']: + assert "id" in project + assert "userGroups" in project + assert len(project["userGroups"]) > 0 + for ug in project["userGroups"]: assert isinstance(ug, dict) - assert 'id' in ug - assert 'title' in ug + assert "id" in ug + assert "title" in ug def test_update_project_organizations(self): - org1 = self.create(Organization, title='Test Organization 1') - org2 = self.create(Organization, title='Test Organization 2') - org3 = self.create(Organization, title='Test Organization 3') - org4 = self.create(Organization, title='Test Organization 4') - org5 = self.create(Organization, title='Test Organization 5') + org1 = self.create(Organization, title="Test Organization 1") + org2 = self.create(Organization, title="Test Organization 2") + org3 = self.create(Organization, title="Test Organization 3") + org4 = self.create(Organization, title="Test Organization 4") + org5 = self.create(Organization, title="Test Organization 5") - url = '/api/v1/projects/' + url = "/api/v1/projects/" data = { - 'title': 'TestProject', - 'organizations': [ - {'organization': org1.id, 'organization_type': ProjectOrganization.Type.DONOR}, - {'organization': org2.id, 'organization_type': ProjectOrganization.Type.GOVERNMENT}, - {'organization': org3.id, 'organization_type': ProjectOrganization.Type.GOVERNMENT}, + "title": "TestProject", + "organizations": [ + {"organization": org1.id, "organization_type": ProjectOrganization.Type.DONOR}, + {"organization": org2.id, "organization_type": ProjectOrganization.Type.GOVERNMENT}, + {"organization": org3.id, "organization_type": ProjectOrganization.Type.GOVERNMENT}, ], } @@ -458,35 +442,27 @@ def test_update_project_organizations(self): response = self.client.post(url, data) self.assert_201(response) - url = '/api/v1/projects/{}/'.format(response.json()['id']) + url = "/api/v1/projects/{}/".format(response.json()["id"]) data = { - 'organizations': [ - {'organization': org4.id, 'organization_type': ProjectOrganization.Type.DONOR}, - {'organization': org5.id, 'organization_type': ProjectOrganization.Type.GOVERNMENT}, + "organizations": [ + {"organization": org4.id, "organization_type": ProjectOrganization.Type.DONOR}, + {"organization": org5.id, "organization_type": ProjectOrganization.Type.GOVERNMENT}, ], } response = self.client.patch(url, data) self.assert_200(response) - assert len(response.json()['organizations']) == 2 + assert len(response.json()["organizations"]) == 2 def test_update_project_add_user_group(self): - project = self.create( - Project, - user_groups=[], - title='TestProject', - role=self.admin_role - ) + project = self.create(Project, user_groups=[], title="TestProject", role=self.admin_role) memberships = ProjectMembership.objects.filter(project=project) initial_member_count = memberships.count() - url = f'/api/v1/projects/{project.id}/project-usergroups/' - data = { - 'usergroup': self.ug1.id, - 'role': self.normal_role.id - } + url = f"/api/v1/projects/{project.id}/project-usergroups/" + data = {"usergroup": self.ug1.id, "role": self.normal_role.id} self.authenticate() response = self.client.post(url, data) @@ -498,187 +474,104 @@ def test_update_project_add_user_group(self): self.assertEqual( initial_member_count + self.ug1.members.all().count() - 1, # -1 because usergroup admin and project admin is common - final_member_count + final_member_count, ) - self.assertEqual(response.data['role_details']['title'], self.normal_role.title) - self.assertEqual(response.data['project'], project.id) + self.assertEqual(response.data["role_details"]["title"], self.normal_role.title) + self.assertEqual(response.data["project"], project.id) def test_update_project_remove_ug(self): - project = self.create( - Project, - title='TestProject', - user_groups=[], - role=self.admin_role - ) + project = self.create(Project, title="TestProject", user_groups=[], role=self.admin_role) # Add usergroups - ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) - project_ug2 = ProjectUserGroupMembership.objects.create( - usergroup=self.ug2, - project=project - ) + ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) + project_ug2 = ProjectUserGroupMembership.objects.create(usergroup=self.ug2, project=project) - initial_member_count = ProjectMembership.objects.filter( - project=project - ).count() + initial_member_count = ProjectMembership.objects.filter(project=project).count() # We keep just ug1, and remove ug2 - url = f'/api/v1/projects/{project.id}/project-usergroups/{project_ug2.id}/' + url = f"/api/v1/projects/{project.id}/project-usergroups/{project_ug2.id}/" self.authenticate() response = self.client.delete(url) self.assert_204(response) - final_member_count = ProjectMembership.objects.filter( - project=project - ).count() + final_member_count = ProjectMembership.objects.filter(project=project).count() # now check for members self.assertEqual( # Subtract all members from second group except # the two users that are common in both user groups initial_member_count - self.ug2.members.all().count() + 2, - final_member_count + final_member_count, ) def test_duplicate_usergroup_add_in_project(self): - project = self.create( - Project, - title='For test', - user_groups=[], - role=self.admin_role - ) + project = self.create(Project, title="For test", user_groups=[], role=self.admin_role) # add usergroup to the project - ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) - membership_count = ProjectUserGroupMembership.objects.filter( - project=project - ).count() + ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) + membership_count = ProjectUserGroupMembership.objects.filter(project=project).count() # now try to create same usergroup from api level - data = { - 'usergroup': self.ug1.id, - 'role': self.normal_role.id - } - url = f'/api/v1/projects/{project.id}/project-usergroups/' + data = {"usergroup": self.ug1.id, "role": self.normal_role.id} + url = f"/api/v1/projects/{project.id}/project-usergroups/" self.authenticate() response = self.client.post(url, data) self.assert_400(response) - assert 'errors' in response.data - assert 'usergroup' in response.data['errors'] + assert "errors" in response.data + assert "usergroup" in response.data["errors"] # try deleting the usergroup - ProjectUserGroupMembership.objects.filter( - usergroup=self.ug1, - project=project - ).delete() - self.assertEqual(ProjectUserGroupMembership.objects.filter( - project=project - ).count(), membership_count - 1) + ProjectUserGroupMembership.objects.filter(usergroup=self.ug1, project=project).delete() + self.assertEqual(ProjectUserGroupMembership.objects.filter(project=project).count(), membership_count - 1) # now try to add the same usergroup - data = { - 'usergroup': self.ug1.id, - 'role': self.normal_role.id - } - url = f'/api/v1/projects/{project.id}/project-usergroups/' + data = {"usergroup": self.ug1.id, "role": self.normal_role.id} + url = f"/api/v1/projects/{project.id}/project-usergroups/" self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['added_by'], self.user.id) + self.assertEqual(response.data["added_by"], self.user.id) def test_add_user_to_usergroup(self): - project = self.create( - Project, - title='TestProject', - user_groups=[], - role=self.admin_role - ) + project = self.create(Project, title="TestProject", user_groups=[], role=self.admin_role) # Add usergroups - project_ug1 = ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) - initial_member_count = ProjectMembership.objects.filter( - project=project - ).count() + project_ug1 = ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) + initial_member_count = ProjectMembership.objects.filter(project=project).count() # Create a new user and add it to project_ug1 newUser = self.create(User) from user_group.models import GroupMembership - GroupMembership.objects.create( - member=newUser, - group=project_ug1.usergroup - ) - final_member_count = ProjectMembership.objects.filter( - project=project - ).count() + + GroupMembership.objects.create(member=newUser, group=project_ug1.usergroup) + final_member_count = ProjectMembership.objects.filter(project=project).count() self.assertEqual(initial_member_count + 1, final_member_count) def test_remove_user_in_only_one_usergroup(self): - project = self.create( - Project, - title='TestProject', - user_groups=[], - role=self.admin_role - ) + project = self.create(Project, title="TestProject", user_groups=[], role=self.admin_role) # Add usergroups - project_ug1 = ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) + project_ug1 = ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) - initial_member_count = ProjectMembership.objects.filter( - project=project - ).count() + initial_member_count = ProjectMembership.objects.filter(project=project).count() from user_group.models import GroupMembership - GroupMembership.objects.filter( - member=self.user1, # user1 belongs to ug1 - group=project_ug1.usergroup - ).delete() + GroupMembership.objects.filter(member=self.user1, group=project_ug1.usergroup).delete() # user1 belongs to ug1 - final_member_count = ProjectMembership.objects.filter( - project=project - ).count() + final_member_count = ProjectMembership.objects.filter(project=project).count() self.assertEqual(initial_member_count - 1, final_member_count) def test_remove_user_in_only_multiple_usergroups(self): - project = self.create( - Project, - title='TestProject', - user_groups=[], - role=self.admin_role - ) + project = self.create(Project, title="TestProject", user_groups=[], role=self.admin_role) # Add usergroups - project_ug1 = ProjectUserGroupMembership.objects.create( - usergroup=self.ug1, - project=project - ) - ProjectUserGroupMembership.objects.create( - usergroup=self.ug2, - project=project - ) + project_ug1 = ProjectUserGroupMembership.objects.create(usergroup=self.ug1, project=project) + ProjectUserGroupMembership.objects.create(usergroup=self.ug2, project=project) - initial_member_count = ProjectMembership.objects.filter( - project=project - ).count() + initial_member_count = ProjectMembership.objects.filter(project=project).count() from user_group.models import GroupMembership - GroupMembership.objects.filter( - member=self.user2, # user1 belongs to ug1 and ug2 - group=project_ug1.usergroup - ).delete() + GroupMembership.objects.filter(member=self.user2, group=project_ug1.usergroup).delete() # user1 belongs to ug1 and ug2 - final_member_count = ProjectMembership.objects.filter( - project=project - ).count() + final_member_count = ProjectMembership.objects.filter(project=project).count() # Should be no change in membeship as user2 is member from ug2 as well self.assertEqual(initial_member_count, final_member_count) @@ -686,40 +579,40 @@ def test_member_of(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - url = '/api/v1/projects/member-of/' + url = "/api/v1/projects/member-of/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['id'], project.id) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["id"], project.id) - url = '/api/v1/projects/member-of/?user={}'.format(test_user.id) + url = "/api/v1/projects/member-of/?user={}".format(test_user.id) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 0) + self.assertEqual(response.data["count"], 0) def test_project_of_user(self): test_user = self.create(User) - url = '/api/v1/projects/member-of/?user={}'.format(test_user.id) + url = "/api/v1/projects/member-of/?user={}".format(test_user.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 0) + self.assertEqual(response.data["count"], 0) - url = '/api/v1/projects/member-of/' + url = "/api/v1/projects/member-of/" # authenticate test_user self.authenticate(test_user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 0) + self.assertEqual(response.data["count"], 0) # Create another project and add test_user to the project project1 = self.create(Project, role=self.admin_role) @@ -730,8 +623,8 @@ def test_project_of_user(self): response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['id'], project1.id) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["id"], project1.id) def test_project_members_view(self): # NOTE: Can only get if member of project @@ -740,12 +633,12 @@ def test_project_members_view(self): test_dummy = self.create(User) project1.add_member(test_user, role=self.admin_role) - url = f'/api/v1/projects/{project1.pk}/members/' + url = f"/api/v1/projects/{project1.pk}/members/" # authenticate test_user self.authenticate(test_user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # authenticate test_dummy user self.authenticate(test_dummy) @@ -756,25 +649,23 @@ def test_add_member(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - url = f'/api/v1/projects/{project.id}/project-memberships/' + url = f"/api/v1/projects/{project.id}/project-memberships/" data = { - 'member': test_user.pk, - 'role': self.normal_role.id, + "member": test_user.pk, + "role": self.normal_role.id, } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['role'], data['role']) - self.assertEqual(response.data['member'], data['member']) - self.assertEqual(response.data['project'], project.id) - self.assertEqual(response.data['role_details']['title'], self.normal_role.title) - response_id = response.data['id'] - url = f'/api/v1/projects/{project.id}/project-memberships/{response_id}/' - data = { - 'role': self.admin_role.id - } + self.assertEqual(response.data["role"], data["role"]) + self.assertEqual(response.data["member"], data["member"]) + self.assertEqual(response.data["project"], project.id) + self.assertEqual(response.data["role_details"]["title"], self.normal_role.title) + response_id = response.data["id"] + url = f"/api/v1/projects/{project.id}/project-memberships/{response_id}/" + data = {"role": self.admin_role.id} response = self.client.patch(url, data) self.assert_200(response) @@ -782,41 +673,33 @@ def test_add_member_unexistent_role(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - url = f'/api/v1/projects/{project.id}/project-memberships/' - data = { - 'member': test_user.pk, - 'role': 9999 - } + url = f"/api/v1/projects/{project.id}/project-memberships/" + data = {"member": test_user.pk, "role": 9999} self.authenticate() response = self.client.post(url, data) self.assert_400(response) - assert 'errors' in response.data + assert "errors" in response.data def test_add_member_duplicate(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) project.add_member(test_user) - url = f'/api/v1/projects/{project.id}/project-memberships/' - data = { - 'member': test_user.pk - } + url = f"/api/v1/projects/{project.id}/project-memberships/" + data = {"member": test_user.pk} self.authenticate() response = self.client.post(url, data) self.assert_400(response) - assert 'errors' in response.data - assert 'member' in response.data['errors'] + assert "errors" in response.data + assert "member" in response.data["errors"] # try deleting the members and add back again - ProjectMembership.objects.filter( - project=project, - member=test_user - ).delete() + ProjectMembership.objects.filter(project=project, member=test_user).delete() data = { - 'member': test_user.pk, + "member": test_user.pk, } self.authenticate() response = self.client.post(url, data) @@ -828,9 +711,9 @@ def test_project_membership_edit_normal_role(self): test_user = self.create(User) m1 = project.add_member(test_user, role=self.normal_role) data = { - 'role': self.admin_role.id, + "role": self.admin_role.id, } - url = f'/api/v1/projects/{project.id}/project-memberships/{m1.id}/' + url = f"/api/v1/projects/{project.id}/project-memberships/{m1.id}/" self.authenticate() # authenticate with normal_role response = self.client.patch(url, data) self.assert_403(response) @@ -839,10 +722,8 @@ def test_project_membership_edit_admin_role(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) m1 = project.add_member(test_user, role=self.normal_role) - data = { - 'role': self.admin_role.id - } - url = f'/api/v1/projects/{project.id}/project-memberships/{m1.id}/' + data = {"role": self.admin_role.id} + url = f"/api/v1/projects/{project.id}/project-memberships/{m1.id}/" self.authenticate() # authenticate with admin_role response = self.client.patch(url, data) self.assert_200(response) @@ -853,35 +734,26 @@ def test_project_membership_add(self): test_user1 = self.create(User) test_user2 = self.create(User) project.add_member(test_user2, role=self.normal_role) - data = { - 'member': test_user1.id, - 'role': self.admin_role.id - } - url = f'/api/v1/projects/{project.id}/project-memberships/' + data = {"member": test_user1.id, "role": self.admin_role.id} + url = f"/api/v1/projects/{project.id}/project-memberships/" self.authenticate(test_user2) # test_user2 has normal_role in project response = self.client.post(url, data) self.assert_400(response) def test_options(self): - url = '/api/v1/project-options/' + url = "/api/v1/project-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertIn('regions', response.data) - self.assertIn(self.region2.id, [item['key'] for item in response.data['regions']]) - self.assertIn(self.region2.title, [item['value'] for item in response.data['regions']]) - self.assertEqual( - set([item['key'] for item in response.data['regions']]), - set([self.region1.id, self.region2.id]) - ) - self.assertIn('user_groups', response.data) - self.assertIn(self.ug1.id, [item['key'] for item in response.data['user_groups']]) - self.assertIn(self.ug1.title, [item['value'] for item in response.data['user_groups']]) - self.assertEqual( - set([item['key'] for item in response.data['user_groups']]), - set([self.ug1.id, self.ug2.id]) - ) + self.assertIn("regions", response.data) + self.assertIn(self.region2.id, [item["key"] for item in response.data["regions"]]) + self.assertIn(self.region2.title, [item["value"] for item in response.data["regions"]]) + self.assertEqual(set([item["key"] for item in response.data["regions"]]), set([self.region1.id, self.region2.id])) + self.assertIn("user_groups", response.data) + self.assertIn(self.ug1.id, [item["key"] for item in response.data["user_groups"]]) + self.assertIn(self.ug1.title, [item["value"] for item in response.data["user_groups"]]) + self.assertEqual(set([item["key"] for item in response.data["user_groups"]]), set([self.ug1.id, self.ug2.id])) def test_particular_project_in_project_options(self): user = self.create_user() @@ -897,81 +769,71 @@ def test_particular_project_in_project_options(self): ProjectUserGroupMembership.objects.create(project=project, usergroup=usergroup1) ProjectUserGroupMembership.objects.create(project=project, usergroup=usergroup2) - url = f'/api/v1/project-options/?project={project.id}' + url = f"/api/v1/project-options/?project={project.id}" self.authenticate(user) response = self.client.get(url) self.assert_200(response) - self.assertIn(region2.id, [item['key'] for item in response.data['regions']]) - self.assertIn(region2.title, [item['value'] for item in response.data['regions']]) + self.assertIn(region2.id, [item["key"] for item in response.data["regions"]]) + self.assertIn(region2.title, [item["value"] for item in response.data["regions"]]) # here response consists of the regions for the user # which are public or project__member or created_by self.assertEqual( - set([item['key'] for item in response.data['regions']]), - set([region2.id, region1.id, self.region1.id, self.region2.id]) - ) - self.assertIn('user_groups', response.data) - self.assertIn(usergroup2.id, [item['key'] for item in response.data['user_groups']]) - self.assertIn(usergroup2.title, [item['value'] for item in response.data['user_groups']]) - self.assertEqual( - set([item['key'] for item in response.data['user_groups']]), - set([usergroup1.id, usergroup2.id]) + set([item["key"] for item in response.data["regions"]]), + set([region2.id, region1.id, self.region1.id, self.region2.id]), ) + self.assertIn("user_groups", response.data) + self.assertIn(usergroup2.id, [item["key"] for item in response.data["user_groups"]]) + self.assertIn(usergroup2.title, [item["value"] for item in response.data["user_groups"]]) + self.assertEqual(set([item["key"] for item in response.data["user_groups"]]), set([usergroup1.id, usergroup2.id])) def test_project_status_in_project_options(self): choices = dict(make_hashable(Project.Status.choices)) - url = '/api/v1/project-options/' + url = "/api/v1/project-options/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertIn('project_status', response.data) - self.assertEqual(response.data['project_status'][0]['key'], Project.Status.ACTIVE) - self.assertEqual(response.data['project_status'][0]['value'], choices[Project.Status.ACTIVE]) - self.assertEqual(response.data['project_status'][1]['key'], Project.Status.INACTIVE) - self.assertEqual(response.data['project_status'][1]['value'], choices[Project.Status.INACTIVE]) + self.assertIn("project_status", response.data) + self.assertEqual(response.data["project_status"][0]["key"], Project.Status.ACTIVE) + self.assertEqual(response.data["project_status"][0]["value"], choices[Project.Status.ACTIVE]) + self.assertEqual(response.data["project_status"][1]["key"], Project.Status.INACTIVE) + self.assertEqual(response.data["project_status"][1]["value"], choices[Project.Status.INACTIVE]) def test_join_request(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) data = dict( - reason='bla', + reason="bla", ) - url = '/api/v1/projects/{}/join/'.format(project.id) + url = "/api/v1/projects/{}/join/".format(project.id) self.authenticate(test_user) response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['project']['id'], project.id) - self.assertEqual(response.data['requested_by']['id'], test_user.id) - self.assertEqual( - ProjectJoinRequest.objects.get(id=response.data['id']).data['reason'], - data['reason'] - ) + self.assertEqual(response.data["project"]["id"], project.id) + self.assertEqual(response.data["requested_by"]["id"], test_user.id) + self.assertEqual(ProjectJoinRequest.objects.get(id=response.data["id"]).data["reason"], data["reason"]) def test_invalid_join_request(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - url = '/api/v1/projects/{}/join/'.format(project.id) + url = "/api/v1/projects/{}/join/".format(project.id) self.authenticate(test_user) response = self.client.post(url) self.assert_400(response) - self.assertIn('reason', response.data['errors']) + self.assertIn("reason", response.data["errors"]) def test_accept_request(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - request = ProjectJoinRequest.objects.create( - project=project, - requested_by=test_user, - role=self.admin_role - ) + request = ProjectJoinRequest.objects.create(project=project, requested_by=test_user, role=self.admin_role) - url = '/api/v1/projects/{}/requests/{}/accept/'.format( + url = "/api/v1/projects/{}/requests/{}/accept/".format( project.id, request.id, ) @@ -980,8 +842,8 @@ def test_accept_request(self): response = self.client.post(url) self.assert_200(response) - self.assertEqual(response.data['responded_by']['id'], self.user.id) - self.assertEqual(response.data['status'], 'accepted') + self.assertEqual(response.data["responded_by"]["id"], self.user.id) + self.assertEqual(response.data["status"], "accepted") membership = ProjectMembership.objects.filter( project=project, member=test_user, @@ -992,13 +854,9 @@ def test_accept_request(self): def test_reject_request(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - request = ProjectJoinRequest.objects.create( - project=project, - requested_by=test_user, - role=self.admin_role - ) + request = ProjectJoinRequest.objects.create(project=project, requested_by=test_user, role=self.admin_role) - url = '/api/v1/projects/{}/requests/{}/reject/'.format( + url = "/api/v1/projects/{}/requests/{}/reject/".format( project.id, request.id, ) @@ -1007,25 +865,17 @@ def test_reject_request(self): response = self.client.post(url) self.assert_200(response) - self.assertEqual(response.data['responded_by']['id'], self.user.id) - self.assertEqual(response.data['status'], 'rejected') - membership = ProjectMembership.objects.filter( - project=project, - member=test_user, - role=self.normal_role - ) + self.assertEqual(response.data["responded_by"]["id"], self.user.id) + self.assertEqual(response.data["status"], "rejected") + membership = ProjectMembership.objects.filter(project=project, member=test_user, role=self.normal_role) self.assertEqual(membership.count(), 0) def test_cancel_request(self): project = self.create(Project, role=self.admin_role) test_user = self.create(User) - request = ProjectJoinRequest.objects.create( - project=project, - requested_by=test_user, - role=self.admin_role - ) + request = ProjectJoinRequest.objects.create(project=project, requested_by=test_user, role=self.admin_role) - url = '/api/v1/projects/{}/join/cancel/'.format(project.id) + url = "/api/v1/projects/{}/join/cancel/".format(project.id) self.authenticate(test_user) response = self.client.post(url) @@ -1040,18 +890,18 @@ def test_list_request(self): self.create(ProjectJoinRequest, project=project) self.create(ProjectJoinRequest, project=project) - url = '/api/v1/projects/{}/requests/'.format(project.id) + url = "/api/v1/projects/{}/requests/".format(project.id) self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 3) - self.assertEqual(response.data['count'], 3) + self.assertEqual(len(response.data["results"]), 3) + self.assertEqual(response.data["count"], 3) def test_delete_project_admin(self): project = self.create(Project, role=self.admin_role) - url = '/api/v1/projects/{}/'.format(project.id) + url = "/api/v1/projects/{}/".format(project.id) self.authenticate() response = self.client.delete(url) self.assert_204(response) @@ -1062,7 +912,7 @@ def test_delete_project_normal(self): project.add_member(user) - url = '/api/v1/projects/{}/'.format(project.id) + url = "/api/v1/projects/{}/".format(project.id) self.authenticate(user) response = self.client.delete(url) @@ -1073,7 +923,7 @@ def test_get_project_role(self): user = self.create(User) project.add_member(user) - url = '/api/v1/project-roles/' + url = "/api/v1/project-roles/" self.authenticate() @@ -1101,45 +951,39 @@ def test_auto_accept(self): # request for that user, auto accept that request project = self.create(Project, role=self.admin_role) test_user = self.create(User) - request = ProjectJoinRequest.objects.create( - project=project, - requested_by=test_user, - role=self.admin_role - ) + request = ProjectJoinRequest.objects.create(project=project, requested_by=test_user, role=self.admin_role) project.add_member(test_user, self.normal_role, self.user) request = ProjectJoinRequest.objects.get(id=request.id) - self.assertEqual(request.status, 'accepted') + self.assertEqual(request.status, "accepted") self.assertEqual(request.responded_by, self.user) def test_status_filter(self): - project1 = self.create(Project, role=self.admin_role, status='active') - self.create(Project, role=self.admin_role, status='inactive') - self.create(Project, role=self.admin_role, status='inactive') + project1 = self.create(Project, role=self.admin_role, status="active") + self.create(Project, role=self.admin_role, status="inactive") + self.create(Project, role=self.admin_role, status="inactive") test_user = self.create(User) project1.add_member(test_user, role=self.admin_role) - url = '/api/v1/projects/?status=inactive' + url = "/api/v1/projects/?status=inactive" self.authenticate(test_user) response = self.client.get(url) - self.assertEqual(response.data['count'], 2) + self.assertEqual(response.data["count"], 2) # try filtering out the active status - url = '/api/v1/projects/?status=active' + url = "/api/v1/projects/?status=active" self.authenticate(test_user) response = self.client.get(url) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data["count"], 1) # try to update the status of the project - data = { - 'status': 'active' - } - url1 = f'/api/v1/projects/{project1.id}/' + data = {"status": "active"} + url1 = f"/api/v1/projects/{project1.id}/" self.authenticate(test_user) response = self.client.patch(url1, data) self.assert_200(response) - self.assertEqual(response.data['status'], project1.status) + self.assertEqual(response.data["status"], project1.status) def test_involvment_filter(self): project1 = self.create(Project, role=self.admin_role) @@ -1150,22 +994,19 @@ def test_involvment_filter(self): project1.add_member(test_user, role=self.normal_role) project2.add_member(test_user, role=self.normal_role) - url = '/api/v1/projects/?involvement=my_projects' + url = "/api/v1/projects/?involvement=my_projects" self.authenticate(test_user) response = self.client.get(url) self.assert_200(response) - expected = [ - project1.id, - project2.id - ] - obtained = [r['id'] for r in response.data['results']] + expected = [project1.id, project2.id] + obtained = [r["id"] for r in response.data["results"]] - self.assertEqual(response.data['count'], len(expected)) + self.assertEqual(response.data["count"], len(expected)) self.assertTrue(sorted(expected) == sorted(obtained)) - url = '/api/v1/projects/?involvement=not_my_projects' + url = "/api/v1/projects/?involvement=not_my_projects" self.authenticate(test_user) response = self.client.get(url) @@ -1174,9 +1015,9 @@ def test_involvment_filter(self): expected = [ project3.id, ] - obtained = [r['id'] for r in response.data['results']] + obtained = [r["id"] for r in response.data["results"]] - self.assertEqual(response.data['count'], len(expected)) + self.assertEqual(response.data["count"], len(expected)) self.assertTrue(sorted(expected) == sorted(obtained)) def test_project_role_level(self): @@ -1186,8 +1027,8 @@ def test_project_role_level(self): m1 = project.add_member(test_user1, role=self.normal_role) m2 = project.add_member(test_user2, role=self.admin_role) - url1 = f'/api/v1/projects/{project.id}/project-memberships/{m1.id}/' - url2 = f'/api/v1/projects/{project.id}/project-memberships/{m2.id}/' + url1 = f"/api/v1/projects/{project.id}/project-memberships/{m1.id}/" + url2 = f"/api/v1/projects/{project.id}/project-memberships/{m2.id}/" # Initial condition: We are Admin self.authenticate() @@ -1195,7 +1036,7 @@ def test_project_role_level(self): # Condition 1: We are trying to change a normal # user's role to Clairvaoyant One data = { - 'role': self.admin_role.id, + "role": self.admin_role.id, } response = self.client.patch(url1, data) self.assert_400(response) @@ -1203,7 +1044,7 @@ def test_project_role_level(self): # Condition 2: We are trying to change a normal # user's role to Admin data = { - 'role': self.smaller_admin_role.id, + "role": self.smaller_admin_role.id, } response = self.client.patch(url1, data) self.assert_200(response) @@ -1211,7 +1052,7 @@ def test_project_role_level(self): # Condition 3: We are trying to change a CO user # when he/she is the only CO user in the project data = { - 'role': self.smaller_admin_role.id, + "role": self.smaller_admin_role.id, } response = self.client.patch(url2, data) self.assert_403(response) @@ -1242,13 +1083,13 @@ def test_project_role_level(self): self.assert_204(response) def _change_project_privacy_test(self, project, status=403, user=None): - url = f'/api/v1/projects/{project.id}/' + url = f"/api/v1/projects/{project.id}/" changed_privacy = not project.is_private put_data = { - 'title': project.title, - 'is_private': changed_privacy, - 'organizations': [], + "title": project.title, + "is_private": changed_privacy, + "organizations": [], # Other fields we don't care } self.authenticate(user) @@ -1256,7 +1097,7 @@ def _change_project_privacy_test(self, project, status=403, user=None): self.assertEqual(response.status_code, status) # Try patching, should give 403 as well - patch_data = {'is_private': changed_privacy} + patch_data = {"is_private": changed_privacy} response = self.client.patch(url, patch_data) self.assertEqual(response.status_code, status) @@ -1274,41 +1115,46 @@ def test_project_stats(self): lead = self.create(Lead, project=project) entry = self.create( Entry, - project=project, analysis_framework=af, lead=lead, entry_type=Entry.TagType.EXCERPT, + project=project, + analysis_framework=af, + lead=lead, + entry_type=Entry.TagType.EXCERPT, ) # Create widgets, attributes and configs invalid_stat_config = {} valid_stat_config = {} - for index, (title, widget_identifier, data_identifier, config_kwargs) in enumerate([ - ('widget 1d', 'widget_1d', 'matrix1dWidget', {}), - ('widget 2d', 'widget_2d', 'matrix2dWidget', {}), - ('geo widget', 'geo_widget', 'geoWidget', {}), - ('reliability widget', 'reliability_widget', 'scaleWidget', {}), - ('affected groups widget', 'affected_groups_widget', 'multiselectWidget', {}), - ('specific needs groups widget', 'specific_needs_groups_widget', 'multiselectWidget', {}), - ]): + for index, (title, widget_identifier, data_identifier, config_kwargs) in enumerate( + [ + ("widget 1d", "widget_1d", "matrix1dWidget", {}), + ("widget 2d", "widget_2d", "matrix2dWidget", {}), + ("geo widget", "geo_widget", "geoWidget", {}), + ("reliability widget", "reliability_widget", "scaleWidget", {}), + ("affected groups widget", "affected_groups_widget", "multiselectWidget", {}), + ("specific needs groups widget", "specific_needs_groups_widget", "multiselectWidget", {}), + ] + ): widget = self.create( Widget, analysis_framework=af, section=None, title=title, widget_id=data_identifier, - key=f'{data_identifier}-{index}', + key=f"{data_identifier}-{index}", properties=w_data[data_identifier], ) self.create(Attribute, entry=entry, widget=widget, data=a_data[data_identifier]) valid_stat_config[widget_identifier] = { - 'pk': widget.pk, + "pk": widget.pk, **config_kwargs, } - invalid_stat_config[widget_identifier] = {'pk': 0} - if data_identifier in ['matrix1dWidget', 'matrix2dWidget', 'multiselectWidget']: + invalid_stat_config[widget_identifier] = {"pk": 0} + if data_identifier in ["matrix1dWidget", "matrix2dWidget", "multiselectWidget"]: valid_stat_config[widget_identifier] = [valid_stat_config[widget_identifier]] invalid_stat_config[widget_identifier] = [invalid_stat_config[widget_identifier]] - url = f'/api/v1/projects/{project.pk}/project-viz/' + url = f"/api/v1/projects/{project.pk}/project-viz/" # 404 for non project user self.authenticate(non_project_user) response = self.client.get(url) @@ -1320,7 +1166,7 @@ def test_project_stats(self): response = self.client.get(url) self.assert_404(response) - af.properties = {'stats_config': invalid_stat_config} + af.properties = {"stats_config": invalid_stat_config} af.save() # 202 if config is set @@ -1331,16 +1177,16 @@ def test_project_stats(self): _generate_project_viz_stats(project.pk) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.json()['status'], 'failure') + self.assertEqual(response.json()["status"], "failure") - af.properties = {'stats_config': valid_stat_config} + af.properties = {"stats_config": valid_stat_config} af.save() # 302 (Redirect to data file) if valid config is set and stat is generated _generate_project_viz_stats(project.pk) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.json()['status'], 'success') + self.assertEqual(response.json()["status"], "success") return project def test_project_lead_groups_api(self): @@ -1348,14 +1194,13 @@ def test_project_lead_groups_api(self): lead_group1 = self.create(LeadGroup, project=project) lead_group2 = self.create(LeadGroup, project=project) - url = f'/api/v1/projects/{project.pk}/lead-groups/' + url = f"/api/v1/projects/{project.pk}/lead-groups/" self.authenticate() response = self.client.get(url) self.assert_200(response) # Only provide projects leads-group [Pagination is done for larger dataset] - assert set([lg['id'] for lg in response.json()['results']]) ==\ - set([lead_group1.pk, lead_group2.pk]) + assert set([lg["id"] for lg in response.json()["results"]]) == set([lead_group1.pk, lead_group2.pk]) def test_project_memberships_if_not_in_project(self): """ @@ -1366,13 +1211,13 @@ def test_project_memberships_if_not_in_project(self): user2 = self.create(User) project.add_member(user1, role=self.admin_role) - url = '/api/v1/projects/' + url = "/api/v1/projects/" self.authenticate(user2) # authenticate with another user that is not project member response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) # there should be one project - self.assertEqual(response.data['results'][0]['id'], project.id) - self.assertNotIn('memberships', response.data['results'][0]) # No memberships field should be shown + self.assertEqual(response.data["count"], 1) # there should be one project + self.assertEqual(response.data["results"][0]["id"], project.id) + self.assertNotIn("memberships", response.data["results"][0]) # No memberships field should be shown def test_project_memberships_in_particluar_project(self): project1 = self.create(Project, is_private=False) @@ -1380,29 +1225,29 @@ def test_project_memberships_in_particluar_project(self): user2 = self.create(User) project1.add_member(user1, role=self.admin_role) - url = f'/api/v1/projects/{project1.id}/' + url = f"/api/v1/projects/{project1.id}/" self.authenticate(user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['id'], project1.id) - self.assertIn('memberships', response.data) - self.assertEqual(response.data['memberships'][0]['member'], user1.id) + self.assertEqual(response.data["id"], project1.id) + self.assertIn("memberships", response.data) + self.assertEqual(response.data["memberships"][0]["member"], user1.id) # same project authenticate with not member user self.authenticate(user2) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['id'], project1.id) - self.assertNotIn('memberships', response.data) # `membership` field shouldnot be present + self.assertEqual(response.data["id"], project1.id) + self.assertNotIn("memberships", response.data) # `membership` field shouldnot be present def test_project_summary_api(self): user = self.create_user() - project1 = self.create_project(title='Project 1') - project2 = self.create_project(title='Project 2') - project3 = self.create_project(title='Project 3') - project4 = self.create_project(title='Project 4') - project5 = self.create_project(title='Project 5') + project1 = self.create_project(title="Project 1") + project2 = self.create_project(title="Project 2") + project3 = self.create_project(title="Project 3") + project4 = self.create_project(title="Project 4") + project5 = self.create_project(title="Project 5") project1.add_member(user) project2.add_member(user) project3.add_member(user) @@ -1420,66 +1265,16 @@ def test_project_summary_api(self): self.create_lead(project=project5) data = [ - { - "lead": lead1, - "controlled": True, - "months": -3, - "days": -1 - }, - { - "lead": lead1, - "controlled": True, - "months": -2, - "days": -1 - }, - { - "lead": lead2, - "controlled": False, - "months": -3, - "days": -1 - }, - { - "lead": lead2, - "controlled": True, - "months": -3, - "days": -1 - }, - { - "lead": lead2, - "controlled": True, - "months": -3, - "days": -1 - }, - { - "lead": lead3, - "controlled": True, - "months": -1, - "days": -10 - }, - { - "lead": lead3, - "controlled": True, - "months": -1, - "days": -20 - }, - { - "lead": lead3, - "controlled": True, - "months": -1, - "days": -30 - }, - { - "lead": lead3, - "controlled": True, - "months": -1, - "days": -40 - }, - { - "lead": lead4, - "controlled": False, - "months": -3, - "days": -1 - }, + {"lead": lead1, "controlled": True, "months": -3, "days": -1}, + {"lead": lead1, "controlled": True, "months": -2, "days": -1}, + {"lead": lead2, "controlled": False, "months": -3, "days": -1}, + {"lead": lead2, "controlled": True, "months": -3, "days": -1}, + {"lead": lead2, "controlled": True, "months": -3, "days": -1}, + {"lead": lead3, "controlled": True, "months": -1, "days": -10}, + {"lead": lead3, "controlled": True, "months": -1, "days": -20}, + {"lead": lead3, "controlled": True, "months": -1, "days": -30}, + {"lead": lead3, "controlled": True, "months": -1, "days": -40}, + {"lead": lead4, "controlled": False, "months": -3, "days": -1}, { "lead": lead5, "controlled": True, @@ -1492,12 +1287,7 @@ def test_project_summary_api(self): "months": -2, "days": -1, }, - { - "lead": lead6, - "controlled": True, - "months": -3, - "days": -1 - }, + {"lead": lead6, "controlled": True, "months": -3, "days": -1}, { "lead": lead7, "controlled": True, @@ -1508,35 +1298,35 @@ def test_project_summary_api(self): now = timezone.now() for item in data: self.update_obj( - self.create_entry(lead=item['lead'], controlled=item['controlled'], created_by=user), - created_at=now + relativedelta(months=item['months'], days=item['days']) + self.create_entry(lead=item["lead"], controlled=item["controlled"], created_by=user), + created_at=now + relativedelta(months=item["months"], days=item["days"]), ) # Run the caching process _generate_project_stats_cache() self.authenticate(user) - url = '/api/v1/projects-stat/summary/' + url = "/api/v1/projects-stat/summary/" response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['projects_count'], 5) - self.assertEqual(response.data['total_leads_count'], 9) - self.assertEqual(response.data['total_leads_tagged_count'], 7) - self.assertEqual(response.data['total_leads_tagged_and_controlled_count'], 5) - self.assertEqual(len(response.data['recent_entries_activity']['projects']), 3) - self.assertEqual(response.data['recent_entries_activity']['projects'][0]['id'], project1.id) - self.assertEqual(response.data['recent_entries_activity']['projects'][0]['count'], 1) - self.assertEqual(response.data['recent_entries_activity']['projects'][1]['id'], project2.id) - self.assertEqual(response.data['recent_entries_activity']['projects'][1]['count'], 4) - self.assertEqual(len(response.data['recent_entries_activity']['activities']), 6) + self.assertEqual(response.data["projects_count"], 5) + self.assertEqual(response.data["total_leads_count"], 9) + self.assertEqual(response.data["total_leads_tagged_count"], 7) + self.assertEqual(response.data["total_leads_tagged_and_controlled_count"], 5) + self.assertEqual(len(response.data["recent_entries_activity"]["projects"]), 3) + self.assertEqual(response.data["recent_entries_activity"]["projects"][0]["id"], project1.id) + self.assertEqual(response.data["recent_entries_activity"]["projects"][0]["count"], 1) + self.assertEqual(response.data["recent_entries_activity"]["projects"][1]["id"], project2.id) + self.assertEqual(response.data["recent_entries_activity"]["projects"][1]["count"], 4) + self.assertEqual(len(response.data["recent_entries_activity"]["activities"]), 6) def test_project_recent_api(self): user = self.create_user() - project1 = self.create_project(title='Project 1') - project2 = self.create_project(title='Project 2') - project3 = self.create_project(title='Project 3') - project4 = self.create_project(title='Project 4') + project1 = self.create_project(title="Project 1") + project2 = self.create_project(title="Project 2") + project3 = self.create_project(title="Project 3") + project4 = self.create_project(title="Project 4") project1.add_member(user) project2.add_member(user) project3.add_member(user) @@ -1548,27 +1338,27 @@ def test_project_recent_api(self): self.create_lead(project=project4, created_by=user) self.authenticate(user) - url = '/api/v1/projects-stat/recent/' + url = "/api/v1/projects-stat/recent/" response = self.client.get(url) self.assert_200(response) self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]['id'], project3.pk) - self.assertEqual(response.data[1]['id'], project1.pk) - self.assertEqual(response.data[2]['id'], project2.pk) + self.assertEqual(response.data[0]["id"], project3.pk) + self.assertEqual(response.data[1]["id"], project1.pk) + self.assertEqual(response.data[2]["id"], project2.pk) lead2.modified_by = user lead2.save() response = self.client.get(url) self.assert_200(response) self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]['id'], project2.pk) + self.assertEqual(response.data[0]["id"], project2.pk) def test_project_stats_api(self): user = self.create_user() - project1 = self.create_project(title='Project 1') - project2 = self.create_project(title='Project 2') - project3 = self.create_project(title='Project 3') + project1 = self.create_project(title="Project 1") + project2 = self.create_project(title="Project 2") + project3 = self.create_project(title="Project 3") project1.add_member(user) project2.add_member(user) @@ -1596,28 +1386,26 @@ def test_project_stats_api(self): # number_of_leads_tagged lead1_2.status = lead1_1.status = Lead.Status.TAGGED - lead1_2.save(update_fields=('status',)) - lead1_1.save(update_fields=('status',)) + lead1_2.save(update_fields=("status",)) + lead1_1.save(update_fields=("status",)) # Run the caching process _generate_project_stats_cache() self.authenticate(user) - url = '/api/v1/projects-stat/?involvement=my_projects' + url = "/api/v1/projects-stat/?involvement=my_projects" response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) + self.assertEqual(len(response.data["results"]), 2) # Check response for Project 1 - project_1_data = next( - project for project in response.data['results'] if project['id'] == project1.pk - ) - self.assertEqual(project_1_data['id'], project1.pk) - self.assertEqual(project_1_data['number_of_leads'], 5) - self.assertEqual(project_1_data['number_of_leads_tagged'], 2) - self.assertEqual(project_1_data['number_of_leads_tagged_and_controlled'], 1) - self.assertEqual(project_1_data['number_of_entries'], 9) - self.assertEqual(len(project_1_data['leads_activity']), 2) - self.assertEqual(len(project_1_data['entries_activity']), 3) + project_1_data = next(project for project in response.data["results"] if project["id"] == project1.pk) + self.assertEqual(project_1_data["id"], project1.pk) + self.assertEqual(project_1_data["number_of_leads"], 5) + self.assertEqual(project_1_data["number_of_leads_tagged"], 2) + self.assertEqual(project_1_data["number_of_leads_tagged_and_controlled"], 1) + self.assertEqual(project_1_data["number_of_entries"], 9) + self.assertEqual(len(project_1_data["leads_activity"]), 2) + self.assertEqual(len(project_1_data["entries_activity"]), 3) def test_project_stats_public_api(self): normal_user = self.create_user() @@ -1628,10 +1416,10 @@ def test_project_stats_public_api(self): project.add_member(admin_user, role=self.admin_role) project.add_member(member_user, role=self.normal_role) - url = f'/api/v1/projects/{project.pk}/public-viz/' + url = f"/api/v1/projects/{project.pk}/public-viz/" # Check permission for token generation - for action in ['new', 'off', 'new', 'on', 'random']: + for action in ["new", "off", "new", "on", "random"]: for user, assertLogic in [ (normal_user, self.assert_403), (member_user, self.assert_403), @@ -1639,34 +1427,26 @@ def test_project_stats_public_api(self): ]: self.authenticate(user) current_stats = ProjectStats.objects.get(project=project) - response = self.client.post(url, data={'action': action}) - if action == 'random' and assertLogic == self.assert_200: + response = self.client.post(url, data={"action": action}) + if action == "random" and assertLogic == self.assert_200: self.assert_400(response) else: assertLogic(response) if assertLogic == self.assert_200: - if action == 'new': - assert response.data['public_url'] != current_stats.token + if action == "new": + assert response.data["public_url"] != current_stats.token # Logout and check if response is okay self.client.logout() response = self.client.get(f"{response.data['public_url']}?format=json") self.assert_200(response) - elif action == 'on': - assert ( - response.data['public_url'] is not None - ) or ( - response.data['public_url'] == current_stats.token - ) + elif action == "on": + assert (response.data["public_url"] is not None) or (response.data["public_url"] == current_stats.token) # Logout and check if response is not okay self.client.logout() response = self.client.get(f"{response.data['public_url']}?format=json") self.assert_200(response) - elif action == 'off': - assert ( - response.data['public_url'] is not None - ) or ( - response.data['public_url'] == current_stats.token - ) + elif action == "off": + assert (response.data["public_url"] is not None) or (response.data["public_url"] == current_stats.token) # Logout and check if response is not okay self.client.logout() response = self.client.get(f"{response.data['public_url']}?format=json") @@ -1697,7 +1477,7 @@ def test_project_recent_activities_api(self): normal_user = self.create_user() member_user = self.create_user() - project = self.create_project(title='Project 1') + project = self.create_project(title="Project 1") project.add_member(member_user) now = timezone.now() @@ -1718,14 +1498,14 @@ def test_project_recent_activities_api(self): # Entries Comments self.create(EntryReviewComment, entry=entry) - url = '/api/v1/projects/recent-activities/' + url = "/api/v1/projects/recent-activities/" self.authenticate(normal_user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 0) + self.assertEqual(len(response.data["results"]), 0) self.authenticate(member_user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 11) + self.assertEqual(len(response.data["results"]), 11) diff --git a/apps/project/tests/test_filters.py b/apps/project/tests/test_filters.py index cecbb02b4a..7c8e17d836 100644 --- a/apps/project/tests/test_filters.py +++ b/apps/project/tests/test_filters.py @@ -1,12 +1,11 @@ -from utils.graphene.tests import GraphQLTestCase - -from project.factories import ProjectFactory -from organization.factories import OrganizationFactory from analysis_framework.factories import AnalysisFrameworkFactory - +from organization.factories import OrganizationFactory +from project.factories import ProjectFactory from project.filter_set import ProjectGqlFilterSet from project.models import Project +from utils.graphene.tests import GraphQLTestCase + class TestProjectFilter(GraphQLTestCase): def setUp(self) -> None: @@ -20,43 +19,27 @@ def test_organization_filter(self): p2.organizations.set([org2, org3]) p3.organizations.add(org1) - obtained = self.filter_class(data=dict( - organizations=[org3.pk, org2.pk] - )).qs + obtained = self.filter_class(data=dict(organizations=[org3.pk, org2.pk])).qs expected = [p2, p1] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) def test_search_filter(self): - ProjectFactory.create(title='one') - p2 = ProjectFactory.create(title='two') - p3 = ProjectFactory.create(title='twoo') - obtained = self.filter_class(data=dict( - search='tw' - )).qs + ProjectFactory.create(title="one") + p2 = ProjectFactory.create(title="two") + p3 = ProjectFactory.create(title="twoo") + obtained = self.filter_class(data=dict(search="tw")).qs expected = [p2, p3] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) def test_status_filter(self): p1, p2 = ProjectFactory.create_batch(2, status=Project.Status.ACTIVE) p3 = ProjectFactory.create(status=Project.Status.INACTIVE) obtained = self.filter_class(data=dict(status=Project.Status.ACTIVE.value)).qs expected = [p1, p2] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) obtained = self.filter_class(data=dict(status=Project.Status.INACTIVE.value)).qs expected = [p3] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) def test_analysis_framework_filter(self): af1, af2, af3 = AnalysisFrameworkFactory.create_batch(3) @@ -64,11 +47,6 @@ def test_analysis_framework_filter(self): p2 = ProjectFactory.create(analysis_framework=af2) ProjectFactory.create(analysis_framework=af3) - obtained = self.filter_class(data=dict( - analysis_frameworks=[af1.id, af2.id] - )).qs + obtained = self.filter_class(data=dict(analysis_frameworks=[af1.id, af2.id])).qs expected = [p2, p1] - self.assertQuerySetIdEqual( - expected, - obtained - ) + self.assertQuerySetIdEqual(expected, obtained) diff --git a/apps/project/tests/test_migration.py b/apps/project/tests/test_migration.py index d91544a67f..5a09eea52c 100644 --- a/apps/project/tests/test_migration.py +++ b/apps/project/tests/test_migration.py @@ -1,51 +1,54 @@ -from utils.graphene.tests import GraphQLTestCase from project.factories import ProjectFactory from project.migrations.rename_duplicate_project_name import _rename_duplicate_name -from project.migrations.set_istest_true_for_test_projects import _set_istest_true_for_test_projects +from project.migrations.set_istest_true_for_test_projects import ( + _set_istest_true_for_test_projects, +) from project.models import Project +from utils.graphene.tests import GraphQLTestCase + class TestProjectMigrations(GraphQLTestCase): def test_rename_duplicate_projects(self): project1, project2 = ProjectFactory.create_batch(2, title="Ukraine war") project3, project4 = ProjectFactory.create_batch(2, title="Nepal Food Crisis") project5, project6 = ProjectFactory.create_batch(2, title="Iran Bombblast") - project9 = ProjectFactory(title='Iran Bombblast (2)') - project10 = ProjectFactory(title='Iran Bombblast (3)') - project11 = ProjectFactory(title='Japan Earthquake') - project12 = ProjectFactory(title='Japan Hurricane') + project9 = ProjectFactory(title="Iran Bombblast (2)") + project10 = ProjectFactory(title="Iran Bombblast (3)") + project11 = ProjectFactory(title="Japan Earthquake") + project12 = ProjectFactory(title="Japan Hurricane") excepted_projects_name = { - project1.pk: 'Ukraine war (1)', - project2.pk: 'Ukraine war (2)', - project3.pk: 'Nepal Food Crisis (1)', - project4.pk: 'Nepal Food Crisis (2)', - project5.pk: 'Iran Bombblast (1)', - project6.pk: 'Iran Bombblast (4)', - project9.pk: 'Iran Bombblast (2)', - project10.pk: 'Iran Bombblast (3)', - project11.pk: 'Japan Earthquake', - project12.pk: 'Japan Hurricane', + project1.pk: "Ukraine war (1)", + project2.pk: "Ukraine war (2)", + project3.pk: "Nepal Food Crisis (1)", + project4.pk: "Nepal Food Crisis (2)", + project5.pk: "Iran Bombblast (1)", + project6.pk: "Iran Bombblast (4)", + project9.pk: "Iran Bombblast (2)", + project10.pk: "Iran Bombblast (3)", + project11.pk: "Japan Earthquake", + project12.pk: "Japan Hurricane", } _rename_duplicate_name(Project) - for id, title in Project.objects.values_list('id', 'title'): + for id, title in Project.objects.values_list("id", "title"): assert excepted_projects_name[id] == title def test_set_istest_true_for_test_projects(self): project_titles = { - 'test project': True, - 'Test': True, - 'test project': True, - 'Testing project': True, - 'Test1': True, - 'Test2': True, - 'TestTestTest': True, - 'testing project': True, - 'test1 project': True, - 'UNHCR': False, - 'Relief Web': False, + "test project": True, + "Test": True, + "test project": True, + "Testing project": True, + "Test1": True, + "Test2": True, + "TestTestTest": True, + "testing project": True, + "test1 project": True, + "UNHCR": False, + "Relief Web": False, } for title in project_titles.keys(): ProjectFactory(title=title) diff --git a/apps/project/tests/test_mutations.py b/apps/project/tests/test_mutations.py index e6c5b2a68e..4586b4d243 100644 --- a/apps/project/tests/test_mutations.py +++ b/apps/project/tests/test_mutations.py @@ -1,43 +1,42 @@ -from unittest import mock from datetime import timedelta -from factory import fuzzy - -from utils.graphene.tests import GraphQLTestCase, GraphQLSnapShotTestCase -from user.utils import ( - send_project_join_request_emails, - send_project_accept_email, - send_project_reject_email, -) - -from user.models import Feature -from notification.models import Notification -from project.models import ( - get_default_role_id, - ProjectRole, - ProjectJoinRequest, - ProjectMembership, - ProjectUserGroupMembership, - ProjectStats, - ProjectChangeLog, -) +from unittest import mock -from user.factories import UserFactory, FeatureFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory, EntryAttributeFactory from analysis_framework.factories import AnalysisFrameworkFactory, WidgetFactory -from user_group.factories import UserGroupFactory +from entry.factories import EntryAttributeFactory, EntryFactory +from factory import fuzzy +from geo.factories import RegionFactory +from lead.factories import LeadFactory +from notification.models import Notification +from organization.factories import OrganizationFactory from project.factories import ( ProjectFactory, ProjectJoinRequestFactory, ProjectOrganizationFactory, ProjectPinnedFactory, ) -from organization.factories import OrganizationFactory -from geo.factories import RegionFactory - +from project.models import ( + Project, + ProjectChangeLog, + ProjectJoinRequest, + ProjectMembership, + ProjectOrganization, + ProjectRole, + ProjectStats, + ProjectUserGroupMembership, + get_default_role_id, +) from project.tasks import _generate_project_viz_stats, permanently_delete_projects +from user.factories import FeatureFactory, UserFactory +from user.models import Feature +from user.utils import ( + send_project_accept_email, + send_project_join_request_emails, + send_project_reject_email, +) +from user_group.factories import UserGroupFactory + +from utils.graphene.tests import GraphQLSnapShotTestCase, GraphQLTestCase -from project.models import Project, ProjectOrganization from . import entry_stats_data @@ -57,37 +56,39 @@ def set_project_viz_configuration(project): invalid_stat_config = {} valid_stat_config = {} - for index, (title, widget_identifier, data_identifier, config_kwargs) in enumerate([ - ('widget 1d', 'widget_1d', 'matrix1dWidget', {}), - ('widget 2d', 'widget_2d', 'matrix2dWidget', {}), - ('geo widget', 'geo_widget', 'geoWidget', {}), - ('reliability widget', 'reliability_widget', 'scaleWidget', {}), - ('affected groups widget', 'affected_groups_widget', 'multiselectWidget', {}), - ('specific needs groups widget', 'specific_needs_groups_widget', 'multiselectWidget', {}), - ]): + for index, (title, widget_identifier, data_identifier, config_kwargs) in enumerate( + [ + ("widget 1d", "widget_1d", "matrix1dWidget", {}), + ("widget 2d", "widget_2d", "matrix2dWidget", {}), + ("geo widget", "geo_widget", "geoWidget", {}), + ("reliability widget", "reliability_widget", "scaleWidget", {}), + ("affected groups widget", "affected_groups_widget", "multiselectWidget", {}), + ("specific needs groups widget", "specific_needs_groups_widget", "multiselectWidget", {}), + ] + ): widget = WidgetFactory.create( analysis_framework=af, section=None, title=title, widget_id=data_identifier, - key=f'{data_identifier}-{index}', - properties={'data': w_data[data_identifier]}, + key=f"{data_identifier}-{index}", + properties={"data": w_data[data_identifier]}, ) EntryAttributeFactory.create(entry=entry, widget=widget, data=a_data[data_identifier]) valid_stat_config[widget_identifier] = { - 'pk': widget.pk, + "pk": widget.pk, **config_kwargs, } - invalid_stat_config[widget_identifier] = {'pk': 0} + invalid_stat_config[widget_identifier] = {"pk": 0} - af.properties = {'stats_config': invalid_stat_config} - af.save(update_fields=('properties',)) + af.properties = {"stats_config": invalid_stat_config} + af.save(update_fields=("properties",)) project.is_visualization_enabled = True - project.save(update_fields=('is_visualization_enabled',)) + project.save(update_fields=("is_visualization_enabled",)) def test_projects_viz_configuration_update(self): - query = ''' + query = """ mutation MyMutation($id: ID!, $input: ProjectVizConfigurationInputType!) { project(id: $id) { projectVizConfigurationUpdate(data: $input) { @@ -103,7 +104,7 @@ def test_projects_viz_configuration_update(self): } } } - ''' + """ normal_user = UserFactory.create() admin_user = UserFactory.create() @@ -121,8 +122,8 @@ def _query_check(**kwargs): return self.query_check( query, minput=minput, - mnested=['project'], - variables={'id': project.id}, + mnested=["project"], + variables={"id": project.id}, **kwargs, ) @@ -137,46 +138,36 @@ def _query_check(**kwargs): ]: self.force_login(user) current_stats = project.project_stats - minput['action'] = self.genum(action) + minput["action"] = self.genum(action) if assertLogic == self.assert_200: content = _query_check(okay=True) else: _query_check(assert_for_error=True) continue - response = content['data']['project']['projectVizConfigurationUpdate']['result'] + response = content["data"]["project"]["projectVizConfigurationUpdate"]["result"] if assertLogic == self.assert_200: - if action == 'new': - assert response['publicUrl'] != current_stats.token + if action == "new": + assert response["publicUrl"] != current_stats.token # Logout and check if response is okay self.client.logout() rest_response = self.client.get(f"{response['publicUrl']}?format=json") self.assert_200(rest_response) - elif action == 'on': - assert ( - response['publicUrl'] is not None - ) or ( - response['publicUrl'] == current_stats.token - ) + elif action == "on": + assert (response["publicUrl"] is not None) or (response["publicUrl"] == current_stats.token) # Logout and check if response is not okay self.client.logout() rest_response = self.client.get(f"{response['publicUrl']}?format=json") self.assert_200(rest_response) - elif action == 'off': - assert ( - response['publicUrl'] is not None - ) or ( - response['publicUrl'] == current_stats.token - ) + elif action == "off": + assert (response["publicUrl"] is not None) or (response["publicUrl"] == current_stats.token) # Logout and check if response is not okay self.client.logout() rest_response = self.client.get(f"{response['publicUrl']}?format=json") self.assert_403(rest_response) # Check Project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) @@ -190,7 +181,7 @@ class ProjectMutationSnapshotTest(GraphQLSnapShotTestCase): ] def test_project_create_mutation(self): - query = ''' + query = """ mutation MyMutation($input: ProjectCreateInputType!) { __typename projectCreate(data: $input) { @@ -224,7 +215,7 @@ def test_project_create_mutation(self): } } } - ''' + """ user = UserFactory.create() af = AnalysisFrameworkFactory.create() @@ -236,11 +227,11 @@ def test_project_create_mutation(self): org1 = OrganizationFactory.create() minput = dict( - title='Project 1', + title="Project 1", analysisFramework=str(private_af.id), - description='Project description 101', - startDate='2020-01-01', - endDate='2021-01-01', + description="Project description 101", + startDate="2020-01-01", + endDate="2021-01-01", status=self.genum(Project.Status.ACTIVE), isPrivate=True, hasPubliclyViewableUnprotectedLeads=False, @@ -261,8 +252,8 @@ def _query_check(**kwargs): minput=minput, **kwargs, ) - if kwargs.get('okay'): - project_log = ProjectChangeLog.objects.get(project=response['data']['projectCreate']['result']['id']) + if kwargs.get("okay"): + project_log = ProjectChangeLog.objects.get(project=response["data"]["projectCreate"]["result"]["id"]) assert project_log.action == ProjectChangeLog.Action.PROJECT_CREATE return response @@ -274,35 +265,35 @@ def _query_check(**kwargs): response = _query_check(okay=False) # invalid [private AF with memership] + public project - minput['analysisFramework'] = str(private_af_w_membership.pk) - minput['isPrivate'] = False + minput["analysisFramework"] = str(private_af_w_membership.pk) + minput["isPrivate"] = False response = _query_check(okay=False) # invalid [private AF with memership] + private project + without feature permission - minput['isPrivate'] = True + minput["isPrivate"] = True response = _query_check(okay=False) # invalid [private AF with memership] + private project + with feature permission private_project_feature.users.add(user) - response = _query_check(okay=True)['data']['projectCreate'] - self.assertMatchSnapshot(response, 'private-af-private-project-success') + response = _query_check(okay=True)["data"]["projectCreate"] + self.assertMatchSnapshot(response, "private-af-private-project-success") # Valid [public AF] + private project - minput['title'] = "Project 2" - minput['analysisFramework'] = str(af.pk) - minput['isPrivate'] = True - response = _query_check(okay=True)['data']['projectCreate'] - self.assertMatchSnapshot(response, 'public-af-private-project-success') + minput["title"] = "Project 2" + minput["analysisFramework"] = str(af.pk) + minput["isPrivate"] = True + response = _query_check(okay=True)["data"]["projectCreate"] + self.assertMatchSnapshot(response, "public-af-private-project-success") # Valid [public AF] + private project - minput['title'] = "Project 3" - minput['analysisFramework'] = str(af.pk) - minput['isPrivate'] = False - response = _query_check(okay=True)['data']['projectCreate'] - self.assertMatchSnapshot(response, 'public-af-public-project-success') + minput["title"] = "Project 3" + minput["analysisFramework"] = str(af.pk) + minput["isPrivate"] = False + response = _query_check(okay=True)["data"]["projectCreate"] + self.assertMatchSnapshot(response, "public-af-public-project-success") - minput['title'] = 'Project 1' + minput["title"] = "Project 1" response = _query_check(okay=False) def test_project_update_mutation(self): - query = ''' + query = """ mutation MyMutation($projectId: ID!, $input: ProjectUpdateInputType!) { __typename project(id: $projectId) { @@ -338,7 +329,7 @@ def test_project_update_mutation(self): } } } - ''' + """ user = UserFactory.create() normal_user = UserFactory.create() @@ -354,11 +345,11 @@ def test_project_update_mutation(self): private_af_w_membership.add_member(user) public_project = ProjectFactory.create( - title='Public Project 101', + title="Public Project 101", analysis_framework=af, ) private_project = ProjectFactory.create( - title='Private Project 101', + title="Private Project 101", analysis_framework=private_af, is_private=True, ) @@ -384,7 +375,7 @@ def test_project_update_mutation(self): public_project.add_member(normal_user) public_minput = dict( - title=f'{public_project.title} (Updated)', + title=f"{public_project.title} (Updated)", analysisFramework=str(public_project.analysis_framework.id), isTest=True, isPrivate=False, @@ -401,7 +392,7 @@ def test_project_update_mutation(self): private_minput = dict( title=private_project.title, - description='Added some description', + description="Added some description", analysisFramework=str(private_project.analysis_framework.id), isPrivate=True, organizations=[ @@ -416,8 +407,8 @@ def _query_check(project, minput, **kwargs): return self.query_check( query, minput=minput, - mnested=['project'], - variables={'projectId': str(project.pk)}, + mnested=["project"], + variables={"projectId": str(project.pk)}, **kwargs, ) @@ -447,41 +438,41 @@ def _private_query_check(**kwargs): # WITH ACCESS # ----- isPrivate attribute # [changing private status) [public project] - public_minput['isPrivate'] = True - self.assertMatchSnapshot(_public_query_check(okay=False), 'public-project:is-private-change-error') - public_minput['isPrivate'] = False + public_minput["isPrivate"] = True + self.assertMatchSnapshot(_public_query_check(okay=False), "public-project:is-private-change-error") + public_minput["isPrivate"] = False # [changing private status) [public project] - private_minput['isPrivate'] = False - self.assertMatchSnapshot(_private_query_check(okay=False), 'private-project:is-private-change-error') - private_minput['isPrivate'] = True + private_minput["isPrivate"] = False + self.assertMatchSnapshot(_private_query_check(okay=False), "private-project:is-private-change-error") + private_minput["isPrivate"] = True # ----- AF attribute # [changing private status) [public project] - public_minput['analysisFramework'] = str(private_af.id) - self.assertMatchSnapshot(_public_query_check(okay=False), 'public-project:private-af') - public_minput['analysisFramework'] = str(private_af_w_membership.id) - self.assertMatchSnapshot(_public_query_check(okay=False), 'public-project:private-af-with-membership') - public_minput['analysisFramework'] = str(public_project.analysis_framework_id) + public_minput["analysisFramework"] = str(private_af.id) + self.assertMatchSnapshot(_public_query_check(okay=False), "public-project:private-af") + public_minput["analysisFramework"] = str(private_af_w_membership.id) + self.assertMatchSnapshot(_public_query_check(okay=False), "public-project:private-af-with-membership") + public_minput["analysisFramework"] = str(public_project.analysis_framework_id) # [changing private status) [private project] - private_minput['analysisFramework'] = str(private_af_2.id) - self.assertMatchSnapshot(_private_query_check(okay=False), 'private-project:private-af') - private_minput['analysisFramework'] = str(private_af_w_membership.id) + private_minput["analysisFramework"] = str(private_af_2.id) + self.assertMatchSnapshot(_private_query_check(okay=False), "private-project:private-af") + private_minput["analysisFramework"] = str(private_af_w_membership.id) _private_query_check(okay=True) - private_minput['analysisFramework'] = str(private_project.analysis_framework_id) + private_minput["analysisFramework"] = str(private_project.analysis_framework_id) # Check Project change logs project_log = ProjectChangeLog.objects.get(project=public_project) assert project_log.action == ProjectChangeLog.Action.MULTIPLE - self.assertMatchSnapshot(project_log.diff, 'public-project:project-change:diff') - project_logs = list(ProjectChangeLog.objects.filter(project=private_project).order_by('id')) + self.assertMatchSnapshot(project_log.diff, "public-project:project-change:diff") + project_logs = list(ProjectChangeLog.objects.filter(project=private_project).order_by("id")) assert project_logs[0].action == ProjectChangeLog.Action.MULTIPLE - self.assertMatchSnapshot(project_logs[0].diff, 'private-project-0:project-change:diff') + self.assertMatchSnapshot(project_logs[0].diff, "private-project-0:project-change:diff") assert project_logs[1].action == ProjectChangeLog.Action.FRAMEWORK - self.assertMatchSnapshot(project_logs[1].diff, 'private-project-1:project-change:diff') + self.assertMatchSnapshot(project_logs[1].diff, "private-project-1:project-change:diff") def test_project_region_action_mutation(self): - query = ''' + query = """ mutation MyMutation ($projectId: ID!, $regionsToAdd: [ID!], $regionsToRemove: [ID!]) { project(id: $projectId) { projectRegionBulk(regionsToAdd: $regionsToAdd, regionsToRemove: $regionsToRemove) { @@ -496,27 +487,27 @@ def test_project_region_action_mutation(self): } } } - ''' + """ user = UserFactory.create() normal_user = UserFactory.create() another_user = UserFactory.create() af = AnalysisFrameworkFactory.create() - project = ProjectFactory.create(title='Project 101', analysis_framework=af) + project = ProjectFactory.create(title="Project 101", analysis_framework=af) project.add_member(user, role=self.project_role_owner) project.add_member(normal_user) - region_public_zero = RegionFactory.create(title='public-region-zero') - region_public = RegionFactory.create(title='public-region') - region_private = RegionFactory.create(title='private-region', public=False) - region_private_owner = RegionFactory.create(title='private-region-owner', public=False, created_by=user) + region_public_zero = RegionFactory.create(title="public-region-zero") + region_public = RegionFactory.create(title="public-region") + region_private = RegionFactory.create(title="private-region", public=False) + region_private_owner = RegionFactory.create(title="private-region-owner", public=False, created_by=user) # Region with project membership # -- Normal - region_private_with_membership = RegionFactory.create(title='private-region-with-membership', public=False) + region_private_with_membership = RegionFactory.create(title="private-region-with-membership", public=False) another_project_for_membership = ProjectFactory.create() another_project_for_membership.regions.add(region_private_with_membership) another_project_for_membership.add_member(user, role=self.project_role_admin) # -- Admin - region_private_with_membership_admin = RegionFactory.create(title='private-region-with-membership', public=False) + region_private_with_membership_admin = RegionFactory.create(title="private-region-with-membership", public=False) another_project_for_membership_admin = ProjectFactory.create() another_project_for_membership_admin.regions.add(region_private_with_membership_admin) another_project_for_membership_admin.add_member(user, role=self.project_role_admin) @@ -526,11 +517,11 @@ def test_project_region_action_mutation(self): def _query_check(add, remove, **kwargs): return self.query_check( query, - mnested=['project'], + mnested=["project"], variables={ - 'projectId': str(project.pk), - 'regionsToAdd': add, - 'regionsToRemove': remove, + "projectId": str(project.pk), + "regionsToAdd": add, + "regionsToRemove": remove, }, **kwargs, ) @@ -547,10 +538,13 @@ def _query_check(add, remove, **kwargs): self.force_login(user) # Simple checkup response = _query_check([], []) - self.assertEqual(response['data']['project']['projectRegionBulk'], { - 'deletedResult': [], - 'result': [], - }) + self.assertEqual( + response["data"]["project"]["projectRegionBulk"], + { + "deletedResult": [], + "result": [], + }, + ) # Add response = _query_check( @@ -564,23 +558,26 @@ def _query_check(add, remove, **kwargs): str(region_public_zero.pk), ], ) - self.assertEqual(response['data']['project']['projectRegionBulk'], { - 'deletedResult': [ - dict(id=str(region_public_zero.pk), title=region_public_zero.title), - ], - 'result': [ - dict(id=str(region_public.pk), title=region_public.title), - dict(id=str(region_private_owner.pk), title=region_private_owner.title), - dict(id=str(region_private_with_membership.pk), title=region_private_with_membership.title), - ], - }) self.assertEqual( - list(project.regions.values_list('id', flat=True).order_by('id')), + response["data"]["project"]["projectRegionBulk"], + { + "deletedResult": [ + dict(id=str(region_public_zero.pk), title=region_public_zero.title), + ], + "result": [ + dict(id=str(region_public.pk), title=region_public.title), + dict(id=str(region_private_owner.pk), title=region_private_owner.title), + dict(id=str(region_private_with_membership.pk), title=region_private_with_membership.title), + ], + }, + ) + self.assertEqual( + list(project.regions.values_list("id", flat=True).order_by("id")), [ region_public.pk, region_private_owner.pk, region_private_with_membership.pk, - ] + ], ) # Delete @@ -593,27 +590,28 @@ def _query_check(add, remove, **kwargs): str(region_private_with_membership.pk), ], ) - self.assertEqual(response['data']['project']['projectRegionBulk'], { - 'deletedResult': [ - dict(id=str(region_public.pk), title=region_public.title), - dict(id=str(region_private_owner.pk), title=region_private_owner.title), - dict(id=str(region_private_with_membership.pk), title=region_private_with_membership.title), - ], - 'result': [], - }) - self.assertEqual(list(project.regions.values_list('id', flat=True).order_by('id')), []) + self.assertEqual( + response["data"]["project"]["projectRegionBulk"], + { + "deletedResult": [ + dict(id=str(region_public.pk), title=region_public.title), + dict(id=str(region_private_owner.pk), title=region_private_owner.title), + dict(id=str(region_private_with_membership.pk), title=region_private_with_membership.title), + ], + "result": [], + }, + ) + self.assertEqual(list(project.regions.values_list("id", flat=True).order_by("id")), []) # Check Project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) class TestProjectJoinMutation(GraphQLTestCase): def setUp(self): - self.project_join_mutation = ''' + self.project_join_mutation = """ mutation Mutation($input: ProjectJoinRequestInputType!) { joinProject(data: $input) { ok @@ -631,13 +629,10 @@ def setUp(self): } } } - ''' + """ super().setUp() - @mock.patch( - 'project.serializers.send_project_join_request_emails.delay', - side_effect=send_project_join_request_emails.delay - ) + @mock.patch("project.serializers.send_project_join_request_emails.delay", side_effect=send_project_join_request_emails.delay) def test_valid_project_join(self, send_project_join_request_email_mock): user = UserFactory.create() admin_user = UserFactory.create() @@ -647,15 +642,13 @@ def test_valid_project_join(self, send_project_join_request_email_mock): minput = dict(project=project.id, reason=reason) self.force_login(user) notification_qs = Notification.objects.filter( - receiver=admin_user, - project=project, - notification_type=Notification.Type.PROJECT_JOIN_REQUEST + receiver=admin_user, project=project, notification_type=Notification.Type.PROJECT_JOIN_REQUEST ) old_count = notification_qs.count() with self.captureOnCommitCallbacks(execute=True): content = self.query_check(self.project_join_mutation, minput=minput, okay=True) - self.assertEqual(content['data']['joinProject']['result']['requestedBy']['id'], str(user.id), content) - self.assertEqual(content['data']['joinProject']['result']['project']['id'], str(project.id), content) + self.assertEqual(content["data"]["joinProject"]["result"]["requestedBy"]["id"], str(user.id), content) + self.assertEqual(content["data"]["joinProject"]["result"]["project"]["id"], str(project.id), content) send_project_join_request_email_mock.assert_called_once() # confirm that the notification is also created assert notification_qs.count() > old_count @@ -668,7 +661,7 @@ def test_already_member_project(self): minput = dict(project=project.id, reason=reason) self.force_login(user) content = self.query_check(self.project_join_mutation, minput=minput, okay=False) - self.assertEqual(len(content['data']['joinProject']['errors']), 1, content) + self.assertEqual(len(content["data"]["joinProject"]["errors"]), 1, content) def test_project_join_reason_length(self): user = UserFactory.create() @@ -678,20 +671,20 @@ def test_project_join_reason_length(self): self.force_login(user) # Invalid content = self.query_check(self.project_join_mutation, minput=minput, okay=False) - self.assertEqual(len(content['data']['joinProject']['errors']), 1, content) + self.assertEqual(len(content["data"]["joinProject"]["errors"]), 1, content) # Invalid - minput['reason'] = fuzzy.FuzzyText(length=501).fuzz() + minput["reason"] = fuzzy.FuzzyText(length=501).fuzz() content = self.query_check(self.project_join_mutation, minput=minput, okay=False) - self.assertEqual(len(content['data']['joinProject']['errors']), 1, content) + self.assertEqual(len(content["data"]["joinProject"]["errors"]), 1, content) # Valid (Project 1) max=500 - minput['reason'] = fuzzy.FuzzyText(length=500).fuzz() + minput["reason"] = fuzzy.FuzzyText(length=500).fuzz() content = self.query_check(self.project_join_mutation, minput=minput, okay=True) - self.assertEqual(content['data']['joinProject']['errors'], None, content) + self.assertEqual(content["data"]["joinProject"]["errors"], None, content) # Valid (Project 2) min=50 - minput['reason'] = fuzzy.FuzzyText(length=50).fuzz() - minput['project'] = project2.pk + minput["reason"] = fuzzy.FuzzyText(length=50).fuzz() + minput["project"] = project2.pk content = self.query_check(self.project_join_mutation, minput=minput, okay=True) - self.assertEqual(content['data']['joinProject']['errors'], None, content) + self.assertEqual(content["data"]["joinProject"]["errors"], None, content) def test_join_private_project(self): user = UserFactory.create() @@ -700,7 +693,7 @@ def test_join_private_project(self): minput = dict(project=project.id, reason=reason) self.force_login(user) content = self.query_check(self.project_join_mutation, minput=minput, okay=False) - self.assertEqual(len(content['data']['joinProject']['errors']), 1, content) + self.assertEqual(len(content["data"]["joinProject"]["errors"]), 1, content) def test_already_request_sent_for_project(self): user = UserFactory.create() @@ -716,12 +709,12 @@ def test_already_request_sent_for_project(self): minput = dict(project=project.id, reason=reason) self.force_login(user) content = self.query_check(self.project_join_mutation, minput=minput, okay=False) - self.assertEqual(len(content['data']['joinProject']['errors']), 1, content) + self.assertEqual(len(content["data"]["joinProject"]["errors"]), 1, content) class TestProjectJoinDeleteMutation(GraphQLTestCase): def setUp(self): - self.project_join_request_delete_mutation = ''' + self.project_join_request_delete_mutation = """ mutation Mutation($projectId: ID!) { projectJoinRequestDelete(projectId: $projectId) { ok @@ -738,7 +731,7 @@ def setUp(self): } } } - ''' + """ super().setUp() def test_delete_project_join_request(self): @@ -754,13 +747,13 @@ def test_delete_project_join_request(self): old_join_request_count = join_request_qs.count() self.force_login(user) - self.query_check(self.project_join_request_delete_mutation, variables={'projectId': project.id}, okay=True) + self.query_check(self.project_join_request_delete_mutation, variables={"projectId": project.id}, okay=True) self.assertEqual(join_request_qs.count(), old_join_request_count - 1) class TestProjectJoinAcceptRejectMutation(GraphQLSnapShotTestCase): def setUp(self): - self.projet_accept_reject_mutation = ''' + self.projet_accept_reject_mutation = """ mutation MyMutation ($projectId: ID! $joinRequestId: ID! $input: ProjectAcceptRejectInputType!) { project(id: $projectId) { acceptRejectProject(id: $joinRequestId, data: $input) { @@ -782,37 +775,28 @@ def setUp(self): } } } - ''' + """ super().setUp() - @mock.patch( - 'project.serializers.send_project_accept_email.delay', - side_effect=send_project_accept_email.delay - ) + @mock.patch("project.serializers.send_project_accept_email.delay", side_effect=send_project_accept_email.delay) def test_project_join_request_accept(self, send_project_accept_email_mock): user = UserFactory.create() user2 = UserFactory.create() project = ProjectFactory.create() project.add_member(user, role=self.project_role_admin) join_request = ProjectJoinRequestFactory.create( - requested_by=user2, - project=project, - role=ProjectRole.get_default_role(), - status=ProjectJoinRequest.Status.PENDING + requested_by=user2, project=project, role=ProjectRole.get_default_role(), status=ProjectJoinRequest.Status.PENDING ) - minput = dict(status=self.genum(ProjectJoinRequest.Status.ACCEPTED), role='normal') + minput = dict(status=self.genum(ProjectJoinRequest.Status.ACCEPTED), role="normal") # without login self.query_check( self.projet_accept_reject_mutation, minput=minput, - variables={'projectId': project.id, 'joinRequestId': join_request.id}, - assert_for_error=True - ) - notification_qs = Notification.objects.filter( - receiver=user, - notification_type=Notification.Type.PROJECT_JOIN_RESPONSE + variables={"projectId": project.id, "joinRequestId": join_request.id}, + assert_for_error=True, ) + notification_qs = Notification.objects.filter(receiver=user, notification_type=Notification.Type.PROJECT_JOIN_RESPONSE) old_count = notification_qs.count() # with login @@ -821,53 +805,41 @@ def test_project_join_request_accept(self, send_project_accept_email_mock): content = self.query_check( self.projet_accept_reject_mutation, minput=minput, - variables={'projectId': project.id, 'joinRequestId': join_request.id} + variables={"projectId": project.id, "joinRequestId": join_request.id}, ) + self.assertEqual(content["data"]["project"]["acceptRejectProject"]["result"]["requestedBy"]["id"], str(user2.id), content) + self.assertEqual(content["data"]["project"]["acceptRejectProject"]["result"]["respondedBy"]["id"], str(user.id), content) self.assertEqual( - content['data']['project']['acceptRejectProject']['result']['requestedBy']['id'], - str(user2.id), content - ) - self.assertEqual( - content['data']['project']['acceptRejectProject']['result']['respondedBy']['id'], - str(user.id), content - ) - self.assertEqual( - content['data']['project']['acceptRejectProject']['result']['status'], + content["data"]["project"]["acceptRejectProject"]["result"]["status"], self.genum(ProjectJoinRequest.Status.ACCEPTED), - content + content, ) # make sure memberships is created - self.assertIn(user2.id, ProjectMembership.objects.filter(project=project).values_list('member', flat=True)) + self.assertIn(user2.id, ProjectMembership.objects.filter(project=project).values_list("member", flat=True)) assert notification_qs.count() > old_count send_project_accept_email_mock.assert_called_once() # Check Project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) - @mock.patch( - 'project.serializers.send_project_reject_email.delay', - side_effect=send_project_reject_email.delay - ) + @mock.patch("project.serializers.send_project_reject_email.delay", side_effect=send_project_reject_email.delay) def test_project_join_request_reject(self, send_project_reject_email_mock): user = UserFactory.create() user2 = UserFactory.create() project = ProjectFactory.create() project.add_member(user, role=self.project_role_admin) - join_request = ProjectJoinRequestFactory.create(requested_by=user2, - project=project, - role=ProjectRole.get_default_role(), - status=ProjectJoinRequest.Status.PENDING) + join_request = ProjectJoinRequestFactory.create( + requested_by=user2, project=project, role=ProjectRole.get_default_role(), status=ProjectJoinRequest.Status.PENDING + ) minput = dict(status=self.genum(ProjectJoinRequest.Status.REJECTED)) # without login self.query_check( self.projet_accept_reject_mutation, minput=minput, - variables={'projectId': project.id, 'joinRequestId': join_request.id}, - assert_for_error=True + variables={"projectId": project.id, "joinRequestId": join_request.id}, + assert_for_error=True, ) # with login @@ -876,12 +848,12 @@ def test_project_join_request_reject(self, send_project_reject_email_mock): content = self.query_check( self.projet_accept_reject_mutation, minput=minput, - variables={'projectId': project.id, 'joinRequestId': join_request.id} + variables={"projectId": project.id, "joinRequestId": join_request.id}, ) self.assertEqual( - content['data']['project']['acceptRejectProject']['result']['status'], + content["data"]["project"]["acceptRejectProject"]["result"]["status"], self.genum(ProjectJoinRequest.Status.REJECTED), - content + content, ) send_project_reject_email_mock.assert_called_once() # Check project change logs @@ -892,7 +864,7 @@ class TestProjectMembershipMutation(GraphQLSnapShotTestCase): ENABLE_NOW_PATCHER = True def _user_membership_bulk(self, user_role): - query = ''' + query = """ mutation MyMutation( $id: ID!, $projectMembership: [BulkProjectMembershipInputType!], @@ -943,7 +915,7 @@ def _user_membership_bulk(self, user_role): } } } - ''' + """ creater_user = UserFactory.create() user = UserFactory.create() low_permission_user = UserFactory.create() @@ -961,7 +933,7 @@ def _user_membership_bulk(self, user_role): ) = UserFactory.create_batch(8) project = ProjectFactory.create(created_by=creater_user) - user_group = UserGroupFactory.create(title='Group-1') + user_group = UserGroupFactory.create(title="Group-1") membership1 = project.add_member(member_user1, badges=[ProjectMembership.BadgeType.QA]) membership2 = project.add_member(member_user2) membership_using_user_group = project.add_member(member_user7, linked_group=user_group) @@ -1012,17 +984,18 @@ def _user_membership_bulk(self, user_role): clientId="member-user-2-with-user-group", role=self.project_role_member.pk, badges=[self.genum(ProjectMembership.BadgeType.QA)], - ) + ), ], ) def _query_check(**kwargs): return self.query_check( query, - mnested=['project'], - variables={'id': project.id, **minput}, + mnested=["project"], + variables={"id": project.id, **minput}, **kwargs, ) + # ---------- Without login _query_check(assert_for_error=True) # ---------- With login (with non-member) @@ -1034,33 +1007,33 @@ def _query_check(**kwargs): # ---------- With login (with higher permission) self.force_login(user) # ----------------- Some Invalid input - response = _query_check()['data']['project']['projectUserMembershipBulk'] - self.assertMatchSnapshot(response, 'try 1') + response = _query_check()["data"]["project"]["projectUserMembershipBulk"] + self.assertMatchSnapshot(response, "try 1") # ----------------- Another try - minput['projectMembership'].pop(1) - minput['projectMembership'].extend([ - # Invalid (changing member) - dict( - member=member_user6.pk, - clientId="member-user-2", - role=self.project_role_owner.pk, - id=membership2.pk, - ), - dict( - member=member_user2.pk, - clientId="member-user-2", - role=self.project_role_admin.pk, - id=membership2.pk, - ), - ]) - response = _query_check()['data']['project']['projectUserMembershipBulk'] - self.assertMatchSnapshot(response, 'try 2') + minput["projectMembership"].pop(1) + minput["projectMembership"].extend( + [ + # Invalid (changing member) + dict( + member=member_user6.pk, + clientId="member-user-2", + role=self.project_role_owner.pk, + id=membership2.pk, + ), + dict( + member=member_user2.pk, + clientId="member-user-2", + role=self.project_role_admin.pk, + id=membership2.pk, + ), + ] + ) + response = _query_check()["data"]["project"]["projectUserMembershipBulk"] + self.assertMatchSnapshot(response, "try 2") # Check project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) def test_user_membership_using_clairvoyan_one_bulk(self): @@ -1070,7 +1043,7 @@ def test_user_membership_admin_bulk(self): self._user_membership_bulk(self.project_role_admin) def _user_group_membership_bulk(self, user_role): - query = ''' + query = """ mutation MyMutation( $id: ID!, $projectMembership: [BulkProjectUserGroupMembershipInputType!], @@ -1121,7 +1094,7 @@ def _user_group_membership_bulk(self, user_role): } } } - ''' + """ project = ProjectFactory.create() def _add_member(usergroup, role=None, badges=[]): @@ -1143,7 +1116,7 @@ def _add_member(usergroup, role=None, badges=[]): member_user_group3, member_user_group4, member_user_group5, - member_user_group6 + member_user_group6, ) = UserGroupFactory.create_batch(7) membership1 = _add_member(member_user_group1, badges=[ProjectMembership.BadgeType.QA]) @@ -1192,40 +1165,41 @@ def _add_member(usergroup, role=None, badges=[]): def _query_check(**kwargs): return self.query_check( query, - mnested=['project'], - variables={'id': project.id, **minput}, + mnested=["project"], + variables={"id": project.id, **minput}, **kwargs, ) + # ---------- With login (with higher permission) self.force_login(user) # ----------------- Some Invalid input - response = _query_check()['data']['project']['projectUserGroupMembershipBulk'] - self.assertMatchSnapshot(response, 'try 1') + response = _query_check()["data"]["project"]["projectUserGroupMembershipBulk"] + self.assertMatchSnapshot(response, "try 1") # ----------------- Another try - minput['projectMembership'].pop(1) - minput['projectMembership'].extend([ - # Invalid (changing member) - dict( - usergroup=member_user_group6.pk, - clientId="member-user-2", - role=self.project_role_owner.pk, - id=membership2.pk, - ), - dict( - usergroup=member_user_group2.pk, - clientId="member-user-2", - role=self.project_role_admin.pk, - id=membership2.pk, - ), - ]) - response = _query_check()['data']['project']['projectUserGroupMembershipBulk'] - self.assertMatchSnapshot(response, 'try 2') + minput["projectMembership"].pop(1) + minput["projectMembership"].extend( + [ + # Invalid (changing member) + dict( + usergroup=member_user_group6.pk, + clientId="member-user-2", + role=self.project_role_owner.pk, + id=membership2.pk, + ), + dict( + usergroup=member_user_group2.pk, + clientId="member-user-2", + role=self.project_role_admin.pk, + id=membership2.pk, + ), + ] + ) + response = _query_check()["data"]["project"]["projectUserGroupMembershipBulk"] + self.assertMatchSnapshot(response, "try 2") # Check project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) def test_user_group_membership_using_clairvoyan_one_bulk(self): @@ -1235,7 +1209,7 @@ def test_user_group_membership_admin_bulk(self): self._user_group_membership_bulk(self.project_role_admin) def test_project_deletion(self): - query = ''' + query = """ mutation MyMutation($projectId: ID!) { __typename project(id: $projectId) { @@ -1249,7 +1223,7 @@ def test_project_deletion(self): } } } - ''' + """ normal_user = UserFactory.create() admin_user = UserFactory.create() member_user = UserFactory.create() @@ -1277,10 +1251,11 @@ def _assert_project_soft_delete_status(is_deleted): def _query_check(**kwargs): return self.query_check( query, - mnested=['project'], - variables={'projectId': project.id}, + mnested=["project"], + variables={"projectId": project.id}, **kwargs, ) + # without login _query_check(assert_for_error=True) _assert_project_soft_delete_status(False) @@ -1311,17 +1286,13 @@ def _query_check(**kwargs): _assert_project_soft_delete_status(True) # Check project change logs self.assertMatchSnapshot( - list( - ProjectChangeLog.objects.filter(project=project).order_by('id').values('action', 'diff') - ), - 'project-change-log', + list(ProjectChangeLog.objects.filter(project=project).order_by("id").values("action", "diff")), + "project-change-log", ) def test_project_deletion_celery_task(self): def _get_project_ids(): - return list( - Project.objects.values_list('id', flat=True) - ) + return list(Project.objects.values_list("id", flat=True)) # Check with single project project = ProjectFactory.create() @@ -1334,32 +1305,30 @@ def _get_project_ids(): # Check with multiple projects project1 = ProjectFactory.create( - title='Test Project 1', - is_deleted=True, - deleted_at=self.now_datetime - timedelta(days=32) + title="Test Project 1", is_deleted=True, deleted_at=self.now_datetime - timedelta(days=32) ) - project2 = ProjectFactory.create(title='Test Project 2') + project2 = ProjectFactory.create(title="Test Project 2") project2_1 = ProjectFactory.create( title="Test Project 2 [Don't Delete']", deleted_at=self.now_datetime - timedelta(days=32), ) project3 = ProjectFactory.create( - title='Test Project 3', + title="Test Project 3", is_deleted=True, deleted_at=self.now_datetime - timedelta(days=42), ) project4 = ProjectFactory.create( - title='Test Project 4', + title="Test Project 4", is_deleted=True, deleted_at=self.now_datetime - timedelta(days=20), ) project5 = ProjectFactory.create( - title='Test Project 5', + title="Test Project 5", is_deleted=True, deleted_at=self.now_datetime - timedelta(days=30), ) project6 = ProjectFactory.create( - title='Test Project 6', + title="Test Project 6", is_deleted=True, deleted_at=self.now_datetime - timedelta(days=29), ) @@ -1374,7 +1343,7 @@ def _get_project_ids(): [ project1.id, project3.id, - ] + ], ) self.assertEqual( project_ids, @@ -1384,11 +1353,11 @@ def _get_project_ids(): project4.id, project5.id, project6.id, - ] + ], ) def test_create_user_pinned_project(self): - query = ''' + query = """ mutation MyMutation($project: ID!) { createUserPinnedProject(data: {project: $project}) { ok @@ -1405,20 +1374,18 @@ def test_create_user_pinned_project(self): } } } - ''' + """ project1 = ProjectFactory.create( - title='Test Project 1', + title="Test Project 1", ) project2 = ProjectFactory.create( - title='Test Project 2', + title="Test Project 2", ) member_user = UserFactory.create() owner_user = UserFactory.create() project1.add_member(member_user, role=self.project_role_member) project2.add_member(owner_user, role=self.project_role_owner) - minput = dict( - project=project1.id - ) + minput = dict(project=project1.id) def _query_check(**kwargs): return self.query_check( @@ -1426,35 +1393,33 @@ def _query_check(**kwargs): variables=minput, **kwargs, ) + self.force_login(member_user) - response = _query_check()['data']['createUserPinnedProject']['result'] - self.assertEqual(response['clientId'], str(project1.id)) - self.assertEqual(response['order'], 1) - self.assertEqual(response['user']['id'], str(member_user.id)) - self.assertEqual(response['project']['id'], str(project1.id)) + response = _query_check()["data"]["createUserPinnedProject"]["result"] + self.assertEqual(response["clientId"], str(project1.id)) + self.assertEqual(response["order"], 1) + self.assertEqual(response["user"]["id"], str(member_user.id)) + self.assertEqual(response["project"]["id"], str(project1.id)) # pin project which is already pinned by user - response = _query_check(assert_for_error=True)['errors'] - self.assertIn("Project already pinned!!", response[0]['message']) + response = _query_check(assert_for_error=True)["errors"] + self.assertIn("Project already pinned!!", response[0]["message"]) # pin another project - minput['project'] = project2.id - response = _query_check()['data']['createUserPinnedProject']['result'] - self.assertEqual(response['clientId'], str(project2.id)) - self.assertEqual(response['order'], 2) - self.assertEqual(response['project']['id'], str(project2.id)) + minput["project"] = project2.id + response = _query_check()["data"]["createUserPinnedProject"]["result"] + self.assertEqual(response["clientId"], str(project2.id)) + self.assertEqual(response["order"], 2) + self.assertEqual(response["project"]["id"], str(project2.id)) def test_bulk_reorder_pinned_project(self): - project1 = ProjectFactory.create(title='Test project 3') - project2 = ProjectFactory.create(title='Test project 4') + project1 = ProjectFactory.create(title="Test project 3") + project2 = ProjectFactory.create(title="Test project 4") member_user = UserFactory.create() project1.add_member(member_user, role=self.project_role_member) project2.add_member(member_user, role=self.project_role_member) pinned_project1 = ProjectPinnedFactory.create(project=project1, user=member_user, order=10) # pinned_project2 = ProjectPinnedFactory.create(project=project2, user=member_user, order=12) - minput = dict( - order=14, - id=pinned_project1.id - ) - query = ''' + minput = dict(order=14, id=pinned_project1.id) + query = """ mutation MyMutation($bulkReorder: UserPinnedProjectReOrderInputType!) { reorderPinnedProjects(items: $bulkReorder) { errors @@ -1472,12 +1437,9 @@ def test_bulk_reorder_pinned_project(self): } } } - ''' + """ def _query_check(**kwargs): - return self.query_check( - query, - variable=minput, - **kwargs - ) + return self.query_check(query, variable=minput, **kwargs) + self.force_login(member_user) diff --git a/apps/project/tests/test_schemas.py b/apps/project/tests/test_schemas.py index 88735954b7..4546d31b10 100644 --- a/apps/project/tests/test_schemas.py +++ b/apps/project/tests/test_schemas.py @@ -1,46 +1,47 @@ -import pytz from datetime import datetime -from dateutil.relativedelta import relativedelta from unittest.mock import patch -from django.utils import timezone +import pytz +from analysis_framework.factories import AnalysisFrameworkFactory +from ary.factories import AssessmentTemplateFactory +from dateutil.relativedelta import relativedelta from django.contrib.gis.geos import Point from django.core.cache import cache - -from utils.graphene.tests import GraphQLTestCase, GraphQLSnapShotTestCase - +from django.utils import timezone +from entry.factories import EntryFactory +from export.factories import ExportFactory +from geo.enums import GeoAreaOrderingEnum +from geo.factories import AdminLevelFactory, GeoAreaFactory, RegionFactory +from lead.factories import LeadFactory from lead.models import Lead +from project.factories import ( + ProjectFactory, + ProjectJoinRequestFactory, + ProjectPinnedFactory, +) from project.models import ( - ProjectMembership, - ProjectUserGroupMembership, - ProjectStats, Project, + ProjectMembership, ProjectRole, + ProjectStats, + ProjectUserGroupMembership, ) -from deep.permissions import ProjectPermissions as PP -from deep.caches import CacheKey -from deep.trackers import schedule_tracker_data_handler - +from project.tasks import _generate_project_stats_cache +from quality_assurance.factories import EntryReviewCommentFactory from user.factories import UserFactory from user_group.factories import UserGroupFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory -from project.factories import ProjectFactory, ProjectJoinRequestFactory, ProjectPinnedFactory -from analysis_framework.factories import AnalysisFrameworkFactory -from geo.factories import RegionFactory, AdminLevelFactory, GeoAreaFactory -from ary.factories import AssessmentTemplateFactory -from export.factories import ExportFactory -from quality_assurance.factories import EntryReviewCommentFactory -from project.tasks import _generate_project_stats_cache -from geo.enums import GeoAreaOrderingEnum +from deep.caches import CacheKey +from deep.permissions import ProjectPermissions as PP +from deep.trackers import schedule_tracker_data_handler +from utils.graphene.tests import GraphQLSnapShotTestCase, GraphQLTestCase from .test_mutations import TestProjectGeneralMutationSnapshotTest class TestProjectSchema(GraphQLTestCase): def test_project_recent_activities(self): - query = ''' + query = """ query MyQuery { recentActivities { createdAt @@ -71,7 +72,7 @@ def test_project_recent_activities(self): leadId } } - ''' + """ normal_user, member_user = UserFactory.create_batch(2) af = AnalysisFrameworkFactory.create() @@ -91,17 +92,17 @@ def test_project_recent_activities(self): self.force_login(normal_user) response = self.query_check(query) - self.assertEqual(len(response['data']['recentActivities']), 0) + self.assertEqual(len(response["data"]["recentActivities"]), 0) self.force_login(member_user) response = self.query_check(query) - self.assertEqual(len(response['data']['recentActivities']), 12) + self.assertEqual(len(response["data"]["recentActivities"]), 12) def test_project_query(self): """ Test private + non-private project behaviour """ - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { id @@ -140,13 +141,12 @@ def test_project_query(self): } } } - ''' + """ user = UserFactory.create() analysis_framework = AnalysisFrameworkFactory.create() public_project, public_project2, public_project3, public_project4 = ProjectFactory.create_batch( - 4, - analysis_framework=analysis_framework + 4, analysis_framework=analysis_framework ) now = timezone.now() lead1_1 = self.update_obj(LeadFactory.create(project=public_project), created_at=now + relativedelta(months=-1)) @@ -191,14 +191,17 @@ def test_project_query(self): "controlled": True, "months": -3, }, - ] now = timezone.now() for item in data: self.update_obj( - EntryFactory.create(lead=item['lead'], controlled=item['controlled'], - project=public_project, analysis_framework=analysis_framework), - created_at=now + relativedelta(months=item['months']) + EntryFactory.create( + lead=item["lead"], + controlled=item["controlled"], + project=public_project, + analysis_framework=analysis_framework, + ), + created_at=now + relativedelta(months=item["months"]), ) EntryFactory.create(lead=lead1_3, project=public_project, controlled=True, analysis_framework=analysis_framework) EntryFactory.create(lead=lead1_4, project=public_project, controlled=True, analysis_framework=analysis_framework) @@ -212,13 +215,11 @@ def test_project_query(self): public_project = ProjectFactory.create(analysis_framework=analysis_framework) private_project = ProjectFactory.create(is_private=True, analysis_framework=analysis_framework) ProjectJoinRequestFactory.create( - project=public_project, requested_by=request_user, - status='pending', role=self.project_role_admin + project=public_project, requested_by=request_user, status="pending", role=self.project_role_admin ) # create projectJoinRequest(status='rejected') ProjectJoinRequestFactory.create( - project=public_project4, requested_by=request_user, - status='rejected', role=self.project_role_admin + project=public_project4, requested_by=request_user, status="rejected", role=self.project_role_admin ) # add some project member public_project.add_member(user) @@ -232,12 +233,8 @@ def test_project_query(self): LeadFactory.create(project=private_project) # add some entry for the project - EntryFactory.create_batch( - 4, - project=public_project, analysis_framework=analysis_framework, lead=lead - ) - entry2_1 = EntryFactory.create( - project=public_project, analysis_framework=analysis_framework, lead=lead2, controlled=True) + EntryFactory.create_batch(4, project=public_project, analysis_framework=analysis_framework, lead=lead) + entry2_1 = EntryFactory.create(project=public_project, analysis_framework=analysis_framework, lead=lead2, controlled=True) entry2_2 = EntryFactory.create(project=public_project, analysis_framework=analysis_framework, lead=lead2) EntryFactory.create(project=private_project, analysis_framework=analysis_framework, lead=lead) @@ -249,7 +246,7 @@ def test_project_query(self): # NOTE: Right noe only IN_PROGRESS status is set automatically # Control one lead lead2.status = Lead.Status.TAGGED - lead2.save(update_fields=('status',)) + lead2.save(update_fields=("status",)) # lets add some regions to project region1, region2, region3 = RegionFactory.create_batch(3) @@ -261,64 +258,64 @@ def test_project_query(self): _generate_project_stats_cache() # -- Without login - self.query_check(query, assert_for_error=True, variables={'id': public_project.id}) - self.query_check(query, assert_for_error=True, variables={'id': private_project.id}) + self.query_check(query, assert_for_error=True, variables={"id": public_project.id}) + self.query_check(query, assert_for_error=True, variables={"id": private_project.id}) # -- With login self.force_login(user) # --- non-member user - content = self.query_check(query, variables={'id': public_project.id}) - self.assertNotEqual(content['data']['project'], None, content) - content = self.query_check(query, variables={'id': private_project.id}) - self.assertEqual(content['data']['project'], None, content) + content = self.query_check(query, variables={"id": public_project.id}) + self.assertNotEqual(content["data"]["project"], None, content) + content = self.query_check(query, variables={"id": private_project.id}) + self.assertEqual(content["data"]["project"], None, content) # login with non_member self.force_login(non_member_user) - content = self.query_check(query, variables={'id': public_project.id}) - self.assertNotEqual(content['data']['project'], None, content) - self.assertEqual(content['data']['project']['membershipPending'], False) + content = self.query_check(query, variables={"id": public_project.id}) + self.assertNotEqual(content["data"]["project"], None, content) + self.assertEqual(content["data"]["project"]["membershipPending"], False) # login with request_user self.force_login(request_user) - content = self.query_check(query, variables={'id': public_project4.id}) - self.assertNotEqual(content['data']['project'], None, content) - self.assertEqual(content['data']['project']['isRejected'], True) + content = self.query_check(query, variables={"id": public_project4.id}) + self.assertNotEqual(content["data"]["project"], None, content) + self.assertEqual(content["data"]["project"]["isRejected"], True) # --- member user # ---- (public-project) self.force_login(user) - content = self.query_check(query, variables={'id': public_project.id}) - self.assertNotEqual(content['data']['project'], None, content) - self.assertEqual(content['data']['project']['stats']['numberOfLeads'], 5, content) - self.assertEqual(content['data']['project']['stats']['numberOfLeadsNotTagged'], 3, content) - self.assertEqual(content['data']['project']['stats']['numberOfLeadsInProgress'], 1, content) - self.assertEqual(content['data']['project']['stats']['numberOfLeadsTagged'], 1, content) - self.assertEqual(content['data']['project']['stats']['numberOfEntries'], 6, content) - self.assertEqual(content['data']['project']['stats']['numberOfEntriesVerified'], 2, content) - self.assertEqual(content['data']['project']['stats']['numberOfEntriesControlled'], 1, content) - self.assertEqual(content['data']['project']['stats']['numberOfUsers'], 3, content) - self.assertEqual(len(content['data']['project']['stats']['leadsActivity']), 1, content) - self.assertEqual(len(content['data']['project']['stats']['entriesActivity']), 1, content) - self.assertEqual(len(content['data']['project']['regions']), 2, content) - self.assertListIds(content['data']['project']['regions'], [region1, region2], content) + content = self.query_check(query, variables={"id": public_project.id}) + self.assertNotEqual(content["data"]["project"], None, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfLeads"], 5, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfLeadsNotTagged"], 3, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfLeadsInProgress"], 1, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfLeadsTagged"], 1, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfEntries"], 6, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfEntriesVerified"], 2, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfEntriesControlled"], 1, content) + self.assertEqual(content["data"]["project"]["stats"]["numberOfUsers"], 3, content) + self.assertEqual(len(content["data"]["project"]["stats"]["leadsActivity"]), 1, content) + self.assertEqual(len(content["data"]["project"]["stats"]["entriesActivity"]), 1, content) + self.assertEqual(len(content["data"]["project"]["regions"]), 2, content) + self.assertListIds(content["data"]["project"]["regions"], [region1, region2], content) # login with request user self.force_login(request_user) - content = self.query_check(query, variables={'id': public_project.id}) - self.assertNotEqual(content['data']['project'], None, content) - self.assertEqual(content['data']['project']['membershipPending'], True) + content = self.query_check(query, variables={"id": public_project.id}) + self.assertNotEqual(content["data"]["project"], None, content) + self.assertEqual(content["data"]["project"]["membershipPending"], True) # ---- (private-project) self.force_login(user) private_project.add_member(user) - content = self.query_check(query, variables={'id': private_project.id}) - self.assertNotEqual(content['data']['project'], None, content) - self.assertEqual(len(content['data']['project']['regions']), 1, content) - self.assertListIds(content['data']['project']['regions'], [region3], content) + content = self.query_check(query, variables={"id": private_project.id}) + self.assertNotEqual(content["data"]["project"], None, content) + self.assertEqual(len(content["data"]["project"]["regions"]), 1, content) + self.assertListIds(content["data"]["project"]["regions"], [region3], content) def test_project_query_has_assesment_af(self): - query = ''' + query = """ query MyQuery { projects(ordering: ASC_TITLE) { results { @@ -327,29 +324,33 @@ def test_project_query_has_assesment_af(self): } } } - ''' + """ user = UserFactory.create() analysis_framework = AnalysisFrameworkFactory.create() assessment_template = AssessmentTemplateFactory.create() project1 = ProjectFactory.create(analysis_framework=analysis_framework, assessment_template=assessment_template) - project2 = ProjectFactory.create(analysis_framework=analysis_framework,) + project2 = ProjectFactory.create( + analysis_framework=analysis_framework, + ) project3 = ProjectFactory.create(assessment_template=assessment_template) self.force_login(user) - projects = self.query_check(query)['data']['projects']['results'] - for index, (_id, has_af) in enumerate([ - (project1.pk, True), - (project2.pk, True), - (project3.pk, False), - ]): - self.assertIdEqual(projects[index]['id'], _id, projects) - self.assertEqual(projects[index]['hasAnalysisFramework'], has_af, projects) + projects = self.query_check(query)["data"]["projects"]["results"] + for index, (_id, has_af) in enumerate( + [ + (project1.pk, True), + (project2.pk, True), + (project3.pk, False), + ] + ): + self.assertIdEqual(projects[index]["id"], _id, projects) + self.assertEqual(projects[index]["hasAnalysisFramework"], has_af, projects) def test_projects_query(self): """ Test private + non-private project list behaviour """ - query = ''' + query = """ query MyQuery { projects (ordering: ASC_TITLE) { page @@ -365,7 +366,7 @@ def test_projects_query(self): } } } - ''' + """ user = UserFactory.create() public_project = ProjectFactory.create() @@ -379,19 +380,19 @@ def test_projects_query(self): # --- non-member user (only public project is listed) content = self.query_check(query) - self.assertEqual(content['data']['projects']['totalCount'], 1, content) - self.assertEqual(content['data']['projects']['results'][0]['id'], str(public_project.pk), content) + self.assertEqual(content["data"]["projects"]["totalCount"], 1, content) + self.assertEqual(content["data"]["projects"]["results"][0]["id"], str(public_project.pk), content) # --- member user (all public project is listed) public_project.add_member(user) private_project.add_member(user) content = self.query_check(query) - self.assertEqual(content['data']['projects']['totalCount'], 2, content) - self.assertEqual(content['data']['projects']['results'][0]['id'], str(public_project.pk), content) - self.assertEqual(content['data']['projects']['results'][1]['id'], str(private_project.pk), content) + self.assertEqual(content["data"]["projects"]["totalCount"], 2, content) + self.assertEqual(content["data"]["projects"]["results"][0]["id"], str(public_project.pk), content) + self.assertEqual(content["data"]["projects"]["results"][1]["id"], str(private_project.pk), content) def test_public_projects(self): - query = ''' + query = """ query MyQuery { publicProjects (ordering: ASC_TITLE) { page @@ -411,20 +412,14 @@ def test_public_projects(self): } } } - ''' + """ # Lets create some analysis_framework(private + publice) - public_af = AnalysisFrameworkFactory.create( - is_private=False, - title='Public Analysis Framework Title' - ) - private_af = AnalysisFrameworkFactory.create( - title='Private Analysis Framework Title', - is_private=True - ) + public_af = AnalysisFrameworkFactory.create(is_private=False, title="Public Analysis Framework Title") + private_af = AnalysisFrameworkFactory.create(title="Private Analysis Framework Title", is_private=True) # lets create some regions(private + public) - public_region = RegionFactory.create(public=True, title='Public Region') - private_region = RegionFactory.create(public=False, title='Private Region') + public_region = RegionFactory.create(public=True, title="Public Region") + private_region = RegionFactory.create(public=False, title="Private Region") # deleted_project ProjectFactory.create(analysis_framework=public_af, regions=[public_region], is_deleted=True) public_project1 = ProjectFactory.create(analysis_framework=public_af, regions=[public_region]) @@ -433,39 +428,25 @@ def test_public_projects(self): public_project4 = ProjectFactory.create(analysis_framework=private_af, regions=[public_region]) private_project = ProjectFactory.create(is_private=True) content = self.query_check(query) - self.assertEqual(content['data']['publicProjects']['totalCount'], 4, content) + self.assertEqual(content["data"]["publicProjects"]["totalCount"], 4, content) self.assertListIds( - content['data']['publicProjects']['results'], + content["data"]["publicProjects"]["results"], [public_project1, public_project2, public_project3, public_project4], - content + content, ) # some checks for analysis_framework private and public self.assertEqual( - content['data']['publicProjects']['results'][0]['analysisFrameworkTitle'], - 'Public Analysis Framework Title', - content - ) - self.assertEqual( - content['data']['publicProjects']['results'][2]['analysisFrameworkTitle'], - None, - content + content["data"]["publicProjects"]["results"][0]["analysisFrameworkTitle"], "Public Analysis Framework Title", content ) + self.assertEqual(content["data"]["publicProjects"]["results"][2]["analysisFrameworkTitle"], None, content) # some check for regions private and public - self.assertEqual( - content['data']['publicProjects']['results'][2]['regionsTitle'], - 'Public Region', - content - ) - self.assertEqual( - content['data']['publicProjects']['results'][1]['regionsTitle'], - '', - content - ) + self.assertEqual(content["data"]["publicProjects"]["results"][2]["regionsTitle"], "Public Region", content) + self.assertEqual(content["data"]["publicProjects"]["results"][1]["regionsTitle"], "", content) # make sure private projects are not visible here - self.assertNotListIds(content['data']['publicProjects']['results'], [private_project], content) + self.assertNotListIds(content["data"]["publicProjects"]["results"], [private_project], content) def test_project_geoareas(self): - query = ''' + query = """ query MyQuery( $projectID: ID!, $ids: [ID!], @@ -490,18 +471,18 @@ def test_project_geoareas(self): } } } - ''' + """ user = UserFactory.create() - region = RegionFactory.create(title='Nepal', is_published=True) + region = RegionFactory.create(title="Nepal", is_published=True) project = ProjectFactory.create() project.add_member(user) project.regions.add(region) - admin_level = AdminLevelFactory.create(title='District', region=region) - geo1 = GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu') - geo2 = GeoAreaFactory.create(admin_level=admin_level, title='Lalitpur') - GeoAreaFactory.create(admin_level=admin_level, title='Bhaktapur') + admin_level = AdminLevelFactory.create(title="District", region=region) + geo1 = GeoAreaFactory.create(admin_level=admin_level, title="Kathmandu") + geo2 = GeoAreaFactory.create(admin_level=admin_level, title="Lalitpur") + GeoAreaFactory.create(admin_level=admin_level, title="Bhaktapur") geo1_data = dict( id=str(geo1.pk), @@ -519,7 +500,7 @@ def test_project_geoareas(self): def _query_check(variables={}, **kwargs): return self.query_check( query, - variables={'projectID': project.id, **variables}, + variables={"projectID": project.id, **variables}, **kwargs, ) @@ -529,31 +510,31 @@ def _query_check(variables={}, **kwargs): # -- With login self.force_login(user) - content = _query_check()['data']['project']['geoAreas']['results'] + content = _query_check()["data"]["project"]["geoAreas"]["results"] self.assertEqual(len(content), 3, content) - filters = {'ids': [str(geo1.pk), str(geo2.pk)], 'ordering': self.genum(GeoAreaOrderingEnum.ASC_ID)} - content = _query_check(variables=filters)['data']['project']['geoAreas']['results'] + filters = {"ids": [str(geo1.pk), str(geo2.pk)], "ordering": self.genum(GeoAreaOrderingEnum.ASC_ID)} + content = _query_check(variables=filters)["data"]["project"]["geoAreas"]["results"] self.assertEqual(len(content), 2, content) self.assertEqual(content, [geo1_data, geo2_data], content) - filters = {'search': 'kathm', 'ordering': self.genum(GeoAreaOrderingEnum.ASC_ID)} - content = _query_check(variables=filters)['data']['project']['geoAreas']['results'] + filters = {"search": "kathm", "ordering": self.genum(GeoAreaOrderingEnum.ASC_ID)} + content = _query_check(variables=filters)["data"]["project"]["geoAreas"]["results"] self.assertEqual(len(content), 1, content) self.assertEqual(content, [geo1_data], content) - filters = {'titles': ['Kathmandu', 'lalitpur'], 'ordering': self.genum(GeoAreaOrderingEnum.ASC_ID)} - content = _query_check(variables=filters)['data']['project']['geoAreas']['results'] + filters = {"titles": ["Kathmandu", "lalitpur"], "ordering": self.genum(GeoAreaOrderingEnum.ASC_ID)} + content = _query_check(variables=filters)["data"]["project"]["geoAreas"]["results"] self.assertEqual(len(content), 2, content) self.assertEqual(content, [geo1_data, geo2_data], content) - filters = {'titles': ['Kath', 'lal'], 'ordering': self.genum(GeoAreaOrderingEnum.ASC_ID)} - content = _query_check(variables=filters)['data']['project']['geoAreas']['results'] + filters = {"titles": ["Kath", "lal"], "ordering": self.genum(GeoAreaOrderingEnum.ASC_ID)} + content = _query_check(variables=filters)["data"]["project"]["geoAreas"]["results"] self.assertEqual(len(content), 0, content) self.assertEqual(content, [], content) def test_project_stat_recent(self): - query = ''' + query = """ query MyQuery { recentProjects { id @@ -564,7 +545,7 @@ def test_project_stat_recent(self): currentUserRole } } - ''' + """ user = UserFactory.create() analysis_framework = AnalysisFrameworkFactory.create() @@ -575,9 +556,9 @@ def test_project_stat_recent(self): lead1 = LeadFactory.create(project=public_project1, created_by=user) LeadFactory.create(project=public_project2, created_by=user) - EntryFactory.create(lead=lead1, controlled=False, - created_by=user, project=public_project1, - analysis_framework=analysis_framework) + EntryFactory.create( + lead=lead1, controlled=False, created_by=user, project=public_project1, analysis_framework=analysis_framework + ) LeadFactory.create(project=public_project3, created_by=user) LeadFactory.create(project=public_project4, created_by=user) # -- Without login @@ -587,13 +568,13 @@ def test_project_stat_recent(self): self.force_login(user) content = self.query_check(query) - self.assertEqual(len(content['data']['recentProjects']), 3, content) - self.assertEqual(content['data']['recentProjects'][0]['id'], str(public_project3.pk), content) - self.assertEqual(content['data']['recentProjects'][1]['id'], str(public_project1.pk), content) - self.assertEqual(content['data']['recentProjects'][2]['id'], str(public_project2.pk), content) + self.assertEqual(len(content["data"]["recentProjects"]), 3, content) + self.assertEqual(content["data"]["recentProjects"][0]["id"], str(public_project3.pk), content) + self.assertEqual(content["data"]["recentProjects"][1]["id"], str(public_project1.pk), content) + self.assertEqual(content["data"]["recentProjects"][2]["id"], str(public_project2.pk), content) def test_project_allowed_permissions(self): - query = ''' + query = """ query MyQuery { projects { results { @@ -602,25 +583,22 @@ def test_project_allowed_permissions(self): } } } - ''' + """ project1, project2 = ProjectFactory.create_batch(2) user = UserFactory.create() project1.add_member(user, badges=[]) project2.add_member(user, badges=[ProjectMembership.BadgeType.QA]) self.force_login(user) - content_projects = self.query_check(query)['data']['projects']['results'] + content_projects = self.query_check(query)["data"]["projects"]["results"] QA_PERMISSION = self.genum(PP.Permission.CAN_QUALITY_CONTROL) - content_projects_permissions = { - int(pdata['id']): pdata['allowedPermissions'] - for pdata in content_projects - } + content_projects_permissions = {int(pdata["id"]): pdata["allowedPermissions"] for pdata in content_projects} self.assertEqual(len(content_projects), 2, content_projects) self.assertNotIn(QA_PERMISSION, content_projects_permissions[project1.pk], content_projects) self.assertIn(QA_PERMISSION, content_projects_permissions[project2.pk], content_projects) def test_projects_by_region(self): - query = ''' + query = """ query MyQuery ($projectFilter: RegionProjectFilterData) { projectsByRegion (projectFilter: $projectFilter) { totalCount @@ -631,102 +609,98 @@ def test_projects_by_region(self): } } } - ''' + """ user = UserFactory.create() region1 = RegionFactory.create() region2 = RegionFactory.create() - project1 = ProjectFactory.create(regions=[region1], title='Test Nepal') - project2 = ProjectFactory.create(is_private=True, regions=[region1, region2], title='Test USA') + project1 = ProjectFactory.create(regions=[region1], title="Test Nepal") + project2 = ProjectFactory.create(is_private=True, regions=[region1, region2], title="Test USA") # This two projects willn't be shown ProjectFactory.create(is_private=True, regions=[region1, region2]) # private + no member access ProjectFactory.create() # no regions attached project2.add_member(user) self.force_login(user) - content = self.query_check(query)['data']['projectsByRegion']['results'] + content = self.query_check(query)["data"]["projectsByRegion"]["results"] self.assertEqual(content, [], content) # only save region2 centroid. region2.centroid = Point(1, 2) - region2.save(update_fields=('centroid',)) - content = self.query_check(query)['data']['projectsByRegion'] - self.assertEqual(content['totalCount'], 1, content) + region2.save(update_fields=("centroid",)) + content = self.query_check(query)["data"]["projectsByRegion"] + self.assertEqual(content["totalCount"], 1, content) self.assertEqual( - content['results'], [ + content["results"], + [ { - 'id': str(region2.pk), - 'centroid': { - 'coordinates': [region2.centroid.x, region2.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project2.pk)] + "id": str(region2.pk), + "centroid": {"coordinates": [region2.centroid.x, region2.centroid.y], "type": "Point"}, + "projectsId": [str(project2.pk)], } - ], content) + ], + content, + ) # Now save region1 centroid as well. region1.centroid = Point(2, 3) - region1.save(update_fields=('centroid',)) - content = self.query_check(query)['data']['projectsByRegion'] - self.assertEqual(content['totalCount'], 2, content) + region1.save(update_fields=("centroid",)) + content = self.query_check(query)["data"]["projectsByRegion"] + self.assertEqual(content["totalCount"], 2, content) self.assertEqual( - content['results'], [ + content["results"], + [ { - 'id': str(region2.pk), - 'centroid': { - 'coordinates': [region2.centroid.x, region2.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project2.pk)] - }, { - 'id': str(region1.pk), - 'centroid': { - 'coordinates': [region1.centroid.x, region1.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project1.pk), str(project2.pk)] - } - ], content) + "id": str(region2.pk), + "centroid": {"coordinates": [region2.centroid.x, region2.centroid.y], "type": "Point"}, + "projectsId": [str(project2.pk)], + }, + { + "id": str(region1.pk), + "centroid": {"coordinates": [region1.centroid.x, region1.centroid.y], "type": "Point"}, + "projectsId": [str(project1.pk), str(project2.pk)], + }, + ], + content, + ) # Now using filters - project_filter = {'search': 'USA'} - content = self.query_check(query, variables={'projectFilter': project_filter})['data']['projectsByRegion'] - self.assertEqual(content['totalCount'], 2, content) + project_filter = {"search": "USA"} + content = self.query_check(query, variables={"projectFilter": project_filter})["data"]["projectsByRegion"] + self.assertEqual(content["totalCount"], 2, content) self.assertEqual( - content['results'], [ + content["results"], + [ { - 'id': str(region2.pk), - 'centroid': { - 'coordinates': [region2.centroid.x, region2.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project2.pk)] - }, { - 'id': str(region1.pk), - 'centroid': { - 'coordinates': [region1.centroid.x, region1.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project2.pk)] - } - ], content) + "id": str(region2.pk), + "centroid": {"coordinates": [region2.centroid.x, region2.centroid.y], "type": "Point"}, + "projectsId": [str(project2.pk)], + }, + { + "id": str(region1.pk), + "centroid": {"coordinates": [region1.centroid.x, region1.centroid.y], "type": "Point"}, + "projectsId": [str(project2.pk)], + }, + ], + content, + ) - project_filter = {'ids': [project1.pk]} - content = self.query_check(query, variables={'projectFilter': project_filter})['data']['projectsByRegion'] - self.assertEqual(content['totalCount'], 1, content) + project_filter = {"ids": [project1.pk]} + content = self.query_check(query, variables={"projectFilter": project_filter})["data"]["projectsByRegion"] + self.assertEqual(content["totalCount"], 1, content) self.assertEqual( - content['results'], [ + content["results"], + [ { - 'id': str(region1.pk), - 'centroid': { - 'coordinates': [region1.centroid.x, region1.centroid.y], - 'type': 'Point' - }, - 'projectsId': [str(project1.pk)] + "id": str(region1.pk), + "centroid": {"coordinates": [region1.centroid.x, region1.centroid.y], "type": "Point"}, + "projectsId": [str(project1.pk)], } - ], content) + ], + content, + ) def test_public_projects_by_region(self): - query = ''' + query = """ query MyQuery ($projectFilter: RegionProjectFilterData) { publicProjectsByRegion (projectFilter: $projectFilter) { totalCount @@ -737,7 +711,7 @@ def test_public_projects_by_region(self): } } } - ''' + """ fake_centroid = Point(1, 2) region1 = RegionFactory.create(public=False, centroid=fake_centroid) region2 = RegionFactory.create(centroid=fake_centroid) @@ -745,32 +719,32 @@ def test_public_projects_by_region(self): region4 = RegionFactory.create(public=False, centroid=fake_centroid) RegionFactory.create() # No Centroid ( This will not show) # Deleted project - ProjectFactory.create(is_private=False, is_deleted=True, regions=[region1, region2], title='Test Nepal') - project1 = ProjectFactory.create(is_private=False, regions=[region1, region2], title='Test Nepal') - ProjectFactory.create(is_private=False, regions=[region3], title='Test Canada') - project2 = ProjectFactory.create(is_private=True, regions=[region4], title='Test Brazil') + ProjectFactory.create(is_private=False, is_deleted=True, regions=[region1, region2], title="Test Nepal") + project1 = ProjectFactory.create(is_private=False, regions=[region1, region2], title="Test Nepal") + ProjectFactory.create(is_private=False, regions=[region3], title="Test Canada") + project2 = ProjectFactory.create(is_private=True, regions=[region4], title="Test Brazil") def _query_check(project_filter): - return self.query_check(query, variables={'projectFilter': project_filter}) + return self.query_check(query, variables={"projectFilter": project_filter}) content = self.query_check(query) - self.assertEqual(content['data']['publicProjectsByRegion']['totalCount'], 3, content) + self.assertEqual(content["data"]["publicProjectsByRegion"]["totalCount"], 3, content) # test for project filter - content = _query_check({'ids': [project1.pk]})['data']['publicProjectsByRegion'] - self.assertEqual(content['totalCount'], 2, content) + content = _query_check({"ids": [project1.pk]})["data"]["publicProjectsByRegion"] + self.assertEqual(content["totalCount"], 2, content) - content = _query_check({'ids': [project1.pk, project2.pk]})['data']['publicProjectsByRegion'] - self.assertEqual(content['totalCount'], 2, content) + content = _query_check({"ids": [project1.pk, project2.pk]})["data"]["publicProjectsByRegion"] + self.assertEqual(content["totalCount"], 2, content) - content = _query_check({'search': 'Canada'})['data']['publicProjectsByRegion'] - self.assertEqual(content['totalCount'], 1, content) + content = _query_check({"search": "Canada"})["data"]["publicProjectsByRegion"] + self.assertEqual(content["totalCount"], 1, content) - content = _query_check({'search': 'Brazil'})['data']['publicProjectsByRegion'] - self.assertEqual(content['totalCount'], 0, content) # Private projects are not shown + content = _query_check({"search": "Brazil"})["data"]["publicProjectsByRegion"] + self.assertEqual(content["totalCount"], 0, content) # Private projects are not shown def test_project_stats_with_filter(self): - query = ''' + query = """ query MyQuery ($projectId: ID! $leadFilters: LeadsFilterDataInputType) { project(id: $projectId) { stats(filters: $leadFilters) { @@ -792,7 +766,7 @@ def test_project_stats_with_filter(self): } } } - ''' + """ non_member_user, member_user = UserFactory.create_batch(2) af = AnalysisFrameworkFactory.create() @@ -803,16 +777,16 @@ def test_project_stats_with_filter(self): EntryFactory.create_batch(2, lead=lead1, controlled=True) EntryFactory.create(lead=lead2, verified_by=[member_user]) lead2.status = Lead.Status.TAGGED - lead2.save(update_fields=('status',)) + lead2.save(update_fields=("status",)) def _query_check(filters=None, **kwargs): return self.query_check( query, variables={ - 'projectId': project.id, - 'leadFilters': filters, + "projectId": project.id, + "leadFilters": filters, }, - **kwargs + **kwargs, ) def _expected_response( @@ -847,41 +821,37 @@ def _expected_response( # With login - non-member zero count self.force_login(non_member_user) - content = _query_check()['data']['project']['stats'] + content = _query_check()["data"]["project"]["stats"] self.assertIsNone(content, content) # With login - member self.force_login(member_user) self.maxDiff = None - for index, (filters, _expected) in enumerate([ - ( - {'confidentiality': self.genum(Lead.Confidentiality.CONFIDENTIAL)}, - [1, 1, 0, 0, 2, 2, 0] - ), + for index, (filters, _expected) in enumerate( + [ + ({"confidentiality": self.genum(Lead.Confidentiality.CONFIDENTIAL)}, [1, 1, 0, 0, 2, 2, 0]), ( - {'entriesFilterData': {'leadConfidentialities': self.genum(Lead.Confidentiality.CONFIDENTIAL)}}, - [1, 1, 0, 0, 2, 2, 0] + {"entriesFilterData": {"leadConfidentialities": self.genum(Lead.Confidentiality.CONFIDENTIAL)}}, + [1, 1, 0, 0, 2, 2, 0], ), + ({"confidentiality": self.genum(Lead.Confidentiality.UNPROTECTED)}, [1, 0, 0, 1, 1, 0, 1]), ( - {'confidentiality': self.genum(Lead.Confidentiality.UNPROTECTED)}, - [1, 0, 0, 1, 1, 0, 1] + {"entriesFilterData": {"leadConfidentialities": self.genum(Lead.Confidentiality.UNPROTECTED)}}, + [1, 0, 0, 1, 1, 0, 1], ), - ( - {'entriesFilterData': {'leadConfidentialities': self.genum(Lead.Confidentiality.UNPROTECTED)}}, - [1, 0, 0, 1, 1, 0, 1] - ), - ]): - content = _query_check(filters=filters)['data']['project']['stats'] + ] + ): + content = _query_check(filters=filters)["data"]["project"]["stats"] self.assertEqual(_expected_response(*_expected), content, index) def test_project_last_read_access(self): - QUERY = ''' + QUERY = """ query MyQuery ($projectId: ID!) { project(id: $projectId) { id } } - ''' + """ user = UserFactory.create() projects = ProjectFactory.create_batch(2) @@ -893,13 +863,13 @@ def test_project_last_read_access(self): project.add_member(user, role=self.project_role_member) def _query_check(project_id): - return self.query_check(QUERY, variables={'projectId': project_id}) + return self.query_check(QUERY, variables={"projectId": project_id}) self.force_login(user) # Run/try query and check if last_read_access are changing properly base_now = datetime(2021, 1, 1, 0, 0, 0, 123456, tzinfo=pytz.UTC) - with patch('deep.trackers.timezone.now') as timezone_now_mock: + with patch("deep.trackers.timezone.now") as timezone_now_mock: timezone_now = None old_timezone_now = None for timezone_now in [ @@ -920,7 +890,7 @@ def _query_check(project_id): else: # Public project have readaccess for some nodes assert project.last_read_access == old_timezone_now - _query_check(project.id)['data']['project'] + _query_check(project.id)["data"]["project"] with self.captureOnCommitCallbacks(execute=True): schedule_tracker_data_handler() project.refresh_from_db() @@ -937,19 +907,17 @@ def _query_check(project_id): old_timezone_now = timezone_now def test_project_last_write_access(self): - MUTATION = ''' + MUTATION = """ mutation MyMutation ($projectId: ID!) { project(id: $projectId) { id } } - ''' + """ user = UserFactory.create() projects = ProjectFactory.create_batch(2) - projects.extend( - ProjectFactory.create_batch(2, is_private=True) - ) + projects.extend(ProjectFactory.create_batch(2, is_private=True)) project_with_access = [projects[0], projects[2]] @@ -957,13 +925,13 @@ def test_project_last_write_access(self): project.add_member(user, role=self.project_role_member) def _query_check(project_id, **kwargs): - return self.query_check(MUTATION, variables={'projectId': project_id}, **kwargs) + return self.query_check(MUTATION, variables={"projectId": project_id}, **kwargs) self.force_login(user) # Run/try mutations and check if last_write_access and project.status are changing properly base_now = datetime(2021, 1, 1, 0, 0, 0, 123456, tzinfo=pytz.UTC) - with patch('deep.trackers.timezone.now') as timezone_now_mock: + with patch("deep.trackers.timezone.now") as timezone_now_mock: timezone_now = None old_timezone_now = None for timezone_now in [ @@ -978,7 +946,7 @@ def _query_check(project_id, **kwargs): if project in project_with_access: # Existing state assert project.last_write_access == old_timezone_now - _query_check(project.id)['data']['project'] + _query_check(project.id)["data"]["project"] with self.captureOnCommitCallbacks(execute=True): schedule_tracker_data_handler() project.refresh_from_db() @@ -1010,7 +978,7 @@ def _query_check(project_id, **kwargs): assert project.status == Project.Status.INACTIVE def test_project_role(self): - query = ''' + query = """ query MyQuery { projectRoles{ id @@ -1020,7 +988,7 @@ def test_project_role(self): } } - ''' + """ user = UserFactory.create() # without login @@ -1029,10 +997,10 @@ def test_project_role(self): self.force_login(user) project_role_count = ProjectRole.objects.count() content = self.query_check(query) - self.assertEqual(len(content['data']['projectRoles']), project_role_count) + self.assertEqual(len(content["data"]["projectRoles"]), project_role_count) def test_user_pinned_projects_query(self): - query = ''' + query = """ query MyQuery { userPinnedProjects { clientId @@ -1049,7 +1017,7 @@ def test_user_pinned_projects_query(self): } } } - ''' + """ user1 = UserFactory.create() user2 = UserFactory.create() @@ -1057,11 +1025,7 @@ def test_user_pinned_projects_query(self): project_with_access = [project[0], project[2]] for idx, project in enumerate(project_with_access): project.add_member(user1) - ProjectPinnedFactory.create( - project=project, - user=user1, - order=idx - ) + ProjectPinnedFactory.create(project=project, user=user1, order=idx) # -- Without login self.query_check(query, assert_for_error=True) @@ -1069,19 +1033,19 @@ def test_user_pinned_projects_query(self): self.force_login(user1) content = self.query_check(query) - self.assertEqual(len(content['data']['userPinnedProjects']), 2) + self.assertEqual(len(content["data"]["userPinnedProjects"]), 2) # -- With non member user self.force_login(user2) content = self.query_check(query) - self.assertEqual(len(content['data']['userPinnedProjects']), 0) + self.assertEqual(len(content["data"]["userPinnedProjects"]), 0) class TestProjectViz(GraphQLTestCase): ENABLE_NOW_PATCHER = True def test_projects_viz_node(self): - query = ''' + query = """ query MyQuery ($id: ID!) { project(id: $id) { vizData { @@ -1096,10 +1060,10 @@ def test_projects_viz_node(self): isVisualizationEnabled } } - ''' + """ af = AnalysisFrameworkFactory.create() - member_user = UserFactory.create() # with confidential access + member_user = UserFactory.create() # with confidential access non_confidential_user = UserFactory.create() non_member_user = UserFactory.create() project = ProjectFactory.create(analysis_framework=af) @@ -1107,7 +1071,7 @@ def test_projects_viz_node(self): project.add_member(non_confidential_user, role=self.project_role_reader_non_confidential) def _query_check(**kwargs): - return self.query_check(query, variables={'id': project.pk}, **kwargs) + return self.query_check(query, variables={"id": project.pk}, **kwargs) # -- Without login _query_check(assert_for_error=True) @@ -1117,33 +1081,33 @@ def _query_check(**kwargs): # --- non-member user self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], False, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], False, content) + self.assertEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], False, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], False, content) # --- member user self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], False, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], False, content) + self.assertEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], False, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], False, content) # Only enabling project viz settings (not configuring AF). project.is_visualization_enabled = True - project.save(update_fields=('is_visualization_enabled',)) + project.save(update_fields=("is_visualization_enabled",)) # --- non-member user self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], True, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], False, content) + self.assertEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], True, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], False, content) # --- member user self.force_login(member_user) content = _query_check() - self.assertEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], True, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], False, content) + self.assertEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], True, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], False, content) # Configure/Enable viz. TestProjectGeneralMutationSnapshotTest.set_project_viz_configuration(project) @@ -1151,24 +1115,24 @@ def _query_check(**kwargs): # --- non-member project self.force_login(non_member_user) content = _query_check() - self.assertEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], True, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], True, content) + self.assertEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], True, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], True, content) # --- member project self.force_login(member_user) content = _query_check() - self.assertNotEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], True, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], True, content) + self.assertNotEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], True, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], True, content) self.assertEqual( - content['data']['project']['vizData'], + content["data"]["project"]["vizData"], { - 'dataUrl': '', - 'modifiedAt': self.now_datetime_str(), - 'publicShare': False, - 'publicUrl': None, - 'status': self.genum(ProjectStats.Status.PENDING), + "dataUrl": "", + "modifiedAt": self.now_datetime_str(), + "publicShare": False, + "publicUrl": None, + "status": self.genum(ProjectStats.Status.PENDING), }, content, ) @@ -1177,17 +1141,17 @@ def _query_check(**kwargs): project_stats = project.project_stats.update_public_share_configuration(ProjectStats.Action.ON) content = _query_check() - self.assertNotEqual(content['data']['project']['vizData'], None, content) - self.assertEqual(content['data']['project']['isVisualizationEnabled'], True, content) - self.assertEqual(content['data']['project']['isVisualizationAvailable'], True, content) + self.assertNotEqual(content["data"]["project"]["vizData"], None, content) + self.assertEqual(content["data"]["project"]["isVisualizationEnabled"], True, content) + self.assertEqual(content["data"]["project"]["isVisualizationAvailable"], True, content) self.assertEqual( - content['data']['project']['vizData'], + content["data"]["project"]["vizData"], { - 'dataUrl': '', - 'modifiedAt': self.now_datetime_str(), - 'publicShare': True, - 'publicUrl': 'http://testserver' + project_stats.get_public_url(), - 'status': self.genum(ProjectStats.Status.PENDING), + "dataUrl": "", + "modifiedAt": self.now_datetime_str(), + "publicShare": True, + "publicUrl": "http://testserver" + project_stats.get_public_url(), + "status": self.genum(ProjectStats.Status.PENDING), }, content, ) @@ -1195,7 +1159,7 @@ def _query_check(**kwargs): class TestProjectFilterSchema(GraphQLTestCase): def test_project_query_filter(self): - query = ''' + query = """ query MyQuery ($isCurrentUserMember: Boolean!) { projects(isCurrentUserMember: $isCurrentUserMember) { page @@ -1209,7 +1173,7 @@ def test_project_query_filter(self): } } } - ''' + """ user = UserFactory.create() project1 = ProjectFactory.create() @@ -1222,22 +1186,22 @@ def test_project_query_filter(self): project2.add_member(user) # -- Without login - self.query_check(query, variables={'isCurrentUserMember': True}, assert_for_error=True) + self.query_check(query, variables={"isCurrentUserMember": True}, assert_for_error=True) # -- With login self.force_login(user) # project without membership - content = self.query_check(query, variables={'isCurrentUserMember': True}) - self.assertEqual(content['data']['projects']['totalCount'], 2, content) - self.assertListIds(content['data']['projects']['results'], [project1, project2], content) + content = self.query_check(query, variables={"isCurrentUserMember": True}) + self.assertEqual(content["data"]["projects"]["totalCount"], 2, content) + self.assertListIds(content["data"]["projects"]["results"], [project1, project2], content) # project with membership - content = self.query_check(query, variables={'isCurrentUserMember': False}) - self.assertEqual(content['data']['projects']['totalCount'], 1, content) # Private will not show here - self.assertListIds(content['data']['projects']['results'], [project3], content) + content = self.query_check(query, variables={"isCurrentUserMember": False}) + self.assertEqual(content["data"]["projects"]["totalCount"], 1, content) # Private will not show here + self.assertListIds(content["data"]["projects"]["results"], [project3], content) def test_query_test_projects_filter(self): - query = ''' + query = """ query MyQuery ($isTest: Boolean!) { projects(isTest: $isTest) { page @@ -1251,31 +1215,31 @@ def test_query_test_projects_filter(self): } } } - ''' + """ user = UserFactory.create() project1, project2 = ProjectFactory.create_batch(2, is_test=True) project3 = ProjectFactory.create() # -- Without login - self.query_check(query, variables={'isTest': True}, assert_for_error=True) + self.query_check(query, variables={"isTest": True}, assert_for_error=True) # -- With login self.force_login(user) # test projects - content = self.query_check(query, variables={'isTest': True}) - self.assertEqual(content['data']['projects']['totalCount'], 2, content) - self.assertListIds(content['data']['projects']['results'], [project1, project2], content) + content = self.query_check(query, variables={"isTest": True}) + self.assertEqual(content["data"]["projects"]["totalCount"], 2, content) + self.assertListIds(content["data"]["projects"]["results"], [project1, project2], content) # except test projects - content = self.query_check(query, variables={'isTest': False}) - self.assertEqual(content['data']['projects']['totalCount'], 1, content) - self.assertListIds(content['data']['projects']['results'], [project3], content) + content = self.query_check(query, variables={"isTest": False}) + self.assertEqual(content["data"]["projects"]["totalCount"], 1, content) + self.assertListIds(content["data"]["projects"]["results"], [project3], content) class TestProjectMembersFilterSchema(GraphQLTestCase): def test_project(self): - query = ''' + query = """ query MyQuery ($id: ID!, $user_search: String, $usergroup_search: String) { project(id: $id) { userMembers(search: $user_search) { @@ -1310,13 +1274,13 @@ def test_project(self): } } } - ''' + """ - user, user1, user2, user3, _ = UserFactory.create_batch(5, first_name='Ram') - usergroup1, usergroup2, _ = UserGroupFactory.create_batch(3, title='UserGroup YYY') - usergroup4 = UserGroupFactory.create(title='UserGroup ZZZ') + user, user1, user2, user3, _ = UserFactory.create_batch(5, first_name="Ram") + usergroup1, usergroup2, _ = UserGroupFactory.create_batch(3, title="UserGroup YYY") + usergroup4 = UserGroupFactory.create(title="UserGroup ZZZ") - user5 = UserFactory.create(first_name='Nam') + user5 = UserFactory.create(first_name="Nam") project = ProjectFactory.create() # Add user to project1 only (one normal + one private) @@ -1333,25 +1297,25 @@ def test_project(self): self.force_login(user) # project without membership - content = self.query_check(query, variables={'id': project.id, 'user_search': user.first_name}) - self.assertEqual(content['data']['project']['userMembers']['totalCount'], 4, content) - self.assertEqual(len(content['data']['project']['userMembers']['results']), 4, content) - self.assertEqual(content['data']['project']['userGroupMembers']['totalCount'], 3, content) - self.assertEqual(len(content['data']['project']['userGroupMembers']['results']), 3, content) + content = self.query_check(query, variables={"id": project.id, "user_search": user.first_name}) + self.assertEqual(content["data"]["project"]["userMembers"]["totalCount"], 4, content) + self.assertEqual(len(content["data"]["project"]["userMembers"]["results"]), 4, content) + self.assertEqual(content["data"]["project"]["userGroupMembers"]["totalCount"], 3, content) + self.assertEqual(len(content["data"]["project"]["userGroupMembers"]["results"]), 3, content) # project without membership - content = self.query_check(query, variables={'id': project.id, 'usergroup_search': usergroup1.title}) - self.assertEqual(content['data']['project']['userGroupMembers']['totalCount'], 2, content) - self.assertEqual(len(content['data']['project']['userGroupMembers']['results']), 2, content) - self.assertEqual(content['data']['project']['userMembers']['totalCount'], 5, content) - self.assertEqual(len(content['data']['project']['userMembers']['results']), 5, content) + content = self.query_check(query, variables={"id": project.id, "usergroup_search": usergroup1.title}) + self.assertEqual(content["data"]["project"]["userGroupMembers"]["totalCount"], 2, content) + self.assertEqual(len(content["data"]["project"]["userGroupMembers"]["results"]), 2, content) + self.assertEqual(content["data"]["project"]["userMembers"]["totalCount"], 5, content) + self.assertEqual(len(content["data"]["project"]["userMembers"]["results"]), 5, content) class TestProjectExploreStats(GraphQLSnapShotTestCase): factories_used = [ProjectFactory, AnalysisFrameworkFactory] def test_snapshot(self): - query = ''' + query = """ query MyQuery { projectExploreStats { totalProjects @@ -1370,7 +1334,7 @@ def test_snapshot(self): calculatedAt } } - ''' + """ def _cache_clear(): cache.delete(CacheKey.PROJECT_EXPLORE_STATS_LOADER_KEY) # Delete cache @@ -1383,15 +1347,14 @@ def _cache_clear(): _cache_clear() previous_content = content = self.query_check(query) - self.assertMatchSnapshot(content, 'no-data') + self.assertMatchSnapshot(content, "no-data") UserFactory.create_batch(3, is_active=False) # Some Inactive users analysis_framework = AnalysisFrameworkFactory.create() projects = ProjectFactory.create_batch(3) projects_with_af = ProjectFactory.create_batch(3, analysis_framework=analysis_framework) # This shouldn't show in top projects but leads/entries count should - private_project = ProjectFactory.create( - title='Private Project', is_private=True, analysis_framework=analysis_framework) + private_project = ProjectFactory.create(title="Private Project", is_private=True, analysis_framework=analysis_framework) now = timezone.now() # Generate project cache @@ -1401,7 +1364,7 @@ def _cache_clear(): self.assertEqual(content, previous_content) # Test for cache _cache_clear() previous_content = content = self.query_check(query) # Pull latest data - self.assertMatchSnapshot(content, 'only-project') + self.assertMatchSnapshot(content, "only-project") self.update_obj(LeadFactory.create(project=projects[0]), created_at=now + relativedelta(weeks=-1)) self.update_obj(LeadFactory.create(project=projects[0]), created_at=now + relativedelta(weeks=-1)) @@ -1428,20 +1391,26 @@ def _cache_clear(): _generate_project_stats_cache() self.update_obj( - ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-1)) + ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-1) + ) self.update_obj( - ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-1)) + ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-1) + ) self.update_obj( - ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-2)) + ExportFactory.create(project=projects_with_af[0], exported_by=user), exported_at=now + relativedelta(months=-2) + ) self.update_obj( - ExportFactory.create(project=projects_with_af[1], exported_by=user), exported_at=now + relativedelta(months=-2)) + ExportFactory.create(project=projects_with_af[1], exported_by=user), exported_at=now + relativedelta(months=-2) + ) self.update_obj( - ExportFactory.create(project=projects_with_af[2], exported_by=user), exported_at=now + relativedelta(months=-3)) + ExportFactory.create(project=projects_with_af[2], exported_by=user), exported_at=now + relativedelta(months=-3) + ) self.update_obj( - ExportFactory.create(project=private_project, exported_by=user), exported_at=now + relativedelta(months=-1)) + ExportFactory.create(project=private_project, exported_by=user), exported_at=now + relativedelta(months=-1) + ) content = self.query_check(query) self.assertEqual(content, previous_content) # Test for cache _cache_clear() previous_content = content = self.query_check(query) # Pull latest data - self.assertMatchSnapshot(content, 'with-data') + self.assertMatchSnapshot(content, "with-data") diff --git a/apps/project/tests/test_utils.py b/apps/project/tests/test_utils.py index 6e826f2df8..0f519deaa4 100644 --- a/apps/project/tests/test_utils.py +++ b/apps/project/tests/test_utils.py @@ -1,9 +1,6 @@ import unittest -from project.change_log import ( - get_flat_dict_diff, - get_list_diff, -) +from project.change_log import get_flat_dict_diff, get_list_diff class ProjectChangeLog(unittest.TestCase): @@ -16,28 +13,28 @@ def _obj(value1, value2, value3): ) list1 = [ - _obj('a', 'b', 'c'), - _obj('b', 'c', 'd'), - _obj('a', 'b', 'z'), + _obj("a", "b", "c"), + _obj("b", "c", "d"), + _obj("a", "b", "z"), ] list2 = [ - _obj('a', 'b', 'c'), - _obj('b', 'c', 'f'), - _obj('a', 'b', 'i'), + _obj("a", "b", "c"), + _obj("b", "c", "f"), + _obj("a", "b", "i"), ] diff = get_flat_dict_diff( list1, list2, - fields=('key1', 'key2', 'key3'), + fields=("key1", "key2", "key3"), ) assert diff == { - 'add': [ - _obj('a', 'b', 'i'), - _obj('b', 'c', 'f'), + "add": [ + _obj("a", "b", "i"), + _obj("b", "c", "f"), ], - 'remove': [ - _obj('a', 'b', 'z'), - _obj('b', 'c', 'd'), + "remove": [ + _obj("a", "b", "z"), + _obj("b", "c", "d"), ], } @@ -46,14 +43,14 @@ def test_get_list_diff(self): list2 = [5, 4, 3] diff = get_list_diff(list1, list2) assert diff == { - 'add': [4, 5], - 'remove': [1, 2], + "add": [4, 5], + "remove": [1, 2], } - list1 = ['dfs', 'deep'] - list2 = ['toggle', 'deep', 'nepal'] + list1 = ["dfs", "deep"] + list2 = ["toggle", "deep", "nepal"] diff = get_list_diff(list1, list2) assert diff == { - 'add': ['nepal', 'toggle'], - 'remove': ['dfs'], + "add": ["nepal", "toggle"], + "remove": ["dfs"], } diff --git a/apps/project/token.py b/apps/project/token.py index f1fefb508f..d97553b1b5 100644 --- a/apps/project/token.py +++ b/apps/project/token.py @@ -1,4 +1,5 @@ from django.conf import settings + from deep.token import DeepTokenGenerator @@ -7,6 +8,7 @@ class ProjectRequestTokenGenerator(DeepTokenGenerator): Strategy object used to generate and check tokens for the project request mechanism. """ + key_salt = "projects.token.ProjectRequestTokenGenerator" secret = settings.SECRET_KEY reset_timeout_days = settings.PROJECT_REQUEST_RESET_TIMEOUT_DAYS @@ -23,17 +25,13 @@ def _make_hash_value(self, project_join_request, timestamp): Failing those things, settings.PROJECT_REQUEST_RESET_TIMEOUT_DAYS eventually invalidates the token. """ - join_request = project_join_request['join_request'] - user = project_join_request['will_responded_by'] + join_request = project_join_request["join_request"] + user = project_join_request["will_responded_by"] # Truncate microseconds so that tokens are consistent even if the # database doesn't support microseconds. - responded_at = '' if join_request.responded_at is None else\ - join_request.responded_at.replace(microsecond=0, tzinfo=None) - return ( - str(join_request.pk) + str(user.pk) + join_request.status + - str(responded_at) + str(timestamp) - ) + responded_at = "" if join_request.responded_at is None else join_request.responded_at.replace(microsecond=0, tzinfo=None) + return str(join_request.pk) + str(user.pk) + join_request.status + str(responded_at) + str(timestamp) project_request_token_generator = ProjectRequestTokenGenerator() diff --git a/apps/project/views.py b/apps/project/views.py index 979de1c2c0..8f69dc66c1 100644 --- a/apps/project/views.py +++ b/apps/project/views.py @@ -1,19 +1,26 @@ import logging import uuid -from dateutil.relativedelta import relativedelta +import ary.serializers as arys import django_filters +from analysis.models import Analysis, AnalyticalStatementEntry, DiscardedEntry +from dateutil.relativedelta import relativedelta from django.conf import settings -from django.http import Http404 -from django.db import transaction, models -from django.utils import timezone -from django.utils.http import urlsafe_base64_decode -from django.utils.encoding import force_text from django.contrib.postgres.fields.jsonb import KeyTextTransform +from django.db import models, transaction from django.db.models.functions import Cast +from django.http import Http404 from django.template.response import TemplateResponse -from deep.permalinks import Permalink -from rest_framework.exceptions import PermissionDenied +from django.utils import timezone +from django.utils.encoding import force_text +from django.utils.http import urlsafe_base64_decode +from docs.utils import mark_as_delete, mark_as_list +from entry.models import Entry +from entry.views import ComprehensiveEntriesViewSet +from geo.models import Region +from geo.serializers import RegionSerializer +from lead.models import Lead +from lead.views import ProjectLeadGroupViewSet from rest_framework import ( exceptions, filters, @@ -24,70 +31,54 @@ viewsets, ) from rest_framework.decorators import action +from rest_framework.exceptions import PermissionDenied from rest_framework.generics import get_object_or_404 - -from docs.utils import mark_as_list, mark_as_delete -import ary.serializers as arys - -from deep.views import get_frontend_url -from deep.permissions import ( - ModifyPermission, - IsProjectMember, -) -from deep.serializers import URLCachedFileField -from deep.paginations import SmallSizeSetPagination from tabular.models import Field - -from user.utils import send_project_join_request_emails -from user.serializers import SimpleUserSerializer from user.models import User -from lead.models import Lead -from lead.views import ProjectLeadGroupViewSet -from geo.models import Region +from user.serializers import SimpleUserSerializer +from user.utils import send_project_join_request_emails from user_group.models import UserGroup -from geo.serializers import RegionSerializer -from entry.models import Entry -from entry.views import ComprehensiveEntriesViewSet -from analysis.models import ( - Analysis, - AnalyticalStatementEntry, - DiscardedEntry -) +from deep.paginations import SmallSizeSetPagination +from deep.permalinks import Permalink +from deep.permissions import IsProjectMember, ModifyPermission +from deep.serializers import URLCachedFileField +from deep.views import get_frontend_url + +from .filter_set import ( + ProjectFilterSet, + ProjectMembershipFilterSet, + ProjectUserGroupMembershipFilterSet, + get_filtered_projects, +) from .models import ( Project, - ProjectRole, - ProjectMembership, ProjectJoinRequest, - ProjectUserGroupMembership, + ProjectMembership, + ProjectOrganization, + ProjectRole, ProjectStats, - ProjectOrganization -) -from .serializers import ( - ProjectSerializer, - ProjectStatSerializer, - ProjectRoleSerializer, - ProjectMembershipSerializer, - ProjectJoinRequestSerializer, - ProjectUserGroupSerializer, - ProjectMemberViewSerializer, - ProjectRecentActivitySerializer, + ProjectUserGroupMembership, ) from .permissions import ( - JoinPermission, + PROJECT_PERMISSIONS, AcceptRejectPermission, + JoinPermission, MembershipModifyPermission, - PROJECT_PERMISSIONS, ) -from .filter_set import ( - ProjectFilterSet, - get_filtered_projects, - ProjectMembershipFilterSet, - ProjectUserGroupMembershipFilterSet, +from .serializers import ( + ProjectJoinRequestSerializer, + ProjectMembershipSerializer, + ProjectMemberViewSerializer, + ProjectRecentActivitySerializer, + ProjectRoleSerializer, + ProjectSerializer, + ProjectStatSerializer, + ProjectUserGroupSerializer, ) from .tasks import generate_viz_stats - from .token import project_request_token_generator + logger = logging.getLogger(__name__) @@ -96,41 +87,34 @@ def _get_viz_data(request, project, can_view_confidential, token=None): Util function to trigger and serve Project entry/ary viz data """ if ( - project.analysis_framework is None or - project.analysis_framework.properties is None or - project.analysis_framework.properties.get('stats_config') is None + project.analysis_framework is None + or project.analysis_framework.properties is None + or project.analysis_framework.properties.get("stats_config") is None ): return { - 'error': f'No configuration provided for current Project: {project.title}, Contact Admin', + "error": f"No configuration provided for current Project: {project.title}, Contact Admin", }, status.HTTP_404_NOT_FOUND stats, created = ProjectStats.objects.get_or_create(project=project) - if token and ( - not stats.public_share or token != str(stats.token) - ): - return { - 'error': 'Token is invalid or sharing is disabled. Please contact project\'s admin.' - }, status.HTTP_403_FORBIDDEN + if token and (not stats.public_share or token != str(stats.token)): + return {"error": "Token is invalid or sharing is disabled. Please contact project's admin."}, status.HTTP_403_FORBIDDEN stat_file = stats.confidential_file if can_view_confidential else stats.file - file_url = ( - request.build_absolute_uri(URLCachedFileField().to_representation(stat_file)) - if stat_file else None - ) + file_url = request.build_absolute_uri(URLCachedFileField().to_representation(stat_file)) if stat_file else None stats_meta = { - 'data': file_url, - 'modified_at': stats.modified_at, - 'status': stats.status, - 'public_share': stats.public_share, - 'public_url': stats.get_public_url(request), + "data": file_url, + "modified_at": stats.modified_at, + "status": stats.status, + "public_share": stats.public_share, + "public_url": stats.get_public_url(request), } if stats.is_ready(): return stats_meta, status.HTTP_200_OK elif stats.status == ProjectStats.Status.FAILURE: return { - 'error': 'Failed to generate stats, Contact Admin', + "error": "Failed to generate stats, Contact Admin", **stats_meta, }, status.HTTP_200_OK transaction.on_commit(lambda: generate_viz_stats.delay(project.pk)) @@ -139,18 +123,19 @@ def _get_viz_data(request, project, can_view_confidential, token=None): stats.status = ProjectStats.Status.PENDING stats.save() return { - 'message': 'Processing the request, try again later', + "message": "Processing the request, try again later", **stats_meta, }, status.HTTP_202_ACCEPTED class ProjectViewSet(viewsets.ModelViewSet): - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + permission_classes = [permissions.IsAuthenticated, ModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = ProjectFilterSet - search_fields = ('title', 'description',) + search_fields = ( + "title", + "description", + ) def get_queryset(self): return get_filtered_projects(self.request.user, self.request.GET) @@ -168,32 +153,36 @@ def get_project_object(self): """ Return project same as get_object without any other filters """ - if self.kwargs.get('pk') is not None: - return get_object_or_404(self.get_queryset(), pk=self.kwargs['pk']) + if self.kwargs.get("pk") is not None: + return get_object_or_404(self.get_queryset(), pk=self.kwargs["pk"]) raise Http404 @action( detail=False, - url_path='recent-activities', + url_path="recent-activities", ) def get_recent_activities(self, request, version=None): - return response.Response({ - 'results': ProjectRecentActivitySerializer( - Project.get_recent_activities(request.user), - context={'request': request}, many=True, - ).data - }) + return response.Response( + { + "results": ProjectRecentActivitySerializer( + Project.get_recent_activities(request.user), + context={"request": request}, + many=True, + ).data + } + ) """ Get list of projects that user is member of """ + @action( detail=False, permission_classes=[permissions.IsAuthenticated], - url_path='member-of', + url_path="member-of", ) def get_for_member(self, request, version=None): - user = self.request.GET.get('user') + user = self.request.GET.get("user") projects = Project.get_for_member(user) if user is None or request.user == user: @@ -201,9 +190,9 @@ def get_for_member(self, request, version=None): else: projects = Project.get_for_public(request.user, user) - user_group = request.GET.get('user_group') + user_group = request.GET.get("user_group") if user_group: - user_group = user_group.split(',') + user_group = user_group.split(",") projects = projects.filter(user_groups__id__in=user_group) self.page = self.paginate_queryset(projects) @@ -213,46 +202,44 @@ def get_for_member(self, request, version=None): """ Generate project public VIZ URL """ + @action( detail=True, - methods=['post'], - url_path='public-viz', + methods=["post"], + url_path="public-viz", ) def generate_public_viz(self, request, pk=None, version=None): project = self.get_object() - action = request.data.get('action', 'new') + action = request.data.get("action", "new") stats, created = ProjectStats.objects.get_or_create(project=project) - if action == 'new': + if action == "new": stats.public_share = True stats.token = uuid.uuid4() - elif action == 'on': + elif action == "on": stats.public_share = True stats.token = stats.token or uuid.uuid4() - elif action == 'off': + elif action == "off": stats.public_share = False else: - raise exceptions.ValidationError({'action': f'Invalid action {action}'}) - stats.save(update_fields=['token', 'public_share']) - return response.Response({'public_url': stats.get_public_url(request)}) + raise exceptions.ValidationError({"action": f"Invalid action {action}"}) + stats.save(update_fields=["token", "public_share"]) + return response.Response({"public_url": stats.get_public_url(request)}) """ Get analysis framework for this project """ - @action( - detail=True, - permission_classes=[permissions.IsAuthenticated], - url_path='analysis-framework' - ) + + @action(detail=True, permission_classes=[permissions.IsAuthenticated], url_path="analysis-framework") def get_framework(self, request, pk=None, version=None): from analysis_framework.serializers import AnalysisFrameworkSerializer project = self.get_object() if not project.analysis_framework: - raise exceptions.NotFound('Resource not found') + raise exceptions.NotFound("Resource not found") serializer = AnalysisFrameworkSerializer( project.analysis_framework, - context={'request': request}, + context={"request": request}, ) return response.Response(serializer.data) @@ -260,36 +247,39 @@ def get_framework(self, request, pk=None, version=None): """ Get regions assigned to this project """ + @action( detail=True, - url_path='regions', + url_path="regions", permission_classes=[permissions.IsAuthenticated], ) def get_regions(self, request, pk=None, version=None): instance = self.get_object() serializer = RegionSerializer( instance.regions, - many=True, context={'request': request}, + many=True, + context={"request": request}, ) - return response.Response({'regions': serializer.data}) + return response.Response({"regions": serializer.data}) """ Get assessment template for this project """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated], serializer_class=arys.AssessmentTemplateSerializer, - url_path='assessment-template', + url_path="assessment-template", ) def get_assessment_template(self, request, pk=None, version=None): project = self.get_object() if not project.assessment_template: - raise exceptions.NotFound('Resource not found') + raise exceptions.NotFound("Resource not found") serializer = arys.AssessmentTemplateSerializer( project.assessment_template, - context={'request': request}, + context={"request": request}, ) return response.Response(serializer.data) @@ -298,11 +288,12 @@ def get_assessment_template(self, request, pk=None, version=None): Get status for export: - tabular chart generation status """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated], serializer_class=ProjectJoinRequestSerializer, - url_path='export-status', + url_path="export-status", ) def get_export_status(self, request, pk=None, version=None): project = self.get_object() @@ -310,14 +301,16 @@ def get_export_status(self, request, pk=None, version=None): cache__image_status=Field.CACHE_PENDING, sheet__book__project=project, ).count() - return response.Response({ - 'tabular_pending_fields_count': fields_pending_count, - }) + return response.Response( + { + "tabular_pending_fields_count": fields_pending_count, + } + ) @action( detail=True, permission_classes=[permissions.IsAuthenticated], - url_path='project-viz', + url_path="project-viz", ) def get_project_viz_data(self, request, pk=None, version=None): """ @@ -325,11 +318,8 @@ def get_project_viz_data(self, request, pk=None, version=None): """ project = self.get_object() can_view_confidential = ( - ProjectMembership.objects - .filter(member=request.user, project=project) - .annotate( - view_all=models.F('role__lead_permissions').bitand(PROJECT_PERMISSIONS.lead.view) - ) + ProjectMembership.objects.filter(member=request.user, project=project) + .annotate(view_all=models.F("role__lead_permissions").bitand(PROJECT_PERMISSIONS.lead.view)) .filter(view_all=PROJECT_PERMISSIONS.lead.view) .exists() ) @@ -339,34 +329,33 @@ def get_project_viz_data(self, request, pk=None, version=None): """ Join request to this project """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated, JoinPermission], - methods=['post'], - url_path='join', + methods=["post"], + url_path="join", ) def join_project(self, request, pk=None, version=None): project = self.get_object() # Forbid join requests for private project - if (project.is_private): - raise PermissionDenied( - {'message': "You cannot send join request to the private project"} - ) + if project.is_private: + raise PermissionDenied({"message": "You cannot send join request to the private project"}) serializer = ProjectJoinRequestSerializer( data={ - 'role': ProjectRole.get_default_role().id, + "role": ProjectRole.get_default_role().id, **request.data, }, - context={'request': request, 'project': project} + context={"request": request, "project": project}, ) serializer.is_valid(raise_exception=True) join_request = serializer.save() serializer = ProjectJoinRequestSerializer( join_request, - context={'request': request}, + context={"request": request}, ) if settings.TESTING: @@ -379,27 +368,24 @@ def join_project(self, request, pk=None, version=None): # while the emails are being sent in the background. def send_mail(): send_project_join_request_emails.delay(join_request.id) + transaction.on_commit(send_mail) - return response.Response(serializer.data, - status=status.HTTP_201_CREATED) + return response.Response(serializer.data, status=status.HTTP_201_CREATED) @staticmethod def _accept_request(responded_by, join_request, role): - if not role or role == 'normal': + if not role or role == "normal": role = ProjectRole.get_default_role() - elif role == 'admin': + elif role == "admin": role = ProjectRole.get_admin_role() else: role_qs = ProjectRole.objects.filter(id=role) if not role_qs.exists(): - return response.Response( - {'errors': 'Role id \'{}\' does not exist'.format(role)}, - status=status.HTTP_404_NOT_FOUND - ) + return response.Response({"errors": "Role id '{}' does not exist".format(role)}, status=status.HTTP_404_NOT_FOUND) role = role_qs.first() - join_request.status = 'accepted' + join_request.status = "accepted" join_request.responded_by = responded_by join_request.responded_at = timezone.now() join_request.role = role @@ -409,14 +395,14 @@ def _accept_request(responded_by, join_request, role): project=join_request.project, member=join_request.requested_by, defaults={ - 'role': role, - 'added_by': responded_by, + "role": role, + "added_by": responded_by, }, ) @staticmethod def _reject_request(responded_by, join_request): - join_request.status = 'rejected' + join_request.status = "rejected" join_request.responded_by = responded_by join_request.responded_at = timezone.now() join_request.save() @@ -425,85 +411,77 @@ def _reject_request(responded_by, join_request): Accept a join request to this project, creating the membership while doing so. """ + @action( detail=True, permission_classes=[ - permissions.IsAuthenticated, AcceptRejectPermission, + permissions.IsAuthenticated, + AcceptRejectPermission, ], - methods=['post'], - url_path=r'requests/(?P\d+)/accept', + methods=["post"], + url_path=r"requests/(?P\d+)/accept", ) def accept_request(self, request, pk=None, version=None, request_id=None): project = self.get_object() - join_request = get_object_or_404(ProjectJoinRequest, - id=request_id, - project=project) + join_request = get_object_or_404(ProjectJoinRequest, id=request_id, project=project) - if join_request.status in ['accepted', 'rejected']: - raise exceptions.ValidationError( - 'This request has already been {}'.format(join_request.status) - ) + if join_request.status in ["accepted", "rejected"]: + raise exceptions.ValidationError("This request has already been {}".format(join_request.status)) - role = request.data.get('role') + role = request.data.get("role") ProjectViewSet._accept_request(request.user, join_request, role) serializer = ProjectJoinRequestSerializer( join_request, - context={'request': request}, + context={"request": request}, ) return response.Response(serializer.data) """ Reject a join request to this project """ + @action( detail=True, permission_classes=[ - permissions.IsAuthenticated, AcceptRejectPermission, + permissions.IsAuthenticated, + AcceptRejectPermission, ], - methods=['post'], - url_path=r'requests/(?P\d+)/reject', + methods=["post"], + url_path=r"requests/(?P\d+)/reject", ) def reject_request(self, request, pk=None, version=None, request_id=None): project = self.get_object() - join_request = get_object_or_404(ProjectJoinRequest, - id=request_id, - project=project) + join_request = get_object_or_404(ProjectJoinRequest, id=request_id, project=project) - if join_request.status in ['accepted', 'rejected']: - raise exceptions.ValidationError( - 'This request has already been {}'.format(join_request.status) - ) + if join_request.status in ["accepted", "rejected"]: + raise exceptions.ValidationError("This request has already been {}".format(join_request.status)) ProjectViewSet._reject_request(request.user, join_request) serializer = ProjectJoinRequestSerializer( join_request, - context={'request': request}, + context={"request": request}, ) return response.Response(serializer.data) """ Cancel a join request to this project """ + @mark_as_delete() @action( detail=True, permission_classes=[permissions.IsAuthenticated], - methods=['post'], - url_path=r'join/cancel', + methods=["post"], + url_path=r"join/cancel", ) def cancel_request(self, request, pk=None, version=None, request_id=None): project = self.get_object() - join_request = get_object_or_404(ProjectJoinRequest, - requested_by=request.user, - status='pending', - project=project) - - if join_request.status in ['accepted', 'rejected']: - raise exceptions.ValidationError( - 'This request has already been {}'.format(join_request.status) - ) + join_request = get_object_or_404(ProjectJoinRequest, requested_by=request.user, status="pending", project=project) + + if join_request.status in ["accepted", "rejected"]: + raise exceptions.ValidationError("This request has already been {}".format(join_request.status)) join_request.delete() return response.Response(status=status.HTTP_204_NO_CONTENT) @@ -511,13 +489,15 @@ def cancel_request(self, request, pk=None, version=None, request_id=None): """ Get list of join requests for this project """ + @mark_as_list() @action( detail=True, permission_classes=[ - permissions.IsAuthenticated, ModifyPermission, + permissions.IsAuthenticated, + ModifyPermission, ], - url_path='requests', + url_path="requests", ) def get_requests(self, request, pk=None, version=None): project = self.get_object() @@ -529,29 +509,25 @@ def get_requests(self, request, pk=None, version=None): """ Comprehensive Entries """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated], - methods=['get'], - url_path=r'comprehensive-entries', + methods=["get"], + url_path=r"comprehensive-entries", ) def comprehensive_entries(self, request, *args, **kwargs): project = self.get_project_object() - viewfn = ComprehensiveEntriesViewSet.as_view({'get': 'list'}) + viewfn = ComprehensiveEntriesViewSet.as_view({"get": "list"}) request._request.GET = request._request.GET.copy() - request._request.GET['project'] = project.pk + request._request.GET["project"] = project.pk return viewfn(request._request, *args, **kwargs) - @action( - detail=True, - permission_classes=[permissions.IsAuthenticated, IsProjectMember], - url_path='members' - ) + @action(detail=True, permission_classes=[permissions.IsAuthenticated, IsProjectMember], url_path="members") def get_members(self, request, pk=None, version=None): project = self.get_object() members = User.objects.filter( - models.Q(projectmembership__project=project) | - models.Q(usergroup__projectusergroupmembership__project=project) + models.Q(projectmembership__project=project) | models.Q(usergroup__projectusergroupmembership__project=project) ).distinct() self.page = self.paginate_queryset(members) serializer = SimpleUserSerializer(self.page, many=True) @@ -560,37 +536,40 @@ def get_members(self, request, pk=None, version=None): """ Project Lead-Groups """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated], - methods=['get'], - url_path=r'lead-groups', + methods=["get"], + url_path=r"lead-groups", ) def get_lead_groups(self, request, *args, **kwargs): project = self.get_project_object() - viewfn = ProjectLeadGroupViewSet.as_view({'get': 'list'}) + viewfn = ProjectLeadGroupViewSet.as_view({"get": "list"}) request._request.GET = request._request.GET.copy() - request._request.GET['project'] = project.pk + request._request.GET["project"] = project.pk return viewfn(request._request) """ Project Questionnaire Meta """ + @action( detail=True, permission_classes=[permissions.IsAuthenticated], - methods=['get'], - url_path=r'questionnaire-meta', + methods=["get"], + url_path=r"questionnaire-meta", ) def get_questionnaire_meta(self, request, *args, **kwargs): project = self.get_project_object() af = project.analysis_framework meta = { - 'active_count': project.questionnaire_set.filter(is_archived=False).count(), - 'archived_count': project.questionnaire_set.filter(is_archived=True).count(), - 'analysis_framework': af and { - 'id': af.id, - 'title': af.title, + "active_count": project.questionnaire_set.filter(is_archived=False).count(), + "archived_count": project.questionnaire_set.filter(is_archived=True).count(), + "analysis_framework": af + and { + "id": af.id, + "title": af.title, }, } return response.Response(meta) @@ -598,72 +577,88 @@ def get_questionnaire_meta(self, request, *args, **kwargs): """ Get analysis for this project """ - @action( - detail=True, - permission_classes=[permissions.IsAuthenticated, IsProjectMember], - url_path='analysis-overview' - ) + + @action(detail=True, permission_classes=[permissions.IsAuthenticated, IsProjectMember], url_path="analysis-overview") def get_analysis(self, request, pk=None, version=None): project = self.get_object() # get all the analysis in the project # TODO: Remove this later and let client handle this using graphql - analysis_list = Analysis.objects.filter(project=project).values('id', 'title', 'created_at') + analysis_list = Analysis.objects.filter(project=project).values("id", "title", "created_at") - total_sources = Lead.objects\ - .filter(project=project)\ - .annotate(entries_count=models.Count('entry'))\ - .filter(entries_count__gt=0)\ - .count() + total_sources = ( + Lead.objects.filter(project=project).annotate(entries_count=models.Count("entry")).filter(entries_count__gt=0).count() + ) entries_total = Entry.objects.filter(project=project).count() - entries_dragged = AnalyticalStatementEntry.objects\ - .filter(analytical_statement__analysis_pillar__analysis__project=project)\ - .order_by().values('entry').distinct() - entries_discarded = DiscardedEntry.objects\ - .filter(analysis_pillar__analysis__project=project)\ - .order_by().values('entry').distinct() + entries_dragged = ( + AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis__project=project) + .order_by() + .values("entry") + .distinct() + ) + entries_discarded = ( + DiscardedEntry.objects.filter(analysis_pillar__analysis__project=project).order_by().values("entry").distinct() + ) total_analyzed_entries = entries_discarded.union(entries_dragged).count() - sources_discarded = DiscardedEntry.objects\ - .filter(analysis_pillar__analysis__project=project)\ - .order_by().values('entry__lead_id').distinct() - sources_dragged = AnalyticalStatementEntry.objects\ - .filter(analytical_statement__analysis_pillar__analysis__project=project)\ - .order_by().values('entry__lead_id').distinct() + sources_discarded = ( + DiscardedEntry.objects.filter(analysis_pillar__analysis__project=project) + .order_by() + .values("entry__lead_id") + .distinct() + ) + sources_dragged = ( + AnalyticalStatementEntry.objects.filter(analytical_statement__analysis_pillar__analysis__project=project) + .order_by() + .values("entry__lead_id") + .distinct() + ) total_analyzed_sources = sources_dragged.union(sources_discarded).count() - lead_qs = Lead.objects\ - .filter(project=project, authors__organization_type__isnull=False)\ + lead_qs = ( + Lead.objects.filter(project=project, authors__organization_type__isnull=False) .annotate( - entries_count=models.functions.Coalesce(models.Subquery( - AnalyticalStatementEntry.objects.filter( - entry__lead_id=models.OuterRef('pk') - ).order_by().values('entry__lead_id').annotate(count=models.Count('*')) - .values('count')[:1], - output_field=models.IntegerField(), - ), 0) - ).filter(entries_count__gt=0) - authoring_organizations = Lead.objects\ - .filter(id__in=lead_qs)\ - .order_by('authors__organization_type').values('authors__organization_type')\ + entries_count=models.functions.Coalesce( + models.Subquery( + AnalyticalStatementEntry.objects.filter(entry__lead_id=models.OuterRef("pk")) + .order_by() + .values("entry__lead_id") + .annotate(count=models.Count("*")) + .values("count")[:1], + output_field=models.IntegerField(), + ), + 0, + ) + ) + .filter(entries_count__gt=0) + ) + authoring_organizations = ( + Lead.objects.filter(id__in=lead_qs) + .order_by("authors__organization_type") + .values("authors__organization_type") .annotate( - count=models.Count('id'), + count=models.Count("id"), organization_type_title=models.functions.Coalesce( - models.F('authors__organization_type__title'), - models.Value(''), - )).values( - 'count', - 'organization_type_title', - organization_type_id=models.F('authors__organization_type'), + models.F("authors__organization_type__title"), + models.Value(""), + ), ) + .values( + "count", + "organization_type_title", + organization_type_id=models.F("authors__organization_type"), + ) + ) - return response.Response({ - 'analysis_list': analysis_list, - 'entries_total': entries_total, - 'analyzed_entries_count': total_analyzed_entries, - 'sources_total': total_sources, - 'analyzed_source_count': total_analyzed_sources, - 'authoring_organizations': authoring_organizations - }) + return response.Response( + { + "analysis_list": analysis_list, + "entries_total": entries_total, + "analyzed_entries_count": total_analyzed_entries, + "sources_total": total_sources, + "analyzed_source_count": total_analyzed_sources, + "authoring_organizations": authoring_organizations, + } + ) class ProjectStatViewSet(ProjectViewSet): @@ -673,20 +668,20 @@ def get_serializer_class(self): return ProjectStatSerializer def get_queryset(self): - return get_filtered_projects( - self.request.user, self.request.GET, - annotate=True, - ).prefetch_related( - 'regions', 'organizations', - ).select_related( - 'created_by__profile', 'modified_by__profile' + return ( + get_filtered_projects( + self.request.user, + self.request.GET, + annotate=True, + ) + .prefetch_related( + "regions", + "organizations", + ) + .select_related("created_by__profile", "modified_by__profile") ) - @action( - detail=False, - permission_classes=[permissions.IsAuthenticated], - url_path='recent' - ) + @action(detail=False, permission_classes=[permissions.IsAuthenticated], url_path="recent") def get_recent_projects(self, request, *args, **kwargs): # Only pull project data for which user is member of qs = self.get_queryset().filter(Project.get_query_for_member(request.user)) @@ -699,67 +694,68 @@ def get_recent_projects(self, request, *args, **kwargs): ).data ) - @action( - detail=False, - permission_classes=[permissions.IsAuthenticated], - url_path='summary' - ) + @action(detail=False, permission_classes=[permissions.IsAuthenticated], url_path="summary") def get_projects_summary(self, request, pk=None, version=None): projects = Project.get_for_member(request.user) # Lead stats leads = Lead.objects.filter(project__in=projects) - total_leads_tagged_count = leads.annotate(entries_count=models.Count('entry')).filter(entries_count__gt=0).count() - total_leads_tagged_and_controlled_count = leads.annotate( - entries_count=models.Count('entry'), - controlled_entries_count=models.Count( - 'entry', filter=models.Q(entry__controlled=True) - ), - ).filter(entries_count__gt=0, entries_count=models.F('controlled_entries_count')).count() + total_leads_tagged_count = leads.annotate(entries_count=models.Count("entry")).filter(entries_count__gt=0).count() + total_leads_tagged_and_controlled_count = ( + leads.annotate( + entries_count=models.Count("entry"), + controlled_entries_count=models.Count("entry", filter=models.Q(entry__controlled=True)), + ) + .filter(entries_count__gt=0, entries_count=models.F("controlled_entries_count")) + .count() + ) # Entries activity recent_projects_id = list( - projects.annotate( - entries_count=Cast(KeyTextTransform('entries_activity', 'stats_cache'), models.IntegerField()) - ).filter(entries_count__gt=0).order_by('-entries_count').values_list('id', flat=True)[:3]) + projects.annotate(entries_count=Cast(KeyTextTransform("entries_activity", "stats_cache"), models.IntegerField())) + .filter(entries_count__gt=0) + .order_by("-entries_count") + .values_list("id", flat=True)[:3] + ) recent_entries = Entry.objects.filter( - project__in=recent_projects_id, - created_at__gte=(timezone.now() + relativedelta(months=-3)) + project__in=recent_projects_id, created_at__gte=(timezone.now() + relativedelta(months=-3)) ) recent_entries_activity = { - 'projects': ( - recent_entries.order_by().values('project') - .annotate(count=models.Count('*')) + "projects": ( + recent_entries.order_by() + .values("project") + .annotate(count=models.Count("*")) .filter(count__gt=0) - .values('count', id=models.F('project'), title=models.F('project__title')) + .values("count", id=models.F("project"), title=models.F("project__title")) ), - 'activities': ( - recent_entries.order_by('project', 'created_at__date').values('project', 'created_at__date') - .annotate(count=models.Count('*')) - .values('project', 'count', date=models.Func(models.F('created_at__date'), function='DATE')) + "activities": ( + recent_entries.order_by("project", "created_at__date") + .values("project", "created_at__date") + .annotate(count=models.Count("*")) + .values("project", "count", date=models.Func(models.F("created_at__date"), function="DATE")) ), } - return response.Response({ - 'projects_count': projects.count(), - 'total_leads_count': leads.count(), - 'total_leads_tagged_count': total_leads_tagged_count, - 'total_leads_tagged_and_controlled_count': total_leads_tagged_and_controlled_count, - 'recent_entries_activity': recent_entries_activity, - }) + return response.Response( + { + "projects_count": projects.count(), + "total_leads_count": leads.count(), + "total_leads_tagged_count": total_leads_tagged_count, + "total_leads_tagged_and_controlled_count": total_leads_tagged_and_controlled_count, + "recent_entries_activity": recent_entries_activity, + } + ) class ProjectMembershipViewSet(viewsets.ModelViewSet): serializer_class = ProjectMembershipSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission, MembershipModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + permission_classes = [permissions.IsAuthenticated, ModifyPermission, MembershipModifyPermission] + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) filterset_class = ProjectMembershipFilterSet def get_serializer(self, *args, **kwargs): - data = kwargs.get('data') - list = data and data.get('list') + data = kwargs.get("data") + list = data and data.get("list") if list: - kwargs.pop('data') - kwargs.pop('many', None) + kwargs.pop("data") + kwargs.pop("many", None) return super().get_serializer( data=list, many=True, @@ -772,47 +768,47 @@ def get_serializer(self, *args, **kwargs): ) def finalize_response(self, request, response, *args, **kwargs): - if request.method == 'POST' and isinstance(response.data, list): + if request.method == "POST" and isinstance(response.data, list): response.data = { - 'results': response.data, + "results": response.data, } return super().finalize_response( - request, response, - *args, **kwargs, + request, + response, + *args, + **kwargs, ) def get_queryset(self): - return ProjectMembership.get_for(self.request.user).filter(project=self.kwargs['project_id']).select_related( - 'role' - ) + return ProjectMembership.get_for(self.request.user).filter(project=self.kwargs["project_id"]).select_related("role") class ProjectOptionsView(views.APIView): """ Options for various attributes related to project """ + permission_classes = [permissions.IsAuthenticated] def get(self, request, version=None): - project_query = request.GET.get('project') - fields_query = request.GET.get('fields') + project_query = request.GET.get("project") + fields_query = request.GET.get("fields") projects = None if project_query: - projects = Project.get_for(request.user).filter( - id__in=project_query.split(',') - ) + projects = Project.get_for(request.user).filter(id__in=project_query.split(",")) fields = None if fields_query: - fields = fields_query.split(',') + fields = fields_query.split(",") options = { - 'project_organization_types': [ + "project_organization_types": [ { - 'key': s[0], - 'value': s[1], - } for s in ProjectOrganization.Type.choices + "key": s[0], + "value": s[1], + } + for s in ProjectOrganization.Type.choices ], } @@ -821,57 +817,54 @@ def _filter_by_projects(qs, projects): qs = qs.filter(project=p) return qs - if (fields is None or 'regions' in fields): + if fields is None or "regions" in fields: if projects: project_regions = _filter_by_projects(Region.objects, projects).distinct() else: project_regions = Region.objects.none() user_regions = Region.get_for(request.user) - regions = Region.objects.filter(id__in=(project_regions | user_regions).values('id')).distinct() + regions = Region.objects.filter(id__in=(project_regions | user_regions).values("id")).distinct() # regions = regions1.union(regions2).distinct() - options['regions'] = [ + options["regions"] = [ { - 'key': region.id, - 'value': region.get_verbose_title(), - } for region in regions + "key": region.id, + "value": region.get_verbose_title(), + } + for region in regions ] - if (fields is None or 'user_groups' in fields): + if fields is None or "user_groups" in fields: if projects: project_user_groups = _filter_by_projects(UserGroup.objects, projects).distinct() else: project_user_groups = UserGroup.objects.none() - user_user_groups = UserGroup.get_modifiable_for(request.user)\ - .distinct() - user_groups = UserGroup.objects.filter(id__in=(project_user_groups | user_user_groups).values('id')).distinct() + user_user_groups = UserGroup.get_modifiable_for(request.user).distinct() + user_groups = UserGroup.objects.filter(id__in=(project_user_groups | user_user_groups).values("id")).distinct() # user_groups = user_groups1.union(user_groups2) - options['user_groups'] = user_groups.distinct().annotate( - key=models.F('id'), - value=models.F('title') - ).values('key', 'value') + options["user_groups"] = ( + user_groups.distinct().annotate(key=models.F("id"), value=models.F("title")).values("key", "value") + ) - if (fields is None or 'involvement' in fields): - options['involvement'] = [ - {'key': 'my_projects', 'value': 'My projects'}, - {'key': 'not_my_projects', 'value': 'Not my projects'} + if fields is None or "involvement" in fields: + options["involvement"] = [ + {"key": "my_projects", "value": "My projects"}, + {"key": "not_my_projects", "value": "Not my projects"}, ] - options['project_status'] = [ - { - 'key': value, - 'value': label - } for value, label in Project.Status.choices - ] + options["project_status"] = [{"key": value, "value": label} for value, label in Project.Status.choices] return response.Response(options) def accept_project_confirm( - request, uidb64, pidb64, token, - template_name='project/project_join_request_confirm.html', + request, + uidb64, + pidb64, + token, + template_name="project/project_join_request_confirm.html", ): - accept = request.GET.get('accept', 'True').lower() == 'true' - role = request.GET.get('role', 'normal') + accept = request.GET.get("accept", "True").lower() == "true" + role = request.GET.get("role", "normal") try: uid = force_text(urlsafe_base64_decode(uidb64)) pid = force_text(urlsafe_base64_decode(pidb64)) @@ -888,27 +881,26 @@ def accept_project_confirm( join_request = None request_data = { - 'join_request': join_request, - 'will_responded_by': user, + "join_request": join_request, + "will_responded_by": user, } context = { - 'title': 'Project Join Request', - 'success': True, - 'accept': accept, - 'role': role, - 'frontend_url': get_frontend_url(''), - 'join_request': join_request, - 'project_url': Permalink.project(join_request.project.id) if join_request else None, + "title": "Project Join Request", + "success": True, + "accept": accept, + "role": role, + "frontend_url": get_frontend_url(""), + "join_request": join_request, + "project_url": Permalink.project(join_request.project.id) if join_request else None, } - if (join_request and user) is not None and\ - project_request_token_generator.check_token(request_data, token): + if (join_request and user) is not None and project_request_token_generator.check_token(request_data, token): if accept: ProjectViewSet._accept_request(user, join_request, role) else: ProjectViewSet._reject_request(user, join_request) else: - context['success'] = False + context["success"] = False return TemplateResponse(request, template_name, context) @@ -916,18 +908,15 @@ def accept_project_confirm( class ProjectRoleViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = ProjectRoleSerializer permission_classes = [permissions.IsAuthenticated] - queryset = ProjectRole.objects.order_by('level') + queryset = ProjectRole.objects.order_by("level") class ProjectUserGroupViewSet(viewsets.ModelViewSet): serializer_class = ProjectUserGroupSerializer permission_classes = [permissions.IsAuthenticated, ModifyPermission] - filter_backends = (django_filters.rest_framework.DjangoFilterBackend, - filters.SearchFilter, filters.OrderingFilter) + filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) queryset = ProjectUserGroupMembership.objects.all() filterset_class = ProjectUserGroupMembershipFilterSet def get_queryset(self): - return ProjectUserGroupMembership.objects.filter(project=self.kwargs['project_id']).select_related( - 'role' - ) + return ProjectUserGroupMembership.objects.filter(project=self.kwargs["project_id"]).select_related("role") diff --git a/apps/project/widgets.py b/apps/project/widgets.py index f97bc9c6b1..632b57a14c 100644 --- a/apps/project/widgets.py +++ b/apps/project/widgets.py @@ -13,24 +13,18 @@ def value_from_datadict(self, data, files, name): """ Get the checkbox values and or them to get the final value """ - to_or_vals = [ - v if data.get(f'{self.widget_name}_{k}') == 'on' else 0 - for k, v in self.permission_values.items() - ] - return reduce( - lambda acc, x: acc | x, - to_or_vals - ) + to_or_vals = [v if data.get(f"{self.widget_name}_{k}") == "on" else 0 for k, v in self.permission_values.items()] + return reduce(lambda acc, x: acc | x, to_or_vals) def render(self, name, value, attrs=None, renderer=None): - html = '' + html = "" for k, v in self.permission_values.items(): checked = value & v == v - html += f''' + html += f""" {k} - ''' + """ return html diff --git a/apps/quality_assurance/admin.py b/apps/quality_assurance/admin.py index 204a114397..5646916e6f 100644 --- a/apps/quality_assurance/admin.py +++ b/apps/quality_assurance/admin.py @@ -6,14 +6,15 @@ class EntryReviewCommentTextInline(admin.StackedInline): model = EntryReviewCommentText extra = 0 - readonly_fields = ('created_at',) + readonly_fields = ("created_at",) @admin.register(EntryReviewComment) class EntryReviewCommentAdmin(admin.ModelAdmin): inlines = [EntryReviewCommentTextInline] - list_display = ('id', 'created_by', 'created_at') - readonly_fields = ('created_at', 'entry_comment',) - autocomplete_fields = ( - 'created_by', 'mentioned_users', 'entry' + list_display = ("id", "created_by", "created_at") + readonly_fields = ( + "created_at", + "entry_comment", ) + autocomplete_fields = ("created_by", "mentioned_users", "entry") diff --git a/apps/quality_assurance/apps.py b/apps/quality_assurance/apps.py index 56a96a2436..260ad41146 100644 --- a/apps/quality_assurance/apps.py +++ b/apps/quality_assurance/apps.py @@ -2,4 +2,4 @@ class QualityAssuranceConfig(AppConfig): - name = 'quality_assurance' + name = "quality_assurance" diff --git a/apps/quality_assurance/dataloaders.py b/apps/quality_assurance/dataloaders.py index ab5dfa4d37..f7ff1f8396 100644 --- a/apps/quality_assurance/dataloaders.py +++ b/apps/quality_assurance/dataloaders.py @@ -1,12 +1,11 @@ from collections import defaultdict -from promise import Promise from django.utils.functional import cached_property +from promise import Promise +from quality_assurance.models import EntryReviewCommentText from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from quality_assurance.models import EntryReviewCommentText - class EntryReviewCommentTextLoader(DataLoaderWithContext): def batch_load_fn(self, keys): diff --git a/apps/quality_assurance/enums.py b/apps/quality_assurance/enums.py index 26d8791f55..34f23e6795 100644 --- a/apps/quality_assurance/enums.py +++ b/apps/quality_assurance/enums.py @@ -1,29 +1,27 @@ import graphene +from quality_assurance.models import EntryReviewComment + from utils.graphene.enums import ( convert_enum_to_graphene_enum, get_enum_name_from_django_field, ) -from quality_assurance.models import EntryReviewComment - -EntryReviewCommentTypeEnum = convert_enum_to_graphene_enum(EntryReviewComment.CommentType, name='EntryReviewCommentTypeEnum') +EntryReviewCommentTypeEnum = convert_enum_to_graphene_enum(EntryReviewComment.CommentType, name="EntryReviewCommentTypeEnum") enum_map = { get_enum_name_from_django_field(field): enum - for field, enum in ( - (EntryReviewComment.comment_type, EntryReviewCommentTypeEnum), - ) + for field, enum in ((EntryReviewComment.comment_type, EntryReviewCommentTypeEnum),) } class EntryReviewCommentOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_CREATED_AT = 'created_at' - ASC_COMMENT_TYPE = 'comment_type' - ASC_ENTRY = 'entry' + ASC_ID = "id" + ASC_CREATED_AT = "created_at" + ASC_COMMENT_TYPE = "comment_type" + ASC_ENTRY = "entry" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_COMMENT_TYPE = f'-{ASC_COMMENT_TYPE}' - DESC_ENTRY = f'-{ASC_ENTRY}' + DESC_ID = f"-{ASC_ID}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_COMMENT_TYPE = f"-{ASC_COMMENT_TYPE}" + DESC_ENTRY = f"-{ASC_ENTRY}" diff --git a/apps/quality_assurance/factories.py b/apps/quality_assurance/factories.py index 7d07db49f8..7ef19f118a 100644 --- a/apps/quality_assurance/factories.py +++ b/apps/quality_assurance/factories.py @@ -10,7 +10,7 @@ class Meta: class EntryReviewCommentTextFactory(DjangoModelFactory): - text = factory.Sequence(lambda n: f'Text-{n}') + text = factory.Sequence(lambda n: f"Text-{n}") class Meta: model = EntryReviewCommentText diff --git a/apps/quality_assurance/filters.py b/apps/quality_assurance/filters.py index 7ec77b2ea4..909643d8ac 100644 --- a/apps/quality_assurance/filters.py +++ b/apps/quality_assurance/filters.py @@ -1,16 +1,14 @@ import django_filters -from utils.graphene.filters import ( - IDFilter, - MultipleInputFilter, -) -from .models import EntryReviewComment +from utils.graphene.filters import IDFilter, MultipleInputFilter + from .enums import EntryReviewCommentOrderingEnum +from .models import EntryReviewComment class EntryReviewCommentGQFilterSet(django_filters.FilterSet): entry = IDFilter() - ordering = MultipleInputFilter(EntryReviewCommentOrderingEnum, method='ordering_filter') + ordering = MultipleInputFilter(EntryReviewCommentOrderingEnum, method="ordering_filter") class Meta: model = EntryReviewComment diff --git a/apps/quality_assurance/models.py b/apps/quality_assurance/models.py index 8414f30c6a..a4412123f3 100644 --- a/apps/quality_assurance/models.py +++ b/apps/quality_assurance/models.py @@ -1,21 +1,20 @@ +from django.contrib.contenttypes.fields import GenericRelation from django.db import models from django.utils.functional import cached_property -from django.contrib.contenttypes.fields import GenericRelation - -from notification.models import Assignment from entry.models import Entry, EntryComment +from notification.models import Assignment from user.models import User # ---------------------------------------------- Abstract Table --------------------------------------- class BaseReviewComment(models.Model): - created_by = models.ForeignKey(User, related_name='%(class)s_created', on_delete=models.CASCADE) + created_by = models.ForeignKey(User, related_name="%(class)s_created", on_delete=models.CASCADE) created_at = models.DateTimeField(auto_now_add=True) mentioned_users = models.ManyToManyField(User, blank=True) class Meta: abstract = True - ordering = ('-id',) + ordering = ("-id",) def can_delete(self, user): return self.can_modify(user) @@ -27,31 +26,31 @@ def can_modify(self, user): def get_for(cls, user): return ( cls.objects.select_related( - 'entry', - 'created_by', - 'created_by__profile', - 'created_by__profile__display_picture', - ).prefetch_related( - 'comment_texts', - 'mentioned_users', - 'mentioned_users__profile', - 'mentioned_users__profile__display_picture', - ).filter( - models.Q(entry__lead__project__members=user) | - models.Q(entry__lead__project__user_groups__members=user) - ).distinct() + "entry", + "created_by", + "created_by__profile", + "created_by__profile__display_picture", + ) + .prefetch_related( + "comment_texts", + "mentioned_users", + "mentioned_users__profile", + "mentioned_users__profile__display_picture", + ) + .filter(models.Q(entry__lead__project__members=user) | models.Q(entry__lead__project__user_groups__members=user)) + .distinct() ) @cached_property def text(self): - last_comment_text = self.comment_texts.order_by('-id').first() + last_comment_text = self.comment_texts.order_by("-id").first() if last_comment_text: return last_comment_text.text def save(self, *args, **kwargs): super().save(*args, **kwargs) # NOTE: Clear text if cached - if hasattr(self, 'text'): + if hasattr(self, "text"): del self.text @@ -60,49 +59,47 @@ class BaseReviewCommentText(models.Model): NOTE: Define comment comment = models.ForeignKey(BaseReviewComment, related_name='comment_texts', on_delete=models.CASCADE) """ + created_at = models.DateTimeField(auto_now_add=True) text = models.TextField() class Meta: abstract = True - ordering = ('-id',) + ordering = ("-id",) # ---------------------------------------------- Non-Abstract Table ------------------------------------- + class EntryReviewComment(BaseReviewComment): class CommentType(models.IntegerChoices): - COMMENT = 0, 'Comment' - VERIFY = 1, 'Verify' - UNVERIFY = 2, 'Unverify' - CONTROL = 3, 'Control' - UNCONTROL = 4, 'UnControl' + COMMENT = 0, "Comment" + VERIFY = 1, "Verify" + UNVERIFY = 2, "Unverify" + CONTROL = 3, "Control" + UNCONTROL = 4, "UnControl" - entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name='review_comments') + entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="review_comments") comment_type = models.IntegerField(choices=CommentType.choices, default=CommentType.COMMENT) entry_comment = models.ForeignKey(EntryComment, on_delete=models.SET_NULL, null=True, blank=True) - assignments = GenericRelation(Assignment, related_query_name='entry_review_comment') + assignments = GenericRelation(Assignment, related_query_name="entry_review_comment") class Meta(BaseReviewComment.Meta): abstract = False def __str__(self): - return f'{self.entry}: {self.text}' + return f"{self.entry}: {self.text}" def can_delete(self, user): return self.comment_type == self.CommentType.COMMENT and self.can_modify(user) def get_related_users(self, skip_owner_user=True): users = list( - self.mentioned_users.through.objects - .filter(entryreviewcomment__entry=self.entry) - .values_list('user', flat=True).distinct() - ) - users.extend( - type(self).objects - .filter(entry=self.entry) - .values_list('created_by_id', flat=True).distinct() + self.mentioned_users.through.objects.filter(entryreviewcomment__entry=self.entry) + .values_list("user", flat=True) + .distinct() ) + users.extend(type(self).objects.filter(entry=self.entry).values_list("created_by_id", flat=True).distinct()) queryset = User.objects.filter(pk__in=set(users)) if skip_owner_user: queryset = queryset.exclude(pk=self.created_by_id) @@ -110,9 +107,7 @@ def get_related_users(self, skip_owner_user=True): class EntryReviewCommentText(BaseReviewCommentText): - comment = models.ForeignKey( - EntryReviewComment, related_name='comment_texts', on_delete=models.CASCADE - ) + comment = models.ForeignKey(EntryReviewComment, related_name="comment_texts", on_delete=models.CASCADE) class Meta(BaseReviewCommentText.Meta): abstract = False diff --git a/apps/quality_assurance/mutation.py b/apps/quality_assurance/mutation.py index e09d545afb..eab0cd32e0 100644 --- a/apps/quality_assurance/mutation.py +++ b/apps/quality_assurance/mutation.py @@ -1,26 +1,23 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP from .models import EntryReviewComment from .schema import EntryReviewCommentDetailType -from .serializers import ( - EntryReviewCommentGqlSerializer as EntryReviewCommentSerializer, -) - +from .serializers import EntryReviewCommentGqlSerializer as EntryReviewCommentSerializer EntryReviewCommentInputType = generate_input_type_for_serializer( - 'EntryReviewCommentInputType', + "EntryReviewCommentInputType", serializer_class=EntryReviewCommentSerializer, ) -class EntryReviewCommentMutationMixin(): +class EntryReviewCommentMutationMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(created_by=info.context.user) @@ -29,6 +26,7 @@ def filter_queryset(cls, qs, info): class CreateEntryReviewComment(EntryReviewCommentMutationMixin, PsGrapheneMutation): class Arguments: data = EntryReviewCommentInputType(required=True) + model = EntryReviewComment serializer_class = EntryReviewCommentSerializer result = graphene.Field(EntryReviewCommentDetailType) @@ -39,6 +37,7 @@ class UpdateEntryReviewComment(EntryReviewCommentMutationMixin, PsGrapheneMutati class Arguments: data = EntryReviewCommentInputType(required=True) id = graphene.ID(required=True) + model = EntryReviewComment serializer_class = EntryReviewCommentSerializer result = graphene.Field(EntryReviewCommentDetailType) @@ -48,12 +47,13 @@ class Arguments: class DeleteEntryReviewComment(EntryReviewCommentMutationMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = EntryReviewComment result = graphene.Field(EntryReviewCommentDetailType) permissions = [PP.Permission.CREATE_ENTRY, PP.Permission.UPDATE_ENTRY] -class Mutation(): +class Mutation: entry_review_comment_create = CreateEntryReviewComment.Field() entry_review_comment_update = UpdateEntryReviewComment.Field() entry_review_comment_delete = DeleteEntryReviewComment.Field() diff --git a/apps/quality_assurance/schema.py b/apps/quality_assurance/schema.py index b1a1aa8742..ff87640913 100644 --- a/apps/quality_assurance/schema.py +++ b/apps/quality_assurance/schema.py @@ -1,16 +1,15 @@ import graphene - from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField +from lead.models import Lead +from quality_assurance.models import EntryReviewComment, EntryReviewCommentText -from utils.graphene.types import CustomDjangoListObjectType +from deep.permissions import ProjectPermissions as PP +from utils.graphene.enums import EnumDescription from utils.graphene.fields import DjangoPaginatedListObjectField from utils.graphene.pagination import NoOrderingPageGraphqlPagination -from utils.graphene.enums import EnumDescription -from deep.permissions import ProjectPermissions as PP -from lead.models import Lead +from utils.graphene.types import CustomDjangoListObjectType -from quality_assurance.models import EntryReviewComment, EntryReviewCommentText from .enums import EntryReviewCommentTypeEnum from .filters import EntryReviewCommentGQFilterSet @@ -19,9 +18,7 @@ def get_entry_comment_qs(info): """ NOTE: To be used in EntryReviewCommentDetailType """ - entry_comment_qs = EntryReviewComment.objects.filter( - entry__project=info.context.active_project - ) + entry_comment_qs = EntryReviewComment.objects.filter(entry__project=info.context.active_project) # Generate queryset according to permission if PP.check_permission(info, PP.Permission.VIEW_ENTRY): if PP.check_permission(info, PP.Permission.VIEW_ALL_LEAD): @@ -35,39 +32,28 @@ class EntryReviewCommentTextType(DjangoObjectType): class Meta: model = EntryReviewCommentText only_fields = ( - 'id', - 'created_at', - 'text', + "id", + "created_at", + "text", ) class EntryReviewCommentType(DjangoObjectType): class Meta: model = EntryReviewComment - only_fields = ( - 'id', - 'created_by', - 'created_at', - 'mentioned_users' - ) + only_fields = ("id", "created_by", "created_at", "mentioned_users") comment_type = graphene.Field(EntryReviewCommentTypeEnum, required=True) - comment_type_display = EnumDescription(source='get_comment_type_display', required=True) + comment_type_display = EnumDescription(source="get_comment_type_display", required=True) text = graphene.String() - entry = graphene.ID(source='entry_id', required=True) + entry = graphene.ID(source="entry_id", required=True) class EntryReviewCommentDetailType(EntryReviewCommentType): class Meta: model = EntryReviewComment skip_registry = True - only_fields = ( - 'id', - 'entry', - 'created_by', - 'created_at', - 'mentioned_users' - ) + only_fields = ("id", "entry", "created_by", "created_at", "mentioned_users") text_history = graphene.List(graphene.NonNull(EntryReviewCommentTextType)) @@ -89,10 +75,7 @@ class Meta: class Query: review_comment = DjangoObjectField(EntryReviewCommentDetailType) review_comments = DjangoPaginatedListObjectField( - EntryReviewCommentListType, - pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize' - ) + EntryReviewCommentListType, pagination=NoOrderingPageGraphqlPagination(page_size_query_param="pageSize") ) @staticmethod diff --git a/apps/quality_assurance/serializers.py b/apps/quality_assurance/serializers.py index 19ed8ae3de..0bc10800af 100644 --- a/apps/quality_assurance/serializers.py +++ b/apps/quality_assurance/serializers.py @@ -1,46 +1,42 @@ -from django.utils.functional import cached_property from django.db import transaction +from django.utils.functional import cached_property +from entry.models import Entry +from notification.models import Notification +from notification.tasks import send_notifications_for_comment +from project.models import ProjectMembership +from project.serializers import ProjectNotificationSerializer from rest_framework import serializers +from user.serializers import EntryCommentUserSerializer, UserNotificationSerializer from deep.middleware import get_current_user from deep.permissions import ProjectPermissions as PP from deep.serializers import ProjectPropertySerializerMixin -from user.serializers import EntryCommentUserSerializer, UserNotificationSerializer -from project.serializers import ProjectNotificationSerializer -from entry.models import Entry -from project.models import ProjectMembership -from notification.models import Notification -from notification.tasks import send_notifications_for_comment - -from .models import ( - EntryReviewComment, - EntryReviewCommentText, -) +from .models import EntryReviewComment, EntryReviewCommentText class EntryReviewCommentTextSerializer(serializers.ModelSerializer): class Meta: model = EntryReviewCommentText - exclude = ('id', 'comment') + exclude = ("id", "comment") class EntryReviewCommentSerializer(serializers.ModelSerializer): text = serializers.CharField(write_only=True, required=False) - text_history = EntryReviewCommentTextSerializer(source='comment_texts', read_only=True, many=True) - lead = serializers.IntegerField(source='entry.lead_id', read_only=True) - created_by_details = EntryCommentUserSerializer(source='created_by', read_only=True) - mentioned_users_details = EntryCommentUserSerializer(source='mentioned_users', read_only=True, many=True) - comment_type_display = serializers.CharField(source='get_comment_type_display', read_only=True) + text_history = EntryReviewCommentTextSerializer(source="comment_texts", read_only=True, many=True) + lead = serializers.IntegerField(source="entry.lead_id", read_only=True) + created_by_details = EntryCommentUserSerializer(source="created_by", read_only=True) + mentioned_users_details = EntryCommentUserSerializer(source="mentioned_users", read_only=True, many=True) + comment_type_display = serializers.CharField(source="get_comment_type_display", read_only=True) class Meta: model = EntryReviewComment - fields = '__all__' - read_only_fields = ('entry', 'is_resolved', 'created_by', 'resolved_at') + fields = "__all__" + read_only_fields = ("entry", "is_resolved", "created_by", "resolved_at") def _get_entry(self): - if not hasattr(self, '_entry'): - entry = Entry.objects.get(pk=int(self.context['entry_id'])) + if not hasattr(self, "_entry"): + entry = Entry.objects.get(pk=int(self.context["entry_id"])) self._entry = entry return self._entry @@ -56,52 +52,52 @@ def validate_comment_type(self, comment_type): verified_by_qs = Entry.verified_by.through.objects.filter(entry=entry, user=current_user) if ( - comment_type in [ + comment_type + in [ EntryReviewComment.CommentType.CONTROL, EntryReviewComment.CommentType.UNCONTROL, - ] and - not ProjectMembership.objects.filter( + ] + and not ProjectMembership.objects.filter( project=entry.project, - member=self.context['request'].user, + member=self.context["request"].user, badges__contains=[ProjectMembership.BadgeType.QA.value], ).exists() ): - raise serializers.ValidationError({ - 'comment_type': 'Controlled/UnControlled comment are only allowd by QA', - }) + raise serializers.ValidationError( + { + "comment_type": "Controlled/UnControlled comment are only allowd by QA", + } + ) if comment_type == EntryReviewComment.CommentType.VERIFY: if verified_by_qs.exists(): - raise serializers.ValidationError({'comment_type': 'Already verified'}) + raise serializers.ValidationError({"comment_type": "Already verified"}) entry.verified_by.add(current_user) elif comment_type == EntryReviewComment.CommentType.UNVERIFY: if not verified_by_qs.exists(): - raise serializers.ValidationError({'comment_type': 'Need to be verified first'}) + raise serializers.ValidationError({"comment_type": "Need to be verified first"}) entry.verified_by.remove(current_user) elif comment_type == EntryReviewComment.CommentType.CONTROL: if entry.controlled: - raise serializers.ValidationError({'comment_type': 'Already controlled'}) + raise serializers.ValidationError({"comment_type": "Already controlled"}) entry.control(current_user) elif comment_type == EntryReviewComment.CommentType.UNCONTROL: if not entry.controlled: - raise serializers.ValidationError({'comment_type': 'Need to be controlled first'}) + raise serializers.ValidationError({"comment_type": "Need to be controlled first"}) entry.control(current_user, controlled=False) return comment_type def validate(self, data): - mentioned_users = data.get('mentioned_users') - data['entry'] = entry = self._get_entry() + mentioned_users = data.get("mentioned_users") + data["entry"] = entry = self._get_entry() # Check if all assignes are members if mentioned_users: selected_existing_members_count = ( - ProjectMembership.objects.filter(project=entry.project, member__in=mentioned_users) - .distinct('member').count() + ProjectMembership.objects.filter(project=entry.project, member__in=mentioned_users).distinct("member").count() ) if selected_existing_members_count != len(mentioned_users): - raise serializers.ValidationError( - {'mentioned_users': "Selected mentioned users don't belong to this project"} - ) - data['created_by'] = get_current_user() + raise serializers.ValidationError({"mentioned_users": "Selected mentioned users don't belong to this project"}) + data["created_by"] = get_current_user() return data def _add_comment_text(self, comment, text): @@ -114,39 +110,40 @@ def comment_save(self, validated_data, instance=None): """ Comment Middleware save logic """ - text = validated_data.pop('text', '').strip() - comment_type = validated_data.get('comment_type', EntryReviewComment.CommentType.COMMENT) + text = validated_data.pop("text", "").strip() + comment_type = validated_data.get("comment_type", EntryReviewComment.CommentType.COMMENT) # Make sure to check text required - if not text and not (instance and instance.text) and comment_type in [ - EntryReviewComment.CommentType.COMMENT, - EntryReviewComment.CommentType.UNVERIFY, - EntryReviewComment.CommentType.UNCONTROL, - ]: - raise serializers.ValidationError({'text': 'Text is required'}) + if ( + not text + and not (instance and instance.text) + and comment_type + in [ + EntryReviewComment.CommentType.COMMENT, + EntryReviewComment.CommentType.UNVERIFY, + EntryReviewComment.CommentType.UNCONTROL, + ] + ): + raise serializers.ValidationError({"text": "Text is required"}) current_text = instance and instance.text text_changed = current_text != text - notify_meta = {'text_changed': text_changed} + notify_meta = {"text_changed": text_changed} if instance is None: # Create - notify_meta['notification_type'] = Notification.Type.ENTRY_REVIEW_COMMENT_ADD - notify_meta['text_changed'] = True + notify_meta["notification_type"] = Notification.Type.ENTRY_REVIEW_COMMENT_ADD + notify_meta["text_changed"] = True instance = super().create(validated_data) else: # Update - notify_meta['notification_type'] = Notification.Type.ENTRY_REVIEW_COMMENT_MODIFY - current_mentioned_users_pk = list(instance.mentioned_users.values_list('pk', flat=True)) - notify_meta['new_mentioned_users'] = [ - user - for user in validated_data.get('mentioned_users', []) - if user.pk not in current_mentioned_users_pk + notify_meta["notification_type"] = Notification.Type.ENTRY_REVIEW_COMMENT_MODIFY + current_mentioned_users_pk = list(instance.mentioned_users.values_list("pk", flat=True)) + notify_meta["new_mentioned_users"] = [ + user for user in validated_data.get("mentioned_users", []) if user.pk not in current_mentioned_users_pk ] instance = super().update(instance, validated_data) if text and text_changed: self._add_comment_text(instance, text) - transaction.on_commit( - lambda: send_notifications_for_comment(instance.pk, notify_meta) - ) + transaction.on_commit(lambda: send_notifications_for_comment(instance.pk, notify_meta)) return instance def create(self, validated_data): @@ -158,17 +155,24 @@ def update(self, instance, validated_data): class EntryReviewCommentNotificationSerializer(serializers.ModelSerializer): text = serializers.CharField(read_only=True) - lead = serializers.IntegerField(source='entry.lead_id', read_only=True) - project_details = ProjectNotificationSerializer(source='entry.project', read_only=True) - created_by_details = UserNotificationSerializer(source='created_by', read_only=True) - comment_type_display = serializers.CharField(source='get_comment_type_display', read_only=True) + lead = serializers.IntegerField(source="entry.lead_id", read_only=True) + project_details = ProjectNotificationSerializer(source="entry.project", read_only=True) + created_by_details = UserNotificationSerializer(source="created_by", read_only=True) + comment_type_display = serializers.CharField(source="get_comment_type_display", read_only=True) class Meta: model = EntryReviewComment fields = ( - 'id', 'entry', 'created_at', - 'text', 'lead', 'project_details', 'created_by_details', - 'comment_type', 'comment_type_display', 'mentioned_users', + "id", + "entry", + "created_at", + "text", + "lead", + "project_details", + "created_by_details", + "comment_type", + "comment_type_display", + "mentioned_users", ) @@ -183,13 +187,13 @@ class EntryReviewCommentGqlSerializer(ProjectPropertySerializerMixin, serializer class Meta: model = EntryReviewComment fields = ( - 'entry', - 'comment_type', - 'text', - 'mentioned_users', + "entry", + "comment_type", + "text", + "mentioned_users", ) - project_property_attribute = 'entry' + project_property_attribute = "entry" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -197,83 +201,73 @@ def __init__(self, *args, **kwargs): @cached_property def entry(self): - entry = ( - (self.instance and self.instance.entry) or - Entry.objects.filter(id=self.initial_data.get('entry')).first() - ) + entry = (self.instance and self.instance.entry) or Entry.objects.filter(id=self.initial_data.get("entry")).first() if entry is None: - raise serializers.ValidationError('Entry is not defined. Invalid request.') + raise serializers.ValidationError("Entry is not defined. Invalid request.") if entry.project != self.project: raise serializers.ValidationError("Entry from another project isn't allowed!!") return entry def validate_entry(self, entry): if self.instance and self.instance.entry != entry: - raise serializers.ValidationError('Changing comment entry is not allowed.') + raise serializers.ValidationError("Changing comment entry is not allowed.") return entry def validate_comment_type(self, comment_type): # No validation needed for edit since we don't allow changing it if self.instance: if self.instance.comment_type != comment_type: - raise serializers.ValidationError('Changing comment type is not allowed') + raise serializers.ValidationError("Changing comment type is not allowed") return comment_type if comment_type == EntryReviewComment.CommentType.COMMENT: return comment_type # No additional validation/action required - current_user = self.context['request'].user + current_user = self.context["request"].user verified_by_qs = Entry.verified_by.through.objects.filter(entry=self.entry, user=current_user) - if ( - comment_type in [ - EntryReviewComment.CommentType.CONTROL, - EntryReviewComment.CommentType.UNCONTROL, - ] and - not PP.check_permission_from_serializer( - self.context['request'], - PP.Permission.CAN_QUALITY_CONTROL, - ) + if comment_type in [ + EntryReviewComment.CommentType.CONTROL, + EntryReviewComment.CommentType.UNCONTROL, + ] and not PP.check_permission_from_serializer( + self.context["request"], + PP.Permission.CAN_QUALITY_CONTROL, ): - raise serializers.ValidationError('Controlled/Uncontrolled comment are only allowd by QA!!') + raise serializers.ValidationError("Controlled/Uncontrolled comment are only allowd by QA!!") if comment_type == EntryReviewComment.CommentType.VERIFY: if verified_by_qs.exists(): - raise serializers.ValidationError('Already verified!!') + raise serializers.ValidationError("Already verified!!") self.pending_commits.append(lambda: self.entry.verified_by.add(current_user)) elif comment_type == EntryReviewComment.CommentType.UNVERIFY: if not verified_by_qs.exists(): - raise serializers.ValidationError('Need to be verified first!!') + raise serializers.ValidationError("Need to be verified first!!") self.pending_commits.append(lambda: self.entry.verified_by.remove(current_user)) elif comment_type == EntryReviewComment.CommentType.CONTROL: if self.entry.controlled: - raise serializers.ValidationError('Already controlled!!') + raise serializers.ValidationError("Already controlled!!") self.pending_commits.append(lambda: self.entry.control(current_user)) elif comment_type == EntryReviewComment.CommentType.UNCONTROL: if not self.entry.controlled: - raise serializers.ValidationError('Need to be controlled first!!') + raise serializers.ValidationError("Need to be controlled first!!") self.pending_commits.append(lambda: self.entry.control(current_user, controlled=False)) return comment_type def validate_mentioned_users(self, mentioned_users): if mentioned_users: selected_existing_members_count = ( - ProjectMembership.objects.filter( - project=self.project, - member__in=mentioned_users - ) - .distinct('member').count() + ProjectMembership.objects.filter(project=self.project, member__in=mentioned_users).distinct("member").count() ) if selected_existing_members_count != len(mentioned_users): raise serializers.ValidationError("Selected mentioned users don't belong to this project") return mentioned_users def validate(self, validated_data): - text = validated_data['text'] = validated_data.pop('text', '').strip() or (self.instance and self.instance.text) + text = validated_data["text"] = validated_data.pop("text", "").strip() or (self.instance and self.instance.text) comment_type = ( - validated_data.get('comment_type') or - (self.instance and self.instance.comment_type) or - EntryReviewComment.CommentType.COMMENT + validated_data.get("comment_type") + or (self.instance and self.instance.comment_type) + or EntryReviewComment.CommentType.COMMENT ) # Make sure to check text required if not text and comment_type in [ @@ -281,40 +275,41 @@ def validate(self, validated_data): EntryReviewComment.CommentType.UNVERIFY, EntryReviewComment.CommentType.UNCONTROL, ]: - raise serializers.ValidationError({ - 'text': 'Text is required for comment type', - }) + raise serializers.ValidationError( + { + "text": "Text is required for comment type", + } + ) # Only creator can update - if self.instance and self.instance.created_by != self.context['request'].user: - raise serializers.ValidationError('Only comment creator can update.') + if self.instance and self.instance.created_by != self.context["request"].user: + raise serializers.ValidationError("Only comment creator can update.") return validated_data def comment_save(self, validated_data, instance=None): """ Comment Middleware save logic """ + def _add_comment_text(comment, text): return EntryReviewCommentText.objects.create( comment=comment, text=text, ) - text = validated_data.pop('text') # Is available from validate() + text = validated_data.pop("text") # Is available from validate() current_text = instance and instance.text text_changed = current_text != text - notify_meta = {'text_changed': text_changed} + notify_meta = {"text_changed": text_changed} if instance is None: # Create - notify_meta['notification_type'] = Notification.Type.ENTRY_REVIEW_COMMENT_ADD - notify_meta['text_changed'] = True + notify_meta["notification_type"] = Notification.Type.ENTRY_REVIEW_COMMENT_ADD + notify_meta["text_changed"] = True instance = super().create(validated_data) else: # Update - notify_meta['notification_type'] = Notification.Type.ENTRY_REVIEW_COMMENT_MODIFY - current_mentioned_users_pk = list(instance.mentioned_users.values_list('pk', flat=True)) - notify_meta['new_mentioned_users'] = [ - user - for user in validated_data.get('mentioned_users', []) - if user.pk not in current_mentioned_users_pk + notify_meta["notification_type"] = Notification.Type.ENTRY_REVIEW_COMMENT_MODIFY + current_mentioned_users_pk = list(instance.mentioned_users.values_list("pk", flat=True)) + notify_meta["new_mentioned_users"] = [ + user for user in validated_data.get("mentioned_users", []) if user.pk not in current_mentioned_users_pk ] instance = super().update(instance, validated_data) instance.save() @@ -323,13 +318,11 @@ def _add_comment_text(comment, text): pending_commit() if text and text_changed: _add_comment_text(instance, text) - transaction.on_commit( - lambda: send_notifications_for_comment(instance.pk, notify_meta) - ) + transaction.on_commit(lambda: send_notifications_for_comment(instance.pk, notify_meta)) return instance def create(self, validated_data): - validated_data['created_by'] = self.context['request'].user + validated_data["created_by"] = self.context["request"].user return self.comment_save(validated_data) def update(self, instance, validated_data): diff --git a/apps/quality_assurance/tests/test_apis.py b/apps/quality_assurance/tests/test_apis.py index 6df009e2a6..c27b8cf703 100644 --- a/apps/quality_assurance/tests/test_apis.py +++ b/apps/quality_assurance/tests/test_apis.py @@ -1,11 +1,9 @@ -from deep.tests import TestCase from entry.models import Entry from notification.models import Notification from project.models import ProjectMembership -from quality_assurance.models import ( - # EntryReviewComment, - EntryReviewComment, -) +from quality_assurance.models import EntryReviewComment # EntryReviewComment, + +from deep.tests import TestCase VerifiedByQs = Entry.verified_by.through.objects @@ -25,50 +23,50 @@ def test_entry_review_comment_basic_api(self): self.authenticate(user1) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, - 'mentioned_users': [user1.pk, user2.pk, user3.pk], + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, + "mentioned_users": [user1.pk, user2.pk, user3.pk], } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) - review_comment1_pk = response.data['id'] + review_comment1_pk = response.data["id"] - response = self.client.get(f'/api/v1/entries/{entry.pk}/review-comments/') + response = self.client.get(f"/api/v1/entries/{entry.pk}/review-comments/") self.assert_200(response) - assert len(response.data['results']) == 1 + assert len(response.data["results"]) == 1 # Update only allowd by comment creater - data['text'] = 'This is updated text comment' - response = self.client.put(f'/api/v1/entries/{entry.pk}/review-comments/{review_comment1_pk}/', data=data) + data["text"] = "This is updated text comment" + response = self.client.put(f"/api/v1/entries/{entry.pk}/review-comments/{review_comment1_pk}/", data=data) self.assert_200(response) - self.assertEqual(response.data['text_history'][0]['text'], data['text']) + self.assertEqual(response.data["text_history"][0]["text"], data["text"]) self.authenticate(user2) - response = self.client.put(f'/api/v1/entries/{entry.pk}/review-comments/{review_comment1_pk}/', data=data) + response = self.client.put(f"/api/v1/entries/{entry.pk}/review-comments/{review_comment1_pk}/", data=data) self.assert_403(response) self.authenticate(user2) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, - 'mentioned_users': [user1.pk, user2.pk, user3.pk], + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, + "mentioned_users": [user1.pk, user2.pk, user3.pk], } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) - review_comment1_pk = response.data['id'] + review_comment1_pk = response.data["id"] - response = self.client.get(f'/api/v1/entries/{entry.pk}/review-comments/') + response = self.client.get(f"/api/v1/entries/{entry.pk}/review-comments/") self.assert_200(response) - assert len(response.data['results']) == 2 + assert len(response.data["results"]) == 2 self.authenticate(user4) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, - 'mentioned_users': [user1.pk, user2.pk, user3.pk], + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, + "mentioned_users": [user1.pk, user2.pk, user3.pk], } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_403(response) - response = self.client.get(f'/api/v1/entries/{entry.pk}/review-comments/') + response = self.client.get(f"/api/v1/entries/{entry.pk}/review-comments/") self.assert_403(response) def test_entry_review_comment_verify_api(self): @@ -83,47 +81,47 @@ def test_entry_review_comment_verify_api(self): self.authenticate(user1) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) assert VerifiedByQs.filter(entry=entry).count() == 0 # Should include is_verified_by_current_user as False - response = self.client.post('/api/v1/entries/filter/', data={'project': project.pk}) + response = self.client.post("/api/v1/entries/filter/", data={"project": project.pk}) self.assert_200(response) - assert not response.data['results'][0]['is_verified_by_current_user'] + assert not response.data["results"][0]["is_verified_by_current_user"] # Verify data = { - 'text': 'This is a test comment for approvable', - 'comment_type': EntryReviewComment.CommentType.VERIFY, + "text": "This is a test comment for approvable", + "comment_type": EntryReviewComment.CommentType.VERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) assert VerifiedByQs.filter(entry=entry).count() == 1 # Should include is_verified_by_current_user as True - response = self.client.post('/api/v1/entries/filter/', data={'project': project.pk}) + response = self.client.post("/api/v1/entries/filter/", data={"project": project.pk}) self.assert_200(response) - assert response.data['results'][0]['is_verified_by_current_user'] + assert response.data["results"][0]["is_verified_by_current_user"] self.authenticate(user2) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.VERIFY, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.VERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) assert VerifiedByQs.filter(entry=entry).count() == 2 # Unverify data = { - 'text': 'This is a test comment for unapprovable', - 'comment_type': EntryReviewComment.CommentType.UNVERIFY, + "text": "This is a test comment for unapprovable", + "comment_type": EntryReviewComment.CommentType.UNVERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) assert VerifiedByQs.filter(entry=entry).count() == 1 @@ -131,10 +129,10 @@ def test_entry_review_comment_verify_api(self): # Can't verify already verify self.authenticate(user1) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.VERIFY, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.VERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) assert VerifiedByQs.filter(entry=entry).count() == 1 @@ -142,17 +140,17 @@ def test_entry_review_comment_verify_api(self): # Can't unverify not verify self.authenticate(user2) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.UNVERIFY, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.UNVERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) self.authenticate(user3) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.UNVERIFY, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.UNVERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) assert VerifiedByQs.filter(entry=entry).count() == 1 @@ -169,16 +167,16 @@ def test_entry_review_comment_project_qa_badge_api(self): user1_membership.save() data = { - 'text': 'This is a test comment', - 'comment_type': comment_type, + "text": "This is a test comment", + "comment_type": comment_type, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) user1_membership.badges = [ProjectMembership.BadgeType.QA] user1_membership.save() - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) def test_entry_review_comment_control_api(self): @@ -193,18 +191,18 @@ def test_entry_review_comment_control_api(self): self.authenticate(user1) data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) # Control data = { - 'text': 'This is a test comment for control/verify', - 'comment_type': EntryReviewComment.CommentType.CONTROL, + "text": "This is a test comment for control/verify", + "comment_type": EntryReviewComment.CommentType.CONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) entry.refresh_from_db() assert entry.controlled @@ -212,10 +210,10 @@ def test_entry_review_comment_control_api(self): # Control using same user again data = { - 'text': 'This is a test comment to again control already verified', - 'comment_type': EntryReviewComment.CommentType.CONTROL, + "text": "This is a test comment to again control already verified", + "comment_type": EntryReviewComment.CommentType.CONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) entry.refresh_from_db() assert entry.controlled @@ -224,10 +222,10 @@ def test_entry_review_comment_control_api(self): # Control using another user again self.authenticate(user2) data = { - 'text': 'This is a test comment to again control already verified', - 'comment_type': EntryReviewComment.CommentType.CONTROL, + "text": "This is a test comment to again control already verified", + "comment_type": EntryReviewComment.CommentType.CONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) entry.refresh_from_db() assert entry.controlled @@ -236,10 +234,10 @@ def test_entry_review_comment_control_api(self): # Uncontrol (any users can also uncontrol) self.authenticate(user2) data = { - 'text': 'This is a test comment for uncontrol', - 'comment_type': EntryReviewComment.CommentType.UNCONTROL, + "text": "This is a test comment for uncontrol", + "comment_type": EntryReviewComment.CommentType.UNCONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) entry.refresh_from_db() assert not entry.controlled @@ -249,10 +247,10 @@ def test_entry_review_comment_control_api(self): self.authenticate(user) # Can't uncontrol already uncontrol data = { - 'text': 'This is a test comment', - 'comment_type': EntryReviewComment.CommentType.UNVERIFY, + "text": "This is a test comment", + "comment_type": EntryReviewComment.CommentType.UNVERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_400(response) entry.refresh_from_db() assert not entry.controlled @@ -272,34 +270,34 @@ def test_entry_review_comment_summary_api(self): self.authenticate(user1) data = { - 'text': 'This is a comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, + "text": "This is a comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) for user in [user2, user3]: self.authenticate(user) data = { - 'text': 'This is a verify comment', - 'comment_type': EntryReviewComment.CommentType.VERIFY, + "text": "This is a verify comment", + "comment_type": EntryReviewComment.CommentType.VERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) self.authenticate(user4) data = { - 'text': 'This is a control comment', - 'comment_type': EntryReviewComment.CommentType.CONTROL, + "text": "This is a control comment", + "comment_type": EntryReviewComment.CommentType.CONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) - response = self.client.get(f'/api/v1/entries/{entry.pk}/review-comments/') - assert 'summary' in response.data - assert len(response.data['summary']['verified_by']) == 2 - assert response.data['summary']['controlled'] - assert response.data['summary']['controlled_changed_by']['id'] == user4.pk + response = self.client.get(f"/api/v1/entries/{entry.pk}/review-comments/") + assert "summary" in response.data + assert len(response.data["summary"]["verified_by"]) == 2 + assert response.data["summary"]["controlled"] + assert response.data["summary"]["controlled_changed_by"]["id"] == user4.pk def test_entry_filter_verified_count_api(self): project = self.create_project() @@ -310,29 +308,29 @@ def test_entry_filter_verified_count_api(self): project.add_member(user, role=self.normal_role, badges=[ProjectMembership.BadgeType.QA]) self.authenticate(user) data = { - 'text': 'This is a verify comment', - 'comment_type': EntryReviewComment.CommentType.VERIFY, + "text": "This is a verify comment", + "comment_type": EntryReviewComment.CommentType.VERIFY, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) - response = self.client.post('/api/v1/entries/filter/', data={'project': project.pk}) + response = self.client.post("/api/v1/entries/filter/", data={"project": project.pk}) self.assert_200(response) - assert response.data['results'][0]['verified_by_count'] == 3 - assert not response.data['results'][0]['controlled'] + assert response.data["results"][0]["verified_by_count"] == 3 + assert not response.data["results"][0]["controlled"] self.authenticate(user) data = { - 'text': 'This is a control comment', - 'comment_type': EntryReviewComment.CommentType.CONTROL, + "text": "This is a control comment", + "comment_type": EntryReviewComment.CommentType.CONTROL, } - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) - response = self.client.post('/api/v1/entries/filter/', data={'project': project.pk}) + response = self.client.post("/api/v1/entries/filter/", data={"project": project.pk}) self.assert_200(response) - assert response.data['results'][0]['verified_by_count'] == 3 - assert response.data['results'][0]['controlled'] + assert response.data["results"][0]["verified_by_count"] == 3 + assert response.data["results"][0]["controlled"] def test_entry_review_comment_text_required_api(self): project = self.create_project() @@ -341,31 +339,29 @@ def test_entry_review_comment_text_required_api(self): project.add_member(user1, role=self.normal_role, badges=[ProjectMembership.BadgeType.QA]) for comment_type, text_required in [ - (None, True), # Default is CommentType.COMMENT - (EntryReviewComment.CommentType.COMMENT, True), - (EntryReviewComment.CommentType.VERIFY, False), - (EntryReviewComment.CommentType.UNVERIFY, True), - (EntryReviewComment.CommentType.CONTROL, False), - (EntryReviewComment.CommentType.UNCONTROL, True), + (None, True), # Default is CommentType.COMMENT + (EntryReviewComment.CommentType.COMMENT, True), + (EntryReviewComment.CommentType.VERIFY, False), + (EntryReviewComment.CommentType.UNVERIFY, True), + (EntryReviewComment.CommentType.CONTROL, False), + (EntryReviewComment.CommentType.UNCONTROL, True), ]: self.authenticate(user1) data = {} if comment_type: - data['comment_type'] = comment_type - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + data["comment_type"] = comment_type + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) if text_required: self.assert_400(response) - data['text'] = 'This is a comment' - response = self.client.post(f'/api/v1/entries/{entry.pk}/review-comments/', data=data) + data["text"] = "This is a comment" + response = self.client.post(f"/api/v1/entries/{entry.pk}/review-comments/", data=data) self.assert_201(response) else: self.assert_201(response) def test_entry_review_comment_notification(self): def _get_comment_users_pk(pk): - return set( - EntryReviewComment.objects.get(pk=pk).get_related_users().values_list('pk', flat=True) - ) + return set(EntryReviewComment.objects.get(pk=pk).get_related_users().values_list("pk", flat=True)) def _clean_comments(project): return EntryReviewComment.objects.filter(entry__project=project).delete() @@ -374,10 +370,8 @@ def _clear_notifications(): return Notification.objects.all().delete() def _get_notifications_receivers(): - return set( - Notification.objects.values_list('receiver', flat=True) - ), set( - Notification.objects.values_list('notification_type', flat=True).distinct() + return set(Notification.objects.values_list("receiver", flat=True)), set( + Notification.objects.values_list("notification_type", flat=True).distinct() ) project = self.create_project() @@ -390,20 +384,20 @@ def _get_notifications_receivers(): project.add_member(user2, role=self.normal_role, badges=[ProjectMembership.BadgeType.QA]) project.add_member(user3, role=self.normal_role, badges=[ProjectMembership.BadgeType.QA]) project.add_member(user4, role=self.normal_role, badges=[ProjectMembership.BadgeType.QA]) - url = f'/api/v1/entries/{entry.pk}/review-comments/' + url = f"/api/v1/entries/{entry.pk}/review-comments/" self.authenticate(user1) # Create a commit _clear_notifications() data = { - 'text': 'This is a comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, - 'mentioned_users': [user2.pk], + "text": "This is a comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, + "mentioned_users": [user2.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self.client.post(url, data=data).json()['id'] + comment_id = self.client.post(url, data=data).json()["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk]) assert _get_notifications_receivers() == ( set([user2.pk]), @@ -413,13 +407,13 @@ def _get_notifications_receivers(): # Create a commit (multiple mentioned_users) _clear_notifications() data = { - 'text': 'This is a comment', - 'comment_type': EntryReviewComment.CommentType.COMMENT, - 'mentioned_users': [user2.pk, user3.pk, user1.pk], + "text": "This is a comment", + "comment_type": EntryReviewComment.CommentType.COMMENT, + "mentioned_users": [user2.pk, user3.pk, user1.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self.client.post(url, data=data).json()['id'] + comment_id = self.client.post(url, data=data).json()["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == ( set([user2.pk, user3.pk]), @@ -428,19 +422,21 @@ def _get_notifications_receivers(): # Create a commit different comment_type for comment_type in [ - EntryReviewComment.CommentType.VERIFY, EntryReviewComment.CommentType.UNVERIFY, - EntryReviewComment.CommentType.CONTROL, EntryReviewComment.CommentType.UNCONTROL, + EntryReviewComment.CommentType.VERIFY, + EntryReviewComment.CommentType.UNVERIFY, + EntryReviewComment.CommentType.CONTROL, + EntryReviewComment.CommentType.UNCONTROL, ]: _clean_comments(project) _clear_notifications() data = { - 'text': 'This is a comment', - 'comment_type': comment_type, - 'mentioned_users': [user1.pk, user2.pk, user3.pk], + "text": "This is a comment", + "comment_type": comment_type, + "mentioned_users": [user1.pk, user2.pk, user3.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self.client.post(url, data=data).json()['id'] + comment_id = self.client.post(url, data=data).json()["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == ( set([user2.pk, user3.pk]), @@ -450,16 +446,16 @@ def _get_notifications_receivers(): _clear_notifications() # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - resp = self.client.patch(f'{url}{comment_id}/', data=data) + resp = self.client.patch(f"{url}{comment_id}/", data=data) self.assert_200(resp) assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == (set(), set()) # No new notifications are created _clear_notifications() - data['text'] = 'this is a new comment text' + data["text"] = "this is a new comment text" # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - resp = self.client.patch(f'{url}{comment_id}/', data=data) + resp = self.client.patch(f"{url}{comment_id}/", data=data) self.assert_200(resp) assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == ( @@ -468,10 +464,10 @@ def _get_notifications_receivers(): ) # New notifications are created _clear_notifications() - data['mentioned_users'].append(user4.pk) + data["mentioned_users"].append(user4.pk) # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - resp = self.client.patch(f'{url}{comment_id}/', data=data) + resp = self.client.patch(f"{url}{comment_id}/", data=data) self.assert_200(resp) assert _get_comment_users_pk(comment_id) == set([user4.pk, user2.pk, user3.pk]) assert _get_notifications_receivers() == ( diff --git a/apps/quality_assurance/tests/test_mutations.py b/apps/quality_assurance/tests/test_mutations.py index a67d2a8f85..fc9efb1056 100644 --- a/apps/quality_assurance/tests/test_mutations.py +++ b/apps/quality_assurance/tests/test_mutations.py @@ -1,25 +1,22 @@ -from utils.graphene.tests import GraphQLTestCase - -from quality_assurance.models import EntryReviewComment -from project.models import ProjectMembership -from notification.models import Notification -from entry.models import Entry - -from user.factories import UserFactory from analysis_framework.factories import AnalysisFrameworkFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory from entry.factories import EntryFactory - +from entry.models import Entry +from lead.factories import LeadFactory +from notification.models import Notification +from project.factories import ProjectFactory +from project.models import ProjectMembership from quality_assurance.factories import EntryReviewCommentFactory +from quality_assurance.models import EntryReviewComment +from user.factories import UserFactory +from utils.graphene.tests import GraphQLTestCase VerifiedByQs = Entry.verified_by.through.objects class TestQualityAssuranceMutation(GraphQLTestCase): - CREATE_ENTRY_REVIEW_COMMENT_QUERY = ''' + CREATE_ENTRY_REVIEW_COMMENT_QUERY = """ mutation MyMutation ($projectId: ID!, $input: EntryReviewCommentInputType!) { project(id: $projectId) { entryReviewCommentCreate(data: $input) { @@ -47,9 +44,9 @@ class TestQualityAssuranceMutation(GraphQLTestCase): } } } - ''' + """ - UPDATE_ENTRY_REVIEW_COMMENT_QUERY = ''' + UPDATE_ENTRY_REVIEW_COMMENT_QUERY = """ mutation MyMutation ($projectId: ID!, $reviewCommentId: ID!, $input: EntryReviewCommentInputType!) { project(id: $projectId) { entryReviewCommentUpdate(id: $reviewCommentId data: $input) { @@ -77,9 +74,9 @@ class TestQualityAssuranceMutation(GraphQLTestCase): } } } - ''' + """ - DELETE_ENTRY_REVIEW_COMMENT_QUERY = ''' + DELETE_ENTRY_REVIEW_COMMENT_QUERY = """ mutation MyMutation ($projectId: ID!, $commentId: ID!) { project(id: $projectId) { entryReviewCommentDelete(id: $commentId) { @@ -107,7 +104,7 @@ class TestQualityAssuranceMutation(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -125,31 +122,23 @@ def setUp(self): self.project.add_member(self.qa_member_user, role=self.project_role_member, badges=[ProjectMembership.BadgeType.QA]) def _query_check(self, mutation_input, review_comment_id=None, **kwargs): - variables = {'projectId': self.project.id} + variables = {"projectId": self.project.id} query = self.CREATE_ENTRY_REVIEW_COMMENT_QUERY if review_comment_id: query = self.UPDATE_ENTRY_REVIEW_COMMENT_QUERY - variables['reviewCommentId'] = review_comment_id - return self.query_check( - query, - minput=mutation_input, - mnested=['project'], - variables=variables, - **kwargs - ) + variables["reviewCommentId"] = review_comment_id + return self.query_check(query, minput=mutation_input, mnested=["project"], variables=variables, **kwargs) def test_entry_review_comment_create(self): minput = { - 'entry': self.entry.id, - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), + "entry": self.entry.id, + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), # 'mentionedUsers': [self.readonly_member_user.pk, self.qa_member_user.pk], } def _get_notifications_receivers(): - return set( - Notification.objects.values_list('receiver', flat=True) - ), set( - Notification.objects.values_list('notification_type', flat=True).distinct() + return set(Notification.objects.values_list("receiver", flat=True)), set( + Notification.objects.values_list("notification_type", flat=True).distinct() ) # -- Without login @@ -168,44 +157,44 @@ def _get_notifications_receivers(): # Invalid input (Comment without text) self.entry.controlled = True - self.entry.save(update_fields=('controlled',)) + self.entry.save(update_fields=("controlled",)) minput = { - 'entry': self.entry.id, - 'commentType': self.genum(EntryReviewComment.CommentType.CONTROL), + "entry": self.entry.id, + "commentType": self.genum(EntryReviewComment.CommentType.CONTROL), } self._query_check(minput, okay=False) # Control self.entry.controlled = False - self.entry.save(update_fields=('controlled',)) + self.entry.save(update_fields=("controlled",)) self._query_check(minput, okay=False) self.force_login(self.qa_member_user) - minput['commentType'] = self.genum(EntryReviewComment.CommentType.UNCONTROL) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.UNCONTROL) self._query_check(minput, okay=False) - minput['commentType'] = self.genum(EntryReviewComment.CommentType.CONTROL) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.CONTROL) self._query_check(minput, okay=True) # If request by a QA User - minput['commentType'] = self.genum(EntryReviewComment.CommentType.UNCONTROL) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.UNCONTROL) self.force_login(self.member_user) self._query_check(minput, okay=False) self.force_login(self.qa_member_user) self._query_check(minput, okay=False) # Text is required - minput['text'] = 'sample text' + minput["text"] = "sample text" self._query_check(minput, okay=True) # If request by a QA User # Verify self.force_login(self.member_user) - minput.pop('text') - minput['commentType'] = self.genum(EntryReviewComment.CommentType.VERIFY) + minput.pop("text") + minput["commentType"] = self.genum(EntryReviewComment.CommentType.VERIFY) self._query_check(minput, okay=True) - minput['commentType'] = self.genum(EntryReviewComment.CommentType.VERIFY) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.VERIFY) self._query_check(minput, okay=False) - minput['commentType'] = self.genum(EntryReviewComment.CommentType.UNVERIFY) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.UNVERIFY) self._query_check(minput, okay=False) - minput['text'] = 'sample text' + minput["text"] = "sample text" self._query_check(minput, okay=True) - minput['commentType'] = self.genum(EntryReviewComment.CommentType.UNVERIFY) + minput["commentType"] = self.genum(EntryReviewComment.CommentType.UNVERIFY) self._query_check(minput, okay=False) def test_entry_review_comment_basic_api(self): @@ -219,30 +208,29 @@ def test_entry_review_comment_basic_api(self): self.force_login(user1) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), - 'mentionedUsers': [user1.pk, user2.pk, user3.pk], + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), + "mentionedUsers": [user1.pk, user2.pk, user3.pk], } - comment_pk = self._query_check(data, okay=True)['data']['project']['entryReviewCommentCreate']['result']['id'] + comment_pk = self._query_check(data, okay=True)["data"]["project"]["entryReviewCommentCreate"]["result"]["id"] assert self.entry.review_comments.count() == 1 # Update only allowd by comment creater - data['text'] = 'This is updated text comment' - content = self._query_check( - data, review_comment_id=comment_pk, okay=True)['data']['project']['entryReviewCommentUpdate'] - self.assertEqual(content['result']['textHistory'][0]['text'], data['text']) - self.assertEqual(content['result']['text'], data['text']) + data["text"] = "This is updated text comment" + content = self._query_check(data, review_comment_id=comment_pk, okay=True)["data"]["project"]["entryReviewCommentUpdate"] + self.assertEqual(content["result"]["textHistory"][0]["text"], data["text"]) + self.assertEqual(content["result"]["text"], data["text"]) self.force_login(user2) self._query_check(data, review_comment_id=comment_pk, okay=False) self.force_login(user2) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), - 'mentionedUsers': [user1.pk, user2.pk, user3.pk], + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), + "mentionedUsers": [user1.pk, user2.pk, user3.pk], } self._query_check(data, okay=True) @@ -250,10 +238,10 @@ def test_entry_review_comment_basic_api(self): self.force_login(user4) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), - 'mentionedUsers': [user1.pk, user2.pk, user3.pk], + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), + "mentionedUsers": [user1.pk, user2.pk, user3.pk], } self._query_check(data, assert_for_error=True) @@ -267,36 +255,36 @@ def test_entry_review_comment_verify_api(self): self.force_login(user1) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), } self._query_check(data, okay=True) assert VerifiedByQs.filter(entry=self.entry).count() == 0 # Verify data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment for approvable', - 'commentType': self.genum(EntryReviewComment.CommentType.VERIFY), + "entry": self.entry.pk, + "text": "This is a test comment for approvable", + "commentType": self.genum(EntryReviewComment.CommentType.VERIFY), } self._query_check(data, okay=True) assert VerifiedByQs.filter(entry=self.entry).count() == 1 self.force_login(user2) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.VERIFY), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.VERIFY), } self._query_check(data, okay=True) assert VerifiedByQs.filter(entry=self.entry).count() == 2 # Unverify data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment for unapprovable', - 'commentType': self.genum(EntryReviewComment.CommentType.UNVERIFY), + "entry": self.entry.pk, + "text": "This is a test comment for unapprovable", + "commentType": self.genum(EntryReviewComment.CommentType.UNVERIFY), } self._query_check(data, okay=True) @@ -305,9 +293,9 @@ def test_entry_review_comment_verify_api(self): # Can't verify already verify self.force_login(user1) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.VERIFY), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.VERIFY), } self._query_check(data, okay=False) @@ -316,17 +304,17 @@ def test_entry_review_comment_verify_api(self): # Can't unverify not verify self.force_login(user2) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.UNVERIFY), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.UNVERIFY), } self._query_check(data, okay=False) self.force_login(user3) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.UNVERIFY), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.UNVERIFY), } self._query_check(data, okay=False) @@ -345,9 +333,9 @@ def test_entry_review_comment_project_qa_badge_api(self): user1_membership.save() data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(comment_type), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(comment_type), } self._query_check(data, okay=False) @@ -366,17 +354,17 @@ def test_entry_review_comment_control_api(self): self.force_login(user1) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), } self._query_check(data, okay=True) # Control data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment for control/verify', - 'commentType': self.genum(EntryReviewComment.CommentType.CONTROL), + "entry": self.entry.pk, + "text": "This is a test comment for control/verify", + "commentType": self.genum(EntryReviewComment.CommentType.CONTROL), } self._query_check(data, okay=True) self.entry.refresh_from_db() @@ -385,9 +373,9 @@ def test_entry_review_comment_control_api(self): # Control using same user again data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment to again control already verified', - 'commentType': self.genum(EntryReviewComment.CommentType.CONTROL), + "entry": self.entry.pk, + "text": "This is a test comment to again control already verified", + "commentType": self.genum(EntryReviewComment.CommentType.CONTROL), } self._query_check(data, okay=False) self.entry.refresh_from_db() @@ -397,9 +385,9 @@ def test_entry_review_comment_control_api(self): # Control using another user again self.force_login(user2) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment to again control already verified', - 'commentType': self.genum(EntryReviewComment.CommentType.CONTROL), + "entry": self.entry.pk, + "text": "This is a test comment to again control already verified", + "commentType": self.genum(EntryReviewComment.CommentType.CONTROL), } self._query_check(data, okay=False) self.entry.refresh_from_db() @@ -409,9 +397,9 @@ def test_entry_review_comment_control_api(self): # Uncontrol (any users can also uncontrol) self.force_login(user2) data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment for uncontrol', - 'commentType': self.genum(EntryReviewComment.CommentType.UNCONTROL), + "entry": self.entry.pk, + "text": "This is a test comment for uncontrol", + "commentType": self.genum(EntryReviewComment.CommentType.UNCONTROL), } self._query_check(data, okay=True) self.entry.refresh_from_db() @@ -422,9 +410,9 @@ def test_entry_review_comment_control_api(self): self.force_login(user) # Can't uncontrol already uncontrol data = { - 'entry': self.entry.pk, - 'text': 'This is a test comment', - 'commentType': self.genum(EntryReviewComment.CommentType.UNVERIFY), + "entry": self.entry.pk, + "text": "This is a test comment", + "commentType": self.genum(EntryReviewComment.CommentType.UNVERIFY), } self._query_check(data, okay=False) self.entry.refresh_from_db() @@ -436,28 +424,26 @@ def test_entry_review_comment_create_text_required(self): # Text required for comment_type, text_required in [ - (None, True), # Default is CommentType.COMMENT - (EntryReviewComment.CommentType.COMMENT, True), - (EntryReviewComment.CommentType.VERIFY, False), - (EntryReviewComment.CommentType.UNVERIFY, True), - (EntryReviewComment.CommentType.CONTROL, False), - (EntryReviewComment.CommentType.UNCONTROL, True), + (None, True), # Default is CommentType.COMMENT + (EntryReviewComment.CommentType.COMMENT, True), + (EntryReviewComment.CommentType.VERIFY, False), + (EntryReviewComment.CommentType.UNVERIFY, True), + (EntryReviewComment.CommentType.CONTROL, False), + (EntryReviewComment.CommentType.UNCONTROL, True), ]: - _minput = {'entry': self.entry.pk} + _minput = {"entry": self.entry.pk} if comment_type: - _minput['commentType'] = self.genum(comment_type) + _minput["commentType"] = self.genum(comment_type) if text_required: self._query_check(_minput, okay=False) - _minput['text'] = 'This is a comment' + _minput["text"] = "This is a comment" self._query_check(_minput, okay=True) else: self._query_check(_minput, okay=True) def test_entry_review_comment_notification(self): def _get_comment_users_pk(pk): - return set( - EntryReviewComment.objects.get(pk=pk).get_related_users().values_list('pk', flat=True) - ) + return set(EntryReviewComment.objects.get(pk=pk).get_related_users().values_list("pk", flat=True)) def _clean_comments(project): return EntryReviewComment.objects.filter(entry__project=project).delete() @@ -466,10 +452,8 @@ def _clear_notifications(): return Notification.objects.all().delete() def _get_notifications_receivers(): - return set( - Notification.objects.values_list('receiver', flat=True) - ), set( - Notification.objects.values_list('notification_type', flat=True).distinct() + return set(Notification.objects.values_list("receiver", flat=True)), set( + Notification.objects.values_list("notification_type", flat=True).distinct() ) user1 = UserFactory.create() @@ -486,14 +470,14 @@ def _get_notifications_receivers(): # Create a commit _clear_notifications() minput = { - 'entry': self.entry.id, - 'text': 'This is a comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), - 'mentionedUsers': [user2.pk], + "entry": self.entry.id, + "text": "This is a comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), + "mentionedUsers": [user2.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self._query_check(minput, okay=True)['data']['project']['entryReviewCommentCreate']['result']['id'] + comment_id = self._query_check(minput, okay=True)["data"]["project"]["entryReviewCommentCreate"]["result"]["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk]) assert _get_notifications_receivers() == ( set([user2.pk]), @@ -503,14 +487,14 @@ def _get_notifications_receivers(): # Create a commit (multiple mentionedUsers) _clear_notifications() minput = { - 'entry': self.entry.id, - 'text': 'This is a comment', - 'commentType': self.genum(EntryReviewComment.CommentType.COMMENT), - 'mentionedUsers': [user2.pk, user3.pk, self.qa_member_user.pk], + "entry": self.entry.id, + "text": "This is a comment", + "commentType": self.genum(EntryReviewComment.CommentType.COMMENT), + "mentionedUsers": [user2.pk, user3.pk, self.qa_member_user.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self._query_check(minput, okay=True)['data']['project']['entryReviewCommentCreate']['result']['id'] + comment_id = self._query_check(minput, okay=True)["data"]["project"]["entryReviewCommentCreate"]["result"]["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == ( set([user2.pk, user3.pk]), @@ -519,21 +503,22 @@ def _get_notifications_receivers(): # Create a commit different comment_type for comment_type in [ - EntryReviewComment.CommentType.VERIFY, EntryReviewComment.CommentType.UNVERIFY, - EntryReviewComment.CommentType.CONTROL, EntryReviewComment.CommentType.UNCONTROL, + EntryReviewComment.CommentType.VERIFY, + EntryReviewComment.CommentType.UNVERIFY, + EntryReviewComment.CommentType.CONTROL, + EntryReviewComment.CommentType.UNCONTROL, ]: _clean_comments(self.project) _clear_notifications() minput = { - 'entry': self.entry.id, - 'text': 'This is a comment', - 'commentType': self.genum(comment_type), - 'mentionedUsers': [self.qa_member_user.pk, user2.pk, user3.pk], + "entry": self.entry.id, + "text": "This is a comment", + "commentType": self.genum(comment_type), + "mentionedUsers": [self.qa_member_user.pk, user2.pk, user3.pk], } # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): - comment_id = self._query_check( - minput, okay=True)['data']['project']['entryReviewCommentCreate']['result']['id'] + comment_id = self._query_check(minput, okay=True)["data"]["project"]["entryReviewCommentCreate"]["result"]["id"] assert _get_comment_users_pk(comment_id) == set([user2.pk, user3.pk]) assert _get_notifications_receivers() == ( set([user2.pk, user3.pk]), @@ -548,7 +533,7 @@ def _get_notifications_receivers(): assert _get_notifications_receivers() == (set(), set()) # No new notifications are created _clear_notifications() - minput['text'] = 'this is a new comment text' + minput["text"] = "this is a new comment text" # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): self._query_check(minput, review_comment_id=comment_id, okay=True) @@ -559,7 +544,7 @@ def _get_notifications_receivers(): ) # New notifications are created _clear_notifications() - minput['mentionedUsers'].append(user4.pk) + minput["mentionedUsers"].append(user4.pk) # Need self.captureOnCommitCallbacks as this API uses transation.on_commit with self.captureOnCommitCallbacks(execute=True): self._query_check(minput, review_comment_id=comment_id, okay=True) @@ -571,13 +556,8 @@ def _get_notifications_receivers(): def test_entry_review_comment_delete(self): def _query_check(review_comment_id, **kwargs): - variables = {'projectId': self.project.id, 'commentId': review_comment_id} - return self.query_check( - self.DELETE_ENTRY_REVIEW_COMMENT_QUERY, - mnested=['project'], - variables=variables, - **kwargs - ) + variables = {"projectId": self.project.id, "commentId": review_comment_id} + return self.query_check(self.DELETE_ENTRY_REVIEW_COMMENT_QUERY, mnested=["project"], variables=variables, **kwargs) member_user2 = UserFactory.create() self.project.add_member(member_user2, role=self.project_role_member) @@ -608,7 +588,9 @@ def _query_check(review_comment_id, **kwargs): self.force_login(member_user2) [ ( - _query_check(comment.pk, okay=True) if comment.comment_type == EntryReviewComment.CommentType.COMMENT + _query_check(comment.pk, okay=True) + if comment.comment_type == EntryReviewComment.CommentType.COMMENT else _query_check(comment.pk, okay=False) - )for index, comment in enumerate(comments) + ) + for index, comment in enumerate(comments) ] diff --git a/apps/quality_assurance/tests/test_schemas.py b/apps/quality_assurance/tests/test_schemas.py index f6316f7367..b52f13ca4a 100644 --- a/apps/quality_assurance/tests/test_schemas.py +++ b/apps/quality_assurance/tests/test_schemas.py @@ -1,21 +1,20 @@ -from utils.graphene.tests import GraphQLTestCase - -from user.factories import UserFactory -from project.factories import ProjectFactory -from lead.factories import LeadFactory -from entry.factories import EntryFactory from analysis_framework.factories import AnalysisFrameworkFactory +from entry.factories import EntryFactory +from lead.factories import LeadFactory from lead.models import Lead - +from project.factories import ProjectFactory from quality_assurance.factories import ( EntryReviewCommentFactory, - EntryReviewCommentTextFactory + EntryReviewCommentTextFactory, ) +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestReviewCommentQuery(GraphQLTestCase): def test_review_comments_query(self): - query = ''' + query = """ query MyQuery ($projectId: ID! $entryId: ID!) { project(id: $projectId) { entry (id: $entryId) { @@ -46,7 +45,7 @@ def test_review_comments_query(self): } } } - ''' + """ user = UserFactory.create() user2, user3 = UserFactory.create_batch(2) @@ -54,13 +53,16 @@ def test_review_comments_query(self): project = ProjectFactory.create(analysis_framework=analysis_framework) lead = LeadFactory.create(project=project) entry = EntryFactory.create( - project=project, analysis_framework=analysis_framework, - lead=lead, controlled=True, + project=project, + analysis_framework=analysis_framework, + lead=lead, + controlled=True, controlled_changed_by=user2, - verified_by=[user2, user3] + verified_by=[user2, user3], ) entry1 = EntryFactory.create( - project=project, analysis_framework=analysis_framework, + project=project, + analysis_framework=analysis_framework, lead=lead, ) @@ -69,56 +71,48 @@ def test_review_comments_query(self): review_text1 = EntryReviewCommentTextFactory.create(comment=review_comment1) # -- Without login - self.query_check(query, assert_for_error=True, variables={'projectId': project.id, 'entryId': entry.id}) + self.query_check(query, assert_for_error=True, variables={"projectId": project.id, "entryId": entry.id}) # -- With login self.force_login(user) # --- non-member user - content = self.query_check(query, variables={'projectId': project.id, 'entryId': entry.id}) - self.assertEqual(content['data']['project']['entry'], None, content) + content = self.query_check(query, variables={"projectId": project.id, "entryId": entry.id}) + self.assertEqual(content["data"]["project"]["entry"], None, content) # --- add-member in project project.add_member(user) - content = self.query_check(query, variables={'projectId': project.id, 'entryId': entry.id}) - self.assertEqual(content['data']['project']['entry']['reviewCommentsCount'], 2, content) - self.assertEqual(content['data']['project']['reviewComments']['totalCount'], 2, content) - self.assertListIds( - content['data']['project']['reviewComments']['results'], - [review_comment1, review_comment2], - content - ) - self.assertEqual( - content['data']['project']['reviewComments']['results'][1]['text'], - review_text1.text, - content - ) + content = self.query_check(query, variables={"projectId": project.id, "entryId": entry.id}) + self.assertEqual(content["data"]["project"]["entry"]["reviewCommentsCount"], 2, content) + self.assertEqual(content["data"]["project"]["reviewComments"]["totalCount"], 2, content) + self.assertListIds(content["data"]["project"]["reviewComments"]["results"], [review_comment1, review_comment2], content) + self.assertEqual(content["data"]["project"]["reviewComments"]["results"][1]["text"], review_text1.text, content) # add another review_text for same review_comment review_text2 = EntryReviewCommentTextFactory.create(comment=review_comment1) - content = self.query_check(query, variables={'projectId': project.id, 'entryId': entry.id}) - self.assertEqual(content['data']['project']['entry']['reviewCommentsCount'], 2, content) + content = self.query_check(query, variables={"projectId": project.id, "entryId": entry.id}) + self.assertEqual(content["data"]["project"]["entry"]["reviewCommentsCount"], 2, content) self.assertEqual( - content['data']['project']['reviewComments']['results'][1]['text'], + content["data"]["project"]["reviewComments"]["results"][1]["text"], review_text2.text, # here latest text should be present - content + content, ) # lets check for the contolled in entry - self.assertEqual(content['data']['project']['entry']['controlled'], True, content) - self.assertEqual(content['data']['project']['entry']['controlledChangedBy']['id'], str(user2.id), content) - self.assertEqual(len(content['data']['project']['entry']['verifiedBy']), 2, content) + self.assertEqual(content["data"]["project"]["entry"]["controlled"], True, content) + self.assertEqual(content["data"]["project"]["entry"]["controlledChangedBy"]["id"], str(user2.id), content) + self.assertEqual(len(content["data"]["project"]["entry"]["verifiedBy"]), 2, content) # lets query for another entry - content = self.query_check(query, variables={'projectId': project.id, 'entryId': entry1.id}) - self.assertEqual(content['data']['project']['entry']['reviewCommentsCount'], 1, content) - self.assertEqual(content['data']['project']['reviewComments']['totalCount'], 1, content) + content = self.query_check(query, variables={"projectId": project.id, "entryId": entry1.id}) + self.assertEqual(content["data"]["project"]["entry"]["reviewCommentsCount"], 1, content) + self.assertEqual(content["data"]["project"]["reviewComments"]["totalCount"], 1, content) def test_review_comments_project_scope_query(self): """ Include permission check """ - query = ''' + query = """ query MyQuery ($projectId: ID! $reviewId: ID!) { project(id: $projectId) { reviewComment(id: $reviewId) { @@ -143,7 +137,7 @@ def test_review_comments_project_scope_query(self): } } } - ''' + """ user = UserFactory.create() analysis_framework = AnalysisFrameworkFactory.create() @@ -159,7 +153,7 @@ def test_review_comments_project_scope_query(self): review_text_conf1, review_text_conf2 = EntryReviewCommentTextFactory.create_batch(2, comment=conf_review_comment) def _query_check(review_comment, **kwargs): - return self.query_check(query, variables={'projectId': project.id, 'reviewId': review_comment.id}, **kwargs) + return self.query_check(query, variables={"projectId": project.id, "reviewId": review_comment.id}, **kwargs) # Without login _query_check(review_comment, assert_for_error=True) @@ -168,29 +162,23 @@ def _query_check(review_comment, **kwargs): self.force_login(user) # -- Without membership content = _query_check(review_comment) - self.assertEqual(content['data']['project']['reviewComment'], None, content) + self.assertEqual(content["data"]["project"]["reviewComment"], None, content) # -- Without membership (confidential only) current_membership = project.add_member(user, role=self.project_role_reader_non_confidential) content = _query_check(review_comment) - self.assertNotEqual(content['data']['project']['reviewComment'], None, content) - self.assertEqual(len(content['data']['project']['reviewComment']['textHistory']), 2, content) - self.assertListIds( - content['data']['project']['reviewComment']['textHistory'], - [review_text1, review_text2], - content - ) + self.assertNotEqual(content["data"]["project"]["reviewComment"], None, content) + self.assertEqual(len(content["data"]["project"]["reviewComment"]["textHistory"]), 2, content) + self.assertListIds(content["data"]["project"]["reviewComment"]["textHistory"], [review_text1, review_text2], content) content = _query_check(conf_review_comment) - self.assertEqual(content['data']['project']['reviewComment'], None, content) + self.assertEqual(content["data"]["project"]["reviewComment"], None, content) # -- With membership (non-confidential only) current_membership.delete() project.add_member(user, role=self.project_role_reader) content = _query_check(review_comment) - self.assertNotEqual(content['data']['project']['reviewComment'], None, content) + self.assertNotEqual(content["data"]["project"]["reviewComment"], None, content) content = _query_check(conf_review_comment) - self.assertNotEqual(content['data']['project']['reviewComment'], None, content) - self.assertEqual(len(content['data']['project']['reviewComment']['textHistory']), 2, content) + self.assertNotEqual(content["data"]["project"]["reviewComment"], None, content) + self.assertEqual(len(content["data"]["project"]["reviewComment"]["textHistory"]), 2, content) self.assertListIds( - content['data']['project']['reviewComment']['textHistory'], - [review_text_conf1, review_text_conf2], - content + content["data"]["project"]["reviewComment"]["textHistory"], [review_text_conf1, review_text_conf2], content ) diff --git a/apps/quality_assurance/views.py b/apps/quality_assurance/views.py index 50a70ff241..2bae983957 100644 --- a/apps/quality_assurance/views.py +++ b/apps/quality_assurance/views.py @@ -1,30 +1,20 @@ -from rest_framework import ( - mixins, - permissions, - response, - viewsets, -) import django_filters +from entry.models import Entry +from rest_framework import mixins, permissions, response, viewsets from deep.paginations import SmallSizeSetPagination -from deep.permissions import ModifyPermission, IsProjectMember -from entry.models import Entry +from deep.permissions import IsProjectMember, ModifyPermission -from .serializers import ( - EntryReviewCommentSerializer, - VerifiedBySerializer, -) -from .models import ( - EntryReviewComment, -) +from .models import EntryReviewComment +from .serializers import EntryReviewCommentSerializer, VerifiedBySerializer class EntryReviewCommentViewSet( - mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - mixins.UpdateModelMixin, - viewsets.GenericViewSet, + mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + mixins.UpdateModelMixin, + viewsets.GenericViewSet, ): serializer_class = EntryReviewCommentSerializer permission_classes = [permissions.IsAuthenticated, ModifyPermission, IsProjectMember] @@ -32,22 +22,24 @@ class EntryReviewCommentViewSet( pagination_class = SmallSizeSetPagination def get_queryset(self): - return EntryReviewComment.get_for(self.request.user).filter(entry=self.kwargs['entry_id']) + return EntryReviewComment.get_for(self.request.user).filter(entry=self.kwargs["entry_id"]) def get_serializer_context(self): return { **super().get_serializer_context(), - 'entry_id': self.kwargs.get('entry_id'), + "entry_id": self.kwargs.get("entry_id"), } def get_paginated_response(self, data): - entry = Entry.objects.get(pk=self.kwargs['entry_id']) + entry = Entry.objects.get(pk=self.kwargs["entry_id"]) summary_data = { - 'verified_by': VerifiedBySerializer(entry.verified_by.all(), many=True).data, - 'controlled': entry.controlled, - 'controlled_changed_by': VerifiedBySerializer(entry.controlled_changed_by).data, + "verified_by": VerifiedBySerializer(entry.verified_by.all(), many=True).data, + "controlled": entry.controlled, + "controlled_changed_by": VerifiedBySerializer(entry.controlled_changed_by).data, } - return response.Response({ - **super().get_paginated_response(data).data, - 'summary': summary_data, - }) + return response.Response( + { + **super().get_paginated_response(data).data, + "summary": summary_data, + } + ) diff --git a/apps/questionnaire/admin.py b/apps/questionnaire/admin.py index 7dfef788fb..3898b03dde 100644 --- a/apps/questionnaire/admin.py +++ b/apps/questionnaire/admin.py @@ -1,8 +1,6 @@ from django.contrib import admin -from .models import ( - CrisisType, -) +from .models import CrisisType @admin.register(CrisisType) diff --git a/apps/questionnaire/apps.py b/apps/questionnaire/apps.py index 53e842ea00..ea09dd8a44 100644 --- a/apps/questionnaire/apps.py +++ b/apps/questionnaire/apps.py @@ -2,4 +2,4 @@ class QuestionnaireConfig(AppConfig): - name = 'questionnaire' + name = "questionnaire" diff --git a/apps/questionnaire/filter_set.py b/apps/questionnaire/filter_set.py index 6d2ca9b638..655f0c8f1e 100644 --- a/apps/questionnaire/filter_set.py +++ b/apps/questionnaire/filter_set.py @@ -1,21 +1,18 @@ import django_filters -from .models import ( - Questionnaire, - QuestionBase, -) +from .models import QuestionBase, Questionnaire class QuestionnaireFilterSet(django_filters.rest_framework.FilterSet): data_collection_techniques = django_filters.MultipleChoiceFilter( choices=QuestionBase.DATA_COLLECTION_TECHNIQUE_OPTIONS, widget=django_filters.widgets.CSVWidget, - method='filter_data_collection_techniques', + method="filter_data_collection_techniques", ) class Meta: model = Questionnaire - fields = '__all__' + fields = "__all__" def filter_data_collection_techniques(self, queryset, name, value): if len(value): diff --git a/apps/questionnaire/models.py b/apps/questionnaire/models.py index 84650de1ca..e723715a71 100644 --- a/apps/questionnaire/models.py +++ b/apps/questionnaire/models.py @@ -1,13 +1,11 @@ +from analysis_framework.models import AnalysisFramework +from django.contrib.postgres.fields import ArrayField, HStoreField from django.db import models from django.db.models import JSONField -from django.contrib.postgres.fields import ArrayField, HStoreField -from django.utils.hashable import make_hashable from django.utils.encoding import force_str - +from django.utils.hashable import make_hashable from ordered_model.models import OrderedModel - from project.models import Project -from analysis_framework.models import AnalysisFramework from user_resource.models import UserResource @@ -21,28 +19,28 @@ def __str__(self): class QuestionBase(OrderedModel): # https://xlsform.org/en/#question-types - TYPE_INTEGER = 'integer' - TYPE_DECIMAL = 'decimal' - TYPE_TEXT = 'text' - TYPE_RANGE = 'range' + TYPE_INTEGER = "integer" + TYPE_DECIMAL = "decimal" + TYPE_TEXT = "text" + TYPE_RANGE = "range" - TYPE_SELECT_ONE = 'select_one' - TYPE_SELECT_MULTIPLE = 'select_multiple' - TYPE_RANK = 'rank' + TYPE_SELECT_ONE = "select_one" + TYPE_SELECT_MULTIPLE = "select_multiple" + TYPE_RANK = "rank" - TYPE_GEOPOINT = 'geopoint' - TYPE_GEOTRACE = 'geotrace' - TYPE_GEOSHAPE = 'geoshape' + TYPE_GEOPOINT = "geopoint" + TYPE_GEOTRACE = "geotrace" + TYPE_GEOSHAPE = "geoshape" - TYPE_DATE = 'date' - TYPE_TIME = 'time' - TYPE_DATETIME = 'dateTime' + TYPE_DATE = "date" + TYPE_TIME = "time" + TYPE_DATETIME = "dateTime" - TYPE_FILE = 'file' - TYPE_IMAGE = 'image' - TYPE_AUDIO = 'audio' - TYPE_VIDEO = 'video' - TYPE_BARCODE = 'barcode' + TYPE_FILE = "file" + TYPE_IMAGE = "image" + TYPE_AUDIO = "audio" + TYPE_VIDEO = "video" + TYPE_BARCODE = "barcode" # TYPE_CALCULATE = 'calculate' # TYPE_NOTE = 'note' @@ -50,71 +48,66 @@ class QuestionBase(OrderedModel): # TYPE_HIDDEN = 'hidden' TYPE_OPTIONS = ( - (TYPE_TEXT, 'Text'), - (TYPE_INTEGER, 'Integer'), - (TYPE_DECIMAL, 'Decimal'), - - (TYPE_DATE, 'Date'), - (TYPE_TIME, 'Time'), - (TYPE_DATETIME, 'Date and time'), - - (TYPE_SELECT_ONE, 'Select one'), - (TYPE_SELECT_MULTIPLE, 'Select multiple'), - (TYPE_RANK, 'Rank'), - - (TYPE_GEOPOINT, 'Geopoint'), - (TYPE_GEOTRACE, 'Geotrace'), - (TYPE_GEOSHAPE, 'Geoshape'), - - (TYPE_IMAGE, 'Image'), - (TYPE_AUDIO, 'Audio'), - (TYPE_VIDEO, 'Video'), - (TYPE_FILE, 'Generic File'), - (TYPE_BARCODE, 'Barcode'), - (TYPE_RANGE, 'Range'), - + (TYPE_TEXT, "Text"), + (TYPE_INTEGER, "Integer"), + (TYPE_DECIMAL, "Decimal"), + (TYPE_DATE, "Date"), + (TYPE_TIME, "Time"), + (TYPE_DATETIME, "Date and time"), + (TYPE_SELECT_ONE, "Select one"), + (TYPE_SELECT_MULTIPLE, "Select multiple"), + (TYPE_RANK, "Rank"), + (TYPE_GEOPOINT, "Geopoint"), + (TYPE_GEOTRACE, "Geotrace"), + (TYPE_GEOSHAPE, "Geoshape"), + (TYPE_IMAGE, "Image"), + (TYPE_AUDIO, "Audio"), + (TYPE_VIDEO, "Video"), + (TYPE_FILE, "Generic File"), + (TYPE_BARCODE, "Barcode"), + (TYPE_RANGE, "Range"), # (TYPE_CALCULATE, 'Calculate'), # (TYPE_NOTE, 'Note'), # (TYPE_ACKNOWLEDGE, 'Acknowledge'), # (TYPE_HIDDEN, 'Hidden'), ) - IMPORTANCE_1 = '1' - IMPORTANCE_2 = '2' - IMPORTANCE_3 = '3' - IMPORTANCE_4 = '4' - IMPORTANCE_5 = '5' + IMPORTANCE_1 = "1" + IMPORTANCE_2 = "2" + IMPORTANCE_3 = "3" + IMPORTANCE_4 = "4" + IMPORTANCE_5 = "5" IMPORTANCE_OPTIONS = ( - (IMPORTANCE_1, '1'), - (IMPORTANCE_2, '2'), - (IMPORTANCE_3, '3'), - (IMPORTANCE_4, '4'), - (IMPORTANCE_5, '5'), + (IMPORTANCE_1, "1"), + (IMPORTANCE_2, "2"), + (IMPORTANCE_3, "3"), + (IMPORTANCE_4, "4"), + (IMPORTANCE_5, "5"), ) # Data collection technique choices - DIRECT = 'direct' - FOCUS_GROUP = 'focus_group' - ONE_ON_ONE_INTERVIEW = 'one_on_one_interviews' - OPEN_ENDED_SURVEY = 'open_ended_survey' - CLOSED_ENDED_SURVEY = 'closed_ended_survey' + DIRECT = "direct" + FOCUS_GROUP = "focus_group" + ONE_ON_ONE_INTERVIEW = "one_on_one_interviews" + OPEN_ENDED_SURVEY = "open_ended_survey" + CLOSED_ENDED_SURVEY = "closed_ended_survey" DATA_COLLECTION_TECHNIQUE_OPTIONS = ( - (DIRECT, 'Direct observation'), - (FOCUS_GROUP, 'Focus group'), - (ONE_ON_ONE_INTERVIEW, '1-on-1 interviews'), - (OPEN_ENDED_SURVEY, 'Open-ended survey'), - (CLOSED_ENDED_SURVEY, 'Closed-ended survey'), + (DIRECT, "Direct observation"), + (FOCUS_GROUP, "Focus group"), + (ONE_ON_ONE_INTERVIEW, "1-on-1 interviews"), + (OPEN_ENDED_SURVEY, "Open-ended survey"), + (CLOSED_ENDED_SURVEY, "Closed-ended survey"), ) # Enumerator skill choices - BASIC = 'basic' - MEDIUM = 'medium' + BASIC = "basic" + MEDIUM = "medium" ENUMERATOR_SKILL_OPTIONS = ( - (BASIC, 'Basic'), - (MEDIUM, 'Medium'), + (BASIC, "Basic"), + (MEDIUM, "Medium"), ) name = models.CharField(max_length=255) @@ -149,8 +142,7 @@ class Questionnaire(UserResource): models.CharField(max_length=56, choices=QuestionBase.DATA_COLLECTION_TECHNIQUE_OPTIONS), default=list, ) - enumerator_skill = models.CharField( - max_length=56, blank=True, choices=QuestionBase.ENUMERATOR_SKILL_OPTIONS) + enumerator_skill = models.CharField(max_length=56, blank=True, choices=QuestionBase.ENUMERATOR_SKILL_OPTIONS) # required duration in seconds required_duration = models.PositiveIntegerField(blank=True, null=True) @@ -171,9 +163,10 @@ def can_modify(self, user): class FrameworkQuestion(QuestionBase): analysis_framework = models.ForeignKey( - AnalysisFramework, on_delete=models.CASCADE, + AnalysisFramework, + on_delete=models.CASCADE, ) - order_with_respect_to = 'analysis_framework' + order_with_respect_to = "analysis_framework" def can_modify(self, user): return self.analysis_framework.can_modify(user) @@ -186,10 +179,10 @@ class Question(QuestionBase): analysis_framework = models.ForeignKey(AnalysisFramework, on_delete=models.SET_NULL, null=True) cloned_from = models.ForeignKey(FrameworkQuestion, on_delete=models.SET_NULL, null=True) questionnaire = models.ForeignKey(Questionnaire, on_delete=models.CASCADE) - order_with_respect_to = 'questionnaire' + order_with_respect_to = "questionnaire" class Meta: - unique_together = ('questionnaire', 'name') + unique_together = ("questionnaire", "name") def can_modify(self, user): return self.questionnaire.project.can_modify(user) diff --git a/apps/questionnaire/serializers.py b/apps/questionnaire/serializers.py index e9af6a38f7..0bfa1fbca1 100644 --- a/apps/questionnaire/serializers.py +++ b/apps/questionnaire/serializers.py @@ -1,20 +1,14 @@ import re -from django.shortcuts import get_object_or_404 + from django.db import models -from rest_framework import serializers, exceptions +from django.shortcuts import get_object_or_404 from drf_dynamic_fields import DynamicFieldsMixin +from rest_framework import exceptions, serializers from user_resource.serializers import UserResourceSerializer -from deep.serializers import ( - RemoveNullFieldsMixin, -) +from deep.serializers import RemoveNullFieldsMixin -from .models import ( - Questionnaire, - Question, - FrameworkQuestion, - CrisisType, -) +from .models import CrisisType, FrameworkQuestion, Question, Questionnaire class CrisisTypeSerializer( @@ -22,45 +16,47 @@ class CrisisTypeSerializer( RemoveNullFieldsMixin, ): class Meta: - fields = '__all__' + fields = "__all__" model = CrisisType class QuestionBaseSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): - enumerator_skill_display = serializers.CharField(source='get_enumerator_skill_display', read_only=True) - data_collection_technique_display = serializers.CharField( - source='get_data_collection_technique_display', read_only=True) - importance_display = serializers.CharField(source='get_importance_display', read_only=True) - crisis_type_detail = CrisisTypeSerializer(source='crisis_type', read_only=True) + enumerator_skill_display = serializers.CharField(source="get_enumerator_skill_display", read_only=True) + data_collection_technique_display = serializers.CharField(source="get_data_collection_technique_display", read_only=True) + importance_display = serializers.CharField(source="get_importance_display", read_only=True) + crisis_type_detail = CrisisTypeSerializer(source="crisis_type", read_only=True) @staticmethod def apply_order_action(question, action_meta, default_action=None): - action = action_meta.get('action') or default_action - value = action_meta.get('value') + action = action_meta.get("action") or default_action + value = action_meta.get("value") # NOTE: These methods (top, bottom, below and above) are provided by django-ordered-model - if action in ['top', 'bottom']: + if action in ["top", "bottom"]: getattr(question, action)() - elif action in ['below', 'above']: + elif action in ["below", "above"]: if value is None: - raise exceptions.ValidationError('Value is required for below|above actions') + raise exceptions.ValidationError("Value is required for below|above actions") related_question = get_object_or_404(question._meta.model, pk=value) getattr(question, action)(related_question) def validate_name(self, value): if re.match("^[a-zA-Z_][A-Za-z0-9._-]*$", value) is None: raise exceptions.ValidationError( - 'Names have to start with a letter or an underscore' - ' and can only contain letters, digits, hyphens, underscores, and periods' + "Names have to start with a letter or an underscore" + " and can only contain letters, digits, hyphens, underscores, and periods" ) return value def validate(self, data): - data['questionnaire_id'] = int(self.context['questionnaire_id']) + data["questionnaire_id"] = int(self.context["questionnaire_id"]) # The order will not work properly if order is same for multiple questions within questionnare if self.instance is None: - data['order'] = self.Meta.model.objects.filter(questionnaire=data['questionnaire_id']).aggregate( - order=models.functions.Coalesce(models.Max('order'), 0) - )['order'] + 1 + data["order"] = ( + self.Meta.model.objects.filter(questionnaire=data["questionnaire_id"]).aggregate( + order=models.functions.Coalesce(models.Max("order"), 0) + )["order"] + + 1 + ) return data def create(self, data): @@ -68,8 +64,8 @@ def create(self, data): # For handling order actions QuestionSerializer.apply_order_action( question, - self.initial_data.get('order_action') or {}, - 'bottom', + self.initial_data.get("order_action") or {}, + "bottom", ) return question @@ -78,28 +74,28 @@ class QuestionSerializer(QuestionBaseSerializer): def validate(self, data): data = super().validate(data) # Duplicate questions - d_qs = Question.objects.filter(questionnaire=data['questionnaire_id'], name=data['name']) + d_qs = Question.objects.filter(questionnaire=data["questionnaire_id"], name=data["name"]) if self.instance is not None: d_qs = d_qs.exclude(pk=self.instance.pk) if d_qs.exists(): - raise exceptions.ValidationError('Name should be unique') + raise exceptions.ValidationError("Name should be unique") return data class Meta: model = Question - fields = '__all__' - read_only_fields = ('questionnaire',) + fields = "__all__" + read_only_fields = ("questionnaire",) class FrameworkQuestionSerializer(QuestionBaseSerializer): def validate(self, data): - data['analysis_framework_id'] = int(self.context['af_id']) + data["analysis_framework_id"] = int(self.context["af_id"]) return data class Meta: model = FrameworkQuestion - fields = ('__all__') - read_only_fields = ('analysis_framework',) + fields = "__all__" + read_only_fields = ("analysis_framework",) class MiniQuestionnaireSerializer( @@ -107,15 +103,14 @@ class MiniQuestionnaireSerializer( DynamicFieldsMixin, UserResourceSerializer, ): - enumerator_skill_display = serializers.CharField(source='get_enumerator_skill_display', read_only=True) - data_collection_techniques_display = serializers.ListField( - source='get_data_collection_techniques_display', read_only=True) - crisis_types_detail = CrisisTypeSerializer(source='crisis_types', many=True, read_only=True) + enumerator_skill_display = serializers.CharField(source="get_enumerator_skill_display", read_only=True) + data_collection_techniques_display = serializers.ListField(source="get_data_collection_techniques_display", read_only=True) + crisis_types_detail = CrisisTypeSerializer(source="crisis_types", many=True, read_only=True) active_questions_count = serializers.IntegerField(read_only=True) class Meta: model = Questionnaire - fields = '__all__' + fields = "__all__" def create(self, validated_data): questionnaire = super().create(validated_data) @@ -123,7 +118,7 @@ def create(self, validated_data): class QuestionnaireSerializer(MiniQuestionnaireSerializer): - questions = QuestionSerializer(source='question_set', many=True, required=False) + questions = QuestionSerializer(source="question_set", many=True, required=False) class XFormSerializer(serializers.Serializer): diff --git a/apps/questionnaire/tests.py b/apps/questionnaire/tests.py index fbbc20c2e4..d1b4dd2ed5 100644 --- a/apps/questionnaire/tests.py +++ b/apps/questionnaire/tests.py @@ -1,13 +1,14 @@ -from deep.tests import TestCase from analysis_framework.models import AnalysisFramework from questionnaire.models import ( + CrisisType, + FrameworkQuestion, + Question, QuestionBase, Questionnaire, - Question, - FrameworkQuestion, - CrisisType, ) +from deep.tests import TestCase + # TODO: This tests will fail with --reuse-db. Make sure HStoreExtension is loaded for --reuse-db # This might be helpfull https://pytest-django.readthedocs.io/en/latest/configuring_django.html @@ -19,9 +20,9 @@ def test_questionnaire_get_api(self): self.create(Questionnaire, project=project) self.authenticate() - response = self.client.get(f'/api/v1/questionnaires/?project={project.pk}') + response = self.client.get(f"/api/v1/questionnaires/?project={project.pk}") self.assert_200(response) - assert len(response.json()['results']) == 3 + assert len(response.json()["results"]) == 3 # Custom filter test response = self.client.get( @@ -29,42 +30,45 @@ def test_questionnaire_get_api(self): f"&data_collection_techniques={','.join([QuestionBase.DIRECT, QuestionBase.FOCUS_GROUP])}" ) self.assert_200(response) - assert len(response.json()['results']) == 2 + assert len(response.json()["results"]) == 2 def test_questionnaire_post_api(self): project = self.create_project() - title = 'Test Questionnaire' + title = "Test Questionnaire" self.authenticate() - response = self.client.post('/api/v1/questionnaires/', data={ - 'title': title, - 'project': project.pk, - }) + response = self.client.post( + "/api/v1/questionnaires/", + data={ + "title": title, + "project": project.pk, + }, + ) self.assert_201(response) - created_questionnaire = Questionnaire.objects.get(pk=response.json()['id']) + created_questionnaire = Questionnaire.objects.get(pk=response.json()["id"]) assert created_questionnaire.title == title assert created_questionnaire.project == project def test_questionnaire_options_api(self): - self.create(CrisisType, title='Crisis 1') - self.create(CrisisType, title='Crisis 2') - self.create(CrisisType, title='Crisis 3') + self.create(CrisisType, title="Crisis 1") + self.create(CrisisType, title="Crisis 2") + self.create(CrisisType, title="Crisis 3") self.authenticate() - response = self.client.get('/api/v1/questionnaires/options/') + response = self.client.get("/api/v1/questionnaires/options/") self.assert_200(response) - assert len(response.json()['crisisTypeOptions']) == 3 + assert len(response.json()["crisisTypeOptions"]) == 3 def test_questionnaire_clone_api(self): - questionnaire = self.create(Questionnaire, title='Test Questionnaire', project=self.create_project()) + questionnaire = self.create(Questionnaire, title="Test Questionnaire", project=self.create_project()) self.create(Question, questionnaire=questionnaire) self.create(Question, questionnaire=questionnaire) self.create(Question, questionnaire=questionnaire) self.authenticate() - response = self.client.post(f'/api/v1/questionnaires/{questionnaire.pk}/clone/') + response = self.client.post(f"/api/v1/questionnaires/{questionnaire.pk}/clone/") self.assert_200(response) - cloned_questionnaire = Questionnaire.objects.get(pk=response.json()['id']) + cloned_questionnaire = Questionnaire.objects.get(pk=response.json()["id"]) assert cloned_questionnaire.title == questionnaire.title assert cloned_questionnaire.question_set.count() == questionnaire.question_set.count() @@ -76,93 +80,97 @@ def test_question_api(self): self.create(Question, questionnaire=questionnaire) self.authenticate() - response = self.client.get(f'/api/v1/questionnaires/{questionnaire.pk}/questions/') + response = self.client.get(f"/api/v1/questionnaires/{questionnaire.pk}/questions/") self.assert_200(response) - assert len(response.json()['results']) == 3 + assert len(response.json()["results"]) == 3 def test_question_post_api(self): questionnaire = self.create(Questionnaire, project=self.create_project()) - title = 'Test Question' + title = "Test Question" more_titles = { - 'en': title, - 'np': 'Test Question in Nepali', + "en": title, + "np": "Test Question in Nepali", } self.authenticate() - response = self.client.post(f'/api/v1/questionnaires/{questionnaire.pk}/questions/', data={ - 'title': title, - 'name': 'question-1', - 'more_titles': more_titles, - }) + response = self.client.post( + f"/api/v1/questionnaires/{questionnaire.pk}/questions/", + data={ + "title": title, + "name": "question-1", + "more_titles": more_titles, + }, + ) self.assert_201(response) - new_question = Question.objects.get(pk=response.json()['id']) + new_question = Question.objects.get(pk=response.json()["id"]) assert new_question.title == title assert new_question.more_titles == more_titles - response = self.client.post(f'/api/v1/questionnaires/{questionnaire.pk}/questions/', data={ - 'title': title, - 'name': 'question-1', - 'more_titles': more_titles, - }) + response = self.client.post( + f"/api/v1/questionnaires/{questionnaire.pk}/questions/", + data={ + "title": title, + "name": "question-1", + "more_titles": more_titles, + }, + ) # Duplicate name self.assert_400(response) def test_question_clone_api(self): question = self.create( - Question, title='Test Question', - questionnaire=self.create(Questionnaire, project=self.create_project()) + Question, title="Test Question", questionnaire=self.create(Questionnaire, project=self.create_project()) ) self.authenticate() - response = self.client.post( - f'/api/v1/questionnaires/{question.questionnaire.pk}/questions/{question.pk}/clone/') + response = self.client.post(f"/api/v1/questionnaires/{question.questionnaire.pk}/questions/{question.pk}/clone/") self.assert_200(response) - cloned_question = Question.objects.get(pk=response.json()['id']) + cloned_question = Question.objects.get(pk=response.json()["id"]) assert cloned_question.title == question.title def test_question_bulk_actions_api(self): questionnaire = self.create(Questionnaire, project=self.create_project()) - q1 = self.create(Question, title='Test Question', questionnaire=questionnaire) - q2 = self.create(Question, title='Test Question', questionnaire=questionnaire) - q3 = self.create(Question, title='Test Question', questionnaire=questionnaire) - q4 = self.create(Question, title='Test Question', questionnaire=questionnaire) + q1 = self.create(Question, title="Test Question", questionnaire=questionnaire) + q2 = self.create(Question, title="Test Question", questionnaire=questionnaire) + q3 = self.create(Question, title="Test Question", questionnaire=questionnaire) + q4 = self.create(Question, title="Test Question", questionnaire=questionnaire) def get_bulk_data(questions): - return [{'id': q.pk} for q in questions] + return [{"id": q.pk} for q in questions] self.authenticate() # TODO: Detail test for action, data, state, excepted_state in [ - ( - 'bulk-archive', get_bulk_data([q1, q2, q3]), - lambda: Question.objects.filter(questionnaire=questionnaire, is_archived=True).count(), 3 - ), - ( - 'bulk-unarchive', get_bulk_data([q1, q2, q3, q4]), - lambda: Question.objects.filter(questionnaire=questionnaire, is_archived=False).count(), 4 - ), - ( - 'bulk-delete', get_bulk_data([q1, q2]), - lambda: Question.objects.filter(questionnaire=questionnaire).count(), 2 - ), + ( + "bulk-archive", + get_bulk_data([q1, q2, q3]), + lambda: Question.objects.filter(questionnaire=questionnaire, is_archived=True).count(), + 3, + ), + ( + "bulk-unarchive", + get_bulk_data([q1, q2, q3, q4]), + lambda: Question.objects.filter(questionnaire=questionnaire, is_archived=False).count(), + 4, + ), + ("bulk-delete", get_bulk_data([q1, q2]), lambda: Question.objects.filter(questionnaire=questionnaire).count(), 2), ]: - response = self.client.post(f'/api/v1/questionnaires/{questionnaire.pk}/questions/{action}/', data=data) + response = self.client.post(f"/api/v1/questionnaires/{questionnaire.pk}/questions/{action}/", data=data) self.assert_200(response) - assert state() == excepted_state, f'For {action} {response.json()}' + assert state() == excepted_state, f"For {action} {response.json()}" def test_question_order_api(self): questionnaire = self.create(Questionnaire, project=self.create_project()) - q1 = self.create(Question, title='Test Question', questionnaire=questionnaire, order=1) - q2 = self.create(Question, title='Test Question', questionnaire=questionnaire, order=2) - q3 = self.create(Question, title='Test Question', questionnaire=questionnaire, order=3) - q4 = self.create(Question, title='Test Question', questionnaire=questionnaire, order=4) + q1 = self.create(Question, title="Test Question", questionnaire=questionnaire, order=1) + q2 = self.create(Question, title="Test Question", questionnaire=questionnaire, order=2) + q3 = self.create(Question, title="Test Question", questionnaire=questionnaire, order=3) + q4 = self.create(Question, title="Test Question", questionnaire=questionnaire, order=4) questions = [q1, q2, q3, q4] self.authenticate() response = self.client.post( - f'/api/v1/questionnaires/{questionnaire.pk}/questions/{q3.pk}/order/', - data={'action': 'below', 'value': q1.pk} + f"/api/v1/questionnaires/{questionnaire.pk}/questions/{q3.pk}/order/", data={"action": "below", "value": q1.pk} ) self.assert_200(response) @@ -177,7 +185,7 @@ def test_framework_question_api(self): self.create(FrameworkQuestion, analysis_framework=af) self.authenticate() - response = self.client.get(f'/api/v1/analysis-frameworks/{af.pk}/questions/') + response = self.client.get(f"/api/v1/analysis-frameworks/{af.pk}/questions/") self.assert_200(response) def test_framework_question_post_api(self): @@ -186,16 +194,23 @@ def test_framework_question_post_api(self): q1 = self.create(FrameworkQuestion, analysis_framework=af) self.authenticate() - response = self.client.post(f'/api/v1/analysis-frameworks/{af.pk}/questions/', data={ - 'title': 'Test Framework Questions', - 'name': 'framework-question-1', - }) + response = self.client.post( + f"/api/v1/analysis-frameworks/{af.pk}/questions/", + data={ + "title": "Test Framework Questions", + "name": "framework-question-1", + }, + ) self.assert_201(response) - q2_id = response.json()['id'] + q2_id = response.json()["id"] - response = self.client.post(f'/api/v1/analysis-frameworks/{af.pk}/questions/{q2_id}/order/', data={ - 'action': 'above', 'value': q1.pk, - }) + response = self.client.post( + f"/api/v1/analysis-frameworks/{af.pk}/questions/{q2_id}/order/", + data={ + "action": "above", + "value": q1.pk, + }, + ) self.assert_200(response) def test_framework_question_copy_api(self): @@ -205,24 +220,27 @@ def test_framework_question_copy_api(self): self.create(Question, questionnaire=questionnaire) self.authenticate() - response = self.client.post(f'/api/v1/questionnaires/{questionnaire.pk}/questions/af-question-copy/', data={ - 'framework_question_id': fq.pk, - 'order_action': { - 'action': 'bottom', + response = self.client.post( + f"/api/v1/questionnaires/{questionnaire.pk}/questions/af-question-copy/", + data={ + "framework_question_id": fq.pk, + "order_action": { + "action": "bottom", + }, }, - }) + ) self.assert_200(response) - assert response.json()['questionnaire'] == questionnaire.pk - assert response.json()['order'] == 2 + assert response.json()["questionnaire"] == questionnaire.pk + assert response.json()["order"] == 2 def test_xform_view(self): # Just checking API Endpoint. Requires xform file for test self.authenticate() - response = self.client.post('/api/v1/xlsform-to-xform/') + response = self.client.post("/api/v1/xlsform-to-xform/") self.assert_400(response) def test_kobo_toolbox_export(self): # Just checking API Endpoint. Requires oauth for test self.authenticate() - response = self.client.post('/api/v1/import-to-kobotoolbox/') + response = self.client.post("/api/v1/import-to-kobotoolbox/") self.assert_400(response) diff --git a/apps/questionnaire/utils/kobo_toolbox.py b/apps/questionnaire/utils/kobo_toolbox.py index f107335bc9..07bc15f1de 100644 --- a/apps/questionnaire/utils/kobo_toolbox.py +++ b/apps/questionnaire/utils/kobo_toolbox.py @@ -1,8 +1,9 @@ import base64 + import requests -class KoboToolbox(): +class KoboToolbox: def __init__(self, username=None, password=None): self.username = username self.password = password @@ -10,11 +11,11 @@ def __init__(self, username=None, password=None): @property def auth(self): - params = {'headers': {'Accept': 'application/json'}} + params = {"headers": {"Accept": "application/json"}} if self.access_token: - params['headers']['Authorization'] = f"Bearer {self.access_token}" + params["headers"]["Authorization"] = f"Bearer {self.access_token}" else: - params['auth'] = (self.username, self.password) + params["auth"] = (self.username, self.password) return params def getEncodedFile(self, file): @@ -23,26 +24,30 @@ def getEncodedFile(self, file): return b"base64," + base64.b64encode(file.read()) def export(self, file): - assest = requests.post('https://kf.kobotoolbox.org/api/v2/assets/', data={ - 'name': "Untitled (IMPORTED FROM DEEP)", - 'asset_type': 'survey', - }, **self.auth).json() + assest = requests.post( + "https://kf.kobotoolbox.org/api/v2/assets/", + data={ + "name": "Untitled (IMPORTED FROM DEEP)", + "asset_type": "survey", + }, + **self.auth, + ).json() import_trigger = requests.post( - 'https://kf.kobotoolbox.org/imports/', + "https://kf.kobotoolbox.org/imports/", data={ - 'totalFiles': 1, - 'destination': assest['url'], - 'assetUid': assest['uid'], - 'name': file.name, - 'base64Encoded': self.getEncodedFile(file), + "totalFiles": 1, + "destination": assest["url"], + "assetUid": assest["uid"], + "name": file.name, + "base64Encoded": self.getEncodedFile(file), }, **self.auth, ).json() return { - 'assert_settings': f"https://kf.kobotoolbox.org/#/forms/{assest['uid']}/settings", - 'assert_form': f"https://kf.kobotoolbox.org/#/forms/{assest['uid']}/edit", - 'assert': assest, - 'import': import_trigger, + "assert_settings": f"https://kf.kobotoolbox.org/#/forms/{assest['uid']}/settings", + "assert_form": f"https://kf.kobotoolbox.org/#/forms/{assest['uid']}/edit", + "assert": assest, + "import": import_trigger, } diff --git a/apps/questionnaire/utils/xls_form.py b/apps/questionnaire/utils/xls_form.py index 6a017ad13a..315ba2971c 100644 --- a/apps/questionnaire/utils/xls_form.py +++ b/apps/questionnaire/utils/xls_form.py @@ -1,14 +1,14 @@ import os from tempfile import NamedTemporaryFile -from pyxform import create_survey_from_xls from lxml import etree as ET +from pyxform import create_survey_from_xls class XLSForm: @classmethod def create_xform(cls, xlsx_file): - with NamedTemporaryFile(suffix='.xlsx') as tmp: + with NamedTemporaryFile(suffix=".xlsx") as tmp: tmp.write(xlsx_file.read()) tmp.seek(0) survey = create_survey_from_xls(tmp) @@ -19,8 +19,8 @@ def create_xform(cls, xlsx_file): def create_enketo_form(cls, xlsx_file): tree = ET.fromstring(cls.create_xform(xlsx_file)) - form_xslt = ET.parse(os.path.join(os.path.dirname(__file__), 'openrosa2html5form.xsl')) - model_xslt = ET.parse(os.path.join(os.path.dirname(__file__), 'openrosa2xmlmodel.xsl')) + form_xslt = ET.parse(os.path.join(os.path.dirname(__file__), "openrosa2html5form.xsl")) + model_xslt = ET.parse(os.path.join(os.path.dirname(__file__), "openrosa2xmlmodel.xsl")) form_transform = ET.XSLT(form_xslt) model_transform = ET.XSLT(model_xslt) @@ -29,6 +29,6 @@ def create_enketo_form(cls, xlsx_file): model = model_transform(tree) return { - 'form': ET.tostring(form.getroot()[0]).decode(), - 'model': ET.tostring(model.getroot()[0]).decode(), + "form": ET.tostring(form.getroot()[0]).decode(), + "model": ET.tostring(model.getroot()[0]).decode(), } diff --git a/apps/questionnaire/views.py b/apps/questionnaire/views.py index 090434c746..d65a98473a 100644 --- a/apps/questionnaire/views.py +++ b/apps/questionnaire/views.py @@ -1,38 +1,24 @@ -import django_filters import logging + +import django_filters from django.db import models +from rest_framework import exceptions, permissions, response, views, viewsets from rest_framework.decorators import action -from rest_framework import ( - views, - viewsets, - response, - permissions, - exceptions, -) from deep.permissions import ModifyPermission -from .utils import xls_form, kobo_toolbox - -from .models import ( - QuestionBase, - FrameworkQuestion, - Questionnaire, - Question, - CrisisType, -) - +from .filter_set import QuestionnaireFilterSet +from .models import CrisisType, FrameworkQuestion, Question, QuestionBase, Questionnaire from .serializers import ( CrisisTypeSerializer, + FrameworkQuestionSerializer, + KoboToolboxExportSerializer, MiniQuestionnaireSerializer, QuestionnaireSerializer, QuestionSerializer, - FrameworkQuestionSerializer, XFormSerializer, - KoboToolboxExportSerializer, ) - -from .filter_set import QuestionnaireFilterSet +from .utils import kobo_toolbox, xls_form logger = logging.getLogger(__name__) @@ -47,46 +33,41 @@ class QuestionnaireViewSet(viewsets.ModelViewSet): def get_queryset(self): return Questionnaire.objects.annotate( - active_questions_count=models.Count( - 'question', filter=models.Q(question__is_archived=False), distinct=True - ) - ).prefetch_related('crisis_types') + active_questions_count=models.Count("question", filter=models.Q(question__is_archived=False), distinct=True) + ).prefetch_related("crisis_types") def get_serializer_context(self): context = super().get_serializer_context() - if 'pk' in self.kwargs: - context['questionnaire_id'] = self.kwargs['pk'] + if "pk" in self.kwargs: + context["questionnaire_id"] = self.kwargs["pk"] return context def get_serializer_class(self): - if self.action == 'list': + if self.action == "list": return MiniQuestionnaireSerializer return super().get_serializer_class() @action( detail=False, - url_path='options', + url_path="options", ) def get_options(self, request, version=None): options = { - field: [ - {'key': key, 'value': value} - for key, value in values - ] + field: [{"key": key, "value": value} for key, value in values] for field, values in ( - ('enumerator_skill_options', QuestionBase.ENUMERATOR_SKILL_OPTIONS), - ('data_collection_technique_options', QuestionBase.DATA_COLLECTION_TECHNIQUE_OPTIONS), - ('question_importance_options', QuestionBase.IMPORTANCE_OPTIONS), - ('question_type_options', QuestionBase.TYPE_OPTIONS), + ("enumerator_skill_options", QuestionBase.ENUMERATOR_SKILL_OPTIONS), + ("data_collection_technique_options", QuestionBase.DATA_COLLECTION_TECHNIQUE_OPTIONS), + ("question_importance_options", QuestionBase.IMPORTANCE_OPTIONS), + ("question_type_options", QuestionBase.TYPE_OPTIONS), ) } - options['crisis_type_options'] = CrisisTypeSerializer( + options["crisis_type_options"] = CrisisTypeSerializer( CrisisType.objects.all(), many=True, ).data return response.Response(options) - @action(detail=True, methods=['post'], url_path='clone') + @action(detail=True, methods=["post"], url_path="clone") def create_clone(self, request, *args, **kwargs): """ Clone questionnaire (also questions) @@ -105,17 +86,14 @@ def create_clone(self, request, *args, **kwargs): # Override fields value if supplied [ setattr(obj, field, value) - for field in [ - 'title', 'project_id', 'required_duration', - 'data_collection_technique', 'enumerator_skill' - ] + for field in ["title", "project_id", "required_duration", "data_collection_technique", "enumerator_skill"] for value in [request.data.get(field)] if value is not None ] obj.save() # Override crisis types - override_crisis_types_id = request.data.get('crisis_types_id') + override_crisis_types_id = request.data.get("crisis_types_id") if override_crisis_types_id is not None: old_crisis_types = CrisisType.objects.filter(pk__in=override_crisis_types_id) obj.crisis_types.set(old_crisis_types, clear=True) @@ -128,18 +106,18 @@ def create_clone(self, request, *args, **kwargs): return response.Response(self.get_serializer_class()(obj).data) -class QuestionBaseViewMixin(): - @action(detail=False, methods=['post'], url_path='bulk-delete') +class QuestionBaseViewMixin: + @action(detail=False, methods=["post"], url_path="bulk-delete") def bulk_delete(self, *args, **kwargs): """{"id": number}""" return self.bulk_action() - @action(detail=False, methods=['post'], url_path='bulk-archive') + @action(detail=False, methods=["post"], url_path="bulk-archive") def bulk_archive(self, *args, **kwargs): """{"id": number}""" return self.bulk_action() - @action(detail=False, methods=['post'], url_path='bulk-unarchive') + @action(detail=False, methods=["post"], url_path="bulk-unarchive") def bulk_unarchive(self, *args, **kwargs): """{"id": number}""" return self.bulk_action() @@ -147,31 +125,31 @@ def bulk_unarchive(self, *args, **kwargs): def bulk_action(self): # TODO: Permission try: - question_body = {q['id']: q for q in self.request.data} + question_body = {q["id"]: q for q in self.request.data} except (TypeError, KeyError): - raise exceptions.ValidationError('Invalid request. Check and try again!!') + raise exceptions.ValidationError("Invalid request. Check and try again!!") questions = self.get_queryset().filter(id__in=question_body.keys()) - response_body = list(questions.values_list('id', flat=True)) + response_body = list(questions.values_list("id", flat=True)) - if self.action == 'bulk_delete': + if self.action == "bulk_delete": questions.all().delete() - elif self.action == 'bulk_archive': + elif self.action == "bulk_archive": questions.update(is_archived=True) - elif self.action == 'bulk_unarchive': + elif self.action == "bulk_unarchive": questions.update(is_archived=False) - elif self.action == 'bulk_order': + elif self.action == "bulk_order": # TODO: Use bulk update after django upgrade updated_questions = [] for question in questions.all(): - question.order = question_body.get(question.id).get('order') + question.order = question_body.get(question.id).get("order") if question.order is None: continue question.save() - updated_questions.append({'id': question.pk, 'new_order': question.order}) + updated_questions.append({"id": question.pk, "new_order": question.order}) response_body = updated_questions return response.Response(response_body) - @action(detail=True, methods=['post'], url_path='order') + @action(detail=True, methods=["post"], url_path="order") def order(self, request, *args, **kwargs): """ ```json @@ -185,14 +163,16 @@ def order(self, request, *args, **kwargs): question = self.get_object() QuestionSerializer.apply_order_action(question, request.data) if isinstance(question, FrameworkQuestion): - return response.Response({ - 'new_order': question.analysis_framework.question_set.order_by('order').values('pk', 'order') - }) - return response.Response({ - 'new_order': question.questionnaire.question_set.order_by('order').values('pk', 'order'), - }) - - @action(detail=True, methods=['post'], url_path='clone') + return response.Response( + {"new_order": question.analysis_framework.question_set.order_by("order").values("pk", "order")} + ) + return response.Response( + { + "new_order": question.questionnaire.question_set.order_by("order").values("pk", "order"), + } + ) + + @action(detail=True, methods=["post"], url_path="clone") def create_clone(self, request, *args, **kwargs): """ TODO: Remove this @@ -203,10 +183,10 @@ def create_clone(self, request, *args, **kwargs): """ obj = self.get_object() obj.pk = None - obj.name += 'prefix' + obj.name += "prefix" obj.order = None obj.save() - QuestionSerializer.apply_order_action(obj, request.data.get('order_action', {}), 'bottom') + QuestionSerializer.apply_order_action(obj, request.data.get("order_action", {}), "bottom") return response.Response(self.get_serializer_class()(obj).data) @@ -216,15 +196,15 @@ class QuestionViewSet(QuestionBaseViewMixin, viewsets.ModelViewSet): permission_classes = (permissions.IsAuthenticated, ModifyPermission) def get_queryset(self): - return Question.objects.filter(questionnaire=self.kwargs['questionnaire_id']).all() + return Question.objects.filter(questionnaire=self.kwargs["questionnaire_id"]).all() def get_serializer_context(self): return { **super().get_serializer_context(), - 'questionnaire_id': self.kwargs.get('questionnaire_id'), + "questionnaire_id": self.kwargs.get("questionnaire_id"), } - @action(detail=False, methods=['post'], url_path=r'af-question-copy') + @action(detail=False, methods=["post"], url_path=r"af-question-copy") def copy_from_af_question(self, request, *args, **kwargs): """ Copy from framework question to Questionnaire question @@ -233,10 +213,10 @@ def copy_from_af_question(self, request, *args, **kwargs): ``` """ try: - fq = FrameworkQuestion.objects.get(id=request.data['framework_question_id']) - questionnaire = Questionnaire.objects.get(id=self.kwargs['questionnaire_id']) + fq = FrameworkQuestion.objects.get(id=request.data["framework_question_id"]) + questionnaire = Questionnaire.objects.get(id=self.kwargs["questionnaire_id"]) except (TypeError, KeyError): - raise exceptions.ValidationError('Invalid request. Check and try again!!') + raise exceptions.ValidationError("Invalid request. Check and try again!!") if not (fq.can_get(request.user) and questionnaire.can_modify(request.user)): return exceptions.PermissionDenied() @@ -246,12 +226,9 @@ def copy_from_af_question(self, request, *args, **kwargs): cloned_from=fq, questionnaire=questionnaire, order=None, - **{ - field.name: getattr(fq, field.name) - for field in fq._meta.fields if field.name not in ['id', 'order'] - }, + **{field.name: getattr(fq, field.name) for field in fq._meta.fields if field.name not in ["id", "order"]}, ) - QuestionSerializer.apply_order_action(new_question, request.data.get('order_action', {}), 'bottom') + QuestionSerializer.apply_order_action(new_question, request.data.get("order_action", {}), "bottom") return response.Response(self.get_serializer_class()(new_question).data) @@ -263,11 +240,11 @@ class FrameworkQuestionViewSet(QuestionBaseViewMixin, viewsets.ModelViewSet): def get_serializer_context(self): return { **super().get_serializer_context(), - 'af_id': self.kwargs.get('af_id'), + "af_id": self.kwargs.get("af_id"), } def get_queryset(self): - return FrameworkQuestion.objects.filter(analysis_framework=self.kwargs['af_id']).all() + return FrameworkQuestion.objects.filter(analysis_framework=self.kwargs["af_id"]).all() class XFormView(views.APIView): @@ -277,12 +254,12 @@ def get_serializer(self, *args, **kwargs): def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - xlsform_file = serializer.validated_data['file'] + xlsform_file = serializer.validated_data["file"] try: return response.Response(xls_form.XLSForm.create_enketo_form(xlsform_file)) except Exception: - logger.error('Failed to create enketo form', exc_info=True) - raise exceptions.ValidationError('Failed to create enketo form') + logger.error("Failed to create enketo form", exc_info=True) + raise exceptions.ValidationError("Failed to create enketo form") class KoboToolboxExport(views.APIView): @@ -292,12 +269,11 @@ def get_serializer(self, *args, **kwargs): def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - xlsform_file = serializer.validated_data['file'] + xlsform_file = serializer.validated_data["file"] req_vd = serializer.validated_data - kt = kobo_toolbox.KoboToolbox(username=req_vd['username'], password=req_vd['password']) + kt = kobo_toolbox.KoboToolbox(username=req_vd["username"], password=req_vd["password"]) try: return response.Response(kt.export(xlsform_file)) except Exception: - raise exceptions.ValidationError( - 'Invalid request. Please provide valid XLSForm file and valid access token!!') + raise exceptions.ValidationError("Invalid request. Please provide valid XLSForm file and valid access token!!") diff --git a/apps/redis_store/__init__.py b/apps/redis_store/__init__.py index 4564ca6754..66ad67950e 100644 --- a/apps/redis_store/__init__.py +++ b/apps/redis_store/__init__.py @@ -1 +1 @@ -default_app_config = 'redis_store.apps.RedisStoreConfig' +default_app_config = "redis_store.apps.RedisStoreConfig" diff --git a/apps/redis_store/apps.py b/apps/redis_store/apps.py index 204206ce00..8dc6520bba 100644 --- a/apps/redis_store/apps.py +++ b/apps/redis_store/apps.py @@ -3,7 +3,7 @@ class RedisStoreConfig(AppConfig): - name = 'redis_store' + name = "redis_store" def ready(self): redis.init() diff --git a/apps/redis_store/redis.py b/apps/redis_store/redis.py index 5134e806f6..37c93f2bca 100644 --- a/apps/redis_store/redis.py +++ b/apps/redis_store/redis.py @@ -2,16 +2,15 @@ from deep.celery import app as celery_app - """ Redis connection pool """ pool = None SSL_REQ_MAP = { - 'CERT_NONE': 'none', - 'CERT_OPTIONAL': 'optional', - 'CERT_REQUIRED': 'required', + "CERT_NONE": "none", + "CERT_OPTIONAL": "optional", + "CERT_REQUIRED": "required", } @@ -25,7 +24,7 @@ def init(): kconn = celery_app.connection() url = kconn.as_uri(include_password=True) ssl = kconn.ssl - if ssl is not False and 'ssl_cert_reqs' in ssl: + if ssl is not False and "ssl_cert_reqs" in ssl: url += f"?ssl_cert_reqs:{SSL_REQ_MAP.get(ssl['ssl_cert_reqs'].name, 'optional')}" pool = redis.ConnectionPool.from_url(url=url) diff --git a/apps/redis_store/tests/test_redis.py b/apps/redis_store/tests/test_redis.py index 0fc0ef9b8d..65133050d3 100644 --- a/apps/redis_store/tests/test_redis.py +++ b/apps/redis_store/tests/test_redis.py @@ -12,7 +12,7 @@ def test_set_and_del(self): Test redis connection by writing, reading and deleting a value """ r = redis.get_connection() - r.set('foo', 'bar') - self.assertEqual(r.get('foo'), b'bar') - r.delete('foo') - self.assertEqual(r.get('foo'), None) + r.set("foo", "bar") + self.assertEqual(r.get("foo"), b"bar") + r.delete("foo") + self.assertEqual(r.get("foo"), None) diff --git a/apps/tabular/__init__.py b/apps/tabular/__init__.py index f43e28b96a..1e297efe35 100644 --- a/apps/tabular/__init__.py +++ b/apps/tabular/__init__.py @@ -1 +1 @@ -default_app_config = 'tabular.apps.TabularConfig' +default_app_config = "tabular.apps.TabularConfig" diff --git a/apps/tabular/admin.py b/apps/tabular/admin.py index 41e4a2ee61..d947fe27f0 100644 --- a/apps/tabular/admin.py +++ b/apps/tabular/admin.py @@ -1,11 +1,10 @@ -from django.contrib import admin +from django.contrib import admin, messages from django.utils.safestring import mark_safe -from django.contrib import messages from deep.admin import VersionAdmin -from .models import Book, Sheet, Field, Geodata from .filters import CacheStatusListFilter +from .models import Book, Field, Geodata, Sheet from .tasks import tabular_generate_columns_image @@ -23,41 +22,48 @@ class GeodataInline(admin.StackedInline): @admin.register(Book) class BookAdmin(VersionAdmin): - search_fields = ('title',) + search_fields = ("title",) inlines = [SheetInline] - autocomplete_fields = ('project', 'created_by', 'modified_by', 'file',) + autocomplete_fields = ( + "project", + "created_by", + "modified_by", + "file", + ) @admin.register(Sheet) class SheetAdmin(VersionAdmin): - search_fields = ('title',) + search_fields = ("title",) inlines = [FieldInline] - autocomplete_fields = ('book',) + autocomplete_fields = ("book",) def trigger_cache_reset(modeladmin, request, queryset): messages.add_message( - request, messages.INFO, + request, + messages.INFO, mark_safe( - 'Successfully triggerd fields:

' + '
'.join( - '* {} : {}'.format(value[0], value[1]) - for value in queryset.values_list('id', 'title').distinct() - ) - ) - ) - tabular_generate_columns_image.delay( - list(queryset.values_list('id', flat=True).distinct()) + "Successfully triggerd fields:

" + + "
".join("* {} : {}".format(value[0], value[1]) for value in queryset.values_list("id", "title").distinct()) + ), ) + tabular_generate_columns_image.delay(list(queryset.values_list("id", flat=True).distinct())) -trigger_cache_reset.short_description = 'Trigger cache reset for selected Fields' +trigger_cache_reset.short_description = "Trigger cache reset for selected Fields" @admin.register(Field) class FieldAdmin(VersionAdmin): inlines = [GeodataInline] - list_display = ('pk', 'title', 'sheet', 'type',) - list_filter = ('type', CacheStatusListFilter) - search_fields = ['title'] + list_display = ( + "pk", + "title", + "sheet", + "type", + ) + list_filter = ("type", CacheStatusListFilter) + search_fields = ["title"] actions = [trigger_cache_reset] - autocomplete_fields = ('sheet',) + autocomplete_fields = ("sheet",) diff --git a/apps/tabular/apps.py b/apps/tabular/apps.py index 9589ece9c9..5f3d39afef 100644 --- a/apps/tabular/apps.py +++ b/apps/tabular/apps.py @@ -2,7 +2,7 @@ class TabularConfig(AppConfig): - name = 'tabular' + name = "tabular" def ready(self): import tabular.receivers # noqa diff --git a/apps/tabular/extractor/csv.py b/apps/tabular/extractor/csv.py index 39a0a3dfcf..e9daa4ed1e 100644 --- a/apps/tabular/extractor/csv.py +++ b/apps/tabular/extractor/csv.py @@ -1,10 +1,11 @@ -import io import csv +import io from itertools import chain -from ..models import Sheet, Field from utils.common import LogTime +from ..models import Field, Sheet + @LogTime() def extract(book): @@ -16,13 +17,13 @@ def extract(book): book=book, ) reader = csv.reader( - io.StringIO(csv_file.read().decode('utf-8')), - delimiter=options.get('delimiter', ','), - quotechar=options.get('quotechar', '"'), + io.StringIO(csv_file.read().decode("utf-8")), + delimiter=options.get("delimiter", ","), + quotechar=options.get("quotechar", '"'), skipinitialspace=True, ) - no_headers = options.get('no_headers', False) + no_headers = options.get("no_headers", False) data_index = 0 if no_headers else 1 fields = [] @@ -33,8 +34,7 @@ def extract(book): for header in first_row: fields.append( Field( - title=(header if not no_headers - else 'Column ' + str(ordering)), + title=(header if not no_headers else "Column " + str(ordering)), sheet=sheet, ordering=ordering, ) @@ -50,18 +50,20 @@ def extract(book): try: for index, field in enumerate(fields): field_data = fields_data.get(field.id, []) - field_data.append({ - 'value': _row[index], - 'invalid': False, - 'empty': False, - }) + field_data.append( + { + "value": _row[index], + "invalid": False, + "empty": False, + } + ) fields_data[field.id] = field_data except Exception: pass for field in sheet.field_set.all(): field.data = fields_data.get(field.id, []) - block_name = 'Field Save csv extract {}'.format(field.title) + block_name = "Field Save csv extract {}".format(field.title) with LogTime(block_name=block_name): field.save() diff --git a/apps/tabular/extractor/ods.py b/apps/tabular/extractor/ods.py index c5f6f2c8f9..3089bfc86e 100644 --- a/apps/tabular/extractor/ods.py +++ b/apps/tabular/extractor/ods.py @@ -1,11 +1,12 @@ -import pyexcel_ods -from datetime import datetime import logging +from datetime import datetime -from ..models import Sheet, Field +import pyexcel_ods from utils.common import LogTime +from ..models import Field, Sheet + logger = logging.getLogger(__name__) date_type = type(datetime.now().date()) @@ -19,15 +20,15 @@ def extract(book): workbook = pyexcel_ods.get_data(ods_file) for sheet_key in workbook: wb_sheet = workbook[sheet_key] - sheet_options = options.get('sheets', {}).get(str(sheet_key), {}) - if sheet_options.get('skip', False): + sheet_options = options.get("sheets", {}).get(str(sheet_key), {}) + if sheet_options.get("skip", False): continue sheet = Sheet.objects.create( title=sheet_key, book=book, ) - header_index = sheet_options.get('header_row', 1) - 1 - no_headers = sheet_options.get('no_headers', False) + header_index = sheet_options.get("header_row", 1) - 1 + no_headers = sheet_options.get("no_headers", False) data_index = header_index + 1 if no_headers: @@ -41,15 +42,10 @@ def extract(book): for value in header_row: fields.append( Field( - title=(value if not no_headers - else 'Column ' + str(ordering)), + title=(value if not no_headers else "Column " + str(ordering)), sheet=sheet, ordering=ordering, - data=[{ - 'value': value, - 'empty': False, - 'invalid': False - }] + data=[{"value": value, "empty": False, "invalid": False}], ) ) ordering += 1 @@ -64,11 +60,7 @@ def extract(book): value = _row[index] if isinstance(value, (datetime, date_type)): value = _row[index].isoformat() - field_data.append({ - 'value': value, - 'empty': False, - 'invalid': False - }) + field_data.append({"value": value, "empty": False, "invalid": False}) fields_data[field.id] = field_data except Exception: pass @@ -76,7 +68,7 @@ def extract(book): # Save field for field in sheet.field_set.all(): field.data.extend(fields_data.get(field.id, [])) - block_name = 'Field Save ods extract {}'.format(field.title) + block_name = "Field Save ods extract {}".format(field.title) with LogTime(block_name=block_name): field.save() diff --git a/apps/tabular/extractor/xls.py b/apps/tabular/extractor/xls.py index 832f68be92..386608db68 100644 --- a/apps/tabular/extractor/xls.py +++ b/apps/tabular/extractor/xls.py @@ -1,10 +1,11 @@ -import random import os +import random import re import string -from django.conf import settings from subprocess import call +from django.conf import settings + from utils.common import LogTime from .xlsx import extract as xlsx_extract @@ -12,26 +13,28 @@ @LogTime() def extract(book): - tmp_filepath = '/tmp/{}'.format( - ''.join(random.sample(string.ascii_lowercase, 10)) + '.xls' - ) + tmp_filepath = "/tmp/{}".format("".join(random.sample(string.ascii_lowercase, 10)) + ".xls") - with open(tmp_filepath, 'wb') as tmpxls: + with open(tmp_filepath, "wb") as tmpxls: tmpxls.write(book.file.file.read()) tmpxls.flush() - call([ - 'libreoffice', '--headless', '--convert-to', 'xlsx', - tmp_filepath, '--outdir', settings.TEMP_DIR, - ]) - - xlsx_filename = os.path.join( - settings.TEMP_DIR, - re.sub(r'xls$', 'xlsx', os.path.basename(tmp_filepath)) + call( + [ + "libreoffice", + "--headless", + "--convert-to", + "xlsx", + tmp_filepath, + "--outdir", + settings.TEMP_DIR, + ] ) + xlsx_filename = os.path.join(settings.TEMP_DIR, re.sub(r"xls$", "xlsx", os.path.basename(tmp_filepath))) + response = xlsx_extract(book, filename=xlsx_filename) # Clean up converted xlsx file - call(['rm', '-f', xlsx_filename, tmp_filepath]) + call(["rm", "-f", xlsx_filename, tmp_filepath]) return response diff --git a/apps/tabular/extractor/xlsx.py b/apps/tabular/extractor/xlsx.py index 0e56ee61af..6ca44645ef 100644 --- a/apps/tabular/extractor/xlsx.py +++ b/apps/tabular/extractor/xlsx.py @@ -1,14 +1,11 @@ import logging +from datetime import datetime + from openpyxl import load_workbook -from ..models import Sheet, Field -from datetime import datetime +from utils.common import LogTime, excel_to_python_date_format, format_date_or_iso -from utils.common import ( - LogTime, - excel_to_python_date_format, - format_date_or_iso, -) +from ..models import Field, Sheet logger = logging.getLogger(__name__) @@ -34,12 +31,8 @@ def get_excel_value(cell): if value is not None and isinstance(value, datetime): dateformat = cell.number_format # try casting to python format - python_format = excel_to_python_date_format( - dateformat - ) - return format_date_or_iso( - cell.value, python_format - ) + python_format = excel_to_python_date_format(dateformat) + return format_date_or_iso(cell.value, python_format) elif value is not None and not isinstance(value, str): return str(cell.internal_value) return str(value) @@ -49,11 +42,11 @@ def get_excel_value(cell): def extract(book, filename=None): options = book.options if book.options else {} Sheet.objects.filter(book=book).delete() # Delete all previous sheets - with open(filename, 'rb') if filename else book.get_file() as xlsx_file: + with open(filename, "rb") if filename else book.get_file() as xlsx_file: workbook = load_workbook(xlsx_file, data_only=True, read_only=True) for sheet_key, wb_sheet in enumerate(workbook.worksheets): - sheet_options = options.get('sheets', {}).get(str(sheet_key), {}) - if sheet_options.get('skip', False): + sheet_options = options.get("sheets", {}).get(str(sheet_key), {}) + if sheet_options.get("skip", False): continue sheet = Sheet.objects.create( title=wb_sheet.title, @@ -62,7 +55,7 @@ def extract(book, filename=None): sheet_rows = [] - no_headers = sheet_options.get('no_headers', False) + no_headers = sheet_options.get("no_headers", False) max_col_length = 1 for row in wb_sheet.iter_rows(): @@ -78,23 +71,16 @@ def extract(book, filename=None): return if no_headers: - fields = [ - Field(title=f'Column {x}', sheet=sheet, ordering=x, data=[]) - for x in range(max_col_length) - ] + fields = [Field(title=f"Column {x}", sheet=sheet, ordering=x, data=[]) for x in range(max_col_length)] else: fields = [] for x in range(max_col_length): row_len = len(sheet_rows[0]) - title_val = sheet_rows[0][x]['value'] if row_len > x else None - title = title_val or f'Column {x}' + title_val = sheet_rows[0][x]["value"] if row_len > x else None + title = title_val or f"Column {x}" fields.append(Field(title=title, sheet=sheet, ordering=x, data=[])) - empty_value = { - 'value': None, - 'invalid': False, - 'empty': True - } + empty_value = {"value": None, "invalid": False, "empty": True} # Now append data to fields for row in sheet_rows: row_len = len(row) @@ -120,11 +106,7 @@ def get_row_data(row): if cell.value is not None: max_data_col = curr_col value = get_excel_value(cell) - data.append({ - 'value': value, - 'empty': value is None, - 'invalid': False - }) + data.append({"value": value, "empty": value is None, "invalid": False}) curr_col += 1 # Now clip the data beyond which there is nothing - return data[:max_data_col + 1] + return data[: max_data_col + 1] diff --git a/apps/tabular/filters.py b/apps/tabular/filters.py index 42b0ca51a0..e8e99ed75a 100644 --- a/apps/tabular/filters.py +++ b/apps/tabular/filters.py @@ -7,10 +7,10 @@ class CacheStatusListFilter(admin.SimpleListFilter): # Human-readable title which will be displayed in the # right admin sidebar just above the filter options. - title = _('Cache Status') + title = _("Cache Status") # Parameter for the filter that will be used in the URL query. - parameter_name = 'is_cache_status' + parameter_name = "is_cache_status" def lookups(self, request, model_admin): return Field.CACHE_STATUS_TYPES diff --git a/apps/tabular/models.py b/apps/tabular/models.py index b50a1a625a..ea866a14e0 100644 --- a/apps/tabular/models.py +++ b/apps/tabular/models.py @@ -1,39 +1,39 @@ import time from django.db import models, transaction -from user_resource.models import UserResource from gallery.models import File from project.models import Project -from utils.common import get_file_from_url - from tabular.utils import get_cast_function +from user_resource.models import UserResource + +from utils.common import get_file_from_url class Book(UserResource): # STATUS TYPES - INITIAL = 'initial' - PENDING = 'pending' - SUCCESS = 'success' - FAILED = 'failed' + INITIAL = "initial" + PENDING = "pending" + SUCCESS = "success" + FAILED = "failed" STATUS_TYPES = ( - (INITIAL, 'Initial (Book Just Added)'), - (PENDING, 'Pending'), - (SUCCESS, 'Success'), - (FAILED, 'Failed'), + (INITIAL, "Initial (Book Just Added)"), + (PENDING, "Pending"), + (SUCCESS, "Success"), + (FAILED, "Failed"), ) # FILE TYPES - CSV = 'csv' - XLSX = 'xlsx' - XLS = 'xls' - ODS = 'ods' + CSV = "csv" + XLSX = "xlsx" + XLS = "xls" + ODS = "ods" FILE_TYPES = ( - (CSV, 'CSV'), - (XLSX, 'XLSX'), - (XLS, 'XLS'), - (ODS, 'ODS'), + (CSV, "CSV"), + (XLSX, "XLSX"), + (XLS, "XLS"), + (ODS, "ODS"), ) META_REQUIRED_FILE_TYPES = [XLSX, XLS] @@ -43,16 +43,22 @@ class Book(UserResource): FILE_TYPE_ERROR = 101 ERROR_TYPES = ( - (UNKNOWN_ERROR, 'Unknown error'), - (FILE_TYPE_ERROR, 'File type error'), + (UNKNOWN_ERROR, "Unknown error"), + (FILE_TYPE_ERROR, "File type error"), ) title = models.CharField(max_length=255) file = models.OneToOneField( - File, null=True, blank=True, on_delete=models.SET_NULL, + File, + null=True, + blank=True, + on_delete=models.SET_NULL, ) project = models.ForeignKey( - Project, null=True, default=None, on_delete=models.CASCADE, + Project, + null=True, + default=None, + on_delete=models.CASCADE, ) url = models.TextField(null=True, blank=True) status = models.CharField( @@ -63,7 +69,8 @@ class Book(UserResource): error = models.CharField( max_length=30, choices=ERROR_TYPES, - blank=True, null=True, + blank=True, + null=True, ) file_type = models.CharField( max_length=30, @@ -79,23 +86,32 @@ def get_file(self): return get_file_from_url(self.url) def get_pending_fields_id(self): - return Field.objects.filter( - sheet__book=self, - cache__status=Field.CACHE_PENDING, - ).distinct().values_list('id', flat=True) + return ( + Field.objects.filter( + sheet__book=self, + cache__status=Field.CACHE_PENDING, + ) + .distinct() + .values_list("id", flat=True) + ) def get_status(self): - return Field.objects.filter( - sheet__book=self, - cache__status=Field.CACHE_PENDING, - ).count() == 0 + return ( + Field.objects.filter( + sheet__book=self, + cache__status=Field.CACHE_PENDING, + ).count() + == 0 + ) def get_processed_fields(self, fields=[]): """ Return success cached fields """ return Field.objects.filter( - sheet__book=self, cache__status=Field.CACHE_SUCCESS, id__in=fields, + sheet__book=self, + cache__status=Field.CACHE_SUCCESS, + id__in=fields, ).distinct() def __str__(self): @@ -121,53 +137,50 @@ def save(self, *args, **kwargs): # Re-Trigger column generation if data_row_index changed if self.data_row_index != self.current_data_row_index: - from tabular.tasks import tabular_generate_columns_image # to prevent circular import + from tabular.tasks import ( + tabular_generate_columns_image, # to prevent circular import + ) + # First set cache pending to all fields for field in self.field_set.all(): - field.cache['status'] = Field.CACHE_PENDING - field.cache['time'] = time.time() + field.cache["status"] = Field.CACHE_PENDING + field.cache["time"] = time.time() # Update the field title if self.data_row_index > 0: - field.title = str(field.data[self.data_row_index - 1]['value']) + field.title = str(field.data[self.data_row_index - 1]["value"]) field.save() - field_ids = self.field_set.values_list('id', flat=True) - transaction.on_commit( - lambda: tabular_generate_columns_image.delay(list(field_ids)) - ) + field_ids = self.field_set.values_list("id", flat=True) + transaction.on_commit(lambda: tabular_generate_columns_image.delay(list(field_ids))) # Update current_options value self.current_data_row_index = self.data_row_index class Field(models.Model): - CACHE_PENDING = 'pending' - CACHE_SUCCESS = 'success' - CACHE_ERROR = 'error' + CACHE_PENDING = "pending" + CACHE_SUCCESS = "success" + CACHE_ERROR = "error" - NUMBER = 'number' - STRING = 'string' - DATETIME = 'datetime' - GEO = 'geo' + NUMBER = "number" + STRING = "string" + DATETIME = "datetime" + GEO = "geo" CACHE_STATUS_TYPES = ( - (CACHE_PENDING, 'Pending'), - (CACHE_SUCCESS, 'Success'), - (CACHE_ERROR, 'Error'), + (CACHE_PENDING, "Pending"), + (CACHE_SUCCESS, "Success"), + (CACHE_ERROR, "Error"), ) FIELD_TYPES = ( - (NUMBER, 'Number'), - (STRING, 'String'), - (DATETIME, 'Datetime'), - (GEO, 'Geo'), + (NUMBER, "Number"), + (STRING, "String"), + (DATETIME, "Datetime"), + (GEO, "Geo"), ) title = models.CharField(max_length=255) sheet = models.ForeignKey(Sheet, on_delete=models.CASCADE) - type = models.CharField( - max_length=30, - choices=FIELD_TYPES, - default=STRING - ) + type = models.CharField(max_length=30, choices=FIELD_TYPES, default=STRING) hidden = models.BooleanField(default=False) options = models.JSONField(default=None, blank=True, null=True) cache = models.JSONField(default=dict, blank=True, null=True) @@ -185,7 +198,7 @@ def __init__(self, *args, **kwargs): self.current_options = self.options def __str__(self): - return '{}:{}:{} '.format(self.pk, self.title, self.type) + return "{}:{}:{} ".format(self.pk, self.title, self.type) def cast_data(self, geos_names={}, geos_codes={}): """ @@ -202,39 +215,34 @@ def cast_data(self, geos_names={}, geos_codes={}): # Now iterate through every item to find empty/invalid values for i, value in enumerate(values): - val = value['value'] + val = value["value"] - value.pop('invalid', None) - value.pop('empty', None) - value.pop('processed_value', None) + value.pop("invalid", None) + value.pop("empty", None) + value.pop("processed_value", None) - if val is None or val == '': - value['empty'] = True + if val is None or val == "": + value["empty"] = True continue casted = cast_func(val, **self.options) if casted is None: - value['invalid'] = True + value["invalid"] = True elif type == Field.GEO: - value['processed_value'] = casted['id'] - regions[casted['region']] = casted['region_title'] + value["processed_value"] = casted["id"] + regions[casted["region"]] = casted["region_title"] elif type == Field.NUMBER: - value['processed_value'] = casted[0] # (number, separator) + value["processed_value"] = casted[0] # (number, separator) elif type == Field.DATETIME: - value['processed_value'] = casted.isoformat() # (parsed_date) + value["processed_value"] = casted.isoformat() # (parsed_date) if type == Field.GEO and regions: - options['regions'] = [ - {'id': k, 'title': v} for k, v in regions.items() - ] + options["regions"] = [{"id": k, "title": v} for k, v in regions.items()] - return { - 'values': values, - 'options': options - } + return {"values": values, "options": options} def save(self, *args, **kwargs): - if hasattr(self, 'geodata'): + if hasattr(self, "geodata"): self.geodata.delete() super().save(*args, **kwargs) self.current_type = self.type @@ -245,27 +253,23 @@ def get_option(self, key, default_value=None): return options.get(key, default_value) class Meta: - ordering = ['ordering'] + ordering = ["ordering"] class Geodata(models.Model): # STATUS TYPES - PENDING = 'pending' - SUCCESS = 'success' - FAILED = 'failed' + PENDING = "pending" + SUCCESS = "success" + FAILED = "failed" STATUS_TYPES = ( - (PENDING, 'Pending'), - (SUCCESS, 'Success'), - (FAILED, 'Failed'), + (PENDING, "Pending"), + (SUCCESS, "Success"), + (FAILED, "Failed"), ) data = models.JSONField(default=None, blank=True, null=True) - field = models.OneToOneField( - Field, - on_delete=models.CASCADE, - related_name='geodata' - ) + field = models.OneToOneField(Field, on_delete=models.CASCADE, related_name="geodata") status = models.CharField( max_length=30, choices=STATUS_TYPES, @@ -273,4 +277,4 @@ class Geodata(models.Model): ) def __str__(self): - return '{} (Geodata)'.format(self.field.title) + return "{} (Geodata)".format(self.field.title) diff --git a/apps/tabular/receivers.py b/apps/tabular/receivers.py index be039cc3b6..b9a5d91b82 100644 --- a/apps/tabular/receivers.py +++ b/apps/tabular/receivers.py @@ -1,8 +1,7 @@ from django.db import models from django.dispatch import receiver - from tabular.models import Field -from tabular.utils import get_geos_dict, get_geos_codes_from_geos_names +from tabular.utils import get_geos_codes_from_geos_names, get_geos_dict @receiver(models.signals.pre_save, sender=Field) @@ -11,12 +10,11 @@ def on_field_saved(sender, **kwargs): The purpose of this receiver is to update the row value types in tabular sheet model whenever field type changes """ - field = kwargs.get('instance') + field = kwargs.get("instance") if field is None or not field.id: return - if field.type == field.current_type and \ - field.options == field.current_options: + if field.type == field.current_type and field.options == field.current_options: return geos_names = geos_codes = {} @@ -25,8 +23,8 @@ def on_field_saved(sender, **kwargs): geos_codes = get_geos_codes_from_geos_names(geos_names) cast_info = field.cast_data(geos_names, geos_codes) - field.data = cast_info['values'] + field.data = cast_info["values"] - field.options = cast_info['options'] + field.options = cast_info["options"] # But don't save here, will cause recursion # field.save() diff --git a/apps/tabular/serializers.py b/apps/tabular/serializers.py index b3b8f40315..fa539e8b39 100644 --- a/apps/tabular/serializers.py +++ b/apps/tabular/serializers.py @@ -1,21 +1,16 @@ import time + from django.db import transaction -from rest_framework import serializers from drf_dynamic_fields import DynamicFieldsMixin -from drf_writable_nested.serializers import ( - NestedCreateMixin, - NestedUpdateMixin, -) - -from deep.serializers import RemoveNullFieldsMixin - +from drf_writable_nested.serializers import NestedCreateMixin, NestedUpdateMixin +from entry.models import Entry +from geo.serializers import AdminLevel, Region, SimpleRegionSerializer +from rest_framework import serializers from user_resource.serializers import UserResourceSerializer -from geo.serializers import SimpleRegionSerializer, Region, AdminLevel - -from entry.models import Entry +from deep.serializers import RemoveNullFieldsMixin -from .models import Book, Sheet, Field, Geodata +from .models import Book, Field, Geodata, Sheet from .tasks import tabular_generate_column_image @@ -25,50 +20,43 @@ class GeodataSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = Geodata - exclude = ('field',) + exclude = ("field",) def get_regions(self, geodata): if not geodata.data: return [] - area_ids = [d['selected_id'] for d in geodata.data] - regions = Region.objects.filter( - adminlevel__geoarea__id__in=area_ids - ).distinct() + area_ids = [d["selected_id"] for d in geodata.data] + regions = Region.objects.filter(adminlevel__geoarea__id__in=area_ids).distinct() return SimpleRegionSerializer(regions, many=True).data def get_admin_levels(self, geodata): if not geodata.data: return [] - area_ids = [d['selected_id'] for d in geodata.data] - admin_levels = AdminLevel.objects.filter( - geoarea__id__in=area_ids - ).distinct() - return admin_levels.values_list('id', flat=True) + area_ids = [d["selected_id"] for d in geodata.data] + admin_levels = AdminLevel.objects.filter(geoarea__id__in=area_ids).distinct() + return admin_levels.values_list("id", flat=True) -class FieldSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer -): +class FieldSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): geodata = serializers.SerializerMethodField() class Meta: model = Field - ref_name = 'TabularFieldSerializer' - exclude = ('sheet', 'cache',) + ref_name = "TabularFieldSerializer" + exclude = ( + "sheet", + "cache", + ) def get_geodata(self, obj): - if obj.type == Field.GEO and hasattr(obj, 'geodata'): + if obj.type == Field.GEO and hasattr(obj, "geodata"): return GeodataSerializer(obj.geodata).data return None def update(self, instance, validated_data): - validated_data['cache'] = {'status': Field.CACHE_PENDING, 'time': time.time()} + validated_data["cache"] = {"status": Field.CACHE_PENDING, "time": time.time()} instance = super().update(instance, validated_data) - transaction.on_commit( - lambda: tabular_generate_column_image.delay(instance.id) - ) + transaction.on_commit(lambda: tabular_generate_column_image.delay(instance.id)) return instance @@ -77,51 +65,57 @@ class FieldMetaSerializer(FieldSerializer): class Meta: model = Field - exclude = ('sheet', 'data', 'cache',) + exclude = ( + "sheet", + "data", + "cache", + ) class FieldProcessedOnlySerializer(FieldSerializer): class Meta: model = Field - exclude = ('data',) + exclude = ("data",) class SheetSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - NestedCreateMixin, - NestedUpdateMixin, + RemoveNullFieldsMixin, + DynamicFieldsMixin, + NestedCreateMixin, + NestedUpdateMixin, ): - fields = FieldSerializer(many=True, source='field_set', required=False) + fields = FieldSerializer(many=True, source="field_set", required=False) class Meta: model = Sheet - exclude = ('book',) + exclude = ("book",) class SheetMetaSerializer(SheetSerializer): - fields = FieldMetaSerializer(many=True, source='field_set', required=False) + fields = FieldMetaSerializer(many=True, source="field_set", required=False) class SheetProcessedOnlySerializer(SheetSerializer): fields = FieldProcessedOnlySerializer( - many=True, source='field_set', required=False, + many=True, + source="field_set", + required=False, ) class BookSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer, - NestedCreateMixin, - NestedUpdateMixin, + RemoveNullFieldsMixin, + DynamicFieldsMixin, + UserResourceSerializer, + NestedCreateMixin, + NestedUpdateMixin, ): - sheets = SheetSerializer(many=True, source='sheet_set', required=False) + sheets = SheetSerializer(many=True, source="sheet_set", required=False) entry_count = serializers.SerializerMethodField() class Meta: model = Book - fields = '__all__' + fields = "__all__" def get_entry_count(self, instance): return Entry.objects.filter( @@ -130,12 +124,14 @@ def get_entry_count(self, instance): class BookMetaSerializer(BookSerializer): - sheets = SheetMetaSerializer(many=True, source='sheet_set', required=False) + sheets = SheetMetaSerializer(many=True, source="sheet_set", required=False) class BookProcessedOnlySerializer(BookSerializer): sheets = SheetProcessedOnlySerializer( - many=True, source='sheet_set', required=False, + many=True, + source="sheet_set", + required=False, ) pending_fields = serializers.SerializerMethodField() diff --git a/apps/tabular/tasks.py b/apps/tabular/tasks.py index 465696807d..cd06aabdd3 100644 --- a/apps/tabular/tasks.py +++ b/apps/tabular/tasks.py @@ -1,26 +1,22 @@ -import time import logging +import time from celery import shared_task -from redis_store import redis -from django.db import transaction from django.contrib.postgres import search +from django.db import transaction +from geo.models import GeoArea, models +from redis_store import redis -from geo.models import models, GeoArea - -from utils.common import redis_lock, LogTime +from utils.common import LogTime, redis_lock -from .models import Book, Geodata, Field -from .extractor import csv, xls, xlsx, ods -from .viz.renderer import ( - calc_preprocessed_data, - render_field_chart, -) +from .extractor import csv, ods, xls, xlsx +from .models import Book, Field, Geodata from .utils import ( + get_geos_codes_from_geos_names, get_geos_dict, sample_and_detect_type_and_options, - get_geos_codes_from_geos_names, ) +from .viz.renderer import calc_preprocessed_data, render_field_chart logger = logging.getLogger(__name__) @@ -49,7 +45,7 @@ def auto_detect_and_update_fields(book): geos_codes = get_geos_codes_from_geos_names(geos_names) def isValueNotEmpty(v): - return v.get('value') + return v.get("value") generate_column_columns = [] @@ -62,20 +58,18 @@ def isValueNotEmpty(v): for field in fields: data = field.data[row_index:] emptyFiltered = list(filter(isValueNotEmpty, data)) - detected_info = sample_and_detect_type_and_options( - emptyFiltered, geos_names, geos_codes - ) - field.type = detected_info['type'] - field.options = detected_info['options'] + detected_info = sample_and_detect_type_and_options(emptyFiltered, geos_names, geos_codes) + field.type = detected_info["type"] + field.options = detected_info["options"] cast_info = field.cast_data(geos_names, geos_codes) - field.data = cast_info['values'] - field.options = cast_info['options'] + field.data = cast_info["values"] + field.options = cast_info["options"] field.cache = { - 'status': Field.CACHE_PENDING, - 'image_status': Field.CACHE_PENDING, - 'time': time.time(), + "status": Field.CACHE_PENDING, + "image_status": Field.CACHE_PENDING, + "time": time.time(), } field.save() @@ -88,17 +82,13 @@ def isValueNotEmpty(v): def _tabular_meta_extract_geo(geodata): field = geodata.field project = field.sheet.book.project - project_geoareas = GeoArea.objects.filter( - admin_level__region__project=project - ) + project_geoareas = GeoArea.objects.filter(admin_level__region__project=project) geodata_data = [] - is_code = field.get_option('geo_type', 'name') == 'code' - admin_level = field.get_option('admin_level') + is_code = field.get_option("geo_type", "name") == "code" + admin_level = field.get_option("admin_level") if admin_level: - project_geoareas = project_geoareas.filter( - admin_level__level=admin_level - ) + project_geoareas = project_geoareas.filter(admin_level__level=admin_level) for row in geodata.field.sheet.data: similar_areas = [] @@ -109,30 +99,38 @@ def _tabular_meta_extract_geo(geodata): similarity=models.Value(1, models.FloatField()), ) else: - geoareas = project_geoareas.annotate( - similarity=search.TrigramSimilarity('title', query), - ).filter(similarity__gt=0.2).order_by('-similarity') + geoareas = ( + project_geoareas.annotate( + similarity=search.TrigramSimilarity("title", query), + ) + .filter(similarity__gt=0.2) + .order_by("-similarity") + ) for geoarea in geoareas: - similar_areas.append({ - 'id': geoarea.pk, - 'similarity': geoarea.similarity, - }) - geodata_data.append({ - 'similar_areas': similar_areas, - 'selected_id': geoareas.first().pk if geoareas.exists() else None, - }) + similar_areas.append( + { + "id": geoarea.pk, + "similarity": geoarea.similarity, + } + ) + geodata_data.append( + { + "similar_areas": similar_areas, + "selected_id": geoareas.first().pk if geoareas.exists() else None, + } + ) geodata.data = geodata_data geodata.save() return True @shared_task -@redis_lock('tabular_generate_column_image__{0}') +@redis_lock("tabular_generate_column_image__{0}") def tabular_generate_column_image(field_id): field = Field.objects.filter(pk=field_id).first() if field is None: - logger.warning('Field ({}) doesn\'t exists'.format(field_id)) + logger.warning("Field ({}) doesn't exists".format(field_id)) calc_preprocessed_data(field) return render_field_chart(field) @@ -149,7 +147,7 @@ def tabular_generate_columns_image(fields_id): @shared_task @LogTime() def tabular_extract_book(book_pk): - key = 'tabular_extract_book_{}'.format(book_pk) + key = "tabular_extract_book_{}".format(book_pk) lock = redis.get_lock(key, 60 * 60 * 24) # Lock lifetime 24 hours have_lock = lock.acquire(blocking=False) if not have_lock: @@ -161,7 +159,7 @@ def tabular_extract_book(book_pk): return_value = _tabular_extract_book(book) book.status = Book.SUCCESS except Exception: - logger.error('Tabular Extract Book Failed!!', exc_info=True) + logger.error("Tabular Extract Book Failed!!", exc_info=True) book.status = Book.FAILED book.error = Book.UNKNOWN_ERROR # TODO: handle all type of error return_value = False @@ -174,7 +172,7 @@ def tabular_extract_book(book_pk): @shared_task def tabular_extract_geo(geodata_pk): - key = 'tabular_meta_extract_geo_{}'.format(geodata_pk) + key = "tabular_meta_extract_geo_{}".format(geodata_pk) lock = redis.get_lock(key, 60 * 60 * 24) # Lock lifetime 24 hours have_lock = lock.acquire(blocking=False) if not have_lock: @@ -186,7 +184,7 @@ def tabular_extract_geo(geodata_pk): return_value = _tabular_meta_extract_geo(geodata) geodata.status = Geodata.SUCCESS except Exception: - logger.error('Tabular Extract Geo Failed!!', exc_info=True) + logger.error("Tabular Extract Geo Failed!!", exc_info=True) geodata.status = Geodata.FAILED return_value = False @@ -202,15 +200,18 @@ def remaining_tabular_generate_columns_image(): Scheduled task NOTE: Only use it through schedular """ - key = 'remaining_tabular_generate_columns_image' + key = "remaining_tabular_generate_columns_image" lock = redis.get_lock(key, 60 * 60 * 2) # Lock lifetime 2 hours have_lock = lock.acquire(blocking=False) if not have_lock: - return '{} Locked'.format(key) + return "{} Locked".format(key) tabular_generate_columns_image( Field.objects.filter( cache__status=Field.CACHE_PENDING, - ).distinct().order_by('id').values_list('id', flat=True)[:300] + ) + .distinct() + .order_by("id") + .values_list("id", flat=True)[:300] ) lock.release() return True diff --git a/apps/tabular/tests/test_unit.py b/apps/tabular/tests/test_unit.py index 9ee2c0a106..43f0534620 100644 --- a/apps/tabular/tests/test_unit.py +++ b/apps/tabular/tests/test_unit.py @@ -1,25 +1,24 @@ import os -from autofixture.base import AutoFixture from tempfile import NamedTemporaryFile -from deep.tests import TestCase, TEST_MEDIA_ROOT -from utils.common import makedirs - +from autofixture.base import AutoFixture from gallery.models import File -from geo.models import GeoArea, Region, AdminLevel +from geo.models import AdminLevel, GeoArea, Region from project.models import Project - -from tabular.tasks import auto_detect_and_update_fields from tabular.extractor import csv from tabular.models import Book, Field, Sheet +from tabular.tasks import auto_detect_and_update_fields from tabular.utils import ( + auto_detect_datetime, parse_comma_separated, parse_dot_separated, parse_space_separated, - auto_detect_datetime, ) -consistent_csv_data = '''id,age,name,date,place +from deep.tests import TEST_MEDIA_ROOT, TestCase +from utils.common import makedirs + +consistent_csv_data = """id,age,name,date,place 1,10,john,2018 october 28,Kathmandu 1,10,john,2018 october 28,Kathmandu 1,10,john,2018 october 28,Kathmandu @@ -29,41 +28,41 @@ 1,10,john,2018 october 28,banana 1,10,john,2018 october 28, abc,10,john,2018 october 28,mango -abc,30,doe,10 Nevem 2018,Kathmandu''' +abc,30,doe,10 Nevem 2018,Kathmandu""" -inconsistent_csv_data = '''id,age,name,date,place +inconsistent_csv_data = """id,age,name,date,place 1,10,john,1994 December 29,Kathmandu abc,10,john,1994 Deer 29,Kathmandu a,10,john,199 Dmber 29,Kathmandu 1,10,john,1994 December 29,Kathmandu abc,10,john,14 Dber 29,Kathmandu -abc,30,doe,10 Nevem 2018,Mango''' +abc,30,doe,10 Nevem 2018,Mango""" -geo_data_type_code = '''id,age,name,date,place +geo_data_type_code = """id,age,name,date,place 1,10,john,1994 December 29,KAT abc,10,john,1994 Deer 29,KAT 1,10,john,199 Dmber 29,KAT 1,10,john,1994 December 29,KAT abc,10,john,14 Dber 29, KAT -abc,30,doe,10 Nevem 2018,KAT''' +abc,30,doe,10 Nevem 2018,KAT""" -geo_data_type_name = '''id,age,name,date,place +geo_data_type_name = """id,age,name,date,place 1,10,john,1994 December 29,Kathmandu abc,10,john,1994 Deer 29,Kathmandu 1,10,john,199 Dmber 29,Kathmandu 1,10,john,1994 December 29,Kathmandu abc,10,john,14 Dber 29,Kathmandu -abc,30,doe,10 Nevem 2018,''' +abc,30,doe,10 Nevem 2018,""" def check_invalid(index, data): - assert 'invalid' in data[index] - assert data[index]['invalid'] is True + assert "invalid" in data[index] + assert data[index]["invalid"] is True def check_empty(index, data): - assert 'empty' in data[index] - assert data[index]['empty'] is True + assert "empty" in data[index] + assert data[index]["empty"] is True class TestTabularExtraction(TestCase): @@ -79,22 +78,18 @@ def setUp(self): # NOTE: Using choices created random values, and thus error occured self.project = self.create(Project) # Create region - self.region = Region.objects.create(code='RG', title='region') + self.region = Region.objects.create(code="RG", title="region") # Create admin levels - self.admin1 = AdminLevel.objects.create(region=self.region, level=1, title='level1') - self.admin2 = AdminLevel.objects.create(region=self.region, level=2, title='level2') + self.admin1 = AdminLevel.objects.create(region=self.region, level=1, title="level1") + self.admin2 = AdminLevel.objects.create(region=self.region, level=2, title="level2") # Create GeoArea - self.geo = GeoArea.objects.create( - admin_level=self.admin1, - title='Kathmandu', - code='KAT' - ) + self.geo = GeoArea.objects.create(admin_level=self.admin1, title="Kathmandu", code="KAT") # Just create multiple geo in different admin to check detection consistency GeoArea.objects.create( admin_level=self.admin2, - title='Central', - code='CTR', + title="Central", + code="CTR", ) self.project.regions.add(self.region) self.project.save() @@ -114,38 +109,36 @@ def test_auto_detection_consistent(self): for field in Field.objects.filter(sheet=sheet): assert len(field.actual_data) == 10 - if field.title == 'id': - assert field.type == Field.NUMBER, 'id is number' - assert 'separator' in field.options - assert field.options['separator'] == 'none' + if field.title == "id": + assert field.type == Field.NUMBER, "id is number" + assert "separator" in field.options + assert field.options["separator"] == "none" self.validate_number_field(field.data) # Check invalid values check_invalid(8, field.actual_data) check_invalid(9, field.actual_data) - elif field.title == 'age': - assert field.type == Field.NUMBER, 'age is number' - assert 'separator' in field.options - assert field.options['separator'] == 'none' + elif field.title == "age": + assert field.type == Field.NUMBER, "age is number" + assert "separator" in field.options + assert field.options["separator"] == "none" self.validate_number_field(field.data) - elif field.title == 'name': - assert field.type == Field.STRING, 'name is string' - elif field.title == 'date': - assert field.type == Field.DATETIME, 'date is datetime' + elif field.title == "name": + assert field.type == Field.STRING, "name is string" + elif field.title == "date": + assert field.type == Field.DATETIME, "date is datetime" assert field.options is not None - assert 'date_format' in field.options + assert "date_format" in field.options for datum in field.data: - assert datum.get('invalid') is not None or \ - datum.get('empty') is not None or \ - 'processed_value' in datum + assert datum.get("invalid") is not None or datum.get("empty") is not None or "processed_value" in datum check_invalid(9, field.actual_data) - elif field.title == 'place': - assert field.type == Field.GEO, 'place is geo' + elif field.title == "place": + assert field.type == Field.GEO, "place is geo" assert field.options is not None - assert 'regions' in field.options - assert 'admin_level' in field.options - for x in field.options['regions']: - assert 'id' in x - assert 'title' in x + assert "regions" in field.options + assert "admin_level" in field.options + for x in field.options["regions"]: + assert "id" in x + assert "title" in x check_invalid(6, field.actual_data) check_empty(7, field.actual_data) @@ -166,25 +159,21 @@ def test_auto_detection_inconsistent(self): for v in field.data: assert isinstance(v, dict) - if field.title == 'id': - assert field.type == Field.STRING, \ - 'id is string as it is inconsistent' + if field.title == "id": + assert field.type == Field.STRING, "id is string as it is inconsistent" # Verify that being string, no value is invalid for v in field.data: - assert not v.get('invalid'), \ - "Since string, shouldn't be invalid" - elif field.title == 'age': - assert field.type == Field.NUMBER, 'age is number' - assert 'separator' in field.options - assert field.options['separator'] == 'none' - elif field.title == 'name': - assert field.type == Field.STRING, 'name is string' - elif field.title == 'date': - assert field.type == Field.STRING, \ - 'date is string: only less than 80% rows are of date type' - elif field.title == 'place': - assert field.type == Field.GEO, \ - 'place is geo: more than 80% rows are of geo type' + assert not v.get("invalid"), "Since string, shouldn't be invalid" + elif field.title == "age": + assert field.type == Field.NUMBER, "age is number" + assert "separator" in field.options + assert field.options["separator"] == "none" + elif field.title == "name": + assert field.type == Field.STRING, "name is string" + elif field.title == "date": + assert field.type == Field.STRING, "date is string: only less than 80% rows are of date type" + elif field.title == "place": + assert field.type == Field.GEO, "place is geo: more than 80% rows are of geo type" def test_auto_detection_geo_type_name(self): """ @@ -197,30 +186,27 @@ def test_auto_detection_geo_type_name(self): # now validate auto detected fields geofield = None for field in Field.objects.all(): - if field.title == 'place': + if field.title == "place": geofield = field - assert field.type == Field.GEO, \ - 'place is geo: more than 80% rows are of geo type' + assert field.type == Field.GEO, "place is geo: more than 80% rows are of geo type" assert field.options != {} - assert 'regions' in field.options - assert 'admin_level' in field.options - for x in field.options['regions']: - assert 'id' in x - assert 'title' in x - assert 'geo_type' in field.options - assert field.options['geo_type'] == 'name' + assert "regions" in field.options + assert "admin_level" in field.options + for x in field.options["regions"]: + assert "id" in x + assert "title" in x + assert "geo_type" in field.options + assert field.options["geo_type"] == "name" if not geofield: return - kathmandu_geo = GeoArea.objects.filter(code='KAT')[0] + kathmandu_geo = GeoArea.objects.filter(code="KAT")[0] for v in geofield.data: - assert v.get('invalid') or v.get('empty') or 'processed_value' in v - assert 'value' in v - assert v.get('empty') \ - or v.get('invalid') \ - or v['processed_value'] == kathmandu_geo.id + assert v.get("invalid") or v.get("empty") or "processed_value" in v + assert "value" in v + assert v.get("empty") or v.get("invalid") or v["processed_value"] == kathmandu_geo.id def test_auto_detection_geo_type_code(self): """ @@ -237,30 +223,27 @@ def test_auto_detection_geo_type_code(self): geofield = None # now validate auto detected fields for field in Field.objects.all(): - if field.title == 'place': + if field.title == "place": geofield = field - assert field.type == Field.GEO, \ - 'place is geo: more than 80% rows are of geo type' + assert field.type == Field.GEO, "place is geo: more than 80% rows are of geo type" assert field.options != {} - assert 'regions' in field.options - assert 'admin_level' in field.options - for x in field.options['regions']: - assert 'id' in x - assert 'title' in x - assert 'geo_type' in field.options - assert field.options['geo_type'] == 'code' + assert "regions" in field.options + assert "admin_level" in field.options + for x in field.options["regions"]: + assert "id" in x + assert "title" in x + assert "geo_type" in field.options + assert field.options["geo_type"] == "code" if not geofield: return - kathmandu_geo = GeoArea.objects.filter(code='KAT')[0] + kathmandu_geo = GeoArea.objects.filter(code="KAT")[0] for v in geofield.data: - assert v.get('invalid') or v.get('empty') or 'processed_value' in v - assert 'value' in v - assert v.get('empty') \ - or v.get('invalid') \ - or v['processed_value'] == kathmandu_geo.id + assert v.get("invalid") or v.get("empty") or "processed_value" in v + assert "value" in v + assert v.get("empty") or v.get("invalid") or v["processed_value"] == kathmandu_geo.id def test_sheet_data_change_on_datefield_change_to_string(self): """ @@ -275,16 +258,13 @@ def test_sheet_data_change_on_datefield_change_to_string(self): sheet = book.sheet_set.all()[0] # Now update date_field to string - field = Field.objects.get( - sheet=sheet, - type=Field.DATETIME - ) + field = Field.objects.get(sheet=sheet, type=Field.DATETIME) field.type = Field.STRING field.save() # no vlaue should be invalid for v in field.data: - assert not v.get('invalid', None) + assert not v.get("invalid", None) def test_sheet_data_change_on_string_change_to_geo(self): """ @@ -299,10 +279,7 @@ def test_sheet_data_change_on_string_change_to_geo(self): # We first cast geo field to string because initially it will be auto # detected as geo - field = Field.objects.get( - sheet=sheet, - type=Field.GEO - ) + field = Field.objects.get(sheet=sheet, type=Field.GEO) options = field.options fid = str(field.id) @@ -312,29 +289,29 @@ def test_sheet_data_change_on_string_change_to_geo(self): # no value should be invalid for v in field.data: - assert not v.get('invalid') + assert not v.get("invalid") # Now change type to Geo field.type = Field.GEO # Try removing region, and check if it's automatically added from admin # level - options.pop('region', {}) + options.pop("region", {}) field.options = { **options, } field.save() - kat_geo = GeoArea.objects.filter(code='KAT')[0] + kat_geo = GeoArea.objects.filter(code="KAT")[0] # Check if field has region field = Field.objects.get(id=fid) - assert 'regions' in field.options - regions = field.options['regions'] + assert "regions" in field.options + regions = field.options["regions"] for x in regions: - assert 'id' in x - assert 'title' in x - assert field.options['admin_level'] == kat_geo.admin_level.level - assert regions[0]['id'] == kat_geo.admin_level.region.id + assert "id" in x + assert "title" in x + assert field.options["admin_level"] == kat_geo.admin_level.level + assert regions[0]["id"] == kat_geo.admin_level.region.id # Get sheet again, which should be updated @@ -354,7 +331,7 @@ def test_sheet_option_change_data_row_index(self): for field in sheet.field_set.all(): # Also check field title - assert field.title == field.data[sheet.data_row_index - 1]['value'] + assert field.title == field.data[sheet.data_row_index - 1]["value"] assert len(field.data) == 11, "Data includes the column names as well" assert len(field.actual_data) == 10 @@ -364,34 +341,29 @@ def test_sheet_option_change_data_row_index(self): # check if field actual_data changed or not for field in sheet.field_set.all(): - assert field.title == field.data[sheet.data_row_index - 1]['value'] + assert field.title == field.data[sheet.data_row_index - 1]["value"] # check if Re-triggered or not - assert field.cache['status'] == Field.CACHE_PENDING + assert field.cache["status"] == Field.CACHE_PENDING assert len(field.data) == 11, "Data includes the column names as well" assert len(field.actual_data) == 9 def initialize_data_and_basic_test(self, csv_data): makedirs(TEST_MEDIA_ROOT) - file = NamedTemporaryFile('w', dir=TEST_MEDIA_ROOT, delete=False) + file = NamedTemporaryFile("w", dir=TEST_MEDIA_ROOT, delete=False) self.files.append(file.name) - for x in csv_data.split('\n'): - file.write('{}\n'.format(x)) + for x in csv_data.split("\n"): + file.write("{}\n".format(x)) file.close() # create a book - csvfile = AutoFixture( - File, - field_values={ - 'file': file.name - } - ).create_one() + csvfile = AutoFixture(File, field_values={"file": file.name}).create_one() book = AutoFixture( Book, field_values={ - 'file': csvfile, - 'project': self.project, - } + "file": csvfile, + "project": self.project, + }, ).create_one() csv.extract(book) assert Field.objects.count() == 5 @@ -400,39 +372,38 @@ def initialize_data_and_basic_test(self, csv_data): for field in Field.objects.all(): fieldnames[field.title] = True assert field.type == Field.STRING, "Initial type is string" - assert 'id' in fieldnames, 'id should be a fieldname' - assert 'age' in fieldnames, 'age should be a fieldname' - assert 'name' in fieldnames, 'name should be a field name' - assert 'date' in fieldnames, 'date should be a field name' - assert 'place' in fieldnames, 'place should be a field name' + assert "id" in fieldnames, "id should be a fieldname" + assert "age" in fieldnames, "age should be a fieldname" + assert "name" in fieldnames, "name should be a field name" + assert "date" in fieldnames, "date should be a field name" + assert "place" in fieldnames, "place should be a field name" # check structure of data in sheet for sheet in book.sheet_set.all(): fields = sheet.field_set.all() size = len(fields[0].data) - assert all([len(x.data) == size for x in fields]), \ - "All columns should have same size" + assert all([len(x.data) == size for x in fields]), "All columns should have same size" for field in fields: v = field.data assert isinstance(v, list) for x in v: - assert 'value' in x - assert 'empty' in x - assert isinstance(x['empty'], bool) - assert 'invalid' in x - assert isinstance(x['invalid'], bool) + assert "value" in x + assert "empty" in x + assert isinstance(x["empty"], bool) + assert "invalid" in x + assert isinstance(x["invalid"], bool) return book def validate_number_field(self, items): for i, item in enumerate(items): - assert 'value' in item - assert item.get('invalid') \ - or item.get('empty') \ - or ('processed_value' in item) - assert not item.get('processed_value') \ - or isinstance(item['processed_value'], int)\ - or isinstance(item['processed_value'], float) + assert "value" in item + assert item.get("invalid") or item.get("empty") or ("processed_value" in item) + assert ( + not item.get("processed_value") + or isinstance(item["processed_value"], int) + or isinstance(item["processed_value"], float) + ) def tearDown(self): """Remove temp files""" @@ -443,70 +414,70 @@ def tearDown(self): def test_comma_separated_numbers(): - assert parse_comma_separated('1') == (1.0, 'comma') - assert parse_comma_separated('12') == (12.0, 'comma') - assert parse_comma_separated('100') == (100.0, 'comma') - assert parse_comma_separated('1,200') == (1200.0, 'comma') - assert parse_comma_separated('11,200') == (11200.0, 'comma') - assert parse_comma_separated('111,200') == (111200.0, 'comma') - assert parse_comma_separated('5,111,200') == (5111200.0, 'comma') - assert parse_comma_separated('54,111,200') == (54111200.0, 'comma') - assert parse_comma_separated('543,111,200') == (543111200.0, 'comma') - assert parse_comma_separated('543111,200') is None - assert parse_comma_separated('1,200.35') == (1200.35, 'comma') - assert parse_comma_separated('1,200.35.3') is None - assert parse_comma_separated('') is None + assert parse_comma_separated("1") == (1.0, "comma") + assert parse_comma_separated("12") == (12.0, "comma") + assert parse_comma_separated("100") == (100.0, "comma") + assert parse_comma_separated("1,200") == (1200.0, "comma") + assert parse_comma_separated("11,200") == (11200.0, "comma") + assert parse_comma_separated("111,200") == (111200.0, "comma") + assert parse_comma_separated("5,111,200") == (5111200.0, "comma") + assert parse_comma_separated("54,111,200") == (54111200.0, "comma") + assert parse_comma_separated("543,111,200") == (543111200.0, "comma") + assert parse_comma_separated("543111,200") is None + assert parse_comma_separated("1,200.35") == (1200.35, "comma") + assert parse_comma_separated("1,200.35.3") is None + assert parse_comma_separated("") is None assert parse_comma_separated(None) is None - assert parse_comma_separated('abc,123') is None - assert parse_comma_separated('123,abc,123') is None + assert parse_comma_separated("abc,123") is None + assert parse_comma_separated("123,abc,123") is None def test_dot_separated_numbers(): - assert parse_dot_separated('1') == (1.0, 'dot') - assert parse_dot_separated('12') == (12.0, 'dot') - assert parse_dot_separated('100') == (100.0, 'dot') - assert parse_dot_separated('1.200') == (1200.0, 'dot') - assert parse_dot_separated('11.200') == (11200.0, 'dot') - assert parse_dot_separated('111.200') == (111200.0, 'dot') - assert parse_dot_separated('5.111.200') == (5111200.0, 'dot') - assert parse_dot_separated('54.111.200') == (54111200.0, 'dot') - assert parse_dot_separated('543.111.200') == (543111200.0, 'dot') - assert parse_dot_separated('543111.200') is None - assert parse_dot_separated('1.200,35') == (1200.35, 'dot') - assert parse_dot_separated('1.200,35,3') is None - assert parse_dot_separated('') is None + assert parse_dot_separated("1") == (1.0, "dot") + assert parse_dot_separated("12") == (12.0, "dot") + assert parse_dot_separated("100") == (100.0, "dot") + assert parse_dot_separated("1.200") == (1200.0, "dot") + assert parse_dot_separated("11.200") == (11200.0, "dot") + assert parse_dot_separated("111.200") == (111200.0, "dot") + assert parse_dot_separated("5.111.200") == (5111200.0, "dot") + assert parse_dot_separated("54.111.200") == (54111200.0, "dot") + assert parse_dot_separated("543.111.200") == (543111200.0, "dot") + assert parse_dot_separated("543111.200") is None + assert parse_dot_separated("1.200,35") == (1200.35, "dot") + assert parse_dot_separated("1.200,35,3") is None + assert parse_dot_separated("") is None assert parse_dot_separated(None) is None - assert parse_dot_separated('abc.123') is None - assert parse_dot_separated('123.abc.123') is None + assert parse_dot_separated("abc.123") is None + assert parse_dot_separated("123.abc.123") is None def test_space_separated_numbers(): - assert parse_space_separated('1') == (1.0, 'space') - assert parse_space_separated('12') == (12.0, 'space') - assert parse_space_separated('100') == (100.0, 'space') - assert parse_space_separated('1 200') == (1200.0, 'space') - assert parse_space_separated('11 200') == (11200.0, 'space') - assert parse_space_separated('111 200') == (111200.0, 'space') - assert parse_space_separated('5 111 200') == (5111200.0, 'space') - assert parse_space_separated('54 111 200') == (54111200.0, 'space') - assert parse_space_separated('543 111 200') == (543111200.0, 'space') - assert parse_space_separated('543111 200') is None - assert parse_space_separated('1 200.35') == (1200.35, 'space') - assert parse_space_separated('1 200.35.3') is None - assert parse_space_separated('') is None + assert parse_space_separated("1") == (1.0, "space") + assert parse_space_separated("12") == (12.0, "space") + assert parse_space_separated("100") == (100.0, "space") + assert parse_space_separated("1 200") == (1200.0, "space") + assert parse_space_separated("11 200") == (11200.0, "space") + assert parse_space_separated("111 200") == (111200.0, "space") + assert parse_space_separated("5 111 200") == (5111200.0, "space") + assert parse_space_separated("54 111 200") == (54111200.0, "space") + assert parse_space_separated("543 111 200") == (543111200.0, "space") + assert parse_space_separated("543111 200") is None + assert parse_space_separated("1 200.35") == (1200.35, "space") + assert parse_space_separated("1 200.35.3") is None + assert parse_space_separated("") is None assert parse_space_separated(None) is None - assert parse_space_separated('abc 123') is None - assert parse_space_separated('123 abc 123') is None + assert parse_space_separated("abc 123") is None + assert parse_space_separated("123 abc 123") is None def test_parse_date(): - assert auto_detect_datetime('2019-03-15') is not None - assert auto_detect_datetime('2019-Dec-15') is not None - assert auto_detect_datetime('2019-Oct-15') is not None + assert auto_detect_datetime("2019-03-15") is not None + assert auto_detect_datetime("2019-Dec-15") is not None + assert auto_detect_datetime("2019-Oct-15") is not None - assert auto_detect_datetime('2019-03-15') is not None - assert auto_detect_datetime('2019-Dec-15') is not None - assert auto_detect_datetime('2019-Oct-15') is not None + assert auto_detect_datetime("2019-03-15") is not None + assert auto_detect_datetime("2019-Dec-15") is not None + assert auto_detect_datetime("2019-Oct-15") is not None - assert auto_detect_datetime('2019-December-15') is not None - assert auto_detect_datetime('2019 October 15') is not None + assert auto_detect_datetime("2019-December-15") is not None + assert auto_detect_datetime("2019 October 15") is not None diff --git a/apps/tabular/utils.py b/apps/tabular/utils.py index 474476158a..deedfd4e95 100644 --- a/apps/tabular/utils.py +++ b/apps/tabular/utils.py @@ -1,55 +1,51 @@ -import re import random +import re from datetime import datetime + from geo.models import GeoArea from utils.common import calculate_sample_size, get_max_occurence_and_count - DATE_FORMATS = [ - '%m-%d-%Y', - '%m/%d/%Y', - '%m.%d.%Y', - '%m %d %Y', - - '%Y-%m-%d', - '%Y/%m/%d', - '%Y.%m.%d', - '%Y %m %d', - - '%d %b %Y', # 12 Jan 2019 - '%d-%b-%Y', - '%d/%b/%Y', - '%d.%b.%Y', - - '%Y %b %d', # 2019 Jan 12 - '%Y-%b-%d', # 2019-Jan-12 - '%Y/%b/%d', # 2019/Jan/12 - '%Y %B %d', # 2019 January 12 - '%Y-%B-%d', # 2019-January-12 - '%d %B %Y', # 12 January 2019 - - '%d-%m-%Y', - '%d/%m/%Y', - '%d.%m.%Y', - '%d %m %Y', + "%m-%d-%Y", + "%m/%d/%Y", + "%m.%d.%Y", + "%m %d %Y", + "%Y-%m-%d", + "%Y/%m/%d", + "%Y.%m.%d", + "%Y %m %d", + "%d %b %Y", # 12 Jan 2019 + "%d-%b-%Y", + "%d/%b/%Y", + "%d.%b.%Y", + "%Y %b %d", # 2019 Jan 12 + "%Y-%b-%d", # 2019-Jan-12 + "%Y/%b/%d", # 2019/Jan/12 + "%Y %B %d", # 2019 January 12 + "%Y-%B-%d", # 2019-January-12 + "%d %B %Y", # 12 January 2019 + "%d-%m-%Y", + "%d/%m/%Y", + "%d.%m.%Y", + "%d %m %Y", ] -COMMA_SEPARATED_NUMBER = re.compile(r'^(\d{1,3})(,\d{3})*(\.\d+)?$') -SPACE_SEPARATED_NUMBER = re.compile(r'^(\d{1,3})( \d{3})*(\.\d+)?$') -DOT_SEPARATED_NUMBER = re.compile(r'^(\d{1,3})(\.\d{3})*(,\d+)?$') +COMMA_SEPARATED_NUMBER = re.compile(r"^(\d{1,3})(,\d{3})*(\.\d+)?$") +SPACE_SEPARATED_NUMBER = re.compile(r"^(\d{1,3})( \d{3})*(\.\d+)?$") +DOT_SEPARATED_NUMBER = re.compile(r"^(\d{1,3})(\.\d{3})*(,\d+)?$") def parse_number(val, **kwargs): val = str(val) - separator = kwargs.get('separator') - if separator == 'comma': + separator = kwargs.get("separator") + if separator == "comma": return parse_comma_separated(val) - elif separator == 'dot': + elif separator == "dot": return parse_dot_separated(val) - elif separator == 'space': + elif separator == "space": return parse_space_separated(val) - elif separator == 'none': + elif separator == "none": return parse_none_separated(val) elif separator is None: return parse_no_separator(val) @@ -57,17 +53,13 @@ def parse_number(val, **kwargs): def parse_no_separator(val): return ( - parse_none_separated(val) or - parse_comma_separated(val) or - parse_dot_separated(val) or - parse_space_separated(val) or - None + parse_none_separated(val) or parse_comma_separated(val) or parse_dot_separated(val) or parse_space_separated(val) or None ) def parse_none_separated(numstring): try: - return float(numstring), 'none' + return float(numstring), "none" except (TypeError, ValueError): return None @@ -76,8 +68,8 @@ def parse_comma_separated(numstring): try: if not COMMA_SEPARATED_NUMBER.match(numstring.strip()): return None - comma_removed = numstring.replace(',', '') - return float(comma_removed), 'comma' + comma_removed = numstring.replace(",", "") + return float(comma_removed), "comma" except (ValueError, TypeError, AttributeError): # Attribute error is raised by numstring.replace if numstring is None return None @@ -88,10 +80,10 @@ def parse_dot_separated(numstring): if not DOT_SEPARATED_NUMBER.match(numstring.strip()): return None # first, remove dot - dot_removed = numstring.replace('.', '') + dot_removed = numstring.replace(".", "") # now replace comma with dot, to make it parseable - comma_replaced = dot_removed.replace(',', '.') - return float(comma_replaced), 'dot' + comma_replaced = dot_removed.replace(",", ".") + return float(comma_replaced), "dot" except (ValueError, TypeError, AttributeError): # Attribute error is raised by numstring.replace if numstring is None return None @@ -102,8 +94,8 @@ def parse_space_separated(numstring): if not SPACE_SEPARATED_NUMBER.match(numstring.strip()): return None # first, remove space - space_removed = numstring.replace(' ', '') - return float(space_removed), 'space' + space_removed = numstring.replace(" ", "") + return float(space_removed), "space" except (ValueError, TypeError, AttributeError): # Attribute error is raised by numstring.replace if numstring is None return None @@ -141,31 +133,34 @@ def get_geos_dict(project=None, **kwargs): if project is None: return {} - geos = GeoArea.objects.filter( - admin_level__region__project=project - ).values( - 'id', 'code', 'admin_level__level', 'title', 'admin_level_id', - 'admin_level__region', 'admin_level__region__title', + geos = GeoArea.objects.filter(admin_level__region__project=project).values( + "id", + "code", + "admin_level__level", + "title", + "admin_level_id", + "admin_level__region", + "admin_level__region__title", ) admin_levels_areas = {} for geo in geos: - admin_level_data = admin_levels_areas.get(geo['admin_level__level'], {}) - admin_level_data[geo['title'].lower()] = { - "admin_level": geo['admin_level__level'], - "admin_level_id": geo['admin_level_id'], - "title": geo['title'], - "code": geo['code'], - "id": geo['id'], - "region": geo['admin_level__region'], - "region_title": geo['admin_level__region__title'], + admin_level_data = admin_levels_areas.get(geo["admin_level__level"], {}) + admin_level_data[geo["title"].lower()] = { + "admin_level": geo["admin_level__level"], + "admin_level_id": geo["admin_level_id"], + "title": geo["title"], + "code": geo["code"], + "id": geo["id"], + "region": geo["admin_level__region"], + "region_title": geo["admin_level__region__title"], } - admin_levels_areas[geo['admin_level__level']] = admin_level_data + admin_levels_areas[geo["admin_level__level"]] = admin_level_data return admin_levels_areas def parse_geo(value, geos_names={}, geos_codes={}, **kwargs): val = str(value).lower() - admin_level = kwargs.get('admin_level') + admin_level = kwargs.get("admin_level") name_matched = None for level, geos in geos_names.items(): @@ -176,7 +171,7 @@ def parse_geo(value, geos_names={}, geos_codes={}, **kwargs): break if name_matched: - return {**name_matched, 'geo_type': 'name'} + return {**name_matched, "geo_type": "name"} code_matched = None for level, geos in geos_codes.items(): @@ -185,7 +180,7 @@ def parse_geo(value, geos_names={}, geos_codes={}, **kwargs): code_matched = None if code_matched: break - return code_matched and {**code_matched, 'geo_type': 'code'} + return code_matched and {**code_matched, "geo_type": "code"} def sample_and_detect_type_and_options(values, geos_names={}, geos_codes={}): @@ -193,10 +188,7 @@ def sample_and_detect_type_and_options(values, geos_names={}, geos_codes={}): from .models import Field # noqa if not values: - return { - 'type': Field.STRING, - 'options': {} - } + return {"type": Field.STRING, "options": {}} length = len(values) sample_size = calculate_sample_size(length, 95, prob=0.8) @@ -211,7 +203,7 @@ def sample_and_detect_type_and_options(values, geos_names={}, geos_codes={}): number_options = [] for sample in samples: - value = sample['value'] + value = sample["value"] number_parsed = parse_number(value) if number_parsed: types.append(Field.NUMBER) @@ -222,17 +214,19 @@ def sample_and_detect_type_and_options(values, geos_names={}, geos_codes={}): if formats_parsed: types.append(Field.DATETIME) # Append all detected formats - date_options.extend([{'date_format': x[1]} for x in formats_parsed]) + date_options.extend([{"date_format": x[1]} for x in formats_parsed]) continue geo_parsed = parse_geo(value, geos_names, geos_codes) if geo_parsed is not None: types.append(Field.GEO) - geo_options.append({ - 'geo_type': geo_parsed['geo_type'], - 'admin_level': geo_parsed['admin_level'], - 'region': geo_parsed['region'], - }) + geo_options.append( + { + "geo_type": geo_parsed["geo_type"], + "admin_level": geo_parsed["admin_level"], + "region": geo_parsed["region"], + } + ) continue types.append(Field.STRING) @@ -243,24 +237,20 @@ def sample_and_detect_type_and_options(values, geos_names={}, geos_codes={}): # Now find dominant option value if max_type == Field.DATETIME: - max_format, max_count = get_max_occurence_and_count([ - x['date_format'] for x in date_options - ]) - max_options = {'date_format': max_format} + max_format, max_count = get_max_occurence_and_count([x["date_format"] for x in date_options]) + max_options = {"date_format": max_format} elif max_type == Field.NUMBER: max_format, max_count = get_max_occurence_and_count(number_options) - max_options = {'separator': max_format} + max_options = {"separator": max_format} elif max_type == Field.GEO: max_options = get_geo_options(geo_options) - return { - 'type': max_type, - 'options': max_options - } + return {"type": max_type, "options": max_options} def get_cast_function(type, geos_names, geos_codes): from .models import Field + if type == Field.STRING: cast_func = parse_string elif type == Field.NUMBER: @@ -273,27 +263,12 @@ def get_cast_function(type, geos_names, geos_codes): def get_geo_options(geo_options): - max_geo, max_count = get_max_occurence_and_count([ - x['geo_type'] for x in geo_options - ]) - max_admin, max_count = get_max_occurence_and_count([ - x['admin_level'] for x in geo_options - ]) - - max_region, max_count = get_max_occurence_and_count([ - x['region'] for x in geo_options - ]) - return { - 'geo_type': max_geo, - 'region': max_region, - 'admin_level': max_admin - } + max_geo, max_count = get_max_occurence_and_count([x["geo_type"] for x in geo_options]) + max_admin, max_count = get_max_occurence_and_count([x["admin_level"] for x in geo_options]) + + max_region, max_count = get_max_occurence_and_count([x["region"] for x in geo_options]) + return {"geo_type": max_geo, "region": max_region, "admin_level": max_admin} def get_geos_codes_from_geos_names(geos_names): - return { - level: { - v['code'].lower(): v for k, v in admin_level_data.items() - } - for level, admin_level_data in geos_names.items() - } + return {level: {v["code"].lower(): v for k, v in admin_level_data.items()} for level, admin_level_data in geos_names.items()} diff --git a/apps/tabular/views.py b/apps/tabular/views.py index 6095bd44af..6ec92ba48a 100644 --- a/apps/tabular/views.py +++ b/apps/tabular/views.py @@ -1,28 +1,21 @@ from django.conf import settings from django.db import transaction -from rest_framework.decorators import action -from rest_framework import ( - viewsets, - exceptions, - response, - permissions, - views, -) - from entry.models import Entry +from rest_framework import exceptions, permissions, response, views, viewsets +from rest_framework.decorators import action -from .models import Book, Sheet, Field, Geodata -from .tasks import tabular_extract_book, tabular_extract_geo +from .models import Book, Field, Geodata, Sheet from .serializers import ( - BookSerializer, BookMetaSerializer, BookProcessedOnlySerializer, - SheetSerializer, - SheetMetaSerializer, - FieldSerializer, + BookSerializer, FieldProcessedOnlySerializer, + FieldSerializer, GeodataSerializer, + SheetMetaSerializer, + SheetSerializer, ) +from .tasks import tabular_extract_book, tabular_extract_geo class BookViewSet(viewsets.ModelViewSet): @@ -31,13 +24,13 @@ class BookViewSet(viewsets.ModelViewSet): permission_classes = [permissions.IsAuthenticated] def get_serializer_class(self): - if self.action == 'list': + if self.action == "list": return BookMetaSerializer return super().get_serializer_class() @action( detail=True, - url_path='processed', + url_path="processed", serializer_class=BookProcessedOnlySerializer, ) def get_processed_only(self, request, pk=None, version=None): @@ -47,60 +40,61 @@ def get_processed_only(self, request, pk=None, version=None): @action( detail=True, - url_path='fields', - methods=['post'], + url_path="fields", + methods=["post"], serializer_class=FieldProcessedOnlySerializer, ) def get_fields(self, request, pk=None, version=None): instance = self.get_object() - fields = request.data.get('fields', []) + fields = request.data.get("fields", []) pending_fields = instance.get_pending_fields_id() fields = instance.get_processed_fields(fields) serializer = self.get_serializer(fields, many=True) - return response.Response({ - 'pending_fields': pending_fields, - 'fields': serializer.data, - }) + return response.Response( + { + "pending_fields": pending_fields, + "fields": serializer.data, + } + ) @action( detail=True, - url_path='entry-count', + url_path="entry-count", ) def get_entry_count(self, request, pk=None, version=None): instance = self.get_object() count = Entry.objects.filter( tabular_field__sheet__book=instance.id, ).count() - return response.Response({ - 'count': count, - }) + return response.Response( + { + "count": count, + } + ) @action( detail=True, - url_path='sheets', - methods=['patch'], + url_path="sheets", + methods=["patch"], ) def update_sheets(self, request, pk=None, version=None): instance = self.get_object() - sheets = request.data.get('sheets', []) - sheet_maps = {x['id']: x for x in sheets} + sheets = request.data.get("sheets", []) + sheet_maps = {x["id"]: x for x in sheets} - sheet_objs = Sheet.objects.filter( - book=instance, - id__in=[x['id'] for x in sheets] - ) + sheet_objs = Sheet.objects.filter(book=instance, id__in=[x["id"] for x in sheets]) for sheet in sheet_objs: serializer = SheetSerializer( sheet, data=sheet_maps[sheet.id], - context={'request': request}, + context={"request": request}, partial=True, ) serializer.is_valid() serializer.update(sheet, sheet_maps[sheet.id]) return response.Response( - BookMetaSerializer(instance, context={'request': request}).data, + BookMetaSerializer(instance, context={"request": request}).data, ) @@ -111,31 +105,28 @@ class SheetViewSet(viewsets.ModelViewSet): @action( detail=True, - url_path='fields', - methods=['patch'], + url_path="fields", + methods=["patch"], ) def update_fields(self, request, pk=None, version=None): instance = self.get_object() - fields = request.data.get('fields', []) - field_maps = {x['id']: x for x in fields} + fields = request.data.get("fields", []) + field_maps = {x["id"]: x for x in fields} - field_objs = Field.objects.filter( - sheet=instance, - id__in=[x['id'] for x in fields] - ) + field_objs = Field.objects.filter(sheet=instance, id__in=[x["id"] for x in fields]) for field in field_objs: serializer = FieldSerializer( field, data=field_maps[field.id], - context={'request': request}, + context={"request": request}, partial=True, ) serializer.is_valid() serializer.update(field, field_maps[field.id]) return response.Response( - SheetMetaSerializer(instance, context={'request': request}).data, + SheetMetaSerializer(instance, context={"request": request}).data, ) @@ -155,6 +146,7 @@ class TabularExtractionTriggerView(views.APIView): """ A trigger for extracting tabular data for book """ + permission_classes = [permissions.IsAuthenticated] def post(self, request, book_id, version=None): @@ -164,22 +156,21 @@ def post(self, request, book_id, version=None): book = Book.objects.get(id=book_id) if book.status == Book.SUCCESS: - return response.Response({'book_id': book.pk}) + return response.Response({"book_id": book.pk}) if not settings.TESTING: - transaction.on_commit( - lambda: tabular_extract_book.delay(book.pk) - ) + transaction.on_commit(lambda: tabular_extract_book.delay(book.pk)) book.status = Book.PENDING book.save() - return response.Response({'book_id': book.pk}) + return response.Response({"book_id": book.pk}) class TabularGeoProcessTriggerView(views.APIView): """ A trigger for processing geo data for given field """ + permission_classes = [permissions.IsAuthenticated] def post(self, request, field_id, version=None): @@ -192,13 +183,11 @@ def post(self, request, field_id, version=None): geodata = Geodata.objects.create(field=field) if geodata.status == Geodata.SUCCESS: - return response.Response({'geodata_id': geodata.pk}) + return response.Response({"geodata_id": geodata.pk}) if not settings.TESTING: - transaction.on_commit( - lambda: tabular_extract_geo.delay(geodata.pk) - ) + transaction.on_commit(lambda: tabular_extract_geo.delay(geodata.pk)) geodata.status = geodata.PENDING geodata.save() - return response.Response({'geodata_id': geodata.pk}) + return response.Response({"geodata_id": geodata.pk}) diff --git a/apps/tabular/viz/barchart.py b/apps/tabular/viz/barchart.py index 3e94c72722..a0334a52ee 100644 --- a/apps/tabular/viz/barchart.py +++ b/apps/tabular/viz/barchart.py @@ -7,21 +7,21 @@ import matplotlib.pyplot as plt import plotly.graph_objs as go except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") @create_plot_image def plot(x_label, y_label, data, horizontal=False): chart_basic_config = { - 'width': 0.8, - 'color': 'teal', + "width": 0.8, + "color": "teal", } if horizontal: data.plot.barh(**chart_basic_config) - plt.locator_params(axis='y', nbins=24) + plt.locator_params(axis="y", nbins=24) else: data.plot.bar(**chart_basic_config) - plt.locator_params(axis='x', nbins=39) + plt.locator_params(axis="x", nbins=39) plt.xlabel(x_label) plt.ylabel(y_label) plt.gca().get_legend().remove() @@ -30,10 +30,10 @@ def plot(x_label, y_label, data, horizontal=False): @create_plotly_image def plotly(data, horizontal=False): bar = go.Bar( - x=data['count'] if horizontal else data['value'], - y=data['value'] if horizontal else data['count'], + x=data["count"] if horizontal else data["value"], + y=data["value"] if horizontal else data["count"], marker=create_plotly_image.marker, - orientation='h' if horizontal else 'v', - opacity=0.8 + orientation="h" if horizontal else "v", + opacity=0.8, ) return [bar], None diff --git a/apps/tabular/viz/histograms.py b/apps/tabular/viz/histograms.py index 0ae09b1cee..d2f8fc630d 100644 --- a/apps/tabular/viz/histograms.py +++ b/apps/tabular/viz/histograms.py @@ -1,19 +1,19 @@ import logging -from utils.common import create_plot_image, create_plotly_image +from utils.common import create_plot_image, create_plotly_image logger = logging.getLogger(__name__) try: import matplotlib.pyplot as plt import plotly.graph_objs as go except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") @create_plot_image def plot(x_label, y_label, data): # data.plot.hist(color='teal', edgecolor='white', linewidth=0.4) - data.plot.hist(color='teal') + data.plot.hist(color="teal") plt.ylabel(y_label) plt.xlabel(x_label) diff --git a/apps/tabular/viz/map.py b/apps/tabular/viz/map.py index 0697158b26..2772655ce5 100644 --- a/apps/tabular/viz/map.py +++ b/apps/tabular/viz/map.py @@ -1,19 +1,20 @@ -from utils.common import create_plot_image, make_colormap -from geo.models import GeoArea, AdminLevel, Region -import logging import json +import logging +from geo.models import AdminLevel, GeoArea, Region + +from utils.common import create_plot_image, make_colormap logger = logging.getLogger(__name__) try: - from shapely.geometry import shape - import matplotlib.pyplot as plt - import matplotlib.colors as mcolors import geopandas as gpd + import matplotlib.colors as mcolors + import matplotlib.pyplot as plt + from shapely.geometry import shape except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") def get_geoareas(selected_geoareas, admin_levels=None, regions=None): @@ -30,10 +31,14 @@ def get_geoareas(selected_geoareas, admin_levels=None, regions=None): geoareas = GeoArea.objects.filter( admin_level__level__in=AdminLevel.objects.filter( geoarea__in=selected_geoareas, - ).distinct().values_list('level', flat=True), + ) + .distinct() + .values_list("level", flat=True), admin_level__region__in=Region.objects.filter( adminlevel__geoarea__in=selected_geoareas, - ).distinct().values_list('pk', flat=True), + ) + .distinct() + .values_list("pk", flat=True), ) return geoareas @@ -41,36 +46,36 @@ def get_geoareas(selected_geoareas, admin_levels=None, regions=None): @create_plot_image def plot(*args, **kwargs): # NOTE: this are admin_level.level list not pk or admin_levels objects - admin_levels = kwargs.get('admin_levels') + admin_levels = kwargs.get("admin_levels") # NOTE: this are region pks/objects - regions = kwargs.get('regions') - df = kwargs.get('data').rename(columns={'value': 'geoarea_id'}) + regions = kwargs.get("regions") + df = kwargs.get("data").rename(columns={"value": "geoarea_id"}) shapes = [] geoareas = get_geoareas( - df['geoarea_id'].values.tolist(), + df["geoarea_id"].values.tolist(), admin_levels, regions, ) if len(geoareas) == 0: - logger.warning('Empty geoareas found') + logger.warning("Empty geoareas found") return for geoarea in geoareas: s = shape(json.loads(geoarea.polygons.geojson)) - shapes.append({'geoarea_id': geoarea.id, 'geometry': s}) - shapes_frame = gpd.GeoDataFrame(shapes, geometry='geometry') - data = shapes_frame.merge(df, on='geoarea_id', how='outer').fillna(0) + shapes.append({"geoarea_id": geoarea.id, "geometry": s}) + shapes_frame = gpd.GeoDataFrame(shapes, geometry="geometry") + data = shapes_frame.merge(df, on="geoarea_id", how="outer").fillna(0) c = mcolors.ColorConverter().to_rgb - rvb = make_colormap([c('white'), c('teal')]) + rvb = make_colormap([c("white"), c("teal")]) data.plot( - column='count', + column="count", cmap=rvb, legend=True, linewidth=0.4, - edgecolor='0.5', + edgecolor="0.5", ) - plt.axis('off') + plt.axis("off") diff --git a/apps/tabular/viz/renderer.py b/apps/tabular/viz/renderer.py index f0083b89c1..24d348d849 100644 --- a/apps/tabular/viz/renderer.py +++ b/apps/tabular/viz/renderer.py @@ -1,37 +1,34 @@ -import os import logging +import os from datetime import datetime from django.conf import settings +from gallery.models import File +from tabular.models import Field +from tabular.viz import barchart, histograms +from tabular.viz import map as mapViz +from tabular.viz import wordcloud from deep.documents_types import CHART_IMAGE_MIME from utils.common import deep_date_format -from gallery.models import File -from tabular.models import Field -from tabular.viz import ( - barchart, - wordcloud, - histograms, - map as mapViz, -) logger = logging.getLogger(__name__) try: import pandas as pd except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") def DEFAULT_CHART_RENDER(*args, **kwargs): return None -BARCHART = 'barchart' -BARCHARTH = 'barcharth' -MAP = 'map' -HISTOGRAM = 'histogram' -WORDCLOUD = 'wordcloud' +BARCHART = "barchart" +BARCHARTH = "barcharth" +MAP = "map" +HISTOGRAM = "histogram" +WORDCLOUD = "wordcloud" DEFAULT_CHART_TYPE_FIELD_MAP = BARCHART CHART_TYPE_FIELD_MAP = { @@ -39,14 +36,13 @@ def DEFAULT_CHART_RENDER(*args, **kwargs): Field.NUMBER: HISTOGRAM, } -DEFAULT_IMAGE_PATH = os.path.join(settings.BASE_DIR, 'apps/static/image/deep_chart_preview.png') +DEFAULT_IMAGE_PATH = os.path.join(settings.BASE_DIR, "apps/static/image/deep_chart_preview.png") CHART_RENDER = { # frequency data required BARCHART: barchart.plotly, BARCHARTH: lambda *args, **kwargs: barchart.plotly(*args, **kwargs, horizontal=True), MAP: mapViz.plot, - # Frequency data not required HISTOGRAM: histograms.plotly, WORDCLOUD: wordcloud.plot, @@ -55,8 +51,8 @@ def DEFAULT_CHART_RENDER(*args, **kwargs): def get_val_column(field): if field.type in [Field.GEO, Field.DATETIME, Field.NUMBER]: - return 'processed_value' - return 'value' + return "processed_value" + return "value" def clean_real_data(data, val_column): @@ -67,18 +63,20 @@ def clean_real_data(data, val_column): # TODO: Handle the following case from pandas itself formatted_data = [] for datum in data: - formatted_data.append({ - **datum, - 'empty': datum.get('empty', False), - 'invalid': datum.get('invalid', False), - }) + formatted_data.append( + { + **datum, + "empty": datum.get("empty", False), + "invalid": datum.get("invalid", False), + } + ) df = pd.DataFrame(formatted_data) if df.empty: return df, df - filterd_df = df[~(df['empty'] == True) & ~(df['invalid'] == True)] # noqa + filterd_df = df[~(df["empty"] == True) & ~(df["invalid"] == True)] # noqa return filterd_df, df @@ -88,70 +86,71 @@ def calc_data(field): data, df = clean_real_data(field.actual_data, val_column) if data.empty: - logger.warning('Empty DataFrame: no numeric data to calculate for field ({})'.format(field.pk)) + logger.warning("Empty DataFrame: no numeric data to calculate for field ({})".format(field.pk)) return [], {} if val_column not in data.columns: - logger.warning('{} not present in field ({})'.format(val_column, field.pk)) + logger.warning("{} not present in field ({})".format(val_column, field.pk)) return None, {} - data = data.groupby(val_column).count()['empty'].sort_values().to_frame() - data = data.rename(columns={'empty': 'count', val_column: 'value'}) + data = data.groupby(val_column).count()["empty"].sort_values().to_frame() + data = data.rename(columns={"empty": "count", val_column: "value"}) - data['value'] = data.index + data["value"] = data.index health_stats = { - 'empty': int(df[df['empty'] == True]['empty'].count()), # noqa - 'invalid': int(df[df['invalid'] == True]['invalid'].count()), # noqa - 'total': len(df.index), + "empty": int(df[df["empty"] == True]["empty"].count()), # noqa + "invalid": int(df[df["invalid"] == True]["invalid"].count()), # noqa + "total": len(df.index), } - return data.to_dict(orient='records'), health_stats + return data.to_dict(orient="records"), health_stats -def generate_chart(field, chart_type, images_format=['svg']): +def generate_chart(field, chart_type, images_format=["svg"]): params = { - 'x_label': field.title, - 'y_label': 'count', - 'x_params': {}, - 'chart_size': (8, 4), - 'format': images_format, + "x_label": field.title, + "y_label": "count", + "x_params": {}, + "chart_size": (8, 4), + "format": images_format, # data will be added according to chart type } if chart_type not in [HISTOGRAM, WORDCLOUD]: - df = pd.DataFrame(field.cache.get('series')) - if df.empty or 'value' not in df.columns: + df = pd.DataFrame(field.cache.get("series")) + if df.empty or "value" not in df.columns: return None - params['data'] = df + params["data"] = df if field.type == Field.STRING: # NOTE: revered is used for ascending order - params['x_params']['autorange'] = 'reversed' - params['data']['value'] = params['data']['value'].str.slice(0, 30) + '...' # Pre slice with ellipses + params["x_params"]["autorange"] = "reversed" + params["data"]["value"] = params["data"]["value"].str.slice(0, 30) + "..." # Pre slice with ellipses elif field.type == Field.GEO: - adjust_df = pd.DataFrame([ - {'value': 0, 'count': 0}, # Count 0 is min's max value - {'value': 0, 'count': 5}, # Count 5 is max's min value - ]) - params['data'] = params['data'].append(adjust_df, ignore_index=True) + adjust_df = pd.DataFrame( + [ + {"value": 0, "count": 0}, # Count 0 is min's max value + {"value": 0, "count": 5}, # Count 5 is max's min value + ] + ) + params["data"] = params["data"].append(adjust_df, ignore_index=True) elif field.type == Field.DATETIME: - if df['value'].count() > 10: - params['x_params']['tickformat'] = '%d-%m-%Y' + if df["value"].count() > 10: + params["x_params"]["tickformat"] = "%d-%m-%Y" else: - params['x_params']['type'] = 'category' - params['x_params']['ticktext'] = [ - deep_date_format(datetime.strptime(value, '%Y-%m-%dT%H:%M:%S')) - for value in df['value'] + params["x_params"]["type"] = "category" + params["x_params"]["ticktext"] = [ + deep_date_format(datetime.strptime(value, "%Y-%m-%dT%H:%M:%S")) for value in df["value"] ] - params['x_params']['tickvals'] = df['value'] + params["x_params"]["tickvals"] = df["value"] else: val_column = get_val_column(field) df, _ = clean_real_data(field.actual_data, val_column) if chart_type == HISTOGRAM: - params['data'] = pd.to_numeric(df[val_column]) + params["data"] = pd.to_numeric(df[val_column]) elif chart_type == WORDCLOUD: - params['data'] = ' '.join(df[val_column].values) + params["data"] = " ".join(df[val_column].values) - if isinstance(params['data'], pd.DataFrame) and params['data'].empty: - logger.warning('Empty DataFrame: no numeric data to plot for field ({})'.format(field.pk)) + if isinstance(params["data"], pd.DataFrame) and params["data"].empty: + logger.warning("Empty DataFrame: no numeric data to plot for field ({})".format(field.pk)) return None chart_render = CHART_RENDER.get(chart_type) @@ -168,41 +167,41 @@ def calc_preprocessed_data(field): try: series, health_stats = calc_data(field) cache = { - 'status': Field.CACHE_SUCCESS, - 'series': series, - 'health_stats': health_stats, + "status": Field.CACHE_SUCCESS, + "series": series, + "health_stats": health_stats, } # NOTE: Geo Field cache success after chart generation if field.type == Field.GEO: - cache['status'] = Field.CACHE_PENDING + cache["status"] = Field.CACHE_PENDING except Exception: cache = { - 'status': Field.CACHE_ERROR, - 'image_status': Field.CACHE_ERROR, + "status": Field.CACHE_ERROR, + "image_status": Field.CACHE_ERROR, } logger.error( - 'Tabular Processed Data Calculation Error!!', + "Tabular Processed Data Calculation Error!!", exc_info=True, - extra={'data': {'field_id': field.pk}}, + extra={"data": {"field_id": field.pk}}, ) field.cache = cache field.save() - return field.cache['status'] + return field.cache["status"] def _add_image_to_gallery(image_name, image, mime_type, project): file = File.objects.create( title=image_name, mime_type=mime_type, - metadata={'tabular': True}, + metadata={"tabular": True}, is_public=False, ) file.file.save(image_name, image) if project: file.projects.add(project) logger.info( - 'Added image to tabular gallery {}(id={})'.format(image_name, file.id), + "Added image to tabular gallery {}(id={})".format(image_name, file.id), ) return file @@ -211,68 +210,68 @@ def render_field_chart(field): """ Save normalized data to field """ - images_format = ['png', 'svg'] if field.type == Field.GEO else ['png'] + images_format = ["png", "svg"] if field.type == Field.GEO else ["png"] chart_type = CHART_TYPE_FIELD_MAP.get(field.type, DEFAULT_CHART_TYPE_FIELD_MAP) try: images = generate_chart(field, chart_type, images_format=images_format) except Exception: - logger.error( - 'Tabular Chart Render Error!!', - exc_info=True, - extra={'data': {'field_id': field.pk}} - ) + logger.error("Tabular Chart Render Error!!", exc_info=True, extra={"data": {"field_id": field.pk}}) images = [] project = field.sheet.book.project if images and len(images) > 0: field_images = [] for image in images: - file_format = image['format'] - file_content = image['image'] + file_format = image["format"] + file_content = image["image"] file_mime = CHART_IMAGE_MIME[file_format] file = _add_image_to_gallery( - 'tabular_{}_{}.{}'.format(field.sheet.id, field.id, file_format), + "tabular_{}_{}.{}".format(field.sheet.id, field.id, file_format), file_content, file_mime, project, ) - field_images.append({ - 'id': file.id, 'chart_type': chart_type, 'format': file_format, - }) - field.cache['image_status'] = Field.CACHE_SUCCESS + field_images.append( + { + "id": file.id, + "chart_type": chart_type, + "format": file_format, + } + ) + field.cache["image_status"] = Field.CACHE_SUCCESS if field.type == Field.GEO: - field.cache['status'] = Field.CACHE_SUCCESS + field.cache["status"] = Field.CACHE_SUCCESS else: field_images = [] for image_format in images_format: - field_images.append({'id': None, 'chart_type': chart_type, 'format': image_format}) - field.cache['image_status'] = Field.CACHE_ERROR + field_images.append({"id": None, "chart_type": chart_type, "format": image_format}) + field.cache["image_status"] = Field.CACHE_ERROR if field.type == Field.GEO: - field.cache['status'] = Field.CACHE_ERROR - field.cache['images'] = field_images + field.cache["status"] = Field.CACHE_ERROR + field.cache["images"] = field_images field.save() - return field.cache['images'] + return field.cache["images"] def get_entry_image(entry): """ Use cached Graph for given entry """ - default_image = open(DEFAULT_IMAGE_PATH, 'rb') + default_image = open(DEFAULT_IMAGE_PATH, "rb") if not entry.tabular_field: return default_image field = entry.tabular_field - images = field.cache.get('images') + images = field.cache.get("images") if not images or not len(images) > 0: return default_image for image in images: - if image.get('id') is not None and image.get('format') == 'png': - file_id = images[0].get('id') + if image.get("id") is not None and image.get("format") == "png": + file_id = images[0].get("id") return File.objects.get(pk=file_id).file return default_image diff --git a/apps/tabular/viz/wordcloud.py b/apps/tabular/viz/wordcloud.py index ea8a142b3f..c5ccd8a683 100644 --- a/apps/tabular/viz/wordcloud.py +++ b/apps/tabular/viz/wordcloud.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from wordcloud import WordCloud except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") @create_plot_image @@ -17,5 +17,5 @@ def plot(x_label, y_label, data): wordcloud = WordCloud(background_color="white").generate(data) # Display the generated image: - plt.imshow(wordcloud, interpolation='bilinear') + plt.imshow(wordcloud, interpolation="bilinear") plt.axis("off") diff --git a/apps/unified_connector/admin.py b/apps/unified_connector/admin.py index b0bd4a7814..a843680b33 100644 --- a/apps/unified_connector/admin.py +++ b/apps/unified_connector/admin.py @@ -1,10 +1,10 @@ from django.contrib import admin from .models import ( - UnifiedConnector, + ConnectorLead, ConnectorSource, ConnectorSourceLead, - ConnectorLead, + UnifiedConnector, ) @@ -26,11 +26,17 @@ class ConnectorSourceLeadAdmin(admin.ModelAdmin): @admin.register(ConnectorLead) class ConnectorLeadAdmin(admin.ModelAdmin): list_display = [ - 'id', - 'title', - 'created_at', - 'modified_at', + "id", + "title", + "created_at", + "modified_at", ] - readonly_fields = ('created_at', 'modified_at',) - autocomplete_fields = ('authors', 'source',) - search_fields = ('title',) + readonly_fields = ( + "created_at", + "modified_at", + ) + autocomplete_fields = ( + "authors", + "source", + ) + search_fields = ("title",) diff --git a/apps/unified_connector/apps.py b/apps/unified_connector/apps.py index 8ae418eaf4..7f8401fdf5 100644 --- a/apps/unified_connector/apps.py +++ b/apps/unified_connector/apps.py @@ -2,4 +2,4 @@ class UnifiedConnectorConfig(AppConfig): - name = 'unified_connector' + name = "unified_connector" diff --git a/apps/unified_connector/dataloaders.py b/apps/unified_connector/dataloaders.py index 514c0e4e9a..276cc2ea98 100644 --- a/apps/unified_connector/dataloaders.py +++ b/apps/unified_connector/dataloaders.py @@ -1,33 +1,31 @@ from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property + from django.db import models +from django.utils.functional import cached_property +from organization.dataloaders import OrganizationLoader +from organization.models import Organization +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from organization.models import Organization -from organization.dataloaders import OrganizationLoader - -from .models import ( - ConnectorLead, - ConnectorSourceLead, - ConnectorSource, -) +from .models import ConnectorLead, ConnectorSource, ConnectorSourceLead -DEFAULT_SOURCE_LEAD_COUNT = {'total': 0, 'already_added': 0, 'blocked': 0} +DEFAULT_SOURCE_LEAD_COUNT = {"total": 0, "already_added": 0, "blocked": 0} class UnifiedConnectorLeadsCount(DataLoaderWithContext): def batch_load_fn(self, keys): - connector_leads_qs = ConnectorSourceLead.objects\ - .filter(source__unified_connector__in=keys)\ - .order_by().values('source__unified_connector')\ + connector_leads_qs = ( + ConnectorSourceLead.objects.filter(source__unified_connector__in=keys) + .order_by() + .values("source__unified_connector") .annotate( - count=models.Count('connector_lead', distinct=True), - already_count=models.Count('connector_lead', distinct=True, filter=models.Q(already_added=True)), - blocked_count=models.Count('connector_lead', distinct=True, filter=models.Q(blocked=True)), - )\ - .values_list('source__unified_connector', 'count', 'already_count', 'blocked_count') + count=models.Count("connector_lead", distinct=True), + already_count=models.Count("connector_lead", distinct=True, filter=models.Q(already_added=True)), + blocked_count=models.Count("connector_lead", distinct=True, filter=models.Q(blocked=True)), + ) + .values_list("source__unified_connector", "count", "already_count", "blocked_count") + ) _map = { uc: dict( total=count or 0, @@ -41,9 +39,7 @@ def batch_load_fn(self, keys): class UnifiedConnectorSources(DataLoaderWithContext): def batch_load_fn(self, keys): - connector_source_qs = ConnectorSource.objects\ - .filter(unified_connector__in=keys)\ - .order_by('id') + connector_source_qs = ConnectorSource.objects.filter(unified_connector__in=keys).order_by("id") _map = defaultdict(list) for connector_source in connector_source_qs: _map[connector_source.unified_connector_id].append(connector_source) @@ -52,15 +48,17 @@ def batch_load_fn(self, keys): class ConnectorSourceLeadsCount(DataLoaderWithContext): def batch_load_fn(self, keys): - connector_leads_qs = ConnectorSourceLead.objects\ - .filter(source__in=keys)\ - .order_by().values('source')\ + connector_leads_qs = ( + ConnectorSourceLead.objects.filter(source__in=keys) + .order_by() + .values("source") .annotate( - count=models.Count('connector_lead', distinct=True), - already_count=models.Count('connector_lead', distinct=True, filter=models.Q(already_added=True)), - blocked_count=models.Count('connector_lead', distinct=True, filter=models.Q(blocked=True)), - )\ - .values_list('source', 'count', 'already_count', 'blocked_count') + count=models.Count("connector_lead", distinct=True), + already_count=models.Count("connector_lead", distinct=True, filter=models.Q(already_added=True)), + blocked_count=models.Count("connector_lead", distinct=True, filter=models.Q(blocked=True)), + ) + .values_list("source", "count", "already_count", "blocked_count") + ) _map = { uc: dict( total=count or 0, @@ -74,22 +72,18 @@ def batch_load_fn(self, keys): class ConnectorSourceLeadLead(DataLoaderWithContext): def batch_load_fn(self, keys): - connector_leads_qs = ConnectorLead.objects\ - .filter(id__in=keys)\ - .order_by('id') - _map = { - connector_lead.pk: connector_lead - for connector_lead in connector_leads_qs - } + connector_leads_qs = ConnectorLead.objects.filter(id__in=keys).order_by("id") + _map = {connector_lead.pk: connector_lead for connector_lead in connector_leads_qs} return Promise.resolve([_map[key] for key in keys]) class ConnectorLeadAuthors(DataLoaderWithContext): def batch_load_fn(self, keys): - connector_lead_author_qs = ConnectorLead.objects\ - .filter(id__in=keys, authors__isnull=False)\ - .order_by('authors__id')\ - .values_list('id', 'authors__id') + connector_lead_author_qs = ( + ConnectorLead.objects.filter(id__in=keys, authors__isnull=False) + .order_by("authors__id") + .values_list("id", "authors__id") + ) connector_lead_authors_ids = defaultdict(list) organizations_id = set() for connector_lead_id, author_id in connector_lead_author_qs: @@ -97,16 +91,10 @@ def batch_load_fn(self, keys): organizations_id.add(author_id) organization_qs = Organization.objects.filter(id__in=organizations_id) - organizations_map = { - org.id: org for org in organization_qs - } - return Promise.resolve([ - [ - organizations_map.get(author) - for author in connector_lead_authors_ids.get(key, []) - ] - for key in keys - ]) + organizations_map = {org.id: org for org in organization_qs} + return Promise.resolve( + [[organizations_map.get(author) for author in connector_lead_authors_ids.get(key, [])] for key in keys] + ) class DataLoaders(WithContextMixin): diff --git a/apps/unified_connector/enums.py b/apps/unified_connector/enums.py index 44397d1bb7..225179da1f 100644 --- a/apps/unified_connector/enums.py +++ b/apps/unified_connector/enums.py @@ -5,12 +5,13 @@ get_enum_name_from_django_field, ) -from .models import ConnectorSource, ConnectorLead +from .models import ConnectorLead, ConnectorSource -ConnectorSourceSourceEnum = convert_enum_to_graphene_enum(ConnectorSource.Source, name='ConnectorSourceSourceEnum') +ConnectorSourceSourceEnum = convert_enum_to_graphene_enum(ConnectorSource.Source, name="ConnectorSourceSourceEnum") ConnectorLeadExtractionStatusEnum = convert_enum_to_graphene_enum( - ConnectorLead.ExtractionStatus, name='ConnectorLeadExtractionStatusEnum') -ConnectorSourceStatusEnum = convert_enum_to_graphene_enum(ConnectorSource.Status, name='ConnectorSourceStatusEnum') + ConnectorLead.ExtractionStatus, name="ConnectorLeadExtractionStatusEnum" +) +ConnectorSourceStatusEnum = convert_enum_to_graphene_enum(ConnectorSource.Status, name="ConnectorSourceStatusEnum") enum_map = { get_enum_name_from_django_field(field): enum @@ -24,34 +25,34 @@ class UnifiedConnectorOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_CREATED_AT = 'created_at' - ASC_TITLE = 'title' + ASC_ID = "id" + ASC_CREATED_AT = "created_at" + ASC_TITLE = "title" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_TITLE = f'-{ASC_TITLE}' + DESC_ID = f"-{ASC_ID}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_TITLE = f"-{ASC_TITLE}" class ConnectorSourceOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_CREATED_AT = 'created_at' - ASC_TITLE = 'title' - ASC_SOURCE = 'source' + ASC_ID = "id" + ASC_CREATED_AT = "created_at" + ASC_TITLE = "title" + ASC_SOURCE = "source" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_CREATED_AT = f'-{ASC_CREATED_AT}' - DESC_TITLE = f'-{ASC_TITLE}' - DESC_SOURCE = f'-{ASC_SOURCE}' + DESC_ID = f"-{ASC_ID}" + DESC_CREATED_AT = f"-{ASC_CREATED_AT}" + DESC_TITLE = f"-{ASC_TITLE}" + DESC_SOURCE = f"-{ASC_SOURCE}" class ConnectorSourceLeadOrderingEnum(graphene.Enum): # ASC - ASC_ID = 'id' - ASC_LEAD_CREATED_AT = 'connector_lead__created_at' - ASC_LEAD_TITLE = 'connector_lead__title' + ASC_ID = "id" + ASC_LEAD_CREATED_AT = "connector_lead__created_at" + ASC_LEAD_TITLE = "connector_lead__title" # DESC - DESC_ID = f'-{ASC_ID}' - DESC_LEAD_CREATED_AT = f'-{ASC_LEAD_CREATED_AT}' - DESC_LEAD_TITLE = f'-{ASC_LEAD_TITLE}' + DESC_ID = f"-{ASC_ID}" + DESC_LEAD_CREATED_AT = f"-{ASC_LEAD_CREATED_AT}" + DESC_LEAD_TITLE = f"-{ASC_LEAD_TITLE}" diff --git a/apps/unified_connector/factories.py b/apps/unified_connector/factories.py index dbef764b3d..3580eed27b 100644 --- a/apps/unified_connector/factories.py +++ b/apps/unified_connector/factories.py @@ -1,31 +1,30 @@ import factory from factory.django import DjangoModelFactory - from unified_connector.models import ( - UnifiedConnector, - ConnectorSource, ConnectorLead, + ConnectorSource, ConnectorSourceLead, + UnifiedConnector, ) class UnifiedConnectorFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Unified-Connector-{n}') + title = factory.Sequence(lambda n: f"Unified-Connector-{n}") class Meta: model = UnifiedConnector class ConnectorSourceFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Connector-Source-{n}') + title = factory.Sequence(lambda n: f"Connector-Source-{n}") class Meta: model = ConnectorSource class ConnectorLeadFactory(DjangoModelFactory): - title = factory.Sequence(lambda n: f'Connector-Lead-{n}') - url = factory.Sequence(lambda n: f'https://example.com/path-{n}') + title = factory.Sequence(lambda n: f"Connector-Lead-{n}") + url = factory.Sequence(lambda n: f"https://example.com/path-{n}") class Meta: model = ConnectorLead diff --git a/apps/unified_connector/filters.py b/apps/unified_connector/filters.py index fba36bc7d7..fdc9e544cc 100644 --- a/apps/unified_connector/filters.py +++ b/apps/unified_connector/filters.py @@ -1,33 +1,29 @@ -from django.db import models import django_filters +from django.db import models from deep.filter_set import OrderEnumMixin from utils.graphene.filters import ( - MultipleInputFilter, - IDListFilter, DateGteFilter, DateLteFilter, + IDListFilter, + MultipleInputFilter, ) -from .models import ( - ConnectorSource, - ConnectorSourceLead, - UnifiedConnector, -) from .enums import ( - ConnectorSourceSourceEnum, ConnectorLeadExtractionStatusEnum, ConnectorSourceLeadOrderingEnum, ConnectorSourceOrderingEnum, + ConnectorSourceSourceEnum, ConnectorSourceStatusEnum, UnifiedConnectorOrderingEnum, ) +from .models import ConnectorSource, ConnectorSourceLead, UnifiedConnector # ------------------------------ Graphql filters ----------------------------------- class UnifiedConnectorGQFilterSet(OrderEnumMixin, django_filters.FilterSet): - search = django_filters.CharFilter(field_name='title', lookup_expr='icontains') - ordering = MultipleInputFilter(UnifiedConnectorOrderingEnum, method='ordering_filter') + search = django_filters.CharFilter(field_name="title", lookup_expr="icontains") + ordering = MultipleInputFilter(UnifiedConnectorOrderingEnum, method="ordering_filter") is_active = django_filters.BooleanFilter() class Meta: @@ -36,11 +32,11 @@ class Meta: class ConnectorSourceGQFilterSet(OrderEnumMixin, django_filters.FilterSet): - search = django_filters.CharFilter(field_name='title', lookup_expr='icontains') - ordering = MultipleInputFilter(ConnectorSourceOrderingEnum, method='ordering_filter') - sources = MultipleInputFilter(ConnectorSourceSourceEnum, field_name='source') - statuses = MultipleInputFilter(ConnectorSourceStatusEnum, field_name='status') - unified_connectors = IDListFilter(field_name='unified_connector') + search = django_filters.CharFilter(field_name="title", lookup_expr="icontains") + ordering = MultipleInputFilter(ConnectorSourceOrderingEnum, method="ordering_filter") + sources = MultipleInputFilter(ConnectorSourceSourceEnum, field_name="source") + statuses = MultipleInputFilter(ConnectorSourceStatusEnum, field_name="status") + unified_connectors = IDListFilter(field_name="unified_connector") class Meta: model = ConnectorSource @@ -48,18 +44,17 @@ class Meta: class ConnectorSourceLeadGQFilterSet(OrderEnumMixin, django_filters.FilterSet): - ordering = MultipleInputFilter(ConnectorSourceLeadOrderingEnum, method='ordering_filter') - sources = IDListFilter(field_name='source') + ordering = MultipleInputFilter(ConnectorSourceLeadOrderingEnum, method="ordering_filter") + sources = IDListFilter(field_name="source") blocked = django_filters.BooleanFilter() already_added = django_filters.BooleanFilter() - extraction_status = MultipleInputFilter( - ConnectorLeadExtractionStatusEnum, field_name='connector_lead__extraction_status') + extraction_status = MultipleInputFilter(ConnectorLeadExtractionStatusEnum, field_name="connector_lead__extraction_status") - search = django_filters.CharFilter(method='search_filter') - author_organizations = IDListFilter(field_name='connector_lead__authors') - published_on = django_filters.DateFilter(field_name='connector_lead__published_on') - published_on_gte = DateGteFilter(field_name='connector_lead__published_on') - published_on_lte = DateLteFilter(field_name='connector_lead__published_on') + search = django_filters.CharFilter(method="search_filter") + author_organizations = IDListFilter(field_name="connector_lead__authors") + published_on = django_filters.DateFilter(field_name="connector_lead__published_on") + published_on_gte = DateGteFilter(field_name="connector_lead__published_on") + published_on_lte = DateLteFilter(field_name="connector_lead__published_on") class Meta: model = ConnectorSourceLead @@ -71,16 +66,19 @@ def search_filter(self, qs, _, value): return qs return qs.filter( # By title - models.Q(connector_lead__title__icontains=value) | + models.Q(connector_lead__title__icontains=value) + | # By source - models.Q(connector_lead__source_raw__icontains=value) | - models.Q(connector_lead__source__title__icontains=value) | - models.Q(connector_lead__source__parent__title__icontains=value) | + models.Q(connector_lead__source_raw__icontains=value) + | models.Q(connector_lead__source__title__icontains=value) + | models.Q(connector_lead__source__parent__title__icontains=value) + | # By author - models.Q(connector_lead__author_raw__icontains=value) | - models.Q(connector_lead__authors__title__icontains=value) | - models.Q(connector_lead__authors__parent__title__icontains=value) | + models.Q(connector_lead__author_raw__icontains=value) + | models.Q(connector_lead__authors__title__icontains=value) + | models.Q(connector_lead__authors__parent__title__icontains=value) + | # By URL - models.Q(connector_lead__url__icontains=value) | - models.Q(connector_lead__website__icontains=value) + models.Q(connector_lead__url__icontains=value) + | models.Q(connector_lead__website__icontains=value) ).distinct() diff --git a/apps/unified_connector/models.py b/apps/unified_connector/models.py index efb04a7293..513f28fd69 100644 --- a/apps/unified_connector/models.py +++ b/apps/unified_connector/models.py @@ -1,30 +1,29 @@ from typing import Union -from django.db import models -from django.db import transaction -from user_resource.models import UserResource - +from django.db import models, transaction from lead.models import Lead from organization.models import Organization from project.models import Project, ProjectStats +from user_resource.models import UserResource + from .sources import ( atom_feed, - rss_feed, - unhcr_portal, - relief_web, + emm, humanitarian_response, pdna, - emm, + relief_web, + rss_feed, + unhcr_portal, ) class ConnectorLead(models.Model): class ExtractionStatus(models.IntegerChoices): - PENDING = 0, 'Pending' - RETRYING = 1, 'Retrying' - STARTED = 2, 'Started' - SUCCESS = 3, 'Success' - FAILED = 4, 'Failed' + PENDING = 0, "Pending" + RETRYING = 1, "Retrying" + STARTED = 2, "Started" + SUCCESS = 3, "Success" + FAILED = 4, "Failed" id: Union[int, None] url = models.TextField(unique=True) @@ -33,8 +32,8 @@ class ExtractionStatus(models.IntegerChoices): published_on = models.DateField(default=None, null=True, blank=True) source_raw = models.CharField(max_length=255, blank=True) author_raw = models.CharField(max_length=255, blank=True) - authors = models.ManyToManyField(Organization, blank=True, related_name='+') - source = models.ForeignKey(Organization, related_name='+', on_delete=models.SET_NULL, null=True, blank=True) + authors = models.ManyToManyField(Organization, blank=True, related_name="+") + source = models.ForeignKey(Organization, related_name="+", on_delete=models.SET_NULL, null=True, blank=True) # Extracted data simplified_text = models.TextField(blank=True) @@ -44,9 +43,7 @@ class ExtractionStatus(models.IntegerChoices): created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) - extraction_status = models.SmallIntegerField( - choices=ExtractionStatus.choices, default=ExtractionStatus.PENDING - ) + extraction_status = models.SmallIntegerField(choices=ExtractionStatus.choices, default=ExtractionStatus.PENDING) def __init__(self, *args, **kwargs): self.preview_images: models.QuerySet[ConnectorLeadPreviewImage] @@ -59,15 +56,15 @@ def get_or_create_from_lead(cls, lead: Lead): defaults=dict( title=lead.title, published_on=lead.published_on, - source_raw=lead.source_raw or '', - author_raw=lead.author_raw or '', + source_raw=lead.source_raw or "", + author_raw=lead.author_raw or "", source=lead.source, ), ) if not created: return instance, False # NOTE: Custom attributes from connector - authors = getattr(lead, '_authors', None) + authors = getattr(lead, "_authors", None) if authors: instance.authors.set(authors) return instance, True @@ -75,18 +72,19 @@ def get_or_create_from_lead(cls, lead: Lead): def update_extraction_status(self, new_status, commit=True): self.extraction_status = new_status if commit: - self.save(update_fields=('extraction_status',)) + self.save(update_fields=("extraction_status",)) class ConnectorLeadPreviewImage(models.Model): - connector_lead = models.ForeignKey(ConnectorLead, on_delete=models.CASCADE, related_name='preview_images') - image = models.FileField(upload_to='connector-lead/preview-images/', max_length=255) + connector_lead = models.ForeignKey(ConnectorLead, on_delete=models.CASCADE, related_name="preview_images") + image = models.FileField(upload_to="connector-lead/preview-images/", max_length=255) class UnifiedConnector(UserResource): """ Unified Connector: Contains source level connector """ + title = models.CharField(max_length=255) project = models.ForeignKey(Project, on_delete=models.CASCADE) is_active = models.BooleanField(default=False) @@ -102,19 +100,19 @@ def can_delete(self, _): class ConnectorSource(UserResource): class Source(models.TextChoices): - ATOM_FEED = 'atom-feed', 'Atom Feed' - RELIEF_WEB = 'relief-web', 'Relifweb' - RSS_FEED = 'rss-feed', 'RSS Feed' - UNHCR = 'unhcr-portal', 'UNHCR Portal' - HUMANITARIAN_RESP = 'humanitarian-resp', 'Humanitarian Response' - PDNA = 'pdna', 'Post Disaster Needs Assessments' - EMM = 'emm', 'European Media Monitor' + ATOM_FEED = "atom-feed", "Atom Feed" + RELIEF_WEB = "relief-web", "Relifweb" + RSS_FEED = "rss-feed", "RSS Feed" + UNHCR = "unhcr-portal", "UNHCR Portal" + HUMANITARIAN_RESP = "humanitarian-resp", "Humanitarian Response" + PDNA = "pdna", "Post Disaster Needs Assessments" + EMM = "emm", "European Media Monitor" class Status(models.IntegerChoices): - PENDING = 0, 'Pending' - PROCESSING = 1, 'Processing' - SUCCESS = 2, 'success' - FAILURE = 3, 'failure' + PENDING = 0, "Pending" + PROCESSING = 1, "Processing" + SUCCESS = 2, "success" + FAILURE = 3, "failure" SOURCE_FETCHER_MAP = { Source.ATOM_FEED: atom_feed.AtomFeed, @@ -127,15 +125,16 @@ class Status(models.IntegerChoices): } title = models.CharField(max_length=255) - unified_connector = models.ForeignKey(UnifiedConnector, on_delete=models.CASCADE, related_name='sources') + unified_connector = models.ForeignKey(UnifiedConnector, on_delete=models.CASCADE, related_name="sources") source = models.CharField(max_length=20, choices=Source.choices) params = models.JSONField(default=dict) client_id = None leads = models.ManyToManyField( - ConnectorLead, blank=True, - through_fields=('source', 'connector_lead'), - through='ConnectorSourceLead', + ConnectorLead, + blank=True, + through_fields=("source", "connector_lead"), + through="ConnectorSourceLead", ) last_fetched_at = models.DateTimeField(blank=True, null=True) stats = models.JSONField(default=dict) # {published_dates: ['date': <>, 'count': <>]} @@ -158,18 +157,23 @@ def source_fetcher(self): def generate_stats(self, commit=True): threshold = ProjectStats.get_activity_timeframe() self.stats = { - 'published_dates': [ + "published_dates": [ { - 'date': str(date), - 'count': count, - } for count, date in self.leads.filter( + "date": str(date), + "count": count, + } + for count, date in self.leads.filter( published_on__isnull=False, published_on__gte=threshold, - ).order_by().values('published_on').annotate( - count=models.Count('*'), - ).values_list('count', models.F('published_on')) + ) + .order_by() + .values("published_on") + .annotate( + count=models.Count("*"), + ) + .values_list("count", models.F("published_on")) ], - 'leads_count': self.leads.count(), + "leads_count": self.leads.count(), } if commit: self.save() @@ -184,9 +188,8 @@ def add_lead(self, lead, **kwargs): ) def save(self, *args, **kwargs): - params_changed = ( - self.old_params != self.params and - ('params' in kwargs['update_fields'] if 'update_fields' in kwargs else True) + params_changed = self.old_params != self.params and ( + "params" in kwargs["update_fields"] if "update_fields" in kwargs else True ) if params_changed: # Reset attributes if params are changed @@ -194,19 +197,14 @@ def save(self, *args, **kwargs): self.last_fetched_at = None self.stats = {} self.status = ConnectorSource.Status.PENDING - if 'update_fields' in kwargs: - kwargs['update_fields'] = list(set([ - 'stats', - 'status', - 'last_fetched_at', - *kwargs['update_fields'] - ])) + if "update_fields" in kwargs: + kwargs["update_fields"] = list(set(["stats", "status", "last_fetched_at", *kwargs["update_fields"]])) super().save(*args, **kwargs) self.old_params = self.params class ConnectorSourceLead(models.Model): # ConnectorSource's Leads - source = models.ForeignKey(ConnectorSource, on_delete=models.CASCADE, related_name='source_leads') + source = models.ForeignKey(ConnectorSource, on_delete=models.CASCADE, related_name="source_leads") connector_lead = models.ForeignKey(ConnectorLead, on_delete=models.CASCADE) blocked = models.BooleanField(default=False) already_added = models.BooleanField(default=False) diff --git a/apps/unified_connector/mutation.py b/apps/unified_connector/mutation.py index 54c8bc070e..47a82efc60 100644 --- a/apps/unified_connector/mutation.py +++ b/apps/unified_connector/mutation.py @@ -1,43 +1,36 @@ import graphene +from deep.permissions import ProjectPermissions as PP from utils.graphene.mutation import ( - generate_input_type_for_serializer, - PsGrapheneMutation, PsDeleteMutation, + PsGrapheneMutation, + generate_input_type_for_serializer, ) -from deep.permissions import ProjectPermissions as PP -from .models import ( - UnifiedConnector, - ConnectorSourceLead, -) -from .schema import ( - UnifiedConnectorType, - ConnectorSourceLeadType, -) +from .models import ConnectorSourceLead, UnifiedConnector +from .schema import ConnectorSourceLeadType, UnifiedConnectorType from .serializers import ( + ConnectorSourceLeadGqSerializer, UnifiedConnectorGqSerializer, UnifiedConnectorWithSourceGqSerializer, - ConnectorSourceLeadGqSerializer, ) from .tasks import process_unified_connector - UnifiedConnectorInputType = generate_input_type_for_serializer( - 'UnifiedConnectorInputType', + "UnifiedConnectorInputType", serializer_class=UnifiedConnectorGqSerializer, ) UnifiedConnectorWithSourceInputType = generate_input_type_for_serializer( - 'UnifiedConnectorWithSourceInputType', + "UnifiedConnectorWithSourceInputType", serializer_class=UnifiedConnectorWithSourceGqSerializer, ) ConnectorSourceLeadInputType = generate_input_type_for_serializer( - 'ConnectorSourceLeadInputType', + "ConnectorSourceLeadInputType", serializer_class=ConnectorSourceLeadGqSerializer, ) -class UnifiedConnectorMixin(): +class UnifiedConnectorMixin: @classmethod def filter_queryset(cls, qs, info): return qs.filter(project=info.context.active_project) @@ -46,6 +39,7 @@ def filter_queryset(cls, qs, info): class CreateUnifiedConnector(UnifiedConnectorMixin, PsGrapheneMutation): class Arguments: data = UnifiedConnectorWithSourceInputType(required=True) + model = UnifiedConnector serializer_class = UnifiedConnectorWithSourceGqSerializer result = graphene.Field(UnifiedConnectorType) @@ -56,6 +50,7 @@ class UpdateUnifiedConnector(UnifiedConnectorMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) data = UnifiedConnectorInputType(required=True) + model = UnifiedConnector serializer_class = UnifiedConnectorGqSerializer result = graphene.Field(UnifiedConnectorType) @@ -66,6 +61,7 @@ class UpdateUnifiedConnectorWithSource(UnifiedConnectorMixin, PsGrapheneMutation class Arguments: id = graphene.ID(required=True) data = UnifiedConnectorWithSourceInputType(required=True) + model = UnifiedConnector serializer_class = UnifiedConnectorWithSourceGqSerializer result = graphene.Field(UnifiedConnectorType) @@ -75,6 +71,7 @@ class Arguments: class DeleteUnifiedConnector(UnifiedConnectorMixin, PsDeleteMutation): class Arguments: id = graphene.ID(required=True) + model = UnifiedConnector result = graphene.Field(UnifiedConnectorType) permissions = [PP.Permission.DELETE_UNIFIED_CONNECTOR] @@ -83,6 +80,7 @@ class Arguments: class TriggerUnifiedConnector(UnifiedConnectorMixin, PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) + model = UnifiedConnector serializer_class = UnifiedConnectorGqSerializer permissions = [PP.Permission.VIEW_UNIFIED_CONNECTOR] @@ -95,7 +93,7 @@ def perform_mutate(cls, _, info, **kwargs): if instance.is_active: process_unified_connector.delay(instance.pk) return cls(errors=None, ok=True) - errors = [dict(field='nonFieldErrors', message='Inactive unified connector!!')] + errors = [dict(field="nonFieldErrors", message="Inactive unified connector!!")] return cls(errors=errors, ok=False) @@ -103,6 +101,7 @@ class UpdateConnectorSourceLead(PsGrapheneMutation): class Arguments: id = graphene.ID(required=True) data = ConnectorSourceLeadInputType(required=True) + model = ConnectorSourceLead serializer_class = ConnectorSourceLeadGqSerializer permissions = [PP.Permission.VIEW_UNIFIED_CONNECTOR] @@ -110,9 +109,7 @@ class Arguments: @classmethod def filter_queryset(cls, qs, info): - return qs.filter( - source__unified_connector__project=info.context.active_project - ) + return qs.filter(source__unified_connector__project=info.context.active_project) class UnifiedConnectorMutationType(graphene.ObjectType): diff --git a/apps/unified_connector/schema.py b/apps/unified_connector/schema.py index 3d281b7719..0c7f16afe7 100644 --- a/apps/unified_connector/schema.py +++ b/apps/unified_connector/schema.py @@ -1,33 +1,34 @@ -import graphene import datetime -from django.db.models import QuerySet, Q + +import graphene +from django.db.models import Q, QuerySet from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField +from unified_connector.sources.atom_feed import AtomFeed +from unified_connector.sources.rss_feed import RssFeed +from user_resource.schema import UserResourceMixin +from deep.permissions import ProjectPermissions as PP from utils.graphene.enums import EnumDescription -from utils.graphene.pagination import NoOrderingPageGraphqlPagination -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin from utils.graphene.fields import DjangoPaginatedListObjectField -from user_resource.schema import UserResourceMixin -from deep.permissions import ProjectPermissions as PP -from unified_connector.sources.rss_feed import RssFeed -from unified_connector.sources.atom_feed import AtomFeed +from utils.graphene.pagination import NoOrderingPageGraphqlPagination +from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType +from .enums import ( + ConnectorLeadExtractionStatusEnum, + ConnectorSourceSourceEnum, + ConnectorSourceStatusEnum, +) from .filters import ( ConnectorSourceGQFilterSet, ConnectorSourceLeadGQFilterSet, UnifiedConnectorGQFilterSet, ) from .models import ( - UnifiedConnector, ConnectorLead, ConnectorSource, ConnectorSourceLead, -) -from .enums import ( - ConnectorSourceSourceEnum, - ConnectorSourceStatusEnum, - ConnectorLeadExtractionStatusEnum, + UnifiedConnector, ) @@ -55,20 +56,20 @@ def get_connector_source_lead_qs(info): # NOTE: This is not used directly class ConnectorLeadType(DjangoObjectType): extraction_status = graphene.Field(ConnectorLeadExtractionStatusEnum, required=True) - extraction_status_display = EnumDescription(source='get_extraction_status_display', required=True) + extraction_status_display = EnumDescription(source="get_extraction_status_display", required=True) class Meta: model = ConnectorLead only_fields = ( - 'id', - 'url', - 'website', - 'title', - 'published_on', - 'source_raw', - 'author_raw', - 'source', - 'authors', + "id", + "url", + "website", + "title", + "published_on", + "source_raw", + "author_raw", + "source", + "authors", ) @staticmethod @@ -82,14 +83,14 @@ def resolve_authors(root, info, **_): class ConnectorSourceLeadType(DjangoObjectType): connector_lead = graphene.Field(ConnectorLeadType, required=True) - source = graphene.ID(required=True, source='source_id') + source = graphene.ID(required=True, source="source_id") class Meta: model = ConnectorSourceLead only_fields = ( - 'id', - 'blocked', - 'already_added', + "id", + "blocked", + "already_added", ) @staticmethod @@ -113,7 +114,7 @@ class ConnectorSourceStatsType(graphene.ObjectType): @staticmethod def resolve_date(root, info, **kwargs): - return datetime.datetime.strptime(root['date'], '%Y-%m-%d') + return datetime.datetime.strptime(root["date"], "%Y-%m-%d") class ConnectorSourceLeadCountType(graphene.ObjectType): @@ -124,21 +125,21 @@ class ConnectorSourceLeadCountType(graphene.ObjectType): class ConnectorSourceType(UserResourceMixin, ClientIdMixin, DjangoObjectType): source = graphene.Field(ConnectorSourceSourceEnum, required=True) - source_display = EnumDescription(source='get_source_display', required=True) - unified_connector = graphene.ID(required=True, source='unified_connector_id') + source_display = EnumDescription(source="get_source_display", required=True) + unified_connector = graphene.ID(required=True, source="unified_connector_id") stats = graphene.List(ConnectorSourceStatsType) leads_count = graphene.NonNull(ConnectorSourceLeadCountType) status = graphene.Field(ConnectorSourceStatusEnum, required=True) - status_display = EnumDescription(source='get_status_display', required=True) + status_display = EnumDescription(source="get_status_display", required=True) class Meta: model = ConnectorSource only_fields = ( - 'id', - 'title', - 'unified_connector', - 'last_fetched_at', - 'params', + "id", + "title", + "unified_connector", + "last_fetched_at", + "params", ) @staticmethod @@ -147,7 +148,7 @@ def get_custom_queryset(queryset, info, **_): @staticmethod def resolve_stats(root, info, **_): - return (root.stats or {}).get('published_dates') or [] + return (root.stats or {}).get("published_dates") or [] @staticmethod def resolve_leads_count(root, info, **_): @@ -161,16 +162,16 @@ class Meta: class UnifiedConnectorType(UserResourceMixin, ClientIdMixin, DjangoObjectType): - project = graphene.ID(required=True, source='project_id') + project = graphene.ID(required=True, source="project_id") sources = graphene.List(graphene.NonNull(ConnectorSourceType)) leads_count = graphene.NonNull(ConnectorSourceLeadCountType) class Meta: model = UnifiedConnector only_fields = ( - 'id', - 'title', - 'is_active', + "id", + "title", + "is_active", ) @staticmethod @@ -198,22 +199,22 @@ class UnifiedConnectorQueryType(graphene.ObjectType): unified_connectors = DjangoPaginatedListObjectField( UnifiedConnectorListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', - ) + page_size_query_param="pageSize", + ), ) connector_source = DjangoObjectField(ConnectorSourceType) connector_sources = DjangoPaginatedListObjectField( ConnectorSourceListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', - ) + page_size_query_param="pageSize", + ), ) connector_source_lead = DjangoObjectField(ConnectorSourceLeadType) connector_source_leads = DjangoPaginatedListObjectField( ConnectorSourceLeadListType, pagination=NoOrderingPageGraphqlPagination( - page_size_query_param='pageSize', - ) + page_size_query_param="pageSize", + ), ) source_count_without_ingnored_and_added = graphene.Field(graphene.Int) @@ -222,10 +223,7 @@ def resolve_source_count_without_ingnored_and_added(root, info, **kwargs): qs = ConnectorSourceLead.objects.filter( source__unified_connector__project=info.context.active_project, source__unified_connector__is_active=True, - ).exclude( - Q(blocked=True) | - Q(already_added=True) - ) + ).exclude(Q(blocked=True) | Q(already_added=True)) if PP.check_permission(info, PP.Permission.VIEW_UNIFIED_CONNECTOR): return qs.count() return diff --git a/apps/unified_connector/serializers.py b/apps/unified_connector/serializers.py index 20b99a6c8c..2c17612f49 100644 --- a/apps/unified_connector/serializers.py +++ b/apps/unified_connector/serializers.py @@ -1,19 +1,16 @@ import logging -from rest_framework import serializers + from django.db import transaction +from rest_framework import serializers +from user_resource.serializers import UserResourceSerializer from deep.serializers import ( - TempClientIdMixin, - ProjectPropertySerializerMixin, IntegerIDField, + ProjectPropertySerializerMixin, + TempClientIdMixin, ) -from user_resource.serializers import UserResourceSerializer -from .models import ( - UnifiedConnector, - ConnectorSource, - ConnectorSourceLead, -) +from .models import ConnectorSource, ConnectorSourceLead, UnifiedConnector from .tasks import process_unified_connector logger = logging.getLogger(__name__) @@ -22,16 +19,16 @@ # ------------------- Graphql Serializers ------------------------------------ class ConnectorSourceGqSerializer(ProjectPropertySerializerMixin, TempClientIdMixin, UserResourceSerializer): id = IntegerIDField(required=False) - project_property_attribute = 'unified_connector' + project_property_attribute = "unified_connector" class Meta: model = ConnectorSource fields = ( - 'id', - 'title', - 'source', - 'params', - 'client_id', # From TempClientIdMixin + "id", + "title", + "source", + "params", + "client_id", # From TempClientIdMixin ) @@ -39,20 +36,18 @@ class UnifiedConnectorGqSerializer(ProjectPropertySerializerMixin, TempClientIdM class Meta: model = UnifiedConnector fields = ( - 'title', - 'is_active', - 'client_id', # From TempClientIdMixin + "title", + "is_active", + "client_id", # From TempClientIdMixin ) def validate(self, data): - data['project'] = self.project + data["project"] = self.project return data def create(self, data): instance = super().create(data) - transaction.on_commit( - lambda: process_unified_connector.delay(instance.pk) - ) + transaction.on_commit(lambda: process_unified_connector.delay(instance.pk)) return instance @@ -63,7 +58,7 @@ class Meta: model = UnifiedConnector fields = [ *UnifiedConnectorGqSerializer.Meta.fields, - 'sources', + "sources", ] # NOTE: This is a custom function (apps/user_resource/serializers.py::UserResourceSerializer) @@ -77,9 +72,9 @@ def validate_sources(self, sources): source_found = set() # Only allow unique source per unified connectors for source in sources: - source_type = source['source'] + source_type = source["source"] if source_type in source_found: - raise serializers.ValidationError(f'Multiple connector found for {source_type}') + raise serializers.ValidationError(f"Multiple connector found for {source_type}") source_found.add(source_type) return sources @@ -87,6 +82,4 @@ def validate_sources(self, sources): class ConnectorSourceLeadGqSerializer(serializers.ModelSerializer): class Meta: model = ConnectorSourceLead - fields = ( - 'blocked', - ) + fields = ("blocked",) diff --git a/apps/unified_connector/sources/acaps_briefing_notes.py b/apps/unified_connector/sources/acaps_briefing_notes.py index 6ae2e972b7..16921d3d98 100644 --- a/apps/unified_connector/sources/acaps_briefing_notes.py +++ b/apps/unified_connector/sources/acaps_briefing_notes.py @@ -1,141 +1,137 @@ -import logging -from bs4 import BeautifulSoup as Soup -import requests import datetime +import logging -from .base import Source +import requests +from bs4 import BeautifulSoup as Soup from connector.utils import ConnectorWrapper from lead.models import Lead +from .base import Source + logger = logging.getLogger(__name__) COUNTRIES_OPTIONS = [ - {'key': 'All', 'label': 'Any'}, - {'key': '196', 'label': 'Afghanistan'}, - {'key': '202', 'label': 'Angola'}, - {'key': '214', 'label': 'Bangladesh'}, - {'key': '219', 'label': 'Benin'}, - {'key': '222', 'label': 'Bolivia'}, - {'key': '224', 'label': 'Bosnia and Herzegovina'}, - {'key': '226', 'label': 'Brazil'}, - {'key': '230', 'label': 'Burkina Faso'}, - {'key': '231', 'label': 'Burundi'}, - {'key': '233', 'label': 'Cambodia'}, - {'key': '234', 'label': 'Cameroon'}, - {'key': '237', 'label': 'CAR'}, - {'key': '239', 'label': 'Chad'}, - {'key': '242', 'label': 'China'}, - {'key': '248', 'label': 'Colombia'}, - {'key': '250', 'label': 'Congo'}, - {'key': '253', 'label': "Côte d'Ivoire"}, - {'key': '254', 'label': 'Croatia'}, - {'key': '262', 'label': 'Djibouti'}, - {'key': '263', 'label': 'Dominica'}, - {'key': '264', 'label': 'Dominican Republic'}, - {'key': '259', 'label': 'DPRK'}, - {'key': '260', 'label': 'DRC'}, - {'key': '266', 'label': 'Ecuador'}, - {'key': '267', 'label': 'Egypt'}, - {'key': '268', 'label': 'El Salvador'}, - {'key': '270', 'label': 'Eritrea'}, - {'key': '272', 'label': 'Ethiopia'}, - {'key': '275', 'label': 'Fiji'}, - {'key': '277', 'label': 'France'}, - {'key': '287', 'label': 'Greece'}, - {'key': '292', 'label': 'Guatemala'}, - {'key': '293', 'label': 'Guinea'}, - {'key': '642', 'label': 'Haiti'}, - {'key': '302', 'label': 'India'}, - {'key': '303', 'label': 'Indonesia'}, - {'key': '304', 'label': 'Iran'}, - {'key': '305', 'label': 'Iraq'}, - {'key': '312', 'label': 'Jordan'}, - {'key': '314', 'label': 'Kenya'}, - {'key': '320', 'label': 'Lebanon'}, - {'key': '321', 'label': 'Lesotho'}, - {'key': '322', 'label': 'Liberia'}, - {'key': '323', 'label': 'Libya'}, - {'key': '327', 'label': 'Madagascar'}, - {'key': '329', 'label': 'Malawi'}, - {'key': '332', 'label': 'Mali'}, - {'key': '336', 'label': 'Mauritania'}, - {'key': '339', 'label': 'Mexico'}, - {'key': '343', 'label': 'Mongolia'}, - {'key': '346', 'label': 'Morocco'}, - {'key': '347', 'label': 'Mozambique'}, - {'key': '348', 'label': 'Myanmar'}, - {'key': '349', 'label': 'Namibia'}, - {'key': '351', 'label': 'Nepal'}, - {'key': '356', 'label': 'Nicaragua'}, - {'key': '357', 'label': 'Niger'}, - {'key': '358', 'label': 'Nigeria'}, - {'key': '365', 'label': 'Pakistan'}, - {'key': '363', 'label': 'Palestine'}, - {'key': '368', 'label': 'Papua New Guinea'}, - {'key': '370', 'label': 'Peru'}, - {'key': '371', 'label': 'Philippines'}, - {'key': '381', 'label': 'Rwanda'}, - {'key': '393', 'label': 'Senegal'}, - {'key': '394', 'label': 'Serbia'}, - {'key': '396', 'label': 'Sierra Leone'}, - {'key': '400', 'label': 'Slovenia'}, - {'key': '402', 'label': 'Somalia'}, - {'key': '404', 'label': 'South Sudan'}, - {'key': '406', 'label': 'Sri Lanka'}, - {'key': '407', 'label': 'Sudan'}, - {'key': '410', 'label': 'Swaziland'}, - {'key': '194', 'label': 'Syria'}, - {'key': '413', 'label': 'Tajikistan'}, - {'key': '415', 'label': 'the former Yugoslav Republic of Macedonia'}, - {'key': '282', 'label': 'The Gambia'}, - {'key': '416', 'label': 'Timor-Leste'}, - {'key': '419', 'label': 'Tonga'}, - {'key': '422', 'label': 'Turkey'}, - {'key': '426', 'label': 'Uganda'}, - {'key': '427', 'label': 'Ukraine'}, - {'key': '435', 'label': 'Vanuatu'}, - {'key': '436', 'label': 'Venezuela'}, - {'key': '437', 'label': 'Vietnam'}, - {'key': '457', 'label': 'Yemen'}, - {'key': '442', 'label': 'Zambia'}, - {'key': '443', 'label': 'Zimbabwe'} + {"key": "All", "label": "Any"}, + {"key": "196", "label": "Afghanistan"}, + {"key": "202", "label": "Angola"}, + {"key": "214", "label": "Bangladesh"}, + {"key": "219", "label": "Benin"}, + {"key": "222", "label": "Bolivia"}, + {"key": "224", "label": "Bosnia and Herzegovina"}, + {"key": "226", "label": "Brazil"}, + {"key": "230", "label": "Burkina Faso"}, + {"key": "231", "label": "Burundi"}, + {"key": "233", "label": "Cambodia"}, + {"key": "234", "label": "Cameroon"}, + {"key": "237", "label": "CAR"}, + {"key": "239", "label": "Chad"}, + {"key": "242", "label": "China"}, + {"key": "248", "label": "Colombia"}, + {"key": "250", "label": "Congo"}, + {"key": "253", "label": "Côte d'Ivoire"}, + {"key": "254", "label": "Croatia"}, + {"key": "262", "label": "Djibouti"}, + {"key": "263", "label": "Dominica"}, + {"key": "264", "label": "Dominican Republic"}, + {"key": "259", "label": "DPRK"}, + {"key": "260", "label": "DRC"}, + {"key": "266", "label": "Ecuador"}, + {"key": "267", "label": "Egypt"}, + {"key": "268", "label": "El Salvador"}, + {"key": "270", "label": "Eritrea"}, + {"key": "272", "label": "Ethiopia"}, + {"key": "275", "label": "Fiji"}, + {"key": "277", "label": "France"}, + {"key": "287", "label": "Greece"}, + {"key": "292", "label": "Guatemala"}, + {"key": "293", "label": "Guinea"}, + {"key": "642", "label": "Haiti"}, + {"key": "302", "label": "India"}, + {"key": "303", "label": "Indonesia"}, + {"key": "304", "label": "Iran"}, + {"key": "305", "label": "Iraq"}, + {"key": "312", "label": "Jordan"}, + {"key": "314", "label": "Kenya"}, + {"key": "320", "label": "Lebanon"}, + {"key": "321", "label": "Lesotho"}, + {"key": "322", "label": "Liberia"}, + {"key": "323", "label": "Libya"}, + {"key": "327", "label": "Madagascar"}, + {"key": "329", "label": "Malawi"}, + {"key": "332", "label": "Mali"}, + {"key": "336", "label": "Mauritania"}, + {"key": "339", "label": "Mexico"}, + {"key": "343", "label": "Mongolia"}, + {"key": "346", "label": "Morocco"}, + {"key": "347", "label": "Mozambique"}, + {"key": "348", "label": "Myanmar"}, + {"key": "349", "label": "Namibia"}, + {"key": "351", "label": "Nepal"}, + {"key": "356", "label": "Nicaragua"}, + {"key": "357", "label": "Niger"}, + {"key": "358", "label": "Nigeria"}, + {"key": "365", "label": "Pakistan"}, + {"key": "363", "label": "Palestine"}, + {"key": "368", "label": "Papua New Guinea"}, + {"key": "370", "label": "Peru"}, + {"key": "371", "label": "Philippines"}, + {"key": "381", "label": "Rwanda"}, + {"key": "393", "label": "Senegal"}, + {"key": "394", "label": "Serbia"}, + {"key": "396", "label": "Sierra Leone"}, + {"key": "400", "label": "Slovenia"}, + {"key": "402", "label": "Somalia"}, + {"key": "404", "label": "South Sudan"}, + {"key": "406", "label": "Sri Lanka"}, + {"key": "407", "label": "Sudan"}, + {"key": "410", "label": "Swaziland"}, + {"key": "194", "label": "Syria"}, + {"key": "413", "label": "Tajikistan"}, + {"key": "415", "label": "the former Yugoslav Republic of Macedonia"}, + {"key": "282", "label": "The Gambia"}, + {"key": "416", "label": "Timor-Leste"}, + {"key": "419", "label": "Tonga"}, + {"key": "422", "label": "Turkey"}, + {"key": "426", "label": "Uganda"}, + {"key": "427", "label": "Ukraine"}, + {"key": "435", "label": "Vanuatu"}, + {"key": "436", "label": "Venezuela"}, + {"key": "437", "label": "Vietnam"}, + {"key": "457", "label": "Yemen"}, + {"key": "442", "label": "Zambia"}, + {"key": "443", "label": "Zimbabwe"}, ] @ConnectorWrapper class AcapsBriefingNotes(Source): - URL = 'https://www.acaps.org/special-reports' - title = 'ACAPS Briefing Notes' - key = 'acaps-briefing-notes' + URL = "https://www.acaps.org/special-reports" + title = "ACAPS Briefing Notes" + key = "acaps-briefing-notes" options = [ { - 'key': 'field_product_status_value', - 'field_type': 'select', - 'title': 'Published date', - 'options': [ - {'key': 'All', 'label': 'Any'}, - {'key': 'upcoming', 'label': 'Upcoming'}, - {'key': 'published', 'label': 'Published'}, + "key": "field_product_status_value", + "field_type": "select", + "title": "Published date", + "options": [ + {"key": "All", "label": "Any"}, + {"key": "upcoming", "label": "Upcoming"}, + {"key": "published", "label": "Published"}, ], }, + {"key": "field_countries_target_id", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}, { - 'key': 'field_countries_target_id', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS + "key": "field_product_category_target_id", + "field_type": "select", + "title": "Type of Report", + "options": [ + {"key": "All", "label": "Any"}, + {"key": "281", "label": "Short notes"}, + {"key": "279", "label": "Briefing notes"}, + {"key": "280", "label": "Crisis profiles"}, + {"key": "282", "label": "Thematic reports"}, + ], }, - { - 'key': 'field_product_category_target_id', - 'field_type': 'select', - 'title': 'Type of Report', - 'options': [ - {'key': 'All', 'label': 'Any'}, - {'key': '281', 'label': 'Short notes'}, - {'key': '279', 'label': 'Briefing notes'}, - {'key': '280', 'label': 'Crisis profiles'}, - {'key': '282', 'label': 'Thematic reports'}, - ] - } ] def get_content(self, url, params): @@ -145,34 +141,29 @@ def get_content(self, url, params): def fetch(self, params): results = [] content = self.get_content(self.URL, params) - soup = Soup(content, 'html.parser') - contents = soup.findAll('div', {'class': 'wrapper-type'}) + soup = Soup(content, "html.parser") + contents = soup.findAll("div", {"class": "wrapper-type"}) if not contents: return results, 0 content = contents[0] - for item in content.findAll('div', {'class': 'views-row'}): + for item in content.findAll("div", {"class": "views-row"}): try: - bottomcontent = item.find('div', {'class': 'content-bottom'}) - topcontent = item.find('div', {'class': 'content-top'}) - date = topcontent.find('span', {'class': 'updated-date'}).text - date = datetime.datetime.strptime(date, '%d/%m/%Y') - title = topcontent.find('div', {'class': 'field-item'}).text - link_elem = bottomcontent.find( - 'div', {'class': 'field-item'} - ) - link = link_elem.find('a') + bottomcontent = item.find("div", {"class": "content-bottom"}) + topcontent = item.find("div", {"class": "content-top"}) + date = topcontent.find("span", {"class": "updated-date"}).text + date = datetime.datetime.strptime(date, "%d/%m/%Y") + title = topcontent.find("div", {"class": "field-item"}).text + link_elem = bottomcontent.find("div", {"class": "field-item"}) + link = link_elem.find("a") data = { - 'title': title.strip(), - 'published_on': date.date(), - 'url': link['href'], - 'source': 'Briefing Notes', - 'source_type': Lead.SourceType.WEBSITE, + "title": title.strip(), + "published_on": date.date(), + "url": link["href"], + "source": "Briefing Notes", + "source_type": Lead.SourceType.WEBSITE, } results.append(data) except Exception as e: - logger.warning( - "Exception parsing {} with params {}: {}".format( - self.URL, params, e.args) - ) + logger.warning("Exception parsing {} with params {}: {}".format(self.URL, params, e.args)) return results, len(results) diff --git a/apps/unified_connector/sources/atom_feed.py b/apps/unified_connector/sources/atom_feed.py index cc96302b6e..8e0e258ca0 100644 --- a/apps/unified_connector/sources/atom_feed.py +++ b/apps/unified_connector/sources/atom_feed.py @@ -1,35 +1,33 @@ import time + import feedparser import requests -from rest_framework import serializers - -from lead.models import Lead from connector.utils import ConnectorWrapper +from lead.models import Lead +from rest_framework import serializers from .rss_feed import RssFeed @ConnectorWrapper class AtomFeed(RssFeed): - title = 'Atom Feed' - key = 'atom-feed' + title = "Atom Feed" + key = "atom-feed" def get_content(self, url, params): resp = requests.get(url) return resp.content def query_fields(self, params): - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return [] - feed_url = params['feed-url'] + feed_url = params["feed-url"] feed = feedparser.parse(feed_url) items = feed.entries - if feed.get('bozo_exception'): - raise serializers.ValidationError({ - 'feed-url': 'Could not fetch/parse atom feed' - }) + if feed.get("bozo_exception"): + raise serializers.ValidationError({"feed-url": "Could not fetch/parse atom feed"}) if not items: return [] @@ -44,18 +42,18 @@ def query_fields(self, params): fields[key] = {} # Ignore this fields else: fields[key] = { - 'key': key, - 'label': key.replace('_', ' ').title(), + "key": key, + "label": key.replace("_", " ").title(), } return [option for option in fields.values() if option] def fetch(self, params): results = [] - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return results, 0 - feed_url = params['feed-url'] + feed_url = params["feed-url"] content = self.get_content(feed_url, {}) feed = feedparser.parse(content) @@ -66,7 +64,7 @@ def fetch(self, params): for item in limited_items: data = { - 'source_type': Lead.SourceType.RSS, + "source_type": Lead.SourceType.RSS, **{ lead_field: (item or {}).get(params.get(param_key)) for lead_field, param_key in self._option_lead_field_map.items() diff --git a/apps/unified_connector/sources/base.py b/apps/unified_connector/sources/base.py index 9bf5728f7b..823f96ffc8 100644 --- a/apps/unified_connector/sources/base.py +++ b/apps/unified_connector/sources/base.py @@ -1,19 +1,18 @@ import copy import datetime - -from typing import List, Tuple, Union -from functools import reduce from abc import ABC, abstractmethod -from django.db.models import Q +from functools import reduce +from typing import List, Tuple, Union +from django.db.models import Q +from lead.models import Lead from organization.models import Organization + from utils.common import random_key from utils.date_extractor import str_to_date -from lead.models import Lead - -class OrganizationSearch(): +class OrganizationSearch: def __init__(self, texts, source_type, creator): self.source_type = source_type self.creator = creator @@ -29,10 +28,7 @@ def create_organization(self, text): ) def fetch(self, texts): - text_queries = [ - (text, text.lower()) - for text in set(texts) if text - ] + text_queries = [(text, text.lower()) for text in set(texts) if text] if len(text_queries) == 0: # Nothing to do here @@ -41,14 +37,9 @@ def fetch(self, texts): exact_query = reduce( lambda acc, item: acc | item, - [ - Q(title__iexact=d) | - Q(short_name__iexact=d) | - Q(long_name__iexact=d) - for _, d in text_queries - ], + [Q(title__iexact=d) | Q(short_name__iexact=d) | Q(long_name__iexact=d) for _, d in text_queries], ) - exact_organizations = Organization.objects.filter(exact_query).select_related('parent').all() + exact_organizations = Organization.objects.filter(exact_query).select_related("parent").all() organization_map = { # NOTE: organization.data will return itself or it's parent organization (handling merged organizations) key.lower(): organization.data @@ -75,12 +66,8 @@ class Source(ABC): UNIFIED_CONNECTOR_SOURCE_MAX_PAGE_NUMBER = 100 def __init__(self): - if ( - not hasattr(self, 'title') or - not hasattr(self, 'key') or - not hasattr(self, 'options') - ): - raise Exception('Source not defined properly') + if not hasattr(self, "title") or not hasattr(self, "key") or not hasattr(self, "options"): + raise Exception("Source not defined properly") @abstractmethod def fetch(self, params): @@ -102,40 +89,34 @@ def _parse_date(date_raw) -> Union[None, datetime.date]: return [], total_count organization_search = OrganizationSearch( - [ - label - for d in leads_data - for label in [d['source'], d['author']] - ], + [label for d in leads_data for label in [d["source"], d["author"]]], Organization.SourceType.CONNECTOR, request_user, ) leads = [] for ldata in leads_data: - published_on = _parse_date(ldata['published_on']) + published_on = _parse_date(ldata["published_on"]) lead = Lead( - id=ldata.get('id', random_key()), - title=ldata['title'], + id=ldata.get("id", random_key()), + title=ldata["title"], published_on=published_on, - url=ldata['url'], - source_raw=ldata['source'], - author_raw=ldata['author'], - source=organization_search.get(ldata['source']), - author=organization_search.get(ldata['author']), - source_type=ldata['source_type'], + url=ldata["url"], + source_raw=ldata["source"], + author_raw=ldata["author"], + source=organization_search.get(ldata["source"]), + author=organization_search.get(ldata["author"]), + source_type=ldata["source_type"], ) - if ldata.get('author') is not None: - lead._authors = list( - filter(None, [organization_search.get(ldata['author'])]) - ) + if ldata.get("author") is not None: + lead._authors = list(filter(None, [organization_search.get(ldata["author"])])) # Add emm info - if ldata.get('emm_triggers') is not None: - lead._emm_triggers = ldata['emm_triggers'] - if ldata.get('emm_entities') is not None: - lead._emm_entities = ldata['emm_entities'] + if ldata.get("emm_triggers") is not None: + lead._emm_triggers = ldata["emm_triggers"] + if ldata.get("emm_entities") is not None: + lead._emm_entities = ldata["emm_entities"] leads.append(lead) return leads, total_count diff --git a/apps/unified_connector/sources/emm.py b/apps/unified_connector/sources/emm.py index 7c3ba952ed..664e9d2d60 100644 --- a/apps/unified_connector/sources/emm.py +++ b/apps/unified_connector/sources/emm.py @@ -1,24 +1,24 @@ +import logging import re -import requests -from lxml import etree +import requests +from connector.utils import ConnectorWrapper, get_rss_fields from django.db import transaction - -from utils.common import random_key, get_ns_tag +from lead.models import EMMEntity, Lead, LeadEMMTrigger +from lxml import etree from rest_framework import serializers -from lead.models import Lead, LeadEMMTrigger, EMMEntity +from utils.common import get_ns_tag, random_key + from .rss_feed import RssFeed -from connector.utils import get_rss_fields, ConnectorWrapper -import logging logger = logging.getLogger(__name__) @ConnectorWrapper class EMM(RssFeed): - title = 'European Media Monitor' - key = 'emm' + title = "European Media Monitor" + key = "emm" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -29,16 +29,17 @@ def __init__(self, *args, **kwargs): def initialize(self): from connector.models import EMMConfig + conf = EMMConfig.objects.all().first() if not conf: - msg = 'There is no configuration for emm connector' + msg = "There is no configuration for emm connector" logger.error(msg) raise Exception(msg) self.conf = conf self.conf.trigger_regex = re.compile(self.conf.trigger_regex) def update_emm_items_present(self, xml): - trigger_query = './/category[@emm:trigger]' + trigger_query = ".//category[@emm:trigger]" try: items = xml.xpath(trigger_query, namespaces=self.nsmap) if len(items) > 0: @@ -47,7 +48,7 @@ def update_emm_items_present(self, xml): # Means no such tag pass - entity_query = './/emm:entity' + entity_query = ".//emm:entity" try: items = xml.xpath(entity_query, namespaces=self.nsmap) if len(items) > 0: @@ -57,22 +58,18 @@ def update_emm_items_present(self, xml): pass def query_fields(self, params): - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return [] try: - r = requests.get(params['feed-url']) + r = requests.get(params["feed-url"]) xml = etree.fromstring(r.content) except requests.exceptions.RequestException: - raise serializers.ValidationError({ - 'feed-url': 'Could not fetch rss feed' - }) + raise serializers.ValidationError({"feed-url": "Could not fetch rss feed"}) except etree.XMLSyntaxError: - raise serializers.ValidationError({ - 'feed-url': 'Invalid XML' - }) + raise serializers.ValidationError({"feed-url": "Invalid XML"}) - items = xml.find('channel/item') + items = xml.find("channel/item") self.nsmap = xml.nsmap @@ -80,7 +77,7 @@ def query_fields(self, params): self.update_emm_items_present(xml) fields = [] - for field in items.findall('./'): + for field in items.findall("./"): fields.extend(get_rss_fields(field, self.nsmap)) # Remove fields that are present more than once, @@ -96,12 +93,12 @@ def get_content(self, url, params): return resp.content def fetch(self, params): - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return [], 0 self.params = params - content = self.get_content(self.params['feed-url'], {}) + content = self.get_content(self.params["feed-url"], {}) return self.parse(content) @@ -118,7 +115,7 @@ def parse(self, content): # Check if trigger and entities exist self.update_emm_items_present(xml) - items = xml.findall('channel/item') + items = xml.findall("channel/item") total_count = len(items) @@ -128,18 +125,20 @@ def parse(self, content): # Extract info from item lead_info = self.parse_emm_item(item) - item_entities = lead_info.pop('entities', {}) - item_triggers = lead_info.pop('triggers', []) + item_entities = lead_info.pop("entities", {}) + item_triggers = lead_info.pop("triggers", []) entities.update(item_entities) - leads_infos.append({ - 'id': random_key(), - 'source_type': Lead.SourceType.EMM, - 'emm_triggers': [LeadEMMTrigger(**x) for x in item_triggers], - 'emm_entities': item_entities, - **lead_info, - }) + leads_infos.append( + { + "id": random_key(), + "source_type": Lead.SourceType.EMM, + "emm_triggers": [LeadEMMTrigger(**x) for x in item_triggers], + "emm_entities": item_entities, + **lead_info, + } + ) # Get or create EMM entities with transaction.atomic(): @@ -148,7 +147,7 @@ def parse(self, content): entities[eid] = obj for leadinfo in leads_infos: - leadinfo['emm_entities'] = [entities[eid] for eid, _ in leadinfo['emm_entities'].items()] + leadinfo["emm_entities"] = [entities[eid] for eid, _ in leadinfo["emm_entities"].items()] return leads_infos, total_count @@ -158,28 +157,25 @@ def parse_emm_item(self, item): if not self.params.get(field): field_value = None else: - element = item.find(self.params.get(field, '')) - field_value = (element.text or element.get('href')) if element is not None else None + element = item.find(self.params.get(field, "")) + field_value = (element.text or element.get("href")) if element is not None else None info[lead_field] = field_value # Parse entities - info['entities'] = self.get_entities(item) + info["entities"] = self.get_entities(item) # Parse Triggers - info['triggers'] = self.get_triggers(item) - if info['entities']: + info["triggers"] = self.get_triggers(item) + if info["entities"]: self.has_emm_entities = True - if info['triggers']: + if info["triggers"]: self.has_emm_triggers = True return info def get_entities(self, item): entities = item.findall(self.entity_tag_ns) or [] - return { - x.get('id'): x.get('name') - for x in entities - } + return {x.get("id"): x.get("name") for x in entities} def get_triggers(self, item): trigger_elems = item.findall(self.trigger_tag_ns) or [] @@ -199,22 +195,22 @@ def parse_trigger(self, raw): match = self.conf.trigger_regex.match(raw) if match: return { - 'emm_risk_factor': match['risk_factor'], - 'emm_keyword': match['keyword'], - 'count': match['count'], + "emm_risk_factor": match["risk_factor"], + "emm_keyword": match["keyword"], + "count": match["count"], } return None def test_emm(): - with open('/tmp/rss.xml') as f: + with open("/tmp/rss.xml") as f: e = EMM() params = { - 'url-field': 'link', - 'date-field': 'pubDate', - 'source-field': 'source', - 'author-field': 'source', - 'title-field': 'title', + "url-field": "link", + "date-field": "pubDate", + "source-field": "source", + "author-field": "source", + "title-field": "title", } - data = e.parse_xml(bytes(f.read(), 'utf-8'), params, 0, 10) + data = e.parse_xml(bytes(f.read(), "utf-8"), params, 0, 10) return data diff --git a/apps/unified_connector/sources/humanitarian_response.py b/apps/unified_connector/sources/humanitarian_response.py index 7af9419175..43c8bad144 100644 --- a/apps/unified_connector/sources/humanitarian_response.py +++ b/apps/unified_connector/sources/humanitarian_response.py @@ -1,12 +1,13 @@ import logging -import requests -from bs4 import BeautifulSoup as Soup from datetime import datetime -from .base import Source +import requests +from bs4 import BeautifulSoup as Soup from connector.utils import ConnectorWrapper from lead.models import Lead +from .base import Source + logger = logging.getLogger(__name__) @@ -157,24 +158,17 @@ {"key": "world", "label": "World"}, {"key": "yemen", "label": "Yemen"}, {"key": "zambia", "label": "Zambia"}, - {"key": "zimbabwe", "label": "Zimbabwe"} + {"key": "zimbabwe", "label": "Zimbabwe"}, ] @ConnectorWrapper class HumanitarianResponse(Source): - URL = 'https://www.humanitarianresponse.info/en/documents/table' - title = 'Humanitarian Response' - key = 'humanitarian-response' + URL = "https://www.humanitarianresponse.info/en/documents/table" + title = "Humanitarian Response" + key = "humanitarian-response" - options = [ - { - 'key': 'country', # key is not used - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS - } - ] + options = [{"key": "country", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}] # key is not used def get_content(self, url, params): resp = requests.get(url, params={}) @@ -183,31 +177,28 @@ def get_content(self, url, params): def fetch(self, params): results = [] url = self.URL - if params.get('country'): - url = self.URL + '/locations/' + params['country'] + if params.get("country"): + url = self.URL + "/locations/" + params["country"] content = self.get_content(url, {}) - soup = Soup(content, 'html.parser') - contents = soup.find('div', {'id': 'content'}).find('tbody') - for row in contents.findAll('tr'): + soup = Soup(content, "html.parser") + contents = soup.find("div", {"id": "content"}).find("tbody") + for row in contents.findAll("tr"): try: - tds = row.findAll('td') - title = tds[0].find('a').get_text().strip() + tds = row.findAll("td") + title = tds[0].find("a").get_text().strip() datestr = tds[3].get_text().strip() - date = datetime.strptime(datestr, '%m/%d/%Y') - url = tds[4].find('a')['href'] + date = datetime.strptime(datestr, "%m/%d/%Y") + url = tds[4].find("a")["href"] data = { - 'id': url, - 'title': title.replace('\u200b', ''), - 'published_on': date.date(), - 'url': url, - 'source': 'Humanitarian Response', - 'author': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": url, + "title": title.replace("\u200b", ""), + "published_on": date.date(), + "url": url, + "source": "Humanitarian Response", + "author": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, } results.append(data) except Exception as e: - logger.warning( - "Exception parsing humanitarian response connector: " + - str(e.args) - ) + logger.warning("Exception parsing humanitarian response connector: " + str(e.args)) return results, len(results) diff --git a/apps/unified_connector/sources/pdna.py b/apps/unified_connector/sources/pdna.py index 8a7a7255ee..e11ad182cd 100644 --- a/apps/unified_connector/sources/pdna.py +++ b/apps/unified_connector/sources/pdna.py @@ -1,127 +1,118 @@ import logging -from bs4 import BeautifulSoup as Soup -import requests -from .base import Source +import requests +from bs4 import BeautifulSoup as Soup from connector.utils import ConnectorWrapper from lead.models import Lead +from .base import Source + logger = logging.getLogger(__name__) COUNTRIES_OPTIONS = [ - {'key': 'Somalia', 'label': 'Somalia'}, - {'key': 'Dominica', 'label': 'Dominica'}, - {'key': 'Sri Lanka', 'label': 'Sri Lanka'}, - {'key': 'Sierra Leone', 'label': 'Sierra Leone'}, - {'key': 'Saint Vincent and the Grenadines', 'label': 'Saint Vincent and the Grenadines'}, # noqa - {'key': 'Vietnam', 'label': 'Vietnam'}, - {'key': 'Seychelles', 'label': 'Seychelles'}, - {'key': 'Fiji', 'label': 'Fiji'}, - {'key': 'Myanmar', 'label': 'Myanmar'}, - {'key': 'Georgia', 'label': 'Georgia'}, - {'key': 'Nepal', 'label': 'Nepal'}, - {'key': 'Vanuatu', 'label': 'Vanuatu'}, - {'key': 'Malawi', 'label': 'Malawi'}, - {'key': 'Cabo Verde', 'label': 'Cabo Verde'}, - {'key': 'St. Vincent and the Grenadines', 'label': 'St. Vincent and the Grenadines'}, # noqa - {'key': 'Bosnia and Herzegovena', 'label': 'Bosnia and Herzegovena'}, - {'key': 'Burundi ', 'label': 'Burundi '}, - {'key': 'Solomon Islands', 'label': 'Solomon Islands'}, - {'key': 'Burundi ', 'label': 'Burundi '}, - {'key': 'Seychelles', 'label': 'Seychelles'}, - {'key': 'Nigeria', 'label': 'Nigeria'}, - {'key': 'Fiji', 'label': 'Fiji'}, - {'key': 'Samoa', 'label': 'Samoa'}, - {'key': 'Malawi', 'label': 'Malawi'}, - {'key': 'Bhutan', 'label': 'Bhutan'}, - {'key': 'Pakistan', 'label': 'Pakistan'}, - {'key': 'Thailand', 'label': 'Thailand'}, - {'key': 'Djibouti', 'label': 'Djibouti'}, - {'key': 'Kenya', 'label': 'Kenya'}, - {'key': 'Lao PDR', 'label': 'Lao PDR'}, - {'key': 'Lesotho', 'label': 'Lesotho'}, - {'key': 'Uganda', 'label': 'Uganda'}, - {'key': 'Benin', 'label': 'Benin'}, - {'key': 'Guatemala', 'label': 'Guatemala'}, - {'key': 'Togo', 'label': 'Togo'}, - {'key': 'Pakistan', 'label': 'Pakistan'}, - {'key': 'Moldova', 'label': 'Moldova'}, - {'key': 'Haiti', 'label': 'Haiti'}, - {'key': 'El Salvador', 'label': 'El Salvador'}, - {'key': 'Cambodia', 'label': 'Cambodia'}, - {'key': 'Lao PDR', 'label': 'Lao PDR'}, - {'key': 'Indonesia', 'label': 'Indonesia'}, - {'key': 'Samoa', 'label': 'Samoa'}, - {'key': 'Philippines', 'label': 'Philippines'}, - {'key': 'Bhutan', 'label': 'Bhutan'}, - {'key': 'Burkina Faso ', 'label': 'Burkina Faso '}, - {'key': 'Senegal', 'label': 'Senegal'}, - {'key': 'Central African Republic', 'label': 'Central African Republic'}, - {'key': 'Namibia', 'label': 'Namibia'}, - {'key': 'Yemen', 'label': 'Yemen'}, - {'key': 'Haiti', 'label': 'Haiti'}, - {'key': 'India', 'label': 'India'}, - {'key': 'Myanmar', 'label': 'Myanmar'}, - {'key': 'Bolivia', 'label': 'Bolivia'}, - {'key': 'Madagascar', 'label': 'Madagascar'}, - {'key': 'Bangladesh', 'label': 'Bangladesh'} + {"key": "Somalia", "label": "Somalia"}, + {"key": "Dominica", "label": "Dominica"}, + {"key": "Sri Lanka", "label": "Sri Lanka"}, + {"key": "Sierra Leone", "label": "Sierra Leone"}, + {"key": "Saint Vincent and the Grenadines", "label": "Saint Vincent and the Grenadines"}, # noqa + {"key": "Vietnam", "label": "Vietnam"}, + {"key": "Seychelles", "label": "Seychelles"}, + {"key": "Fiji", "label": "Fiji"}, + {"key": "Myanmar", "label": "Myanmar"}, + {"key": "Georgia", "label": "Georgia"}, + {"key": "Nepal", "label": "Nepal"}, + {"key": "Vanuatu", "label": "Vanuatu"}, + {"key": "Malawi", "label": "Malawi"}, + {"key": "Cabo Verde", "label": "Cabo Verde"}, + {"key": "St. Vincent and the Grenadines", "label": "St. Vincent and the Grenadines"}, # noqa + {"key": "Bosnia and Herzegovena", "label": "Bosnia and Herzegovena"}, + {"key": "Burundi ", "label": "Burundi "}, + {"key": "Solomon Islands", "label": "Solomon Islands"}, + {"key": "Burundi ", "label": "Burundi "}, + {"key": "Seychelles", "label": "Seychelles"}, + {"key": "Nigeria", "label": "Nigeria"}, + {"key": "Fiji", "label": "Fiji"}, + {"key": "Samoa", "label": "Samoa"}, + {"key": "Malawi", "label": "Malawi"}, + {"key": "Bhutan", "label": "Bhutan"}, + {"key": "Pakistan", "label": "Pakistan"}, + {"key": "Thailand", "label": "Thailand"}, + {"key": "Djibouti", "label": "Djibouti"}, + {"key": "Kenya", "label": "Kenya"}, + {"key": "Lao PDR", "label": "Lao PDR"}, + {"key": "Lesotho", "label": "Lesotho"}, + {"key": "Uganda", "label": "Uganda"}, + {"key": "Benin", "label": "Benin"}, + {"key": "Guatemala", "label": "Guatemala"}, + {"key": "Togo", "label": "Togo"}, + {"key": "Pakistan", "label": "Pakistan"}, + {"key": "Moldova", "label": "Moldova"}, + {"key": "Haiti", "label": "Haiti"}, + {"key": "El Salvador", "label": "El Salvador"}, + {"key": "Cambodia", "label": "Cambodia"}, + {"key": "Lao PDR", "label": "Lao PDR"}, + {"key": "Indonesia", "label": "Indonesia"}, + {"key": "Samoa", "label": "Samoa"}, + {"key": "Philippines", "label": "Philippines"}, + {"key": "Bhutan", "label": "Bhutan"}, + {"key": "Burkina Faso ", "label": "Burkina Faso "}, + {"key": "Senegal", "label": "Senegal"}, + {"key": "Central African Republic", "label": "Central African Republic"}, + {"key": "Namibia", "label": "Namibia"}, + {"key": "Yemen", "label": "Yemen"}, + {"key": "Haiti", "label": "Haiti"}, + {"key": "India", "label": "India"}, + {"key": "Myanmar", "label": "Myanmar"}, + {"key": "Bolivia", "label": "Bolivia"}, + {"key": "Madagascar", "label": "Madagascar"}, + {"key": "Bangladesh", "label": "Bangladesh"}, ] @ConnectorWrapper class PDNA(Source): - URL = 'https://www.gfdrr.org/post-disaster-needs-assessments' - title = 'Post Disaster Needs Assessment' - key = 'post-disaster-needs-assessment' - website = 'http://www.gfdrr.org' + URL = "https://www.gfdrr.org/post-disaster-needs-assessments" + title = "Post Disaster Needs Assessment" + key = "post-disaster-needs-assessment" + website = "http://www.gfdrr.org" - options = [ - { - 'key': 'country', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS - } - ] + options = [{"key": "country", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}] def get_content(self, url, params): resp = requests.get(url) return resp.text def fetch(self, params): - country = params.get('country') + country = params.get("country") if not country: return [], 0 results = [] content = self.get_content(self.URL, {}) - soup = Soup(content, 'html.parser') + soup = Soup(content, "html.parser") - contents = soup.findAll('tbody') + contents = soup.findAll("tbody") for content in contents: - for row in content.findAll('tr'): + for row in content.findAll("tr"): try: - elem = row.find('a') + elem = row.find("a") name = elem.get_text() - title = row.findAll('td')[-1].get_text() - published_on = row.findAll('td')[1].get_text() + title = row.findAll("td")[-1].get_text() + published_on = row.findAll("td")[1].get_text() if name.strip() == country.strip(): # add as lead - url = elem['href'] - if url[0] == '/': # means relative path + url = elem["href"] + if url[0] == "/": # means relative path url = self.website + url data = { - 'title': title.strip(), - 'url': url, - 'source': 'PDNA portal', - 'author': 'PDNA portal', - 'published_on': published_on, - 'source_type': Lead.SourceType.WEBSITE, + "title": title.strip(), + "url": url, + "source": "PDNA portal", + "author": "PDNA portal", + "published_on": published_on, + "source_type": Lead.SourceType.WEBSITE, } results.append(data) except Exception as e: - logger.warning( - "Exception parsing {} with params {}: {}".format( - self.URL, params, e.args) - ) + logger.warning("Exception parsing {} with params {}: {}".format(self.URL, params, e.args)) return results, len(results) diff --git a/apps/unified_connector/sources/relief_web.py b/apps/unified_connector/sources/relief_web.py index 83a0292c14..4ab296313e 100644 --- a/apps/unified_connector/sources/relief_web.py +++ b/apps/unified_connector/sources/relief_web.py @@ -1,11 +1,11 @@ -import requests import json -from lead.models import Lead -from .base import Source +import requests from connector.utils import ConnectorWrapper - from django.conf import settings +from lead.models import Lead + +from .base import Source # NOTE: Generated using scripts/list_relief_web_countries.sh COUNTRIES_LIST = [ @@ -271,43 +271,35 @@ def _format_date(datestr): - return datestr + 'T00:00:00+00:00' + return datestr + "T00:00:00+00:00" @ConnectorWrapper class ReliefWeb(Source): - URL = f'https://api.reliefweb.int/v1/reports?appname={settings.RELIEFWEB_APPNAME}' - title = 'ReliefWeb Reports' - key = 'relief-web' + URL = f"https://api.reliefweb.int/v1/reports?appname={settings.RELIEFWEB_APPNAME}" + title = "ReliefWeb Reports" + key = "relief-web" options = [ { - 'key': 'primary-country', - 'field_type': 'select', - 'title': 'Primary Country', - 'options': COUNTRIES, + "key": "primary-country", + "field_type": "select", + "title": "Primary Country", + "options": COUNTRIES, }, { - 'key': 'country', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES, + "key": "country", + "field_type": "select", + "title": "Country", + "options": COUNTRIES, }, - { - 'key': 'from', - 'field_type': 'date', - 'title': 'Reports since' - }, - { - 'key': 'to', - 'field_type': 'date', - 'title': 'Reports until' - } + {"key": "from", "field_type": "date", "title": "Reports since"}, + {"key": "to", "field_type": "date", "title": "Reports until"}, ] filters = [ { - 'key': 'search', - 'field_type': 'string', - 'title': 'Search', + "key": "search", + "field_type": "string", + "title": "Search", }, ] @@ -318,46 +310,42 @@ def get_content(self, url, params): def parse_filter_params(self, params): filters = [] - if params.get('country'): - filters.append({'field': 'country.iso3', 'value': params['country']}) - if params.get('primary-country'): - filters.append({'field': 'primary_country.iso3', 'value': params['primary-country']}) + if params.get("country"): + filters.append({"field": "country.iso3", "value": params["country"]}) + if params.get("primary-country"): + filters.append({"field": "primary_country.iso3", "value": params["primary-country"]}) date_filter = {} # If date is obtained, it must be formatted to the ISO string with timezone info # the _format_date just appends 00:00:00 Time and +00:00 tz info - if params.get('from'): - date_filter['from'] = _format_date(params['from']) - if params.get('to'): - date_filter["to"] = _format_date(params['to']) + if params.get("from"): + date_filter["from"] = _format_date(params["from"]) + if params.get("to"): + date_filter["to"] = _format_date(params["to"]) if date_filter: - filters.append({'field': 'date.original', 'value': date_filter}) + filters.append({"field": "date.original", "value": date_filter}) if filters: - return {'operator': 'AND', 'conditions': filters} + return {"operator": "AND", "conditions": filters} return {} def fetch(self, params): results = [] post_params = {} - post_params['fields'] = { - 'include': [ - 'url_alias', 'title', 'date.original', 'file', 'source', 'source.homepage' - ] - } + post_params["fields"] = {"include": ["url_alias", "title", "date.original", "file", "source", "source.homepage"]} - post_params['filter'] = self.parse_filter_params(params) + post_params["filter"] = self.parse_filter_params(params) - if params.get('search'): - post_params['query'] = { - 'value': params['search'], - 'fields': ['title'], - 'operator': 'AND', + if params.get("search"): + post_params["query"] = { + "value": params["search"], + "fields": ["title"], + "operator": "AND", } - post_params['limit'] = 1000 - post_params['sort'] = ['date.original:desc', 'title:asc'] + post_params["limit"] = 1000 + post_params["sort"] = ["date.original:desc", "title:asc"] relief_url = self.URL total_count = 0 @@ -365,23 +353,23 @@ def fetch(self, params): while relief_url is not None: content = self.get_content(relief_url, post_params) resp = json.loads(content) - total_count += resp['totalCount'] + total_count += resp["totalCount"] - for datum in resp['data']: - fields = datum['fields'] - url = fields['file'][0]['url'] if fields.get('file') else fields['url_alias'] - title = fields['title'] - published_on = (fields.get('date') or {}).get('original') - author = ((fields.get('source') or [{}])[0] or {}).get('name') + for datum in resp["data"]: + fields = datum["fields"] + url = fields["file"][0]["url"] if fields.get("file") else fields["url_alias"] + title = fields["title"] + published_on = (fields.get("date") or {}).get("original") + author = ((fields.get("source") or [{}])[0] or {}).get("name") lead = { - 'id': str(datum['id']), - 'title': title, - 'published_on': published_on, - 'url': url, - 'source': 'reliefweb', - 'source_type': Lead.SourceType.WEBSITE.value, - 'author': author, + "id": str(datum["id"]), + "title": title, + "published_on": published_on, + "url": url, + "source": "reliefweb", + "source_type": Lead.SourceType.WEBSITE.value, + "author": author, } results.append(lead) - relief_url = ((resp.get('links') or {}).get('next') or {}).get('href') + relief_url = ((resp.get("links") or {}).get("next") or {}).get("href") return results, total_count diff --git a/apps/unified_connector/sources/research_center.py b/apps/unified_connector/sources/research_center.py index d2bbd5e0bc..690a9ad44f 100644 --- a/apps/unified_connector/sources/research_center.py +++ b/apps/unified_connector/sources/research_center.py @@ -1,72 +1,65 @@ -import requests import datetime -from bs4 import BeautifulSoup as Soup -from .base import Source +import requests +from bs4 import BeautifulSoup as Soup from connector.utils import ConnectorWrapper from lead.models import Lead +from .base import Source COUNTRIES_OPTIONS = [ - {'key': 'AF', 'label': 'Afghanistan'}, - {'key': 'BD', 'label': 'Bangladesh'}, - {'key': 'BW', 'label': 'Botswana'}, - {'key': 'BR', 'label': 'Brazil'}, - {'key': 'CF', 'label': 'Central African Republic'}, - {'key': 'TD', 'label': 'Chad'}, - {'key': 'HR', 'label': 'Croatia'}, - {'key': 'CD', 'label': 'Democratic Republic of the Congo'}, - {'key': 'GR', 'label': 'Greece'}, - {'key': 'HT', 'label': 'Haiti'}, - {'key': 'HU', 'label': 'Hungary'}, - {'key': 'IN', 'label': 'India'}, - {'key': 'IQ', 'label': 'Iraq'}, - {'key': 'IT', 'label': 'Italy'}, - {'key': 'JO', 'label': 'Jordan'}, - {'key': 'KE', 'label': 'Kenya'}, - {'key': 'KG', 'label': 'Kyrgyzstan'}, - {'key': 'LB', 'label': 'Lebanon'}, - {'key': 'LY', 'label': 'Libya'}, - {'key': 'MK', 'label': 'Macedonia'}, - {'key': 'ML', 'label': 'Mali'}, - {'key': 'MM', 'label': 'Myanmar'}, - {'key': 'NP', 'label': 'Nepal'}, - {'key': 'NE', 'label': 'Niger'}, - {'key': 'NG', 'label': 'Nigeria'}, - {'key': 'PS', 'label': 'Palestinian Territory'}, - {'key': 'PE', 'label': 'Peru'}, - {'key': 'PH', 'label': 'Philippines'}, - {'key': 'RS', 'label': 'Serbia'}, - {'key': 'SI', 'label': 'Slovenia'}, - {'key': 'SO', 'label': 'Somalia'}, - {'key': 'SS', 'label': 'South Sudan'}, - {'key': 'ES', 'label': 'Spain'}, - {'key': 'SD', 'label': 'Sudan'}, - {'key': 'SY', 'label': 'Syria'}, - {'key': 'CG', 'label': 'The Republic of the Congo'}, - {'key': 'TL', 'label': 'Timor-Leste'}, - {'key': 'TR', 'label': 'Turkey'}, - {'key': 'UG', 'label': 'Uganda'}, - {'key': 'UA', 'label': 'Ukraine'}, - {'key': 'VU', 'label': 'Vanuatu'}, - {'key': 'YE', 'label': 'Yemen'} + {"key": "AF", "label": "Afghanistan"}, + {"key": "BD", "label": "Bangladesh"}, + {"key": "BW", "label": "Botswana"}, + {"key": "BR", "label": "Brazil"}, + {"key": "CF", "label": "Central African Republic"}, + {"key": "TD", "label": "Chad"}, + {"key": "HR", "label": "Croatia"}, + {"key": "CD", "label": "Democratic Republic of the Congo"}, + {"key": "GR", "label": "Greece"}, + {"key": "HT", "label": "Haiti"}, + {"key": "HU", "label": "Hungary"}, + {"key": "IN", "label": "India"}, + {"key": "IQ", "label": "Iraq"}, + {"key": "IT", "label": "Italy"}, + {"key": "JO", "label": "Jordan"}, + {"key": "KE", "label": "Kenya"}, + {"key": "KG", "label": "Kyrgyzstan"}, + {"key": "LB", "label": "Lebanon"}, + {"key": "LY", "label": "Libya"}, + {"key": "MK", "label": "Macedonia"}, + {"key": "ML", "label": "Mali"}, + {"key": "MM", "label": "Myanmar"}, + {"key": "NP", "label": "Nepal"}, + {"key": "NE", "label": "Niger"}, + {"key": "NG", "label": "Nigeria"}, + {"key": "PS", "label": "Palestinian Territory"}, + {"key": "PE", "label": "Peru"}, + {"key": "PH", "label": "Philippines"}, + {"key": "RS", "label": "Serbia"}, + {"key": "SI", "label": "Slovenia"}, + {"key": "SO", "label": "Somalia"}, + {"key": "SS", "label": "South Sudan"}, + {"key": "ES", "label": "Spain"}, + {"key": "SD", "label": "Sudan"}, + {"key": "SY", "label": "Syria"}, + {"key": "CG", "label": "The Republic of the Congo"}, + {"key": "TL", "label": "Timor-Leste"}, + {"key": "TR", "label": "Turkey"}, + {"key": "UG", "label": "Uganda"}, + {"key": "UA", "label": "Ukraine"}, + {"key": "VU", "label": "Vanuatu"}, + {"key": "YE", "label": "Yemen"}, ] @ConnectorWrapper class ResearchResourceCenter(Source): - URL = 'http://www.reachresourcecentre.info/advanced-search' - title = 'Research Resource Center' - key = 'research-resource-center' + URL = "http://www.reachresourcecentre.info/advanced-search" + title = "Research Resource Center" + key = "research-resource-center" - options = [ - { - 'key': 'name_list[]', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS - } - ] + options = [{"key": "name_list[]", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}] def get_content(self, url, params): resp = requests.get(self.URL, params=params) @@ -75,24 +68,24 @@ def get_content(self, url, params): def fetch(self, params): results = [] content = self.get_content(self.URL, params) - soup = Soup(content, 'html.parser') - contents = soup.find('table').find('tbody').findAll('tr') + soup = Soup(content, "html.parser") + contents = soup.find("table").find("tbody").findAll("tr") total_count = len(contents) limited_contents = contents for row in limited_contents: - tds = row.findAll('td') - title = tds[0].get_text().replace('_', ' ') - date = tds[1].find('span').attrs['content'][:10] # just date str # noqa - date = datetime.datetime.strptime(date, '%Y-%m-%d') - url = tds[0].find('a').attrs['href'] + tds = row.findAll("td") + title = tds[0].get_text().replace("_", " ") + date = tds[1].find("span").attrs["content"][:10] # just date str # noqa + date = datetime.datetime.strptime(date, "%Y-%m-%d") + url = tds[0].find("a").attrs["href"] data = { - 'title': title.strip(), - 'published_on': date.date(), - 'url': url, - 'source': "Research Resource Center", - 'author': "Research Resource Center", - 'source_type': Lead.SourceType.WEBSITE, + "title": title.strip(), + "published_on": date.date(), + "url": url, + "source": "Research Resource Center", + "author": "Research Resource Center", + "source_type": Lead.SourceType.WEBSITE, } results.append(data) return results, total_count diff --git a/apps/unified_connector/sources/rss_feed.py b/apps/unified_connector/sources/rss_feed.py index 694fc5fd62..f0e036793c 100644 --- a/apps/unified_connector/sources/rss_feed.py +++ b/apps/unified_connector/sources/rss_feed.py @@ -1,71 +1,65 @@ -from rest_framework import serializers -from lxml import etree import requests +from connector.utils import ConnectorWrapper, get_rss_fields +from lead.models import Lead +from lxml import etree +from rest_framework import serializers from utils.common import DEFAULT_HEADERS -from lead.models import Lead + from .base import Source -from connector.utils import get_rss_fields, ConnectorWrapper def _get_field_value(item, field): if not field: - return '' + return "" element = item.find(field) - return '' if element is None else element.text or element.get('href') + return "" if element is None else element.text or element.get("href") @ConnectorWrapper class RssFeed(Source): - title = 'RSS Feed' - key = 'rss-feed' + title = "RSS Feed" + key = "rss-feed" options = [ + {"key": "feed-url", "field_type": "url", "title": "Feed URL"}, { - 'key': 'feed-url', - 'field_type': 'url', - 'title': 'Feed URL' - }, - { - 'key': 'title-field', - 'field_type': 'select', - 'lead_field': 'title', - 'title': 'Title field', - 'options': [], + "key": "title-field", + "field_type": "select", + "lead_field": "title", + "title": "Title field", + "options": [], }, { - 'key': 'date-field', - 'field_type': 'select', - 'lead_field': 'published_on', - 'title': 'Published on field', - 'options': [], + "key": "date-field", + "field_type": "select", + "lead_field": "published_on", + "title": "Published on field", + "options": [], }, { - 'key': 'source-field', - 'field_type': 'select', - 'lead_field': 'source', - 'title': 'Publisher field', - 'options': [], + "key": "source-field", + "field_type": "select", + "lead_field": "source", + "title": "Publisher field", + "options": [], }, { - 'key': 'author-field', - 'field_type': 'select', - 'lead_field': 'author', - 'title': 'Author field', - 'options': [], + "key": "author-field", + "field_type": "select", + "lead_field": "author", + "title": "Author field", + "options": [], }, { - 'key': 'url-field', - 'field_type': 'select', - 'lead_field': 'url', - 'title': 'URL field', - 'options': [], + "key": "url-field", + "field_type": "select", + "lead_field": "url", + "title": "URL field", + "options": [], }, ] - _option_lead_field_map = { - option['lead_field']: option['key'] - for option in options if option.get('lead_field') - } + _option_lead_field_map = {option["lead_field"]: option["key"] for option in options if option.get("lead_field")} dynamic_fields = [1, 2, 3, 4, 5] @@ -74,29 +68,25 @@ def get_content(self, url, params): return resp.content def query_fields(self, params): - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return [] try: - r = requests.get(params['feed-url'], headers=DEFAULT_HEADERS) + r = requests.get(params["feed-url"], headers=DEFAULT_HEADERS) xml = etree.fromstring(r.content) except requests.exceptions.RequestException: - raise serializers.ValidationError({ - 'feed-url': 'Could not fetch rss feed' - }) + raise serializers.ValidationError({"feed-url": "Could not fetch rss feed"}) except etree.XMLSyntaxError: - raise serializers.ValidationError({ - 'feed-url': 'Invalid XML' - }) + raise serializers.ValidationError({"feed-url": "Invalid XML"}) - item = xml.find('channel/item') + item = xml.find("channel/item") if not item: return [] nsmap = xml.nsmap fields = [] - for field in item.findall('./'): + for field in item.findall("./"): fields.extend(get_rss_fields(field, nsmap)) # Remove fields that are present more than once, @@ -109,16 +99,16 @@ def query_fields(self, params): def fetch(self, params): results = [] - if not params or not params.get('feed-url'): + if not params or not params.get("feed-url"): return results, 0 - content = self.get_content(params['feed-url'], {}) + content = self.get_content(params["feed-url"], {}) xml = etree.fromstring(content) - items = xml.findall('channel/item') + items = xml.findall("channel/item") for item in items: data = { - 'source_type': Lead.SourceType.RSS, + "source_type": Lead.SourceType.RSS, **{ lead_field: _get_field_value(item, params.get(param_key)) for lead_field, param_key in self._option_lead_field_map.items() diff --git a/apps/unified_connector/sources/unhcr_portal.py b/apps/unified_connector/sources/unhcr_portal.py index 211f3b1c8b..a43ec113ff 100644 --- a/apps/unified_connector/sources/unhcr_portal.py +++ b/apps/unified_connector/sources/unhcr_portal.py @@ -1,16 +1,15 @@ -import json import copy -import requests import datetime +import json +import requests from bs4 import BeautifulSoup as Soup +from connector.utils import ConnectorWrapper from utils.common import deep_date_format -from connector.utils import ConnectorWrapper from .base import Source - COUNTRIES_OPTIONS = [ {"label": "All", "key": ""}, {"label": "Afghanistan", "key": "575"}, @@ -217,37 +216,32 @@ def _format_date_or_none(iso_datestr): try: - return deep_date_format(datetime.datetime.strptime(iso_datestr, '%Y-%m-%d')) + return deep_date_format(datetime.datetime.strptime(iso_datestr, "%Y-%m-%d")) except Exception: return None @ConnectorWrapper class UNHCRPortal(Source): - URL = 'https://data2.unhcr.org/en/search' - title = 'UNHCR Portal' - key = 'unhcr-portal' + URL = "https://data2.unhcr.org/en/search" + title = "UNHCR Portal" + key = "unhcr-portal" options = [ + {"key": "country", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}, { - 'key': 'country', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS - }, - { - 'key': 'date_from', - 'field_type': 'date', - 'title': 'From', + "key": "date_from", + "field_type": "date", + "title": "From", }, { - 'key': 'date_to', - 'field_type': 'date', - 'title': 'To', + "key": "date_to", + "field_type": "date", + "title": "To", }, ] params = { - 'type[]': [ - 'document', + "type[]": [ + "document", # NOTE: for now have only documents as results, other do not seem # to be parsable # 'link', 'news', 'highlight' @@ -261,16 +255,16 @@ def fetch(self, params): results = [] updated_params = copy.deepcopy(params) - country = updated_params.pop('country', None) - date_from = _format_date_or_none(updated_params.pop('date_from', None)) - date_to = _format_date_or_none(updated_params.pop('date_to', None)) + country = updated_params.pop("country", None) + date_from = _format_date_or_none(updated_params.pop("date_from", None)) + date_to = _format_date_or_none(updated_params.pop("date_to", None)) if country: - updated_params['country_json'] = json.dumps({'0': country}) - updated_params['country'] = country + updated_params["country_json"] = json.dumps({"0": country}) + updated_params["country"] = country if date_from: - updated_params['date_from'] = date_from + updated_params["date_from"] = date_from if date_to: - updated_params['date_to'] = date_to + updated_params["date_to"] = date_to updated_params.update(self.params) # type is default @@ -278,47 +272,38 @@ def fetch(self, params): while True: if page > self.UNIFIED_CONNECTOR_SOURCE_MAX_PAGE_NUMBER: break - updated_params['page'] = page + updated_params["page"] = page content = self.get_content(self.URL, updated_params) - soup = Soup(content, 'html.parser') - contents = soup.findAll('ul', {'class': 'searchResults'}) + soup = Soup(content, "html.parser") + contents = soup.findAll("ul", {"class": "searchResults"}) if not contents: return results, 0 content = contents[0] - items = content.findAll('li', {'class': ['searchResultItem']}) + items = content.findAll("li", {"class": ["searchResultItem"]}) for item in items: - itemcontent = item.find( - 'div', - {'class': ['searchResultItem_content', 'media_body']} - ) - urlcontent = item.find( - 'div', - {'class': 'searchResultItem_download'} - ) - datecontent = item.find( - 'span', - {'class': 'searchResultItem_date'} - ) - title = itemcontent.find('a').get_text() - pdfurl = urlcontent.find('a')['href'] - raw_date = datecontent.find('b').get_text() # 4 July 2018 - date = datetime.datetime.strptime(raw_date, '%d %B %Y') + itemcontent = item.find("div", {"class": ["searchResultItem_content", "media_body"]}) + urlcontent = item.find("div", {"class": "searchResultItem_download"}) + datecontent = item.find("span", {"class": "searchResultItem_date"}) + title = itemcontent.find("a").get_text() + pdfurl = urlcontent.find("a")["href"] + raw_date = datecontent.find("b").get_text() # 4 July 2018 + date = datetime.datetime.strptime(raw_date, "%d %B %Y") data = { - 'title': title and title.strip(), - 'published_on': date.date(), - 'url': pdfurl, - 'source': 'UNHCR Portal', - 'author': '', - 'source_type': '', + "title": title and title.strip(), + "published_on": date.date(), + "url": pdfurl, + "source": "UNHCR Portal", + "author": "", + "source_type": "", } results.append(data) - footer = soup.find('div', {'class': 'pgSearch_results_footer'}) + footer = soup.find("div", {"class": "pgSearch_results_footer"}) if not footer: break - next_url = footer.find('a', {'rel': 'next'}) - if next_url and next_url['href']: + next_url = footer.find("a", {"rel": "next"}) + if next_url and next_url["href"]: page += 1 else: break diff --git a/apps/unified_connector/sources/wpf.py b/apps/unified_connector/sources/wpf.py index 300227642a..57bb49faf6 100644 --- a/apps/unified_connector/sources/wpf.py +++ b/apps/unified_connector/sources/wpf.py @@ -1,10 +1,10 @@ import requests from bs4 import BeautifulSoup as Soup - -from .base import Source from connector.utils import ConnectorWrapper from lead.models import Lead +from .base import Source + COUNTRIES_OPTIONS = [ {"key": "All", "label": "Any"}, {"key": "120", "label": "Afghanistan"}, @@ -139,22 +139,12 @@ @ConnectorWrapper class WorldFoodProgramme(Source): - URL = 'https://www.wfp.org/food-security/assessment-bank' - title = 'WFP Assessments' - key = 'world-food-programme' + URL = "https://www.wfp.org/food-security/assessment-bank" + title = "WFP Assessments" + key = "world-food-programme" options = [ - { - 'key': 'tid_1', - 'field_type': 'select', - 'title': 'Country', - 'options': COUNTRIES_OPTIONS - }, - { - 'key': 'tid_6', - 'field_type': 'select', - 'title': 'Year', - 'options': YEAR_OPTIONS - }, + {"key": "tid_1", "field_type": "select", "title": "Country", "options": COUNTRIES_OPTIONS}, + {"key": "tid_6", "field_type": "select", "title": "Year", "options": YEAR_OPTIONS}, ] def get_content(self, url, params): @@ -165,21 +155,21 @@ def fetch(self, params): results = [] content = self.get_content(self.URL, params) - soup = Soup(content, 'html.parser') + soup = Soup(content, "html.parser") - contents = soup.find('div', {'class': 'view-content'}) + contents = soup.find("div", {"class": "view-content"}) if not contents: return results, len(results) # iterate and get leads - for row in contents.findAll('div', {'class': 'views-row'}): - content = row.find('h3').find('a') + for row in contents.findAll("div", {"class": "views-row"}): + content = row.find("h3").find("a") title = content.get_text() - url = content['href'] + url = content["href"] data = { - 'title': title.strip(), - 'url': url, - 'source': 'WFP Assessments', - 'source_type': Lead.SourceType.WEBSITE, + "title": title.strip(), + "url": url, + "source": "WFP Assessments", + "source_type": Lead.SourceType.WEBSITE, } results.append(data) return results, len(results) diff --git a/apps/unified_connector/tasks.py b/apps/unified_connector/tasks.py index ba83b67ee9..b2b65616bb 100644 --- a/apps/unified_connector/tasks.py +++ b/apps/unified_connector/tasks.py @@ -2,52 +2,51 @@ from datetime import timedelta from celery import shared_task -from django.utils import timezone +from deepl_integration.handlers import UnifiedConnectorLeadHandler from django.db import models +from django.utils import timezone from utils.common import redis_lock -from deepl_integration.handlers import UnifiedConnectorLeadHandler - -from .models import ( - ConnectorLead, - ConnectorSource, -) +from .models import ConnectorLead, ConnectorSource logger = logging.getLogger(__name__) @shared_task -@redis_lock('process_unified_connector_{0}', 60 * 60 * 0.5) +@redis_lock("process_unified_connector_{0}", 60 * 60 * 0.5) def process_unified_connector(_id): try: return UnifiedConnectorLeadHandler.process_unified_connector(_id) except Exception: - logger.error('Unified connector process failed', exc_info=True) + logger.error("Unified connector process failed", exc_info=True) @shared_task -@redis_lock('retry_connector_leads', 60 * 60 * 0.5) +@redis_lock("retry_connector_leads", 60 * 60 * 0.5) def retry_connector_leads(): try: - return UnifiedConnectorLeadHandler.send_retry_trigger_request_to_extractor( - ConnectorLead.objects.all() - ) + return UnifiedConnectorLeadHandler.send_retry_trigger_request_to_extractor(ConnectorLead.objects.all()) except Exception: - logger.error('Retry connector lead failed', exc_info=True) + logger.error("Retry connector lead failed", exc_info=True) def trigger_connector_sources(max_execution_time, threshold, limit): - sources_qs = ConnectorSource.objects.annotate( - execution_time=models.F('end_date') - models.F('start_date'), - ).exclude( - execution_time__isnull=False, - status=ConnectorSource.Status.PROCESSING, - ).filter( - unified_connector__is_active=True, - execution_time__lte=max_execution_time, - last_fetched_at__lte=timezone.now() - threshold, - ).order_by('execution_time') + sources_qs = ( + ConnectorSource.objects.annotate( + execution_time=models.F("end_date") - models.F("start_date"), + ) + .exclude( + execution_time__isnull=False, + status=ConnectorSource.Status.PROCESSING, + ) + .filter( + unified_connector__is_active=True, + execution_time__lte=max_execution_time, + last_fetched_at__lte=timezone.now() - threshold, + ) + .order_by("execution_time") + ) processed_unified_connectors = set() for source in sources_qs.all()[:limit]: @@ -55,18 +54,16 @@ def trigger_connector_sources(max_execution_time, threshold, limit): UnifiedConnectorLeadHandler.process_unified_connector_source(source) processed_unified_connectors.add(source.unified_connector_id) except Exception: - logger.error('Failed to trigger connector source', exc_info=True) + logger.error("Failed to trigger connector source", exc_info=True) # Trigger connector leads for unified_connector_id in processed_unified_connectors: UnifiedConnectorLeadHandler.send_trigger_request_to_extractor( - ConnectorLead.objects.filter( - connectorsourcelead__source__unified_connector=unified_connector_id - ) + ConnectorLead.objects.filter(connectorsourcelead__source__unified_connector=unified_connector_id) ) @shared_task -@redis_lock('schedule_trigger_quick_unified_connectors', 60 * 60) +@redis_lock("schedule_trigger_quick_unified_connectors", 60 * 60) def schedule_trigger_quick_unified_connectors(): # NOTE: Process connectors sources which have runtime <= 3 min and was processed 3 hours before. trigger_connector_sources( @@ -77,7 +74,7 @@ def schedule_trigger_quick_unified_connectors(): @shared_task -@redis_lock('schedule_trigger_heavy_unified_connectors', 60 * 60) +@redis_lock("schedule_trigger_heavy_unified_connectors", 60 * 60) def schedule_trigger_heavy_unified_connectors(): # NOTE: Process connectors sources which have runtime <= 10 min and was processed 3 hours ago. trigger_connector_sources( @@ -88,7 +85,7 @@ def schedule_trigger_heavy_unified_connectors(): @shared_task -@redis_lock('schedule_trigger_super_heavy_unified_connectors', 60 * 60) +@redis_lock("schedule_trigger_super_heavy_unified_connectors", 60 * 60) def schedule_trigger_super_heavy_unified_connectors(): # NOTE: Process connectors sources which have runtime <= 1 hour and was processed 24 hours ago. trigger_connector_sources( diff --git a/apps/unified_connector/tests/mock_data/atom_feed_mock_data.py b/apps/unified_connector/tests/mock_data/atom_feed_mock_data.py index f4a4868fca..4909d768c7 100644 --- a/apps/unified_connector/tests/mock_data/atom_feed_mock_data.py +++ b/apps/unified_connector/tests/mock_data/atom_feed_mock_data.py @@ -1,7 +1,8 @@ import datetime + from lead.models import Lead -ATOM_FEED_MOCK_DATA_RAW = ''' +ATOM_FEED_MOCK_DATA_RAW = """ @@ -74,7 +75,7 @@ Time Series Modelling -''' +""" ATOM_FEED_PARAMS = { "feed-url": "test-url", @@ -82,7 +83,7 @@ "author-field": "author", "source-field": "source", "date-field": "published", - "title-field": "title" + "title-field": "title", } ATOM_FEED_MOCK_EXCEPTED_LEADS = [ @@ -91,34 +92,34 @@ "author_raw": "/u/Im__Joseph", "published_on": datetime.date(2022, 5, 15), "title": "Sunday Daily Thread: What's everyone working on this week?", - "source_type": Lead.SourceType.RSS + "source_type": Lead.SourceType.RSS, }, { "url": "https://www.reddit.com/r/Python/comments/usqy75/best_methods_to_run_a_continuous_python_script_on/", "author_raw": "/u/BigdadEdge", "published_on": datetime.date(2022, 5, 19), "title": "Best methods to run a continuous Python script on a server", - "source_type": Lead.SourceType.RSS + "source_type": Lead.SourceType.RSS, }, { "url": "https://www.reddit.com/r/Python/comments/usjg8k/arcade_2614_has_been_released_2d_game_library/", "author_raw": "/u/pvc", "published_on": datetime.date(2022, 5, 18), "title": "Arcade 2.6.14 has been released (2D game library)", - "source_type": Lead.SourceType.RSS + "source_type": Lead.SourceType.RSS, }, { "url": "https://www.reddit.com/r/Python/comments/usdhpf/i_made_a_browser_extension_for_quick_nested/", "author_raw": "/u/jabza_", "published_on": datetime.date(2022, 5, 18), "title": "I made a browser extension for quick nested browsing of the Python docs (and others)", - "source_type": Lead.SourceType.RSS + "source_type": Lead.SourceType.RSS, }, { "url": "https://www.reddit.com/r/Python/comments/usqwnc/time_series_modelling/", "author_raw": "/u/badassbilla", "published_on": datetime.date(2022, 5, 19), "title": "Time Series Modelling", - "source_type": Lead.SourceType.RSS + "source_type": Lead.SourceType.RSS, }, ] diff --git a/apps/unified_connector/tests/mock_data/emm_mock_data.py b/apps/unified_connector/tests/mock_data/emm_mock_data.py index b85b8e21a9..9cb226c608 100644 --- a/apps/unified_connector/tests/mock_data/emm_mock_data.py +++ b/apps/unified_connector/tests/mock_data/emm_mock_data.py @@ -1,6 +1,6 @@ import datetime -EMM_MOCK_DATA_RAW = ''' +EMM_MOCK_DATA_RAW = """ Latest news clusters for en @@ -83,15 +83,17 @@ -'''.encode('utf-8') +""".encode( + "utf-8" +) EMM_PARAMS = { - 'feed-url': "test-url", - 'url-field': 'link', - 'date-field': 'pubDate', - 'source-field': 'source', - 'author-field': 'source', - 'title-field': 'title', + "feed-url": "test-url", + "url-field": "link", + "date-field": "pubDate", + "source-field": "source", + "author-field": "source", + "title-field": "title", } EMM_MOCK_EXCEPTED_LEADS = [ @@ -99,24 +101,24 @@ "url": "https://www.heraldsun.com.au/business/rba-governor-says-he-doesnt-see-a-recession-on-the-horizon/video/56f4558c345418ad2964232cdbd06ca7?nk=0ad220a3bbd23f948b26d2da2bb5ac09-1655775162", "published_on": datetime.date(2022, 6, 21), "title": "RBA Governor says he ‘doesn’t see a recession on the horizon’", - "source_type": 'emm', - 'author_raw': 'heraldsun', - 'source_raw': 'heraldsun', + "source_type": "emm", + "author_raw": "heraldsun", + "source_raw": "heraldsun", }, { "url": "https://www.business-standard.com/article/current-affairs/goyal-calls-for-digital-media-use-for-speedy-consumer-complaint-redressal-122062100114_1.html", "published_on": datetime.date(2022, 6, 21), "title": "Goyal calls for digital media use for speedy consumer complaint redressal", - "source_type": 'emm', - 'author_raw': 'business-standard', - 'source_raw': 'business-standard', + "source_type": "emm", + "author_raw": "business-standard", + "source_raw": "business-standard", }, { "url": "https://www.longbeachstar.com/news/272592709/top-saints-laud-pm-modi-in-karnataka-mysuru", "published_on": datetime.date(2022, 6, 21), "title": "Top saints laud PM Modi in Karnataka's Mysuru", - "source_type": 'emm', - 'author_raw': 'longbeachstar', - 'source_raw': 'longbeachstar', - } + "source_type": "emm", + "author_raw": "longbeachstar", + "source_raw": "longbeachstar", + }, ] diff --git a/apps/unified_connector/tests/mock_data/humanitarian_response_mock_data.py b/apps/unified_connector/tests/mock_data/humanitarian_response_mock_data.py index f7efe4ac97..0dfc32a5d1 100644 --- a/apps/unified_connector/tests/mock_data/humanitarian_response_mock_data.py +++ b/apps/unified_connector/tests/mock_data/humanitarian_response_mock_data.py @@ -2,7 +2,7 @@ from lead.models import Lead -HUMANITARIAN_RESPONSE_MOCK_DATA_RAW = ''' +HUMANITARIAN_RESPONSE_MOCK_DATA_RAW = """ @@ -160,61 +160,61 @@ -''' +""" HUMANITARIAN_RESPONSE_MOCK_EXCEPTED_LEADS = [ { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_avril_2021.pdf', - 'title': 'Compte rendu – avril 2021', - 'published_on': datetime.date(2028, 4, 28), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_avril_2021.pdf', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_avril_2021.pdf", + "title": "Compte rendu – avril 2021", + "published_on": datetime.date(2028, 4, 28), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_avril_2021.pdf", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, }, { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_janvier_2021.pdf', - 'title': 'Compte rendu – janvier 2021', - 'published_on': datetime.date(2027, 1, 27), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_janvier_2021.pdf', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_janvier_2021.pdf", + "title": "Compte rendu – janvier 2021", + "published_on": datetime.date(2027, 1, 27), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/compte_rendu_-_janvier_2021.pdf", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, }, { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/2021-12_sudan_2022_humanitarian_response_plan_january_-_december_2022.pdf', - 'title': 'Sudan 2022 Humanitarian Response Plan (January - December 2022)', - 'published_on': datetime.date(2022, 12, 20), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/2021-12_sudan_2022_humanitarian_response_plan_january_-_december_2022.pdf', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/2021-12_sudan_2022_humanitarian_response_plan_january_-_december_2022.pdf", + "title": "Sudan 2022 Humanitarian Response Plan (January - December 2022)", + "published_on": datetime.date(2022, 12, 20), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/2021-12_sudan_2022_humanitarian_response_plan_january_-_december_2022.pdf", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, }, { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/10_27_2021_bay_south-west_wash_cluster_meeting_minutes.pdf', - 'title': '10_27_2021_Bay_South-West WASH Cluster Meeting Minutes', - 'published_on': datetime.date(2022, 10, 27), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/10_27_2021_bay_south-west_wash_cluster_meeting_minutes.pdf', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/10_27_2021_bay_south-west_wash_cluster_meeting_minutes.pdf", + "title": "10_27_2021_Bay_South-West WASH Cluster Meeting Minutes", + "published_on": datetime.date(2022, 10, 27), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/10_27_2021_bay_south-west_wash_cluster_meeting_minutes.pdf", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, }, { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/flood_rna.zip', - 'title': 'NWS AAWG - Flood RNA - 2021', - 'published_on': datetime.date(2022, 10, 15), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/flood_rna.zip', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/flood_rna.zip", + "title": "NWS AAWG - Flood RNA - 2021", + "published_on": datetime.date(2022, 10, 15), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/flood_rna.zip", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, }, { - 'id': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/chapeau_isp_-_cameroon_sept_2021_v2.pdf', - 'title': 'Cameroon: Chapeau - Information Sharing Protocols for Data Responsibility, September 2021', - 'published_on': datetime.date(2022, 10, 3), - 'url': 'https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/chapeau_isp_-_cameroon_sept_2021_v2.pdf', - 'source_raw': 'Humanitarian Response', - 'author_raw': 'Humanitarian Response', - 'source_type': Lead.SourceType.WEBSITE - } + "id": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/chapeau_isp_-_cameroon_sept_2021_v2.pdf", + "title": "Cameroon: Chapeau - Information Sharing Protocols for Data Responsibility, September 2021", + "published_on": datetime.date(2022, 10, 3), + "url": "https://www.humanitarianresponse.info/sites/www.humanitarianresponse.info/files/documents/files/chapeau_isp_-_cameroon_sept_2021_v2.pdf", + "source_raw": "Humanitarian Response", + "author_raw": "Humanitarian Response", + "source_type": Lead.SourceType.WEBSITE, + }, ] diff --git a/apps/unified_connector/tests/mock_data/pdna_mock_data.py b/apps/unified_connector/tests/mock_data/pdna_mock_data.py index 02ab6759bf..292717ee45 100644 --- a/apps/unified_connector/tests/mock_data/pdna_mock_data.py +++ b/apps/unified_connector/tests/mock_data/pdna_mock_data.py @@ -2,7 +2,7 @@ from lead.models import Lead -PDNA_MOCK_DATA_RAW = ''' +PDNA_MOCK_DATA_RAW = """ @@ -46,7 +46,7 @@ -''' +""" PDNA_PARAMS = { "country": "Nepal", @@ -54,11 +54,11 @@ PDNA_MOCK_EXCEPTED_LEADS = [ { - 'title': 'Earthquake', - 'url': 'https://www.gfdrr.org/sites/default/files/publication/Nepal%20Earthquake%202015%20Post-Disaster%20Needs%20Assessment%20Vol%20A.pdf', - 'source_raw': 'PDNA portal', - 'author_raw': 'PDNA portal', - 'published_on': datetime.date(2015, datetime.date.today().month, datetime.date.today().day), - 'source_type': Lead.SourceType.WEBSITE, + "title": "Earthquake", + "url": "https://www.gfdrr.org/sites/default/files/publication/Nepal%20Earthquake%202015%20Post-Disaster%20Needs%20Assessment%20Vol%20A.pdf", + "source_raw": "PDNA portal", + "author_raw": "PDNA portal", + "published_on": datetime.date(2015, datetime.date.today().month, datetime.date.today().day), + "source_type": Lead.SourceType.WEBSITE, } ] diff --git a/apps/unified_connector/tests/mock_data/relief_web_mock_data.py b/apps/unified_connector/tests/mock_data/relief_web_mock_data.py index 4f774f7e5f..1740995753 100644 --- a/apps/unified_connector/tests/mock_data/relief_web_mock_data.py +++ b/apps/unified_connector/tests/mock_data/relief_web_mock_data.py @@ -1,6 +1,6 @@ import datetime -RELIEF_WEB_MOCK_DATA_PAGE_1_RAW = ''' +RELIEF_WEB_MOCK_DATA_PAGE_1_RAW = """ { "time": 11, "href": "https://api.reliefweb.int/v1/reports?appname=thedeep.io", @@ -124,9 +124,9 @@ } ] } -''' +""" -RELIEF_WEB_MOCK_DATA_PAGE_2_RAW = ''' +RELIEF_WEB_MOCK_DATA_PAGE_2_RAW = """ { "time": 11, "href": "https://api.reliefweb.int/v1/reports?appname=thedeep.io", @@ -247,7 +247,7 @@ } ] } -''' +""" RELIEF_WEB_MOCK_EXCEPTED_LEADS = [ { @@ -256,8 +256,8 @@ "published_on": datetime.date(2020, 9, 17), "url": "https://reliefweb.int/report/nepal/nepal-makes-progress-human-capital-development", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "World Bank" + "source_type": "website", + "author_raw": "World Bank", }, { "id": "3670541", @@ -265,8 +265,8 @@ "published_on": datetime.date(2020, 9, 16), "url": "https://reliefweb.int/sites/reliefweb.int/files/resources/roap_covid_response_sitrep_18.pdf", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "International Organization for Migration" + "source_type": "website", + "author_raw": "International Organization for Migration", }, { "id": "3670672", @@ -274,8 +274,8 @@ "published_on": datetime.date(2020, 9, 16), "url": "https://reliefweb.int/report/nepal/nepal-earthquake-national-seismological", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "European Commission's Directorate-General for European Civil Protection" + "source_type": "website", + "author_raw": "European Commission's Directorate-General for European Civil Protection", }, { "id": "3670318", @@ -283,8 +283,8 @@ "published_on": datetime.date(2020, 9, 15), "url": "https://reliefweb.int/sites/reliefweb.int/files/resources/ROAP_Snapshot_200915.pdf", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "UN Office for the Coordination of Humanitarian Affairs" + "source_type": "website", + "author_raw": "UN Office for the Coordination of Humanitarian Affairs", }, { "id": "3670885", @@ -292,8 +292,8 @@ "published_on": datetime.date(2020, 9, 17), "url": "https://reliefweb.int/report/nepal/nepal-makes-progress-human-capital-development", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "World Bank" + "source_type": "website", + "author_raw": "World Bank", }, { "id": "3670541", @@ -301,8 +301,8 @@ "published_on": datetime.date(2020, 9, 16), "url": "https://reliefweb.int/sites/reliefweb.int/files/resources/roap_covid_response_sitrep_18.pdf", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "International Organization for Migration" + "source_type": "website", + "author_raw": "International Organization for Migration", }, { "id": "3670672", @@ -310,8 +310,8 @@ "published_on": datetime.date(2020, 9, 16), "url": "https://reliefweb.int/report/nepal/nepal-earthquake-national-seismological", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "European Commission's Directorate-General for European Civil Protection" + "source_type": "website", + "author_raw": "European Commission's Directorate-General for European Civil Protection", }, { "id": "3670318", @@ -319,7 +319,7 @@ "published_on": datetime.date(2020, 9, 15), "url": "https://reliefweb.int/sites/reliefweb.int/files/resources/ROAP_Snapshot_200915.pdf", "source_raw": "reliefweb", - "source_type": 'website', - "author_raw": "UN Office for the Coordination of Humanitarian Affairs" - } + "source_type": "website", + "author_raw": "UN Office for the Coordination of Humanitarian Affairs", + }, ] diff --git a/apps/unified_connector/tests/mock_data/rss_feed_mock_data.py b/apps/unified_connector/tests/mock_data/rss_feed_mock_data.py index 7157e911e7..77458fb1de 100644 --- a/apps/unified_connector/tests/mock_data/rss_feed_mock_data.py +++ b/apps/unified_connector/tests/mock_data/rss_feed_mock_data.py @@ -1,6 +1,6 @@ import datetime -RSS_FEED_MOCK_DATA_RAW = ''' +RSS_FEED_MOCK_DATA_RAW = """ @@ -139,7 +139,9 @@ -'''.encode('utf-8') +""".encode( + "utf-8" +) RSS_PARAMS = { @@ -148,7 +150,7 @@ "author-field": "author", "source-field": "source", "date-field": "pubDate", - "title-field": "title" + "title-field": "title", } @@ -159,7 +161,7 @@ "source_raw": "ReliefWeb - Updates", "published_on": datetime.date(2020, 9, 17), "title": "Ukraine: DRC / DDG Legal Alert: Issue 55 - August 2020 [EN/RU/UK]", - "source_type": 'rss' + "source_type": "rss", }, { "url": "https://reliefweb.int/report/ukraine/osce-special-monitoring", @@ -167,7 +169,7 @@ "source_raw": "ReliefWeb - Updates", "published_on": datetime.date(2020, 9, 16), "title": "OSCE Special Monitoring Mission to Ukraine (SMM) Daily Report 221/2020 issued on 16 September 2020", - "source_type": 'rss' + "source_type": "rss", }, { "url": "https://reliefweb.int/report/ukraine/crossing-contact-line-august-2020-snapshot-enuk", @@ -175,7 +177,7 @@ "source_raw": "ReliefWeb - Updates", "published_on": datetime.date(2020, 9, 16), "title": "Ukraine: Crossing the Contact Line, August 2020 Snapshot [EN/UK]", - "source_type": 'rss' + "source_type": "rss", }, { "url": "https://reliefweb.int/report/ukraine/eu-and-undp-supply-protective-respirators-medical", @@ -183,7 +185,7 @@ "source_raw": "ReliefWeb - Updates", "published_on": datetime.date(2020, 9, 16), "title": "Ukraine: EU and UNDP supply protective respirators to medical workers in Donetsk Oblast", - "source_type": 'rss' + "source_type": "rss", }, { "url": "https://reliefweb.int/report/ukraine/osce-special-monitoring-mission-ukraine-smm-daily", @@ -191,6 +193,6 @@ "source_raw": "ReliefWeb - Updates", "published_on": datetime.date(2020, 9, 15), "title": "OSCE Special Monitoring Mission to Ukraine (SMM) Daily Report 220/2020 issued on 15 September 2020", - "source_type": 'rss' + "source_type": "rss", }, ] diff --git a/apps/unified_connector/tests/mock_data/store.py b/apps/unified_connector/tests/mock_data/store.py index 6a96fee153..58ae101be4 100644 --- a/apps/unified_connector/tests/mock_data/store.py +++ b/apps/unified_connector/tests/mock_data/store.py @@ -1,61 +1,43 @@ from unified_connector.models import ConnectorSource -from .relief_web_mock_data import ( - RELIEF_WEB_MOCK_DATA_PAGE_1_RAW, - RELIEF_WEB_MOCK_DATA_PAGE_2_RAW, - RELIEF_WEB_MOCK_EXCEPTED_LEADS, -) -from .unhcr_mock_data import ( - UNHCR_MOCK_DATA_PAGE_1_RAW, - UNHCR_MOCK_DATA_PAGE_2_RAW, - UNHCR_WEB_MOCK_EXCEPTED_LEADS, -) -from .rss_feed_mock_data import ( - RSS_FEED_MOCK_DATA_RAW, - RSS_PARAMS, - RSS_FEED_MOCK_EXCEPTED_LEADS, -) + from .atom_feed_mock_data import ( ATOM_FEED_MOCK_DATA_RAW, - ATOM_FEED_PARAMS, ATOM_FEED_MOCK_EXCEPTED_LEADS, + ATOM_FEED_PARAMS, ) +from .emm_mock_data import EMM_MOCK_DATA_RAW, EMM_MOCK_EXCEPTED_LEADS, EMM_PARAMS from .humanitarian_response_mock_data import ( HUMANITARIAN_RESPONSE_MOCK_DATA_RAW, HUMANITARIAN_RESPONSE_MOCK_EXCEPTED_LEADS, ) -from .pdna_mock_data import ( - PDNA_MOCK_DATA_RAW, - PDNA_MOCK_EXCEPTED_LEADS, - PDNA_PARAMS, +from .pdna_mock_data import PDNA_MOCK_DATA_RAW, PDNA_MOCK_EXCEPTED_LEADS, PDNA_PARAMS +from .relief_web_mock_data import ( + RELIEF_WEB_MOCK_DATA_PAGE_1_RAW, + RELIEF_WEB_MOCK_DATA_PAGE_2_RAW, + RELIEF_WEB_MOCK_EXCEPTED_LEADS, ) -from .emm_mock_data import ( - EMM_MOCK_DATA_RAW, - EMM_MOCK_EXCEPTED_LEADS, - EMM_PARAMS, +from .rss_feed_mock_data import ( + RSS_FEED_MOCK_DATA_RAW, + RSS_FEED_MOCK_EXCEPTED_LEADS, + RSS_PARAMS, +) +from .unhcr_mock_data import ( + UNHCR_MOCK_DATA_PAGE_1_RAW, + UNHCR_MOCK_DATA_PAGE_2_RAW, + UNHCR_WEB_MOCK_EXCEPTED_LEADS, ) CONNECTOR_SOURCE_MOCK_DATA = { - ConnectorSource.Source.UNHCR: ( - (UNHCR_MOCK_DATA_PAGE_1_RAW, UNHCR_MOCK_DATA_PAGE_2_RAW), UNHCR_WEB_MOCK_EXCEPTED_LEADS - ), + ConnectorSource.Source.UNHCR: ((UNHCR_MOCK_DATA_PAGE_1_RAW, UNHCR_MOCK_DATA_PAGE_2_RAW), UNHCR_WEB_MOCK_EXCEPTED_LEADS), ConnectorSource.Source.RELIEF_WEB: ( - (RELIEF_WEB_MOCK_DATA_PAGE_1_RAW, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW), RELIEF_WEB_MOCK_EXCEPTED_LEADS - ), - ConnectorSource.Source.RSS_FEED: ( - (RSS_FEED_MOCK_DATA_RAW,), RSS_FEED_MOCK_EXCEPTED_LEADS - ), - ConnectorSource.Source.ATOM_FEED: ( - (ATOM_FEED_MOCK_DATA_RAW,), ATOM_FEED_MOCK_EXCEPTED_LEADS - ), - ConnectorSource.Source.HUMANITARIAN_RESP: ( - (HUMANITARIAN_RESPONSE_MOCK_DATA_RAW,), HUMANITARIAN_RESPONSE_MOCK_EXCEPTED_LEADS - ), - ConnectorSource.Source.PDNA: ( - (PDNA_MOCK_DATA_RAW,), PDNA_MOCK_EXCEPTED_LEADS - ), - ConnectorSource.Source.EMM: ( - (EMM_MOCK_DATA_RAW,), EMM_MOCK_EXCEPTED_LEADS + (RELIEF_WEB_MOCK_DATA_PAGE_1_RAW, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW), + RELIEF_WEB_MOCK_EXCEPTED_LEADS, ), + ConnectorSource.Source.RSS_FEED: ((RSS_FEED_MOCK_DATA_RAW,), RSS_FEED_MOCK_EXCEPTED_LEADS), + ConnectorSource.Source.ATOM_FEED: ((ATOM_FEED_MOCK_DATA_RAW,), ATOM_FEED_MOCK_EXCEPTED_LEADS), + ConnectorSource.Source.HUMANITARIAN_RESP: ((HUMANITARIAN_RESPONSE_MOCK_DATA_RAW,), HUMANITARIAN_RESPONSE_MOCK_EXCEPTED_LEADS), + ConnectorSource.Source.PDNA: ((PDNA_MOCK_DATA_RAW,), PDNA_MOCK_EXCEPTED_LEADS), + ConnectorSource.Source.EMM: ((EMM_MOCK_DATA_RAW,), EMM_MOCK_EXCEPTED_LEADS), } CONNECTOR_SOURCE_MOCK_PARAMS = { @@ -66,7 +48,7 @@ } -class ConnectorSourceResponseMock(): +class ConnectorSourceResponseMock: def __init__(self, source_type): self.raw_pages_data, self.expected_data = CONNECTOR_SOURCE_MOCK_DATA[source_type] self.params = CONNECTOR_SOURCE_MOCK_PARAMS.get(source_type, {}) diff --git a/apps/unified_connector/tests/mock_data/unhcr_mock_data.py b/apps/unified_connector/tests/mock_data/unhcr_mock_data.py index 5daefe1216..7d9af7ccc5 100644 --- a/apps/unified_connector/tests/mock_data/unhcr_mock_data.py +++ b/apps/unified_connector/tests/mock_data/unhcr_mock_data.py @@ -1,6 +1,6 @@ import datetime -UNHCR_MOCK_DATA_PAGE_1_RAW = ''' +UNHCR_MOCK_DATA_PAGE_1_RAW = """ @@ -320,9 +320,9 @@ -''' +""" -UNHCR_MOCK_DATA_PAGE_2_RAW = ''' +UNHCR_MOCK_DATA_PAGE_2_RAW = """ @@ -608,7 +608,7 @@ -''' +""" UNHCR_WEB_MOCK_EXCEPTED_LEADS = [ diff --git a/apps/unified_connector/tests/test_mutation.py b/apps/unified_connector/tests/test_mutation.py index 8d36c0ae54..b56d5ba6b9 100644 --- a/apps/unified_connector/tests/test_mutation.py +++ b/apps/unified_connector/tests/test_mutation.py @@ -1,28 +1,29 @@ from unittest.mock import patch -from django.core.files.uploadedfile import SimpleUploadedFile -from django.test import override_settings - -from utils.graphene.tests import GraphQLSnapShotTestCase -from deep.tests.test_case import DUMMY_TEST_CACHES -from project.factories import ProjectFactory -from user.factories import UserFactory - -from deep.tests import TestCase -from unified_connector.models import ( - ConnectorLead, - ConnectorSource, - ConnectorLeadPreviewImage, -) from deepl_integration.handlers import UnifiedConnectorLeadHandler from deepl_integration.serializers import DeeplServerBaseCallbackSerializer +from django.core.files.uploadedfile import SimpleUploadedFile +from django.test import override_settings +from project.factories import ProjectFactory from unified_connector.factories import ( ConnectorLeadFactory, ConnectorSourceFactory, ConnectorSourceLeadFactory, UnifiedConnectorFactory, ) -from unified_connector.tests.mock_data.relief_web_mock_data import RELIEF_WEB_MOCK_DATA_PAGE_2_RAW +from unified_connector.models import ( + ConnectorLead, + ConnectorLeadPreviewImage, + ConnectorSource, +) +from unified_connector.tests.mock_data.relief_web_mock_data import ( + RELIEF_WEB_MOCK_DATA_PAGE_2_RAW, +) +from user.factories import UserFactory + +from deep.tests import TestCase +from deep.tests.test_case import DUMMY_TEST_CACHES +from utils.graphene.tests import GraphQLSnapShotTestCase @override_settings( @@ -31,7 +32,7 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): factories_used = [ProjectFactory, UserFactory, UnifiedConnectorFactory, ConnectorSourceFactory] - CREATE_UNIFIED_CONNECTOR_MUTATION = ''' + CREATE_UNIFIED_CONNECTOR_MUTATION = """ mutation MyMutation ($projectId: ID! $input: UnifiedConnectorWithSourceInputType!) { project(id: $projectId) { unifiedConnector { @@ -57,9 +58,9 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ - UPDATE_UNIFIED_CONNECTOR_MUTATION = ''' + UPDATE_UNIFIED_CONNECTOR_MUTATION = """ mutation MyMutation ($projectId: ID! $unifiedConnectorId: ID! $input: UnifiedConnectorWithSourceInputType!) { project(id: $projectId) { unifiedConnector { @@ -85,9 +86,9 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ - DELETE_UNIFIED_CONNECTOR_MUTATION = ''' + DELETE_UNIFIED_CONNECTOR_MUTATION = """ mutation MyMutation ($projectId: ID! $unifiedConnectorId: ID!) { project(id: $projectId) { unifiedConnector { @@ -113,9 +114,9 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ - TRIGGER_UNIFIED_CONNECTOR_MUTATION = ''' + TRIGGER_UNIFIED_CONNECTOR_MUTATION = """ mutation MyMutation ($projectId: ID! $unifiedConnectorId: ID!) { project(id: $projectId) { unifiedConnector { @@ -126,9 +127,9 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ - UPDATE_CONNECTOR_SOURCE_LEAD_MUTATION = ''' + UPDATE_CONNECTOR_SOURCE_LEAD_MUTATION = """ mutation MyMutation ($projectId: ID!, $connectorSourceLeadId: ID!, $input: ConnectorSourceLeadInputType!) { project(id: $projectId) { unifiedConnector { @@ -145,7 +146,7 @@ class TestLeadMutationSchema(GraphQLSnapShotTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -159,43 +160,40 @@ def setUp(self): def test_unified_connector_create(self): def _query_check(minput, **kwargs): return self.query_check( - self.CREATE_UNIFIED_CONNECTOR_MUTATION, - minput=minput, - variables={'projectId': self.project.id}, - **kwargs + self.CREATE_UNIFIED_CONNECTOR_MUTATION, minput=minput, variables={"projectId": self.project.id}, **kwargs ) minput = dict( - title='unified-connector-s-1', - clientId='u-connector-101', + title="unified-connector-s-1", + clientId="u-connector-101", isActive=False, sources=[ dict( - title='connector-s-1', + title="connector-s-1", source=self.genum(ConnectorSource.Source.ATOM_FEED), - clientId='connector-source-101', + clientId="connector-source-101", params={ - 'sample-key': 'sample-value', + "sample-key": "sample-value", }, ), dict( - title='connector-s-2', + title="connector-s-2", source=self.genum(ConnectorSource.Source.RELIEF_WEB), - clientId='connector-source-102', + clientId="connector-source-102", params={ - 'sample-key': 'sample-value', + "sample-key": "sample-value", }, ), dict( - title='connector-s-3', + title="connector-s-3", # Same as previouse -> Throw error source=self.genum(ConnectorSource.Source.RELIEF_WEB), - clientId='connector-source-103', + clientId="connector-source-103", params={ - 'sample-key': 'sample-value', + "sample-key": "sample-value", }, - ) - ] + ), + ], ) # -- Without login @@ -213,11 +211,11 @@ def _query_check(minput, **kwargs): self.force_login(self.member_user) _query_check(minput, okay=False) - minput['sources'].pop(-1) + minput["sources"].pop(-1) # --- member user self.force_login(self.member_user) - content = _query_check(minput)['data']['project']['unifiedConnector']['unifiedConnectorCreate']['result'] - self.assertMatchSnapshot(content, 'success') + content = _query_check(minput)["data"]["project"]["unifiedConnector"]["unifiedConnectorCreate"]["result"] + self.assertMatchSnapshot(content, "success") def test_unified_connector_update(self): uc = UnifiedConnectorFactory.create(project=self.project) @@ -226,47 +224,47 @@ def _query_check(minput, **kwargs): return self.query_check( self.UPDATE_UNIFIED_CONNECTOR_MUTATION, minput=minput, - variables={'projectId': self.project.id, 'unifiedConnectorId': uc.id}, - **kwargs + variables={"projectId": self.project.id, "unifiedConnectorId": uc.id}, + **kwargs, ) source1 = ConnectorSourceFactory.create(unified_connector=uc, source=ConnectorSource.Source.ATOM_FEED) source2 = ConnectorSourceFactory.create(unified_connector=uc, source=ConnectorSource.Source.RELIEF_WEB) source3 = ConnectorSourceFactory.create(unified_connector=uc, source=ConnectorSource.Source.RSS_FEED) minput = dict( - clientId='u-connector-101', + clientId="u-connector-101", isActive=uc.is_active, title=uc.title, sources=[ dict( id=str(source1.pk), - clientId='connector-source-101', + clientId="connector-source-101", params=source1.params, source=self.genum(source1.source), title=source1.title, ), dict( id=str(source2.pk), - clientId='connector-source-102', + clientId="connector-source-102", params=source2.params, source=self.genum(source2.source), title=source2.title, ), dict( # Remove id. This will create a new source - clientId='connector-source-103', + clientId="connector-source-103", params=source3.params, source=self.genum(source3.source), title=source3.title, ), dict( # New source with duplicate source - clientId='connector-source-103', + clientId="connector-source-103", params=source1.params, source=self.genum(source1.source), title=source1.title, ), - ] + ], ) # -- Without login @@ -284,9 +282,9 @@ def _query_check(minput, **kwargs): self.force_login(self.member_user) _query_check(minput, okay=False) - minput['sources'].pop(-1) - content = _query_check(minput)['data']['project']['unifiedConnector']['unifiedConnectorWithSourceUpdate']['result'] - self.assertMatchSnapshot(content, 'success-1') + minput["sources"].pop(-1) + content = _query_check(minput)["data"]["project"]["unifiedConnector"]["unifiedConnectorWithSourceUpdate"]["result"] + self.assertMatchSnapshot(content, "success-1") def test_unified_connector_delete(self): admin_user = UserFactory.create() @@ -296,8 +294,8 @@ def test_unified_connector_delete(self): def _query_check(**kwargs): return self.query_check( self.DELETE_UNIFIED_CONNECTOR_MUTATION, - variables={'projectId': self.project.id, 'unifiedConnectorId': uc.id}, - **kwargs + variables={"projectId": self.project.id, "unifiedConnectorId": uc.id}, + **kwargs, ) for source in [ConnectorSource.Source.ATOM_FEED, ConnectorSource.Source.RSS_FEED, ConnectorSource.Source.RELIEF_WEB]: @@ -318,10 +316,10 @@ def _query_check(**kwargs): # --- member user self.force_login(admin_user) - _query_check(okay=True, mnested=['project', 'unifiedConnector']) + _query_check(okay=True, mnested=["project", "unifiedConnector"]) - @patch('unified_connector.sources.relief_web.requests') - @patch('deepl_integration.handlers.requests') + @patch("unified_connector.sources.relief_web.requests") + @patch("deepl_integration.handlers.requests") def test_unified_connector_trigger(self, extractor_response_mock, reliefweb_requests_mock): uc = UnifiedConnectorFactory.create(project=self.project) ConnectorSourceFactory.create(unified_connector=uc, source=ConnectorSource.Source.RELIEF_WEB) @@ -329,12 +327,12 @@ def test_unified_connector_trigger(self, extractor_response_mock, reliefweb_requ def _query_check(**kwargs): return self.query_check( self.TRIGGER_UNIFIED_CONNECTOR_MUTATION, - variables={'projectId': self.project.id, 'unifiedConnectorId': uc.id}, - **kwargs + variables={"projectId": self.project.id, "unifiedConnectorId": uc.id}, + **kwargs, ) def _query_okay_check(): - return _query_check(okay=True, mnested=['project', 'unifiedConnector']) + return _query_check(okay=True, mnested=["project", "unifiedConnector"]) # -- With login (non-member) self.force_login(self.non_member_user) @@ -346,43 +344,43 @@ def _query_okay_check(): # --- member user (inactive) self.force_login(self.member_user) - _query_check(okay=False, mnested=['project', 'unifiedConnector']) + _query_check(okay=False, mnested=["project", "unifiedConnector"]) # --- member user (active) uc.is_active = True - uc.save(update_fields=('is_active',)) + uc.save(update_fields=("is_active",)) self.force_login(self.member_user) connector_lead_qs = ConnectorLead.objects.filter(connectorsourcelead__source__unified_connector=uc) for label, source_response, extractor_response, expected_source_status, expected_lead_status in [ - ( - 'both-invalid', - [500, 'invalid-content'], - [500, {'error_message': 'Mock error message'}], - ConnectorSource.Status.FAILURE, - [], - ), - ( - 'extractor-invalid', - [200, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW], - [500, {'error_message': 'Mock error message'}], - ConnectorSource.Status.SUCCESS, - [ConnectorLead.ExtractionStatus.RETRYING], - ), - ( - 'source-invalid', - [500, 'invalid-content'], - [202, {}], - ConnectorSource.Status.FAILURE, - [], - ), - ( - 'all-good', - [200, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW], - [202, {}], - ConnectorSource.Status.SUCCESS, - [ConnectorLead.ExtractionStatus.STARTED], - ), + ( + "both-invalid", + [500, "invalid-content"], + [500, {"error_message": "Mock error message"}], + ConnectorSource.Status.FAILURE, + [], + ), + ( + "extractor-invalid", + [200, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW], + [500, {"error_message": "Mock error message"}], + ConnectorSource.Status.SUCCESS, + [ConnectorLead.ExtractionStatus.RETRYING], + ), + ( + "source-invalid", + [500, "invalid-content"], + [202, {}], + ConnectorSource.Status.FAILURE, + [], + ), + ( + "all-good", + [200, RELIEF_WEB_MOCK_DATA_PAGE_2_RAW], + [202, {}], + ConnectorSource.Status.SUCCESS, + [ConnectorLead.ExtractionStatus.STARTED], + ), ]: reliefweb_requests_mock.post.return_value.status_code = source_response[0] reliefweb_requests_mock.post.return_value.text = source_response[1] @@ -391,14 +389,14 @@ def _query_okay_check(): with self.captureOnCommitCallbacks(execute=True): _query_okay_check() self.assertEqual( - list(uc.sources.values_list('status', flat=True)), + list(uc.sources.values_list("status", flat=True)), [expected_source_status.value], - f'{label}: {expected_source_status.label}' + f"{label}: {expected_source_status.label}", ) self.assertEqual( - list(connector_lead_qs.distinct().values_list('extraction_status', flat=True)), + list(connector_lead_qs.distinct().values_list("extraction_status", flat=True)), [status.value for status in expected_lead_status], - f'{label}: {[status.label for status in expected_lead_status]}', + f"{label}: {[status.label for status in expected_lead_status]}", ) connector_lead_qs.delete() # Clear all connector leads @@ -435,8 +433,8 @@ def _query_check(minput, source_lead, **kwargs): return self.query_check( self.UPDATE_CONNECTOR_SOURCE_LEAD_MUTATION, minput=minput, - variables={'projectId': self.project.id, 'connectorSourceLeadId': source_lead.id}, - **kwargs + variables={"projectId": self.project.id, "connectorSourceLeadId": source_lead.id}, + **kwargs, ) minput = dict(blocked=True) @@ -463,7 +461,7 @@ def _query_check(minput, source_lead, **kwargs): _query_check( dict(blocked=updated_to), source_lead, - mnested=['project', 'unifiedConnector'], + mnested=["project", "unifiedConnector"], okay=True, ) source_lead.refresh_from_db() @@ -472,32 +470,30 @@ def _query_check(minput, source_lead, **kwargs): class UnifiedConnectorCallbackApiTest(TestCase): - @patch('deepl_integration.handlers.RequestHelper') + @patch("deepl_integration.handlers.RequestHelper") def test_create_connector(self, RequestHelperMock): def _check_connector_lead_status(connector_lead, status): connector_lead.refresh_from_db() self.assertEqual(connector_lead.extraction_status, status) - url = '/api/v1/callback/unified-connector-lead-extract/' - connector_lead1 = ConnectorLeadFactory.create(url='https://example.com/some-random-url-01') - connector_lead2 = ConnectorLeadFactory.create(url='https://example.com/some-random-url-02') + url = "/api/v1/callback/unified-connector-lead-extract/" + connector_lead1 = ConnectorLeadFactory.create(url="https://example.com/some-random-url-01") + connector_lead2 = ConnectorLeadFactory.create(url="https://example.com/some-random-url-02") - SAMPLE_SIMPLIFIED_TEXT = 'Sample text' + SAMPLE_SIMPLIFIED_TEXT = "Sample text" RequestHelperMock.return_value.get_text.return_value = SAMPLE_SIMPLIFIED_TEXT # Mock file - file_1 = SimpleUploadedFile( - name='test_image.jpg', content=b'', content_type='image/jpeg' - ) + file_1 = SimpleUploadedFile(name="test_image.jpg", content=b"", content_type="image/jpeg") RequestHelperMock.return_value.get_file.return_value = file_1 # ------ Extraction FAILED data = dict( - client_id='some-random-client-id', - images_path=['https://example.com/sample-file-1.jpg'], - text_path='https://example.com/url-where-data-is-fetched-from-mock-response', + client_id="some-random-client-id", + images_path=["https://example.com/sample-file-1.jpg"], + text_path="https://example.com/url-where-data-is-fetched-from-mock-response", total_words_count=100, total_pages=10, status=DeeplServerBaseCallbackSerializer.Status.FAILED.value, - text_extraction_id='c4c3c256-f307-4a85-a50e-5516a6f1ce8e', + text_extraction_id="c4c3c256-f307-4a85-a50e-5516a6f1ce8e", ) response = self.client.post(url, data) @@ -506,37 +502,37 @@ def _check_connector_lead_status(connector_lead, status): connector_lead1.refresh_from_db() assert connector_lead1.text_extraction_id is None - data['client_id'] = UnifiedConnectorLeadHandler.get_client_id(connector_lead1) - data['status'] = DeeplServerBaseCallbackSerializer.Status.FAILED.value + data["client_id"] = UnifiedConnectorLeadHandler.get_client_id(connector_lead1) + data["status"] = DeeplServerBaseCallbackSerializer.Status.FAILED.value response = self.client.post(url, data) self.assert_200(response) connector_lead1.refresh_from_db() _check_connector_lead_status(connector_lead1, ConnectorLead.ExtractionStatus.FAILED) assert connector_lead1.text_extraction_id is None - assert connector_lead1.simplified_text == '' + assert connector_lead1.simplified_text == "" assert connector_lead1.word_count is None assert connector_lead1.page_count is None # ------ Extraction SUCCESS data = dict( - client_id='some-random-client-id', - images_path=['https://example.com/sample-file-1.jpg', 'https://example.com/sample-file-2.jpg'], - text_path='https://example.com/url-where-data-is-fetched-from-mock-response', + client_id="some-random-client-id", + images_path=["https://example.com/sample-file-1.jpg", "https://example.com/sample-file-2.jpg"], + text_path="https://example.com/url-where-data-is-fetched-from-mock-response", total_words_count=100, total_pages=10, status=DeeplServerBaseCallbackSerializer.Status.SUCCESS.value, - text_extraction_id='c4c3c256-f307-4a85-a50e-5516a6f1ce8e', + text_extraction_id="c4c3c256-f307-4a85-a50e-5516a6f1ce8e", ) response = self.client.post(url, data) self.assert_400(response) _check_connector_lead_status(connector_lead2, ConnectorLead.ExtractionStatus.PENDING) - data['client_id'] = UnifiedConnectorLeadHandler.get_client_id(connector_lead2) + data["client_id"] = UnifiedConnectorLeadHandler.get_client_id(connector_lead2) response = self.client.post(url, data) self.assert_200(response) connector_lead2.refresh_from_db() _check_connector_lead_status(connector_lead2, ConnectorLead.ExtractionStatus.SUCCESS) - assert str(connector_lead2.text_extraction_id) == data['text_extraction_id'] + assert str(connector_lead2.text_extraction_id) == data["text_extraction_id"] assert connector_lead2.simplified_text is not None assert connector_lead2.word_count == 100 assert connector_lead2.page_count == 10 diff --git a/apps/unified_connector/tests/test_query.py b/apps/unified_connector/tests/test_query.py index 6b5de994f1..a665af5b4c 100644 --- a/apps/unified_connector/tests/test_query.py +++ b/apps/unified_connector/tests/test_query.py @@ -1,25 +1,22 @@ -from unified_connector.models import ConnectorSourceLead from django.db.models import Q -from utils.graphene.tests import GraphQLTestCase - -from unified_connector.models import ConnectorSource - -from project.factories import ProjectFactory -from user.factories import UserFactory from organization.factories import OrganizationFactory - +from project.factories import ProjectFactory from unified_connector.factories import ( ConnectorLeadFactory, ConnectorSourceFactory, ConnectorSourceLeadFactory, UnifiedConnectorFactory, ) +from unified_connector.models import ConnectorSource, ConnectorSourceLead +from user.factories import UserFactory + +from utils.graphene.tests import GraphQLTestCase class TestUnifiedConnectorQuery(GraphQLTestCase): ENABLE_NOW_PATCHER = True - UNIFIED_CONNECTORS_QUERY = ''' + UNIFIED_CONNECTORS_QUERY = """ query MyQuery ($id: ID!) { project(id: $id) { unifiedConnector { @@ -56,9 +53,9 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ - UNIFIED_CONNECTOR_QUERY = ''' + UNIFIED_CONNECTOR_QUERY = """ query MyQuery ($id: ID! $connectorId: ID!) { project(id: $id) { unifiedConnector { @@ -92,9 +89,9 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ - SOURCE_CONNECTORS_QUERY = ''' + SOURCE_CONNECTORS_QUERY = """ query MyQuery ($id: ID!) { project(id: $id) { unifiedConnector { @@ -120,9 +117,9 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ - SOURCE_CONNECTOR_QUERY = ''' + SOURCE_CONNECTOR_QUERY = """ query MyQuery ($id: ID! $connectorSourceId: ID!) { project(id: $id) { unifiedConnector { @@ -145,9 +142,9 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ - SOURCE_CONNECTOR_LEADS_QUERY = ''' + SOURCE_CONNECTOR_LEADS_QUERY = """ query MyQuery ($id: ID!) { project(id: $id) { unifiedConnector { @@ -173,9 +170,9 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ - SOURCE_CONNECTOR_LEAD_QUERY = ''' + SOURCE_CONNECTOR_LEAD_QUERY = """ query MyQuery ($id: ID! $connectorSourceLeadId: ID!) { project(id: $id) { unifiedConnector { @@ -198,8 +195,8 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' - SOURCE_COUNT_EXCLUDING_ADDED_AND_IGNORED_QUERY = ''' + """ + SOURCE_COUNT_EXCLUDING_ADDED_AND_IGNORED_QUERY = """ query MyQuery ($id: ID!) { project(id: $id) { unifiedConnector { @@ -207,7 +204,7 @@ class TestUnifiedConnectorQuery(GraphQLTestCase): } } } - ''' + """ def setUp(self): super().setUp() @@ -230,14 +227,14 @@ def setUp(self): def test_unified_connector_query(self): # -- non member user self.force_login(self.another_user) - content = self.query_check( - self.UNIFIED_CONNECTORS_QUERY, variables=dict(id=self.project.id) - )['data']['project']['unifiedConnector'] + content = self.query_check(self.UNIFIED_CONNECTORS_QUERY, variables=dict(id=self.project.id))["data"]["project"][ + "unifiedConnector" + ] self.assertEqual(content, None) # Single content = self.query_check( self.UNIFIED_CONNECTOR_QUERY, variables=dict(id=self.project.id, connectorId=str(self.uc1.pk)) - )['data']['project']['unifiedConnector'] + )["data"]["project"]["unifiedConnector"] self.assertEqual(content, None) # -- member user @@ -245,32 +242,40 @@ def test_unified_connector_query(self): content = self.query_check( self.UNIFIED_CONNECTORS_QUERY, variables=dict(id=self.project.id), - )['data']['project']['unifiedConnector']['unifiedConnectors'] - self.assertEqual(content['totalCount'], 2) - self.assertEqual(content['results'], [ - dict( - id=str(self.uc1.pk), - isActive=False, - project=str(self.project.pk), - title=self.uc1.title, - leadsCount=dict(alreadyAdded=0, blocked=0, total=0), - sources=[], - ), - dict( - id=str(self.uc2.pk), - isActive=False, - project=str(self.project.pk), - title=self.uc2.title, - leadsCount=dict(alreadyAdded=0, blocked=0, total=0), - sources=[], - ), - ]) + )[ + "data" + ]["project"][ + "unifiedConnector" + ]["unifiedConnectors"] + self.assertEqual(content["totalCount"], 2) + self.assertEqual( + content["results"], + [ + dict( + id=str(self.uc1.pk), + isActive=False, + project=str(self.project.pk), + title=self.uc1.title, + leadsCount=dict(alreadyAdded=0, blocked=0, total=0), + sources=[], + ), + dict( + id=str(self.uc2.pk), + isActive=False, + project=str(self.project.pk), + title=self.uc2.title, + leadsCount=dict(alreadyAdded=0, blocked=0, total=0), + sources=[], + ), + ], + ) # Single content = self.query_check( self.UNIFIED_CONNECTOR_QUERY, variables=dict(id=self.project.id, connectorId=str(self.uc1.pk)) - )['data']['project']['unifiedConnector']['unifiedConnector'] + )["data"]["project"]["unifiedConnector"]["unifiedConnector"] self.assertEqual( - content, dict( + content, + dict( id=str(self.uc1.pk), isActive=False, project=str(self.project.pk), @@ -295,14 +300,14 @@ def test_connector_source_query(self): # -- non member user self.force_login(self.another_user) - content = self.query_check( - self.SOURCE_CONNECTORS_QUERY, variables=dict(id=self.project.id) - )['data']['project']['unifiedConnector'] + content = self.query_check(self.SOURCE_CONNECTORS_QUERY, variables=dict(id=self.project.id))["data"]["project"][ + "unifiedConnector" + ] self.assertEqual(content, None) # Single content = self.query_check( self.SOURCE_CONNECTOR_QUERY, variables=dict(id=self.project.id, connectorSourceId=str(self.uc1.pk)) - )['data']['project']['unifiedConnector'] + )["data"]["project"]["unifiedConnector"] self.assertEqual(content, None) ec_source1_1 = dict( @@ -312,7 +317,7 @@ def test_connector_source_query(self): unifiedConnector=str(self.uc1.pk), params={}, leadsCount=dict(alreadyAdded=0, blocked=0, total=2), - stats=[{'count': 2, 'date': self.now_datetime.strftime('%Y-%m-%d')}], + stats=[{"count": 2, "date": self.now_datetime.strftime("%Y-%m-%d")}], ) ec_source1_2 = dict( id=str(source1_2.pk), @@ -330,7 +335,7 @@ def test_connector_source_query(self): unifiedConnector=str(self.uc2.pk), params={}, leadsCount=dict(alreadyAdded=0, blocked=0, total=1), - stats=[{'count': 1, 'date': self.now_datetime.strftime('%Y-%m-%d')}], + stats=[{"count": 1, "date": self.now_datetime.strftime("%Y-%m-%d")}], ) ec_source2_2 = dict( id=str(source2_2.pk), @@ -346,46 +351,58 @@ def test_connector_source_query(self): self.force_login(self.user) content = self.query_check( self.SOURCE_CONNECTORS_QUERY, - variables={'id': self.project.id}, - )['data']['project']['unifiedConnector']['connectorSources'] - self.assertEqual(content['totalCount'], 4) - self.assertEqual(content['results'], [ec_source1_1, ec_source1_2, ec_source2_1, ec_source2_2]) + variables={"id": self.project.id}, + )[ + "data" + ]["project"][ + "unifiedConnector" + ]["connectorSources"] + self.assertEqual(content["totalCount"], 4) + self.assertEqual(content["results"], [ec_source1_1, ec_source1_2, ec_source2_1, ec_source2_2]) # Single content = self.query_check( self.SOURCE_CONNECTOR_QUERY, variables=dict(id=self.project.id, connectorSourceId=str(source1_2.pk)) - )['data']['project']['unifiedConnector']['connectorSource'] + )["data"]["project"]["unifiedConnector"]["connectorSource"] self.assertEqual(content, ec_source1_2) # -- Unified connector -> Sources content = self.query_check( self.UNIFIED_CONNECTORS_QUERY, variables=dict(id=self.project.id), - )['data']['project']['unifiedConnector']['unifiedConnectors'] - self.assertEqual(content['totalCount'], 2) - self.assertEqual(content['results'], [ - dict( - id=str(self.uc1.pk), - isActive=False, - project=str(self.project.pk), - title=self.uc1.title, - sources=[ec_source1_1, ec_source1_2], - leadsCount=dict(alreadyAdded=0, blocked=0, total=2), - ), - dict( - id=str(self.uc2.pk), - isActive=False, - project=str(self.project.pk), - title=self.uc2.title, - sources=[ec_source2_1, ec_source2_2], - leadsCount=dict(alreadyAdded=0, blocked=0, total=1), - ), - ]) + )[ + "data" + ]["project"][ + "unifiedConnector" + ]["unifiedConnectors"] + self.assertEqual(content["totalCount"], 2) + self.assertEqual( + content["results"], + [ + dict( + id=str(self.uc1.pk), + isActive=False, + project=str(self.project.pk), + title=self.uc1.title, + sources=[ec_source1_1, ec_source1_2], + leadsCount=dict(alreadyAdded=0, blocked=0, total=2), + ), + dict( + id=str(self.uc2.pk), + isActive=False, + project=str(self.project.pk), + title=self.uc2.title, + sources=[ec_source2_1, ec_source2_2], + leadsCount=dict(alreadyAdded=0, blocked=0, total=1), + ), + ], + ) # Single content = self.query_check( self.UNIFIED_CONNECTOR_QUERY, variables=dict(id=self.project.id, connectorId=str(self.uc1.pk)) - )['data']['project']['unifiedConnector']['unifiedConnector'] + )["data"]["project"]["unifiedConnector"]["unifiedConnector"] self.assertEqual( - content, dict( + content, + dict( id=str(self.uc1.pk), isActive=False, project=str(self.project.pk), @@ -410,37 +427,41 @@ def test_connector_source_leads_query(self): self.maxDiff = None self.force_login(self.user) - content = self.query_check( - self.SOURCE_CONNECTOR_LEADS_QUERY, variables=dict(id=self.project.id) - )['data']['project']['unifiedConnector']['connectorSourceLeads'] - self.assertEqual(content['totalCount'], 15) - self.assertEqual(content['results'], [ - dict( - id=str(lead.pk), - alreadyAdded=False, - blocked=False, - connectorLead=dict( - id=str(clead1.pk), - title=clead1.title, - source=dict(id=str(org1.pk)), - authors=[dict(id=str(org2.pk)), dict(id=str(org3.pk))], - ), - source=str(lead.source_id), - ) - for lead in [ - *source1_1_leads, - *source1_2_leads, - *source2_1_leads, - *source2_2_leads, - ] - ]) + content = self.query_check(self.SOURCE_CONNECTOR_LEADS_QUERY, variables=dict(id=self.project.id))["data"]["project"][ + "unifiedConnector" + ]["connectorSourceLeads"] + self.assertEqual(content["totalCount"], 15) + self.assertEqual( + content["results"], + [ + dict( + id=str(lead.pk), + alreadyAdded=False, + blocked=False, + connectorLead=dict( + id=str(clead1.pk), + title=clead1.title, + source=dict(id=str(org1.pk)), + authors=[dict(id=str(org2.pk)), dict(id=str(org3.pk))], + ), + source=str(lead.source_id), + ) + for lead in [ + *source1_1_leads, + *source1_2_leads, + *source2_1_leads, + *source2_2_leads, + ] + ], + ) lead = source1_1_leads[0] content = self.query_check( self.SOURCE_CONNECTOR_LEAD_QUERY, variables=dict(id=self.project.id, connectorSourceLeadId=str(lead.pk)) - )['data']['project']['unifiedConnector']['connectorSourceLead'] + )["data"]["project"]["unifiedConnector"]["connectorSourceLead"] self.assertEqual( - content, dict( + content, + dict( id=str(lead.pk), alreadyAdded=False, blocked=False, @@ -451,7 +472,7 @@ def test_connector_source_leads_query(self): authors=[dict(id=str(org2.pk)), dict(id=str(org3.pk))], ), source=str(lead.source_id), - ) + ), ) # check for total sources count excluding already_added, blocked for enabled unified connector. @@ -461,15 +482,16 @@ def test_connector_source_leads_query(self): ConnectorSourceLeadFactory.create_batch(2, source=source, connector_lead=lead) ConnectorSourceLeadFactory.create_batch(2, source=source, connector_lead=self.fake_lead, blocked=True) ConnectorSourceLeadFactory.create_batch(2, source=source, connector_lead=self.fake_lead, already_added=True) - total_source_count = ConnectorSourceLead.objects.filter( - source__unified_connector__project=self.project, - source__unified_connector__is_active=True, - ).exclude( - Q(blocked=True) | - Q(already_added=True) - ).count() - content = self.query_check( - self.SOURCE_COUNT_EXCLUDING_ADDED_AND_IGNORED_QUERY, variables=dict(id=self.project.id) - )['data']['project']['unifiedConnector']['sourceCountWithoutIngnoredAndAdded'] + total_source_count = ( + ConnectorSourceLead.objects.filter( + source__unified_connector__project=self.project, + source__unified_connector__is_active=True, + ) + .exclude(Q(blocked=True) | Q(already_added=True)) + .count() + ) + content = self.query_check(self.SOURCE_COUNT_EXCLUDING_ADDED_AND_IGNORED_QUERY, variables=dict(id=self.project.id))[ + "data" + ]["project"]["unifiedConnector"]["sourceCountWithoutIngnoredAndAdded"] self.assertEqual(content, total_source_count) diff --git a/apps/unified_connector/tests/test_source.py b/apps/unified_connector/tests/test_source.py index 19a579f666..d23bc2cdcf 100644 --- a/apps/unified_connector/tests/test_source.py +++ b/apps/unified_connector/tests/test_source.py @@ -1,17 +1,14 @@ -from parameterized import parameterized from unittest.mock import patch -from utils.graphene.tests import GraphQLTestCase +from organization.models import Organization +from parameterized import parameterized from project.factories import ProjectFactory - -from unified_connector.factories import ( - ConnectorSourceFactory, - UnifiedConnectorFactory, -) +from unified_connector.factories import ConnectorSourceFactory, UnifiedConnectorFactory from unified_connector.models import ConnectorSource -from unified_connector.tests.mock_data.store import ConnectorSourceResponseMock from unified_connector.sources.base import OrganizationSearch -from organization.models import Organization +from unified_connector.tests.mock_data.store import ConnectorSourceResponseMock + +from utils.graphene.tests import GraphQLTestCase class TestUnifiedConnectorResponse(GraphQLTestCase): @@ -42,16 +39,20 @@ def _connector_response_check(self, source_type, response_mock): self.assertEqual(len(leads_result), count) self._assert_lead_equal_to_expected_data(leads_result, mock_data.expected_data) - @parameterized.expand([ - [ConnectorSource.Source.UNHCR, 'unified_connector.sources.unhcr_portal.UNHCRPortal.get_content'], - [ConnectorSource.Source.RELIEF_WEB, 'unified_connector.sources.relief_web.ReliefWeb.get_content'], - [ConnectorSource.Source.RSS_FEED, 'unified_connector.sources.rss_feed.RssFeed.get_content'], - [ConnectorSource.Source.ATOM_FEED, 'unified_connector.sources.atom_feed.AtomFeed.get_content'], - [ConnectorSource.Source.PDNA, 'unified_connector.sources.pdna.PDNA.get_content'], - [ConnectorSource.Source.HUMANITARIAN_RESP, - 'unified_connector.sources.humanitarian_response.HumanitarianResponse.get_content'], - [ConnectorSource.Source.EMM, 'unified_connector.sources.emm.EMM.get_content'], - ]) + @parameterized.expand( + [ + [ConnectorSource.Source.UNHCR, "unified_connector.sources.unhcr_portal.UNHCRPortal.get_content"], + [ConnectorSource.Source.RELIEF_WEB, "unified_connector.sources.relief_web.ReliefWeb.get_content"], + [ConnectorSource.Source.RSS_FEED, "unified_connector.sources.rss_feed.RssFeed.get_content"], + [ConnectorSource.Source.ATOM_FEED, "unified_connector.sources.atom_feed.AtomFeed.get_content"], + [ConnectorSource.Source.PDNA, "unified_connector.sources.pdna.PDNA.get_content"], + [ + ConnectorSource.Source.HUMANITARIAN_RESP, + "unified_connector.sources.humanitarian_response.HumanitarianResponse.get_content", + ], + [ConnectorSource.Source.EMM, "unified_connector.sources.emm.EMM.get_content"], + ] + ) def test_connector_source_(self, source_type, response_mock_path): response_mock_patch = patch(response_mock_path) response_mock = response_mock_patch.start() @@ -63,14 +64,14 @@ def _get_orgs(titles): qs = Organization.objects.filter(title__in=titles) return qs - Organization.objects.create(title='Organization 1', short_name='Organization 1', long_name='Organization 1') + Organization.objects.create(title="Organization 1", short_name="Organization 1", long_name="Organization 1") raw_text_labels = [ # Existing - 'Organization 1', + "Organization 1", # New - 'Relief Web', - 'reliefweb', - 'the relief web', + "Relief Web", + "reliefweb", + "the relief web", ] # Fetch/Create using raw_text_labels @@ -82,8 +83,8 @@ def _get_orgs(titles): self.assertEqual(qs.count(), len(raw_text_labels)) # Set Parent Organizations - parent_org = Organization.objects.get(title='Relief Web') - child_titles = ['reliefweb', 'the relief web'] + parent_org = Organization.objects.get(title="Relief Web") + child_titles = ["reliefweb", "the relief web"] qs = _get_orgs(child_titles) self.assertEqual(qs.count(), len(child_titles)) qs.update(parent=parent_org) @@ -94,9 +95,9 @@ def _get_orgs(titles): self.assertEqual(search_organizaton.get(title), parent_org) raw_text_labels += [ - 'Organization 1', # We have a duplicate title here, using set for count now - 'the relief web', - 'the relief web2', + "Organization 1", # We have a duplicate title here, using set for count now + "the relief web", + "the relief web2", ] # Fetch/Create using raw_text_labels @@ -108,8 +109,8 @@ def _get_orgs(titles): self.assertEqual(qs.count(), len(set(raw_text_labels))) # Update newly created child relif web2 parent - Organization.objects.filter(title='the relief web2').update(parent=parent_org) + Organization.objects.filter(title="the relief web2").update(parent=parent_org) # Fetch latest search_organizaton = OrganizationSearch(raw_text_labels, None, None) - self.assertEqual(search_organizaton.get('the relief web2'), parent_org) + self.assertEqual(search_organizaton.get("the relief web2"), parent_org) diff --git a/apps/user/admin.py b/apps/user/admin.py index 1d3ae285ab..65d3070098 100644 --- a/apps/user/admin.py +++ b/apps/user/admin.py @@ -1,11 +1,8 @@ from django.contrib import admin -from django.db import models from django.contrib.auth.admin import UserAdmin -from .models import ( - Profile, User, Feature, EmailDomain, - OTP_MODELS, OTP_PROXY_MODELS -) +from django.db import models +from .models import OTP_MODELS, OTP_PROXY_MODELS, EmailDomain, Feature, Profile, User admin.site.unregister(User) for _, model, _ in OTP_MODELS: @@ -15,24 +12,36 @@ class ProfileInline(admin.StackedInline): model = Profile can_delete = False - verbose_name_plural = 'Profile' - fk_name = 'user' - autocomplete_fields = ('display_picture', 'last_active_project',) + verbose_name_plural = "Profile" + fk_name = "user" + autocomplete_fields = ( + "display_picture", + "last_active_project", + ) @admin.register(User) class CustomUserAdmin(UserAdmin): inlines = [ProfileInline] search_fields = ( - 'username', 'first_name', 'last_name', 'email', 'profile__language', - 'profile__organization', + "username", + "first_name", + "last_name", + "email", + "profile__language", + "profile__organization", ) list_display = ( - 'username', 'email', 'first_name', 'last_name', 'is_staff', - 'get_organization', 'get_language', + "username", + "email", + "first_name", + "last_name", + "is_staff", + "get_organization", + "get_language", ) - list_select_related = ('profile', ) - list_filter = UserAdmin.list_filter + ('profile__invalid_email', ) + list_select_related = ("profile",) + list_filter = UserAdmin.list_filter + ("profile__invalid_email",) def get_organization(self, instance): return instance.profile.organization @@ -45,29 +54,38 @@ def get_inline_instances(self, request, obj=None): return list() return super().get_inline_instances(request, obj) - get_organization.short_description = 'Organization' - get_language.short_description = 'Language' + get_organization.short_description = "Organization" + get_language.short_description = "Language" # Register OTP Proxy Model Dynamically for model, model_admin in OTP_PROXY_MODELS: + class DjangoOTPAdmin(model_admin): - search_fields = [f'user__{user_prop}' for user_prop in CustomUserAdmin.search_fields] - list_display = ('user', 'name', 'confirmed') if len(model_admin.list_display) <= 1 else model_admin.list_display - autocomplete_fields = ('user',) + search_fields = [f"user__{user_prop}" for user_prop in CustomUserAdmin.search_fields] + list_display = ("user", "name", "confirmed") if len(model_admin.list_display) <= 1 else model_admin.list_display + autocomplete_fields = ("user",) + admin.site.register(model, DjangoOTPAdmin) @admin.register(Feature) class CustomFeature(admin.ModelAdmin): - search_fields = ('title',) - list_display = ('title', 'feature_type', 'user_count',) - filter_horizontal = ('users', 'email_domains',) + search_fields = ("title",) + list_display = ( + "title", + "feature_type", + "user_count", + ) + filter_horizontal = ( + "users", + "email_domains", + ) def get_readonly_fields(self, request, obj=None): # editing an existing object if obj: - return self.readonly_fields + ('key', ) + return self.readonly_fields + ("key",) return self.readonly_fields def has_add_permission(self, request): @@ -79,26 +97,30 @@ def has_delete_permission(self, request, obj=None): def user_count(self, instance): if not instance: return - query = models.Q(pk__in=instance.users.values_list('pk', flat=True)) + query = models.Q(pk__in=instance.users.values_list("pk", flat=True)) for item in [ models.Q(username__iendswith=domain_name) - for domain_name in instance.email_domains.values_list('domain_name', flat=True) + for domain_name in instance.email_domains.values_list("domain_name", flat=True) ]: query |= item return User.objects.filter(query).distinct().count() - user_count.short_description = 'User Count' + user_count.short_description = "User Count" @admin.register(EmailDomain) class EmailDoaminAdmin(admin.ModelAdmin): - search_fields = ('title', 'domain_name') - list_display = ('title', 'domain_name', 'user_count') + search_fields = ("title", "domain_name") + list_display = ("title", "domain_name", "user_count") def user_count(self, instance): if instance: - return User.objects.filter( - username__iendswith=instance.domain_name, - ).distinct().count() - - user_count.short_description = 'User Count' + return ( + User.objects.filter( + username__iendswith=instance.domain_name, + ) + .distinct() + .count() + ) + + user_count.short_description = "User Count" diff --git a/apps/user/apps.py b/apps/user/apps.py index 56ea53c660..63f5031dc6 100644 --- a/apps/user/apps.py +++ b/apps/user/apps.py @@ -5,8 +5,9 @@ def device_classes(): """ Returns an iterable of all loaded device models. """ - from django_otp.models import Device from django.apps import apps + from django_otp.models import Device + for config in apps.get_app_configs(): for model in config.get_models(): if issubclass(model, Device) and not model._meta.proxy: @@ -14,11 +15,12 @@ def device_classes(): class UserConfig(AppConfig): - name = 'user' - verbose_name = '[DEEP] Authentication and Authorization' + name = "user" + verbose_name = "[DEEP] Authentication and Authorization" def ready(self): - import user.receivers # noqa import django_otp + import user.receivers # noqa + # Override to avoid capturing proxy models django_otp.device_classes = device_classes diff --git a/apps/user/dataloaders.py b/apps/user/dataloaders.py index 6afa92b512..d74e0e6f4f 100644 --- a/apps/user/dataloaders.py +++ b/apps/user/dataloaders.py @@ -1,27 +1,26 @@ -from promise import Promise from collections import defaultdict + from django.utils.functional import cached_property +from promise import Promise +from user.models import User from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from user.models import User from .models import Profile class UserProfileLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - profile_qs = Profile.objects\ - .filter(user__in=keys)\ - .select_related('display_picture')\ + profile_qs = ( + Profile.objects.filter(user__in=keys) + .select_related("display_picture") .only( - 'user_id', - 'organization', - 'display_picture__file', + "user_id", + "organization", + "display_picture__file", ) - _map = { - profile.user_id: profile - for profile in profile_qs - } + ) + _map = {profile.user_id: profile for profile in profile_qs} return Promise.resolve([_map.get(key) for key in keys]) diff --git a/apps/user/enums.py b/apps/user/enums.py index d55ff9512f..8bb971d75f 100644 --- a/apps/user/enums.py +++ b/apps/user/enums.py @@ -3,24 +3,22 @@ get_enum_name_from_django_field, ) -from .models import User, Profile +from .models import Profile, User -UserEmailConditionOptOutEnum = convert_enum_to_graphene_enum( - Profile.EmailConditionOptOut, name='UserEmailConditionOptOutEnum') +UserEmailConditionOptOutEnum = convert_enum_to_graphene_enum(Profile.EmailConditionOptOut, name="UserEmailConditionOptOutEnum") enum_map = { - get_enum_name_from_django_field(field): enum - for field, enum in ( - (Profile.email_opt_outs, UserEmailConditionOptOutEnum), - ) + get_enum_name_from_django_field(field): enum for field, enum in ((Profile.email_opt_outs, UserEmailConditionOptOutEnum),) } # Additional enums which doesn't have a field in model but are used in serializer -enum_map.update({ - get_enum_name_from_django_field( - None, - field_name='email_opt_outs', # UserMeSerializer.email_opt_outs - model_name=User.__name__, - ): UserEmailConditionOptOutEnum, -}) +enum_map.update( + { + get_enum_name_from_django_field( + None, + field_name="email_opt_outs", # UserMeSerializer.email_opt_outs + model_name=User.__name__, + ): UserEmailConditionOptOutEnum, + } +) diff --git a/apps/user/factories.py b/apps/user/factories.py index 92ded19a93..dfbe5ebba7 100644 --- a/apps/user/factories.py +++ b/apps/user/factories.py @@ -2,17 +2,16 @@ from factory import fuzzy from factory.django import DjangoModelFactory -from .models import User, Feature +from .models import Feature, User from .serializers import UserSerializer - -PROFILE_FIELDS = ['display_picture', 'organization', 'language', 'email_opt_outs', 'last_active_project'] +PROFILE_FIELDS = ["display_picture", "organization", "language", "email_opt_outs", "last_active_project"] class UserFactory(DjangoModelFactory): - first_name = factory.Faker('first_name') - last_name = factory.Faker('last_name') - email = factory.Sequence(lambda n: f'{n}@xyz.com') + first_name = factory.Faker("first_name") + last_name = factory.Faker("last_name") + email = factory.Sequence(lambda n: f"{n}@xyz.com") username = factory.LazyAttribute(lambda user: user.email) password_text = fuzzy.FuzzyText(length=15) password = factory.PostGeneration(lambda user, *args, **kwargs: user.set_password(user.password_text)) @@ -22,10 +21,8 @@ class Meta: @classmethod def _create(cls, model_class, *args, **kwargs): - password_text = kwargs.pop('password_text') - profile_data = { - key: kwargs.pop(key) for key in PROFILE_FIELDS if key in kwargs - } + password_text = kwargs.pop("password_text") + profile_data = {key: kwargs.pop(key) for key in PROFILE_FIELDS if key in kwargs} user = super()._create(model_class, *args, **kwargs) UserSerializer.update_or_create_profile(user, profile_data) user.profile.refresh_from_db() @@ -34,7 +31,7 @@ def _create(cls, model_class, *args, **kwargs): class FeatureFactory(DjangoModelFactory): - title = factory.PostGeneration(lambda feature, *args, **kwargs: f'Feature {feature.key}') + title = factory.PostGeneration(lambda feature, *args, **kwargs: f"Feature {feature.key}") feature_type = fuzzy.FuzzyChoice(Feature.FeatureType.choices, getter=lambda c: c[0]) class Meta: diff --git a/apps/user/filters.py b/apps/user/filters.py index 790dc81529..3de19f5806 100644 --- a/apps/user/filters.py +++ b/apps/user/filters.py @@ -1,8 +1,8 @@ import django_filters from django.db import models from django.db.models.functions import Concat -from utils.graphene.filters import IDFilter +from utils.graphene.filters import IDFilter from .models import User @@ -10,39 +10,33 @@ class UserFilterSet(django_filters.FilterSet): class Meta: model = User - fields = ['id'] + fields = ["id"] # -------------------- Graphql Filter --------------------------------- class UserGqlFilterSet(django_filters.FilterSet): - search = django_filters.CharFilter(method='filter_search') - members_exclude_project = IDFilter(method='filter_exclude_project') - members_exclude_framework = IDFilter(method='filter_exclude_framework') - members_exclude_usergroup = IDFilter(method='filter_exclude_usergroup') + search = django_filters.CharFilter(method="filter_search") + members_exclude_project = IDFilter(method="filter_exclude_project") + members_exclude_framework = IDFilter(method="filter_exclude_framework") + members_exclude_usergroup = IDFilter(method="filter_exclude_usergroup") class Meta: model = User - fields = ('id',) + fields = ("id",) def filter_exclude_project(self, qs, name, value): if value: - qs = qs.filter( - ~models.Q(projectmembership__project_id=value) - ).distinct() + qs = qs.filter(~models.Q(projectmembership__project_id=value)).distinct() return qs def filter_exclude_framework(self, qs, name, value): if value: - qs = qs.filter( - ~models.Q(framework_membership__framework_id=value) - ) + qs = qs.filter(~models.Q(framework_membership__framework_id=value)) return qs def filter_exclude_usergroup(self, qs, name, value): if value: - qs = qs.filter( - ~models.Q(groupmembership__group_id=value) - ) + qs = qs.filter(~models.Q(groupmembership__group_id=value)) return qs def filter_search(self, qs, name, value): @@ -55,11 +49,11 @@ def filter_search(self, qs, name, value): output_field=models.CharField(), ) ).filter( - models.Q(full_name__icontains=value) | - models.Q(first_name__icontains=value) | - models.Q(last_name__icontains=value) | - models.Q(email__icontains=value) | - models.Q(username__icontains=value) + models.Q(full_name__icontains=value) + | models.Q(first_name__icontains=value) + | models.Q(last_name__icontains=value) + | models.Q(email__icontains=value) + | models.Q(username__icontains=value) ) return qs diff --git a/apps/user/models.py b/apps/user/models.py index f315520a7d..409ebc7d8b 100644 --- a/apps/user/models.py +++ b/apps/user/models.py @@ -1,31 +1,26 @@ -from django.contrib.postgres.fields import ArrayField +from django.conf import settings from django.contrib.auth.models import User +from django.contrib.postgres.fields import ArrayField from django.db import models from django.db.models import Q -from django.conf import settings from django.utils import timezone - -from django_otp.plugins import ( - otp_static, - otp_totp, - otp_email, -) +from django_otp.plugins import otp_email, otp_static, otp_totp +from django_otp.plugins.otp_email.admin import EmailDeviceAdmin from django_otp.plugins.otp_static.admin import StaticDeviceAdmin from django_otp.plugins.otp_totp.admin import TOTPDeviceAdmin -from django_otp.plugins.otp_email.admin import EmailDeviceAdmin +from gallery.models import File from utils.common import camelcase_to_titlecase -from gallery.models import File class EmailCondition(models.TextChoices): - JOIN_REQUESTS = 'join_requests', 'Project join requests' - NEWS_AND_UPDATES = 'news_and_updates', 'News and updates' - EMAIL_COMMENT = 'email_comment', 'Entry comment updates' + JOIN_REQUESTS = "join_requests", "Project join requests" + NEWS_AND_UPDATES = "news_and_updates", "News and updates" + EMAIL_COMMENT = "email_comment", "Entry comment updates" # Always send - ACCOUNT_ACTIVATION = 'account_activation', 'Account Activation' - PASSWORD_RESET = 'password_reset', 'Password Reset' - PASSWORD_CHANGED = 'password_changed', 'Password Changed' + ACCOUNT_ACTIVATION = "account_activation", "Account Activation" + PASSWORD_RESET = "password_reset", "Password Reset" + PASSWORD_CHANGED = "password_changed", "Password Changed" class Profile(models.Model): @@ -35,6 +30,7 @@ class Profile(models.Model): Extra attributes for the user besides the django provided ones. """ + class EmailConditionOptOut(models.TextChoices): JOIN_REQUESTS = EmailCondition.JOIN_REQUESTS NEWS_AND_UPDATES = EmailCondition.NEWS_AND_UPDATES @@ -45,7 +41,7 @@ class EmailConditionOptOut(models.TextChoices): ALWAYS_SEND_EMAIL_CONDITIONS = [ EmailCondition.ACCOUNT_ACTIVATION, EmailCondition.PASSWORD_RESET, - EmailCondition.PASSWORD_CHANGED + EmailCondition.PASSWORD_CHANGED, ] user = models.OneToOneField(User, on_delete=models.CASCADE) @@ -53,12 +49,18 @@ class EmailConditionOptOut(models.TextChoices): hid = models.TextField(default=None, null=True, blank=True) # country = models.ForeignKey(Country, on_delete=models.SET_NULL) display_picture = models.ForeignKey( - File, on_delete=models.SET_NULL, null=True, blank=True, default=None, + File, + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, ) last_active_project = models.ForeignKey( - 'project.Project', null=True, - blank=True, default=None, + "project.Project", + null=True, + blank=True, + default=None, on_delete=models.SET_NULL, ) @@ -69,7 +71,7 @@ class EmailConditionOptOut(models.TextChoices): ) login_attempts = models.IntegerField(default=0) - invalid_email = models.BooleanField(default=False, help_text='Flagged as bounce email') + invalid_email = models.BooleanField(default=False, help_text="Flagged as bounce email") email_opt_outs = ArrayField( models.CharField(max_length=128, choices=EmailConditionOptOut.choices), default=list, @@ -92,12 +94,14 @@ def __str__(self): @staticmethod def get_user_accessible_features(user): try: - user_domain = (user.email or user.username).split('@')[1] - return Feature.objects.filter( - Q(is_available_for_all=True) | - Q(users=user) | - Q(email_domains__domain_name__exact=user_domain) - ).order_by('key').distinct() + user_domain = (user.email or user.username).split("@")[1] + return ( + Feature.objects.filter( + Q(is_available_for_all=True) | Q(users=user) | Q(email_domains__domain_name__exact=user_domain) + ) + .order_by("key") + .distinct() + ) except IndexError: return Feature.objects.none() @@ -110,24 +114,18 @@ def have_feature_access_for_user(user, feature): @staticmethod def get_display_name_for_user(user): - return user.get_full_name() or f'User#{user.pk}' + return user.get_full_name() or f"User#{user.pk}" def get_display_name(self): return self.get_display_name_for_user(self.user) def unsubscribe_email(self, email_type): - if ( - email_type not in self.ALWAYS_SEND_EMAIL_CONDITIONS and - self.is_email_subscribed_for(email_type) - ): + if email_type not in self.ALWAYS_SEND_EMAIL_CONDITIONS and self.is_email_subscribed_for(email_type): self.email_opt_outs.append(email_type) def is_email_subscribed_for(self, email_type): - if ( - email_type in self.ALWAYS_SEND_EMAIL_CONDITIONS or ( - email_type in Profile.EMAIL_CONDITIONS_TYPES and - email_type not in self.email_opt_outs - ) + if email_type in self.ALWAYS_SEND_EMAIL_CONDITIONS or ( + email_type in Profile.EMAIL_CONDITIONS_TYPES and email_type not in self.email_opt_outs ): return True return False @@ -156,7 +154,7 @@ def soft_delete(self, deleted_at=None, commit=True): user.is_active = False user.first_name = settings.DELETED_USER_FIRST_NAME user.last_name = settings.DELETED_USER_LAST_NAME - user.email = user.username = f'user-{user.id}@{settings.DELETED_USER_EMAIL_DOMAIN}' + user.email = user.username = f"user-{user.id}@{settings.DELETED_USER_EMAIL_DOMAIN}" # Profile Data self.deleted_at = deleted_at or timezone.now() self.original_data = original_data @@ -168,23 +166,23 @@ def soft_delete(self, deleted_at=None, commit=True): user.save( update_fields=( # User Data - 'first_name', - 'last_name', - 'email', - 'username', - 'is_active', + "first_name", + "last_name", + "email", + "username", + "is_active", ) ) self.save( update_fields=( # Deleted metadata - 'deleted_at', - 'original_data', + "deleted_at", + "original_data", # Profile Data - 'invalid_email', - 'organization', - 'hid', - 'display_picture', + "invalid_email", + "organization", + "hid", + "display_picture", ) ) @@ -209,9 +207,8 @@ def user_get_display_email(user): def get_for_project(project): - return User.objects.prefetch_related('profile').filter( - models.Q(projectmembership__project=project) | - models.Q(usergroup__in=project.user_groups.all()) + return User.objects.prefetch_related("profile").filter( + models.Q(projectmembership__project=project) | models.Q(usergroup__in=project.user_groups.all()) ) @@ -224,28 +221,28 @@ class EmailDomain(models.Model): domain_name = models.CharField(max_length=255) def __str__(self): - return f'{self.title}({self.domain_name})' + return f"{self.title}({self.domain_name})" class Feature(models.Model): class FeatureType(models.TextChoices): - GENERAL_ACCESS = 'general_access', 'General access' - EXPERIMENTAL = 'experimental', 'Experimental' - EARLY_ACCESS = 'early_access', 'Early access' + GENERAL_ACCESS = "general_access", "General access" + EXPERIMENTAL = "experimental", "Experimental" + EARLY_ACCESS = "early_access", "Early access" class FeatureKey(models.TextChoices): - PRIVATE_PROJECT = 'private_project', 'Private projects' - TABULAR = 'tabular', 'Tabular' - ZOOMABLE_IMAGE = 'zoomable_image', 'Zoomable image' - POLYGON_SUPPORT_GEO = 'polygon_support_geo', 'Polygon support geo' - ENTRY_VISUALIZATION_CONFIGURATION = 'entry_visualization_configuration', 'Entry visualization configuration' - CONNECTORS = 'connectors', 'Unified Connectors' - ASSISTED = 'assisted', 'Assisted Tagging' + PRIVATE_PROJECT = "private_project", "Private projects" + TABULAR = "tabular", "Tabular" + ZOOMABLE_IMAGE = "zoomable_image", "Zoomable image" + POLYGON_SUPPORT_GEO = "polygon_support_geo", "Polygon support geo" + ENTRY_VISUALIZATION_CONFIGURATION = "entry_visualization_configuration", "Entry visualization configuration" + CONNECTORS = "connectors", "Unified Connectors" + ASSISTED = "assisted", "Assisted Tagging" # Deprecated keys - QUALITY_CONTROL = 'quality_control', 'Quality Control (Deprecated)' - NEW_UI = 'new_ui', 'New UI (Deprecated)' - ANALYSIS = 'analysis', 'Analysis (Deprecated)' - QUESTIONNAIRE = 'questionnaire', 'Questionnaire Builder' + QUALITY_CONTROL = "quality_control", "Quality Control (Deprecated)" + NEW_UI = "new_ui", "New UI (Deprecated)" + ANALYSIS = "analysis", "Analysis (Deprecated)" + QUESTIONNAIRE = "questionnaire", "Questionnaire Builder" key = models.CharField(max_length=255, unique=True, choices=FeatureKey.choices) title = models.CharField(max_length=255) @@ -265,26 +262,33 @@ def gen_auth_proxy_model(ModelClass, _label=None): class Meta: proxy = True - app_label = 'user' - verbose_name = f'{t_label}' - verbose_name_plural = f'{t_label}s' - - model = type(f"{label.replace(' ', '_')}", (ModelClass,), { - '__module__': __name__, - 'Meta': Meta, - }) + app_label = "user" + verbose_name = f"{t_label}" + verbose_name_plural = f"{t_label}s" + + model = type( + f"{label.replace(' ', '_')}", + (ModelClass,), + { + "__module__": __name__, + "Meta": Meta, + }, + ) return model OTP_MODELS = ( - ('OTP Static', otp_static.models.StaticDevice, StaticDeviceAdmin), - ('OTP TOTP', otp_totp.models.TOTPDevice, TOTPDeviceAdmin), - ('OTP Email', otp_email.models.EmailDevice, EmailDeviceAdmin), + ("OTP Static", otp_static.models.StaticDevice, StaticDeviceAdmin), + ("OTP TOTP", otp_totp.models.TOTPDevice, TOTPDeviceAdmin), + ("OTP Email", otp_email.models.EmailDevice, EmailDeviceAdmin), ) OTP_PROXY_MODELS = [] # Create OTP Proxy Model Dynamically for label, model, model_admin in OTP_MODELS: - OTP_PROXY_MODELS.append([ - gen_auth_proxy_model(model, label), model_admin, - ]) + OTP_PROXY_MODELS.append( + [ + gen_auth_proxy_model(model, label), + model_admin, + ] + ) diff --git a/apps/user/mutation.py b/apps/user/mutation.py index 5550cce823..b73df149f4 100644 --- a/apps/user/mutation.py +++ b/apps/user/mutation.py @@ -1,28 +1,26 @@ import graphene -from django.contrib.auth import login, logout -from django.contrib.auth import update_session_auth_hash +from django.contrib.auth import login, logout, update_session_auth_hash from django.db import models -from utils.graphene.error_types import mutation_is_not_valid, CustomErrorType +from utils.graphene.error_types import CustomErrorType, mutation_is_not_valid from utils.graphene.mutation import generate_input_type_for_serializer +from .schema import UserMeType +from .serializers import GqPasswordResetSerializer as ResetPasswordSerializer from .serializers import ( + HIDLoginSerializer, LoginSerializer, - RegisterSerializer, - GqPasswordResetSerializer as ResetPasswordSerializer, PasswordChangeSerializer, + RegisterSerializer, UserMeSerializer, - HIDLoginSerializer, ) -from .schema import UserMeType - -LoginInputType = generate_input_type_for_serializer('LoginInputType', LoginSerializer) -HIDLoginInputType = generate_input_type_for_serializer('HIDLoginInputType', HIDLoginSerializer) -RegisterInputType = generate_input_type_for_serializer('RegisterInputType', RegisterSerializer) -ResetPasswordInputType = generate_input_type_for_serializer('ResetPasswordInputType', ResetPasswordSerializer) -PasswordChangeInputType = generate_input_type_for_serializer('PasswordChangeInputType', PasswordChangeSerializer) -UserMeInputType = generate_input_type_for_serializer('UserMeInputType', UserMeSerializer) +LoginInputType = generate_input_type_for_serializer("LoginInputType", LoginSerializer) +HIDLoginInputType = generate_input_type_for_serializer("HIDLoginInputType", HIDLoginSerializer) +RegisterInputType = generate_input_type_for_serializer("RegisterInputType", RegisterSerializer) +ResetPasswordInputType = generate_input_type_for_serializer("ResetPasswordInputType", ResetPasswordSerializer) +PasswordChangeInputType = generate_input_type_for_serializer("PasswordChangeInputType", PasswordChangeSerializer) +UserMeInputType = generate_input_type_for_serializer("UserMeInputType", UserMeSerializer) class Login(graphene.Mutation): @@ -36,20 +34,16 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = LoginSerializer(data=data, context={'request': info.context.request}) + serializer = LoginSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return Login( errors=errors, ok=False, - captcha_required=LoginSerializer.is_captcha_required(email=data['email']), + captcha_required=LoginSerializer.is_captcha_required(email=data["email"]), ) - if user := serializer.validated_data.get('user'): + if user := serializer.validated_data.get("user"): login(info.context.request, user) - return Login( - result=user, - errors=None, - ok=True - ) + return Login(result=user, errors=None, ok=True) class LoginWithHID(graphene.Mutation): @@ -62,16 +56,12 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = HIDLoginSerializer(data=data, context={'request': info.context.request}) + serializer = HIDLoginSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return LoginWithHID(errors=errors, ok=False) - if user := serializer.validated_data.get('user'): + if user := serializer.validated_data.get("user"): login(info.context.request, user) - return LoginWithHID( - result=user, - errors=None, - ok=True - ) + return LoginWithHID(result=user, errors=None, ok=True) class Logout(graphene.Mutation): @@ -93,17 +83,14 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = RegisterSerializer(data=data, context={'request': info.context.request}) + serializer = RegisterSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return Register( errors=errors, ok=False, ) serializer.save() - return Register( - errors=None, - ok=True - ) + return Register(errors=None, ok=True) class ResetPassword(graphene.Mutation): @@ -116,17 +103,14 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = ResetPasswordSerializer(data=data, context={'request': info.context.request}) + serializer = ResetPasswordSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return ResetPassword( errors=errors, ok=False, ) serializer.save() - return ResetPassword( - errors=None, - ok=True - ) + return ResetPassword(errors=None, ok=True) class ChangeUserPassword(graphene.Mutation): @@ -138,7 +122,7 @@ class Arguments: @staticmethod def mutate(root, info, data): - serializer = PasswordChangeSerializer(data=data, context={'request': info.context.request}) + serializer = PasswordChangeSerializer(data=data, context={"request": info.context.request}) if errors := mutation_is_not_valid(serializer): return ChangeUserPassword(errors=errors, ok=False) serializer.save() @@ -159,7 +143,7 @@ def mutate(root, info, data): serializer = UserMeSerializer( instance=info.context.user, data=data, - context={'request': info.context.request}, + context={"request": info.context.request}, ) if errors := mutation_is_not_valid(serializer): return UpdateMe(errors=errors, ok=False) @@ -178,40 +162,41 @@ def mutate(_, info): current_user = info.context.user if current_user.profile.deleted_at: - return UserDelete( - errors=[ - dict( - field='nonFieldErrors', - messages='Already deleted.' - ) - ] - ) + return UserDelete(errors=[dict(field="nonFieldErrors", messages="Already deleted.")]) def user_member_project_ids(owner=False): # Member in Projects - project_ids = ProjectMembership.objects.filter(member=current_user).values('project') + project_ids = ProjectMembership.objects.filter(member=current_user).values("project") extra_filter = {} if owner: project_ids = project_ids.filter(role__type=ProjectRole.Type.PROJECT_OWNER) - extra_filter['role__type'] = ProjectRole.Type.PROJECT_OWNER - return ProjectMembership.objects.filter( - member__profile__deleted_at__isnull=True, # Exclude already deleted users - project__in=project_ids, - **extra_filter, - ).order_by().values('project').annotate( - member_count=models.Count('member', distinct=True), - ).filter(member_count=1).values_list('project', 'project__title') + extra_filter["role__type"] = ProjectRole.Type.PROJECT_OWNER + return ( + ProjectMembership.objects.filter( + member__profile__deleted_at__isnull=True, # Exclude already deleted users + project__in=project_ids, + **extra_filter, + ) + .order_by() + .values("project") + .annotate( + member_count=models.Count("member", distinct=True), + ) + .filter(member_count=1) + .values_list("project", "project__title") + ) only_user_member_projects = user_member_project_ids() if only_user_member_projects: return UserDelete( errors=[ dict( - field='nonFieldErrors', - messages='You are only the member in Projects %s. Choose other members before you delete yourself.' - % ', '.join([f'[{_id}]{title}' for _id, title in only_user_member_projects]), + field="nonFieldErrors", + messages="You are only the member in Projects %s. Choose other members before you delete yourself." + % ", ".join([f"[{_id}]{title}" for _id, title in only_user_member_projects]), ) - ], ok=False + ], + ok=False, ) # user only the owner in the project @@ -220,17 +205,18 @@ def user_member_project_ids(owner=False): return UserDelete( errors=[ dict( - field='nonFieldErrors', - messages='You are Owner in Projects %s. Choose another Project Owner before you delete yourself.' - % ', '.join([f'[{_id}]{title}' for _id, title in only_user_owner_role_in_projects]), + field="nonFieldErrors", + messages="You are Owner in Projects %s. Choose another Project Owner before you delete yourself." + % ", ".join([f"[{_id}]{title}" for _id, title in only_user_owner_role_in_projects]), ) - ], ok=False + ], + ok=False, ) current_user.soft_delete() return UserDelete(result=current_user, errors=None, ok=True) -class Mutation(): +class Mutation: login = Login.Field() login_with_hid = LoginWithHID.Field() logout = Logout.Field() diff --git a/apps/user/notifications.py b/apps/user/notifications.py index 380ac1c267..9db5b83bf9 100644 --- a/apps/user/notifications.py +++ b/apps/user/notifications.py @@ -1,12 +1,9 @@ -from project.models import ( - Project, - ProjectJoinRequest, -) +from project.models import Project, ProjectJoinRequest from project.serializers import ProjectJoinRequestSerializer class Notification: - PROJECT_JOIN_REQUEST = 'project_join_request' + PROJECT_JOIN_REQUEST = "project_join_request" def __init__(self, date, notification_type): self.date = date @@ -15,8 +12,7 @@ def __init__(self, date, notification_type): def _get_project_join_requests(user): - admin_projects = Project.get_modifiable_for(user)\ - .values_list('id', flat=True) + admin_projects = Project.get_modifiable_for(user).values_list("id", flat=True) join_requests = ProjectJoinRequest.objects.filter( project__id__in=admin_projects, diff --git a/apps/user/permissions.py b/apps/user/permissions.py index e2092739be..2ce4cf9582 100644 --- a/apps/user/permissions.py +++ b/apps/user/permissions.py @@ -6,7 +6,7 @@ def _is_authenticated(self, rq): return rq.user.is_authenticated def has_permission(self, request, view): - if self._is_authenticated(request) or view.action == 'create': + if self._is_authenticated(request) or view.action == "create": # NOTE:for create user using same api, so return True for `create` return True return False diff --git a/apps/user/receivers.py b/apps/user/receivers.py index 00c19be054..c38c927696 100644 --- a/apps/user/receivers.py +++ b/apps/user/receivers.py @@ -1,8 +1,7 @@ -from django.dispatch import receiver from django.db.models.signals import post_save - +from django.dispatch import receiver from project.models import Project, ProjectMembership -from user.models import User, Profile +from user.models import Profile, User def assign_to_default_project(user): @@ -14,7 +13,7 @@ def assign_to_default_project(user): ProjectMembership.objects.create( member=user, project=default_project, - role='normal', + role="normal", ) diff --git a/apps/user/schema.py b/apps/user/schema.py index fd7c3722ab..f80d2d7592 100644 --- a/apps/user/schema.py +++ b/apps/user/schema.py @@ -1,28 +1,22 @@ -import time import datetime - -from typing import Union, List +import time +from typing import List, Union import graphene +from django.db import models +from django.utils import timezone from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination -from django.utils import timezone -from django.db import models - -from utils.graphene.types import CustomDjangoListObjectType -from utils.graphene.fields import DjangoPaginatedListObjectField from jwt_auth.token import AccessToken -from deep.serializers import URLCachedFileField +from project.models import Project, ProjectMembership, ProjectRole -from project.models import ( - Project, - ProjectMembership, - ProjectRole, -) +from deep.serializers import URLCachedFileField +from utils.graphene.fields import DjangoPaginatedListObjectField +from utils.graphene.types import CustomDjangoListObjectType -from .models import User, Profile, Feature from .enums import UserEmailConditionOptOutEnum from .filters import UserGqlFilterSet +from .models import Feature, Profile, User from .utils import generate_hidden_email @@ -30,31 +24,45 @@ def only_me(func): def wrapper(root, info, *args, **kwargs): if root == info.context.user: return func(root, info, *args, **kwargs) + return wrapper def user_member_project_ids(current_user, owner=False): # Member in Projects - project_ids = ProjectMembership.objects.filter(member=current_user).values('project') + project_ids = ProjectMembership.objects.filter(member=current_user).values("project") if owner: project_ids = project_ids.filter(role__type=ProjectRole.Type.PROJECT_OWNER) - project_members = ProjectMembership.objects.filter( - member__profile__deleted_at__isnull=True, # Exclude already deleted users - project__in=project_ids, - ).order_by().values('project').annotate( - member_count=models.Count('member', distinct=True), - ).filter(member_count=1).values_list('project', 'project__title') + project_members = ( + ProjectMembership.objects.filter( + member__profile__deleted_at__isnull=True, # Exclude already deleted users + project__in=project_ids, + ) + .order_by() + .values("project") + .annotate( + member_count=models.Count("member", distinct=True), + ) + .filter(member_count=1) + .values_list("project", "project__title") + ) else: - project_members = ProjectMembership.objects.filter( - ~models.Q(role__type=ProjectRole.Type.PROJECT_OWNER), - member__profile__deleted_at__isnull=True, # Exclude already deleted users - project__in=project_ids, - ).order_by().values('project').values_list('project', 'project__title') + project_members = ( + ProjectMembership.objects.filter( + ~models.Q(role__type=ProjectRole.Type.PROJECT_OWNER), + member__profile__deleted_at__isnull=True, # Exclude already deleted users + project__in=project_ids, + ) + .order_by() + .values("project") + .values_list("project", "project__title") + ) return [ { - 'id': project_id, - 'title': project_title, - } for project_id, project_title in project_members + "id": project_id, + "title": project_title, + } + for project_id, project_title in project_members ] @@ -66,7 +74,7 @@ class JwtTokenType(graphene.ObjectType): class UserFeatureAccessType(DjangoObjectType): class Meta: model = Feature - only_fields = ('key', 'title', 'feature_type') + only_fields = ("key", "title", "feature_type") class UserProfileType(graphene.ObjectType): @@ -77,16 +85,15 @@ class UserProfileType(graphene.ObjectType): @staticmethod def resolve_display_picture_url(root, info, **kwargs) -> Union[str, None]: if root.display_picture: - return info.context.request.build_absolute_uri( - URLCachedFileField().to_representation(root.display_picture.file) - ) + return info.context.request.build_absolute_uri(URLCachedFileField().to_representation(root.display_picture.file)) class UserType(DjangoObjectType): class Meta: model = User only_fields = ( - 'id', 'is_active', + "id", + "is_active", ) display_name = graphene.String() @@ -118,8 +125,12 @@ class Meta: model = User skip_registry = True only_fields = ( - 'id', 'first_name', 'last_name', 'is_active', - 'email', 'last_login', + "id", + "first_name", + "last_name", + "is_active", + "email", + "last_login", ) display_name = graphene.String() @@ -131,7 +142,7 @@ class Meta: language = graphene.String() email_opt_outs = graphene.List(graphene.NonNull(UserEmailConditionOptOutEnum)) jwt_token = graphene.Field(JwtTokenType) - last_active_project = graphene.Field('project.schema.ProjectDetailType') + last_active_project = graphene.Field("project.schema.ProjectDetailType") accessible_features = graphene.List(graphene.NonNull(UserFeatureAccessType), required=True) deleted_at = graphene.Date() sole_projects = graphene.List(UserMeProjectType) @@ -159,7 +170,7 @@ def resolve_last_active_project(root, info, **kwargs) -> Union[int, None]: if project and project.get_current_user_role(info.context.user): return project # As a fallback return last created member project - return Project.get_for_gq(info.context.user, only_member=True).order_by('-id').first() + return Project.get_for_gq(info.context.user, only_member=True).order_by("-id").first() @staticmethod def resolve_organization(root, info, **kwargs) -> Union[str, None]: @@ -224,12 +235,7 @@ class Meta: class Query: me = graphene.Field(UserMeType) user = DjangoObjectField(UserType) - users = DjangoPaginatedListObjectField( - UserListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) - ) + users = DjangoPaginatedListObjectField(UserListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize")) @staticmethod def resolve_me(root, info, **kwargs) -> Union[User, None]: diff --git a/apps/user/serializers.py b/apps/user/serializers.py index 5df10705be..f2b8a23d34 100644 --- a/apps/user/serializers.py +++ b/apps/user/serializers.py @@ -1,13 +1,23 @@ import logging +from django.conf import settings +from django.contrib.auth import authenticate from django.contrib.auth.models import User from django.contrib.auth.password_validation import validate_password -from django.contrib.auth import authenticate -from django.conf import settings -from django.db import transaction, models - +from django.db import models, transaction from drf_dynamic_fields import DynamicFieldsMixin +from gallery.models import File +from jwt_auth.captcha import validate_hcaptcha +from jwt_auth.errors import UserNotFoundError +from project.models import Project from rest_framework import serializers +from user.models import Feature, Profile +from user.utils import ( + get_client_ip, + get_device_type, + send_password_changed_notification, + send_password_reset, +) from deep.serializers import ( RemoveNullFieldsMixin, @@ -15,18 +25,6 @@ WriteOnlyOnCreateSerializerMixin, ) from utils.hid import hid -from user.models import Profile, Feature -from user.utils import ( - send_password_reset, - send_password_changed_notification, - get_client_ip, - get_device_type -) -from project.models import Project -from gallery.models import File - -from jwt_auth.captcha import validate_hcaptcha -from jwt_auth.errors import UserNotFoundError from .utils import send_account_activation from .validators import CustomMaximumLengthValidator @@ -36,75 +34,68 @@ class NanoUserSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): display_name = serializers.CharField( - source='profile.get_display_name', + source="profile.get_display_name", read_only=True, ) class Meta: model = User - fields = ('id', 'display_name') + fields = ("id", "display_name") class SimpleUserSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): display_name = serializers.CharField( - source='profile.get_display_name', + source="profile.get_display_name", read_only=True, ) display_picture = serializers.PrimaryKeyRelatedField( - source='profile.display_picture', + source="profile.display_picture", read_only=True, ) display_picture_url = URLCachedFileField( - source='profile.display_picture.file', + source="profile.display_picture.file", read_only=True, ) - organization_title = serializers.CharField( - source='profile.organization', - read_only=True - ) + organization_title = serializers.CharField(source="profile.organization", read_only=True) class Meta: model = User - fields = ('id', 'display_name', 'email', - 'display_picture', 'display_picture_url', - 'organization_title') + fields = ("id", "display_name", "email", "display_picture", "display_picture_url", "organization_title") -class UserSerializer(RemoveNullFieldsMixin, WriteOnlyOnCreateSerializerMixin, - DynamicFieldsMixin, serializers.ModelSerializer): - organization = serializers.CharField(source='profile.organization', - allow_blank=True) +class UserSerializer(RemoveNullFieldsMixin, WriteOnlyOnCreateSerializerMixin, DynamicFieldsMixin, serializers.ModelSerializer): + organization = serializers.CharField(source="profile.organization", allow_blank=True) language = serializers.CharField( - source='profile.language', + source="profile.language", allow_null=True, required=False, ) display_picture = serializers.PrimaryKeyRelatedField( - source='profile.display_picture', + source="profile.display_picture", queryset=File.objects.all(), allow_null=True, required=False, ) display_picture_url = URLCachedFileField( - source='profile.display_picture.file', + source="profile.display_picture.file", read_only=True, ) display_name = serializers.CharField( - source='profile.get_display_name', + source="profile.get_display_name", read_only=True, ) last_active_project = serializers.PrimaryKeyRelatedField( - source='profile.last_active_project', + source="profile.last_active_project", queryset=Project.objects.all(), allow_null=True, required=False, ) email_opt_outs = serializers.ListField( - source='profile.email_opt_outs', + source="profile.email_opt_outs", required=False, ) login_attempts = serializers.IntegerField( - source='profile.login_attempts', + source="profile.login_attempts", read_only=True, ) @@ -112,32 +103,41 @@ class UserSerializer(RemoveNullFieldsMixin, WriteOnlyOnCreateSerializerMixin, class Meta: model = User - fields = ('id', 'username', 'first_name', 'last_name', - 'display_name', 'last_active_project', - 'login_attempts', 'hcaptcha_response', - 'email', 'organization', 'display_picture', - 'display_picture_url', 'language', 'email_opt_outs') - write_only_on_create_fields = ('email', 'username') + fields = ( + "id", + "username", + "first_name", + "last_name", + "display_name", + "last_active_project", + "login_attempts", + "hcaptcha_response", + "email", + "organization", + "display_picture", + "display_picture_url", + "language", + "email_opt_outs", + ) + write_only_on_create_fields = ("email", "username") @classmethod def update_or_create_profile(cls, user, profile_data): - profile, created = Profile.objects.update_or_create( - user=user, defaults=profile_data - ) + profile, created = Profile.objects.update_or_create(user=user, defaults=profile_data) return profile def validate_hcaptcha_response(self, captcha): validate_hcaptcha(captcha) def validate_last_active_project(self, project): - if project and not project.is_member(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if project and not project.is_member(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project def create(self, validated_data): - profile_data = validated_data.pop('profile', None) - validated_data.pop('hcaptcha_response', None) - validated_data['email'] = validated_data['username'] = (validated_data['email'] or validated_data['email']).lower() + profile_data = validated_data.pop("profile", None) + validated_data.pop("hcaptcha_response", None) + validated_data["email"] = validated_data["username"] = (validated_data["email"] or validated_data["email"]).lower() user = super().create(validated_data) user.save() user.profile = self.update_or_create_profile(user, profile_data) @@ -145,74 +145,76 @@ def create(self, validated_data): return user def update(self, instance, validated_data): - profile_data = validated_data.pop('profile', None) + profile_data = validated_data.pop("profile", None) user = super().update(instance, validated_data) - if 'password' in validated_data: - user.set_password(validated_data['password']) + if "password" in validated_data: + user.set_password(validated_data["password"]) user.save() user.profile = self.update_or_create_profile(user, profile_data) return user -class FeatureSerializer(RemoveNullFieldsMixin, - DynamicFieldsMixin, serializers.ModelSerializer): +class FeatureSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = Feature - fields = ('key', 'title', 'feature_type') + fields = ("key", "title", "feature_type") -class UserPreferencesSerializer(RemoveNullFieldsMixin, - serializers.ModelSerializer): +class UserPreferencesSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): display_picture = serializers.PrimaryKeyRelatedField( - source='profile.display_picture', + source="profile.display_picture", queryset=File.objects.all(), allow_null=True, required=False, ) display_picture_url = URLCachedFileField( - source='profile.display_picture.file', + source="profile.display_picture.file", read_only=True, ) display_name = serializers.CharField( - source='profile.get_display_name', + source="profile.get_display_name", read_only=True, ) last_active_project = serializers.PrimaryKeyRelatedField( - source='profile.last_active_project', + source="profile.last_active_project", queryset=Project.objects.all(), allow_null=True, required=False, ) - language = serializers.CharField(source='profile.language', - read_only=True) + language = serializers.CharField(source="profile.language", read_only=True) fallback_language = serializers.CharField( - source='profile.get_fallback_language', + source="profile.get_fallback_language", read_only=True, ) accessible_features = FeatureSerializer( - source='profile.get_accessible_features', + source="profile.get_accessible_features", many=True, read_only=True, ) class Meta: model = User - fields = ('display_name', 'username', 'email', 'last_active_project', - 'display_picture', 'display_picture_url', 'is_superuser', - 'language', 'accessible_features', 'fallback_language',) + fields = ( + "display_name", + "username", + "email", + "last_active_project", + "display_picture", + "display_picture_url", + "is_superuser", + "language", + "accessible_features", + "fallback_language", + ) -class PasswordResetSerializer(RemoveNullFieldsMixin, - serializers.Serializer): +class PasswordResetSerializer(RemoveNullFieldsMixin, serializers.Serializer): hcaptcha_response = serializers.CharField(required=True) email = serializers.EmailField(required=True) def get_user(self, email): - users = User.objects.filter( - models.Q(email=email.lower()) | - models.Q(email=email) - ) + users = User.objects.filter(models.Q(email=email.lower()) | models.Q(email=email)) if not users.exists(): raise UserNotFoundError return users.first() @@ -225,36 +227,46 @@ def save(self): send_password_reset(user=self.get_user(email)) -class NotificationSerializer(RemoveNullFieldsMixin, - serializers.Serializer): +class NotificationSerializer(RemoveNullFieldsMixin, serializers.Serializer): date = serializers.DateTimeField() - type = serializers.CharField(source='notification_type') + type = serializers.CharField(source="notification_type") details = serializers.ReadOnlyField() class Meta: - ref_name = 'UserNotificationSerializer' + ref_name = "UserNotificationSerializer" class ComprehensiveUserSerializer(serializers.ModelSerializer): - name = serializers.CharField(source='profile.get_display_name', read_only=True) - organization = serializers.CharField(source='profile.organization', read_only=True) + name = serializers.CharField(source="profile.get_display_name", read_only=True) + organization = serializers.CharField(source="profile.organization", read_only=True) class Meta: model = User - fields = ('id', 'name', 'email', 'organization',) + fields = ( + "id", + "name", + "email", + "organization", + ) class EntryCommentUserSerializer(serializers.ModelSerializer): - name = serializers.CharField(source='profile.get_display_name', read_only=True) + name = serializers.CharField(source="profile.get_display_name", read_only=True) display_picture_url = URLCachedFileField( - source='profile.display_picture.file', + source="profile.display_picture.file", read_only=True, ) - organization = serializers.CharField(source='profile.organization', read_only=True) + organization = serializers.CharField(source="profile.organization", read_only=True) class Meta: model = User - fields = ('id', 'name', 'email', 'organization', 'display_picture_url',) + fields = ( + "id", + "name", + "email", + "organization", + "display_picture_url", + ) class PasswordChangeSerializer(serializers.Serializer): @@ -262,9 +274,9 @@ class PasswordChangeSerializer(serializers.Serializer): new_password = serializers.CharField(required=True, write_only=True) def validate_old_password(self, password): - user = self.context['request'].user + user = self.context["request"].user if not user.check_password(password): - raise serializers.ValidationError('Invalid Old Password') + raise serializers.ValidationError("Invalid Old Password") return password def validate_new_password(self, password): @@ -272,28 +284,26 @@ def validate_new_password(self, password): return password def save(self): - user = self.context['request'].user - user.set_password(self.validated_data['new_password']) + user = self.context["request"].user + user.set_password(self.validated_data["new_password"]) user.save() - client_ip = get_client_ip(self.context['request']) - device_type = get_device_type(self.context['request']) + client_ip = get_client_ip(self.context["request"]) + device_type = get_device_type(self.context["request"]) transaction.on_commit( - lambda: send_password_changed_notification.delay( - user_id=user.id, - client_ip=client_ip, - device_type=device_type) + lambda: send_password_changed_notification.delay(user_id=user.id, client_ip=client_ip, device_type=device_type) ) + # ----------------------- NEW GRAPHQL SCHEME Serializers ---------------------------------- class UserNotificationSerializer(serializers.ModelSerializer): - name = serializers.CharField(source='profile.get_display_name', read_only=True) + name = serializers.CharField(source="profile.get_display_name", read_only=True) # display_picture = URLCachedFileField(source='profile.display_picture.file', read_only=True) class Meta: model = User - fields = ('id', 'name', 'email') + fields = ("id", "name", "email") class LoginSerializer(serializers.Serializer): @@ -304,10 +314,7 @@ class LoginSerializer(serializers.Serializer): @classmethod def is_captcha_required(cls, user=None, email=None): _user = user or User.objects.filter(email=email).first() - return ( - _user is not None and - _user.profile.login_attempts >= settings.MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA - ) + return _user is not None and _user.profile.login_attempts >= settings.MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA def validate_password(self, password): # this will now only handle max-length in the login @@ -317,39 +324,36 @@ def validate_password(self, password): def validate(self, data): def _set_user_login_attempts(user, login_attempts): user.profile.login_attempts = login_attempts - user.profile.save(update_fields=['login_attempts']) + user.profile.save(update_fields=["login_attempts"]) - email = data['email'] + email = data["email"] # NOTE: authenticate only works for active users # NOTE: username should be equal to email - authenticate_user = authenticate(username=email.lower(), password=data['password']) + authenticate_user = authenticate(username=email.lower(), password=data["password"]) # Try again without lower (for legacy users, TODO: Migrate this users) if authenticate_user is None: - authenticate_user = authenticate(username=email, password=data['password']) - captcha = data.get('captcha') - user = User.objects.filter( - models.Q(email=email.lower()) | - models.Q(email=email) - ).first() + authenticate_user = authenticate(username=email, password=data["password"]) + captcha = data.get("captcha") + user = User.objects.filter(models.Q(email=email.lower()) | models.Q(email=email)).first() # User doesn't exists in the system. if user is None: - raise serializers.ValidationError('No active account found with the given credentials') + raise serializers.ValidationError("No active account found with the given credentials") # Validate captcha if required for requested user if self.is_captcha_required(user=user): if not captcha: - raise serializers.ValidationError({'captcha': 'Captcha is required'}) + raise serializers.ValidationError({"captcha": "Captcha is required"}) if not validate_hcaptcha(captcha, raise_on_error=False): - raise serializers.ValidationError({'captcha': 'Invalid captcha! Please, Try Again'}) + raise serializers.ValidationError({"captcha": "Invalid captcha! Please, Try Again"}) # Let user retry until max login attempts is reached if user.profile.login_attempts < settings.MAX_LOGIN_ATTEMPTS: if authenticate_user is None: _set_user_login_attempts(user, user.profile.login_attempts + 1) raise serializers.ValidationError( - 'No active account found with the given credentials.' - f' You have {settings.MAX_LOGIN_ATTEMPTS - user.profile.login_attempts} login attempts remaining' + "No active account found with the given credentials." + f" You have {settings.MAX_LOGIN_ATTEMPTS - user.profile.login_attempts} login attempts remaining" ) else: # Lock account after to many attempts @@ -357,12 +361,12 @@ def _set_user_login_attempts(user, login_attempts): # Send email before locking account. _set_user_login_attempts(user, user.profile.login_attempts + 1) send_account_activation(user) - raise serializers.ValidationError('Account is locked, check your email.') + raise serializers.ValidationError("Account is locked, check your email.") # Clear login_attempts after success authentication if user.profile.login_attempts > 0: _set_user_login_attempts(user, 0) - return {'user': authenticate_user} + return {"user": authenticate_user} class CaptchaSerializerMixin(serializers.ModelSerializer): @@ -370,7 +374,7 @@ class CaptchaSerializerMixin(serializers.ModelSerializer): def validate_captcha(self, captcha): if not validate_hcaptcha(captcha, raise_on_error=False): - raise serializers.ValidationError('Invalid captcha! Please, Try Again') + raise serializers.ValidationError("Invalid captcha! Please, Try Again") class RegisterSerializer(CaptchaSerializerMixin, serializers.ModelSerializer): @@ -380,35 +384,35 @@ class RegisterSerializer(CaptchaSerializerMixin, serializers.ModelSerializer): class Meta: model = User fields = ( - 'email', 'first_name', 'last_name', - 'organization', 'captcha', + "email", + "first_name", + "last_name", + "organization", + "captcha", ) def validate_email(self, email): email = email.lower() existing_users_qs = User.objects.filter( - models.Q(email=email) | - models.Q(username=email) | + models.Q(email=email) + | models.Q(username=email) + | # Partially deleted users - models.Q(profile__original_data__email=email) | - models.Q(profile__original_data__username=email) + models.Q(profile__original_data__email=email) + | models.Q(profile__original_data__username=email) ) if existing_users_qs.exists(): - raise serializers.ValidationError('User with that email already exists!!') + raise serializers.ValidationError("User with that email already exists!!") return email # Only this method is used for Register def create(self, validated_data): - validated_data.pop('captcha') - validated_data['username'] = validated_data['email'].lower() - profile_data = { - 'organization': validated_data.pop('organization') - } + validated_data.pop("captcha") + validated_data["username"] = validated_data["email"].lower() + profile_data = {"organization": validated_data.pop("organization")} user = super().create(validated_data) user.profile = UserSerializer.update_or_create_profile(user, profile_data) - transaction.on_commit( - lambda: send_password_reset(user=user, welcome=True) - ) + transaction.on_commit(lambda: send_password_reset(user=user, welcome=True)) return user @@ -418,12 +422,12 @@ class GqPasswordResetSerializer(CaptchaSerializerMixin, serializers.ModelSeriali class Meta: model = User - fields = ('email', 'captcha') + fields = ("email", "captcha") def validate_email(self, email): if user := User.objects.filter(email=email.lower()).first(): return user - raise serializers.ValidationError('There is no user with that email.') + raise serializers.ValidationError("There is no user with that email.") def save(self): user = self.validated_data["email"] # validate_email returning user instance @@ -431,21 +435,21 @@ def save(self): class UserMeSerializer(serializers.ModelSerializer): - organization = serializers.CharField(source='profile.organization', allow_blank=True, required=False) - language = serializers.CharField(source='profile.language', allow_null=True, required=False) + organization = serializers.CharField(source="profile.organization", allow_blank=True, required=False) + language = serializers.CharField(source="profile.language", allow_null=True, required=False) email_opt_outs = serializers.ListField( child=serializers.ChoiceField(choices=Profile.EmailConditionOptOut.choices), - source='profile.email_opt_outs', + source="profile.email_opt_outs", required=False, ) last_active_project = serializers.PrimaryKeyRelatedField( - source='profile.last_active_project', + source="profile.last_active_project", queryset=Project.objects.all(), allow_null=True, required=False, ) display_picture = serializers.PrimaryKeyRelatedField( - source='profile.display_picture', + source="profile.display_picture", queryset=File.objects.all(), allow_null=True, required=False, @@ -454,25 +458,30 @@ class UserMeSerializer(serializers.ModelSerializer): class Meta: model = User fields = ( - 'first_name', 'last_name', 'organization', 'display_picture', - 'language', 'email_opt_outs', 'last_active_project' + "first_name", + "last_name", + "organization", + "display_picture", + "language", + "email_opt_outs", + "last_active_project", ) def validate_last_active_project(self, project): - if project and not project.is_member(self.context['request'].user): - raise serializers.ValidationError('Invalid project') + if project and not project.is_member(self.context["request"].user): + raise serializers.ValidationError("Invalid project") return project def validate_display_picture(self, display_picture): - if display_picture and display_picture.created_by != self.context['request'].user: - raise serializers.ValidationError('Display picture not found!') + if display_picture and display_picture.created_by != self.context["request"].user: + raise serializers.ValidationError("Display picture not found!") return display_picture def update(self, instance, validated_data): - profile_data = validated_data.pop('profile', None) + profile_data = validated_data.pop("profile", None) user = super().update(instance, validated_data) - if 'password' in validated_data: - user.set_password(validated_data['password']) + if "password" in validated_data: + user.set_password(validated_data["password"]) user.save() user.profile = UserSerializer.update_or_create_profile(user, profile_data) return user @@ -485,14 +494,12 @@ class HIDLoginSerializer(serializers.Serializer): state = serializers.IntegerField(required=False) def validate(self, data): - humanitarian_id = hid.HumanitarianId(data['access_token']) + humanitarian_id = hid.HumanitarianId(data["access_token"]) try: - return { - 'user': humanitarian_id.get_user() - } + return {"user": humanitarian_id.get_user()} except hid.HIDBaseException as e: raise serializers.ValidationError(e.message) except Exception: - logger.error('HID error', exc_info=True) - raise serializers.ValidationError('Unexpected Error') + logger.error("HID error", exc_info=True) + raise serializers.ValidationError("Unexpected Error") diff --git a/apps/user/tasks.py b/apps/user/tasks.py index 76c34288d1..340c53a96a 100644 --- a/apps/user/tasks.py +++ b/apps/user/tasks.py @@ -2,10 +2,9 @@ from datetime import timedelta from celery import shared_task - -from django.utils import timezone -from django.contrib.auth.models import User from django.conf import settings +from django.contrib.auth.models import User +from django.utils import timezone logger = logging.getLogger(__name__) @@ -14,18 +13,16 @@ def permanently_delete_users(): # get all the user whose_deleted_at is null # and check if there deleted days greater than 30 days - logger.info('[User Delete] Querying all the user with deleted_at') - threshold = ( - timezone.now() - timedelta(days=settings.USER_AND_PROJECT_DELETE_IN_DAYS) - ) + logger.info("[User Delete] Querying all the user with deleted_at") + threshold = timezone.now() - timedelta(days=settings.USER_AND_PROJECT_DELETE_IN_DAYS) user_qs = User.objects.filter( profile__original_data__isnull=False, profile__deleted_at__isnull=False, profile__deleted_at__lt=threshold, ) - logger.info(f'[User Delete] Found {user_qs.count()} users to delete.') + logger.info(f"[User Delete] Found {user_qs.count()} users to delete.") for user in user_qs: - logger.info(f'[User Delete] Cleaning up user original data {user.id}') + logger.info(f"[User Delete] Cleaning up user original data {user.id}") user.profile.original_data = None - user.profile.save(update_fields=('original_data',)) - logger.info(f'[User Delete] Successfully deleted all user data from the system {user.id}') + user.profile.save(update_fields=("original_data",)) + logger.info(f"[User Delete] Successfully deleted all user data from the system {user.id}") diff --git a/apps/user/tests/test_apis.py b/apps/user/tests/test_apis.py index 20e581cd42..f14d5e5ebf 100644 --- a/apps/user/tests/test_apis.py +++ b/apps/user/tests/test_apis.py @@ -1,29 +1,20 @@ -from deep.tests import TestCase -from project.models import ( - Project, - ProjectMembership, - ProjectJoinRequest, -) -from user.models import ( - User, - EmailDomain, - Feature, -) +from project.models import Project, ProjectJoinRequest, ProjectMembership +from user.models import EmailDomain, Feature, User from user.notifications import Notification +from deep.tests import TestCase + class UserApiTests(TestCase): def test_active_project(self): # Create a project with self.user as member # and test setting it as active project through the API - project = Project.objects.create(title='Test') - ProjectMembership.objects.create(project=project, - member=self.user, - role=self.admin_role) + project = Project.objects.create(title="Test") + ProjectMembership.objects.create(project=project, member=self.user, role=self.admin_role) - url = '/api/v1/users/{}/'.format(self.user.pk) + url = "/api/v1/users/{}/".format(self.user.pk) data = { - 'last_active_project': project.id, + "last_active_project": project.id, } self.authenticate() @@ -37,10 +28,10 @@ def test_active_project(self): def test_patch_user(self): # TODO: Add old_password to change password - url = '/api/v1/users/{}/'.format(self.user.pk) + url = "/api/v1/users/{}/".format(self.user.pk) data = { - 'password': 'newpassword', - 'receive_email': False, + "password": "newpassword", + "receive_email": False, } self.authenticate() @@ -48,9 +39,9 @@ def test_patch_user(self): self.assert_405(response) def test_authentication_in_users_instance(self): - user = self.create(User, first_name='hello', last_name='bye') + user = self.create(User, first_name="hello", last_name="bye") - url = f'/api/v1/users/{user.id}/' + url = f"/api/v1/users/{user.id}/" # try to get with no authentication response = self.client.get(url) @@ -60,36 +51,31 @@ def test_authentication_in_users_instance(self): self.authenticate(user) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['first_name'], user.first_name) + self.assertEqual(response.data["first_name"], user.first_name) def test_get_me(self): - url = '/api/v1/users/me/' + url = "/api/v1/users/me/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['username'], self.user.username) + self.assertEqual(response.data["username"], self.user.username) def test_search_user_without_exclude(self): project = self.create(Project) - user1 = self.create(User, first_name='search', last_name='user') - user2 = self.create(User, first_name='user', last_name='search') - user3 = self.create(User, first_name='my search', last_name='user') - user4 = self.create(User, email='search@toggle.com') + user1 = self.create(User, first_name="search", last_name="user") + user2 = self.create(User, first_name="user", last_name="search") + user3 = self.create(User, first_name="my search", last_name="user") + user4 = self.create(User, email="search@toggle.com") # Create another non matching user, just to make sure it doesn't appear # in result - self.create( - User, - first_name='abc', - last_name='xyz', - email='something@toggle.com' - ) + self.create(User, first_name="abc", last_name="xyz", email="something@toggle.com") # Add members to project project.add_member(user1) # Search query is 'search' - url = '/api/v1/users/?search=search' + url = "/api/v1/users/?search=search" self.authenticate() response = self.client.get(url) @@ -97,25 +83,24 @@ def test_search_user_without_exclude(self): data = response.json() - assert data['count'] == 4 + assert data["count"] == 4 # user1 is most matching and user4 is the least matching, # user5 does not match - assert data['results'][0]['id'] == user1.id - assert data['results'][1]['id'] == user3.id - assert data['results'][2]['id'] == user2.id - assert data['results'][3]['id'] == user4.id + assert data["results"][0]["id"] == user1.id + assert data["results"][1]["id"] == user3.id + assert data["results"][2]["id"] == user2.id + assert data["results"][3]["id"] == user4.id def test_search_user_with_exclude(self): project = self.create(Project) - user1 = self.create(User, first_name='search', last_name='user') - user2 = self.create(User, first_name='user', last_name='search') - user3 = self.create(User, first_name='my search', last_name='user') + user1 = self.create(User, first_name="search", last_name="user") + user2 = self.create(User, first_name="user", last_name="search") + user3 = self.create(User, first_name="my search", last_name="user") # Add members to project project.add_member(user1) # Search query is 'search' - url = '/api/v1/users/?search=search&members_exclude_project=' \ - + str(project.id) + url = "/api/v1/users/?search=search&members_exclude_project=" + str(project.id) self.authenticate() response = self.client.get(url) @@ -123,108 +108,106 @@ def test_search_user_with_exclude(self): data = response.json() - assert data['count'] == 2, "user 1 is in the project, so one less" + assert data["count"] == 2, "user 1 is in the project, so one less" # user3 is most matching and user2 is the least matching - assert data['results'][0]['id'] == user3.id - assert data['results'][1]['id'] == user2.id + assert data["results"][0]["id"] == user3.id + assert data["results"][1]["id"] == user2.id def test_notifications(self): test_project = self.create(Project, role=self.admin_role) test_user = self.create(User) - request = ProjectJoinRequest.objects.create( - project=test_project, - requested_by=test_user, - role=self.admin_role - ) + request = ProjectJoinRequest.objects.create(project=test_project, requested_by=test_user, role=self.admin_role) - url = '/api/v1/users/me/notifications/' + url = "/api/v1/users/me/notifications/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - result = response.data['results'][0] - self.assertEqual(result['type'], Notification.PROJECT_JOIN_REQUEST) - self.assertEqual(result['details']['id'], request.id) + self.assertEqual(response.data["count"], 1) + result = response.data["results"][0] + self.assertEqual(result["type"], Notification.PROJECT_JOIN_REQUEST) + self.assertEqual(result["details"]["id"], request.id) def test_user_preference_feature_access(self): - user_fhx = self.create(User, email='fhx@togglecorp.com') - user_az273 = self.create(User, email='az273@tc.com') - user_dummy = self.create(User, email='dummy@test.com') + user_fhx = self.create(User, email="fhx@togglecorp.com") + user_az273 = self.create(User, email="az273@tc.com") + user_dummy = self.create(User, email="dummy@test.com") - test_domain = self.create(EmailDomain, title='Togglecorp', domain_name='togglecorp.com') - self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - email_domains=[test_domain], users=[user_dummy]) + test_domain = self.create(EmailDomain, title="Togglecorp", domain_name="togglecorp.com") + self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + email_domains=[test_domain], + users=[user_dummy], + ) self.authenticate(user_fhx) - response = self.client.get('/api/v1/users/me/preferences/') - self.assertEqual(len(response.data['accessible_features']), 1) + response = self.client.get("/api/v1/users/me/preferences/") + self.assertEqual(len(response.data["accessible_features"]), 1) self.authenticate(user_az273) - response = self.client.get('/api/v1/users/me/preferences/') - self.assertEqual(len(response.data['accessible_features']), 0) + response = self.client.get("/api/v1/users/me/preferences/") + self.assertEqual(len(response.data["accessible_features"]), 0) self.authenticate(user_dummy) - response = self.client.get('/api/v1/users/me/preferences/') - self.assertEqual(len(response.data['accessible_features']), 1) + response = self.client.get("/api/v1/users/me/preferences/") + self.assertEqual(len(response.data["accessible_features"]), 1) def test_user_preference_feature_available_for_all(self): - user_fhx = self.create(User, email='fhx@togglecorp.com') - - feature = self.create(Feature, feature_type=Feature.FeatureType.GENERAL_ACCESS, - key=Feature.FeatureKey.PRIVATE_PROJECT, title='Private project', - email_domains=[], users=[], is_available_for_all=False) + user_fhx = self.create(User, email="fhx@togglecorp.com") + + feature = self.create( + Feature, + feature_type=Feature.FeatureType.GENERAL_ACCESS, + key=Feature.FeatureKey.PRIVATE_PROJECT, + title="Private project", + email_domains=[], + users=[], + is_available_for_all=False, + ) self.authenticate(user_fhx) - response = self.client.get('/api/v1/users/me/preferences/') - self.assertEqual(len(response.data['accessible_features']), 0) + response = self.client.get("/api/v1/users/me/preferences/") + self.assertEqual(len(response.data["accessible_features"]), 0) feature.is_available_for_all = True feature.save() self.authenticate(user_fhx) - response = self.client.get('/api/v1/users/me/preferences/') - self.assertEqual(len(response.data['accessible_features']), 1) + response = self.client.get("/api/v1/users/me/preferences/") + self.assertEqual(len(response.data["accessible_features"]), 1) def test_password_change(self): - self.user_password = 'joHnDave!@#123' + self.user_password = "joHnDave!@#123" user = User.objects.create_user( - username='ram@dave.com', - first_name='Ram', - last_name='Dave', + username="ram@dave.com", + first_name="Ram", + last_name="Dave", password=self.user_password, - email='ram@dave.com', + email="ram@dave.com", ) new_pass = "nepal!@#RRFASF" - data = { - "old_password": self.user_password, - "new_password": new_pass - } - url = '/api/v1/users/me/change-password/' + data = {"old_password": self.user_password, "new_password": new_pass} + url = "/api/v1/users/me/change-password/" self.authenticate(user) response = self.client.post(url, data) self.assert_200(response) user.refresh_from_db() - assert user.check_password(data['new_password']) + assert user.check_password(data["new_password"]) # now try with posting diferent `new_password` that doesnot follow django password validation - data = { - "old_password": new_pass, # since password is already changed in the database level - "new_password": "nepa" - } + data = {"old_password": new_pass, "new_password": "nepa"} # since password is already changed in the database level self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) # now try with posting different `old_password` - data = { - "old_password": "hahahmeme", - "new_password": "nepa" - } + data = {"old_password": "hahahmeme", "new_password": "nepa"} self.authenticate(user) response = self.client.post(url, data) self.assert_400(response) diff --git a/apps/user/tests/test_password.py b/apps/user/tests/test_password.py index 9c58cc5792..f8cba6154c 100644 --- a/apps/user/tests/test_password.py +++ b/apps/user/tests/test_password.py @@ -1,16 +1,15 @@ from django.core.exceptions import ValidationError +from user.validators import CustomMaximumLengthValidator from deep.tests import TestCase -from user.validators import CustomMaximumLengthValidator - class PasswordCheckerTest(TestCase): def test_password_greater_than_128_characters(self): - self.assertIsNone(CustomMaximumLengthValidator().validate('12345678')) - self.assertIsNone(CustomMaximumLengthValidator(max_length=20).validate('123')) + self.assertIsNone(CustomMaximumLengthValidator().validate("12345678")) + self.assertIsNone(CustomMaximumLengthValidator(max_length=20).validate("123")) with self.assertRaises(ValidationError) as vd: - CustomMaximumLengthValidator(max_length=128).validate('12' * 129) - self.assertEqual(vd.exception.error_list[0].code, 'password_too_long') + CustomMaximumLengthValidator(max_length=128).validate("12" * 129) + self.assertEqual(vd.exception.error_list[0].code, "password_too_long") diff --git a/apps/user/tests/test_schemas.py b/apps/user/tests/test_schemas.py index 800de37c64..508eeeafc2 100644 --- a/apps/user/tests/test_schemas.py +++ b/apps/user/tests/test_schemas.py @@ -1,30 +1,26 @@ -import pytz +from datetime import datetime, timedelta from unittest import mock -from datetime import timedelta, datetime -from dateutil.relativedelta import relativedelta +import pytz +from analysis_framework.factories import AnalysisFrameworkFactory +from dateutil.relativedelta import relativedelta from django.conf import settings from django.utils import timezone - -from deep.trackers import schedule_tracker_data_handler -from utils.graphene.tests import GraphQLTestCase - from gallery.factories import FileFactory from project.factories import ProjectFactory -from analysis_framework.factories import AnalysisFrameworkFactory -from user.models import User, Feature, EmailCondition, Profile -from user.factories import UserFactory, FeatureFactory +from user.factories import FeatureFactory, UserFactory +from user.models import EmailCondition, Feature, Profile, User +from user.tasks import permanently_delete_users from user.utils import ( + generate_hidden_email, + send_account_activation, send_password_changed_notification, send_password_reset, - send_account_activation, - generate_hidden_email, -) -from utils.hid.tests.test_hid import ( - HIDIntegrationTest, - HID_EMAIL ) -from user.tasks import permanently_delete_users + +from deep.trackers import schedule_tracker_data_handler +from utils.graphene.tests import GraphQLTestCase +from utils.hid.tests.test_hid import HID_EMAIL, HIDIntegrationTest class TestUserSchema(GraphQLTestCase): @@ -32,7 +28,7 @@ class TestUserSchema(GraphQLTestCase): def setUp(self): # This is used in 2 test - self.login_mutation = ''' + self.login_mutation = """ mutation Mutation($input: LoginInputType!) { login(data: $input) { ok @@ -45,36 +41,37 @@ def setUp(self): } } } - ''' + """ super().setUp() def test_login(self): # Try with random user - minput = dict(email='xyz@xyz.com', password='pasword-xyz') + minput = dict(email="xyz@xyz.com", password="pasword-xyz") self.query_check(self.login_mutation, minput=minput, okay=False) # Try with real user - user = UserFactory.create(email=minput['email']) + user = UserFactory.create(email=minput["email"]) minput = dict(email=user.email, password=user.password_text) content = self.query_check(self.login_mutation, minput=minput, okay=True) # FIXME: Maybe ['id'] should be string? - self.assertEqual(content['data']['login']['result']['id'], str(user.id), content) - self.assertEqual(content['data']['login']['result']['email'], user.email, content) + self.assertEqual(content["data"]["login"]["result"]["id"], str(user.id), content) + self.assertEqual(content["data"]["login"]["result"]["email"], user.email, content) - @mock.patch('jwt_auth.captcha.requests') - @mock.patch('user.serializers.send_account_activation', side_effect=send_account_activation) + @mock.patch("jwt_auth.captcha.requests") + @mock.patch("user.serializers.send_account_activation", side_effect=send_account_activation) def test_login_captcha(self, send_account_activation_mock, captch_requests_mock): """ - Test captcha response. - Test account block behaviour """ + def _invalid_login(): content = self.query_check( self.login_mutation, minput=dict( email=user.email, - password='wrong-password', - captcha='captcha', + password="wrong-password", + captcha="captcha", ), okay=False, ) @@ -87,12 +84,12 @@ def _valid_login(okay): minput=dict( email=user.email, password=user.password_text, - captcha='captcha', + captcha="captcha", ), okay=okay, ) - captch_requests_mock.post.return_value.json.return_value = {'success': False} + captch_requests_mock.post.return_value.json.return_value = {"success": False} user = UserFactory.create() # For MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA count failed login attempt for attempt in range(1, 5): @@ -103,7 +100,7 @@ def _valid_login(okay): # Count stoped (when valid captch is not provided) self.assertEqual(user.profile.login_attempts, settings.MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA, content) # After MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA count failed captcha is required - captch_requests_mock.post.return_value.json.return_value = {'success': True} + captch_requests_mock.post.return_value.json.return_value = {"success": True} for attempt in range( settings.MAX_LOGIN_ATTEMPTS_FOR_CAPTCHA + 1, settings.MAX_LOGIN_ATTEMPTS + 2, @@ -120,18 +117,18 @@ def _valid_login(okay): # Count all failed count (when valid captch is provided) # Still can't login (with right password). - captch_requests_mock.post.return_value.json.return_value = {'success': True} + captch_requests_mock.post.return_value.json.return_value = {"success": True} content = _valid_login(okay=False) # mock activation link logic user.profile.login_attempts = 0 - user.profile.save(update_fields=['login_attempts']) + user.profile.save(update_fields=["login_attempts"]) content = _valid_login(okay=True) - @mock.patch('utils.hid.hid.requests') + @mock.patch("utils.hid.hid.requests") def test_login_with_hid(self, mock_requests): - query = ''' + query = """ mutation Mutation($input: HIDLoginInputType!) { loginWithHid(data: $input) { ok @@ -144,11 +141,11 @@ def test_login_with_hid(self, mock_requests): } } } - ''' + """ mock_return_value = HIDIntegrationTest()._setup_mock_hid_requests(mock_requests) - minput = dict(accessToken='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') + minput = dict(accessToken="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") content = self.query_check(query, minput=minput, okay=True) - self.assertEqual(content['data']['loginWithHid']['result']['email'], HID_EMAIL) + self.assertEqual(content["data"]["loginWithHid"]["result"]["email"], HID_EMAIL) # let the response be `400` and look for the error mock_return_value.status_code = 400 @@ -156,14 +153,14 @@ def test_login_with_hid(self, mock_requests): mock_return_value.status_code = 200 # pass not verified email - mock_return_value.json.return_value['email_verified'] = False + mock_return_value.json.return_value["email_verified"] = False self.query_check(query, minput=minput, assert_for_error=True) - mock_return_value.json.return_value['email_verified'] = True + mock_return_value.json.return_value["email_verified"] = True - @mock.patch('jwt_auth.captcha.requests') - @mock.patch('user.serializers.send_password_reset', side_effect=send_password_reset) + @mock.patch("jwt_auth.captcha.requests") + @mock.patch("user.serializers.send_password_reset", side_effect=send_password_reset) def test_register(self, send_password_reset_mock, captch_requests_mock): - query = ''' + query = """ mutation Mutation($input: RegisterInputType!) { register(data: $input) { ok @@ -171,35 +168,35 @@ def test_register(self, send_password_reset_mock, captch_requests_mock): errors } } - ''' + """ # input without email minput = dict( - email='invalid-email', - firstName='john', - lastName='cena', - organization='the-deep', - captcha='captcha', + email="invalid-email", + firstName="john", + lastName="cena", + organization="the-deep", + captcha="captcha", ) # With invalid captcha - captch_requests_mock.post.return_value.json.return_value = {'success': False} + captch_requests_mock.post.return_value.json.return_value = {"success": False} content = self.query_check(query, minput=minput, okay=False) # With valid captcha now - captch_requests_mock.post.return_value.json.return_value = {'success': True} + captch_requests_mock.post.return_value.json.return_value = {"success": True} # With invalid email content = self.query_check(query, minput=minput, okay=False) - self.assertEqual(len(content['data']['register']['errors']), 1, content) + self.assertEqual(len(content["data"]["register"]["errors"]), 1, content) # With valid input - minput['email'] = 'john@Cena.com' + minput["email"] = "john@Cena.com" with self.captureOnCommitCallbacks(execute=True): content = self.query_check(query, minput=minput, okay=True) # Make sure password reset message is send - user = User.objects.get(email=minput['email'].lower()) + user = User.objects.get(email=minput["email"].lower()) send_password_reset_mock.assert_called_once_with(user=user, welcome=True) self.assertEqual(user.username, user.email) - self.assertEqual(user.email, minput['email'].lower()) + self.assertEqual(user.email, minput["email"].lower()) # Try again with same data self.query_check(query, minput=minput, okay=False) @@ -210,27 +207,27 @@ def test_register(self, send_password_reset_mock, captch_requests_mock): # Now permanently delete user data user.profile.original_data = None - user.profile.save(update_fields=('original_data',)) + user.profile.save(update_fields=("original_data",)) # Should work now self.query_check(query, minput=minput, okay=True) def test_logout(self): - query = ''' + query = """ query Query { me { id email } } - ''' - logout_mutation = ''' + """ + logout_mutation = """ mutation Mutation { logout { ok } } - ''' + """ user = UserFactory.create() # # Without Login session self.query_check(query, assert_for_error=True) @@ -240,17 +237,17 @@ def test_logout(self): # Query Me (Success) content = self.query_check(query) - self.assertEqual(content['data']['me']['id'], str(user.id), content) - self.assertEqual(content['data']['me']['email'], user.email, content) + self.assertEqual(content["data"]["me"]["id"], str(user.id), content) + self.assertEqual(content["data"]["me"]["email"], user.email, content) # # Logout self.query_check(logout_mutation, okay=True) # Query Me (with error again) self.query_check(query, assert_for_error=True) - @mock.patch('jwt_auth.captcha.requests') - @mock.patch('user.serializers.send_password_reset', side_effect=send_password_reset) + @mock.patch("jwt_auth.captcha.requests") + @mock.patch("user.serializers.send_password_reset", side_effect=send_password_reset) def test_password_reset(self, send_password_reset_mock, captch_requests_mock): - query = ''' + query = """ mutation Mutation($input: ResetPasswordInputType!) { resetPassword(data: $input) { ok @@ -258,50 +255,50 @@ def test_password_reset(self, send_password_reset_mock, captch_requests_mock): errors } } - ''' + """ # input without email minput = dict( - email='invalid-email', - captcha='captcha', + email="invalid-email", + captcha="captcha", ) # With invalid captcha - captch_requests_mock.post.return_value.json.return_value = {'success': False} + captch_requests_mock.post.return_value.json.return_value = {"success": False} content = self.query_check(query, minput=minput, okay=False) # With valid captcha now - captch_requests_mock.post.return_value.json.return_value = {'success': True} + captch_requests_mock.post.return_value.json.return_value = {"success": True} # With invalid email content = self.query_check(query, minput=minput, okay=False) - self.assertEqual(len(content['data']['resetPassword']['errors']), 1, content) + self.assertEqual(len(content["data"]["resetPassword"]["errors"]), 1, content) # With unknown user email - minput['email'] = 'john@cena.com' + minput["email"] = "john@cena.com" content = self.query_check(query, minput=minput, okay=False) - self.assertEqual(len(content['data']['resetPassword']['errors']), 1, content) + self.assertEqual(len(content["data"]["resetPassword"]["errors"]), 1, content) # With known user email - UserFactory.create(email=minput['email']) + UserFactory.create(email=minput["email"]) content = self.query_check(query, minput=minput, okay=True) # Make sure password reset message is send - user = User.objects.get(email=minput['email']) + user = User.objects.get(email=minput["email"]) send_password_reset_mock.assert_called_once_with(user=user) @mock.patch( - 'user.serializers.send_password_changed_notification.delay', + "user.serializers.send_password_changed_notification.delay", side_effect=send_password_changed_notification.delay, ) def test_password_change(self, send_password_changed_notification_mock): - query = ''' + query = """ mutation Mutation($input: PasswordChangeInputType!) { changePassword(data: $input) { ok errors } } - ''' + """ # input without email - minput = dict(oldPassword='', newPassword='new-password-123') + minput = dict(oldPassword="", newPassword="new-password-123") # Without authentication -- content = self.query_check(query, minput=minput, assert_for_error=True) # With authentication @@ -309,34 +306,34 @@ def test_password_change(self, send_password_changed_notification_mock): self.force_login(user) # With invalid old password -- content = self.query_check(query, minput=minput, okay=False) - self.assertEqual(len(content['data']['changePassword']['errors']), 1, content) + self.assertEqual(len(content["data"]["changePassword"]["errors"]), 1, content) # With valid password -- - minput['oldPassword'] = user.password_text + minput["oldPassword"] = user.password_text with self.captureOnCommitCallbacks(execute=True): content = self.query_check(query, minput=minput, okay=True) # Make sure password reset message is send send_password_changed_notification_mock.assert_called_once() send_password_changed_notification_mock.assert_called_once_with( user_id=user.pk, - client_ip='127.0.0.1', + client_ip="127.0.0.1", device_type=None, ) def test_update_me(self): - query = ''' + query = """ mutation Mutation($input: UserMeInputType!) { updateMe(data: $input) { ok errors } } - ''' + """ user = UserFactory.create() project = ProjectFactory.create() gallery_file = FileFactory.create() minput = dict( - emailOptOuts=[''], + emailOptOuts=[""], displayPicture=gallery_file.pk, # File without access lastActiveProject=project.pk, # Non-member Project language="en-us", @@ -344,7 +341,7 @@ def test_update_me(self): lastName="Deep", organization="DFS", ) - minput['emailOptOuts'] = [ + minput["emailOptOuts"] = [ self.genum(EmailCondition.NEWS_AND_UPDATES), self.genum(EmailCondition.JOIN_REQUESTS), ] @@ -353,7 +350,7 @@ def test_update_me(self): # With authentication ----- self.force_login(user) content = self.query_check(query, minput=minput, okay=False) - self.assertEqual(len(content['data']['updateMe']['errors']), 2, content) + self.assertEqual(len(content["data"]["updateMe"]["errors"]), 2, content) # With valid ----- # Remove invalid option # Add ownership to file @@ -364,7 +361,7 @@ def test_update_me(self): content = self.query_check(query, minput=minput, okay=True) def test_me_last_active_project(self): - query = ''' + query = """ query Query { me { lastActiveProject { @@ -373,7 +370,7 @@ def test_me_last_active_project(self): } } } - ''' + """ user = UserFactory.create() project1 = ProjectFactory.create() @@ -384,23 +381,23 @@ def test_me_last_active_project(self): self.force_login(user) # --- Without any project membership content = self.query_check(query) - self.assertEqual(content['data']['me']['lastActiveProject'], None, content) + self.assertEqual(content["data"]["me"]["lastActiveProject"], None, content) # --- With a project membership + But no lastActiveProject set in profile project1.add_member(user) content = self.query_check(query) - self.assertIdEqual(content['data']['me']['lastActiveProject']['id'], project1.pk, content) + self.assertIdEqual(content["data"]["me"]["lastActiveProject"]["id"], project1.pk, content) # --- With a project membership + lastActiveProject is set in profile project2.add_member(user) user.last_active_project = project2 content = self.query_check(query) - self.assertIdEqual(content['data']['me']['lastActiveProject']['id'], project2.pk, content) + self.assertIdEqual(content["data"]["me"]["lastActiveProject"]["id"], project2.pk, content) # --- With a project membership + (non-member) lastActiveProject is set in profile user.last_active_project = project3 content = self.query_check(query) - self.assertIdEqual(content['data']['me']['lastActiveProject']['id'], project2.pk, content) + self.assertIdEqual(content["data"]["me"]["lastActiveProject"]["id"], project2.pk, content) def test_me_allowed_features(self): - query = ''' + query = """ query MyQuery { me { accessibleFeatures { @@ -410,7 +407,7 @@ def test_me_allowed_features(self): } } } - ''' + """ feature1 = FeatureFactory.create(key=Feature.FeatureKey.ANALYSIS) feature2 = FeatureFactory.create(key=Feature.FeatureKey.POLYGON_SUPPORT_GEO) @@ -421,17 +418,17 @@ def test_me_allowed_features(self): self.force_login(user) # --- Without any features content = self.query_check(query) - self.assertEqual(len(content['data']['me']['accessibleFeatures']), 0, content) + self.assertEqual(len(content["data"]["me"]["accessibleFeatures"]), 0, content) # --- With a project membership + But no lastActiveProject set in profile feature1.users.add(user) feature2.users.add(user) content = self.query_check(query) - self.assertEqual(len(content['data']['me']['accessibleFeatures']), 2, content) - self.assertEqual(content['data']['me']['accessibleFeatures'][0]['key'], self.genum(feature1.key), content) - self.assertEqual(content['data']['me']['accessibleFeatures'][1]['key'], self.genum(feature2.key), content) + self.assertEqual(len(content["data"]["me"]["accessibleFeatures"]), 2, content) + self.assertEqual(content["data"]["me"]["accessibleFeatures"][0]["key"], self.genum(feature1.key), content) + self.assertEqual(content["data"]["me"]["accessibleFeatures"][1]["key"], self.genum(feature2.key), content) def test_me_only_fields(self): - query = ''' + query = """ query UserQuery($id: ID!) { me { id @@ -484,16 +481,16 @@ def test_me_only_fields(self): pageSize } } - ''' + """ User.objects.all().delete() # Clear all users if exists project = ProjectFactory.create() display_picture = FileFactory.create() # Create some users user = UserFactory.create( # Will use this as requesting user - organization='Deep', - language='en-us', - email_opt_outs=['join_requests'], + organization="Deep", + language="en-us", + email_opt_outs=["join_requests"], last_login=timezone.now(), last_active_project=project, display_picture=display_picture, @@ -502,9 +499,9 @@ def test_me_only_fields(self): # Other users for i in range(0, 3): other_last_user = UserFactory.create( - organization=f'Deep {i}', - language='en-us', - email_opt_outs=['join_requests'], + organization=f"Deep {i}", + language="en-us", + email_opt_outs=["join_requests"], last_login=timezone.now(), last_active_project=project, display_picture=display_picture, @@ -512,28 +509,33 @@ def test_me_only_fields(self): # This fields are only meant for `Me` only_me_fields = [ - 'displayPicture', 'lastActiveProject', 'language', 'emailOptOuts', - 'email', 'lastLogin', 'jwtToken', + "displayPicture", + "lastActiveProject", + "language", + "emailOptOuts", + "email", + "lastLogin", + "jwtToken", ] # Without authentication ----- - content = self.query_check(query, assert_for_error=True, variables={'id': str(other_last_user.pk)}) + content = self.query_check(query, assert_for_error=True, variables={"id": str(other_last_user.pk)}) # With authentication ----- self.force_login(user) - content = self.query_check(query, variables={'id': str(other_last_user.pk)}) - self.assertEqual(len(content['data']['users']['results']), 4, content) # 1 me + 3 others + content = self.query_check(query, variables={"id": str(other_last_user.pk)}) + self.assertEqual(len(content["data"]["users"]["results"]), 4, content) # 1 me + 3 others for field in only_me_fields: self.assertNotEqual( - content['data']['me'].get(field), None, (field, content['data']['me'][field]) + content["data"]["me"].get(field), None, (field, content["data"]["me"][field]) ) # Shouldn't be None self.assertEqual( - content['data']['user'].get(field), None, (field, content['data']['user'].get(field)) + content["data"]["user"].get(field), None, (field, content["data"]["user"].get(field)) ) # Should be None # check for display_picture_url - self.assertNotEqual(content['data']['me']['displayPictureUrl'], None, content) + self.assertNotEqual(content["data"]["me"]["displayPictureUrl"], None, content) def test_user_filters(self): - query = ''' + query = """ query UserQuery($membersExcludeFramework: ID, $membersExcludeProject: ID, $search: String) { users( membersExcludeFramework: $membersExcludeFramework, @@ -556,12 +558,12 @@ def test_user_filters(self): pageSize } } - ''' + """ project1, project2 = ProjectFactory.create_batch(2) af1, af2 = AnalysisFrameworkFactory.create_batch(2) - user = UserFactory.create(first_name='Normal', last_name='Guy', email='test@testing.com') - user1 = UserFactory.create(first_name='Admin', last_name='Guy', email='admin@testing.com') + user = UserFactory.create(first_name="Normal", last_name="Guy", email="test@testing.com") + user1 = UserFactory.create(first_name="Admin", last_name="Guy", email="admin@testing.com") user2, user3 = UserFactory.create_batch(2) project1.add_member(user1) project1.add_member(user2) @@ -583,28 +585,28 @@ def _query_check(filters, **kwargs): # Without any filters for name, filters, users in ( - ('no-filter', dict(), [user, user1, user2, user3]), - ('exclude-project-1', dict(membersExcludeProject=project1.pk), [user, user3]), - ('exclude-project-2', dict(membersExcludeProject=project2.pk), [user1, user3]), - ('exclude-af-1', dict(membersExcludeFramework=af1.pk), [user3]), - ('exclude-af-2', dict(membersExcludeFramework=af2.pk), [user, user1, user3]), - ('search-fist_name', dict(search='Normal'), [user]), - ('search-last_name', dict(search='Guy'), [user, user1]), - ('search-email', dict(search='test@testing.com'), [user]), - ('search-partial_email-01', dict(search='test@'), [user]), - ('search-partial_email-02', dict(search='@testing.com'), [user, user1]), - ('search-full_name', dict(search='Normal Guy'), [user]), - ('search-with-space-after-first_name', dict(search='Normal '), [user]), - ('search-with-space-before-first_name', dict(search=' Normal'), [user]), - ('search-with-space-before-after-last_name', dict(search=' Guy '), [user, user1]), - ('search-with-space-after-full-name', dict(search='Normal Guy '), [user]), + ("no-filter", dict(), [user, user1, user2, user3]), + ("exclude-project-1", dict(membersExcludeProject=project1.pk), [user, user3]), + ("exclude-project-2", dict(membersExcludeProject=project2.pk), [user1, user3]), + ("exclude-af-1", dict(membersExcludeFramework=af1.pk), [user3]), + ("exclude-af-2", dict(membersExcludeFramework=af2.pk), [user, user1, user3]), + ("search-fist_name", dict(search="Normal"), [user]), + ("search-last_name", dict(search="Guy"), [user, user1]), + ("search-email", dict(search="test@testing.com"), [user]), + ("search-partial_email-01", dict(search="test@"), [user]), + ("search-partial_email-02", dict(search="@testing.com"), [user, user1]), + ("search-full_name", dict(search="Normal Guy"), [user]), + ("search-with-space-after-first_name", dict(search="Normal "), [user]), + ("search-with-space-before-first_name", dict(search=" Normal"), [user]), + ("search-with-space-before-after-last_name", dict(search=" Guy "), [user, user1]), + ("search-with-space-after-full-name", dict(search="Normal Guy "), [user]), ): - content = _query_check(filters)['data']['users']['results'] + content = _query_check(filters)["data"]["users"]["results"] self.assertEqual(len(content), len(users), (name, content)) self.assertListIds(content, users, (name, content)) def test_get_user_hidden_email(self): - query_single_user = ''' + query_single_user = """ query MyQuery($id: ID!) { user(id: $id) { id @@ -621,9 +623,9 @@ def test_get_user_hidden_email(self): } } - ''' + """ - query_all_users = ''' + query_all_users = """ query MyQuery { users { results { @@ -643,35 +645,35 @@ def test_get_user_hidden_email(self): } } - ''' - user = UserFactory.create(email='testuser@deep.com') - UserFactory.create(email='testuser2@deep.com') - UserFactory.create(email='testuser3@deep.com') + """ + user = UserFactory.create(email="testuser@deep.com") + UserFactory.create(email="testuser2@deep.com") + UserFactory.create(email="testuser3@deep.com") # # Without Login session - self.query_check(query_single_user, variables={'id': str(user.id)}, assert_for_error=True) + self.query_check(query_single_user, variables={"id": str(user.id)}, assert_for_error=True) # # Login self.force_login(user) # Query User (Success) - content = self.query_check(query_single_user, variables={'id': str(user.id)}) - self.assertEqual(content['data']['user']['id'], str(user.id), content) - self.assertEqual(content['data']['user']['emailDisplay'], 't***r@deep.com') + content = self.query_check(query_single_user, variables={"id": str(user.id)}) + self.assertEqual(content["data"]["user"]["id"], str(user.id), content) + self.assertEqual(content["data"]["user"]["emailDisplay"], "t***r@deep.com") # Query Users (Success) content = self.query_check(query_all_users) - email_display_list = [result['emailDisplay'] for result in content['data']['users']['results']] - self.assertTrue(set(['t***r@deep.com', 't***2@deep.com']).issubset(set(email_display_list))) + email_display_list = [result["emailDisplay"] for result in content["data"]["users"]["results"]] + self.assertTrue(set(["t***r@deep.com", "t***2@deep.com"]).issubset(set(email_display_list))) def test_generate_hidden_email(self): - deleted_email = f'test123@{settings.DELETED_USER_EMAIL_DOMAIN}' + deleted_email = f"test123@{settings.DELETED_USER_EMAIL_DOMAIN}" for original, expected in [ - ('testuser1@deep.com', 't***1@deep.com'), - ('testuser2@deep.com', 't***2@deep.com'), - ('abcd@deep.com', 'a***d@deep.com'), - ('abc@deep.com', 'a***c@deep.com'), - ('xy@deep.com', 'x***y@deep.com'), - ('a@deep.com', 'a***a@deep.com'), + ("testuser1@deep.com", "t***1@deep.com"), + ("testuser2@deep.com", "t***2@deep.com"), + ("abcd@deep.com", "a***d@deep.com"), + ("abc@deep.com", "a***c@deep.com"), + ("xy@deep.com", "x***y@deep.com"), + ("a@deep.com", "a***a@deep.com"), (deleted_email, deleted_email), ]: self.assertEqual(expected, generate_hidden_email(original)) @@ -693,9 +695,9 @@ def test_user_deletion_project_check(self): another_deleted_owner_user.soft_delete() af = AnalysisFrameworkFactory.create() - project1 = ProjectFactory.create(analysis_framework=af, title='Project 1') - project2 = ProjectFactory.create(analysis_framework=af, title='Project 2') - project3 = ProjectFactory.create(analysis_framework=af, title='Project 3') + project1 = ProjectFactory.create(analysis_framework=af, title="Project 1") + project2 = ProjectFactory.create(analysis_framework=af, title="Project 2") + project3 = ProjectFactory.create(analysis_framework=af, title="Project 3") project1.add_member(admin_user, role=self.project_role_admin) project1.add_member(member_user, role=self.project_role_member) @@ -705,7 +707,7 @@ def test_user_deletion_project_check(self): project3.add_member(owner_user, role=self.project_role_owner) - delete_mutation = ''' + delete_mutation = """ mutation Mutation { deleteUser { ok @@ -716,7 +718,7 @@ def test_user_deletion_project_check(self): } } } - ''' + """ def _query_check(**kwargs): return self.query_check(delete_mutation, **kwargs) @@ -767,7 +769,7 @@ def _query_check(**kwargs): self.assertEqual(active_users_qs.count(), 2) # another owner only def test_user_deletion(self): - users_query = ''' + users_query = """ query Query($id: ID!) { user(id: $id) { id @@ -780,66 +782,63 @@ def test_user_deletion(self): } } } - ''' + """ deleted_user = UserFactory.create() deleted_user.soft_delete() another_user = UserFactory.create() # now try to get users data from another user self.force_login(another_user) - user_data = self.query_check(users_query, variables={'id': deleted_user.id})['data']['user'] - self.assertEqual(user_data, dict( - id=str(deleted_user.id), - displayName=f'{settings.DELETED_USER_FIRST_NAME} {settings.DELETED_USER_LAST_NAME}', - firstName=settings.DELETED_USER_FIRST_NAME, - lastName=settings.DELETED_USER_LAST_NAME, - profile=dict( - id=str(deleted_user.profile.id), - organization=settings.DELETED_USER_ORGANIZATION, - ) - )) + user_data = self.query_check(users_query, variables={"id": deleted_user.id})["data"]["user"] + self.assertEqual( + user_data, + dict( + id=str(deleted_user.id), + displayName=f"{settings.DELETED_USER_FIRST_NAME} {settings.DELETED_USER_LAST_NAME}", + firstName=settings.DELETED_USER_FIRST_NAME, + lastName=settings.DELETED_USER_LAST_NAME, + profile=dict( + id=str(deleted_user.profile.id), + organization=settings.DELETED_USER_ORGANIZATION, + ), + ), + ) def test_user_deletion_celery_method(self): def _get_user_data(user): profile = user.profile return { - 'first_name': user.first_name, - 'last_name': user.last_name, - 'email': user.email, - 'username': user.username, - 'is_active': user.is_active, - 'profile': { - 'invalid_email': profile.invalid_email, - 'organization': profile.organization, - 'hid': profile.hid, - 'display_picture': profile.display_picture, + "first_name": user.first_name, + "last_name": user.last_name, + "email": user.email, + "username": user.username, + "is_active": user.is_active, + "profile": { + "invalid_email": profile.invalid_email, + "organization": profile.organization, + "hid": profile.hid, + "display_picture": profile.display_picture, }, } def _get_anonymized_user_data(user): return { - 'first_name': settings.DELETED_USER_FIRST_NAME, - 'last_name': settings.DELETED_USER_LAST_NAME, - 'email': f'user-{user.id}@deleted.thedeep.io', - 'username': f'user-{user.id}@deleted.thedeep.io', - 'is_active': False, - 'profile': { - 'invalid_email': True, - 'organization': settings.DELETED_USER_ORGANIZATION, - 'hid': None, - 'display_picture': None, + "first_name": settings.DELETED_USER_FIRST_NAME, + "last_name": settings.DELETED_USER_LAST_NAME, + "email": f"user-{user.id}@deleted.thedeep.io", + "username": f"user-{user.id}@deleted.thedeep.io", + "is_active": False, + "profile": { + "invalid_email": True, + "organization": settings.DELETED_USER_ORGANIZATION, + "hid": None, + "display_picture": None, }, } user1, user2, user3, user4, user5 = all_users = UserFactory.create_batch(5) - users_data = { - user.id: _get_user_data(user) - for user in all_users - } - anonymized_users_data = { - user.id: _get_anonymized_user_data(user) - for user in all_users - } + users_data = {user.id: _get_user_data(user) for user in all_users} + anonymized_users_data = {user.id: _get_anonymized_user_data(user) for user in all_users} user1.soft_delete(deleted_at=self.now_datetime - timedelta(days=32)) user2.soft_delete(deleted_at=self.now_datetime - timedelta(days=10)) @@ -887,7 +886,7 @@ def _get_anonymized_user_data(user): self.assertEqual(users_data[user.pk], user.profile.original_data) def test_user_query_db_queries(self): - QUERY_ALL_USERS = ''' + QUERY_ALL_USERS = """ query MyQuery { users { results { @@ -907,10 +906,10 @@ def test_user_query_db_queries(self): } } - ''' - user = UserFactory.create(email='testuser@deep.com') - UserFactory.create(email='testuser2@deep.com') - UserFactory.create(email='testuser3@deep.com') + """ + user = UserFactory.create(email="testuser@deep.com") + UserFactory.create(email="testuser2@deep.com") + UserFactory.create(email="testuser3@deep.com") self.force_login(user) """ @@ -924,7 +923,7 @@ def test_user_query_db_queries(self): self.query_check(QUERY_ALL_USERS) def test_user_last_active(self): - QUERY_NOTIFICATIONS = ''' + QUERY_NOTIFICATIONS = """ query MyQuery { notifications { results { @@ -932,7 +931,7 @@ def test_user_last_active(self): } } } - ''' + """ def assert_user_last_activity(user, last_active_date, is_active): user.refresh_from_db() diff --git a/apps/user/token.py b/apps/user/token.py index 69dd24f276..9c56fe09b1 100644 --- a/apps/user/token.py +++ b/apps/user/token.py @@ -1,4 +1,5 @@ from django.conf import settings + from deep.token import DeepTokenGenerator @@ -7,6 +8,7 @@ class UnsubscribeEmailTokenGenerator(DeepTokenGenerator): Strategy object used to generate and check tokens for the unsubscribing user from receving email. """ + key_salt = "user.token.UnsubscribeEmailTokenGenerator" secret = settings.SECRET_KEY reset_timeout_days = 100 @@ -23,8 +25,9 @@ def _make_hash_value(self, user, timestamp): """ return ( # FIXME: Add str(user.receive_email) here - str(user.pk) + user.password + - str(timestamp) + str(user.pk) + + user.password + + str(timestamp) ) diff --git a/apps/user/utils.py b/apps/user/utils.py index 798167aaa6..31c517ce9c 100644 --- a/apps/user/utils.py +++ b/apps/user/utils.py @@ -1,43 +1,37 @@ -import logging import datetime +import logging from celery import shared_task -from user_agents import parse - -from django.contrib.auth.models import User -from django.utils.encoding import force_bytes from django.conf import settings -from django.utils.http import urlsafe_base64_encode -from django.template import loader +from django.contrib.auth.models import User from django.contrib.auth.tokens import default_token_generator from django.core.mail import EmailMultiAlternatives - -from .token import unsubscribe_email_token_generator +from django.template import loader +from django.utils.encoding import force_bytes +from django.utils.http import urlsafe_base64_encode from project.models import ProjectJoinRequest from project.token import project_request_token_generator -from .models import Profile, EmailCondition +from user_agents import parse +from .models import EmailCondition, Profile +from .token import unsubscribe_email_token_generator logger = logging.getLogger(__name__) -def _send_mail(subject_template_name, email_template_name, - context, from_email, to_email, - html_email_template_name=None): +def _send_mail(subject_template_name, email_template_name, context, from_email, to_email, html_email_template_name=None): """ Send a django.core.mail.EmailMultiAlternatives to `to_email`. """ subject = loader.render_to_string(subject_template_name, context) # Email subject *must not* contain newlines - subject = ''.join(subject.splitlines()) + subject = "".join(subject.splitlines()) body = loader.render_to_string(email_template_name, context) - email_message = EmailMultiAlternatives( - subject, body, from_email, [to_email]) + email_message = EmailMultiAlternatives(subject, body, from_email, [to_email]) email_message.attach_alternative(body, "text/html") if html_email_template_name is not None: - html_email = loader.render_to_string( - html_email_template_name, context) - email_message.attach_alternative(html_email, 'text/html') + html_email = loader.render_to_string(html_email_template_name, context) + email_message.attach_alternative(html_email, "text/html") email_message.send() @@ -48,35 +42,40 @@ def send_mail_to_user(user, email_type, context={}, *args, **kwargs): """ if user.profile.invalid_email: logger.warning( - '[{}] Email not sent: User <{}>({}) email flagged as invalid email!!'.format( - email_type, user.email, user.pk, + "[{}] Email not sent: User <{}>({}) email flagged as invalid email!!".format( + email_type, + user.email, + user.pk, ) ) return elif not user.profile.is_email_subscribed_for(email_type): logger.warning( - '[{}] Email not sent: User <{}>({}) has not subscribed!!'.format( - email_type, user.email, user.pk, + "[{}] Email not sent: User <{}>({}) has not subscribed!!".format( + email_type, + user.email, + user.pk, ) ) return - context.update({ - 'client_domain': settings.DEEPER_FRONTEND_HOST, - 'protocol': settings.HTTP_PROTOCOL, - 'site_name': settings.DEEPER_SITE_NAME, - 'domain': settings.DJANGO_API_HOST, - 'user': user, - 'email_type': email_type, - 'unsubscribe_email_types': Profile.EMAIL_CONDITIONS_TYPES, - 'unsubscribe_email_token': - unsubscribe_email_token_generator.make_token(user), - 'unsubscribe_email_id': - urlsafe_base64_encode(force_bytes(user.pk)), - }) + context.update( + { + "client_domain": settings.DEEPER_FRONTEND_HOST, + "protocol": settings.HTTP_PROTOCOL, + "site_name": settings.DEEPER_SITE_NAME, + "domain": settings.DJANGO_API_HOST, + "user": user, + "email_type": email_type, + "unsubscribe_email_types": Profile.EMAIL_CONDITIONS_TYPES, + "unsubscribe_email_token": unsubscribe_email_token_generator.make_token(user), + "unsubscribe_email_id": urlsafe_base64_encode(force_bytes(user.pk)), + } + ) _send_mail( - *args, **kwargs, + *args, + **kwargs, context=context, from_email=settings.EMAIL_FROM, to_email=user.email, @@ -89,10 +88,12 @@ def get_users(email): that prevent inactive users and users with unusable passwords from resetting their password. """ - active_users = User._default_manager.filter(**{ - '%s__iexact' % User.get_email_field_name(): email, - 'is_active': True, - }) + active_users = User._default_manager.filter( + **{ + "%s__iexact" % User.get_email_field_name(): email, + "is_active": True, + } + ) return (u for u in active_users) @@ -102,15 +103,16 @@ def send_password_reset(user, welcome=False): user. """ context = { - 'uid': urlsafe_base64_encode(force_bytes(user.pk)), - 'token': default_token_generator.make_token(user), - 'welcome': welcome, + "uid": urlsafe_base64_encode(force_bytes(user.pk)), + "token": default_token_generator.make_token(user), + "welcome": welcome, } send_mail_to_user( - user, EmailCondition.PASSWORD_RESET, + user, + EmailCondition.PASSWORD_RESET, context=context, - subject_template_name='registration/password_reset_subject.txt', - email_template_name='registration/password_reset_email.html', + subject_template_name="registration/password_reset_subject.txt", + email_template_name="registration/password_reset_email.html", ) @@ -120,14 +122,15 @@ def send_account_activation(user): user. """ context = { - 'uid': urlsafe_base64_encode(force_bytes(user.pk)), - 'token': default_token_generator.make_token(user), + "uid": urlsafe_base64_encode(force_bytes(user.pk)), + "token": default_token_generator.make_token(user), } send_mail_to_user( - user, EmailCondition.ACCOUNT_ACTIVATION, + user, + EmailCondition.ACCOUNT_ACTIVATION, context=context, - subject_template_name='registration/user_activation_subject.txt', - email_template_name='registration/user_activation_email.html', + subject_template_name="registration/user_activation_subject.txt", + email_template_name="registration/user_activation_email.html", ) @@ -139,30 +142,32 @@ def send_project_join_request_emails(join_request_id): join_request = ProjectJoinRequest.objects.get(id=join_request_id) project = join_request.project request_by = join_request.requested_by - reason = join_request.data['reason'] - request_data = {'join_request': join_request} + reason = join_request.data["reason"] + request_data = {"join_request": join_request} email_type = EmailCondition.JOIN_REQUESTS context = { - 'request_by': request_by, - 'project': project, - 'reason': reason, - 'pid': urlsafe_base64_encode(force_bytes(join_request.pk)), + "request_by": request_by, + "project": project, + "reason": reason, + "pid": urlsafe_base64_encode(force_bytes(join_request.pk)), } for user in project.get_admins(): - request_data.update({'will_responded_by': user}) - context.update({ - 'uid': urlsafe_base64_encode(force_bytes(user.pk)), - 'token': - project_request_token_generator.make_token(request_data) - }) + request_data.update({"will_responded_by": user}) + context.update( + { + "uid": urlsafe_base64_encode(force_bytes(user.pk)), + "token": project_request_token_generator.make_token(request_data), + } + ) send_mail_to_user( - user, email_type, + user, + email_type, context=context, - subject_template_name='project/project_join_request.txt', - email_template_name='project/project_join_request_email.html', + subject_template_name="project/project_join_request.txt", + email_template_name="project/project_join_request_email.html", ) @@ -173,9 +178,9 @@ def project_context(join_request_id): role = join_request.role context = { - 'project': project, - 'user': user, - 'role': role, + "project": project, + "user": user, + "role": role, } return context @@ -185,11 +190,11 @@ def send_project_accept_email(join_request_id): context = project_context(join_request_id) send_mail_to_user( - context['user'], + context["user"], EmailCondition.JOIN_REQUESTS, context=context, - subject_template_name='project/project_join_accept.txt', - email_template_name='project/project_join_accept_email.html', + subject_template_name="project/project_join_accept.txt", + email_template_name="project/project_join_accept_email.html", ) @@ -198,28 +203,28 @@ def send_project_reject_email(join_request_id): context = project_context(join_request_id) send_mail_to_user( - context['user'], + context["user"], EmailCondition.JOIN_REQUESTS, context=context, - subject_template_name='project/project_join_reject.txt', - email_template_name='project/project_join_reject_email.html', + subject_template_name="project/project_join_reject.txt", + email_template_name="project/project_join_reject_email.html", ) def get_client_ip(request): - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") if x_forwarded_for: - ip = x_forwarded_for.split(',')[-1].strip() + ip = x_forwarded_for.split(",")[-1].strip() else: - ip = request.META.get('REMOTE_ADDR') + ip = request.META.get("REMOTE_ADDR") return ip def get_device_type(request): - http_agent = request.META.get('HTTP_USER_AGENT') + http_agent = request.META.get("HTTP_USER_AGENT") if http_agent: user_agent = parse(http_agent) - return user_agent.browser.family + ',' + user_agent.os.family + return user_agent.browser.family + "," + user_agent.os.family return @@ -227,22 +232,23 @@ def get_device_type(request): def send_password_changed_notification(user_id, client_ip, device_type): user = User.objects.get(pk=user_id) context = { - 'time': datetime.datetime.now(), - 'location': client_ip, - 'device': device_type, + "time": datetime.datetime.now(), + "location": client_ip, + "device": device_type, } send_mail_to_user( - user, email_type=EmailCondition.PASSWORD_CHANGED, + user, + email_type=EmailCondition.PASSWORD_CHANGED, context=context, - subject_template_name='password_changed/subject.txt', - email_template_name='password_changed/email.html', + subject_template_name="password_changed/subject.txt", + email_template_name="password_changed/email.html", ) def generate_hidden_email(email): - email_name, email_domain = email.split('@') + email_name, email_domain = email.split("@") # For deleted emails no need to hide. if email_domain == settings.DELETED_USER_EMAIL_DOMAIN: return email email_name_first_char, email_name_last_char = email_name[:1], email_name[-1:] - return f'{email_name_first_char}***{email_name_last_char}@{email_domain}' + return f"{email_name_first_char}***{email_name_last_char}@{email_domain}" diff --git a/apps/user/validators.py b/apps/user/validators.py index 9b5f6f1f0f..f90284c00d 100644 --- a/apps/user/validators.py +++ b/apps/user/validators.py @@ -2,7 +2,7 @@ from django.utils.translation import gettext as _ -class CustomMaximumLengthValidator(): +class CustomMaximumLengthValidator: def __init__(self, max_length=128): self.max_length = max_length @@ -11,11 +11,8 @@ def validate(self, password, user=None): raise ValidationError( _("This password has exceed the limit of %(max_length)d characters"), code="password_too_long", - params={'max_length': self.max_length}, + params={"max_length": self.max_length}, ) def get_help_text(self): - return _( - "Your password must contain less than %(max_length)d characters." - % {'max_length': self.max_length} - ) + return _("Your password must contain less than %(max_length)d characters." % {"max_length": self.max_length}) diff --git a/apps/user/views.py b/apps/user/views.py index eaf2ee6bbf..eb02845691 100644 --- a/apps/user/views.py +++ b/apps/user/views.py @@ -1,10 +1,10 @@ +from django.contrib.auth import update_session_auth_hash from django.contrib.auth.models import User +from django.contrib.auth.tokens import default_token_generator from django.db import models from django.template.response import TemplateResponse -from django.contrib.auth import update_session_auth_hash -from django.contrib.auth.tokens import default_token_generator -from django.utils.http import urlsafe_base64_decode from django.utils.encoding import force_text +from django.utils.http import urlsafe_base64_decode from rest_framework import ( exceptions, filters, @@ -16,18 +16,18 @@ ) from rest_framework.decorators import action -from utils.db.functions import StrPos from deep.views import get_frontend_url +from utils.db.functions import StrPos -from .token import unsubscribe_email_token_generator +from .permissions import UserPermission from .serializers import ( - UserSerializer, - UserPreferencesSerializer, NotificationSerializer, + PasswordChangeSerializer, PasswordResetSerializer, - PasswordChangeSerializer + UserPreferencesSerializer, + UserSerializer, ) -from .permissions import UserPermission +from .token import unsubscribe_email_token_generator class UserViewSet(viewsets.ReadOnlyModelViewSet): @@ -51,15 +51,15 @@ class UserViewSet(viewsets.ReadOnlyModelViewSet): Modify an existing user partially """ - queryset = User.objects.filter(is_active=True).order_by('-date_joined') + queryset = User.objects.filter(is_active=True).order_by("-date_joined") serializer_class = UserSerializer permission_classes = [permissions.IsAuthenticated, UserPermission] filter_backends = (filters.SearchFilter, filters.OrderingFilter) def get_object(self): - pk = self.kwargs['pk'] - if pk == 'me': + pk = self.kwargs["pk"] + if pk == "me": return self.request.user else: return super().get_object() @@ -68,42 +68,43 @@ def filter_queryset(self, queryset): queryset = super().filter_queryset(queryset) # Check if project/framework exclusion query is present - exclude_project = self.request.query_params.get( - 'members_exclude_project') - exclude_framework = self.request.query_params.get( - 'members_exclude_framework') + exclude_project = self.request.query_params.get("members_exclude_project") + exclude_framework = self.request.query_params.get("members_exclude_framework") if exclude_project: - queryset = queryset.filter( - ~models.Q(projectmembership__project=exclude_project) - ).distinct() + queryset = queryset.filter(~models.Q(projectmembership__project=exclude_project)).distinct() if exclude_framework: - queryset = queryset.filter( - ~models.Q(framework_membership__framework_id=exclude_framework) - ) + queryset = queryset.filter(~models.Q(framework_membership__framework_id=exclude_framework)) - search_str = self.request.query_params.get('search') + search_str = self.request.query_params.get("search") if search_str is None or not search_str.strip(): return queryset - return queryset.annotate( - strpos=StrPos( - models.functions.Lower( - models.functions.Concat( - 'first_name', models.Value(' '), 'last_name', - models.Value(' '), 'email', - output_field=models.CharField() - ) - ), - models.Value(search_str.lower(), models.CharField()) + return ( + queryset.annotate( + strpos=StrPos( + models.functions.Lower( + models.functions.Concat( + "first_name", + models.Value(" "), + "last_name", + models.Value(" "), + "email", + output_field=models.CharField(), + ) + ), + models.Value(search_str.lower(), models.CharField()), + ) ) - ).filter(strpos__gte=1).order_by('strpos') + .filter(strpos__gte=1) + .order_by("strpos") + ) @action( detail=True, permission_classes=[permissions.IsAuthenticated], - url_path='preferences', + url_path="preferences", serializer_class=UserPreferencesSerializer, ) def get_preferences(self, request, pk=None, version=None): @@ -117,11 +118,12 @@ def get_preferences(self, request, pk=None, version=None): @action( detail=True, permission_classes=[permissions.IsAuthenticated], - url_path='notifications', + url_path="notifications", serializer_class=NotificationSerializer, ) def get_notifications(self, request, pk=None, version=None): from user.notifications import generate_notifications + user = self.get_object() if user != request.user: raise exceptions.PermissionDenied() @@ -135,17 +137,12 @@ def get_notifications(self, request, pk=None, version=None): detail=False, permission_classes=[permissions.IsAuthenticated], url_name="change_password", - url_path='me/change-password', + url_path="me/change-password", serializer_class=PasswordChangeSerializer, - methods=['POST'] + methods=["POST"], ) def change_password(self, request, pk=None, version=None): - serializer = PasswordChangeSerializer( - data=request.data, - context={ - 'request': request - } - ) + serializer = PasswordChangeSerializer(data=request.data, context={"request": request}) serializer.is_valid(raise_exception=True) serializer.save() update_session_auth_hash(request, request.user) @@ -157,13 +154,14 @@ def post(self, request, version=None): serializer = PasswordResetSerializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() - return response.Response( - serializer.data, status=status.HTTP_201_CREATED) + return response.Response(serializer.data, status=status.HTTP_201_CREATED) def user_activate_confirm( - request, uidb64, token, - template_name='registration/user_activation_confirm.html', + request, + uidb64, + token, + template_name="registration/user_activation_confirm.html", token_generator=default_token_generator, ): try: @@ -178,9 +176,9 @@ def user_activate_confirm( user = None context = { - 'success': True, - 'login_url': get_frontend_url('login/'), - 'title': 'Account Activation', + "success": True, + "login_url": get_frontend_url("login/"), + "title": "Account Activation", } if user is not None and token_generator.check_token(user, token): @@ -188,14 +186,17 @@ def user_activate_confirm( user.profile.login_attempts = 0 user.save() else: - context['success'] = False + context["success"] = False return TemplateResponse(request, template_name, context) def unsubscribe_email( - request, uidb64, token, email_type, - template_name='user/unsubscribe_email__confirm.html', + request, + uidb64, + token, + email_type, + template_name="user/unsubscribe_email__confirm.html", token_generator=unsubscribe_email_token_generator, ): try: @@ -210,14 +211,14 @@ def unsubscribe_email( user = None context = { - 'success': True, - 'title': 'Unsubscribe Email', + "success": True, + "title": "Unsubscribe Email", } if user is not None and token_generator.check_token(user, token): user.profile.unsubscribe_email(email_type) user.save() else: - context['success'] = False + context["success"] = False return TemplateResponse(request, template_name, context) diff --git a/apps/user_group/__init__.py b/apps/user_group/__init__.py index ec04c1dc2e..792c4ae110 100644 --- a/apps/user_group/__init__.py +++ b/apps/user_group/__init__.py @@ -1 +1 @@ -default_app_config = 'user_group.apps.UserGroupConfig' +default_app_config = "user_group.apps.UserGroupConfig" diff --git a/apps/user_group/admin.py b/apps/user_group/admin.py index 13b897014e..230a3ca8d6 100644 --- a/apps/user_group/admin.py +++ b/apps/user_group/admin.py @@ -1,25 +1,26 @@ from django.contrib import admin from django.db.models import Count -from .models import UserGroup, GroupMembership +from .models import GroupMembership, UserGroup class UserGroupInline(admin.TabularInline): model = GroupMembership - autocomplete_fields = ('member', 'added_by',) + autocomplete_fields = ( + "member", + "added_by", + ) @admin.register(UserGroup) class UserGroupAdmin(admin.ModelAdmin): - list_display = ('title', 'member_count') + list_display = ("title", "member_count") inlines = [UserGroupInline] - search_fields = ('title',) - autocomplete_fields = ('display_picture',) + search_fields = ("title",) + autocomplete_fields = ("display_picture",) def get_queryset(self, request): - return super().get_queryset(request).annotate( - member_count=Count('members', distinct=True) - ) + return super().get_queryset(request).annotate(member_count=Count("members", distinct=True)) def member_count(self, instance): return instance.member_count diff --git a/apps/user_group/apps.py b/apps/user_group/apps.py index 2079712525..5e044ed355 100644 --- a/apps/user_group/apps.py +++ b/apps/user_group/apps.py @@ -2,7 +2,7 @@ class UserGroupConfig(AppConfig): - name = 'user_group' + name = "user_group" def ready(self): - from . import receivers # noqa + from . import receivers # noqa diff --git a/apps/user_group/dataloaders.py b/apps/user_group/dataloaders.py index 4481bc9937..674c8c8799 100644 --- a/apps/user_group/dataloaders.py +++ b/apps/user_group/dataloaders.py @@ -1,14 +1,12 @@ from collections import defaultdict -from promise import Promise -from django.utils.functional import cached_property + from django.db import models +from django.utils.functional import cached_property +from promise import Promise from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin -from .models import ( - UserGroup, - GroupMembership, -) +from .models import GroupMembership, UserGroup class UserGroupMembershipsLoader(DataLoaderWithContext): @@ -16,7 +14,7 @@ def batch_load_fn(self, keys): membership_qs = GroupMembership.objects.filter( # Only fetch for user_group where current user is member + ids (keys) group__in=UserGroup.get_for_member(self.context.user).filter(id__in=keys) - ).select_related('member', 'added_by') + ).select_related("member", "added_by") # Membership map memberships_map = defaultdict(list) for membership in membership_qs: @@ -26,11 +24,9 @@ def batch_load_fn(self, keys): class UserGroupMembershipsCountLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - membership_count_qs = GroupMembership.objects\ - .order_by()\ - .values('group')\ - .annotate(count=models.Count('*'))\ - .values_list('group', 'count') + membership_count_qs = ( + GroupMembership.objects.order_by().values("group").annotate(count=models.Count("*")).values_list("group", "count") + ) # Membership map _map = defaultdict(int) for group, count in membership_count_qs: @@ -40,9 +36,7 @@ def batch_load_fn(self, keys): class UserGroupCurrentUserRoleLoader(DataLoaderWithContext): def batch_load_fn(self, keys): - membership_qs = GroupMembership.objects\ - .filter(group__in=keys, member=self.context.user)\ - .values_list('group_id', 'role') + membership_qs = GroupMembership.objects.filter(group__in=keys, member=self.context.user).values_list("group_id", "role") # Role map role_map = {} for group_id, role in membership_qs: diff --git a/apps/user_group/enums.py b/apps/user_group/enums.py index 20d6f9b123..25ba9e0183 100644 --- a/apps/user_group/enums.py +++ b/apps/user_group/enums.py @@ -5,11 +5,6 @@ from .models import GroupMembership -GroupMembershipRoleEnum = convert_enum_to_graphene_enum(GroupMembership.Role, name='GroupMembershipRoleEnum') +GroupMembershipRoleEnum = convert_enum_to_graphene_enum(GroupMembership.Role, name="GroupMembershipRoleEnum") -enum_map = { - get_enum_name_from_django_field(field): enum - for field, enum in ( - (GroupMembership.role, GroupMembershipRoleEnum), - ) -} +enum_map = {get_enum_name_from_django_field(field): enum for field, enum in ((GroupMembership.role, GroupMembershipRoleEnum),)} diff --git a/apps/user_group/factories.py b/apps/user_group/factories.py index b523809a10..b88ab9bdf0 100644 --- a/apps/user_group/factories.py +++ b/apps/user_group/factories.py @@ -9,11 +9,11 @@ class UserGroupFactory(DjangoModelFactory): class Meta: model = UserGroup - title = factory.Sequence(lambda n: f'Group-{n}') + title = factory.Sequence(lambda n: f"Group-{n}") description = fuzzy.FuzzyText(length=15) - display_picture = factory.SubFactory('gallery.factories.FileFactory') - global_crisis_monitoring = factory.Faker('pybool') - custom_project_fields = factory.Dict({'custom-field': 'custom-value'}) + display_picture = factory.SubFactory("gallery.factories.FileFactory") + global_crisis_monitoring = factory.Faker("pybool") + custom_project_fields = factory.Dict({"custom-field": "custom-value"}) @factory.post_generation def members(self, create, extracted, **kwargs): diff --git a/apps/user_group/filters.py b/apps/user_group/filters.py index 2e567a2c63..c3bf7cca47 100644 --- a/apps/user_group/filters.py +++ b/apps/user_group/filters.py @@ -1,21 +1,18 @@ import django_filters from django.db import models -from utils.graphene.filters import IDFilter from utils.db.functions import StrPos +from utils.graphene.filters import IDFilter -from .models import ( - UserGroup, -) +from .models import UserGroup class UserGroupFilterSet(django_filters.FilterSet): - is_current_user_member = django_filters.BooleanFilter( - field_name='is_current_user_member', method='filter_with_membership') + is_current_user_member = django_filters.BooleanFilter(field_name="is_current_user_member", method="filter_with_membership") class Meta: model = UserGroup - fields = ['id'] + fields = ["id"] def filter_with_membership(self, queryset, name, value): if value is not None: @@ -29,9 +26,9 @@ def filter_with_membership(self, queryset, name, value): class UserGroupGQFilterSet(UserGroupFilterSet): - search = django_filters.CharFilter(method='filter_search') - members_include_project = IDFilter(method='filter_include_project') - members_exclude_project = IDFilter(method='filter_exclude_project') + search = django_filters.CharFilter(method="filter_search") + members_include_project = IDFilter(method="filter_include_project") + members_exclude_project = IDFilter(method="filter_exclude_project") class Meta: model = UserGroup @@ -39,19 +36,16 @@ class Meta: def filter_search(self, qs, name, value): if value: - return qs.annotate( - strpos=StrPos( - models.functions.Lower('title'), - models.Value(value.lower(), models.CharField()) - ) - ).filter(strpos__gte=1).order_by('strpos') + return ( + qs.annotate(strpos=StrPos(models.functions.Lower("title"), models.Value(value.lower(), models.CharField()))) + .filter(strpos__gte=1) + .order_by("strpos") + ) return qs def filter_exclude_project(self, qs, name, value): if value: - qs = qs.filter( - ~models.Q(projectusergroupmembership__project_id=value) - ).distinct() + qs = qs.filter(~models.Q(projectusergroupmembership__project_id=value)).distinct() return qs def filter_include_project(self, qs, name, value): diff --git a/apps/user_group/models.py b/apps/user_group/models.py index 22a84be601..950868a6b5 100644 --- a/apps/user_group/models.py +++ b/apps/user_group/models.py @@ -1,6 +1,5 @@ from django.contrib.auth.models import User from django.db import models - from user_resource.models import UserResource @@ -8,17 +7,21 @@ class UserGroup(UserResource): """ User group model """ + title = models.CharField(max_length=255, blank=True) description = models.TextField(blank=True) display_picture = models.ForeignKey( - 'gallery.File', + "gallery.File", on_delete=models.SET_NULL, - null=True, blank=True, default=None, + null=True, + blank=True, + default=None, ) members = models.ManyToManyField( - User, blank=True, - through_fields=('group', 'member'), - through='GroupMembership', + User, + blank=True, + through_fields=("group", "member"), + through="GroupMembership", ) global_crisis_monitoring = models.BooleanField(default=False) @@ -33,17 +36,18 @@ def get_for(user): @classmethod def get_for_gq(cls, user, only_member=False): - qs = cls.objects\ - .annotate( - # NOTE: This is used by permission module - current_user_role=models.Subquery( - GroupMembership.objects.filter( - group=models.OuterRef('pk'), - member=user, - ).order_by('role').values('role')[:1], - output_field=models.CharField() + qs = cls.objects.annotate( + # NOTE: This is used by permission module + current_user_role=models.Subquery( + GroupMembership.objects.filter( + group=models.OuterRef("pk"), + member=user, ) + .order_by("role") + .values("role")[:1], + output_field=models.CharField(), ) + ) if only_member: return qs.exclude(current_user_role__isnull=True) return qs @@ -86,40 +90,41 @@ def add_member(self, user, role=None, added_by=None): role=_role, group=self, defaults={ - 'added_by': added_by or user, + "added_by": added_by or user, }, ) def get_current_user_role(self, user): - return GroupMembership.objects\ - .filter(group=self, member=user)\ - .values_list('role', flat=True).first() + return GroupMembership.objects.filter(group=self, member=user).values_list("role", flat=True).first() class GroupMembership(models.Model): """ User group-Member relationship attributes """ + class Role(models.TextChoices): - NORMAL = 'normal', 'Normal' - ADMIN = 'admin', 'Admin' + NORMAL = "normal", "Normal" + ADMIN = "admin", "Admin" member = models.ForeignKey(User, on_delete=models.CASCADE) group = models.ForeignKey(UserGroup, on_delete=models.CASCADE) role = models.CharField(max_length=96, choices=Role.choices, default=Role.NORMAL) joined_at = models.DateTimeField(auto_now_add=True) added_by = models.ForeignKey( - User, on_delete=models.CASCADE, - null=True, blank=True, default=None, - related_name='added_group_memberships', + User, + on_delete=models.CASCADE, + null=True, + blank=True, + default=None, + related_name="added_group_memberships", ) def __str__(self): - return '{} @ {}'.format(str(self.member), - self.group.title) + return "{} @ {}".format(str(self.member), self.group.title) class Meta: - unique_together = ('member', 'group') + unique_together = ("member", "group") @staticmethod def get_for(user): @@ -133,6 +138,4 @@ def can_modify(self, user): @staticmethod def get_member_for_user_group(user_group): - return GroupMembership.objects.filter( - group=user_group - ).distinct() + return GroupMembership.objects.filter(group=user_group).distinct() diff --git a/apps/user_group/mutation.py b/apps/user_group/mutation.py index 802b407469..92a09bcdfd 100644 --- a/apps/user_group/mutation.py +++ b/apps/user_group/mutation.py @@ -1,32 +1,28 @@ import graphene +from django.core.exceptions import PermissionDenied from graphene_django import DjangoObjectType from graphene_django_extras import DjangoObjectField -from django.core.exceptions import PermissionDenied from deep.permissions import UserGroupPermissions as UgP from utils.graphene.mutation import ( - generate_input_type_for_serializer, GrapheneMutation, + UserGroupBulkGrapheneMutation, UserGroupDeleteMutation, UserGroupGrapheneMutation, - UserGroupBulkGrapheneMutation, -) - -from .models import UserGroup, GroupMembership -from .schema import UserGroupType, GroupMembershipType -from .serializers import ( - UserGroupGqSerializer, - UserGroupMembershipGqlSerializer, + generate_input_type_for_serializer, ) +from .models import GroupMembership, UserGroup +from .schema import GroupMembershipType, UserGroupType +from .serializers import UserGroupGqSerializer, UserGroupMembershipGqlSerializer UserGroupInputType = generate_input_type_for_serializer( - 'UserGroupInputType', + "UserGroupInputType", serializer_class=UserGroupGqSerializer, ) UserGroupMembershipInputType = generate_input_type_for_serializer( - 'UserGroupMembershipInputType', + "UserGroupMembershipInputType", serializer_class=UserGroupMembershipGqlSerializer, ) @@ -34,6 +30,7 @@ class CreateUserGroup(GrapheneMutation): class Arguments: data = UserGroupInputType(required=True) + model = UserGroup serializer_class = UserGroupGqSerializer result = graphene.Field(UserGroupType) @@ -55,11 +52,11 @@ class Arguments: @classmethod def check_permissions(cls, info, **_): if info.context.user != info.context.active_ug.created_by: - raise PermissionDenied('Only creater have permission to update user group') + raise PermissionDenied("Only creater have permission to update user group") @classmethod def perform_mutate(cls, root, info, **kwargs): - kwargs['id'] = info.context.active_ug.id + kwargs["id"] = info.context.active_ug.id return super().perform_mutate(root, info, **kwargs) @@ -71,11 +68,11 @@ class DeleteUserGroup(UserGroupDeleteMutation): @classmethod def check_permissions(cls, info, **_): if info.context.user != info.context.active_ug.created_by: - raise PermissionDenied('Only creater have permission to delete user group') + raise PermissionDenied("Only creater have permission to delete user group") @classmethod def perform_mutate(cls, root, info, **kwargs): - kwargs['id'] = info.context.active_ug.id + kwargs["id"] = info.context.active_ug.id return super().perform_mutate(root, info, **kwargs) @@ -111,6 +108,7 @@ class UserGroupMutationType(DjangoObjectType): """ This mutation is for other scoped objects """ + user_group_update = UpdateUserGroup.Field() user_group_delete = DeleteUserGroup.Field() user_group_membership_bulk = BulkUpdateUserGroupMembership.Field() @@ -118,7 +116,7 @@ class UserGroupMutationType(DjangoObjectType): class Meta: model = UserGroup skip_registry = True - fields = ('id', 'title') + fields = ("id", "title") @staticmethod def get_custom_node(_, info, id): @@ -130,6 +128,6 @@ def get_custom_node(_, info, id): raise PermissionDenied() -class Mutation(): +class Mutation: user_group_create = CreateUserGroup.Field() user_group = DjangoObjectField(UserGroupMutationType) diff --git a/apps/user_group/receivers.py b/apps/user_group/receivers.py index d62314fe8f..ac84a97d10 100644 --- a/apps/user_group/receivers.py +++ b/apps/user_group/receivers.py @@ -1,11 +1,6 @@ -from django.dispatch import receiver from django.db.models.signals import post_save, pre_delete - -from project.models import ( - Project, - ProjectMembership, - ProjectUserGroupMembership, -) +from django.dispatch import receiver +from project.models import Project, ProjectMembership, ProjectUserGroupMembership from user_group.models import GroupMembership @@ -25,7 +20,7 @@ def refresh_group_membership_updated(sender, instance, **kwargs): ) # Create memberships, only if this is not signal from delete - if kwargs.get('delete', False) is False: + if kwargs.get("delete", False) is False: project_members = project.get_all_members() new_users = user_group_members.difference(project_members) for user in new_users: @@ -42,9 +37,7 @@ def refresh_group_membership_updated(sender, instance, **kwargs): ).exclude(member__in=user_group_members) for membership in remove_memberships: - other_user_groups = membership.get_user_group_options().exclude( - id=user_group.id - ) + other_user_groups = membership.get_user_group_options().exclude(id=user_group.id) if other_user_groups.count() > 0: membership.linked_group = other_user_groups.first() membership.save() @@ -70,9 +63,7 @@ def refresh_group_membership_deleted(sender, instance, **kwargs): if not membership: continue - other_user_groups = membership.get_user_group_options().exclude( - id=user_group.id - ) + other_user_groups = membership.get_user_group_options().exclude(id=user_group.id) if other_user_groups.count() > 0: membership.linked_group = other_user_groups.first() membership.save() diff --git a/apps/user_group/schema.py b/apps/user_group/schema.py index 1789449745..5a42170952 100644 --- a/apps/user_group/schema.py +++ b/apps/user_group/schema.py @@ -1,46 +1,48 @@ import graphene - -from graphene_django import DjangoObjectType, DjangoListField +from graphene_django import DjangoListField, DjangoObjectType from graphene_django_extras import DjangoObjectField, PageGraphqlPagination from utils.graphene.enums import EnumDescription -from utils.graphene.types import CustomDjangoListObjectType, ClientIdMixin from utils.graphene.fields import DjangoPaginatedListObjectField +from utils.graphene.types import ClientIdMixin, CustomDjangoListObjectType -from .models import UserGroup, GroupMembership -from .filters import UserGroupGQFilterSet from .enums import GroupMembershipRoleEnum +from .filters import UserGroupGQFilterSet +from .models import GroupMembership, UserGroup class GroupMembershipType(ClientIdMixin, DjangoObjectType): class Meta: model = GroupMembership only_fields = ( - 'id', 'member', 'joined_at', 'added_by', + "id", + "member", + "joined_at", + "added_by", ) role = graphene.Field(GroupMembershipRoleEnum, required=True) - role_display = EnumDescription(source='get_role_display', required=True) + role_display = EnumDescription(source="get_role_display", required=True) class UserGroupType(DjangoObjectType): class Meta: model = UserGroup only_fields = ( - 'id', - 'title', - 'description', - 'created_at', - 'created_by', - 'modified_at', - 'modified_by', - 'client_id', - 'custom_project_fields', - 'global_crisis_monitoring', + "id", + "title", + "description", + "created_at", + "created_by", + "modified_at", + "modified_by", + "client_id", + "custom_project_fields", + "global_crisis_monitoring", ) current_user_role = graphene.Field(GroupMembershipRoleEnum) - current_user_role_display = EnumDescription(source='get_current_user_role_display') + current_user_role_display = EnumDescription(source="get_current_user_role_display") memberships_count = graphene.Int(required=True) memberships = DjangoListField(GroupMembershipType) @@ -67,8 +69,5 @@ class Meta: class Query: user_group = DjangoObjectField(UserGroupType) user_groups = DjangoPaginatedListObjectField( - UserGroupListType, - pagination=PageGraphqlPagination( - page_size_query_param='pageSize' - ) + UserGroupListType, pagination=PageGraphqlPagination(page_size_query_param="pageSize") ) diff --git a/apps/user_group/serializers.py b/apps/user_group/serializers.py index 78081959b9..715e5a0599 100644 --- a/apps/user_group/serializers.py +++ b/apps/user_group/serializers.py @@ -1,55 +1,45 @@ from django.utils.functional import cached_property from drf_dynamic_fields import DynamicFieldsMixin from rest_framework import serializers - -from deep.serializers import RemoveNullFieldsMixin, TempClientIdMixin, IntegerIDField -from user_group.models import UserGroup, GroupMembership +from user_group.models import GroupMembership, UserGroup from user_resource.serializers import UserResourceSerializer +from deep.serializers import IntegerIDField, RemoveNullFieldsMixin, TempClientIdMixin + class SimpleUserGroupSerializer(RemoveNullFieldsMixin, serializers.ModelSerializer): class Meta: model = UserGroup - fields = ('id', 'title') + fields = ("id", "title") -class GroupMembershipSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer -): - member_email = serializers.CharField(source='member.email', read_only=True) +class GroupMembershipSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): + member_email = serializers.CharField(source="member.email", read_only=True) member_name = serializers.SerializerMethodField() class Meta: model = GroupMembership - fields = ('id', 'member', 'member_name', 'member_email', - 'group', 'role', 'joined_at') + fields = ("id", "member", "member_name", "member_email", "group", "role", "joined_at") def get_member_name(self, membership): return membership.member.profile.get_display_name() # Validations def validate_group(self, group): - if not group.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid user group') + if not group.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid user group") return group def create(self, validated_data): - resource = super()\ - .create(validated_data) - resource.added_by = self.context['request'].user + resource = super().create(validated_data) + resource.added_by = self.context["request"].user resource.save() return resource -class UserGroupSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer -): +class UserGroupSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): memberships = GroupMembershipSerializer( - source='groupmembership_set', + source="groupmembership_set", many=True, required=False, ) @@ -59,89 +49,84 @@ class UserGroupSerializer( class Meta: model = UserGroup fields = ( - 'id', 'title', 'description', 'display_picture', 'role', - 'memberships', 'global_crisis_monitoring', - 'custom_project_fields', 'created_at', 'modified_at', - 'created_by', 'modified_by', 'members_count' + "id", + "title", + "description", + "display_picture", + "role", + "memberships", + "global_crisis_monitoring", + "custom_project_fields", + "created_at", + "modified_at", + "created_by", + "modified_by", + "members_count", ) def create(self, validated_data): user_group = super().create(validated_data) - GroupMembership.objects.create( - group=user_group, - member=self.context['request'].user, - role='admin' - ) + GroupMembership.objects.create(group=user_group, member=self.context["request"].user, role="admin") return user_group def get_role(self, user_group): - request = self.context['request'] - user = request.GET.get('user', request.user) + request = self.context["request"] + user = request.GET.get("user", request.user) - membership = GroupMembership.objects.filter( - group=user_group, - member=user - ).first() + membership = GroupMembership.objects.filter(group=user_group, member=user).first() if membership: return membership.role - return 'null' + return "null" # ------------------------ Graphql mutation serializers ------------------------------- -class GroupMembershipGqSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - serializers.ModelSerializer -): +class GroupMembershipGqSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = GroupMembership - fields = ('id', 'member', 'role',) + fields = ( + "id", + "member", + "role", + ) # Validations def validate_group(self, group): # TODO: Use permission check in mutation - if not group.can_modify(self.context['request'].user): - raise serializers.ValidationError('Invalid user group') + if not group.can_modify(self.context["request"].user): + raise serializers.ValidationError("Invalid user group") return group def create(self, validated_data): resource = super().create(validated_data) - resource.added_by = self.context['request'].user + resource.added_by = self.context["request"].user resource.save() return resource -class UserGroupGqSerializer( - RemoveNullFieldsMixin, - DynamicFieldsMixin, - UserResourceSerializer -): +class UserGroupGqSerializer(RemoveNullFieldsMixin, DynamicFieldsMixin, UserResourceSerializer): class Meta: model = UserGroup fields = ( - 'id', 'title', 'description', 'display_picture', 'global_crisis_monitoring', 'custom_project_fields', + "id", + "title", + "description", + "display_picture", + "global_crisis_monitoring", + "custom_project_fields", ) def create(self, validated_data): user_group = super().create(validated_data) - GroupMembership.objects.create( - group=user_group, - member=self.context['request'].user, - role='admin' - ) + GroupMembership.objects.create(group=user_group, member=self.context["request"].user, role="admin") return user_group def update(self, instance, validated_data): user_group = super().update(instance, validated_data) # FIXME: Adding created_by as admin if removed after update if user_group.created_by and not user_group.members.filter(pk=user_group.created_by_id).exists(): - GroupMembership.objects.create( - group=user_group, - member=user_group.created_by, - role='admin' - ) + GroupMembership.objects.create(group=user_group, member=user_group.created_by, role="admin") return user_group @@ -150,29 +135,24 @@ class UserGroupMembershipGqlSerializer(TempClientIdMixin, serializers.ModelSeria class Meta: model = GroupMembership - fields = ( - 'id', - 'member', - 'role', - 'client_id' - ) + fields = ("id", "member", "role", "client_id") @cached_property def usergroup(self): - usergroup = self.context['request'].active_ug + usergroup = self.context["request"].active_ug # This is a rare case, just to make sure this is validated if self.instance and self.instance.group != usergroup: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return usergroup def validate_member(self, member): current_members = GroupMembership.objects.filter(group=self.usergroup, member=member) if current_members.exclude(pk=self.instance and self.instance.pk).exists(): - raise serializers.ValidationError('User is already a member!') + raise serializers.ValidationError("User is already a member!") return member def create(self, validated_data): # make request user to be added_by by default - validated_data['added_by'] = self.context['request'].user - validated_data['group'] = self.usergroup + validated_data["added_by"] = self.context["request"].user + validated_data["group"] = self.usergroup return super().create(validated_data) diff --git a/apps/user_group/tests/test_apis.py b/apps/user_group/tests/test_apis.py index a83ddc36a4..8f6b20f9d2 100644 --- a/apps/user_group/tests/test_apis.py +++ b/apps/user_group/tests/test_apis.py @@ -1,20 +1,17 @@ -from deep.tests import TestCase +from project.models import Project, ProjectMembership, ProjectUserGroupMembership from user.models import User -from user_group.models import UserGroup, GroupMembership -from project.models import ( - Project, - ProjectMembership, - ProjectUserGroupMembership -) +from user_group.models import GroupMembership, UserGroup + +from deep.tests import TestCase class UserGroupApiTest(TestCase): def test_create_user_group(self): user_group_count = UserGroup.objects.count() - url = '/api/v1/user-groups/' + url = "/api/v1/user-groups/" data = { - 'title': 'Test user group', + "title": "Test user group", } self.authenticate() @@ -22,131 +19,108 @@ def test_create_user_group(self): self.assert_201(response) self.assertEqual(UserGroup.objects.count(), user_group_count + 1) - self.assertEqual(response.data['title'], data['title']) + self.assertEqual(response.data["title"], data["title"]) # Test that the user has been made admin - self.assertEqual(len(response.data['memberships']), 1) - self.assertEqual(response.data['memberships'][0]['member'], - self.user.pk) + self.assertEqual(len(response.data["memberships"]), 1) + self.assertEqual(response.data["memberships"][0]["member"], self.user.pk) - membership = GroupMembership.objects.get( - pk=response.data['memberships'][0]['id']) + membership = GroupMembership.objects.get(pk=response.data["memberships"][0]["id"]) self.assertEqual(membership.member.pk, self.user.pk) - self.assertEqual(membership.role, 'admin') + self.assertEqual(membership.role, "admin") def test_usergoup_fields(self): user = self.create(User) - self.create( - UserGroup, - role='admin', - created_by=user - ) - url = '/api/v1/user-groups/' + self.create(UserGroup, role="admin", created_by=user) + url = "/api/v1/user-groups/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertIn('created_at', response.data['results'][0]) - self.assertEqual(response.data['results'][0]['created_by'], user.id) + self.assertIn("created_at", response.data["results"][0]) + self.assertEqual(response.data["results"][0]["created_by"], user.id) def test_member_of(self): - user_group1 = self.create(UserGroup, role='admin') - user_group2 = self.create(UserGroup, role='admin') + user_group1 = self.create(UserGroup, role="admin") + user_group2 = self.create(UserGroup, role="admin") test_user1 = self.create(User) test_user2 = self.create(User) GroupMembership.objects.create(member=test_user1, group=user_group1) GroupMembership.objects.create(member=test_user2, group=user_group1) GroupMembership.objects.create(member=test_user2, group=user_group2) - url = '/api/v1/user-groups/member-of/' + url = "/api/v1/user-groups/member-of/" self.authenticate() response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 2) + self.assertEqual(response.data["count"], 2) self.assertEqual( - set([user_group['id'] for user_group in response.data['results']]), - set([user_group1.id, user_group2.id]) + set([user_group["id"] for user_group in response.data["results"]]), set([user_group1.id, user_group2.id]) ) - url = '/api/v1/user-groups/member-of/?user={}'.format(test_user1.id) + url = "/api/v1/user-groups/member-of/?user={}".format(test_user1.id) response = self.client.get(url) self.assert_200(response) - self.assertEqual(response.data['count'], 1) - self.assertEqual(response.data['results'][0]['id'], user_group1.id) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["id"], user_group1.id) # check for the member count # NOTE: count 3 since a member is created whenever a usergroup is created - self.assertEqual(response.data['results'][0]['members_count'], 3) + self.assertEqual(response.data["results"][0]["members_count"], 3) def test_search_user_group_without_exclusion(self): project = self.create(Project) user_group1 = self.create(UserGroup, title="MyTestUserGroup") user_group2 = self.create(UserGroup, title="MyUserTestGroup") - url = '/api/v1/user-groups/?search=test' + url = "/api/v1/user-groups/?search=test" - ProjectUserGroupMembership.objects.create( - project=project, - usergroup=user_group1 - ) + ProjectUserGroupMembership.objects.create(project=project, usergroup=user_group1) self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.json() - assert data['count'] == 2 - assert data['results'][0]['id'] == user_group1.id, \ - "'MyTestUserGroup' matches more to search query 'test'" - assert data['results'][1]['id'] == user_group2.id + assert data["count"] == 2 + assert data["results"][0]["id"] == user_group1.id, "'MyTestUserGroup' matches more to search query 'test'" + assert data["results"][1]["id"] == user_group2.id def test_search_user_group_with_exclusion(self): project = self.create(Project) user_group1 = self.create(UserGroup, title="MyTestUserGroup") user_group2 = self.create(UserGroup, title="MyUserTestGroup") - url = '/api/v1/user-groups/?search=test&members_exclude_project=' \ - + str(project.id) + url = "/api/v1/user-groups/?search=test&members_exclude_project=" + str(project.id) - ProjectUserGroupMembership.objects.create( - project=project, - usergroup=user_group1 - ) + ProjectUserGroupMembership.objects.create(project=project, usergroup=user_group1) self.authenticate() response = self.client.get(url) self.assert_200(response) data = response.json() - assert data['count'] == 1, "user group 1 is added to project" - assert data['results'][0]['id'] == user_group2.id + assert data["count"] == 1, "user group 1 is added to project" + assert data["results"][0]["id"] == user_group2.id def test_add_member(self): # check if project membership changes or not - project = self.create( - Project, - user_groups=[], - title='TestProject', - role=self.admin_role - ) - user_group = self.create(UserGroup, role='admin') + project = self.create(Project, user_groups=[], title="TestProject", role=self.admin_role) + user_group = self.create(UserGroup, role="admin") test_user = self.create(User) - ProjectUserGroupMembership.objects.create( - usergroup=user_group, - project=project - ) + ProjectUserGroupMembership.objects.create(usergroup=user_group, project=project) memberships = ProjectMembership.objects.filter(project=project) initial_member_count = memberships.count() - url = '/api/v1/group-memberships/' + url = "/api/v1/group-memberships/" data = { - 'member': test_user.pk, - 'group': user_group.pk, - 'role': 'normal', + "member": test_user.pk, + "group": user_group.pk, + "role": "normal", } self.authenticate() response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(response.data['role'], data['role']) - self.assertEqual(response.data['member'], data['member']) - self.assertEqual(response.data['group'], data['group']) + self.assertEqual(response.data["role"], data["role"]) + self.assertEqual(response.data["member"], data["member"]) + self.assertEqual(response.data["group"], data["group"]) # check for project memberships final_memberships = ProjectMembership.objects.filter(project=project) @@ -164,13 +138,10 @@ def test_delete_user_group(self): self.assertEqual(mems_count, 1) # create user group - user_group = self.create(UserGroup, role='admin') + user_group = self.create(UserGroup, role="admin") # add usergroup to project - ProjectUserGroupMembership.objects.create( - usergroup=user_group, - project=project - ) + ProjectUserGroupMembership.objects.create(usergroup=user_group, project=project) test_user = self.create(User) gm = GroupMembership.objects.create(member=test_user, group=user_group) @@ -179,7 +150,7 @@ def test_delete_user_group(self): self.assertEqual(mems_count, 2) # we added a user # now delete - url = '/api/v1/group-memberships/{}/'.format(gm.id) + url = "/api/v1/group-memberships/{}/".format(gm.id) self.authenticate() response = self.client.delete(url) @@ -199,18 +170,15 @@ def test_user_group_memberships(self): GroupMembership.objects.create(member=test_user2, group=user_group1) GroupMembership.objects.create(member=test_user1, group=user_group2) - url = f'/api/v1/user-groups/{user_group1.id}/memberships/' + url = f"/api/v1/user-groups/{user_group1.id}/memberships/" self.authenticate(test_user1) response = self.client.get(url) self.assert_200(response) - self.assertEqual(len(response.data['results']), 2) - self.assertEqual( - set(members['member'] for members in response.data['results']), - set([test_user1.id, test_user2.id]) - ) + self.assertEqual(len(response.data["results"]), 2) + self.assertEqual(set(members["member"] for members in response.data["results"]), set([test_user1.id, test_user2.id])) # request by user who is not member of usergroup - url = f'/api/v1/user-groups/{user_group2.id}/memberships/' + url = f"/api/v1/user-groups/{user_group2.id}/memberships/" self.authenticate(test_user2) response = self.client.get(url) self.assert_403(response) diff --git a/apps/user_group/tests/test_mutations.py b/apps/user_group/tests/test_mutations.py index 95c9a531f1..183efea1cc 100644 --- a/apps/user_group/tests/test_mutations.py +++ b/apps/user_group/tests/test_mutations.py @@ -1,16 +1,15 @@ -from utils.graphene.tests import GraphQLSnapShotTestCase - -from user_group.models import GroupMembership - from user.factories import UserFactory from user_group.factories import UserGroupFactory +from user_group.models import GroupMembership + +from utils.graphene.tests import GraphQLSnapShotTestCase class TestUserGroupMutationSnapShotTestCase(GraphQLSnapShotTestCase): factories_used = [UserGroupFactory] def test_usergroup_membership_bulk(self): - query = ''' + query = """ mutation MyMutation( $id: ID!, $ugMembership: [BulkUserGroupMembershipInputType!]!, @@ -53,7 +52,7 @@ def test_usergroup_membership_bulk(self): } } } - ''' + """ creater_user = UserFactory.create() user = UserFactory.create() low_permission_user = UserFactory.create() @@ -111,10 +110,11 @@ def test_usergroup_membership_bulk(self): def _query_check(**kwargs): return self.query_check( query, - mnested=['userGroup'], - variables={'id': ug.id, **minput}, + mnested=["userGroup"], + variables={"id": ug.id, **minput}, **kwargs, ) + # ---------- Without login _query_check(assert_for_error=True) # ---------- With login (with non-member) @@ -126,9 +126,9 @@ def _query_check(**kwargs): # ---------- With login (with higher permission) self.force_login(user) # ----------------- Some Invalid input - response = _query_check()['data']['userGroup']['userGroupMembershipBulk'] - self.assertMatchSnapshot(response, 'try 1') + response = _query_check()["data"]["userGroup"]["userGroupMembershipBulk"] + self.assertMatchSnapshot(response, "try 1") # ----------------- All valid input - minput['ugMembership'].pop(1) - response = _query_check()['data']['userGroup']['userGroupMembershipBulk'] - self.assertMatchSnapshot(response, 'try 2') + minput["ugMembership"].pop(1) + response = _query_check()["data"]["userGroup"]["userGroupMembershipBulk"] + self.assertMatchSnapshot(response, "try 2") diff --git a/apps/user_group/tests/test_schemas.py b/apps/user_group/tests/test_schemas.py index 532e66696a..21651b2ec8 100644 --- a/apps/user_group/tests/test_schemas.py +++ b/apps/user_group/tests/test_schemas.py @@ -1,14 +1,14 @@ -from utils.graphene.tests import GraphQLTestCase - from user.factories import UserFactory from user_group.factories import UserGroupFactory -from user_group.models import UserGroup, GroupMembership +from user_group.models import GroupMembership, UserGroup + +from utils.graphene.tests import GraphQLTestCase class TestUserGroupSchema(GraphQLTestCase): def test_user_groups_query(self): # Try with random user - query = ''' + query = """ query Query { userGroups(ordering: "id") { results { @@ -40,7 +40,7 @@ def test_user_groups_query(self): } } } - ''' + """ # Without login, throw error self.query_check(query, assert_for_error=True) @@ -51,7 +51,7 @@ def test_user_groups_query(self): # with login, return empty list content = self.query_check(query) - self.assertEqual(len(content['data']['userGroups']['results']), 0, content) + self.assertEqual(len(content["data"]["userGroups"]["results"]), 0, content) # -- Create new user groups w/wo user as member # Try with real user @@ -60,24 +60,26 @@ def test_user_groups_query(self): ug_with_admin_membership.add_member(user, role=GroupMembership.Role.ADMIN) ug_without_membership = UserGroupFactory.create(members=[another_user]) - results = self.query_check(query)['data']['userGroups']['results'] + results = self.query_check(query)["data"]["userGroups"]["results"] self.assertEqual(len(results), 3, results) - for index, (user_group, memberships_count, real_memberships_count, current_user_role) in enumerate([ - # as normal member - (ug_with_membership, 2, 2, self.genum(GroupMembership.Role.NORMAL)), - # as admin member - (ug_with_admin_membership, 1, 1, self.genum(GroupMembership.Role.ADMIN)), - # as non member - (ug_without_membership, 0, 1, None), - ]): - self.assertEqual(results[index]['id'], str(user_group.pk), results[index]) - self.assertEqual(len(results[index]['memberships']), memberships_count, results[index]) - self.assertEqual(results[index]['membershipsCount'], real_memberships_count, results[index]) - self.assertEqual(results[index]['currentUserRole'], current_user_role, results[index]) + for index, (user_group, memberships_count, real_memberships_count, current_user_role) in enumerate( + [ + # as normal member + (ug_with_membership, 2, 2, self.genum(GroupMembership.Role.NORMAL)), + # as admin member + (ug_with_admin_membership, 1, 1, self.genum(GroupMembership.Role.ADMIN)), + # as non member + (ug_without_membership, 0, 1, None), + ] + ): + self.assertEqual(results[index]["id"], str(user_group.pk), results[index]) + self.assertEqual(len(results[index]["memberships"]), memberships_count, results[index]) + self.assertEqual(results[index]["membershipsCount"], real_memberships_count, results[index]) + self.assertEqual(results[index]["currentUserRole"], current_user_role, results[index]) def test_user_group_query(self): # Try with random user - query = ''' + query = """ query Query($id: ID!) { userGroup(id: $id) { id @@ -107,32 +109,32 @@ def test_user_group_query(self): } } } - ''' + """ another_user = UserFactory.create() ug_without_membership = UserGroupFactory.create(members=[another_user]) # Without login, throw error - self.query_check(query, assert_for_error=True, variables={'id': str(ug_without_membership.pk)}) + self.query_check(query, assert_for_error=True, variables={"id": str(ug_without_membership.pk)}) # -- Create new user and login -- user = UserFactory.create() self.force_login(user) # with login, non-member usergroup will give zero members but membershipsCount 1 - content = self.query_check(query, variables={'id': str(ug_without_membership.pk)}) - self.assertEqual(content['data']['userGroup']['membershipsCount'], 1, content) - self.assertEqual(len(content['data']['userGroup']['memberships']), 0, content) + content = self.query_check(query, variables={"id": str(ug_without_membership.pk)}) + self.assertEqual(content["data"]["userGroup"]["membershipsCount"], 1, content) + self.assertEqual(len(content["data"]["userGroup"]["memberships"]), 0, content) # -- Create new user groups w/wo user as member # with login, non-member usergroup will give real members ug_with_membership = UserGroupFactory.create(members=[user, another_user]) - content = self.query_check(query, variables={'id': str(ug_with_membership.pk)}) - self.assertEqual(content['data']['userGroup']['membershipsCount'], 2, content) - self.assertEqual(len(content['data']['userGroup']['memberships']), 2, content) - self.assertEqual(content['data']['userGroup']['currentUserRole'], self.genum(GroupMembership.Role.NORMAL), content) + content = self.query_check(query, variables={"id": str(ug_with_membership.pk)}) + self.assertEqual(content["data"]["userGroup"]["membershipsCount"], 2, content) + self.assertEqual(len(content["data"]["userGroup"]["memberships"]), 2, content) + self.assertEqual(content["data"]["userGroup"]["currentUserRole"], self.genum(GroupMembership.Role.NORMAL), content) def test_user_group_create_mutation(self): - query = ''' + query = """ mutation MyMutation($input: UserGroupInputType!) { userGroupCreate(data: $input) { ok @@ -165,21 +167,21 @@ def test_user_group_create_mutation(self): } } } - ''' + """ user = UserFactory.create() - minput = dict(title='New user group from mutation') + minput = dict(title="New user group from mutation") self.query_check(query, minput=minput, assert_for_error=True) self.force_login(user) # TODO: Add permission check # content = self.query_check(query, minput=minput, okay=False) # Response with new user group - result = self.query_check(query, minput=minput, okay=True)['data']['userGroupCreate']['result'] - self.assertEqual(result['title'], minput['title'], result) - self.assertEqual(result['membershipsCount'], 1, result) + result = self.query_check(query, minput=minput, okay=True)["data"]["userGroupCreate"]["result"] + self.assertEqual(result["title"], minput["title"], result) + self.assertEqual(result["membershipsCount"], 1, result) def test_user_group_update_mutation(self): - query = ''' + query = """ mutation MyMutation($input: UserGroupInputType! $id: ID!) { userGroup(id: $id) { userGroupUpdate(data: $input) { @@ -214,46 +216,40 @@ def test_user_group_update_mutation(self): } } } - ''' + """ user = UserFactory.create() member_user = UserFactory.create() guest_user = UserFactory.create() - ug = UserGroupFactory.create(title='User-Group 101', members=[member_user], created_by=user) + ug = UserGroupFactory.create(title="User-Group 101", members=[member_user], created_by=user) ug.add_member(user, role=GroupMembership.Role.ADMIN) minput = dict( - title='User-Group 101 (Updated)', + title="User-Group 101 (Updated)", ) - self.query_check(query, minput=minput, assert_for_error=True, variables={'id': str(ug.pk)}) + self.query_check(query, minput=minput, assert_for_error=True, variables={"id": str(ug.pk)}) for _user in [guest_user, member_user]: self.force_login(_user) - self.query_check( - query, minput=minput, assert_for_error=True, variables={'id': str(ug.pk)} - ) + self.query_check(query, minput=minput, assert_for_error=True, variables={"id": str(ug.pk)}) self.force_login(user) - result = self.query_check( - query, minput=minput, okay=True, mnested=['userGroup'], variables={'id': str(ug.pk)} - )['data']['userGroup']['userGroupUpdate']['result'] + result = self.query_check(query, minput=minput, okay=True, mnested=["userGroup"], variables={"id": str(ug.pk)})["data"][ + "userGroup" + ]["userGroupUpdate"]["result"] - self.assertEqual(result['title'], minput['title'], result) - self.assertEqual(result['membershipsCount'], 2, result) - self.assertEqual(len(result['memberships']), 2, result) + self.assertEqual(result["title"], minput["title"], result) + self.assertEqual(result["membershipsCount"], 2, result) + self.assertEqual(len(result["memberships"]), 2, result) - result = self.query_check( - query, - minput=minput, - okay=True, - mnested=['userGroup'], - variables={'id': str(ug.pk)} - )['data']['userGroup']['userGroupUpdate']['result'] - self.assertEqual(result['membershipsCount'], 2, result) - self.assertEqual(len(result['memberships']), 2, result) - self.assertEqual(result['memberships'][1]['member']['id'], str(user.pk), result) - self.assertEqual(result['memberships'][1]['role'], self.genum(GroupMembership.Role.ADMIN), result) + result = self.query_check(query, minput=minput, okay=True, mnested=["userGroup"], variables={"id": str(ug.pk)})["data"][ + "userGroup" + ]["userGroupUpdate"]["result"] + self.assertEqual(result["membershipsCount"], 2, result) + self.assertEqual(len(result["memberships"]), 2, result) + self.assertEqual(result["memberships"][1]["member"]["id"], str(user.pk), result) + self.assertEqual(result["memberships"][1]["role"], self.genum(GroupMembership.Role.ADMIN), result) def test_user_group_delete_mutation(self): - query = ''' + query = """ mutation MyMutation($id: ID!) { userGroup(id: $id) { userGroupDelete { @@ -266,28 +262,32 @@ def test_user_group_delete_mutation(self): } } } - ''' + """ user = UserFactory.create() guest_user = UserFactory.create() member_user = UserFactory.create() - ug = UserGroupFactory.create(title='User-Group 101', created_by=user) + ug = UserGroupFactory.create(title="User-Group 101", created_by=user) ug.add_member(user, role=GroupMembership.Role.ADMIN) ug.add_member(member_user, role=GroupMembership.Role.ADMIN) - self.query_check(query, assert_for_error=True, variables={'id': str(ug.pk)}) + self.query_check(query, assert_for_error=True, variables={"id": str(ug.pk)}) for _user in [guest_user, member_user]: self.force_login(_user) - self.query_check(query, assert_for_error=True, variables={'id': str(ug.pk)}) + self.query_check(query, assert_for_error=True, variables={"id": str(ug.pk)}) self.force_login(user) result = self.query_check( query, okay=True, - mnested=['userGroup'], - variables={'id': str(ug.pk)}, - )['data']['userGroup']['userGroupDelete']['result'] + mnested=["userGroup"], + variables={"id": str(ug.pk)}, + )[ + "data" + ]["userGroup"][ + "userGroupDelete" + ]["result"] - self.assertEqual(result['id'], str(ug.id), result) - self.assertEqual(result['title'], ug.title, result) + self.assertEqual(result["id"], str(ug.id), result) + self.assertEqual(result["title"], ug.title, result) with self.assertRaises(UserGroup.DoesNotExist): # Make sure user_group doesn't exists anymore ug.refresh_from_db() diff --git a/apps/user_group/views.py b/apps/user_group/views.py index 7897a4836d..99847b9344 100644 --- a/apps/user_group/views.py +++ b/apps/user_group/views.py @@ -1,52 +1,41 @@ from django.db import models - -from rest_framework import ( - permissions, - viewsets, -) +from rest_framework import permissions, viewsets from rest_framework.decorators import action -from deep.permissions import ( - ModifyPermission, - IsUserGroupMember -) +from deep.permissions import IsUserGroupMember, ModifyPermission from utils.db.functions import StrPos -from .models import ( - GroupMembership, - UserGroup, -) -from .serializers import ( - GroupMembershipSerializer, - UserGroupSerializer, -) + +from .models import GroupMembership, UserGroup +from .serializers import GroupMembershipSerializer, UserGroupSerializer class UserGroupViewSet(viewsets.ModelViewSet): serializer_class = UserGroupSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_queryset(self): return UserGroup.get_for(self.request.user) @action( - detail=False, permission_classes=[permissions.IsAuthenticated], + detail=False, + permission_classes=[permissions.IsAuthenticated], serializer_class=UserGroupSerializer, - url_path='member-of', + url_path="member-of", ) def get_for_member(self, request, version=None): - user = self.request.GET.get('user', self.request.user) + user = self.request.GET.get("user", self.request.user) user_groups = UserGroup.get_for_member(user).annotate( members_count=models.functions.Coalesce( models.Subquery( - GroupMembership.objects.filter( - group=models.OuterRef('pk') - ).order_by().values('group').annotate(count=models.Count( - 'member', - distinct=True)) - .values('count')[:1], + GroupMembership.objects.filter(group=models.OuterRef("pk")) + .order_by() + .values("group") + .annotate(count=models.Count("member", distinct=True)) + .values("count")[:1], output_field=models.IntegerField(), - ), 0) + ), + 0, + ) ) self.page = self.paginate_queryset(user_groups) serializer = self.get_serializer(self.page, many=True) @@ -56,7 +45,7 @@ def get_for_member(self, request, version=None): detail=True, permission_classes=[permissions.IsAuthenticated, IsUserGroupMember], serializer_class=GroupMembershipSerializer, - url_path='memberships', + url_path="memberships", ) def get_usergroup_member(self, request, pk, version=None): user_group = self.get_object() @@ -69,36 +58,33 @@ def filter_queryset(self, queryset): queryset = super().filter_queryset(queryset) # Check if project exclusion query is present - exclude_project = self.request.query_params.get( - 'members_exclude_project') + exclude_project = self.request.query_params.get("members_exclude_project") if exclude_project: - queryset = queryset.filter( - ~models.Q(projectusergroupmembership__project=exclude_project) - ).distinct() + queryset = queryset.filter(~models.Q(projectusergroupmembership__project=exclude_project)).distinct() - search_str = self.request.query_params.get('search') + search_str = self.request.query_params.get("search") if search_str is None or not search_str.strip(): return queryset - return queryset.annotate( - strpos=StrPos( - models.functions.Lower('title'), - models.Value(search_str.lower(), models.CharField()) + return ( + queryset.annotate( + strpos=StrPos(models.functions.Lower("title"), models.Value(search_str.lower(), models.CharField())) ) - ).filter(strpos__gte=1).order_by('strpos') + .filter(strpos__gte=1) + .order_by("strpos") + ) class GroupMembershipViewSet(viewsets.ModelViewSet): serializer_class = GroupMembershipSerializer - permission_classes = [permissions.IsAuthenticated, - ModifyPermission] + permission_classes = [permissions.IsAuthenticated, ModifyPermission] def get_serializer(self, *args, **kwargs): - data = kwargs.get('data') - list = data and data.get('list') + data = kwargs.get("data") + list = data and data.get("list") if list: - kwargs.pop('data') - kwargs.pop('many', None) + kwargs.pop("data") + kwargs.pop("many", None) return super().get_serializer( data=list, many=True, @@ -111,13 +97,15 @@ def get_serializer(self, *args, **kwargs): ) def finalize_response(self, request, response, *args, **kwargs): - if request.method == 'POST' and isinstance(response.data, list): + if request.method == "POST" and isinstance(response.data, list): response.data = { - 'results': response.data, + "results": response.data, } return super().finalize_response( - request, response, - *args, **kwargs, + request, + response, + *args, + **kwargs, ) def get_queryset(self): diff --git a/apps/user_resource/apps.py b/apps/user_resource/apps.py index 034622d925..a6fb0c14bf 100644 --- a/apps/user_resource/apps.py +++ b/apps/user_resource/apps.py @@ -2,4 +2,4 @@ class UserResourceConfig(AppConfig): - name = 'user_resource' + name = "user_resource" diff --git a/apps/user_resource/filters.py b/apps/user_resource/filters.py index b3b5c8876b..3a5b14a368 100644 --- a/apps/user_resource/filters.py +++ b/apps/user_resource/filters.py @@ -2,46 +2,28 @@ from django.contrib.auth.models import User from utils.graphene.filters import ( - IDListFilter, DateTimeFilter, DateTimeGteFilter, DateTimeLteFilter, + IDListFilter, ) class UserResourceFilterSet(django_filters.FilterSet): - created_at__lt = django_filters.DateFilter( - field_name='created_at', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'] - ) - created_at__gte = django_filters.DateFilter( - field_name='created_at', - lookup_expr='gte', - input_formats=['%Y-%m-%d%z'] - ) - modified_at__lt = django_filters.DateFilter( - field_name='modified_at', - lookup_expr='lte', - input_formats=['%Y-%m-%d%z'] - ) - modified_at__gt = django_filters.DateFilter( - field_name='modified_at', - lookup_expr='gte', - input_formats=['%Y-%m-%d%z'] - ) - created_by = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all()) - modified_by = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all()) + created_at__lt = django_filters.DateFilter(field_name="created_at", lookup_expr="lte", input_formats=["%Y-%m-%d%z"]) + created_at__gte = django_filters.DateFilter(field_name="created_at", lookup_expr="gte", input_formats=["%Y-%m-%d%z"]) + modified_at__lt = django_filters.DateFilter(field_name="modified_at", lookup_expr="lte", input_formats=["%Y-%m-%d%z"]) + modified_at__gt = django_filters.DateFilter(field_name="modified_at", lookup_expr="gte", input_formats=["%Y-%m-%d%z"]) + created_by = django_filters.ModelMultipleChoiceFilter(queryset=User.objects.all()) + modified_by = django_filters.ModelMultipleChoiceFilter(queryset=User.objects.all()) class UserResourceGqlFilterSet(django_filters.FilterSet): created_at = DateTimeFilter() - created_at_gte = DateTimeGteFilter(field_name='created_at') - created_at_lte = DateTimeLteFilter(field_name='created_at') + created_at_gte = DateTimeGteFilter(field_name="created_at") + created_at_lte = DateTimeLteFilter(field_name="created_at") modified_at = DateTimeFilter() - modified_at_gte = DateTimeGteFilter(field_name='modified_at') - modified_at_lte = DateTimeLteFilter(field_name='modified_at') + modified_at_gte = DateTimeGteFilter(field_name="modified_at") + modified_at_lte = DateTimeLteFilter(field_name="modified_at") created_by = IDListFilter() modified_by = IDListFilter() diff --git a/apps/user_resource/models.py b/apps/user_resource/models.py index 2feaff17ce..02adb167ea 100644 --- a/apps/user_resource/models.py +++ b/apps/user_resource/models.py @@ -1,19 +1,21 @@ -from django.db import models from django.contrib.auth.models import User +from django.db import models class UserResourceCreated(models.Model): created_at = models.DateTimeField(auto_now_add=True) created_by = models.ForeignKey( User, - related_name='%(class)s_created', - default=None, blank=True, null=True, + related_name="%(class)s_created", + default=None, + blank=True, + null=True, on_delete=models.SET_NULL, ) class Meta: abstract = True - ordering = ['-created_at'] + ordering = ["-created_at"] class UserResource(models.Model): @@ -21,14 +23,18 @@ class UserResource(models.Model): modified_at = models.DateTimeField(auto_now=True) created_by = models.ForeignKey( User, - related_name='%(class)s_created', - default=None, blank=True, null=True, + related_name="%(class)s_created", + default=None, + blank=True, + null=True, on_delete=models.SET_NULL, ) modified_by = models.ForeignKey( User, - related_name='%(class)s_modified', - default=None, blank=True, null=True, + related_name="%(class)s_modified", + default=None, + blank=True, + null=True, on_delete=models.SET_NULL, ) @@ -36,4 +42,4 @@ class UserResource(models.Model): class Meta: abstract = True - ordering = ['-created_at'] + ordering = ["-created_at"] diff --git a/apps/user_resource/schema.py b/apps/user_resource/schema.py index 0f75bc9429..085c85d984 100644 --- a/apps/user_resource/schema.py +++ b/apps/user_resource/schema.py @@ -1,10 +1,9 @@ import graphene - -from graphql.execution.base import ResolveInfo from django.db import models +from graphql.execution.base import ResolveInfo +from user.schema import UserType from utils.common import has_select_related -from user.schema import UserType class UserResourceMixin(graphene.ObjectType): @@ -14,10 +13,10 @@ class UserResourceMixin(graphene.ObjectType): modified_by = graphene.Field(UserType) def resolve_created_by(root, info, **kwargs): - return resolve_user_field(root, info, 'created_by') + return resolve_user_field(root, info, "created_by") def resolve_modified_by(root, info, **kwargs): - return resolve_user_field(root, info, 'modified_by') + return resolve_user_field(root, info, "modified_by") def resolve_user_field(root: models.Model, info: ResolveInfo, field: str): diff --git a/apps/user_resource/serializers.py b/apps/user_resource/serializers.py index 3f72ce2abe..665bb0f934 100644 --- a/apps/user_resource/serializers.py +++ b/apps/user_resource/serializers.py @@ -1,37 +1,30 @@ +import reversion from drf_writable_nested.serializers import WritableNestedModelSerializer from rest_framework import serializers from reversion.models import Version -import reversion -from deep.writable_nested_serializers import ( - NestedCreateMixin, - NestedUpdateMixin, -) +from deep.writable_nested_serializers import NestedCreateMixin, NestedUpdateMixin class UserResourceBaseSerializer(serializers.Serializer): modified_at = serializers.DateTimeField(read_only=True) modified_by = serializers.PrimaryKeyRelatedField(read_only=True) - created_by_name = serializers.CharField( - source='created_by.profile.get_display_name', - read_only=True) - modified_by_name = serializers.CharField( - source='modified_by.profile.get_display_name', - read_only=True) + created_by_name = serializers.CharField(source="created_by.profile.get_display_name", read_only=True) + modified_by_name = serializers.CharField(source="modified_by.profile.get_display_name", read_only=True) client_id = serializers.CharField(required=False) version_id = serializers.SerializerMethodField() def create(self, validated_data): - if 'created_by' in self.Meta.model._meta._forward_fields_map: - validated_data['created_by'] = self.context['request'].user - if 'modified_by' in self.Meta.model._meta._forward_fields_map: - validated_data['modified_by'] = self.context['request'].user + if "created_by" in self.Meta.model._meta._forward_fields_map: + validated_data["created_by"] = self.context["request"].user + if "modified_by" in self.Meta.model._meta._forward_fields_map: + validated_data["modified_by"] = self.context["request"].user return super().create(validated_data) def update(self, instance, validated_data): - if 'modified_by' in self.Meta.model._meta._forward_fields_map: - validated_data['modified_by'] = self.context['request'].user + if "modified_by" in self.Meta.model._meta._forward_fields_map: + validated_data["modified_by"] = self.context["request"].user return super().update(instance, validated_data) def get_version_id(self, resource): @@ -39,9 +32,9 @@ def get_version_id(self, resource): return None version_id = Version.objects.get_for_object(resource).count() - request = self.context['request'] - if request.method in ['POST', 'PUT', 'PATCH']: - if not (request.method == 'POST' and self.context.get('post_is_used_for_filter', False)): + request = self.context["request"] + if request.method in ["POST", "PUT", "PATCH"]: + if not (request.method == "POST" and self.context.get("post_is_used_for_filter", False)): version_id += 1 return version_id @@ -57,10 +50,7 @@ def _prefetch_related_instances(self, field, related_data): pk_list = self._extract_related_pks(field, related_data) qs = self._get_prefetch_related_instances_qs(model_class.objects) # Modification added - instances = { - str(related_instance.pk): related_instance - for related_instance in qs.filter(pk__in=pk_list) - } + instances = {str(related_instance.pk): related_instance for related_instance in qs.filter(pk__in=pk_list)} return instances @@ -70,7 +60,7 @@ class UserResourceCreatedMixin(serializers.Serializer): created_by = serializers.PrimaryKeyRelatedField(read_only=True) def create(self, validated_data): - validated_data['created_by'] = self.context['request'].user + validated_data["created_by"] = self.context["request"].user return super().create(validated_data) diff --git a/deep/__init__.py b/deep/__init__.py index 1a6c551dd5..600d25e061 100644 --- a/deep/__init__.py +++ b/deep/__init__.py @@ -2,4 +2,4 @@ # Django starts so that shared_task will use this app. from .celery import app as celery_app -__all__ = ['celery_app'] +__all__ = ["celery_app"] diff --git a/deep/admin.py b/deep/admin.py index fa8ec83d4d..d7998bc972 100644 --- a/deep/admin.py +++ b/deep/admin.py @@ -1,42 +1,42 @@ -from django.utils.html import format_html -from django.utils.safestring import mark_safe -from django.urls import reverse -from django.contrib.postgres import fields -from django.contrib import admin -from django.conf import settings from urllib.parse import quote -from reversion.admin import VersionAdmin as _VersionAdmin +from django.conf import settings +from django.contrib import admin +from django.contrib.postgres import fields +from django.urls import reverse +from django.utils.html import format_html +from django.utils.safestring import mark_safe from jsoneditor.forms import JSONEditor as _JSONEditor - +from reversion.admin import VersionAdmin as _VersionAdmin site = admin.site def get_site_string(title): - return f'{title} ({settings.DEEP_ENVIRONMENT.title()})' + return f"{title} ({settings.DEEP_ENVIRONMENT.title()})" # Text to put at the end of each page's . -site.site_title = get_site_string('DEEP site admin') +site.site_title = get_site_string("DEEP site admin") # Text to put in each page's <h1> (and above login form). -site.site_header = get_site_string('DEEP Administration') +site.site_header = get_site_string("DEEP Administration") # Text to put at the top of the admin index page. -site.index_title = get_site_string('DEEP Administration') +site.index_title = get_site_string("DEEP Administration") class JSONEditor(_JSONEditor): class Media: js = [ # NOTE: Not using this breaks autocomplete - 'admin/js/vendor/jquery/jquery%s.js' % ('' if settings.DEBUG else '.min') + "admin/js/vendor/jquery/jquery%s.js" + % ("" if settings.DEBUG else ".min") ] + list(_JSONEditor.Media.js[1:]) css = _JSONEditor.Media.css -class JSONFieldMixin(): +class JSONFieldMixin: formfield_overrides = { - fields.JSONField: {'widget': JSONEditor}, + fields.JSONField: {"widget": JSONEditor}, } @@ -52,7 +52,7 @@ class VersionAdmin(JSONFieldMixin, _VersionAdmin): pass -class ReadOnlyMixin(): +class ReadOnlyMixin: def has_add_permission(self, *args, **kwargs): return False @@ -74,7 +74,7 @@ def linkify(field_name, label=None): def _linkify(obj): try: linked_obj = obj - for _field_name in field_name.split('.'): + for _field_name in field_name.split("."): linked_obj = getattr(linked_obj, _field_name, None) if linked_obj: app_label = linked_obj._meta.app_label @@ -84,10 +84,10 @@ def _linkify(obj): return format_html(f'<a href="{link_url}">{linked_obj}</a>') except Exception: pass - return '-' + return "-" - _linkify.short_description = label or ' '.join(field_name.split('.')) - _linkify.admin_order_field = '__'.join(field_name.split('.')) + _linkify.short_description = label or " ".join(field_name.split(".")) + _linkify.admin_order_field = "__".join(field_name.split(".")) return _linkify @@ -98,37 +98,40 @@ def query_buttons(description, queries): If field_name is 'parent', link text will be str(obj.parent) Link will be admin url for the admin url for obj.parent.id:change """ + def _query_buttons(obj): app_label = obj._meta.app_label model_name = obj._meta.model_name - view_name = f'admin:{app_label}_{model_name}_change' + view_name = f"admin:{app_label}_{model_name}_change" buttons = [] for query in queries: - link_url = f'{reverse(view_name, args=[obj.pk])}?show_{query}=true' + link_url = f"{reverse(view_name, args=[obj.pk])}?show_{query}=true" buttons.append(f'<a class="changelink" href="{link_url}">{query.title()}</a>') - return mark_safe(''.join(buttons)) + return mark_safe("".join(buttons)) _query_buttons.short_description = description return _query_buttons -def document_preview(field_name, max_height='600px', max_width='800px', label=None): +def document_preview(field_name, max_height="600px", max_width="800px", label=None): """ Show document preview for file fields """ + def _document_preview(obj): file = getattr(obj, field_name) if file and file.url: try: - if file.name.split('?')[0].split('.')[-1] in ['docx', 'xlsx', 'pptx', 'ods', 'doc']: + if file.name.split("?")[0].split(".")[-1] in ["docx", "xlsx", "pptx", "ods", "doc"]: return mark_safe( - f''' + f""" <iframe src="https://docs.google.com/viewer?url={quote(file.url)}&embedded=true"></iframe> - ''' + """ ) except Exception: pass - return mark_safe(f""" + return mark_safe( + f""" <object data="{file.url}" style="display: block; max-width:{max_width}; max-height:{max_height}; width: auto; height: auto;" @@ -136,8 +139,10 @@ def _document_preview(obj): <img style="max-height:{max_height};max-width:{max_width}" src="{file.url}"/> <iframe src="https://docs.google.com/viewer?url={quote(file.url)}&embedded=true"></iframe> </object> - """) - return 'N/A' - _document_preview.short_description = label or 'Document Preview' + """ + ) + return "N/A" + + _document_preview.short_description = label or "Document Preview" _document_preview.allow_tags = True return _document_preview diff --git a/deep/caches.py b/deep/caches.py index 2d6595c560..9a68d6713c 100644 --- a/deep/caches.py +++ b/deep/caches.py @@ -1,19 +1,16 @@ -import json import hashlib +import json from typing import Union from django.core.cache import cache, caches from django.core.serializers.json import DjangoJSONEncoder - -local_cache = caches['local-memory'] +local_cache = caches["local-memory"] def clear_cache(prefix): try: - cache.delete_many( - cache.keys(prefix + '*') - ) + cache.delete_many(cache.keys(prefix + "*")) return True except Exception: pass @@ -21,95 +18,95 @@ def clear_cache(prefix): class CacheKey: # Redis Cache - URL_CACHED_FILE_FIELD_KEY_FORMAT = 'url_cache_{}' - CONNECTOR_KEY_FORMAT = 'connector_{}' - EXPORT_TASK_CACHE_KEY_FORMAT = 'EXPORT-{}-TASK-ID' - GENERIC_EXPORT_TASK_CACHE_KEY_FORMAT = 'GENERIC-EXPORT-{}-TASK-ID' - PROJECT_EXPLORE_STATS_LOADER_KEY = 'project-explore-stats-loader' - RECENT_ACTIVITIES_KEY_FORMAT = 'user-recent-activities-{}' + URL_CACHED_FILE_FIELD_KEY_FORMAT = "url_cache_{}" + CONNECTOR_KEY_FORMAT = "connector_{}" + EXPORT_TASK_CACHE_KEY_FORMAT = "EXPORT-{}-TASK-ID" + GENERIC_EXPORT_TASK_CACHE_KEY_FORMAT = "GENERIC-EXPORT-{}-TASK-ID" + PROJECT_EXPLORE_STATS_LOADER_KEY = "project-explore-stats-loader" + RECENT_ACTIVITIES_KEY_FORMAT = "user-recent-activities-{}" # Local (RAM) Cache - TEMP_CLIENT_ID_KEY_FORMAT = 'client-id-mixin-{request_hash}-{instance_type}-{instance_id}' - TEMP_CUSTOM_CLIENT_ID_KEY_FORMAT = '{prefix}-client-id-mixin-{request_hash}-{instance_type}-{instance_id}' + TEMP_CLIENT_ID_KEY_FORMAT = "client-id-mixin-{request_hash}-{instance_type}-{instance_id}" + TEMP_CUSTOM_CLIENT_ID_KEY_FORMAT = "{prefix}-client-id-mixin-{request_hash}-{instance_type}-{instance_id}" class ExploreDeep: - BASE = 'EXPLORE-DEEP-' - _PREFIX = BASE + '{}-' + BASE = "EXPLORE-DEEP-" + _PREFIX = BASE + "{}-" # Dynamic - TOTAL_PROJECTS_COUNT = _PREFIX + 'TOTAL-PROJECTS' - TOTAL_REGISTERED_USERS_COUNT = _PREFIX + 'TOTAL-REGISTERED-USERS' - TOTAL_LEADS_COUNT = _PREFIX + 'TOTAL-LEADS' - TOTAL_ENTRIES_COUNT = _PREFIX + 'TOTAL-ENTRIES' - TOTAL_ACTIVE_USERS_COUNT = _PREFIX + 'TOTAL-ACTIVE-USERS' - TOTAL_AUTHORS_COUNT = _PREFIX + 'TOTAL-AUTHORS' - TOTAL_PUBLISHERS_COUNT = _PREFIX + 'TOTAL-PUBLISHERS' - TOP_TEN_AUTHORS_LIST = _PREFIX + 'TOP-TEN-AUTHORS' - TOP_TEN_PUBLISHERS_LIST = _PREFIX + 'TOP-TEN-PUBLISHERS' - TOP_TEN_FRAMEWORKS_LIST = _PREFIX + 'TOP-TEN-FRAMEWORKS' - TOP_TEN_PROJECTS_BY_USERS_LIST = _PREFIX + 'TOP-TEN-PROJECTS-BY-USERS' - TOP_TEN_PROJECTS_BY_ENTRIES_LIST = _PREFIX + 'TOP-TEN-PROJECTS-BY-ENTRIES' - TOP_TEN_PROJECTS_BY_SOURCES_LIST = _PREFIX + 'TOP-TEN-PROJECTS-BY-SOURCES' + TOTAL_PROJECTS_COUNT = _PREFIX + "TOTAL-PROJECTS" + TOTAL_REGISTERED_USERS_COUNT = _PREFIX + "TOTAL-REGISTERED-USERS" + TOTAL_LEADS_COUNT = _PREFIX + "TOTAL-LEADS" + TOTAL_ENTRIES_COUNT = _PREFIX + "TOTAL-ENTRIES" + TOTAL_ACTIVE_USERS_COUNT = _PREFIX + "TOTAL-ACTIVE-USERS" + TOTAL_AUTHORS_COUNT = _PREFIX + "TOTAL-AUTHORS" + TOTAL_PUBLISHERS_COUNT = _PREFIX + "TOTAL-PUBLISHERS" + TOP_TEN_AUTHORS_LIST = _PREFIX + "TOP-TEN-AUTHORS" + TOP_TEN_PUBLISHERS_LIST = _PREFIX + "TOP-TEN-PUBLISHERS" + TOP_TEN_FRAMEWORKS_LIST = _PREFIX + "TOP-TEN-FRAMEWORKS" + TOP_TEN_PROJECTS_BY_USERS_LIST = _PREFIX + "TOP-TEN-PROJECTS-BY-USERS" + TOP_TEN_PROJECTS_BY_ENTRIES_LIST = _PREFIX + "TOP-TEN-PROJECTS-BY-ENTRIES" + TOP_TEN_PROJECTS_BY_SOURCES_LIST = _PREFIX + "TOP-TEN-PROJECTS-BY-SOURCES" # Static - TOTAL_ENTRIES_ADDED_LAST_WEEK_COUNT = BASE + 'TOTAL_ENTRIES_ADDED_LAST_WEEK_COUNT' + TOTAL_ENTRIES_ADDED_LAST_WEEK_COUNT = BASE + "TOTAL_ENTRIES_ADDED_LAST_WEEK_COUNT" @classmethod def clear_cache(cls): return clear_cache(cls.BASE) class Tracker: - BASE = 'DEEP-TRACKER-' + BASE = "DEEP-TRACKER-" # Dynamic - LAST_PROJECT_READ_ACCESS_DATETIME = BASE + 'LAST-PROJECT-READ-ACCESS-DATETIME-' - LAST_PROJECT_WRITE_ACCESS_DATETIME = BASE + 'LAST-PROJECT-WRITE-ACCESS-DATETIME-' - LAST_USER_ACTIVE_DATETIME = BASE + 'LAST-USER-ACTIVE-DATETIME-' + LAST_PROJECT_READ_ACCESS_DATETIME = BASE + "LAST-PROJECT-READ-ACCESS-DATETIME-" + LAST_PROJECT_WRITE_ACCESS_DATETIME = BASE + "LAST-PROJECT-WRITE-ACCESS-DATETIME-" + LAST_USER_ACTIVE_DATETIME = BASE + "LAST-USER-ACTIVE-DATETIME-" class AssessmentDashboard: - BASE = 'ASSESSMENT-DASHBOARD-' - _PREFIX = BASE + '{}-' - - TOTAL_ASSESSMENT_COUNT = _PREFIX + 'TOTAL-ASSESSMENT' - TOTAL_STAKEHOLDER_COUNT = _PREFIX + 'TOTAL-STAKEHOLDER' - TOTAL_COLLECTION_TECHNIQUE_COUNT = _PREFIX + 'TOTAL-COLLECTION_TECHNIQUE' - OVER_THE_TIME = _PREFIX + 'OVER-THE-TIME' - ASSESSMENT_COUNT = _PREFIX + 'ASSESSMENT-COUNT' - STAKEHOLDER_COUNT = _PREFIX + 'STAKEHOLDER-COUNT' - TOTAL_MULTISECTOR_ASSESSMENT_COUNT = _PREFIX + 'TOTAL-MULTISECTOR-ASSESSMENT-COUNT' - TOTAL_SINGLESECTOR_ASSESSMENT_COUNT = _PREFIX + 'TOTAL-SINGLESECTOR-ASSESSMENT-COUNT' - COLLECTION_TECHNIQUE_COUNT = _PREFIX + 'COLLECTION-TECHNIQUE-COUNT' - ASSESSMENT_PER_FRAMEWORK_PILLAR = _PREFIX + 'ASSESSMENT-PER-FRAMEWORK-PILLAR' - ASSESSMENT_PER_AFFECTED_GROUP = _PREFIX + 'ASSESSMENT-PER-AFFECTED-GROUP' - ASSESSMENT_PER_HUMANITRATION_SECTOR = _PREFIX + 'ASSESSMENT-PER-HUMANITRATION-SECTOR' - ASSESSMENT_PER_PROTECTION_MANAGEMENT = _PREFIX + 'ASSESSMENT-PER-PROTECTION-MANAGEMENT' - ASSESSMENT_SECTOR_AND_GEOAREA = _PREFIX + 'ASSESSMENT-SECTOR-AND-GEOAREA' - ASSESSMENT_AFFECTED_GROUP_AND_GEOAREA = _PREFIX + 'ASSESSMENT-AFFECTED-GROUP-AND-GEOAREA' - ASSESSMENT_AFFECTED_GROUP_AND_SECTOR = _PREFIX + 'ASSESSMENT-AFFECTED-GROUP-AND-SECTOR' - ASSESSMENT_BY_LEAD_ORGANIZATION = _PREFIX + 'ASSESSMENT-BY-LEAD-ORGANIZATION' - ASSESSMENT_PER_DATA_COLLECTION_TECHNIQUE = _PREFIX + 'ASSESSMENT-PER-DATA-COLLECTION-TECHNIQUE' - ASSESSMENT_PER_UNIT_ANALYSIS = _PREFIX + 'ASSESSMENT-PER-UNIT-ANALYSIS' - ASSESSMENT_PER_UNIT_REPORTING = _PREFIX + 'ASSESSMENT-PER-UNIT-REPORTING' - ASSESSMENT_PER_SAMPLE_APPROACH = _PREFIX + 'ASSESSMENT-PER-SAMPLE-APPROACH' - ASSESSMENT_PER_PROXIMITY = _PREFIX + 'ASSESSMENT-PER-PROXIMITY' - ASSESSMENT_BY_GEOAREA = _PREFIX + 'ASSESSMENT-BY-GEOAREA' - SAMPLE_SIZE_PER_DATA_COLLECTION_TECHNIQUE = _PREFIX + 'SAMPLE-SIZE-PER-DATA-COLLECTION-TECHNIQUE' - DATA_COLLECTION_TECHNIQUE_AND_GEOLOCATION = _PREFIX + 'DATA-COLLECTION-TECHNIQUE-AND-GEOLOCATION' - MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP_BY_MONTH = _PREFIX + 'MEDIAN-SCORE-BY-SECTOR-AND-AFFECTED-GROUP-BY-MONTH' - MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP = _PREFIX + 'MEDIAN-SCORE-BY-SECTOR-AND-AFFECTED-GROUP' - MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_SECTOR_BY_MONTH = _PREFIX + 'MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP_BY_MONTH' - MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_AFFECTED_GROUP = _PREFIX + 'MEDIAN-QUALITY-SCORE-BY-GEOAREA-AND-AFFECTED-GROUP' - MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_SECTOR = _PREFIX + 'MEDIAN-QUALIRY-SCORE-BY-GEOAREA-AND-SECTOR' - MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE = _PREFIX + 'MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE_MONTH' - MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE_MONTH = _PREFIX + 'MEDIAN-QUALIRY-SCORE-BY-ANALYTICAL-DENSITY-DATE' - MEDIAN_QUALITY_SCORE_OF_ANALYTICAL_DENSITY = _PREFIX + 'MEDIAN-QUALITY-SCORE-OF_ANALYTICAL-DENSITY' - MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION_BY_DATE_MONTH = _PREFIX + 'MEDIAN-QUALITY-SCORE-EACH-DIMENSION-BY-DATE-MONTH' - MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION_BY_DATE = _PREFIX + 'MEDIAN-QUALITY-SCORE-EACH-DIMENSION-BY-DATE' - MEDIAN_QUALITY_SCORE_OVER_TIME = _PREFIX + 'MEDIAN-QUALITY-SCORE-OVER-TIME' - MEDIAN_QUALITY_SCORE_BY_GEO_AREA = _PREFIX + 'MEDIAN-QUALITY-SCORE-GEO-AREA' - UNIT_REPORTING_AND_GEOLOCATION = _PREFIX + 'UNIT-REPORTING-AND-GEOLOCATION' - UNIT_OF_ANALYSIS_AND_GEOLOCATION = _PREFIX + 'UNIT-OF-ANALYSIS-AND-GEOLOCATION' - PROXIMITY_AND_GEOLOCATION = _PREFIX + 'PROXIMITY-AND-GEOLOCATION' - SAMPLING_APPROACH_AND_GEOLOCATION = _PREFIX + 'SAMPLING-APPROCACH-AND-GEOLOCATION' - MEDIAN_QUALITY_SCORE_OVER_TIME_BY_MONTH = _PREFIX + 'MEDIAN-QUALITY-SCORE-OVER-TIME-MONTH' - MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION = _PREFIX + 'MEDIAN-QUALITY-SCORE-EACH-DIMENSION' + BASE = "ASSESSMENT-DASHBOARD-" + _PREFIX = BASE + "{}-" + + TOTAL_ASSESSMENT_COUNT = _PREFIX + "TOTAL-ASSESSMENT" + TOTAL_STAKEHOLDER_COUNT = _PREFIX + "TOTAL-STAKEHOLDER" + TOTAL_COLLECTION_TECHNIQUE_COUNT = _PREFIX + "TOTAL-COLLECTION_TECHNIQUE" + OVER_THE_TIME = _PREFIX + "OVER-THE-TIME" + ASSESSMENT_COUNT = _PREFIX + "ASSESSMENT-COUNT" + STAKEHOLDER_COUNT = _PREFIX + "STAKEHOLDER-COUNT" + TOTAL_MULTISECTOR_ASSESSMENT_COUNT = _PREFIX + "TOTAL-MULTISECTOR-ASSESSMENT-COUNT" + TOTAL_SINGLESECTOR_ASSESSMENT_COUNT = _PREFIX + "TOTAL-SINGLESECTOR-ASSESSMENT-COUNT" + COLLECTION_TECHNIQUE_COUNT = _PREFIX + "COLLECTION-TECHNIQUE-COUNT" + ASSESSMENT_PER_FRAMEWORK_PILLAR = _PREFIX + "ASSESSMENT-PER-FRAMEWORK-PILLAR" + ASSESSMENT_PER_AFFECTED_GROUP = _PREFIX + "ASSESSMENT-PER-AFFECTED-GROUP" + ASSESSMENT_PER_HUMANITRATION_SECTOR = _PREFIX + "ASSESSMENT-PER-HUMANITRATION-SECTOR" + ASSESSMENT_PER_PROTECTION_MANAGEMENT = _PREFIX + "ASSESSMENT-PER-PROTECTION-MANAGEMENT" + ASSESSMENT_SECTOR_AND_GEOAREA = _PREFIX + "ASSESSMENT-SECTOR-AND-GEOAREA" + ASSESSMENT_AFFECTED_GROUP_AND_GEOAREA = _PREFIX + "ASSESSMENT-AFFECTED-GROUP-AND-GEOAREA" + ASSESSMENT_AFFECTED_GROUP_AND_SECTOR = _PREFIX + "ASSESSMENT-AFFECTED-GROUP-AND-SECTOR" + ASSESSMENT_BY_LEAD_ORGANIZATION = _PREFIX + "ASSESSMENT-BY-LEAD-ORGANIZATION" + ASSESSMENT_PER_DATA_COLLECTION_TECHNIQUE = _PREFIX + "ASSESSMENT-PER-DATA-COLLECTION-TECHNIQUE" + ASSESSMENT_PER_UNIT_ANALYSIS = _PREFIX + "ASSESSMENT-PER-UNIT-ANALYSIS" + ASSESSMENT_PER_UNIT_REPORTING = _PREFIX + "ASSESSMENT-PER-UNIT-REPORTING" + ASSESSMENT_PER_SAMPLE_APPROACH = _PREFIX + "ASSESSMENT-PER-SAMPLE-APPROACH" + ASSESSMENT_PER_PROXIMITY = _PREFIX + "ASSESSMENT-PER-PROXIMITY" + ASSESSMENT_BY_GEOAREA = _PREFIX + "ASSESSMENT-BY-GEOAREA" + SAMPLE_SIZE_PER_DATA_COLLECTION_TECHNIQUE = _PREFIX + "SAMPLE-SIZE-PER-DATA-COLLECTION-TECHNIQUE" + DATA_COLLECTION_TECHNIQUE_AND_GEOLOCATION = _PREFIX + "DATA-COLLECTION-TECHNIQUE-AND-GEOLOCATION" + MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP_BY_MONTH = _PREFIX + "MEDIAN-SCORE-BY-SECTOR-AND-AFFECTED-GROUP-BY-MONTH" + MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP = _PREFIX + "MEDIAN-SCORE-BY-SECTOR-AND-AFFECTED-GROUP" + MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_SECTOR_BY_MONTH = _PREFIX + "MEDIAN_SCORE_BY_SECTOR_AND_AFFECTED_GROUP_BY_MONTH" + MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_AFFECTED_GROUP = _PREFIX + "MEDIAN-QUALITY-SCORE-BY-GEOAREA-AND-AFFECTED-GROUP" + MEDIAN_QUALITY_SCORE_BY_GEOAREA_AND_SECTOR = _PREFIX + "MEDIAN-QUALIRY-SCORE-BY-GEOAREA-AND-SECTOR" + MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE = _PREFIX + "MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE_MONTH" + MEDIAN_QUALITY_SCORE_BY_ANALYTICAL_DENSITY_DATE_MONTH = _PREFIX + "MEDIAN-QUALIRY-SCORE-BY-ANALYTICAL-DENSITY-DATE" + MEDIAN_QUALITY_SCORE_OF_ANALYTICAL_DENSITY = _PREFIX + "MEDIAN-QUALITY-SCORE-OF_ANALYTICAL-DENSITY" + MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION_BY_DATE_MONTH = _PREFIX + "MEDIAN-QUALITY-SCORE-EACH-DIMENSION-BY-DATE-MONTH" + MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION_BY_DATE = _PREFIX + "MEDIAN-QUALITY-SCORE-EACH-DIMENSION-BY-DATE" + MEDIAN_QUALITY_SCORE_OVER_TIME = _PREFIX + "MEDIAN-QUALITY-SCORE-OVER-TIME" + MEDIAN_QUALITY_SCORE_BY_GEO_AREA = _PREFIX + "MEDIAN-QUALITY-SCORE-GEO-AREA" + UNIT_REPORTING_AND_GEOLOCATION = _PREFIX + "UNIT-REPORTING-AND-GEOLOCATION" + UNIT_OF_ANALYSIS_AND_GEOLOCATION = _PREFIX + "UNIT-OF-ANALYSIS-AND-GEOLOCATION" + PROXIMITY_AND_GEOLOCATION = _PREFIX + "PROXIMITY-AND-GEOLOCATION" + SAMPLING_APPROACH_AND_GEOLOCATION = _PREFIX + "SAMPLING-APPROCACH-AND-GEOLOCATION" + MEDIAN_QUALITY_SCORE_OVER_TIME_BY_MONTH = _PREFIX + "MEDIAN-QUALITY-SCORE-OVER-TIME-MONTH" + MEDIAN_QUALITY_SCORE_OF_EACH_DIMENSION = _PREFIX + "MEDIAN-QUALITY-SCORE-EACH-DIMENSION" @classmethod def clear_cache(cls): @@ -126,7 +123,7 @@ def calculate_md5_str(string): @classmethod def generate_hash(cls, item: Union[None, str, dict]) -> str: if item is None: - return '' + return "" hashable = None if isinstance(item, str): hashable = item @@ -136,9 +133,9 @@ def generate_hash(cls, item: Union[None, str, dict]) -> str: sort_keys=True, indent=2, cls=DjangoJSONEncoder, - ).encode('utf-8') + ).encode("utf-8") else: - raise Exception(f'Unknown Type: {type(item)}') + raise Exception(f"Unknown Type: {type(item)}") return cls.calculate_md5_str(hashable) @staticmethod @@ -146,9 +143,7 @@ def gql_cache(cache_key, cache_key_gen=None, timeout=60): def _dec(func): def _caller(*args, **kwargs): if cache_key_gen: - _cache_key = cache_key.format( - cache_key_gen(*args, **kwargs) - ) + _cache_key = cache_key.format(cache_key_gen(*args, **kwargs)) else: _cache_key = cache_key return cache.get_or_set( @@ -156,7 +151,9 @@ def _caller(*args, **kwargs): lambda: func(*args, **kwargs), timeout, ) + _caller.__name__ = func.__name__ _caller.__module__ = func.__module__ return _caller + return _dec diff --git a/deep/celery.py b/deep/celery.py index e4749c0e1e..8c339a5959 100644 --- a/deep/celery.py +++ b/deep/celery.py @@ -1,7 +1,7 @@ import os import sys -import celery +import celery from django.conf import settings from utils import sentry @@ -10,33 +10,30 @@ class Celery(celery.Celery): def on_configure(self): if settings.SENTRY_DSN: - sentry.init_sentry( - app_type='WORKER', - **settings.SENTRY_CONFIG - ) + sentry.init_sentry(app_type="WORKER", **settings.SENTRY_CONFIG) # set the default Django settings module for the 'celery' program. -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'deep.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "deep.settings") sys.path.append(settings.APPS_DIR) -app = Celery('deep') +app = Celery("deep") # Using a string here means the worker doesn't have to serialize # the configuration object to child processes. # - namespace='CELERY' means all celery-related configuration keys # should have a `CELERY_` prefix. -app.config_from_object('django.conf:settings', namespace='CELERY') +app.config_from_object("django.conf:settings", namespace="CELERY") # Load task modules from all registered Django app configs. app.autodiscover_tasks() -app.autodiscover_tasks(['deep']) +app.autodiscover_tasks(["deep"]) # This is used for ECS Cluster (Each queue needs it's own clusters) -class CeleryQueue(): - DEFAULT = 'CELERY-DEFAULT-QUEUE' - EXPORT_HEAVY = 'CELERY-EXPORT-HEAVY-QUEUE' +class CeleryQueue: + DEFAULT = "CELERY-DEFAULT-QUEUE" + EXPORT_HEAVY = "CELERY-EXPORT-HEAVY-QUEUE" ALL_QUEUES = ( DEFAULT, @@ -49,4 +46,4 @@ class CeleryQueue(): @app.task(bind=True) def debug_task(self): - print('Request: {0!r}'.format(self.request)) + print("Request: {0!r}".format(self.request)) diff --git a/deep/compiler.py b/deep/compiler.py index 4a135bb1a8..3834293521 100644 --- a/deep/compiler.py +++ b/deep/compiler.py @@ -18,7 +18,7 @@ def get_group_by(self, select, order_by): # when we have public API way of forcing the GROUP BY clause. # Converts string references to expressions. for expr in self.query.group_by: - if not hasattr(expr, 'as_sql'): + if not hasattr(expr, "as_sql"): expressions.append(self.query.resolve_ref(expr)) else: expressions.append(expr) diff --git a/deep/context_processor.py b/deep/context_processor.py index 8754a5e398..3c4dc3cf1b 100644 --- a/deep/context_processor.py +++ b/deep/context_processor.py @@ -3,6 +3,6 @@ def deep(request): return { - 'request': request, - 'DEEP_ENVIRONMENT': settings.DEEP_ENVIRONMENT, + "request": request, + "DEEP_ENVIRONMENT": settings.DEEP_ENVIRONMENT, } diff --git a/deep/converters.py b/deep/converters.py index b9f8e2115a..a33f72a51a 100644 --- a/deep/converters.py +++ b/deep/converters.py @@ -1,5 +1,5 @@ class FileNameRegex: - regex = '.*' + regex = ".*" def to_python(self, value): return value diff --git a/deep/dataloaders.py b/deep/dataloaders.py index 05109008ce..6ec7e129e3 100644 --- a/deep/dataloaders.py +++ b/deep/dataloaders.py @@ -1,22 +1,21 @@ +from analysis.dataloaders import DataLoaders as AnalysisDataLoaders +from analysis_framework.dataloaders import DataLoaders as AfDataloaders +from assessment_registry.dataloaders import DataLoaders as AssessmentRegistryDataLoaders +from assisted_tagging.dataloaders import DataLoaders as AssistedTaggingLoaders from django.utils.functional import cached_property - -from utils.graphene.dataloaders import WithContextMixin - -from project.dataloaders import DataLoaders as ProjectDataLoaders -from user.dataloaders import DataLoaders as UserDataLoaders -from user_group.dataloaders import DataLoaders as UserGroupDataLoaders -from lead.dataloaders import DataLoaders as LeadDataLoaders from entry.dataloaders import DataLoaders as EntryDataloaders +from gallery.dataloaders import DataLoaders as DeepGalleryDataLoaders +from geo.dataloaders import DataLoaders as GeoDataLoaders +from lead.dataloaders import DataLoaders as LeadDataLoaders +from notification.dataloaders import DataLoaders as AssignmentLoaders from organization.dataloaders import DataLoaders as OrganizationDataLoaders -from analysis_framework.dataloaders import DataLoaders as AfDataloaders +from project.dataloaders import DataLoaders as ProjectDataLoaders from quality_assurance.dataloaders import DataLoaders as QADataLoaders -from geo.dataloaders import DataLoaders as GeoDataLoaders from unified_connector.dataloaders import DataLoaders as UnifiedConnectorDataLoaders -from analysis.dataloaders import DataLoaders as AnalysisDataLoaders -from gallery.dataloaders import DataLoaders as DeepGalleryDataLoaders -from assessment_registry.dataloaders import DataLoaders as AssessmentRegistryDataLoaders -from assisted_tagging.dataloaders import DataLoaders as AssistedTaggingLoaders -from notification.dataloaders import DataLoaders as AssignmentLoaders +from user.dataloaders import DataLoaders as UserDataLoaders +from user_group.dataloaders import DataLoaders as UserGroupDataLoaders + +from utils.graphene.dataloaders import WithContextMixin class GlobalDataLoaders(WithContextMixin): diff --git a/deep/deepl.py b/deep/deepl.py index 78e47a5ab6..cdabbde4f2 100644 --- a/deep/deepl.py +++ b/deep/deepl.py @@ -1,21 +1,20 @@ from django.conf import settings - DEEPL_SERVICE_DOMAIN = settings.DEEPL_SERVICE_DOMAIN DEEPL_SERVER_DOMAIN = settings.DEEPL_SERVER_DOMAIN -class DeeplServiceEndpoint(): +class DeeplServiceEndpoint: # DEEPL Service Endpoints (Existing/Legacy) # NOTE: This will be moved to server endpoints in near future - ASSISTED_TAGGING_MODELS_ENDPOINT = f'{DEEPL_SERVICE_DOMAIN}/model_info' + ASSISTED_TAGGING_MODELS_ENDPOINT = f"{DEEPL_SERVICE_DOMAIN}/model_info" # DEEPL Server Endpoints (New) - ASSISTED_TAGGING_TAGS_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/nlp-tags/' - DOCS_EXTRACTOR_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/text-extraction/' - ANALYSIS_TOPIC_MODEL = f'{DEEPL_SERVER_DOMAIN}/api/v1/topicmodel/' - ANALYSIS_AUTOMATIC_SUMMARY = f'{DEEPL_SERVER_DOMAIN}/api/v1/summarization/' - ANALYSIS_AUTOMATIC_NGRAM = f'{DEEPL_SERVER_DOMAIN}/api/v1/ngrams/' - ANALYSIS_GEO = f'{DEEPL_SERVER_DOMAIN}/api/v1/geolocation/' - ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification/' - ENTRY_EXTRACTION_CLASSIFICATION = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification/' + ASSISTED_TAGGING_TAGS_ENDPOINT = f"{DEEPL_SERVER_DOMAIN}/api/v1/nlp-tags/" + DOCS_EXTRACTOR_ENDPOINT = f"{DEEPL_SERVER_DOMAIN}/api/v1/text-extraction/" + ANALYSIS_TOPIC_MODEL = f"{DEEPL_SERVER_DOMAIN}/api/v1/topicmodel/" + ANALYSIS_AUTOMATIC_SUMMARY = f"{DEEPL_SERVER_DOMAIN}/api/v1/summarization/" + ANALYSIS_AUTOMATIC_NGRAM = f"{DEEPL_SERVER_DOMAIN}/api/v1/ngrams/" + ANALYSIS_GEO = f"{DEEPL_SERVER_DOMAIN}/api/v1/geolocation/" + ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f"{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification/" + ENTRY_EXTRACTION_CLASSIFICATION = f"{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification/" diff --git a/deep/documents_types.py b/deep/documents_types.py index ffe7128a47..dfeb648bf9 100644 --- a/deep/documents_types.py +++ b/deep/documents_types.py @@ -2,43 +2,64 @@ # List of mime types supported in deep # NOTE: also change in frontend -PDF_MIME_TYPES = ['application/pdf'] +PDF_MIME_TYPES = ["application/pdf"] DOCX_MIME_TYPES = [ - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - 'application/wps-office.docx', + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/wps-office.docx", ] MSWORD_MIME_TYPES = [ - 'application/msword', 'application/wps-office.doc', + "application/msword", + "application/wps-office.doc", ] POWERPOINT_MIME_TYPES = [ - 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - 'application/vnd.ms-powerpoint', + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.ms-powerpoint", ] SHEET_MIME_TYPES = [ - 'application/vnd.ms-excel', - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - 'application/wps-office.xlsx', + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/wps-office.xlsx", ] -ODS_MIME_TYPES = ['application/vnd.oasis.opendocument.spreadsheet'] -IMAGE_MIME_TYPES = ['image/png', 'image/jpeg', 'image/fig', 'image/gif'] +ODS_MIME_TYPES = ["application/vnd.oasis.opendocument.spreadsheet"] +IMAGE_MIME_TYPES = ["image/png", "image/jpeg", "image/fig", "image/gif"] CHART_IMAGE_MIME = { - 'png': 'image/png', - 'svg': 'image/svg+xml', + "png": "image/png", + "svg": "image/svg+xml", } # Overall Supported Mime Types DEEP_SUPPORTED_MIME_TYPES = [ - 'application/rtf', 'text/plain', 'font/otf', 'text/csv', - 'application/json', 'application/xml', + "application/rtf", + "text/plain", + "font/otf", + "text/csv", + "application/json", + "application/xml", ] + ( - DOCX_MIME_TYPES + MSWORD_MIME_TYPES + PDF_MIME_TYPES + - POWERPOINT_MIME_TYPES + SHEET_MIME_TYPES + ODS_MIME_TYPES + - IMAGE_MIME_TYPES + DOCX_MIME_TYPES + + MSWORD_MIME_TYPES + + PDF_MIME_TYPES + + POWERPOINT_MIME_TYPES + + SHEET_MIME_TYPES + + ODS_MIME_TYPES + + IMAGE_MIME_TYPES ) DEEP_SUPPORTED_EXTENSIONS = [ - 'docx', 'xlsx', 'pdf', 'pptx', - 'json', 'png', 'jpg', 'jpeg', 'csv', 'txt', - 'geojson', 'zip', 'ods', 'doc', 'xls', + "docx", + "xlsx", + "pdf", + "pptx", + "json", + "png", + "jpg", + "jpeg", + "csv", + "txt", + "geojson", + "zip", + "ods", + "doc", + "xls", ] diff --git a/deep/enums.py b/deep/enums.py index c74a1f67b9..4aa400f05c 100644 --- a/deep/enums.py +++ b/deep/enums.py @@ -1,19 +1,18 @@ import graphene - -from user.enums import enum_map as user_enum_map -from user_group.enums import enum_map as user_group_enum_map -from project.enums import enum_map as project_enum_map +from analysis.enums import enum_map as analysis_enum_map from analysis_framework.enums import enum_map as analysis_framework_enum_map -from lead.enums import enum_map as lead_enum_map +from ary.enums import enum_map as ary_enum_map +from assessment_registry.enums import enum_map as assessment_reg_enum_map +from assisted_tagging.enums import enum_map as assisted_tagging_enum_map from entry.enums import enum_map as entry_enum_map from export.enums import enum_map as export_enum_map -from quality_assurance.enums import enum_map as quality_assurance_enum_map -from analysis.enums import enum_map as analysis_enum_map +from lead.enums import enum_map as lead_enum_map from notification.enums import enum_map as notification_enum_map +from project.enums import enum_map as project_enum_map +from quality_assurance.enums import enum_map as quality_assurance_enum_map from unified_connector.enums import enum_map as unified_connector_enum_map -from assisted_tagging.enums import enum_map as assisted_tagging_enum_map -from ary.enums import enum_map as ary_enum_map -from assessment_registry.enums import enum_map as assessment_reg_enum_map +from user.enums import enum_map as user_enum_map +from user_group.enums import enum_map as user_group_enum_map ENUM_TO_GRAPHENE_ENUM_MAP = { **user_enum_map, @@ -33,20 +32,19 @@ } ENUM_TO_GRAPHENE_ENUM_DESCRIPTION_MAP = { - enum: getattr(enum._meta.enum, '__description__', {}) - for enum in ENUM_TO_GRAPHENE_ENUM_MAP.values() + enum: getattr(enum._meta.enum, "__description__", {}) for enum in ENUM_TO_GRAPHENE_ENUM_MAP.values() } def generate_type_for_enum(name, Enum): EnumMetaType = type( - f'AppEnumCollection{name}', + f"AppEnumCollection{name}", (graphene.ObjectType,), { - 'enum': graphene.NonNull(Enum), - 'label': graphene.NonNull(graphene.String), - 'description': graphene.String(), - } + "enum": graphene.NonNull(Enum), + "label": graphene.NonNull(graphene.String), + "description": graphene.String(), + }, ) return graphene.Field( graphene.List( @@ -59,7 +57,7 @@ def generate_type_for_enum(name, Enum): description=ENUM_TO_GRAPHENE_ENUM_DESCRIPTION_MAP[Enum].get((e.value, e.label)), ) for e in Enum._meta.enum - ] + ], ) @@ -77,4 +75,4 @@ def generate_type_for_enums(name): ) -AppEnumCollection = generate_type_for_enums('AppEnumCollection') +AppEnumCollection = generate_type_for_enums("AppEnumCollection") diff --git a/deep/errors.py b/deep/errors.py index bc9957c781..76241d6003 100644 --- a/deep/errors.py +++ b/deep/errors.py @@ -1,11 +1,11 @@ -from deep import error_codes -from jwt_auth.errors import WARN_EXCEPTIONS as JWT_WARN_EXCEPTIONS from entry.errors import EntryValidationVersionMismatchError +from jwt_auth.errors import WARN_EXCEPTIONS as JWT_WARN_EXCEPTIONS +from deep import error_codes error_code_map = { - 'not_authenticated': error_codes.NOT_AUTHENTICATED, - 'authentication_failed': error_codes.AUTHENTICATION_FAILED, + "not_authenticated": error_codes.NOT_AUTHENTICATED, + "authentication_failed": error_codes.AUTHENTICATION_FAILED, } @@ -23,7 +23,7 @@ def map_error_codes(codes, default=None): if isinstance(codes, str): return error_code_map.get(codes, default) - if codes == {'non_field_errors': ['invalid']}: + if codes == {"non_field_errors": ["invalid"]}: return error_codes.TOKEN_INVALID return default diff --git a/deep/exception_handler.py b/deep/exception_handler.py index 2e9ab9c28c..518d827f4d 100644 --- a/deep/exception_handler.py +++ b/deep/exception_handler.py @@ -1,19 +1,15 @@ -from django.utils import timezone -import sentry_sdk +import logging -from rest_framework.views import exception_handler +import sentry_sdk +from django.utils import timezone +from rest_framework import exceptions, status from rest_framework.response import Response -from rest_framework import status, exceptions - -from deep.errors import map_error_codes, WARN_EXCEPTIONS +from rest_framework.views import exception_handler -import logging +from deep.errors import WARN_EXCEPTIONS, map_error_codes logger = logging.getLogger(__name__) -standard_error_string = ( - 'Something unexpected has occured. ' - 'Please contact an admin to fix this issue.' -) +standard_error_string = "Something unexpected has occured. " "Please contact an admin to fix this issue." def custom_exception_handler(exc, context): @@ -22,14 +18,14 @@ def custom_exception_handler(exc, context): # For 500 errors, we create new response if not response: - request = context.get('request') + request = context.get("request") if request and request.user and request.user.id: with sentry_sdk.configure_scope() as scope: scope.user = { - 'id': request.user.id, - 'email': request.user.email, + "id": request.user.id, + "email": request.user.email, } - scope.set_extra('is_superuser', request.user.is_superuser) + scope.set_extra("is_superuser", request.user.is_superuser) sentry_sdk.capture_exception() response = Response({}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -37,26 +33,25 @@ def custom_exception_handler(exc, context): response.data = {} # Timestamp of exception - response.data['timestamp'] = timezone.now() + response.data["timestamp"] = timezone.now() if isinstance(exc, (exceptions.NotAuthenticated,)): response.status_code = status.HTTP_401_UNAUTHORIZED - elif hasattr(exc, 'status_code'): + elif hasattr(exc, "status_code"): response.status_code = exc.status_code - if hasattr(exc, 'code'): + if hasattr(exc, "code"): # If the raised exception defines a code, send it as # internal error code - response.data['error_code'] = exc.code - elif hasattr(exc, 'get_codes'): + response.data["error_code"] = exc.code + elif hasattr(exc, "get_codes"): # Otherwise, try to map the exception.get_codes() value to an # internal error code. # If no internal code available, return http status code as # internal error code by default. - response.data['error_code'] = map_error_codes( - exc.get_codes(), response.status_code) + response.data["error_code"] = map_error_codes(exc.get_codes(), response.status_code) else: - response.data['error_code'] = response.status_code + response.data["error_code"] = response.status_code # Error message can be defined by the exception as message # or detail attributres @@ -65,53 +60,53 @@ def custom_exception_handler(exc, context): errors = None user_error = None - if hasattr(exc, 'message'): + if hasattr(exc, "message"): errors = exc.message - elif hasattr(exc, 'detail'): + elif hasattr(exc, "detail"): if isinstance(exc.detail, list): errors = [str(error) for error in exc.detail] else: errors = exc.detail - elif hasattr(exc, 'default_detail'): + elif hasattr(exc, "default_detail"): errors = exc.default_detail elif response.status_code == 404: - errors = 'Resource not found' + errors = "Resource not found" else: errors = str(exc) user_error = standard_error_string - if hasattr(exc, 'user_message'): + if hasattr(exc, "user_message"): user_error = exc.user_message # Wrap up string error inside non-field-errors if isinstance(errors, str): errors = { - 'non_field_errors': [errors], + "non_field_errors": [errors], } elif isinstance(errors, list) and all([isinstance(error, str) for error in errors]): errors = { - 'non_field_errors': errors, + "non_field_errors": errors, } if user_error: - errors['internal_non_field_errors'] = errors.get('non_field_errors') - errors['non_field_errors'] = [user_error] + errors["internal_non_field_errors"] = errors.get("non_field_errors") + errors["non_field_errors"] = [user_error] - response.data['errors'] = errors + response.data["errors"] = errors # If there is a link available for the exception, # send back the link as well. - if hasattr(exc, 'link'): - response.data['link'] = exc.link + if hasattr(exc, "link"): + response.data["link"] = exc.link # Logging if any([isinstance(exc, exception) for exception in WARN_EXCEPTIONS]): - logger.warning('API Exception Warning!!', exc_info=True) + logger.warning("API Exception Warning!!", exc_info=True) else: logger.error( - '{}.{}'.format(type(exc).__module__, type(exc).__name__), + "{}.{}".format(type(exc).__module__, type(exc).__name__), exc_info=True, - extra={'request': context.get('request')}, + extra={"request": context.get("request")}, ) return response diff --git a/deep/exceptions.py b/deep/exceptions.py index 1237876c40..79ffb5b45e 100644 --- a/deep/exceptions.py +++ b/deep/exceptions.py @@ -3,21 +3,21 @@ class DeepBaseException(Exception): - default_message = _('Unexpected exception. Contact admin.') + default_message = _("Unexpected exception. Contact admin.") def __init__(self, msg=None, *args, **kwargs): super().__init__(msg or self.default_message, *args, **kwargs) class CustomException(Exception): - default_message = _('You do not have permission to perform this action.') + default_message = _("You do not have permission to perform this action.") def __init__(self, msg=None, *args, **kwargs): super().__init__(msg or self.default_message, *args, **kwargs) class UnauthorizedException(CustomException): - default_message = _('You are not authenticated') + default_message = _("You are not authenticated") code = status.HTTP_401_UNAUTHORIZED diff --git a/deep/filter_set.py b/deep/filter_set.py index 5965685dbf..3ee224a1f1 100644 --- a/deep/filter_set.py +++ b/deep/filter_set.py @@ -1,13 +1,14 @@ +from typing import Tuple + import django_filters import graphene -from typing import Tuple from django import forms from django.db import models - from graphene_django.filter.utils import get_filtering_args_from_filterset + from utils.graphene.fields import ( - generate_object_field_from_input_type, compare_input_output_type_fields, + generate_object_field_from_input_type, ) @@ -16,20 +17,20 @@ def value_from_datadict(self, data, files, name): value = forms.Widget.value_from_datadict(self, data, files, name) if value is not None: - if value == '': # parse empty value as an empty list + if value == "": # parse empty value as an empty list return [] # if value is already list(by POST) elif isinstance(value, list): return value - return [x.strip() for x in value.strip().split(',') if x.strip()] + return [x.strip() for x in value.strip().split(",") if x.strip()] return None -class OrderEnumMixin(): +class OrderEnumMixin: def ordering_filter(self, qs, _, value): for ordering in value: if isinstance(ordering, str): - if ordering.startswith('-'): + if ordering.startswith("-"): _ordering = models.F(ordering[1:]).desc() else: _ordering = models.F(ordering).asc() @@ -43,7 +44,8 @@ def ordering_filter(self, qs, _, value): def get_dummy_request(**kwargs): return type( - 'DummyRequest', (object,), + "DummyRequest", + (object,), kwargs, )() @@ -61,6 +63,7 @@ def generate_type_for_filter_set( - LeadGqlFilterSetInputType - LeadGqlFilterSetType """ + def generate_type_from_input_type(input_type): new_fields_map = generate_object_field_from_input_type(input_type) if custom_new_fields_map: @@ -69,10 +72,6 @@ def generate_type_from_input_type(input_type): compare_input_output_type_fields(input_type, new_type) return new_type - input_type = type( - input_type_name, - (graphene.InputObjectType,), - get_filtering_args_from_filterset(filter_set, used_node) - ) + input_type = type(input_type_name, (graphene.InputObjectType,), get_filtering_args_from_filterset(filter_set, used_node)) _type = generate_type_from_input_type(input_type) return _type, input_type diff --git a/deep/graphene_context.py b/deep/graphene_context.py index c5d742a671..6e6ebfce8e 100644 --- a/deep/graphene_context.py +++ b/deep/graphene_context.py @@ -1,11 +1,9 @@ from django.utils.functional import cached_property -from deep.permissions import ( - ProjectPermissions as PP, - AnalysisFrameworkPermissions as AfP, - UserGroupPermissions as UgP, -) from deep.dataloaders import GlobalDataLoaders +from deep.permissions import AnalysisFrameworkPermissions as AfP +from deep.permissions import ProjectPermissions as PP +from deep.permissions import UserGroupPermissions as UgP class GQLContext: @@ -27,9 +25,7 @@ def set_active_project(self, project): def set_active_af(self, af): self.active_af = self.request.active_af = af - self.af_permissions = AfP.get_permissions( - af.get_current_user_role(self.request.user) - ) + self.af_permissions = AfP.get_permissions(af.get_current_user_role(self.request.user)) def set_active_usergroup(self, user_group): self.active_ug = self.request.active_ug = user_group diff --git a/deep/graphene_converter.py b/deep/graphene_converter.py index 73dca50ff9..7afa76d5f7 100644 --- a/deep/graphene_converter.py +++ b/deep/graphene_converter.py @@ -1,13 +1,15 @@ -from aniso8601 import parse_date, parse_datetime, parse_time import graphene -from graphene.types.generic import GenericScalar +from aniso8601 import parse_date, parse_datetime, parse_time +# For Geo Fields +from django.contrib.gis.db import models as gis_models +from graphene.types.generic import GenericScalar from graphene_django.compat import HStoreField, JSONField, PGJSONField from graphene_django.converter import convert_django_field -from graphene_django_extras.converter import convert_django_field as extra_convert_django_field +from graphene_django_extras.converter import ( + convert_django_field as extra_convert_django_field, +) -# For Geo Fields -from django.contrib.gis.db import models as gis_models from utils.graphene import geo_scalars @@ -30,7 +32,7 @@ def custom_convert_json_field_to_scalar(field, register=None): "LineStringField": geo_scalars.LineStringScalar, "PolygonField": geo_scalars.PolygonScalar, "MultiPolygonField": geo_scalars.MultiPolygonScalar, - "GeometryField": geo_scalars.GISScalar + "GeometryField": geo_scalars.GISScalar, } @@ -41,9 +43,7 @@ def custom_convert_json_field_to_scalar(field, register=None): @convert_django_field.register(gis_models.PointField) def gis_converter(field, registry=None): class_name = field.__class__.__name__ - return GIS_FIELD_SCALAR[class_name]( - required=not field.null, description=field.help_text - ) + return GIS_FIELD_SCALAR[class_name](required=not field.null, description=field.help_text) original_time_serialize = graphene.Time.serialize @@ -52,7 +52,7 @@ def gis_converter(field, registry=None): # Add option to add string as well. -class CustomSerialize(): +class CustomSerialize: @staticmethod def _parse(dt, parse_func): if isinstance(dt, str): @@ -61,21 +61,15 @@ def _parse(dt, parse_func): @classmethod def time(cls, time) -> str: - return original_time_serialize( - cls._parse(time, parse_time) - ) + return original_time_serialize(cls._parse(time, parse_time)) @classmethod def date(cls, date) -> str: - return original_date_serialize( - cls._parse(date, parse_date) - ) + return original_date_serialize(cls._parse(date, parse_date)) @classmethod def datetime(cls, dt) -> str: - return original_datetime_serialize( - cls._parse(dt, parse_datetime) - ) + return original_datetime_serialize(cls._parse(dt, parse_datetime)) graphene.Time.serialize = CustomSerialize.time diff --git a/deep/managers.py b/deep/managers.py index 6f7a66b4bc..c7bb4f8083 100644 --- a/deep/managers.py +++ b/deep/managers.py @@ -1,4 +1,5 @@ from collections import defaultdict + from django.apps import apps @@ -17,7 +18,7 @@ def __init__(self, chunk_size=100): self.chunk_size = chunk_size def _commit(self, _): - raise Exception('This is not implemented yet.') + raise Exception("This is not implemented yet.") def _process_obj(self, obj): return obj @@ -58,7 +59,7 @@ def __init__(self, update_fields, *args, **kwargs): def _process_obj(self, obj): if obj.pk is None: - raise Exception(f'Only object with pk is allowed: {obj}') + raise Exception(f"Only object with pk is allowed: {obj}") return obj def _commit(self, model_class): diff --git a/deep/middleware.py b/deep/middleware.py index 7c0e235c01..816a6906af 100644 --- a/deep/middleware.py +++ b/deep/middleware.py @@ -1,23 +1,22 @@ import logging -import requests import threading -from reversion.views import create_revision -from django.utils import timezone +import requests from django.conf import settings from django.contrib.auth.models import AnonymousUser from django.core.files.storage import get_storage_class +from django.utils import timezone +from reversion.views import create_revision from utils.date_extractor import str_to_date - logger = logging.getLogger(__name__) _threadlocal = threading.local() class RevisionMiddleware: skip_paths = [ - '/api/v1/token/', + "/api/v1/token/", ] def __init__(self, get_response): @@ -43,24 +42,22 @@ def get_s3_signed_url_ttl(): class DeepInnerCacheMiddleware: - EC2_META_URL = 'http://169.254.169.254/latest/meta-data/iam/security-credentials/' - THREAD_S3_SIGNED_URL_TTL_ATTRIBUTE = 'URLCachedFileField__get_cache_ttl' + EC2_META_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials/" + THREAD_S3_SIGNED_URL_TTL_ATTRIBUTE = "URLCachedFileField__get_cache_ttl" @classmethod def get_cache_ttl(cls): - if getattr(get_storage_class()(), 'access_key', None) is not None: + if getattr(get_storage_class()(), "access_key", None) is not None: return settings.MAX_FILE_CACHE_AGE # Assume IAM Role is being used try: iam_role_resp = requests.get(cls.EC2_META_URL, timeout=0.01) if iam_role_resp.status_code == 200: - expiration = str_to_date( - requests.get(cls.EC2_META_URL + iam_role_resp.text, timeout=0.01).json()['Expiration'] - ) + expiration = str_to_date(requests.get(cls.EC2_META_URL + iam_role_resp.text, timeout=0.01).json()["Expiration"]) return max(0, expiration.timestamp() - timezone.now().timestamp()) except requests.exceptions.RequestException: - logger.error('Failed to retrive IAM Role session expiration.', exc_info=True) + logger.error("Failed to retrive IAM Role session expiration.", exc_info=True) # Avoid cache for now (This shouldn't happen) return 0 @@ -75,15 +72,15 @@ def process_view(self, request, view_function, *args, **kwargs): def _do_set_current_request(request_fun): - setattr(_threadlocal, 'request', request_fun.__get__(request_fun, threading.local)) + setattr(_threadlocal, "request", request_fun.__get__(request_fun, threading.local)) def _set_current_request(request=None): - ''' + """ Sets current user in local thread. Can be used as a hook e.g. for shell jobs (when request object is not available). - ''' + """ _do_set_current_request(lambda self: request) diff --git a/deep/models.py b/deep/models.py index 53560a52a6..115f32d240 100644 --- a/deep/models.py +++ b/deep/models.py @@ -8,29 +8,29 @@ class Field(models.Model): is_required = models.BooleanField(default=True) # Fields - STRING = 'string' - NUMBER = 'number' - DATE = 'date' - DATERANGE = 'daterange' - SELECT = 'select' - MULTISELECT = 'multiselect' + STRING = "string" + NUMBER = "number" + DATE = "date" + DATERANGE = "daterange" + SELECT = "select" + MULTISELECT = "multiselect" FIELD_TYPES = ( - (STRING, 'String'), - (NUMBER, 'Number'), - (DATE, 'Date'), - (DATERANGE, 'Date Range'), - (SELECT, 'Select'), - (MULTISELECT, 'Multiselect'), + (STRING, "String"), + (NUMBER, "Number"), + (DATE, "Date"), + (DATERANGE, "Date Range"), + (SELECT, "Select"), + (MULTISELECT, "Multiselect"), ) # Sources - COUNTRIES = 'countries' - ORGANIZATIONS = 'organizations' + COUNTRIES = "countries" + ORGANIZATIONS = "organizations" SOURCE_TYPES = ( - (COUNTRIES, 'Countries'), - (ORGANIZATIONS, 'Organizations'), + (COUNTRIES, "Countries"), + (ORGANIZATIONS, "Organizations"), ) field_type = models.CharField( @@ -42,7 +42,8 @@ class Field(models.Model): source_type = models.CharField( max_length=50, choices=SOURCE_TYPES, - null=True, blank=True, + null=True, + blank=True, default=None, ) @@ -54,11 +55,11 @@ class Meta: def get_options(self): if self.source_type in [type[0] for type in Field.SOURCE_TYPES]: return [] - return [{'key': x.key, 'title': x.title} for x in self.options.all()] + return [{"key": x.key, "title": x.title} for x in self.options.all()] def get_value(self, raw_value): value = raw_value - options = {x['key']: x['title'] for x in self.get_options()} + options = {x["key"]: x["title"] for x in self.get_options()} if self.field_type == Field.SELECT: value = options.get(raw_value, raw_value) elif self.field_type == Field.MULTISELECT: diff --git a/deep/number_generator.py b/deep/number_generator.py index 77030d3577..549cb7f4ae 100644 --- a/deep/number_generator.py +++ b/deep/number_generator.py @@ -1,6 +1,6 @@ -import string import random +import string def client_id_generator(size=16, chars=string.ascii_uppercase + string.digits): - return ''.join(random.choice(chars) for _ in range(size)) + return "".join(random.choice(chars) for _ in range(size)) diff --git a/deep/permalinks.py b/deep/permalinks.py index d17092070e..18a8da5ebd 100644 --- a/deep/permalinks.py +++ b/deep/permalinks.py @@ -4,40 +4,40 @@ class Permalink: # TODO: Add test for permalink generation - BASE_URL = f'{settings.HTTP_PROTOCOL}://{settings.DEEPER_FRONTEND_HOST}/permalink' + BASE_URL = f"{settings.HTTP_PROTOCOL}://{settings.DEEPER_FRONTEND_HOST}/permalink" @classmethod def project(cls, _id): - return f'{cls.BASE_URL}/projects/{_id}' + return f"{cls.BASE_URL}/projects/{_id}" @classmethod def lead(cls, project_id, _id): - return f'{cls.project(project_id)}/leads/{_id}' + return f"{cls.project(project_id)}/leads/{_id}" @classmethod def lead_share_view(cls, uuid): - return f'{cls.BASE_URL}/leads-uuid/{uuid}' + return f"{cls.BASE_URL}/leads-uuid/{uuid}" @classmethod def entry(cls, project_id, lead_id, _id): - return f'{cls.lead(project_id, lead_id)}/entries/{_id}' + return f"{cls.lead(project_id, lead_id)}/entries/{_id}" @classmethod def ientry(cls, entry): - return f'{cls.lead(entry.project_id, entry.lead_id)}/entries/{entry.id}' + return f"{cls.lead(entry.project_id, entry.lead_id)}/entries/{entry.id}" @classmethod def entry_comments(cls, project_id, lead_id, _id): - return f'{cls.entry(project_id, lead_id, _id)}/comments/' + return f"{cls.entry(project_id, lead_id, _id)}/comments/" @classmethod def ientry_comments(cls, entry): - return f'{cls.ientry(entry)}/comments/' + return f"{cls.ientry(entry)}/comments/" @classmethod def entry_comment(cls, project_id, lead_id, entry_id, _id): - return f'{cls.entry(project_id, lead_id, entry_id)}/review-comments/{_id}/' + return f"{cls.entry(project_id, lead_id, entry_id)}/review-comments/{_id}/" @classmethod def ientry_comment(cls, comment): - return f'{cls.ientry(comment.entry)}/review-comments/{comment.id}/' + return f"{cls.ientry(comment.entry)}/review-comments/{comment.id}/" diff --git a/deep/permissions.py b/deep/permissions.py index eb989b6bb5..5398257ed5 100644 --- a/deep/permissions.py +++ b/deep/permissions.py @@ -1,29 +1,23 @@ import logging -from typing import List -from enum import Enum, auto, unique from collections import defaultdict +from enum import Enum, auto, unique +from typing import List +from analysis.models import AnalysisPillar +from analysis_framework.models import AnalysisFrameworkRole from django.db.models import F +from entry.models import Entry +from lead.models import Lead +from project.models import Project, ProjectMembership, ProjectRole +from project.permissions import PROJECT_PERMISSIONS from rest_framework import permissions +from user_group.models import GroupMembership, UserGroup from deep.exceptions import PermissionDeniedException -from project.models import Project, ProjectRole, ProjectMembership -from analysis_framework.models import AnalysisFrameworkRole -from project.permissions import PROJECT_PERMISSIONS -from lead.models import Lead -from entry.models import Entry -from analysis.models import AnalysisPillar -from user_group.models import UserGroup, GroupMembership logger = logging.getLogger(__name__) -METHOD_ACTION_MAP = { - 'PUT': 'modify', - 'PATCH': 'modify', - 'GET': 'view', - 'POST': 'create', - 'DELETE': 'delete' -} +METHOD_ACTION_MAP = {"PUT": "modify", "PATCH": "modify", "GET": "view", "POST": "create", "DELETE": "delete"} class ModifyPermission(permissions.BasePermission): @@ -32,7 +26,7 @@ def has_object_permission(self, request, view, obj): return True action = METHOD_ACTION_MAP[request.method] - objmethod = 'can_{}'.format(action) + objmethod = "can_{}".format(action) if hasattr(obj, objmethod): return getattr(obj, objmethod)(request.user) @@ -41,11 +35,12 @@ def has_object_permission(self, request, view, obj): class CreateLeadPermission(permissions.BasePermission): """Permission class to check if user can create Lead""" + def has_permission(self, request, view): - if request.method != 'POST': + if request.method != "POST": return True # Check project and all - project_id = request.data.get('project') + project_id = request.data.get("project") # If there is no project id, the serializers will give 400 error, no need to forbid here if project_id is None: @@ -62,25 +57,26 @@ def has_permission(self, request, view): # Check if the user has create permissions on all projects # To do this, filter projects in which user has permissions and check if # the returned result length equals the queried projects length - projects_count = Project.objects.filter( - id__in=project_ids, - projectmembership__member=request.user - ).annotate( - create_lead=F('projectmembership__role__lead_permissions').bitand(create_lead_perm_value) - ).filter( - create_lead__gt=0, - ).count() + projects_count = ( + Project.objects.filter(id__in=project_ids, projectmembership__member=request.user) + .annotate(create_lead=F("projectmembership__role__lead_permissions").bitand(create_lead_perm_value)) + .filter( + create_lead__gt=0, + ) + .count() + ) return projects_count == len(project_ids) class DeleteLeadPermission(permissions.BasePermission): """Checks if user can delete lead(s)""" + def has_permission(self, request, view): - if request.method not in ('POST', 'DELETE'): + if request.method not in ("POST", "DELETE"): return True - project_id = view.kwargs.get('project_id') + project_id = view.kwargs.get("project_id") if not project_id: return False @@ -88,31 +84,32 @@ def has_permission(self, request, view): delete_lead_perm_value = PROJECT_PERMISSIONS.lead.delete # Check if the user has delete permissions on all projects - return Project.objects.filter( - id=project_id, - projectmembership__member=request.user - ).annotate( - delete_lead=F('projectmembership__role__lead_permissions').bitand(delete_lead_perm_value) - ).filter( - delete_lead__gt=0, - ).exists() + return ( + Project.objects.filter(id=project_id, projectmembership__member=request.user) + .annotate(delete_lead=F("projectmembership__role__lead_permissions").bitand(delete_lead_perm_value)) + .filter( + delete_lead__gt=0, + ) + .exists() + ) class CreateEntryPermission(permissions.BasePermission): """Permission class to check if user can create Lead""" + def get_project_id(self, request): """Try getting project id first from the data itself, if not try to get it from lead """ - project_id = request.data.get('project') + project_id = request.data.get("project") if project_id: return project_id # Else, get it from lead - lead = Lead.objects.filter(id=request.data.get('lead')).first() + lead = Lead.objects.filter(id=request.data.get("lead")).first() return lead and lead.project.id def has_permission(self, request, view): - if request.method != 'POST': + if request.method != "POST": return True # Get project id from request @@ -123,34 +120,39 @@ def has_permission(self, request, view): return False create_entry_perm_value = PROJECT_PERMISSIONS.entry.create - return ProjectRole.objects.annotate( - create_entry=F('entry_permissions').bitand(create_entry_perm_value) - ).filter( - projectmembership__project_id=project_id, - projectmembership__member=request.user, - create_entry__gt=0, - ).exists() + return ( + ProjectRole.objects.annotate(create_entry=F("entry_permissions").bitand(create_entry_perm_value)) + .filter( + projectmembership__project_id=project_id, + projectmembership__member=request.user, + create_entry__gt=0, + ) + .exists() + ) class CreateAssessmentPermission(permissions.BasePermission): """Permission class to check if user can create Lead""" + def has_permission(self, request, view): - if request.method != 'POST': + if request.method != "POST": return True # Check project and all - project_id = request.data.get('project') + project_id = request.data.get("project") # If there is no project id, the serializers will give 400 error, no need to forbid here if project_id is None: return True create_assmt_perm_value = PROJECT_PERMISSIONS.assessment.create - return ProjectRole.objects.annotate( - create_entry=F('assessment_permissions').bitand(create_assmt_perm_value) - ).filter( - projectmembership__project_id=project_id, - projectmembership__member=request.user, - create_entry__gt=0, - ).exists() + return ( + ProjectRole.objects.annotate(create_entry=F("assessment_permissions").bitand(create_assmt_perm_value)) + .filter( + projectmembership__project_id=project_id, + projectmembership__member=request.user, + create_entry__gt=0, + ) + .exists() + ) class IsSuperAdmin(permissions.BasePermission): @@ -162,13 +164,13 @@ def has_object_permission(self, request, view, obj): class IsProjectMember(permissions.BasePermission): - message = 'Only allowed for Project members' + message = "Only allowed for Project members" def has_permission(self, request, view): - project_id = view.kwargs.get('project_id') - lead_id = view.kwargs.get('lead_id') - entry_id = view.kwargs.get('entry_id') - analysis_pillar_id = view.kwargs.get('analysis_pillar_id') + project_id = view.kwargs.get("project_id") + lead_id = view.kwargs.get("lead_id") + entry_id = view.kwargs.get("entry_id") + analysis_pillar_id = view.kwargs.get("analysis_pillar_id") if project_id: return Project.get_for_member(request.user).filter(id=project_id).exists() @@ -178,8 +180,7 @@ def has_permission(self, request, view): return Entry.get_for(request.user).filter(id=entry_id).exists() elif analysis_pillar_id: return AnalysisPillar.objects.filter( - analysis__project__projectmembership__member=request.user, - id=analysis_pillar_id + analysis__project__projectmembership__member=request.user, id=analysis_pillar_id ).exists() return True @@ -190,10 +191,10 @@ def has_object_permission(self, request, view, obj): class IsUserGroupMember(permissions.BasePermission): - message = 'Only allowed for UserGroup members' + message = "Only allowed for UserGroup members" def has_permission(self, request, view): - user_group_id = view.kwargs.get('pk') + user_group_id = view.kwargs.get("pk") if user_group_id: return UserGroup.get_for_member(request.user).filter(id=user_group_id).exists() return True @@ -201,15 +202,17 @@ def has_permission(self, request, view): # ---------------------------- GRAPHQL Permissions ------------------------------ -class BasePermissions(): + +class BasePermissions: # ------------ Define this after using this as base ----------- @unique class Permission(Enum): pass + __error_message__ = {} PERMISSION_MAP = {} - CONTEXT_PERMISSION_ATTR = '' + CONTEXT_PERMISSION_ATTR = "" # ------------ Define this after using this as base ----------- DEFAULT_PERMISSION_DENIED_MESSAGE = PermissionDeniedException.default_message @@ -273,7 +276,7 @@ class Permission(Enum): CREATE_ASSESSMENT_REGISTRY = auto() UPDATE_ASSESSMENT_REGISTRY = auto() - Permission.__name__ = 'ProjectPermission' + Permission.__name__ = "ProjectPermission" __error_message__ = { Permission.UPDATE_PROJECT: "You don't have permission to update project", @@ -357,7 +360,7 @@ class Permission(Enum): REVERSE_PERMISSION_MAP[permission].append(_role_type) REVERSE_PERMISSION_MAP[permission.value].append(_role_type) - CONTEXT_PERMISSION_ATTR = 'project_permissions' + CONTEXT_PERMISSION_ATTR = "project_permissions" @classmethod def get_permissions(cls, project, user) -> List[Permission]: @@ -365,11 +368,7 @@ def get_permissions(cls, project, user) -> List[Permission]: badges = project.get_current_user_badges(user) or [] if role is None: return [] - badges_permissions = [ - cls.BADGES_PERMISSION_MAP[badge] - for badge in badges - if badge in cls.BADGES_PERMISSION_MAP - ] + badges_permissions = [cls.BADGES_PERMISSION_MAP[badge] for badge in badges if badge in cls.BADGES_PERMISSION_MAP] return [ *cls.PERMISSION_MAP.get(role, []), *badges_permissions, @@ -386,7 +385,7 @@ class Permission(Enum): CAN_USE_IN_OTHER_PROJECTS = auto() DELETE_FRAMEWORK = auto() - Permission.__name__ = 'AnalysisFrameworkPermission' + Permission.__name__ = "AnalysisFrameworkPermission" __error_message__ = { Permission.CAN_ADD_USER: "You don't have permission to add user", @@ -413,10 +412,9 @@ class Permission(Enum): AnalysisFrameworkRole.Type.PRIVATE_EDITOR: PRIVATE_EDITOR, AnalysisFrameworkRole.Type.PRIVATE_OWNER: PRIVATE_OWNER, AnalysisFrameworkRole.Type.PRIVATE_VIEWER: PRIVATE_VIEWER, - } - CONTEXT_PERMISSION_ATTR = 'af_permissions' + CONTEXT_PERMISSION_ATTR = "af_permissions" @classmethod def get_permissions(cls, role, is_public=False): @@ -431,7 +429,7 @@ class UserGroupPermissions(BasePermissions): class Permission(Enum): CAN_ADD_USER = auto() - Permission.__name__ = 'UserGroupPermission' + Permission.__name__ = "UserGroupPermission" __error_message__ = { Permission.CAN_ADD_USER: "You don't have permission to update memberships", @@ -445,7 +443,7 @@ class Permission(Enum): GroupMembership.Role.NORMAL: NORMAL, } - CONTEXT_PERMISSION_ATTR = 'ug_permissions' + CONTEXT_PERMISSION_ATTR = "ug_permissions" @classmethod def get_permissions(cls, role): diff --git a/deep/s3_storages.py b/deep/s3_storages.py index 803e45caaa..77378e8c8d 100644 --- a/deep/s3_storages.py +++ b/deep/s3_storages.py @@ -4,7 +4,7 @@ class StaticStorage(S3Boto3Storage): location = settings.STATICFILES_LOCATION - default_acl = 'public-read' + default_acl = "public-read" bucket_name = settings.AWS_STORAGE_BUCKET_NAME_STATIC querystring_auth = False diff --git a/deep/schema.py b/deep/schema.py index dd3fcdb751..95e2e71a30 100644 --- a/deep/schema.py +++ b/deep/schema.py @@ -1,6 +1,6 @@ import graphene -from graphene_django.debug import DjangoDebug from django.conf import settings +from graphene_django.debug import DjangoDebug # Importing for initialization (Make sure to import this before apps.<>) """ @@ -8,27 +8,34 @@ Make sure use string import outside graphene files. For eg: In filters.py use 'entry.schema.EntryListType' instead of `from entry.schema import EntryListType' """ -from .graphene_converter import * # type: ignore # noqa F401 -from utils.graphene.resolver import * # type: ignore # noqa F401 - -from project import schema as pj_schema, mutation as pj_mutation -from lead import public_schema as lead_public_schema -from analysis_framework import mutation as af_mutation, schema as af_schema from analysis import public_schema as analysis_public_schema -from user import mutation as user_mutation, schema as user_schema -from user_group import mutation as user_group_mutation, schema as user_group_schema -from organization import schema as organization_schema, mutation as organization_mutation -from geo import schema as geo_schema -from notification import schema as notification_schema, mutation as notification_mutation -from assisted_tagging import schema as assisted_tagging_schema -from unified_connector import schema as unified_connector_schema -from export import schema as export_schema, mutation as export_mutation +from analysis_framework import mutation as af_mutation +from analysis_framework import schema as af_schema from assessment_registry import mutation as assessment_registry_mutation from assessment_registry import schema as assessment_registry_schema +from assisted_tagging import schema as assisted_tagging_schema from deep_explore import schema as deep_explore_schema +from export import mutation as export_mutation +from export import schema as export_schema from gallery import mutations as gallery_mutation +from geo import schema as geo_schema +from lead import public_schema as lead_public_schema +from notification import mutation as notification_mutation +from notification import schema as notification_schema +from organization import mutation as organization_mutation +from organization import schema as organization_schema +from project import mutation as pj_mutation +from project import schema as pj_schema +from unified_connector import schema as unified_connector_schema +from user import mutation as user_mutation +from user import schema as user_schema +from user_group import mutation as user_group_mutation +from user_group import schema as user_group_schema from deep.enums import AppEnumCollection +from utils.graphene.resolver import * # type: ignore # noqa F401 + +from .graphene_converter import * # type: ignore # noqa F401 class Query( @@ -46,7 +53,7 @@ class Query( analysis_public_schema.Query, assessment_registry_schema.Query, # -- - graphene.ObjectType + graphene.ObjectType, ): assisted_tagging = graphene.Field(assisted_tagging_schema.AssistedTaggingRootQueryType) enums = graphene.Field(AppEnumCollection) @@ -74,7 +81,7 @@ class Mutation( assessment_registry_mutation.Mutation, organization_mutation.Mutation, # -- - graphene.ObjectType + graphene.ObjectType, ): pass diff --git a/deep/serializers.py b/deep/serializers.py index e0419b8799..957057e793 100644 --- a/deep/serializers.py +++ b/deep/serializers.py @@ -1,12 +1,16 @@ import json -from django.utils.functional import cached_property -from django.core.files.storage import FileSystemStorage, get_storage_class, default_storage -from django.core.serializers.json import DjangoJSONEncoder from django.core.cache import cache +from django.core.files.storage import ( + FileSystemStorage, + default_storage, + get_storage_class, +) +from django.core.serializers.json import DjangoJSONEncoder +from django.utils.functional import cached_property from rest_framework import serializers -from deep.caches import local_cache, CacheKey +from deep.caches import CacheKey, local_cache from deep.middleware import get_s3_signed_url_ttl StorageClass = get_storage_class() @@ -19,17 +23,10 @@ def remove_null(d): if isinstance(d, list): return [v for v in (remove_null(v) for v in d) if v is not None] - return { - k: v - for k, v in ( - (k, remove_null(v)) - for k, v in d.items() - ) - if v is not None - } + return {k: v for k, v in ((k, remove_null(v)) for k, v in d.items()) if v is not None} -class RemoveNullFieldsMixin(): +class RemoveNullFieldsMixin: def to_representation(self, instance): rep = super().to_representation(instance) return remove_null(rep) @@ -40,7 +37,7 @@ def to_internal_value(self, data): for field, field_type in self.fields.items(): if isinstance(field_type, serializers.CharField): if field in data and not data.get(field): - data[field] = '' + data[field] = "" return super().to_internal_value(data) @@ -129,15 +126,16 @@ def StringListField(): ) -class WriteOnlyOnCreateSerializerMixin(): +class WriteOnlyOnCreateSerializerMixin: """ Allow to define fields only writable on creation """ + def get_fields(self, *args, **kwargs): fields = super().get_fields(*args, **kwargs) - write_only_on_create_fields = getattr(self.Meta, 'write_only_on_create_fields', []) - request = self.context.get('request', None) - if request and getattr(request, 'method', None) != 'POST': + write_only_on_create_fields = getattr(self.Meta, "write_only_on_create_fields", []) + request = self.context.get("request", None) + if request and getattr(request, "method", None) != "POST": for field in write_only_on_create_fields: fields[field].read_only = True return fields @@ -147,6 +145,7 @@ class TempClientIdMixin(serializers.ModelSerializer): """ ClientId for serializer level only, storing to database is optional (if field exists). """ + client_id = serializers.CharField(required=False) @staticmethod @@ -159,14 +158,14 @@ def get_cache_key(instance, request): def _get_temp_client_id(self, validated_data): # For now, let's not save anything. Look at history if not. - return validated_data.pop('client_id', None) + return validated_data.pop("client_id", None) def create(self, validated_data): temp_client_id = self._get_temp_client_id(validated_data) instance = super().create(validated_data) if temp_client_id: instance.client_id = temp_client_id - local_cache.set(self.get_cache_key(instance, self.context['request']), temp_client_id, 60) + local_cache.set(self.get_cache_key(instance, self.context["request"]), temp_client_id, 60) return instance def update(self, instance, validated_data): @@ -174,7 +173,7 @@ def update(self, instance, validated_data): instance = super().update(instance, validated_data) if temp_client_id: instance.client_id = temp_client_id - local_cache.set(self.get_cache_key(instance, self.context['request']), temp_client_id, 60) + local_cache.set(self.get_cache_key(instance, self.context["request"]), temp_client_id, 60) return instance @@ -183,19 +182,19 @@ class ProjectPropertySerializerMixin(serializers.Serializer): @cached_property def project(self): - project = self.context['request'].active_project + project = self.context["request"].active_project # This is a rare case, just to make sure this is validated if self.instance: model_with_project = self.instance if self.project_property_attribute: model_with_project = getattr(self.instance, self.project_property_attribute) if model_with_project is None or model_with_project.project != project: - raise serializers.ValidationError('Invalid access') + raise serializers.ValidationError("Invalid access") return project @cached_property def current_user(self): - return self.context['request'].user + return self.context["request"].user class IntegerIDField(serializers.IntegerField): @@ -203,6 +202,7 @@ class IntegerIDField(serializers.IntegerField): This field is created to override the graphene conversion of the integerfield -> graphene.ID check out utils/graphene/mutation.py """ + pass @@ -211,6 +211,7 @@ class StringIDField(serializers.CharField): This field is created to override the graphene conversion of the charField -> graphene.ID check out utils/graphene/mutation.py """ + pass @@ -221,12 +222,12 @@ def __init__(self, **kwargs): def to_internal_value(self, data): try: - if self.binary or getattr(data, 'is_json_string', False): + if self.binary or getattr(data, "is_json_string", False): if isinstance(data, bytes): data = data.decode() return json.loads(data, cls=self.decoder) else: data = json.loads(json.dumps(data, cls=self.encoder)) except (TypeError, ValueError): - self.fail('invalid') + self.fail("invalid") return data diff --git a/deep/ses.py b/deep/ses.py index 9138775cde..b501f10737 100644 --- a/deep/ses.py +++ b/deep/ses.py @@ -1,21 +1,18 @@ import json -import typing import logging +import typing from django.http import JsonResponse from django.views.decorators.csrf import csrf_exempt - from sns_message_validator import ( - InvalidMessageTypeException, InvalidCertURLException, + InvalidMessageTypeException, InvalidSignatureVersionException, SignatureVerificationFailureException, SNSMessageValidator, ) - from user.models import Profile - logger = logging.getLogger(__name__) @@ -24,47 +21,47 @@ def verify_sns_payload(request) -> typing.Tuple[str, int]: # Validate message type from header without having to parse the request body. - message_type = request.headers.get('x-amz-sns-message-type') + message_type = request.headers.get("x-amz-sns-message-type") try: sns_message_validator.validate_message_type(message_type) - message = json.loads(request.body.decode('utf-8')) + message = json.loads(request.body.decode("utf-8")) sns_message_validator.validate_message(message=message) except InvalidMessageTypeException: - return 'Invalid message type.', 400 + return "Invalid message type.", 400 except json.decoder.JSONDecodeError: - return 'Request body is not in json format.', 400 + return "Request body is not in json format.", 400 except InvalidCertURLException: - return 'Invalid certificate URL.', 400 + return "Invalid certificate URL.", 400 except InvalidSignatureVersionException: - return 'Unexpected signature version.', 400 + return "Unexpected signature version.", 400 except SignatureVerificationFailureException: - return 'Failed to verify signature.', 400 - return 'Success', 200 + return "Failed to verify signature.", 400 + return "Success", 200 @csrf_exempt def ses_bounce_handler_view(request): - if request.method != 'POST': - return JsonResponse({'message': f'{request.method} Method not allowed'}, status=405) + if request.method != "POST": + return JsonResponse({"message": f"{request.method} Method not allowed"}, status=405) error_message, status_code = verify_sns_payload(request) if status_code != 200: - logger.warning(f'Failed to handle bounce request: {error_message}') - return JsonResponse({'message': error_message}, status=status_code) + logger.warning(f"Failed to handle bounce request: {error_message}") + return JsonResponse({"message": error_message}, status=status_code) - body = json.loads(request.body.decode('utf-8')) - if 'SubscribeURL' in body: + body = json.loads(request.body.decode("utf-8")) + if "SubscribeURL" in body: logger.warning(f'Verify subscription using this url: {body["SubscribeURL"]}') - return JsonResponse({'message': 'Logged'}, status=200) + return JsonResponse({"message": "Logged"}, status=200) - message = json.loads(body['Message']) - notification_type = message['notificationType'] - if notification_type == 'Bounce': - recipients = message['bounce']['bouncedRecipients'] - bounce_type = message['bounce']['bounceType'] - if bounce_type == 'Permanent': + message = json.loads(body["Message"]) + notification_type = message["notificationType"] + if notification_type == "Bounce": + recipients = message["bounce"]["bouncedRecipients"] + bounce_type = message["bounce"]["bounceType"] + if bounce_type == "Permanent": for recipient in recipients: - email_address = recipient['emailAddress'] + email_address = recipient["emailAddress"] Profile.objects.filter(user__email__iexact=email_address).update(invalid_email=True) - logger.warning(f'Flagged {email_address} as invalid email') - return JsonResponse({'message': 'Success'}, status=200) + logger.warning(f"Flagged {email_address} as invalid email") + return JsonResponse({"message": "Success"}, status=200) diff --git a/deep/settings.py b/deep/settings.py index cc6b76b84d..1d98826418 100644 --- a/deep/settings.py +++ b/deep/settings.py @@ -1,34 +1,36 @@ """ Django settings for deep project. """ + +import json +import logging import os import sys -import logging -import json +from email.utils import parseaddr + import environ from celery.schedules import crontab -from email.utils import parseaddr from utils import sentry -from utils.aws import fetch_db_credentials_from_secret_arn, get_internal_ip as get_aws_internal_ip - +from utils.aws import fetch_db_credentials_from_secret_arn +from utils.aws import get_internal_ip as get_aws_internal_ip # Build paths inside the project like this: os.path.join(BASE_DIR, ...) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -APPS_DIR = os.path.join(BASE_DIR, 'apps') -TEMP_DIR = '/tmp' +APPS_DIR = os.path.join(BASE_DIR, "apps") +TEMP_DIR = "/tmp" # TODO: Make sure to pull as much from env then default values. env = environ.Env( DJANGO_DEBUG=(bool, False), DJANGO_SECRET_KEY=str, - DEEP_ENVIRONMENT=(str, 'development'), + DEEP_ENVIRONMENT=(str, "development"), SERVICE_ENVIRONMENT_TYPE=str, DEEP_FRONTEND_ARY_HOST=str, DEEP_FRONTEND_HOST=str, DEEP_BACKEND_HOST=str, DJANGO_ALLOWED_HOST=str, - DEEPER_SITE_NAME=(str, 'DEEPER'), + DEEPER_SITE_NAME=(str, "DEEPER"), CORS_ALLOWED_ORIGINS=(list, []), # Database DATABASE_NAME=str, @@ -36,7 +38,7 @@ DATABASE_PASSWORD=str, DATABASE_PORT=str, DATABASE_HOST=str, - DATABASE_SSL_MODE=(str, 'prefer'), # Use `require` in production + DATABASE_SSL_MODE=(str, "prefer"), # Use `require` in production # S3 DJANGO_USE_S3=(bool, False), S3_AWS_ACCESS_KEY_ID=(str, None), @@ -55,7 +57,7 @@ HID_AUTH_URI=str, # Email EMAIL_FROM=str, - DJANGO_ADMINS=(list, ['Admin <admin@thedeep.io>']), + DJANGO_ADMINS=(list, ["Admin <admin@thedeep.io>"]), USE_SES_EMAIL_CONFIG=(bool, False), SES_AWS_ACCESS_KEY_ID=(str, None), SES_AWS_SECRET_ACCESS_KEY=(str, None), @@ -66,12 +68,12 @@ SMTP_EMAIL_USERNAME=str, SMTP_EMAIL_PASSWORD=str, # Hcaptcha - HCAPTCHA_SECRET=(str, '0x0000000000000000000000000000000000000000'), + HCAPTCHA_SECRET=(str, "0x0000000000000000000000000000000000000000"), # Sentry SENTRY_DSN=(str, None), SENTRY_SAMPLE_RATE=(float, 0.2), # Security settings - DEEP_HTTPS=(str, 'http'), + DEEP_HTTPS=(str, "http"), # CSRF_TRUSTED_ORIGINS=(bool, False), SESSION_COOKIE_DOMAIN=str, CSRF_COOKIE_DOMAIN=str, @@ -107,202 +109,202 @@ # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = env('DJANGO_SECRET_KEY') +SECRET_KEY = env("DJANGO_SECRET_KEY") # SECURITY WARNING: don't run with debug turned on in production! -DEBUG = env('DJANGO_DEBUG') +DEBUG = env("DJANGO_DEBUG") -DEEP_ENVIRONMENT = env('COPILOT_ENVIRONMENT_NAME') or env('DEEP_ENVIRONMENT') +DEEP_ENVIRONMENT = env("COPILOT_ENVIRONMENT_NAME") or env("DEEP_ENVIRONMENT") -ALLOWED_HOSTS = ['web', env('DJANGO_ALLOWED_HOST')] +ALLOWED_HOSTS = ["web", env("DJANGO_ALLOWED_HOST")] -DEEPER_FRONTEND_HOST = env('DEEP_FRONTEND_HOST') -DEEPER_FRONTEND_ARY_HOST = env('DEEP_FRONTEND_ARY_HOST') # TODO: Remove this later -DJANGO_API_HOST = env('DEEP_BACKEND_HOST') +DEEPER_FRONTEND_HOST = env("DEEP_FRONTEND_HOST") +DEEPER_FRONTEND_ARY_HOST = env("DEEP_FRONTEND_ARY_HOST") # TODO: Remove this later +DJANGO_API_HOST = env("DEEP_BACKEND_HOST") -DEEPER_SITE_NAME = env('DEEPER_SITE_NAME') -HTTP_PROTOCOL = env('DEEP_HTTPS') +DEEPER_SITE_NAME = env("DEEPER_SITE_NAME") +HTTP_PROTOCOL = env("DEEP_HTTPS") # See if we are inside a test environment (pytest) -PYTEST_XDIST_WORKER = env('PYTEST_XDIST_WORKER') -TESTING = any([ - arg in sys.argv for arg in [ - 'test', - 'pytest', '/usr/local/bin/pytest', - 'py.test', '/usr/local/bin/py.test', - '/usr/local/lib/python3.6/dist-packages/py/test.py', - ] - # Provided by pytest-xdist -]) or PYTEST_XDIST_WORKER is not None -TEST_RUNNER = 'snapshottest.django.TestRunner' -TEST_DIR = os.path.join(BASE_DIR, 'deep/test_files') +PYTEST_XDIST_WORKER = env("PYTEST_XDIST_WORKER") +TESTING = ( + any( + [ + arg in sys.argv + for arg in [ + "test", + "pytest", + "/usr/local/bin/pytest", + "py.test", + "/usr/local/bin/py.test", + "/usr/local/lib/python3.6/dist-packages/py/test.py", + ] + # Provided by pytest-xdist + ] + ) + or PYTEST_XDIST_WORKER is not None +) +TEST_RUNNER = "snapshottest.django.TestRunner" +TEST_DIR = os.path.join(BASE_DIR, "deep/test_files") -PROFILE = env('PROFILE') +PROFILE = env("PROFILE") # Application definition LOCAL_APPS = [ # DEEP APPS - 'analysis', - 'analysis_framework', - 'ary', - 'assessment_registry', - 'category_editor', - 'connector', - 'deep_migration', - 'entry', - 'export', - 'gallery', - 'geo', - 'lang', - 'lead', - 'organization', - 'project', - 'user', - 'user_group', - 'user_resource', - 'tabular', - 'notification', - 'client_page_meta', - 'questionnaire', - 'quality_assurance', - 'unified_connector', - 'assisted_tagging', - 'deep_explore', - 'deepl_integration', - + "analysis", + "analysis_framework", + "ary", + "assessment_registry", + "category_editor", + "connector", + "deep_migration", + "entry", + "export", + "gallery", + "geo", + "lang", + "lead", + "organization", + "project", + "user", + "user_group", + "user_resource", + "tabular", + "notification", + "client_page_meta", + "questionnaire", + "quality_assurance", + "unified_connector", + "assisted_tagging", + "deep_explore", + "deepl_integration", # MISC DEEP APPS - 'bulk_data_migration', - 'profiling', - 'commons', - 'redis_store', - 'jwt_auth', - 'deduplication', + "bulk_data_migration", + "profiling", + "commons", + "redis_store", + "jwt_auth", + "deduplication", ] INSTALLED_APPS = [ # DJANGO APPS - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.messages', - 'django.contrib.sessions', - 'django.contrib.staticfiles', - 'django.contrib.gis', - 'django.contrib.postgres', - + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.messages", + "django.contrib.sessions", + "django.contrib.staticfiles", + "django.contrib.gis", + "django.contrib.postgres", # LIBRARIES - # -> 2-factor-auth - 'django_otp', - 'django_otp.plugins.otp_static', - 'django_otp.plugins.otp_email', - 'django_otp.plugins.otp_totp', - - 'admin_auto_filters', - 'ordered_model', - 'fixture_magic', - 'autofixture', - 'corsheaders', - 'crispy_forms', - 'django_filters', - 'djangorestframework_camel_case', - 'drf_dynamic_fields', - 'rest_framework', - 'generic_relations', # DRF Generic relations - 'reversion', - 'storages', - 'django_premailer', - 'django_celery_beat', - 'jsoneditor', - 'drf_yasg', # API Documentation - 'graphene_django', - 'graphene_graphiql_explorer', + "django_otp", + "django_otp.plugins.otp_static", + "django_otp.plugins.otp_email", + "django_otp.plugins.otp_totp", + "admin_auto_filters", + "ordered_model", + "fixture_magic", + "autofixture", + "corsheaders", + "crispy_forms", + "django_filters", + "djangorestframework_camel_case", + "drf_dynamic_fields", + "rest_framework", + "generic_relations", # DRF Generic relations + "reversion", + "storages", + "django_premailer", + "django_celery_beat", + "jsoneditor", + "drf_yasg", # API Documentation + "graphene_django", + "graphene_graphiql_explorer", ] + [ - '{}.{}.apps.{}Config'.format( - APPS_DIR.split('/')[-1], + "{}.{}.apps.{}Config".format( + APPS_DIR.split("/")[-1], app, - ''.join([word.title() for word in app.split('_')]), - ) for app in LOCAL_APPS + "".join([word.title() for word in app.split("_")]), + ) + for app in LOCAL_APPS ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'corsheaders.middleware.CorsMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django_otp.middleware.OTPMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'deep.middleware.RevisionMiddleware', - 'deep.middleware.DeepInnerCacheMiddleware', - 'deep.middleware.RequestMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "corsheaders.middleware.CorsMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django_otp.middleware.OTPMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "deep.middleware.RevisionMiddleware", + "deep.middleware.DeepInnerCacheMiddleware", + "deep.middleware.RequestMiddleware", ] -ROOT_URLCONF = 'deep.urls' +ROOT_URLCONF = "deep.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [os.path.join(APPS_DIR, 'templates')], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - 'deep.context_processor.deep', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(APPS_DIR, "templates")], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + "deep.context_processor.deep", ], }, }, ] -WSGI_APPLICATION = 'deep.wsgi.application' +WSGI_APPLICATION = "deep.wsgi.application" -IN_AWS_COPILOT_ECS = not not env('COPILOT_SERVICE_NAME') +IN_AWS_COPILOT_ECS = not not env("COPILOT_SERVICE_NAME") -if IN_AWS_COPILOT_ECS and env('SERVICE_ENVIRONMENT_TYPE') == 'web': - ALLOWED_HOSTS.append( - get_aws_internal_ip(env('SERVICE_ENVIRONMENT_TYPE')) - ) +if IN_AWS_COPILOT_ECS and env("SERVICE_ENVIRONMENT_TYPE") == "web": + ALLOWED_HOSTS.append(get_aws_internal_ip(env("SERVICE_ENVIRONMENT_TYPE"))) # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases if IN_AWS_COPILOT_ECS: - DBCLUSTER_SECRET = ( - env.json('DEEP_DATABASE_SECRET') or - fetch_db_credentials_from_secret_arn(env('DEEP_DATABASE_SECRET_ARN')) - ) + DBCLUSTER_SECRET = env.json("DEEP_DATABASE_SECRET") or fetch_db_credentials_from_secret_arn(env("DEEP_DATABASE_SECRET_ARN")) DATABASES = { - 'default': { - 'ENGINE': 'django.contrib.gis.db.backends.postgis', + "default": { + "ENGINE": "django.contrib.gis.db.backends.postgis", # in the workflow environment - 'NAME': DBCLUSTER_SECRET['dbname'], - 'USER': DBCLUSTER_SECRET['username'], - 'PASSWORD': DBCLUSTER_SECRET['password'], - 'HOST': DBCLUSTER_SECRET['host'], - 'PORT': DBCLUSTER_SECRET['port'], - 'OPTIONS': { - 'sslmode': 'require', + "NAME": DBCLUSTER_SECRET["dbname"], + "USER": DBCLUSTER_SECRET["username"], + "PASSWORD": DBCLUSTER_SECRET["password"], + "HOST": DBCLUSTER_SECRET["host"], + "PORT": DBCLUSTER_SECRET["port"], + "OPTIONS": { + "sslmode": "require", }, } } else: DATABASES = { - 'default': { - 'ENGINE': 'django.contrib.gis.db.backends.postgis', - 'NAME': env('DATABASE_NAME'), - 'USER': env('DATABASE_USER'), - 'PASSWORD': env('DATABASE_PASSWORD'), - 'PORT': env('DATABASE_PORT'), - 'HOST': env('DATABASE_HOST'), - 'OPTIONS': { - 'sslmode': env('DATABASE_SSL_MODE'), + "default": { + "ENGINE": "django.contrib.gis.db.backends.postgis", + "NAME": env("DATABASE_NAME"), + "USER": env("DATABASE_USER"), + "PASSWORD": env("DATABASE_PASSWORD"), + "PORT": env("DATABASE_PORT"), + "HOST": env("DATABASE_HOST"), + "OPTIONS": { + "sslmode": env("DATABASE_SSL_MODE"), }, } } @@ -312,76 +314,63 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.' - 'UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation." "UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'MinimumLengthValidator', + "NAME": "django.contrib.auth.password_validation." "MinimumLengthValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation." "CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation." "NumericPasswordValidator", }, # NOTE: Using django admin panel for password reset/change { - 'NAME': 'user.validators.CustomMaximumLengthValidator', + "NAME": "user.validators.CustomMaximumLengthValidator", }, ] # Authentication REST_FRAMEWORK = { - 'DEFAULT_AUTHENTICATION_CLASSES': ( + "DEFAULT_AUTHENTICATION_CLASSES": ( # TODO: REMOVE THIS!! User client to authenticate. - 'rest_framework.authentication.BasicAuthentication', - 'rest_framework.authentication.SessionAuthentication', + "rest_framework.authentication.BasicAuthentication", + "rest_framework.authentication.SessionAuthentication", # 'jwt_auth.authentication.JwtAuthentication', ), - 'EXCEPTION_HANDLER': 'deep.exception_handler.custom_exception_handler', - 'DEFAULT_RENDERER_CLASSES': [ - 'djangorestframework_camel_case.render.CamelCaseJSONRenderer', + "EXCEPTION_HANDLER": "deep.exception_handler.custom_exception_handler", + "DEFAULT_RENDERER_CLASSES": [ + "djangorestframework_camel_case.render.CamelCaseJSONRenderer", ], - 'DEFAULT_PARSER_CLASSES': ( - 'djangorestframework_camel_case.parser.CamelCaseJSONParser', - 'djangorestframework_camel_case.parser.CamelCaseFormParser', - 'djangorestframework_camel_case.parser.CamelCaseMultiPartParser', + "DEFAULT_PARSER_CLASSES": ( + "djangorestframework_camel_case.parser.CamelCaseJSONParser", + "djangorestframework_camel_case.parser.CamelCaseFormParser", + "djangorestframework_camel_case.parser.CamelCaseMultiPartParser", ), - 'JSON_UNDERSCOREIZE': { - 'no_underscore_before_number': True, + "JSON_UNDERSCOREIZE": { + "no_underscore_before_number": True, }, - - 'DEFAULT_VERSIONING_CLASS': - 'rest_framework.versioning.URLPathVersioning', - 'DEFAULT_FILTER_BACKENDS': ( - 'django_filters.rest_framework.DjangoFilterBackend', - ), - - 'DEFAULT_PAGINATION_CLASS': - 'rest_framework.pagination.LimitOffsetPagination', - 'PAGE_SIZE': 10000, - - 'TEST_REQUEST_DEFAULT_FORMAT': 'json', + "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning", + "DEFAULT_FILTER_BACKENDS": ("django_filters.rest_framework.DjangoFilterBackend",), + "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", + "PAGE_SIZE": 10000, + "TEST_REQUEST_DEFAULT_FORMAT": "json", } if DEBUG: - REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append( - 'rest_framework.renderers.BrowsableAPIRenderer' - ) + REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append("rest_framework.renderers.BrowsableAPIRenderer") # Crispy forms for better django filters rendering -CRISPY_TEMPLATE_PACK = 'bootstrap3' +CRISPY_TEMPLATE_PACK = "bootstrap3" -DEFAULT_VERSION = 'v1' +DEFAULT_VERSION = "v1" # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -390,9 +379,9 @@ USE_TZ = True LANGUAGES = ( - ('en-us', 'English (US)'), - ('es-ES', 'Spanish'), - ('np', 'Nepali'), + ("en-us", "English (US)"), + ("es-ES", "Spanish"), + ("np", "Nepali"), ) @@ -403,199 +392,201 @@ # NOTE: S3 have max 7 days for signed url (https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html) # noqa GALLERY_FILE_EXPIRE = 60 * 60 * 24 * 2 -if env('DJANGO_USE_S3'): +if env("DJANGO_USE_S3"): # AWS S3 Bucket Credentials - AWS_STORAGE_BUCKET_NAME_STATIC = env('AWS_STORAGE_BUCKET_NAME_STATIC') - AWS_STORAGE_BUCKET_NAME_MEDIA = env('AWS_STORAGE_BUCKET_NAME_MEDIA') + AWS_STORAGE_BUCKET_NAME_STATIC = env("AWS_STORAGE_BUCKET_NAME_STATIC") + AWS_STORAGE_BUCKET_NAME_MEDIA = env("AWS_STORAGE_BUCKET_NAME_MEDIA") # If environment variable are not provided, then EC2 Role will be used. - AWS_S3_SECRET = ( - env.json('DEEP_BUCKET_ACCESS_USER_SECRET') or - ( - env('DEEP_BUCKET_ACCESS_USER_SECRET_ARN') and - fetch_db_credentials_from_secret_arn(env('DEEP_BUCKET_ACCESS_USER_SECRET_ARN'), ignore_error=True) - ) + AWS_S3_SECRET = env.json("DEEP_BUCKET_ACCESS_USER_SECRET") or ( + env("DEEP_BUCKET_ACCESS_USER_SECRET_ARN") + and fetch_db_credentials_from_secret_arn(env("DEEP_BUCKET_ACCESS_USER_SECRET_ARN"), ignore_error=True) ) if AWS_S3_SECRET: - AWS_ACCESS_KEY_ID = AWS_S3_SECRET['AccessKeyId'] - AWS_SECRET_ACCESS_KEY = AWS_S3_SECRET['SecretAccessKey'] + AWS_ACCESS_KEY_ID = AWS_S3_SECRET["AccessKeyId"] + AWS_SECRET_ACCESS_KEY = AWS_S3_SECRET["SecretAccessKey"] else: - AWS_ACCESS_KEY_ID = env('S3_AWS_ACCESS_KEY_ID') - AWS_SECRET_ACCESS_KEY = env('S3_AWS_SECRET_ACCESS_KEY') - AWS_S3_ENDPOINT_URL = env('S3_AWS_ENDPOINT_URL') if DEBUG else None + AWS_ACCESS_KEY_ID = env("S3_AWS_ACCESS_KEY_ID") + AWS_SECRET_ACCESS_KEY = env("S3_AWS_SECRET_ACCESS_KEY") + AWS_S3_ENDPOINT_URL = env("S3_AWS_ENDPOINT_URL") if DEBUG else None AWS_S3_FILE_OVERWRITE = False - AWS_DEFAULT_ACL = 'private' + AWS_DEFAULT_ACL = "private" AWS_QUERYSTRING_AUTH = True AWS_S3_CUSTOM_DOMAIN = None AWS_QUERYSTRING_EXPIRE = GALLERY_FILE_EXPIRE - AWS_S3_SIGNATURE_VERSION = 's3v4' + AWS_S3_SIGNATURE_VERSION = "s3v4" AWS_IS_GZIPPED = True GZIP_CONTENT_TYPES = [ - 'text/css', 'text/javascript', 'application/javascript', 'application/x-javascript', 'image/svg+xml', - 'application/json', + "text/css", + "text/javascript", + "application/javascript", + "application/x-javascript", + "image/svg+xml", + "application/json", ] # Static configuration - STATICFILES_LOCATION = 'static' + STATICFILES_LOCATION = "static" STATIC_URL = "https://%s/%s/" % (AWS_S3_CUSTOM_DOMAIN, STATICFILES_LOCATION) - STATICFILES_STORAGE = 'deep.s3_storages.StaticStorage' + STATICFILES_STORAGE = "deep.s3_storages.StaticStorage" # Media configuration - MEDIAFILES_LOCATION = 'media' + MEDIAFILES_LOCATION = "media" MEDIA_URL = "https://%s/%s/" % (AWS_S3_CUSTOM_DOMAIN, MEDIAFILES_LOCATION) - DEFAULT_FILE_STORAGE = 'deep.s3_storages.MediaStorage' + DEFAULT_FILE_STORAGE = "deep.s3_storages.MediaStorage" else: - STATIC_URL = '/static/' - STATIC_ROOT = '/static' + STATIC_URL = "/static/" + STATIC_ROOT = "/static" - MEDIA_URL = '/media/' - MEDIA_ROOT = '/media' + MEDIA_URL = "/media/" + MEDIA_ROOT = "/media" STATICFILES_DIRS = [ - os.path.join(APPS_DIR, 'static'), + os.path.join(APPS_DIR, "static"), ] if IN_AWS_COPILOT_ECS: ELASTIC_REDIS_URL = f"redis://{env('ELASTI_CACHE_ADDRESS')}:{env('ELASTI_CACHE_PORT')}" - CELERY_REDIS_URL = f'{ELASTIC_REDIS_URL}/0' - DJANGO_CACHE_REDIS_URL = f'{ELASTIC_REDIS_URL}/1' + CELERY_REDIS_URL = f"{ELASTIC_REDIS_URL}/0" + DJANGO_CACHE_REDIS_URL = f"{ELASTIC_REDIS_URL}/1" else: - CELERY_REDIS_URL = env('CELERY_REDIS_URL') - DJANGO_CACHE_REDIS_URL = env('DJANGO_CACHE_REDIS_URL') + CELERY_REDIS_URL = env("CELERY_REDIS_URL") + DJANGO_CACHE_REDIS_URL = env("DJANGO_CACHE_REDIS_URL") -TEST_DJANGO_CACHE_REDIS_URL = env('TEST_DJANGO_CACHE_REDIS_URL') +TEST_DJANGO_CACHE_REDIS_URL = env("TEST_DJANGO_CACHE_REDIS_URL") # CELERY CONFIG "redis://:{password}@{host}:{port}/{db}" CELERY_BROKER_URL = CELERY_REDIS_URL CELERY_RESULT_BACKEND = CELERY_REDIS_URL CELERY_TIMEZONE = TIME_ZONE -CELERY_EVENT_QUEUE_PREFIX = 'deep-celery-' +CELERY_EVENT_QUEUE_PREFIX = "deep-celery-" CELERY_ACKS_LATE = True CELERY_BEAT_SCHEDULE = { - 'retry_connector_leads': { - 'task': 'unified_connector.tasks.retry_connector_leads', + "retry_connector_leads": { + "task": "unified_connector.tasks.retry_connector_leads", # Every 2 hour - 'schedule': crontab(minute=0, hour='*/2'), + "schedule": crontab(minute=0, hour="*/2"), }, - 'sync_tag_data_with_deepl': { - 'task': 'assisted_tagging.tasks.sync_tags_with_deepl_task', + "sync_tag_data_with_deepl": { + "task": "assisted_tagging.tasks.sync_tags_with_deepl_task", # Every 6 hour - 'schedule': crontab(minute=0, hour='*/6'), + "schedule": crontab(minute=0, hour="*/6"), }, - 'remaining_tabular_generate_columns_image': { - 'task': 'tabular.tasks.remaining_tabular_generate_columns_image', + "remaining_tabular_generate_columns_image": { + "task": "tabular.tasks.remaining_tabular_generate_columns_image", # Every 6 hour - 'schedule': crontab(minute=0, hour='*/6'), + "schedule": crontab(minute=0, hour="*/6"), }, - 'project_generate_stats': { - 'task': 'project.tasks.generate_project_stats_cache', + "project_generate_stats": { + "task": "project.tasks.generate_project_stats_cache", # Every 5 min - 'schedule': crontab(minute="*/5"), + "schedule": crontab(minute="*/5"), }, # UNIFIED CONNECTORS - 'schedule_trigger_quick_unified_connectors': { - 'task': 'unified_connector.tasks.schedule_trigger_quick_unified_connectors', + "schedule_trigger_quick_unified_connectors": { + "task": "unified_connector.tasks.schedule_trigger_quick_unified_connectors", # Every 1 hour - 'schedule': crontab(hour="*/1"), + "schedule": crontab(hour="*/1"), }, - 'schedule_trigger_heavy_unified_connectors': { - 'task': 'unified_connector.tasks.schedule_trigger_heavy_unified_connectors', + "schedule_trigger_heavy_unified_connectors": { + "task": "unified_connector.tasks.schedule_trigger_heavy_unified_connectors", # Every 1 hour - 'schedule': crontab(hour="*/1"), + "schedule": crontab(hour="*/1"), }, - 'schedule_trigger_super_heavy_unified_connectors': { - 'task': 'unified_connector.tasks.schedule_trigger_super_heavy_unified_connectors', + "schedule_trigger_super_heavy_unified_connectors": { + "task": "unified_connector.tasks.schedule_trigger_super_heavy_unified_connectors", # Every 6 hours - 'schedule': crontab(hour="*/6"), + "schedule": crontab(hour="*/6"), }, - 'schedule_trigger_remaining_lead_extract': { - 'task': 'lead.tasks.remaining_lead_extract', + "schedule_trigger_remaining_lead_extract": { + "task": "lead.tasks.remaining_lead_extract", # Every 6 hours - 'schedule': crontab(hour="*/6"), + "schedule": crontab(hour="*/6"), }, # Project Deletion - 'permanently_delete_projects': { - 'task': 'project.tasks.permanently_delete_projects', - 'schedule': crontab(minute=0, hour=0), # execute every day + "permanently_delete_projects": { + "task": "project.tasks.permanently_delete_projects", + "schedule": crontab(minute=0, hour=0), # execute every day }, # User Deletion - 'permanently_delete_users': { - 'task': 'project.tasks.permanently_delete_users', - 'schedule': crontab(minute=0, hour=0), + "permanently_delete_users": { + "task": "project.tasks.permanently_delete_users", + "schedule": crontab(minute=0, hour=0), }, # Organization - 'update_organization_popularity': { - 'task': 'organization.tasks.update_organization_popularity', - 'schedule': crontab(minute=0, hour=0), # execute every day + "update_organization_popularity": { + "task": "organization.tasks.update_organization_popularity", + "schedule": crontab(minute=0, hour=0), # execute every day }, # Lead indexing for deduplication - 'index_leads': { - 'task': 'deduplication.tasks.indexing.create_indices', - 'schedule': crontab(minute=0, hour=2), # execute every second hour of the day + "index_leads": { + "task": "deduplication.tasks.indexing.create_indices", + "schedule": crontab(minute=0, hour=2), # execute every second hour of the day }, # Deep Explore - 'update_deep_explore_entries_count_by_geo_aggreagate_task': { - 'task': 'deep_explore.tasks.update_deep_explore_entries_count_by_geo_aggreagate_task', + "update_deep_explore_entries_count_by_geo_aggreagate_task": { + "task": "deep_explore.tasks.update_deep_explore_entries_count_by_geo_aggreagate_task", # Every day at 01:00 - 'schedule': crontab(minute=0, hour=1), + "schedule": crontab(minute=0, hour=1), }, - 'update_public_deep_explore_snapshot': { - 'task': 'deep_explore.tasks.update_public_deep_explore_snapshot', + "update_public_deep_explore_snapshot": { + "task": "deep_explore.tasks.update_public_deep_explore_snapshot", # Every day at 01:00 - 'schedule': crontab(minute=0, hour=1), + "schedule": crontab(minute=0, hour=1), }, - 'schedule_tracker_data_handler': { - 'task': 'deep.trackers.schedule_tracker_data_handler', + "schedule_tracker_data_handler": { + "task": "deep.trackers.schedule_tracker_data_handler", # Every 6 hours - 'schedule': crontab(hour="*/6"), + "schedule": crontab(hour="*/6"), }, } -CELERY_BEAT_SCHEDULER = 'django_celery_beat.schedulers:DatabaseScheduler' +CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler" if IN_AWS_COPILOT_ECS: - CELERY_BEAT_SCHEDULE.update({ - 'push_celery_cloudwatch_metric': { - 'task': 'deep.tasks.put_celery_query_metric', - # Every minute - 'schedule': crontab(minute='*/1'), - }, - }) + CELERY_BEAT_SCHEDULE.update( + { + "push_celery_cloudwatch_metric": { + "task": "deep.tasks.put_celery_query_metric", + # Every minute + "schedule": crontab(minute="*/1"), + }, + } + ) CACHES = { - 'default': { - 'BACKEND': 'django_redis.cache.RedisCache', - 'LOCATION': DJANGO_CACHE_REDIS_URL, - 'OPTIONS': { - 'CLIENT_CLASS': 'django_redis.client.DefaultClient', + "default": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": DJANGO_CACHE_REDIS_URL, + "OPTIONS": { + "CLIENT_CLASS": "django_redis.client.DefaultClient", }, - 'KEY_PREFIX': 'dj_cache-', + "KEY_PREFIX": "dj_cache-", + }, + "local-memory": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", }, - 'local-memory': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - } } # RELIEF WEB -RELIEFWEB_APPNAME = 'thedeep.io' +RELIEFWEB_APPNAME = "thedeep.io" # HID CONFIGS [NOTE: Update config in React too] -HID_CLIENT_ID = env('HID_CLIENT_ID') -HID_CLIENT_REDIRECT_URL = env('HID_CLIENT_REDIRECT_URL') -HID_AUTH_URI = env('HID_AUTH_URI') +HID_CLIENT_ID = env("HID_CLIENT_ID") +HID_CLIENT_REDIRECT_URL = env("HID_CLIENT_REDIRECT_URL") +HID_AUTH_URI = env("HID_AUTH_URI") def add_username_attribute(record): """ Append username(email) to logs """ - record.username = 'UNK_USER' - if hasattr(record, 'request'): - if hasattr(record.request, 'user') and\ - not record.request.user.is_anonymous: + record.username = "UNK_USER" + if hasattr(record, "request"): + if hasattr(record.request, "user") and not record.request.user.is_anonymous: record.username = record.request.user.username else: - record.username = 'Anonymous_User' + record.username = "Anonymous_User" return True @@ -603,115 +594,116 @@ def add_username_attribute(record): logging.getLogger("pdfminer").setLevel(logging.WARNING) if IN_AWS_COPILOT_ECS: - format_args = [env('SERVICE_ENVIRONMENT_TYPE')] + format_args = [env("SERVICE_ENVIRONMENT_TYPE")] LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'filters': { - 'add_username_attribute': { - '()': 'django.utils.log.CallbackFilter', - 'callback': add_username_attribute, + "version": 1, + "disable_existing_loggers": False, + "filters": { + "add_username_attribute": { + "()": "django.utils.log.CallbackFilter", + "callback": add_username_attribute, } }, - 'formatters': { - 'simple': { - 'format': '%(asctime)s DJANGO-{}: - %(levelname)s - %(name)s - [%(username)s] %(message)s'.format( + "formatters": { + "simple": { + "format": "%(asctime)s DJANGO-{}: - %(levelname)s - %(name)s - [%(username)s] %(message)s".format( *format_args, ), - 'datefmt': '%Y-%m-%dT%H:%M:%S', + "datefmt": "%Y-%m-%dT%H:%M:%S", }, - 'profiling': { - 'format': '%(asctime)s PROFILING-{}: %(message)s'.format(*format_args), - 'datefmt': '%Y-%m-%dT%H:%M:%S', + "profiling": { + "format": "%(asctime)s PROFILING-{}: %(message)s".format(*format_args), + "datefmt": "%Y-%m-%dT%H:%M:%S", }, }, - 'handlers': { - 'SysLog': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'filters': ['add_username_attribute'], - 'formatter': 'simple', + "handlers": { + "SysLog": { + "level": "INFO", + "class": "logging.StreamHandler", + "filters": ["add_username_attribute"], + "formatter": "simple", }, - 'ProfilingSysLog': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'formatter': 'profiling', + "ProfilingSysLog": { + "level": "INFO", + "class": "logging.StreamHandler", + "formatter": "profiling", }, }, - 'loggers': { + "loggers": { **{ app: { - 'handlers': ['SysLog'], - 'propagate': True, + "handlers": ["SysLog"], + "propagate": True, } - for app in LOCAL_APPS + ['deep', 'utils', 'celery', 'django'] + for app in LOCAL_APPS + ["deep", "utils", "celery", "django"] }, - 'profiling': { - 'handlers': ['ProfilingSysLog'], - 'level': 'INFO', - 'propagate': True, + "profiling": { + "handlers": ["ProfilingSysLog"], + "level": "INFO", + "propagate": True, }, - } + }, } else: + def log_render_extra_context(record): """ Append extra->context to logs """ - if hasattr(record, 'context'): - record.context = f' - {str(record.context)}' + if hasattr(record, "context"): + record.context = f" - {str(record.context)}" else: - record.context = '' + record.context = "" return True LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'filters': { - 'render_extra_context': { - '()': 'django.utils.log.CallbackFilter', - 'callback': log_render_extra_context, + "version": 1, + "disable_existing_loggers": False, + "filters": { + "render_extra_context": { + "()": "django.utils.log.CallbackFilter", + "callback": log_render_extra_context, } }, - 'formatters': { - 'colored_verbose': { - '()': 'colorlog.ColoredFormatter', - 'format': ( + "formatters": { + "colored_verbose": { + "()": "colorlog.ColoredFormatter", + "format": ( "%(log_color)s%(levelname)-8s%(red)s%(module)-8s%(reset)s %(asctime)s %(blue)s%(message)s %(context)s" - ) + ), }, }, - 'handlers': { - 'console': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'filters': ['render_extra_context'], + "handlers": { + "console": { + "level": "INFO", + "class": "logging.StreamHandler", + "filters": ["render_extra_context"], }, - 'colored_console': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'formatter': 'colored_verbose', - 'filters': ['render_extra_context'], + "colored_console": { + "level": "INFO", + "class": "logging.StreamHandler", + "formatter": "colored_verbose", + "filters": ["render_extra_context"], }, }, - 'loggers': { + "loggers": { **{ app: { - 'handlers': ['colored_console'], - 'level': 'INFO', - 'propagate': True, + "handlers": ["colored_console"], + "level": "INFO", + "propagate": True, } - for app in LOCAL_APPS + ['deep', 'utils', 'celery', 'django'] + for app in LOCAL_APPS + ["deep", "utils", "celery", "django"] }, - 'profiling': { - 'handlers': ['colored_console'], - 'level': 'DEBUG', - 'propagate': True, + "profiling": { + "handlers": ["colored_console"], + "level": "DEBUG", + "propagate": True, }, }, } -CORS_ALLOWED_ORIGINS = env('CORS_ALLOWED_ORIGINS') +CORS_ALLOWED_ORIGINS = env("CORS_ALLOWED_ORIGINS") # CORS CONFIGS if DEBUG and not CORS_ALLOWED_ORIGINS: @@ -721,57 +713,57 @@ def log_render_extra_context(record): r"^https://[\w-]+\.thedeep\.io$", ] -CORS_URLS_REGEX = r'(^/api/.*$)|(^/media/.*$)|(^/graphql$)' +CORS_URLS_REGEX = r"(^/api/.*$)|(^/media/.*$)|(^/graphql$)" CORS_ALLOW_CREDENTIALS = True CORS_ALLOW_METHODS = ( - 'DELETE', - 'GET', - 'OPTIONS', - 'PATCH', - 'POST', - 'PUT', + "DELETE", + "GET", + "OPTIONS", + "PATCH", + "POST", + "PUT", ) CORS_ALLOW_HEADERS = ( - 'accept', - 'accept-encoding', - 'authorization', - 'content-type', - 'dnt', - 'origin', - 'user-agent', - 'x-csrftoken', - 'x-requested-with', - 'sentry-trace', + "accept", + "accept-encoding", + "authorization", + "content-type", + "dnt", + "origin", + "user-agent", + "x-csrftoken", + "x-requested-with", + "sentry-trace", ) # Email CONFIGS -USE_SES_EMAIL_CONFIG = env('USE_SES_EMAIL_CONFIG') -USE_SMTP_EMAIL_CONFIG = env('USE_SMTP_EMAIL_CONFIG') -DEFAULT_FROM_EMAIL = EMAIL_FROM = env('EMAIL_FROM') +USE_SES_EMAIL_CONFIG = env("USE_SES_EMAIL_CONFIG") +USE_SMTP_EMAIL_CONFIG = env("USE_SMTP_EMAIL_CONFIG") +DEFAULT_FROM_EMAIL = EMAIL_FROM = env("EMAIL_FROM") -ADMINS = tuple(parseaddr(email) for email in env.list('DJANGO_ADMINS')) +ADMINS = tuple(parseaddr(email) for email in env.list("DJANGO_ADMINS")) if USE_SES_EMAIL_CONFIG and not TESTING: """ Use AWS SES """ - EMAIL_BACKEND = 'django_ses.SESBackend' + EMAIL_BACKEND = "django_ses.SESBackend" # If environment variable are not provided, then EC2 Role will be used. - AWS_SES_ACCESS_KEY_ID = env('SES_AWS_ACCESS_KEY_ID') - AWS_SES_SECRET_ACCESS_KEY = env('SES_AWS_SECRET_ACCESS_KEY') + AWS_SES_ACCESS_KEY_ID = env("SES_AWS_ACCESS_KEY_ID") + AWS_SES_SECRET_ACCESS_KEY = env("SES_AWS_SECRET_ACCESS_KEY") elif USE_SMTP_EMAIL_CONFIG: # Use SMTP instead https://docs.djangoproject.com/en/3.2/topics/email/#smtp-backend - EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' - EMAIL_HOST = env('SMTP_EMAIL_HOST') - EMAIL_PORT = env('SMTP_EMAIL_PORT') - EMAIL_HOST_USER = env('SMTP_EMAIL_USERNAME') - EMAIL_HOST_PASSWORD = env('SMTP_EMAIL_PASSWORD') + EMAIL_BACKEND = "django.core.mail.backends.smtp.EmailBackend" + EMAIL_HOST = env("SMTP_EMAIL_HOST") + EMAIL_PORT = env("SMTP_EMAIL_PORT") + EMAIL_HOST_USER = env("SMTP_EMAIL_USERNAME") + EMAIL_HOST_PASSWORD = env("SMTP_EMAIL_PASSWORD") else: """ DUMP THE EMAIL TO CONSOLE """ - EMAIL_BACKEND = 'django.core.mail.backends.console.EmailBackend' + EMAIL_BACKEND = "django.core.mail.backends.console.EmailBackend" # Gallery files Cache-control max-age - 1hr from s3 @@ -789,25 +781,25 @@ def log_render_extra_context(record): MAX_LOGIN_ATTEMPTS = 10 # https://docs.hcaptcha.com/#integration-testing-test-keys -HCAPTCHA_SECRET = env('HCAPTCHA_SECRET') +HCAPTCHA_SECRET = env("HCAPTCHA_SECRET") # Sentry Config -SENTRY_DSN = env('SENTRY_DSN') -SENTRY_SAMPLE_RATE = env('SENTRY_SAMPLE_RATE') +SENTRY_DSN = env("SENTRY_DSN") +SENTRY_SAMPLE_RATE = env("SENTRY_SAMPLE_RATE") if SENTRY_DSN: SENTRY_CONFIG = { - 'dsn': SENTRY_DSN, - 'send_default_pii': True, - 'release': sentry.fetch_git_sha(BASE_DIR), - 'environment': DEEP_ENVIRONMENT, - 'debug': DEBUG, - 'tags': { - 'site': DJANGO_API_HOST, + "dsn": SENTRY_DSN, + "send_default_pii": True, + "release": sentry.fetch_git_sha(BASE_DIR), + "environment": DEEP_ENVIRONMENT, + "debug": DEBUG, + "tags": { + "site": DJANGO_API_HOST, }, } sentry.init_sentry( - app_type='API', + app_type="API", **SENTRY_CONFIG, ) @@ -818,33 +810,33 @@ def log_render_extra_context(record): DRAFT_ENTRY_EXTRACTION_TIMEOUT_DAYS = 1 CONNECTOR_LEAD_EXTRACTION_TOKEN_RESET_TIMEOUT_DAYS = 1 -JSON_EDITOR_INIT_JS = 'js/jsoneditor-init.js' -LOGIN_URL = '/admin/login' +JSON_EDITOR_INIT_JS = "js/jsoneditor-init.js" +LOGIN_URL = "/admin/login" -OTP_TOTP_ISSUER = f'Deep Admin {DEEP_ENVIRONMENT.title()}' +OTP_TOTP_ISSUER = f"Deep Admin {DEEP_ENVIRONMENT.title()}" OTP_EMAIL_SENDER = EMAIL_FROM -OTP_EMAIL_SUBJECT = 'Deep Admin OTP Token' +OTP_EMAIL_SUBJECT = "Deep Admin OTP Token" REDOC_SETTINGS = { - 'LAZY_RENDERING': True, - 'HIDE_HOSTNAME': True, - 'NATIVE_SCROLLBARS': True, - 'EXPAND_RESPONSES': [], + "LAZY_RENDERING": True, + "HIDE_HOSTNAME": True, + "NATIVE_SCROLLBARS": True, + "EXPAND_RESPONSES": [], } OPEN_API_DOCS_TIMEOUT = 86400 # 24 Hours ANALYTICAL_STATEMENT_COUNT = 30 # max no of analytical statement that can be created ANALYTICAL_ENTRIES_COUNT = 50 # max no of entries that can be created in analytical_statement -DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" # DEBUG TOOLBAR CONFIGURATION DEBUG_TOOLBAR_CONFIG = { - 'DISABLE_PANELS': [ - 'debug_toolbar.panels.sql.SQLPanel', - 'debug_toolbar.panels.staticfiles.StaticFilesPanel', - 'debug_toolbar.panels.redirects.RedirectsPanel', - 'debug_toolbar.panels.templates.TemplatesPanel', + "DISABLE_PANELS": [ + "debug_toolbar.panels.sql.SQLPanel", + "debug_toolbar.panels.staticfiles.StaticFilesPanel", + "debug_toolbar.panels.redirects.RedirectsPanel", + "debug_toolbar.panels.templates.TemplatesPanel", ], } DEBUG_TOOLBAR_PANELS = [ @@ -863,11 +855,11 @@ def log_render_extra_context(record): "debug_toolbar.panels.profiling.ProfilingPanel", ] -if DEBUG and env('DOCKER_HOST_IP') and not TESTING: +if DEBUG and env("DOCKER_HOST_IP") and not TESTING: # https://github.com/flavors/django-graphiql-debug-toolbar#installation # FIXME: If mutation are triggered twice https://github.com/flavors/django-graphiql-debug-toolbar/pull/12/files # FIXME: All request are triggered twice. Creating multiple entries in admin panel as well. - INTERNAL_IPS = [env('DOCKER_HOST_IP')] + INTERNAL_IPS = [env("DOCKER_HOST_IP")] # # JUST FOR Graphiql # INSTALLED_APPS += ['debug_toolbar', 'graphiql_debug_toolbar'] # MIDDLEWARE = ['graphiql_debug_toolbar.middleware.DebugToolbarMiddleware'] + MIDDLEWARE @@ -880,22 +872,22 @@ def log_render_extra_context(record): APPEND_SLASH = True # Security Header configuration -SESSION_COOKIE_NAME = f'deep-{DEEP_ENVIRONMENT}-sessionid' -CSRF_COOKIE_NAME = f'deep-{DEEP_ENVIRONMENT}-csrftoken' +SESSION_COOKIE_NAME = f"deep-{DEEP_ENVIRONMENT}-sessionid" +CSRF_COOKIE_NAME = f"deep-{DEEP_ENVIRONMENT}-csrftoken" SECURE_BROWSER_XSS_FILTER = True SECURE_CONTENT_TYPE_NOSNIFF = True -X_FRAME_OPTIONS = 'DENY' +X_FRAME_OPTIONS = "DENY" CSP_DEFAULT_SRC = ["'self'"] -SECURE_REFERRER_POLICY = 'same-origin' -if HTTP_PROTOCOL == 'https': - SESSION_COOKIE_NAME = f'__Secure-{SESSION_COOKIE_NAME}' +SECURE_REFERRER_POLICY = "same-origin" +if HTTP_PROTOCOL == "https": + SESSION_COOKIE_NAME = f"__Secure-{SESSION_COOKIE_NAME}" SESSION_COOKIE_SECURE = True SESSION_COOKIE_HTTPONLY = True # SECURE_SSL_REDIRECT = True SECURE_HSTS_SECONDS = 30 # TODO: Increase this slowly SECURE_HSTS_INCLUDE_SUBDOMAINS = True SECURE_HSTS_PRELOAD = True - SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') + SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") # NOTE: Client needs to read CSRF COOKIE. # CSRF_COOKIE_NAME = f'__Secure-{CSRF_COOKIE_NAME}' # CSRF_COOKIE_SECURE = True @@ -909,75 +901,75 @@ def log_render_extra_context(record): # https://docs.djangoproject.com/en/3.2/ref/settings/#std:setting-CSRF_USE_SESSIONS # CSRF_USE_SESSIONS = env('CSRF_TRUSTED_ORIGINS') # https://docs.djangoproject.com/en/3.2/ref/settings/#std:setting-SESSION_COOKIE_DOMAIN -SESSION_COOKIE_DOMAIN = env('SESSION_COOKIE_DOMAIN') +SESSION_COOKIE_DOMAIN = env("SESSION_COOKIE_DOMAIN") # https://docs.djangoproject.com/en/3.2/ref/settings/#csrf-cookie-domain -CSRF_COOKIE_DOMAIN = env('CSRF_COOKIE_DOMAIN') +CSRF_COOKIE_DOMAIN = env("CSRF_COOKIE_DOMAIN") # DEEPL Service Config (Existing/Legacy) -DEEPL_SERVICE_DOMAIN = env('DEEPL_SERVICE_DOMAIN') -DEEPL_SERVICE_CALLBACK_DOMAIN = env('DEEPL_SERVICE_CALLBACK_DOMAIN') +DEEPL_SERVICE_DOMAIN = env("DEEPL_SERVICE_DOMAIN") +DEEPL_SERVICE_CALLBACK_DOMAIN = env("DEEPL_SERVICE_CALLBACK_DOMAIN") # DEEPL Server Config (New) -DEEPL_SERVER_TOKEN = env('DEEPL_SERVER_TOKEN') -DEEPL_SERVER_DOMAIN = env('DEEPL_SERVER_DOMAIN') -DEEPL_SERVER_AS_MOCK = env('DEEPL_SERVER_AS_MOCK') -DEEPL_SERVER_CALLBACK_DOMAIN = env('DEEPL_SERVER_CALLBACK_DOMAIN') +DEEPL_SERVER_TOKEN = env("DEEPL_SERVER_TOKEN") +DEEPL_SERVER_DOMAIN = env("DEEPL_SERVER_DOMAIN") +DEEPL_SERVER_AS_MOCK = env("DEEPL_SERVER_AS_MOCK") +DEEPL_SERVER_CALLBACK_DOMAIN = env("DEEPL_SERVER_CALLBACK_DOMAIN") # Graphene configs # WHITELIST following nodes from authentication checks GRAPHENE_NODES_WHITELIST = ( - '__schema', - '__type', - '__typename', + "__schema", + "__type", + "__typename", # custom nodes... - 'enums', - 'login', - 'loginWithHid', - 'register', - 'resetPassword', - 'projectExploreStats', - 'publicProjects', - 'publicProjectsByRegion', - 'publicAnalysisFrameworks', - 'publicOrganizations', - 'publicLead', - 'publicDeepExploreYearlySnapshots', - 'publicDeepExploreGlobalSnapshots', - 'publicAnalysisReportSnapshot', + "enums", + "login", + "loginWithHid", + "register", + "resetPassword", + "projectExploreStats", + "publicProjects", + "publicProjectsByRegion", + "publicAnalysisFrameworks", + "publicOrganizations", + "publicLead", + "publicDeepExploreYearlySnapshots", + "publicDeepExploreGlobalSnapshots", + "publicAnalysisReportSnapshot", ) # https://docs.graphene-python.org/projects/django/en/latest/settings/ GRAPHENE = { - 'ATOMIC_MUTATIONS': True, - 'SCHEMA': 'deep.schema.schema', - 'SCHEMA_OUTPUT': 'schema.json', # defaults to schema.json, - 'CAMELCASE_ERRORS': True, - 'SCHEMA_INDENT': 2, # Defaults to None (displays all data on a single line) - 'MIDDLEWARE': [ - 'utils.graphene.middleware.DisableIntrospectionSchemaMiddleware', - 'utils.sentry.SentryGrapheneMiddleware', - 'utils.graphene.middleware.WhiteListMiddleware', - 'utils.graphene.middleware.ProjectLogMiddleware', + "ATOMIC_MUTATIONS": True, + "SCHEMA": "deep.schema.schema", + "SCHEMA_OUTPUT": "schema.json", # defaults to schema.json, + "CAMELCASE_ERRORS": True, + "SCHEMA_INDENT": 2, # Defaults to None (displays all data on a single line) + "MIDDLEWARE": [ + "utils.graphene.middleware.DisableIntrospectionSchemaMiddleware", + "utils.sentry.SentryGrapheneMiddleware", + "utils.graphene.middleware.WhiteListMiddleware", + "utils.graphene.middleware.ProjectLogMiddleware", ], } if DEBUG: - GRAPHENE['MIDDLEWARE'].append('graphene_django.debug.DjangoDebugMiddleware') + GRAPHENE["MIDDLEWARE"].append("graphene_django.debug.DjangoDebugMiddleware") GRAPHENE_DJANGO_EXTRAS = { - 'DEFAULT_PAGINATION_CLASS': 'graphene_django_extras.paginations.PageGraphqlPagination', - 'DEFAULT_PAGE_SIZE': 20, - 'MAX_PAGE_SIZE': 50, + "DEFAULT_PAGINATION_CLASS": "graphene_django_extras.paginations.PageGraphqlPagination", + "DEFAULT_PAGE_SIZE": 20, + "MAX_PAGE_SIZE": 50, } -UNHCR_PORTAL_API_KEY = env('UNHCR_PORTAL_API_KEY') +UNHCR_PORTAL_API_KEY = env("UNHCR_PORTAL_API_KEY") # Used for project and user deletion -DELETED_USER_FIRST_NAME = 'The Deep' -DELETED_USER_LAST_NAME = 'User' +DELETED_USER_FIRST_NAME = "The Deep" +DELETED_USER_LAST_NAME = "User" USER_AND_PROJECT_DELETE_IN_DAYS = 30 -DELETED_USER_ORGANIZATION = 'The Deep Organization' -DELETED_USER_EMAIL_DOMAIN = 'deleted.thedeep.io' +DELETED_USER_ORGANIZATION = "The Deep Organization" +DELETED_USER_EMAIL_DOMAIN = "deleted.thedeep.io" # MISC -ALLOW_DUMMY_DATA_GENERATION = env('ALLOW_DUMMY_DATA_GENERATION') +ALLOW_DUMMY_DATA_GENERATION = env("ALLOW_DUMMY_DATA_GENERATION") diff --git a/deep/tasks.py b/deep/tasks.py index 0499835542..72c56322cd 100644 --- a/deep/tasks.py +++ b/deep/tasks.py @@ -2,10 +2,11 @@ from collections import defaultdict import boto3 -from django.conf import settings from celery import shared_task +from django.conf import settings -from deep.celery import app as celery_app, CeleryQueue +from deep.celery import CeleryQueue +from deep.celery import app as celery_app def _get_celery_queue_length_metric(): @@ -17,7 +18,7 @@ def _get_celery_queue_length_metric(): ping_response = celery_app.control.inspect().ping() if ping_response is not None: for worker, resp in ping_response.items(): - if resp.get('ok') == 'pong': + if resp.get("ok") == "pong": active_workers.append(worker) # Fetch queue task lengths @@ -29,7 +30,7 @@ def _get_celery_queue_length_metric(): if worker not in active_workers: continue for q in queues: - queues_worker_count[q['name']].append(worker) + queues_worker_count[q["name"]].append(worker) current_timestamp = int(datetime.datetime.now().timestamp()) for queue in CeleryQueue.ALL_QUEUES: @@ -37,21 +38,21 @@ def _get_celery_queue_length_metric(): worker_count = len(queues_worker_count.get(queue, [])) backlog_per_worker = task_count if worker_count != 0: - backlog_per_worker = (task_count / worker_count) + backlog_per_worker = task_count / worker_count yield { - 'MetricName': 'celery-queue-backlog-per-worker', - 'Value': backlog_per_worker, - 'Unit': 'Percent', - 'Timestamp': current_timestamp, - 'Dimensions': [ + "MetricName": "celery-queue-backlog-per-worker", + "Value": backlog_per_worker, + "Unit": "Percent", + "Timestamp": current_timestamp, + "Dimensions": [ { - 'Name': 'Environment', - 'Value': settings.DEEP_ENVIRONMENT, + "Name": "Environment", + "Value": settings.DEEP_ENVIRONMENT, }, { - 'Name': 'Queue', - 'Value': queue, - } + "Name": "Queue", + "Value": queue, + }, ], } @@ -62,8 +63,8 @@ def put_celery_query_metric(): *_get_celery_queue_length_metric(), ] - cloudwatch = boto3.client('cloudwatch') + cloudwatch = boto3.client("cloudwatch") cloudwatch.put_metric_data( - Namespace='DEEP', + Namespace="DEEP", MetricData=metrics, ) diff --git a/deep/tests/__init__.py b/deep/tests/__init__.py index b28537fb62..03ef42c584 100644 --- a/deep/tests/__init__.py +++ b/deep/tests/__init__.py @@ -1 +1 @@ -from .test_case import * # noqa +from .test_case import * # noqa diff --git a/deep/tests/test_api_exception.py b/deep/tests/test_api_exception.py index c902b193f4..3821f78a1d 100644 --- a/deep/tests/test_api_exception.py +++ b/deep/tests/test_api_exception.py @@ -1,15 +1,15 @@ -from deep.tests import TestCase from deep.error_codes import NOT_AUTHENTICATED +from deep.tests import TestCase class ApiExceptionTests(TestCase): def test_notoken_exception(self): - url = '/api/v1/users/{}/'.format(self.user.pk) + url = "/api/v1/users/{}/".format(self.user.pk) data = { - 'password': 'newpassword', + "password": "newpassword", } response = self.client.patch(url, data) self.assertEqual(response.status_code, 401) - self.assertIsNotNone(response.data['timestamp']) - self.assertEqual(response.data['error_code'], NOT_AUTHENTICATED) + self.assertIsNotNone(response.data["timestamp"]) + self.assertEqual(response.data["error_code"], NOT_AUTHENTICATED) diff --git a/deep/tests/test_case.py b/deep/tests/test_case.py index 54d34445c7..d04dbdcd82 100644 --- a/deep/tests/test_case.py +++ b/deep/tests/test_case.py @@ -1,51 +1,47 @@ +import datetime import os import shutil -import autofixture -from rest_framework import ( - test, - status, -) -import datetime +import autofixture +from analysis_framework.models import AnalysisFramework +from ary.models import Assessment, AssessmentTemplate +from django.conf import settings from django.test import override_settings from django.utils import timezone -from django.conf import settings - -from deep.middleware import _set_current_request as _set_middleware_current_request -from user.models import User -from project.models import ProjectRole, Project -from project.permissions import get_project_permissions_value -from lead.models import Lead from entry.models import Entry from gallery.models import File -from analysis_framework.models import AnalysisFramework -from ary.models import AssessmentTemplate, Assessment +from lead.models import Lead +from project.models import Project, ProjectRole +from project.permissions import get_project_permissions_value +from rest_framework import status, test +from user.models import User +from deep.middleware import _set_current_request as _set_middleware_current_request -TEST_MEDIA_ROOT = 'rest-media-temp' +TEST_MEDIA_ROOT = "rest-media-temp" if settings.PYTEST_XDIST_WORKER: - TEST_MEDIA_ROOT = f'rest-media-temp/{settings.PYTEST_XDIST_WORKER}' + TEST_MEDIA_ROOT = f"rest-media-temp/{settings.PYTEST_XDIST_WORKER}" -TEST_EMAIL_BACKEND = 'django.core.mail.backends.console.EmailBackend' -TEST_FILE_STORAGE = 'django.core.files.storage.FileSystemStorage' +TEST_EMAIL_BACKEND = "django.core.mail.backends.console.EmailBackend" +TEST_FILE_STORAGE = "django.core.files.storage.FileSystemStorage" TEST_CACHES = { - 'default': { - 'BACKEND': 'django_redis.cache.RedisCache', - 'LOCATION': settings.TEST_DJANGO_CACHE_REDIS_URL, - 'OPTIONS': { - 'CLIENT_CLASS': 'django_redis.client.DefaultClient', + "default": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": settings.TEST_DJANGO_CACHE_REDIS_URL, + "OPTIONS": { + "CLIENT_CLASS": "django_redis.client.DefaultClient", }, - 'KEY_PREFIX': 'test_dj_cache-', + "KEY_PREFIX": "test_dj_cache-", + }, + "local-memory": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", }, - 'local-memory': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - } } DUMMY_TEST_CACHES = { - 'default': { - 'BACKEND': 'django.core.cache.backends.dummy.DummyCache', - 'LOCATION': 'unique-snowflake', + "default": { + "BACKEND": "django.core.cache.backends.dummy.DummyCache", + "LOCATION": "unique-snowflake", } } @@ -66,32 +62,32 @@ def clean_up_test_media_files(path): DEFAULT_FILE_STORAGE=TEST_FILE_STORAGE, CACHES=TEST_CACHES, CELERY_TASK_ALWAYS_EAGER=True, - DEEPL_SERVER_CALLBACK_DOMAIN='http://testserver', + DEEPL_SERVER_CALLBACK_DOMAIN="http://testserver", ) class TestCase(test.APITestCase): def setUp(self): self.root_user = User.objects.create_user( - username='root@test.com', - first_name='Root', - last_name='Toot', - password='admin123', - email='root@test.com', + username="root@test.com", + first_name="Root", + last_name="Toot", + password="admin123", + email="root@test.com", is_superuser=True, is_staff=True, ) self.user = User.objects.create_user( - username='jon@dave.com', - first_name='Jon', - last_name='Mon', - password='test123', - email='jon@dave.com', + username="jon@dave.com", + first_name="Jon", + last_name="Mon", + password="test123", + email="jon@dave.com", ) # This should be called here to access roles later self.create_project_roles() self.deep_test_files_path = [] # NOTE: CI will clean itself - if os.environ.get('CI', '').lower() != 'true' and not os.path.exists(TEST_MEDIA_ROOT): + if os.environ.get("CI", "").lower() != "true" and not os.path.exists(TEST_MEDIA_ROOT): os.makedirs(TEST_MEDIA_ROOT) super().setUp() @@ -121,14 +117,15 @@ def assertEqualWithWarning(self, expected, real): self.assertEqual(expected, real) except AssertionError: import logging + logger = logging.getLogger(__name__) - logger.warning('', exc_info=True) + logger.warning("", exc_info=True) def assert_http_code(self, response, status_code, msg=None): - error_resp = getattr(response, 'data', None) + error_resp = getattr(response, "data", None) mesg = msg or error_resp - if isinstance(error_resp, dict) and 'errors' in error_resp: - mesg = error_resp['errors'] + if isinstance(error_resp, dict) and "errors" in error_resp: + mesg = error_resp["errors"] return self.assertEqual(response.status_code, status_code, mesg) def assert_200(self, response): @@ -165,21 +162,18 @@ def assert_500(self, response): self.assert_http_code(response, status.HTTP_500_INTERNAL_SERVER_ERROR) def create(self, model, **kwargs): - if not kwargs.get('created_by'): - kwargs['created_by'] = self.user - if not kwargs.get('modified_by'): - kwargs['modified_by'] = self.user + if not kwargs.get("created_by"): + kwargs["created_by"] = self.user + if not kwargs.get("modified_by"): + kwargs["modified_by"] = self.user obj = autofixture.base.AutoFixture( - model, field_values=kwargs, - generate_fk=True, - follow_fk=False, - follow_m2m=False + model, field_values=kwargs, generate_fk=True, follow_fk=False, follow_m2m=False ).create_one() - role = kwargs.get('role') + role = kwargs.get("role") - if role and hasattr(obj, 'add_member'): + if role and hasattr(obj, "add_member"): obj.add_member(self.user, role=role) return obj @@ -189,70 +183,48 @@ def create_project_roles(self): ProjectRole.objects.all().delete() # Creator role self.admin_role = ProjectRole.objects.create( - title='Clairvoyant One', + title="Clairvoyant One", type=ProjectRole.Type.PROJECT_OWNER, - lead_permissions=get_project_permissions_value('lead', '__all__'), - entry_permissions=get_project_permissions_value( - 'entry', '__all__'), - setup_permissions=get_project_permissions_value( - 'setup', '__all__'), - export_permissions=get_project_permissions_value( - 'export', '__all__'), - assessment_permissions=get_project_permissions_value( - 'assessment', '__all__'), + lead_permissions=get_project_permissions_value("lead", "__all__"), + entry_permissions=get_project_permissions_value("entry", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), is_creator_role=True, level=1, ) # Smaller admin role self.smaller_admin_role = ProjectRole.objects.create( - title='Admin', + title="Admin", type=ProjectRole.Type.ADMIN, - lead_permissions=get_project_permissions_value('lead', '__all__'), - entry_permissions=get_project_permissions_value( - 'entry', '__all__'), - setup_permissions=get_project_permissions_value( - 'setup', ['modify']), - export_permissions=get_project_permissions_value( - 'export', '__all__'), - assessment_permissions=get_project_permissions_value( - 'assessment', '__all__'), + lead_permissions=get_project_permissions_value("lead", "__all__"), + entry_permissions=get_project_permissions_value("entry", "__all__"), + setup_permissions=get_project_permissions_value("setup", ["modify"]), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), is_creator_role=True, level=100, ) # Default role self.normal_role = ProjectRole.objects.create( - title='Analyst', + title="Analyst", type=ProjectRole.Type.MEMBER, - lead_permissions=get_project_permissions_value( - 'lead', '__all__'), - entry_permissions=get_project_permissions_value( - 'entry', '__all__'), - setup_permissions=get_project_permissions_value('setup', []), - export_permissions=get_project_permissions_value( - 'export', ['create']), - assessment_permissions=get_project_permissions_value( - 'assessment', '__all__'), + lead_permissions=get_project_permissions_value("lead", "__all__"), + entry_permissions=get_project_permissions_value("entry", "__all__"), + setup_permissions=get_project_permissions_value("setup", []), + export_permissions=get_project_permissions_value("export", ["create"]), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), is_default_role=True, level=100, ) self.view_only_role = ProjectRole.objects.create( - title='Viewer', + title="Viewer", type=ProjectRole.Type.READER, - lead_permissions=get_project_permissions_value( - 'lead', ['view'] - ), - entry_permissions=get_project_permissions_value( - 'entry', ['view'] - ), - setup_permissions=get_project_permissions_value( - 'setup', [] - ), - export_permissions=get_project_permissions_value( - 'export', [] - ), - assessment_permissions=get_project_permissions_value( - 'assessment', ['view'] - ), + lead_permissions=get_project_permissions_value("lead", ["view"]), + entry_permissions=get_project_permissions_value("entry", ["view"]), + setup_permissions=get_project_permissions_value("setup", []), + export_permissions=get_project_permissions_value("export", []), + assessment_permissions=get_project_permissions_value("assessment", ["view"]), ) def post_and_check_201(self, url, data, model, fields): @@ -262,8 +234,7 @@ def post_and_check_201(self, url, data, model, fields): response = self.client.post(url, data) self.assert_201(response) - self.assertEqual(model.objects.count(), model_count + 1), \ - f'One more {model} should have been created' + self.assertEqual(model.objects.count(), model_count + 1), f"One more {model} should have been created" for field in fields: self.assertEqual(response.data[field], data[field]) @@ -276,52 +247,43 @@ def create_user(self, **fields): def create_project(self, **fields): data = { **fields, - 'analysis_framework': fields.pop('analysis_framework', None) or self.create(AnalysisFramework), - 'role': fields.pop('role', self.admin_role), + "analysis_framework": fields.pop("analysis_framework", None) or self.create(AnalysisFramework), + "role": fields.pop("role", self.admin_role), } - if fields.pop('create_assessment_template', False): - data['assessment_template'] = self.create(AssessmentTemplate) + if fields.pop("create_assessment_template", False): + data["assessment_template"] = self.create(AssessmentTemplate) return self.create(Project, **data) def create_gallery_file(self): - url = '/api/v1/files/' + url = "/api/v1/files/" - path = os.path.join(settings.TEST_DIR, 'documents') - self.supported_file = os.path.join(path, 'doc.docx') + path = os.path.join(settings.TEST_DIR, "documents") + self.supported_file = os.path.join(path, "doc.docx") data = { - 'title': 'Test file', - 'file': open(self.supported_file, 'rb'), - 'isPublic': True, + "title": "Test file", + "file": open(self.supported_file, "rb"), + "isPublic": True, } self.authenticate() - self.client.post(url, data, format='multipart') + self.client.post(url, data, format="multipart") file = File.objects.last() self.deep_test_files_path.append(file.file.path) return file def create_lead(self, **fields): - project = fields.pop('project', None) or self.create_project() + project = fields.pop("project", None) or self.create_project() return self.create(Lead, project=project, **fields) def create_entry(self, **fields): - project = fields.pop('project', None) or self.create_project() - lead = fields.pop('lead', None) or self.create_lead(project=project) - return self.create( - Entry, lead=lead, project=lead.project, - analysis_framework=lead.project.analysis_framework, - **fields - ) + project = fields.pop("project", None) or self.create_project() + lead = fields.pop("lead", None) or self.create_lead(project=project) + return self.create(Entry, lead=lead, project=lead.project, analysis_framework=lead.project.analysis_framework, **fields) def create_assessment(self, **fields): - lead = fields.pop('lead', None) or self.create_lead() - return self.create( - Assessment, - lead=lead, - project=lead.project, - **fields - ) + lead = fields.pop("lead", None) or self.create_lead() + return self.create(Assessment, lead=lead, project=lead.project, **fields) def update_obj(self, obj, **fields): for key, value in fields.items(): @@ -330,9 +292,7 @@ def update_obj(self, obj, **fields): return obj def post_filter_test(self, url, filters, count=1, skip_auth=False): - params = { - 'filters': [[k, v] for k, v in filters.items()] - } + params = {"filters": [[k, v] for k, v in filters.items()]} if skip_auth: self.authenticate() @@ -340,14 +300,14 @@ def post_filter_test(self, url, filters, count=1, skip_auth=False): self.assert_200(response) r_data = response.json() - self.assertEqual(len(r_data['results']), count, f'Filters: {filters}') + self.assertEqual(len(r_data["results"]), count, f"Filters: {filters}") return response def get_datetime_str(self, _datetime): - return _datetime.strftime('%Y-%m-%d%z') + return _datetime.strftime("%Y-%m-%d%z") def get_date_str(self, _datetime): - return _datetime.strftime('%Y-%m-%d') + return _datetime.strftime("%Y-%m-%d") def get_aware_datetime(self, *args, **kwargs): return timezone.make_aware(datetime.datetime(*args, **kwargs)) diff --git a/deep/tests/test_fake.py b/deep/tests/test_fake.py index 9ee9759e99..8a218272ac 100644 --- a/deep/tests/test_fake.py +++ b/deep/tests/test_fake.py @@ -6,5 +6,6 @@ class FakeTest(TestCase): This test is for running migrations only docker-compose run --rm server ./manage.py test -v 2 --pattern="deep/tests/test_fake.py" """ + def test_fake(self): pass diff --git a/deep/tests/test_views.py b/deep/tests/test_views.py index a949e5455e..cde160c585 100644 --- a/deep/tests/test_views.py +++ b/deep/tests/test_views.py @@ -1,36 +1,38 @@ import uuid -from django.test import TestCase, Client +from django.test import Client, TestCase from django.urls import reverse +from project.factories import ProjectFactory +from project.models import ProjectStats from deep.views import FrontendView -from project.models import ProjectStats -from project.factories import ProjectFactory class TestIndexView(TestCase): def test_index_view(self): client = Client() - response = client.get('/') - self.assertEqual(response.resolver_match.func.__name__, - FrontendView.as_view().__name__) + response = client.get("/") + self.assertEqual(response.resolver_match.func.__name__, FrontendView.as_view().__name__) class ProjectVizView(TestCase): def test_x_frame_headers(self): client = Client() - url = reverse('server-frontend') + url = reverse("server-frontend") response = client.get(url) # There should be x-frame-options by default in views - assert 'X-Frame-Options' in response.headers + assert "X-Frame-Options" in response.headers project = ProjectFactory.create() stat = ProjectStats.objects.create(project=project, token=uuid.uuid4()) - url = reverse('project-stat-viz-public', kwargs={ - 'project_stat_id': stat.id, - 'token': stat.token, - }) + url = reverse( + "project-stat-viz-public", + kwargs={ + "project_stat_id": stat.id, + "token": stat.token, + }, + ) response = client.get(url) # There should not be x-frame-options in specific views like project-stat-viz-public - assert 'X-Frame-Options' not in response.headers + assert "X-Frame-Options" not in response.headers diff --git a/deep/token.py b/deep/token.py index 07f9fcfade..8fa3bf720e 100644 --- a/deep/token.py +++ b/deep/token.py @@ -1,8 +1,8 @@ from datetime import date from django.conf import settings -from django.db import models from django.contrib.auth.tokens import PasswordResetTokenGenerator +from django.db import models from django.utils.crypto import constant_time_compare from django.utils.http import base36_to_int @@ -12,6 +12,7 @@ class DeepTokenGenerator(PasswordResetTokenGenerator): Strategy object used to generate and check tokens for the deep models mechanism. """ + # key_salt = "deep.token.DeepTokenGenerator" reset_timeout_days = settings.TOKEN_DEFAULT_RESET_TIMEOUT_DAYS secret = settings.SECRET_KEY @@ -40,10 +41,7 @@ def check_token(self, model, token): return False # Check that the timestamp/uid has not been tampered with - if not constant_time_compare( - self._make_token_with_timestamp(model, ts), - token - ): + if not constant_time_compare(self._make_token_with_timestamp(model, ts), token): return False # Check TIMEOUT @@ -53,9 +51,7 @@ def check_token(self, model, token): return True def _make_hash_value(self, model, timestamp): - raise Exception( - "No _make_hash_value defined for Class: " + type(self).__name__ - ) + raise Exception("No _make_hash_value defined for Class: " + type(self).__name__) def _num_days(self, dt): return (dt - date(2001, 1, 1)).days diff --git a/deep/trackers.py b/deep/trackers.py index 6b1fd87be6..8023e2a67a 100644 --- a/deep/trackers.py +++ b/deep/trackers.py @@ -1,17 +1,16 @@ -from typing import Type from enum import Enum, auto, unique +from typing import Type -from django.core.cache import cache -from django.utils import timezone -from django.db import transaction, models from celery import shared_task from dateutil.relativedelta import relativedelta +from django.core.cache import cache +from django.db import models, transaction +from django.utils import timezone +from project.models import Project +from user.models import Profile from deep.caches import CacheKey -from user.models import Profile -from project.models import Project - class TrackerAction: @unique @@ -45,20 +44,20 @@ def update_entity_data_in_bulk( cache_key_prefix: str, field: str, ): - cache_keys = cache.keys(cache_key_prefix + '*') + cache_keys = cache.keys(cache_key_prefix + "*") entities_update = [] for key, value in cache.get_many(cache_keys).items(): entities_update.append( - Model(**{ - 'id': key.split(cache_key_prefix)[1], - field: value, - }) + Model( + **{ + "id": key.split(cache_key_prefix)[1], + field: value, + } + ) ) if entities_update: Model.objects.bulk_update(entities_update, fields=[field], batch_size=200) - transaction.on_commit( - lambda: cache.delete_many(cache_keys) - ) + transaction.on_commit(lambda: cache.delete_many(cache_keys)) def update_project_data_in_bulk(): @@ -66,13 +65,13 @@ def update_project_data_in_bulk(): update_entity_data_in_bulk( Project, CacheKey.Tracker.LAST_PROJECT_READ_ACCESS_DATETIME, - 'last_read_access', + "last_read_access", ) # -- Write update_entity_data_in_bulk( Project, CacheKey.Tracker.LAST_PROJECT_WRITE_ACCESS_DATETIME, - 'last_write_access', + "last_write_access", ) # -- Update project->status using last_write_access # -- -- To active @@ -97,7 +96,7 @@ def update_user_data_in_bulk(): update_entity_data_in_bulk( Profile, CacheKey.Tracker.LAST_USER_ACTIVE_DATETIME, - 'last_active', + "last_active", ) # -- Update user->profile-is_active using last_active # -- -- To active diff --git a/deep/urls.py b/deep/urls.py index 88b1f0ca7b..ebe0e7f144 100644 --- a/deep/urls.py +++ b/deep/urls.py @@ -1,162 +1,157 @@ """deep URL Configuration """ -from django.views.decorators.clickjacking import xframe_options_exempt -from django.views.generic.base import RedirectView -from django.conf.urls import include, static -from django.views.static import serve -from django.contrib.auth import views as auth_views -from django.contrib import admin -from django.conf import settings -from django.urls import path, register_converter, re_path -from django.views.decorators.csrf import csrf_exempt -from rest_framework import routers -from rest_framework import permissions -from drf_yasg.views import get_schema_view -from drf_yasg import openapi -from django_otp.admin import OTPAdminSite - -from . import converters - -# import autofixture - -from user.views import ( - UserViewSet, - PasswordResetView, - user_activate_confirm, - unsubscribe_email, -) -from gallery.views import ( - FileView, - FileViewSet, - GoogleDriveFileViewSet, - DropboxFileViewSet, - FilePreviewViewSet, - FileExtractionTriggerView, - MetaExtractionView, - PrivateFileView, - DeprecatedPrivateFileView, - PublicFileView, -) -from tabular.views import ( - BookViewSet, - SheetViewSet, - FieldViewSet, - GeodataViewSet, - TabularExtractionTriggerView, - TabularGeoProcessTriggerView, -) -from user_group.views import ( - GroupMembershipViewSet, - UserGroupViewSet, -) -from project.views import ( - ProjectMembershipViewSet, - ProjectUserGroupViewSet, - ProjectOptionsView, - ProjectRoleViewSet, - ProjectViewSet, - ProjectStatViewSet, - accept_project_confirm, -) -from geo.views import ( - AdminLevelViewSet, - RegionCloneView, - RegionViewSet, - GeoAreasLoadTriggerView, - GeoJsonView, - GeoBoundsView, - GeoOptionsView, - GeoAreaView -) -from questionnaire.views import ( - QuestionnaireViewSet, - QuestionViewSet, - FrameworkQuestionViewSet, - XFormView, - KoboToolboxExport, -) -from lead.views import ( - LeadGroupViewSet, - LeadViewSet, - LeadBulkDeleteViewSet, - LeadPreviewViewSet, - LeadOptionsView, - LeadExtractionTriggerView, - LeadWebsiteFetch, - LeadCopyView, - WebInfoExtractView, - WebInfoDataView, -) -from entry.views import ( - EntryViewSet, - AttributeViewSet, - FilterDataViewSet, - EntryFilterView, - ExportDataViewSet, - EntryOptionsView, - EditEntriesDataViewSet, - EntryCommentViewSet, - # Entry Grouping - ProjectEntryLabelViewSet, - LeadEntryGroupViewSet, -) from analysis.views import ( - AnalysisViewSet, - AnalysisPillarViewSet, - AnalyticalStatementViewSet, AnalysisPillarDiscardedEntryViewSet, AnalysisPillarEntryViewSet, - DiscardedEntryOptionsView -) -from quality_assurance.views import ( - EntryReviewCommentViewSet, + AnalysisPillarViewSet, + AnalysisViewSet, + AnalyticalStatementViewSet, + DiscardedEntryOptionsView, ) from analysis_framework.views import ( AnalysisFrameworkCloneView, - AnalysisFrameworkViewSet, - PrivateAnalysisFrameworkRoleViewSet, - PublicAnalysisFrameworkRoleViewSet, AnalysisFrameworkMembershipViewSet, + AnalysisFrameworkViewSet, ExportableViewSet, FilterViewSet, + PrivateAnalysisFrameworkRoleViewSet, + PublicAnalysisFrameworkRoleViewSet, WidgetViewSet, ) from ary.views import ( - AssessmentViewSet, - PlannedAssessmentViewSet, + AssessmentCopyView, AssessmentOptionsView, AssessmentTemplateViewSet, + AssessmentViewSet, LeadAssessmentViewSet, LeadGroupAssessmentViewSet, - AssessmentCopyView, + PlannedAssessmentViewSet, ) from category_editor.views import ( - CategoryEditorViewSet, - CategoryEditorCloneView, CategoryEditorClassifyView, + CategoryEditorCloneView, + CategoryEditorViewSet, ) +from client_page_meta.views import PageViewSet +from commons.views import RenderChart from connector.views import ( - SourceViewSet, - SourceQueryView, - ConnectorViewSet, - ConnectorUserViewSet, ConnectorProjectViewSet, -) -from export.views import ( - ExportTriggerView, - ExportViewSet, + ConnectorUserViewSet, + ConnectorViewSet, + SourceQueryView, + SourceViewSet, ) from deepl_integration.views import ( + AnalysisAutomaticSummaryCallbackView, + AnalysisTopicModelCallbackView, + AnalyticalStatementGeoCallbackView, + AnalyticalStatementNGramCallbackView, AssistedTaggingDraftEntryPredictionCallbackView, AutoTaggingDraftEntryPredictionCallbackView, LeadExtractCallbackView, UnifiedConnectorLeadExtractCallbackView, - AnalysisTopicModelCallbackView, - AnalysisAutomaticSummaryCallbackView, - AnalyticalStatementNGramCallbackView, - AnalyticalStatementGeoCallbackView, ) +from django.conf import settings +from django.conf.urls import ( # handler403, handler400, handler500 + handler404, + include, + static, +) +from django.contrib import admin +from django.contrib.auth import views as auth_views +from django.urls import path, re_path, register_converter +from django.views.decorators.clickjacking import xframe_options_exempt +from django.views.decorators.csrf import csrf_exempt +from django.views.generic.base import RedirectView +from django.views.static import serve +from django_otp.admin import OTPAdminSite +from drf_yasg import openapi +from drf_yasg.views import get_schema_view +from entry.views import ( # Entry Grouping + AttributeViewSet, + EditEntriesDataViewSet, + EntryCommentViewSet, + EntryFilterView, + EntryOptionsView, + EntryViewSet, + ExportDataViewSet, + FilterDataViewSet, + LeadEntryGroupViewSet, + ProjectEntryLabelViewSet, +) +from export.views import ExportTriggerView, ExportViewSet +from gallery.views import ( + DeprecatedPrivateFileView, + DropboxFileViewSet, + FileExtractionTriggerView, + FilePreviewViewSet, + FileView, + FileViewSet, + GoogleDriveFileViewSet, + MetaExtractionView, + PrivateFileView, + PublicFileView, +) +from geo.views import ( + AdminLevelViewSet, + GeoAreasLoadTriggerView, + GeoAreaView, + GeoBoundsView, + GeoJsonView, + GeoOptionsView, + RegionCloneView, + RegionViewSet, +) +from jwt_auth.views import HIDTokenObtainPairView, TokenObtainPairView, TokenRefreshView +from lang.views import LanguageViewSet +from lead.views import ( + LeadBulkDeleteViewSet, + LeadCopyView, + LeadExtractionTriggerView, + LeadGroupViewSet, + LeadOptionsView, + LeadPreviewViewSet, + LeadViewSet, + LeadWebsiteFetch, + WebInfoDataView, + WebInfoExtractView, +) +from notification.views import AssignmentViewSet, NotificationViewSet +from organization.views import OrganizationTypeViewSet, OrganizationViewSet +from project.views import ( + ProjectMembershipViewSet, + ProjectOptionsView, + ProjectRoleViewSet, + ProjectStatViewSet, + ProjectUserGroupViewSet, + ProjectViewSet, + accept_project_confirm, +) +from quality_assurance.views import EntryReviewCommentViewSet +from questionnaire.views import ( + FrameworkQuestionViewSet, + KoboToolboxExport, + QuestionnaireViewSet, + QuestionViewSet, + XFormView, +) +from rest_framework import permissions, routers +from tabular.views import ( + BookViewSet, + FieldViewSet, + GeodataViewSet, + SheetViewSet, + TabularExtractionTriggerView, + TabularGeoProcessTriggerView, +) +from user.views import ( + PasswordResetView, + UserViewSet, + unsubscribe_email, + user_activate_confirm, +) +from user_group.views import GroupMembershipViewSet, UserGroupViewSet from deep.ses import ses_bounce_handler_view from deep.views import ( @@ -167,45 +162,20 @@ EntryCommentEmail, EntryReviewCommentEmail, FrontendView, + PasswordChanged, PasswordReset, ProjectJoinRequest, ProjectPublicVizView, - PasswordChanged, get_frontend_url, - graphql_docs -) -from organization.views import ( - OrganizationViewSet, - OrganizationTypeViewSet, -) -from lang.views import ( - LanguageViewSet, -) -from client_page_meta.views import ( - PageViewSet, -) - -from notification.views import ( - NotificationViewSet, - AssignmentViewSet + graphql_docs, ) -from jwt_auth.views import ( - HIDTokenObtainPairView, - TokenObtainPairView, - TokenRefreshView, -) -from commons.views import ( - RenderChart, -) +from . import converters -from django.conf.urls import ( - handler404 - # handler403, handler400, handler500 -) +# import autofixture -register_converter(converters.FileNameRegex, 'filename') +register_converter(converters.FileNameRegex, "filename") handler404 = Api_404View # noqa @@ -215,7 +185,7 @@ api_schema_view = get_schema_view( openapi.Info( title="DEEP API", - default_version='v1', + default_version="v1", description="DEEP API", contact=openapi.Contact(email="admin@thedeep.io"), ), @@ -225,175 +195,129 @@ # User routers -router.register(r'users', UserViewSet, - basename='user') +router.register(r"users", UserViewSet, basename="user") # File routers -router.register(r'files', FileViewSet, - basename='file') -router.register(r'files-google-drive', GoogleDriveFileViewSet, - basename='file_google_drive') -router.register(r'files-dropbox', DropboxFileViewSet, - basename='file_dropbox') -router.register(r'file-previews', FilePreviewViewSet, - basename='file_preview') +router.register(r"files", FileViewSet, basename="file") +router.register(r"files-google-drive", GoogleDriveFileViewSet, basename="file_google_drive") +router.register(r"files-dropbox", DropboxFileViewSet, basename="file_dropbox") +router.register(r"file-previews", FilePreviewViewSet, basename="file_preview") # Tabular routers -router.register(r'tabular-books', BookViewSet, - basename='tabular_book') -router.register(r'tabular-sheets', SheetViewSet, - basename='tabular_sheet') -router.register(r'tabular-fields', FieldViewSet, - basename='tabular_field') -router.register(r'tabular-geodatas', GeodataViewSet, - basename='tabular_geodata') +router.register(r"tabular-books", BookViewSet, basename="tabular_book") +router.register(r"tabular-sheets", SheetViewSet, basename="tabular_sheet") +router.register(r"tabular-fields", FieldViewSet, basename="tabular_field") +router.register(r"tabular-geodatas", GeodataViewSet, basename="tabular_geodata") # User group registers -router.register(r'user-groups', UserGroupViewSet, - basename='user_group') -router.register(r'group-memberships', GroupMembershipViewSet, - basename='group_membership') +router.register(r"user-groups", UserGroupViewSet, basename="user_group") +router.register(r"group-memberships", GroupMembershipViewSet, basename="group_membership") # Project routers -router.register(r'projects', ProjectViewSet, - basename='project') -router.register(r'projects-stat', ProjectStatViewSet, - basename='project-stat') -router.register(r'project-roles', ProjectRoleViewSet, - basename='project_role') -router.register(r'projects/(?P<project_id>\d+)/project-memberships', ProjectMembershipViewSet, - basename='project_membership') -router.register(r'projects/(?P<project_id>\d+)/project-usergroups', ProjectUserGroupViewSet, - basename='project_usergroup') +router.register(r"projects", ProjectViewSet, basename="project") +router.register(r"projects-stat", ProjectStatViewSet, basename="project-stat") +router.register(r"project-roles", ProjectRoleViewSet, basename="project_role") +router.register(r"projects/(?P<project_id>\d+)/project-memberships", ProjectMembershipViewSet, basename="project_membership") +router.register(r"projects/(?P<project_id>\d+)/project-usergroups", ProjectUserGroupViewSet, basename="project_usergroup") # Geo routers -router.register(r'regions', RegionViewSet, - basename='region') -router.register(r'admin-levels', AdminLevelViewSet, - basename='admin_level') -router.register(r'projects/(?P<project_id>\d+)/geo-area', GeoAreaView, - basename='geo_area') +router.register(r"regions", RegionViewSet, basename="region") +router.register(r"admin-levels", AdminLevelViewSet, basename="admin_level") +router.register(r"projects/(?P<project_id>\d+)/geo-area", GeoAreaView, basename="geo_area") # Lead routers -router.register(r'lead-groups', LeadGroupViewSet, - basename='lead_group') -router.register(r'leads', LeadViewSet, - basename='lead') -router.register(r'project/(?P<project_id>\d+)/leads', LeadBulkDeleteViewSet, - basename='leads-bulk') -router.register(r'lead-previews', LeadPreviewViewSet, - basename='lead_preview') +router.register(r"lead-groups", LeadGroupViewSet, basename="lead_group") +router.register(r"leads", LeadViewSet, basename="lead") +router.register(r"project/(?P<project_id>\d+)/leads", LeadBulkDeleteViewSet, basename="leads-bulk") +router.register(r"lead-previews", LeadPreviewViewSet, basename="lead_preview") # Questionnaire routers -router.register(r'questionnaires/(?P<questionnaire_id>\d+)/questions', - QuestionViewSet, basename='question') -router.register(r'questionnaires', QuestionnaireViewSet, - basename='questionnaire') +router.register(r"questionnaires/(?P<questionnaire_id>\d+)/questions", QuestionViewSet, basename="question") +router.register(r"questionnaires", QuestionnaireViewSet, basename="questionnaire") # Entry routers -router.register(r'entries', EntryViewSet, - basename='entry_lead') -router.register(r'entry-attributes', AttributeViewSet, - basename='entry_attribute') -router.register(r'entry-filter-data', FilterDataViewSet, - basename='entry_filter_data') -router.register(r'entry-export-data', ExportDataViewSet, - basename='entry_export_data') -router.register(r'edit-entries-data', EditEntriesDataViewSet, - basename='edit_entries_data') - -router.register(r'entries/(?P<entry_id>\d+)/entry-comments', EntryCommentViewSet, basename='entry-comment') -router.register(r'projects/(?P<project_id>\d+)/entry-labels', ProjectEntryLabelViewSet, basename='entry-labels') -router.register(r'leads/(?P<lead_id>\d+)/entry-groups', LeadEntryGroupViewSet, basename='entry-groups') +router.register(r"entries", EntryViewSet, basename="entry_lead") +router.register(r"entry-attributes", AttributeViewSet, basename="entry_attribute") +router.register(r"entry-filter-data", FilterDataViewSet, basename="entry_filter_data") +router.register(r"entry-export-data", ExportDataViewSet, basename="entry_export_data") +router.register(r"edit-entries-data", EditEntriesDataViewSet, basename="edit_entries_data") + +router.register(r"entries/(?P<entry_id>\d+)/entry-comments", EntryCommentViewSet, basename="entry-comment") +router.register(r"projects/(?P<project_id>\d+)/entry-labels", ProjectEntryLabelViewSet, basename="entry-labels") +router.register(r"leads/(?P<lead_id>\d+)/entry-groups", LeadEntryGroupViewSet, basename="entry-groups") # Analysis routers -router.register(r'projects/(?P<project_id>\d+)/analysis', AnalysisViewSet, - basename='analysis') -router.register(r'projects/(?P<project_id>\d+)/analysis/(?P<analysis_id>\d+)/pillars', - AnalysisPillarViewSet, basename='analysis_analysis_pillar') +router.register(r"projects/(?P<project_id>\d+)/analysis", AnalysisViewSet, basename="analysis") +router.register( + r"projects/(?P<project_id>\d+)/analysis/(?P<analysis_id>\d+)/pillars", + AnalysisPillarViewSet, + basename="analysis_analysis_pillar", +) router.register( - r'projects/(?P<project_id>\d+)/analysis/(?P<analysis_id>\d+)/pillars/(?P<analysis_pillar_id>\d+)/analytical-statement', - AnalyticalStatementViewSet, basename='analytical_statement') + r"projects/(?P<project_id>\d+)/analysis/(?P<analysis_id>\d+)/pillars/(?P<analysis_pillar_id>\d+)/analytical-statement", + AnalyticalStatementViewSet, + basename="analytical_statement", +) router.register( - r'analysis-pillar/(?P<analysis_pillar_id>\d+)/discarded-entries', - AnalysisPillarDiscardedEntryViewSet, basename='analysis_pillar_discarded_entries' + r"analysis-pillar/(?P<analysis_pillar_id>\d+)/discarded-entries", + AnalysisPillarDiscardedEntryViewSet, + basename="analysis_pillar_discarded_entries", ) # QA routers -router.register( - r'entries/(?P<entry_id>\d+)/review-comments', EntryReviewCommentViewSet, basename='entry-review-comment') +router.register(r"entries/(?P<entry_id>\d+)/review-comments", EntryReviewCommentViewSet, basename="entry-review-comment") # Analysis framework routers -router.register(r'analysis-frameworks/(?P<af_id>\d+)/questions', - FrameworkQuestionViewSet, basename='framework-question') -router.register(r'analysis-frameworks', AnalysisFrameworkViewSet, - basename='analysis_framework') -router.register(r'analysis-framework-widgets', WidgetViewSet, - basename='analysis_framework_widget') -router.register(r'analysis-framework-filters', FilterViewSet, - basename='analysis_framework_filter') -router.register(r'analysis-framework-exportables', ExportableViewSet, - basename='analysis_framework_exportable') -router.register(r'framework-memberships', AnalysisFrameworkMembershipViewSet, - basename='framework_memberships') -router.register(r'private-framework-roles', PrivateAnalysisFrameworkRoleViewSet, - basename='framework_roles') -router.register(r'public-framework-roles', PublicAnalysisFrameworkRoleViewSet, - basename='framework_roles') +router.register(r"analysis-frameworks/(?P<af_id>\d+)/questions", FrameworkQuestionViewSet, basename="framework-question") +router.register(r"analysis-frameworks", AnalysisFrameworkViewSet, basename="analysis_framework") +router.register(r"analysis-framework-widgets", WidgetViewSet, basename="analysis_framework_widget") +router.register(r"analysis-framework-filters", FilterViewSet, basename="analysis_framework_filter") +router.register(r"analysis-framework-exportables", ExportableViewSet, basename="analysis_framework_exportable") +router.register(r"framework-memberships", AnalysisFrameworkMembershipViewSet, basename="framework_memberships") +router.register(r"private-framework-roles", PrivateAnalysisFrameworkRoleViewSet, basename="framework_roles") +router.register(r"public-framework-roles", PublicAnalysisFrameworkRoleViewSet, basename="framework_roles") # Assessment registry -router.register(r'assessments', AssessmentViewSet, - basename='assessment') -router.register(r'planned-assessments', PlannedAssessmentViewSet, - basename='planned-assessment') - -router.register(r'lead-assessments', LeadAssessmentViewSet, - basename='lead_assessment') -router.register(r'lead-group-assessments', LeadGroupAssessmentViewSet, - basename='lead_group_assessment') -router.register(r'assessment-templates', AssessmentTemplateViewSet, - basename='assessment_template') +router.register(r"assessments", AssessmentViewSet, basename="assessment") +router.register(r"planned-assessments", PlannedAssessmentViewSet, basename="planned-assessment") + +router.register(r"lead-assessments", LeadAssessmentViewSet, basename="lead_assessment") +router.register(r"lead-group-assessments", LeadGroupAssessmentViewSet, basename="lead_group_assessment") +router.register(r"assessment-templates", AssessmentTemplateViewSet, basename="assessment_template") # Category editor routers -router.register(r'category-editors', CategoryEditorViewSet, - basename='category_editor') +router.register(r"category-editors", CategoryEditorViewSet, basename="category_editor") # Connector routers -router.register(r'connector-sources', SourceViewSet, - basename='connector_source') -router.register(r'connectors', ConnectorViewSet, - basename='connector') -router.register(r'connector-users', ConnectorUserViewSet, - basename='connector_users') -router.register(r'connector-projects', ConnectorProjectViewSet, - basename='connector_projects') +router.register(r"connector-sources", SourceViewSet, basename="connector_source") +router.register(r"connectors", ConnectorViewSet, basename="connector") +router.register(r"connector-users", ConnectorUserViewSet, basename="connector_users") +router.register(r"connector-projects", ConnectorProjectViewSet, basename="connector_projects") # Organization routers -router.register(r'organizations', OrganizationViewSet, basename='organization') -router.register(r'organization-types', OrganizationTypeViewSet, basename='organization-type') +router.register(r"organizations", OrganizationViewSet, basename="organization") +router.register(r"organization-types", OrganizationTypeViewSet, basename="organization-type") # Export routers -router.register(r'exports', ExportViewSet, basename='export') +router.register(r"exports", ExportViewSet, basename="export") # Notification routers -router.register(r'notifications', - NotificationViewSet, basename='notification') -router.register(r'assignments', - AssignmentViewSet, basename='assignments') +router.register(r"notifications", NotificationViewSet, basename="notification") +router.register(r"assignments", AssignmentViewSet, basename="assignments") # Language routers -router.register(r'languages', LanguageViewSet, basename='language') +router.register(r"languages", LanguageViewSet, basename="language") # Page routers -router.register(r'pages', PageViewSet, basename='page') +router.register(r"pages", PageViewSet, basename="page") # Versioning : (v1|v2|v3) -API_PREFIX = r'^api/(?P<version>(v1|v2))/' +API_PREFIX = r"^api/(?P<version>(v1|v2))/" def get_api_path(path): - return '{}{}'.format(API_PREFIX, path) + return "{}{}".format(API_PREFIX, path) CustomGraphQLView.graphiql_template = "graphene_graphiql_explorer/graphiql.html" @@ -402,245 +326,203 @@ def get_api_path(path): if not settings.DEBUG: admin.site.__class__ = OTPAdminSite -urlpatterns = [ - re_path(r'^$', FrontendView.as_view(), name='server-frontend'), - re_path(r'^admin/', admin.site.urls), - re_path(r'^graphql-docs/$', graphql_docs, name='graphql_docs'), - re_path(r'^api-docs(?P<format>\.json|\.yaml)$', - api_schema_view.without_ui(cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), name='schema-json'), - re_path(r'^api-docs/$', api_schema_view.with_ui('swagger', cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), - name='schema-swagger-ui'), - re_path(r'^redoc/$', api_schema_view.with_ui('redoc', cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), - name='schema-redoc'), - - # JWT Authentication - re_path(get_api_path(r'token/$'), - TokenObtainPairView.as_view()), - - re_path(get_api_path(r'token/hid/$'), - HIDTokenObtainPairView.as_view()), - - re_path(get_api_path(r'token/refresh/$'), - TokenRefreshView.as_view()), - - # Gallery - re_path(r'^file/(?P<file_id>\d+)/$', FileView.as_view(), name='file'), - path( - 'private-file/<uuid:uuid>/<filename:filename>', - PrivateFileView.as_view(), - name='gallery_private_url', - ), - path( - 'deprecated-private-file/<uuid:uuid>/<filename:filename>', - DeprecatedPrivateFileView.as_view(), - name='deprecated_gallery_private_url', - ), - re_path( - r'^public-file/(?P<fidb64>[0-9A-Za-z]+)/(?P<token>.+)/(?P<filename>.*)$', - PublicFileView.as_view(), - name='gallery_public_url', - ), - - # Activate User - re_path(r'^user/activate/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/$', - user_activate_confirm, - name='user_activate_confirm'), - - # Unsubscribe User Email - re_path(r'^user/unsubscribe/email/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/' - '(?P<email_type>[A-Za-z_]+)/$', +urlpatterns = ( + [ + re_path(r"^$", FrontendView.as_view(), name="server-frontend"), + re_path(r"^admin/", admin.site.urls), + re_path(r"^graphql-docs/$", graphql_docs, name="graphql_docs"), + re_path( + r"^api-docs(?P<format>\.json|\.yaml)$", + api_schema_view.without_ui(cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), + name="schema-json", + ), + re_path( + r"^api-docs/$", + api_schema_view.with_ui("swagger", cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), + name="schema-swagger-ui", + ), + re_path(r"^redoc/$", api_schema_view.with_ui("redoc", cache_timeout=settings.OPEN_API_DOCS_TIMEOUT), name="schema-redoc"), + # JWT Authentication + re_path(get_api_path(r"token/$"), TokenObtainPairView.as_view()), + re_path(get_api_path(r"token/hid/$"), HIDTokenObtainPairView.as_view()), + re_path(get_api_path(r"token/refresh/$"), TokenRefreshView.as_view()), + # Gallery + re_path(r"^file/(?P<file_id>\d+)/$", FileView.as_view(), name="file"), + path( + "private-file/<uuid:uuid>/<filename:filename>", + PrivateFileView.as_view(), + name="gallery_private_url", + ), + path( + "deprecated-private-file/<uuid:uuid>/<filename:filename>", + DeprecatedPrivateFileView.as_view(), + name="deprecated_gallery_private_url", + ), + re_path( + r"^public-file/(?P<fidb64>[0-9A-Za-z]+)/(?P<token>.+)/(?P<filename>.*)$", + PublicFileView.as_view(), + name="gallery_public_url", + ), + # Activate User + re_path(r"^user/activate/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/$", user_activate_confirm, name="user_activate_confirm"), + # Unsubscribe User Email + re_path( + r"^user/unsubscribe/email/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/" "(?P<email_type>[A-Za-z_]+)/$", unsubscribe_email, - name='unsubscribe_email'), - # Project Request Accept - re_path(r'^project/join-request/' - '(?P<uidb64>[0-9A-Za-z]+)-(?P<pidb64>[0-9A-Za-z]+)-(?P<token>.+)/$', + name="unsubscribe_email", + ), + # Project Request Accept + re_path( + r"^project/join-request/" "(?P<uidb64>[0-9A-Za-z]+)-(?P<pidb64>[0-9A-Za-z]+)-(?P<token>.+)/$", accept_project_confirm, - name='accept_project_confirm'), - - # password reset API - re_path(get_api_path(r'password/reset/$'), - PasswordResetView.as_view()), - - # Password Reset - re_path(r'^password/reset/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/$', + name="accept_project_confirm", + ), + # password reset API + re_path(get_api_path(r"password/reset/$"), PasswordResetView.as_view()), + # Password Reset + re_path( + r"^password/reset/(?P<uidb64>[0-9A-Za-z]+)-(?P<token>.+)/$", auth_views.PasswordResetConfirmView.as_view( - success_url='{}://{}/login/'.format( - settings.HTTP_PROTOCOL, settings.DEEPER_FRONTEND_HOST, + success_url="{}://{}/login/".format( + settings.HTTP_PROTOCOL, + settings.DEEPER_FRONTEND_HOST, ) ), - name='password_reset_confirm'), - - # Attribute options for various models - re_path(get_api_path(r'lead-options/$'), - LeadOptionsView.as_view()), - re_path(get_api_path(r'assessment-options/$'), - AssessmentOptionsView.as_view()), - re_path(get_api_path(r'entry-options/$'), - EntryOptionsView.as_view()), - re_path(get_api_path(r'project-options/$'), - ProjectOptionsView.as_view()), - re_path(get_api_path(r'discarded-entry-options/$'), - DiscardedEntryOptionsView.as_view()), - - # Triggering api - re_path(get_api_path(r'lead-extraction-trigger/(?P<lead_id>\d+)/$'), - LeadExtractionTriggerView.as_view()), - - re_path(get_api_path(r'file-extraction-trigger/$'), - FileExtractionTriggerView.as_view()), - - re_path(get_api_path(r'meta-extraction/(?P<file_id>\d+)/$'), - MetaExtractionView.as_view()), - - re_path(get_api_path(r'geo-areas-load-trigger/(?P<region_id>\d+)/$'), - GeoAreasLoadTriggerView.as_view()), - - re_path(get_api_path(r'export-trigger/$'), - ExportTriggerView.as_view()), - - re_path(get_api_path(r'tabular-extraction-trigger/(?P<book_id>\d+)/$'), - TabularExtractionTriggerView.as_view()), - - re_path(get_api_path(r'tabular-geo-extraction-trigger/(?P<field_id>\d+)/$'), - TabularGeoProcessTriggerView.as_view()), - - # Website fetch api - re_path(get_api_path(r'lead-website-fetch/$'), LeadWebsiteFetch.as_view()), - - re_path(get_api_path(r'web-info-data/$'), WebInfoDataView.as_view()), - re_path(get_api_path(r'web-info-extract/$'), WebInfoExtractView.as_view()), - - # Questionnaire utils api - re_path(get_api_path(r'xlsform-to-xform/$'), XFormView.as_view()), - re_path(get_api_path(r'import-to-kobotoolbox/$'), KoboToolboxExport.as_view()), - - # Lead copy - re_path(get_api_path(r'lead-copy/$'), LeadCopyView.as_view()), - # Assessment copy - re_path(get_api_path(r'assessment-copy/$'), AssessmentCopyView.as_view()), - - # Filter apis - re_path(get_api_path(r'entries/filter/'), EntryFilterView.as_view()), - re_path( - get_api_path(r'analysis-pillar/(?P<analysis_pillar_id>\d+)/entries'), - AnalysisPillarEntryViewSet.as_view(), - name='analysis_pillar_entries', - ), - - re_path(get_api_path( - r'projects/(?P<project_id>\d+)/category-editor/classify/' - ), CategoryEditorClassifyView.as_view()), - - # Source query api - re_path(get_api_path( - r'connector-sources/(?P<source_type>[-\w]+)/(?P<query>[-\w]+)/', - ), SourceQueryView.as_view()), - - # Geojson api - re_path(get_api_path(r'admin-levels/(?P<admin_level_id>\d+)/geojson/$'), - GeoJsonView.as_view()), - re_path(get_api_path(r'admin-levels/(?P<admin_level_id>\d+)/geojson/bounds/$'), - GeoBoundsView.as_view()), - re_path(get_api_path(r'geo-options/$'), - GeoOptionsView.as_view()), - - # Clone apis - re_path(get_api_path(r'clone-region/(?P<region_id>\d+)/$'), - RegionCloneView.as_view()), - re_path(get_api_path(r'clone-analysis-framework/(?P<af_id>\d+)/$'), - AnalysisFrameworkCloneView.as_view()), - re_path(get_api_path(r'clone-category-editor/(?P<ce_id>\d+)/$'), - CategoryEditorCloneView.as_view()), - - # NLP Callback endpoints - re_path( - get_api_path(r'callback/lead-extract/$'), - LeadExtractCallbackView.as_view(), - name='lead_extract_callback', - ), - re_path( - get_api_path(r'callback/unified-connector-lead-extract/$'), - UnifiedConnectorLeadExtractCallbackView.as_view(), - name='unified_connector_lead_extract_callback', - ), - re_path( - get_api_path(r'callback/assisted-tagging-draft-entry-prediction/$'), - AssistedTaggingDraftEntryPredictionCallbackView.as_view(), - name='assisted_tagging_draft_entry_prediction_callback', - ), - re_path( - get_api_path(r'callback/auto-assisted-tagging-draft-entry-prediction/$'), - AutoTaggingDraftEntryPredictionCallbackView.as_view(), - name='auto-assisted_tagging_draft_entry_prediction_callback', - ), - - re_path( - get_api_path(r'callback/analysis-topic-model/$'), - AnalysisTopicModelCallbackView.as_view(), - name='analysis_topic_model_callback', - ), - re_path( - get_api_path(r'callback/analysis-automatic-summary/$'), - AnalysisAutomaticSummaryCallbackView.as_view(), - name='analysis_automatic_summary_callback', - ), - re_path( - get_api_path(r'callback/analysis-automatic-ngram/$'), - AnalyticalStatementNGramCallbackView.as_view(), - name='analysis_automatic_ngram_callback', - ), - re_path( - get_api_path(r'callback/analysis-geo/$'), - AnalyticalStatementGeoCallbackView.as_view(), - name='analysis_geo_callback', - ), - - # Combined API View - re_path(get_api_path(r'combined/$'), CombinedView.as_view()), - - # Viewsets - re_path(get_api_path(''), include(router.urls)), - - # DRF auth, TODO: logout - re_path(r'^api-auth/', include('rest_framework.urls', - namespace='rest_framework')), - - re_path(r'^project-viz/(?P<project_stat_id>\d+)/(?P<token>[0-9a-f-]+)/$', - ProjectPublicVizView.as_view(), name='project-stat-viz-public'), - - re_path(r'^favicon.ico$', - RedirectView.as_view(url=get_frontend_url('favicon.ico')), - name="favicon"), - - re_path('ses-bounce/?$', ses_bounce_handler_view, name='ses_bounce'), -] + [ - # graphql patterns - re_path('^graphql/?$', csrf_exempt(CustomGraphQLView.as_view())), - re_path(r'^favicon.ico$', + name="password_reset_confirm", + ), + # Attribute options for various models + re_path(get_api_path(r"lead-options/$"), LeadOptionsView.as_view()), + re_path(get_api_path(r"assessment-options/$"), AssessmentOptionsView.as_view()), + re_path(get_api_path(r"entry-options/$"), EntryOptionsView.as_view()), + re_path(get_api_path(r"project-options/$"), ProjectOptionsView.as_view()), + re_path(get_api_path(r"discarded-entry-options/$"), DiscardedEntryOptionsView.as_view()), + # Triggering api + re_path(get_api_path(r"lead-extraction-trigger/(?P<lead_id>\d+)/$"), LeadExtractionTriggerView.as_view()), + re_path(get_api_path(r"file-extraction-trigger/$"), FileExtractionTriggerView.as_view()), + re_path(get_api_path(r"meta-extraction/(?P<file_id>\d+)/$"), MetaExtractionView.as_view()), + re_path(get_api_path(r"geo-areas-load-trigger/(?P<region_id>\d+)/$"), GeoAreasLoadTriggerView.as_view()), + re_path(get_api_path(r"export-trigger/$"), ExportTriggerView.as_view()), + re_path(get_api_path(r"tabular-extraction-trigger/(?P<book_id>\d+)/$"), TabularExtractionTriggerView.as_view()), + re_path(get_api_path(r"tabular-geo-extraction-trigger/(?P<field_id>\d+)/$"), TabularGeoProcessTriggerView.as_view()), + # Website fetch api + re_path(get_api_path(r"lead-website-fetch/$"), LeadWebsiteFetch.as_view()), + re_path(get_api_path(r"web-info-data/$"), WebInfoDataView.as_view()), + re_path(get_api_path(r"web-info-extract/$"), WebInfoExtractView.as_view()), + # Questionnaire utils api + re_path(get_api_path(r"xlsform-to-xform/$"), XFormView.as_view()), + re_path(get_api_path(r"import-to-kobotoolbox/$"), KoboToolboxExport.as_view()), + # Lead copy + re_path(get_api_path(r"lead-copy/$"), LeadCopyView.as_view()), + # Assessment copy + re_path(get_api_path(r"assessment-copy/$"), AssessmentCopyView.as_view()), + # Filter apis + re_path(get_api_path(r"entries/filter/"), EntryFilterView.as_view()), + re_path( + get_api_path(r"analysis-pillar/(?P<analysis_pillar_id>\d+)/entries"), + AnalysisPillarEntryViewSet.as_view(), + name="analysis_pillar_entries", + ), + re_path(get_api_path(r"projects/(?P<project_id>\d+)/category-editor/classify/"), CategoryEditorClassifyView.as_view()), + # Source query api + re_path( + get_api_path( + r"connector-sources/(?P<source_type>[-\w]+)/(?P<query>[-\w]+)/", + ), + SourceQueryView.as_view(), + ), + # Geojson api + re_path(get_api_path(r"admin-levels/(?P<admin_level_id>\d+)/geojson/$"), GeoJsonView.as_view()), + re_path(get_api_path(r"admin-levels/(?P<admin_level_id>\d+)/geojson/bounds/$"), GeoBoundsView.as_view()), + re_path(get_api_path(r"geo-options/$"), GeoOptionsView.as_view()), + # Clone apis + re_path(get_api_path(r"clone-region/(?P<region_id>\d+)/$"), RegionCloneView.as_view()), + re_path(get_api_path(r"clone-analysis-framework/(?P<af_id>\d+)/$"), AnalysisFrameworkCloneView.as_view()), + re_path(get_api_path(r"clone-category-editor/(?P<ce_id>\d+)/$"), CategoryEditorCloneView.as_view()), + # NLP Callback endpoints + re_path( + get_api_path(r"callback/lead-extract/$"), + LeadExtractCallbackView.as_view(), + name="lead_extract_callback", + ), + re_path( + get_api_path(r"callback/unified-connector-lead-extract/$"), + UnifiedConnectorLeadExtractCallbackView.as_view(), + name="unified_connector_lead_extract_callback", + ), + re_path( + get_api_path(r"callback/assisted-tagging-draft-entry-prediction/$"), + AssistedTaggingDraftEntryPredictionCallbackView.as_view(), + name="assisted_tagging_draft_entry_prediction_callback", + ), + re_path( + get_api_path(r"callback/auto-assisted-tagging-draft-entry-prediction/$"), + AutoTaggingDraftEntryPredictionCallbackView.as_view(), + name="auto-assisted_tagging_draft_entry_prediction_callback", + ), + re_path( + get_api_path(r"callback/analysis-topic-model/$"), + AnalysisTopicModelCallbackView.as_view(), + name="analysis_topic_model_callback", + ), + re_path( + get_api_path(r"callback/analysis-automatic-summary/$"), + AnalysisAutomaticSummaryCallbackView.as_view(), + name="analysis_automatic_summary_callback", + ), + re_path( + get_api_path(r"callback/analysis-automatic-ngram/$"), + AnalyticalStatementNGramCallbackView.as_view(), + name="analysis_automatic_ngram_callback", + ), + re_path( + get_api_path(r"callback/analysis-geo/$"), + AnalyticalStatementGeoCallbackView.as_view(), + name="analysis_geo_callback", + ), + # Combined API View + re_path(get_api_path(r"combined/$"), CombinedView.as_view()), + # Viewsets + re_path(get_api_path(""), include(router.urls)), + # DRF auth, TODO: logout + re_path(r"^api-auth/", include("rest_framework.urls", namespace="rest_framework")), + re_path( + r"^project-viz/(?P<project_stat_id>\d+)/(?P<token>[0-9a-f-]+)/$", + ProjectPublicVizView.as_view(), + name="project-stat-viz-public", + ), + re_path(r"^favicon.ico$", RedirectView.as_view(url=get_frontend_url("favicon.ico")), name="favicon"), + re_path("ses-bounce/?$", ses_bounce_handler_view, name="ses_bounce"), + ] + + [ + # graphql patterns + re_path("^graphql/?$", csrf_exempt(CustomGraphQLView.as_view())), + re_path( + r"^favicon.ico$", RedirectView.as_view( - url=get_frontend_url('favicon.ico'), + url=get_frontend_url("favicon.ico"), ), - name="favicon"), -] + static.static( - settings.MEDIA_URL, view=xframe_options_exempt(serve), - document_root=settings.MEDIA_ROOT + name="favicon", + ), + ] + + static.static(settings.MEDIA_URL, view=xframe_options_exempt(serve), document_root=settings.MEDIA_ROOT) ) if settings.DEBUG: import debug_toolbar - if 'debug_toolbar' in settings.INSTALLED_APPS: + + if "debug_toolbar" in settings.INSTALLED_APPS: urlpatterns += [ - re_path('__debug__/', include(debug_toolbar.urls)), + re_path("__debug__/", include(debug_toolbar.urls)), ] urlpatterns += [ - re_path('^graphiql/?$', csrf_exempt(CustomGraphQLView.as_view(graphiql=True))), - re_path(r'^pr-email/$', PasswordReset.as_view()), - re_path(r'^pc-email/$', PasswordChanged.as_view()), - re_path(r'^aa-email/$', AccountActivate.as_view()), - re_path(r'^pj-email/$', ProjectJoinRequest.as_view()), - re_path(r'^ec-email/$', EntryCommentEmail.as_view()), - re_path(r'^erc-email/$', EntryReviewCommentEmail.as_view()), - re_path(r'^render-debug/$', RenderChart.as_view()), + re_path("^graphiql/?$", csrf_exempt(CustomGraphQLView.as_view(graphiql=True))), + re_path(r"^pr-email/$", PasswordReset.as_view()), + re_path(r"^pc-email/$", PasswordChanged.as_view()), + re_path(r"^aa-email/$", AccountActivate.as_view()), + re_path(r"^pj-email/$", ProjectJoinRequest.as_view()), + re_path(r"^ec-email/$", EntryCommentEmail.as_view()), + re_path(r"^erc-email/$", EntryReviewCommentEmail.as_view()), + re_path(r"^render-debug/$", RenderChart.as_view()), ] handler404 = Api_404View.as_view() diff --git a/deep/views.py b/deep/views.py index 34fe80e184..141d3662d9 100644 --- a/deep/views.py +++ b/deep/views.py @@ -1,21 +1,16 @@ import datetime -from rest_framework.exceptions import NotFound, PermissionDenied -from rest_framework import ( - views, - status, - response, -) - +from django.conf import settings from django.core.exceptions import PermissionDenied as DjPermissionDenied -from django.views.decorators.clickjacking import xframe_options_exempt -from django.http import JsonResponse, HttpResponse +from django.http import HttpResponse, JsonResponse +from django.template.response import TemplateResponse from django.urls import resolve +from django.views.decorators.clickjacking import xframe_options_exempt from django.views.generic import View -from django.conf import settings -from django.template.response import TemplateResponse from graphene_django.views import GraphQLView from graphene_file_upload.django import FileUploadGraphQLView +from rest_framework import response, status, views +from rest_framework.exceptions import NotFound, PermissionDenied from sentry_sdk.api import start_transaction as sentry_start_transaction # Importing for initialization (Make sure to import this before apps.<>) @@ -24,27 +19,27 @@ Make sure use string import outside graphene files. For eg: In filters.py use 'entry.schema.EntryListType' instead of `from entry.schema import EntryListType' """ -from deep.graphene_converter import * # type: ignore # noqa F401 - -from deep.graphene_context import GQLContext -from deep.exceptions import PermissionDeniedException -from user.models import User, Profile, EmailCondition -from project.models import Project +import graphdoc from entry.models import EntryComment -from quality_assurance.models import EntryReviewComment from notification.models import Notification -import graphdoc +from project.models import Project +from quality_assurance.models import EntryReviewComment +from user.models import EmailCondition, Profile, User + +from deep.exceptions import PermissionDeniedException +from deep.graphene_context import GQLContext +from deep.graphene_converter import * # type: ignore # noqa F401 def graphql_docs(request): html = graphdoc.to_doc(str(CustomGraphQLView().schema)) - return HttpResponse(html, content_type='text/html') + return HttpResponse(html, content_type="text/html") -def get_frontend_url(path=''): - return '{protocol}://{domain}/{path}'.format( - protocol=settings.HTTP_PROTOCOL or 'http', - domain=settings.DEEPER_FRONTEND_HOST or 'localhost:3000', +def get_frontend_url(path=""): + return "{protocol}://{domain}/{path}".format( + protocol=settings.HTTP_PROTOCOL or "http", + domain=settings.DEEPER_FRONTEND_HOST or "localhost:3000", path=path, ) @@ -53,37 +48,34 @@ class FrontendView(View): def get(self, request): # TODO: make nice redirect page context = { - 'frontend_url': get_frontend_url(), + "frontend_url": get_frontend_url(), } - return TemplateResponse(request, 'home/welcome.html', context) + return TemplateResponse(request, "home/welcome.html", context) class Api_404View(views.APIView): def get(self, request, exception): - raise NotFound(detail="Error 404, page not found", - code=status.HTTP_404_NOT_FOUND) + raise NotFound(detail="Error 404, page not found", code=status.HTTP_404_NOT_FOUND) class CombinedView(views.APIView): def get(self, request, version=None): - apis = request.query_params.get('apis', None) + apis = request.query_params.get("apis", None) if apis is None: return response.Response({}) - apis = apis.split(',') + apis = apis.split(",") results = {} - api_prefix = '/'.join(request.path_info.split('/')[:-2]) + api_prefix = "/".join(request.path_info.split("/")[:-2]) for api in apis: - url = '{}/{}/'.format(api_prefix, api.strip('/')) + url = "{}/{}/".format(api_prefix, api.strip("/")) view, args, kwargs = resolve(url) - kwargs['request'] = request._request + kwargs["request"] = request._request api_response = view(*args, **kwargs) if api_response.status_code >= 400: - return response.Response({ - api: api_response.data - }, status=api_response.status_code) + return response.Response({api: api_response.data}, status=api_response.status_code) results[api] = api_response.data return response.Response(results) @@ -93,34 +85,35 @@ class ProjectPublicVizView(View): """ View for public viz view without user authentication """ + @xframe_options_exempt def get(self, request, project_stat_id, token): from project.views import _get_viz_data - json_only = 'json' in request.GET.get('format', ['html']) + json_only = "json" in request.GET.get("format", ["html"]) project = Project.objects.get(entry_stats__id=project_stat_id) context, status_code = _get_viz_data(request, project, False, token) - context['project_title'] = project.title + context["project_title"] = project.title if json_only: return JsonResponse(context, status=status_code) - context['poll_url'] = f'{request.path}?format=json' - return TemplateResponse(request, 'project/project_viz.html', context, status=status_code) + context["poll_url"] = f"{request.path}?format=json" + return TemplateResponse(request, "project/project_viz.html", context, status=status_code) def get_basic_email_context(): user = User.objects.get(pk=1) context = { - 'client_domain': settings.DEEPER_FRONTEND_HOST, - 'protocol': settings.HTTP_PROTOCOL, - 'site_name': settings.DEEPER_SITE_NAME, - 'domain': settings.DJANGO_API_HOST, - 'uid': 'fakeuid', - 'user': user, - 'unsubscribe_email_types': Profile.EMAIL_CONDITIONS_TYPES, - 'request_by': user, - 'token': 'faketoken', - 'unsubscribe_email_token': 'faketoken', - 'unsubscribe_email_id': 'fakeid', + "client_domain": settings.DEEPER_FRONTEND_HOST, + "protocol": settings.HTTP_PROTOCOL, + "site_name": settings.DEEPER_SITE_NAME, + "domain": settings.DJANGO_API_HOST, + "uid": "fakeuid", + "user": user, + "unsubscribe_email_types": Profile.EMAIL_CONDITIONS_TYPES, + "request_by": user, + "token": "faketoken", + "unsubscribe_email_token": "faketoken", + "unsubscribe_email_id": "fakeid", } return context @@ -130,20 +123,21 @@ class ProjectJoinRequest(View): Template view for project join request email NOTE: Use Only For Debug """ + def get(self, request): project = Project.objects.get(pk=1) context = get_basic_email_context() - context.update({ - 'email_type': 'join_requests', - 'project': project, - 'pid': 'fakeuid', - 'reason': 'I want to join this project \ + context.update( + { + "email_type": "join_requests", + "project": project, + "pid": "fakeuid", + "reason": "I want to join this project \ because this is closely related to my research. \ - Data from this project will help me alot.', - }) - return TemplateResponse( - request, 'project/project_join_request_email.html', context + Data from this project will help me alot.", + } ) + return TemplateResponse(request, "project/project_join_request_email.html", context) class PasswordReset(View): @@ -151,13 +145,12 @@ class PasswordReset(View): Template view for password reset email NOTE: Use Only For Debug """ + def get(self, request): - welcome = request.GET.get('welcome', 'false').upper() == 'TRUE' + welcome = request.GET.get("welcome", "false").upper() == "TRUE" context = get_basic_email_context() - context.update({'welcome': welcome}) - return TemplateResponse( - request, 'registration/password_reset_email.html', context - ) + context.update({"welcome": welcome}) + return TemplateResponse(request, "registration/password_reset_email.html", context) class PasswordChanged(View): @@ -165,18 +158,17 @@ class PasswordChanged(View): Template view for password changed email NOTE: Use Only For Debug """ + def get(self, request): context = get_basic_email_context() context.update( { - 'time': datetime.datetime.now(), - 'location': 'Nepal', - 'device': 'Chrome OS', + "time": datetime.datetime.now(), + "location": "Nepal", + "device": "Chrome OS", } ) - return TemplateResponse( - request, 'password_changed/email.html', context - ) + return TemplateResponse(request, "password_changed/email.html", context) class AccountActivate(View): @@ -184,11 +176,10 @@ class AccountActivate(View): Template view for account activate email NOTE: Use Only For Debug """ + def get(self, request): context = get_basic_email_context() - return TemplateResponse( - request, 'registration/user_activation_email.html', context - ) + return TemplateResponse(request, "registration/user_activation_email.html", context) class EntryCommentEmail(View): @@ -196,25 +187,20 @@ class EntryCommentEmail(View): Template view for entry commit email NOTE: Use Only For Debug """ + def get(self, request): - comment_id = request.GET.get('comment_id') - comment = ( - EntryComment.objects.get(pk=comment_id) - if comment_id else EntryComment - .objects - .filter(parent=None) - .first() - ) + comment_id = request.GET.get("comment_id") + comment = EntryComment.objects.get(pk=comment_id) if comment_id else EntryComment.objects.filter(parent=None).first() context = get_basic_email_context() - context.update({ - 'email_type': EmailCondition.EMAIL_COMMENT, - 'notification_type': Notification.Type.ENTRY_COMMENT_ASSIGNEE_CHANGE, - 'Notification': Notification, - 'comment': comment, - }) - return TemplateResponse( - request, 'entry/comment_notification_email.html', context + context.update( + { + "email_type": EmailCondition.EMAIL_COMMENT, + "notification_type": Notification.Type.ENTRY_COMMENT_ASSIGNEE_CHANGE, + "Notification": Notification, + "comment": comment, + } ) + return TemplateResponse(request, "entry/comment_notification_email.html", context) class EntryReviewCommentEmail(View): @@ -222,24 +208,22 @@ class EntryReviewCommentEmail(View): Template view for entry review commit email NOTE: Use Only For Debug """ + def get(self, request): - comment_id = request.GET.get('comment_id') - notification_type = request.GET.get('notification_type', Notification.Type.ENTRY_REVIEW_COMMENT_ADD) - comment = ( - EntryReviewComment.objects.get(pk=comment_id) - if comment_id else EntryReviewComment.objects.first() - ) + comment_id = request.GET.get("comment_id") + notification_type = request.GET.get("notification_type", Notification.Type.ENTRY_REVIEW_COMMENT_ADD) + comment = EntryReviewComment.objects.get(pk=comment_id) if comment_id else EntryReviewComment.objects.first() context = get_basic_email_context() - context.update({ - 'email_type': EmailCondition.EMAIL_COMMENT, - 'notification_type': notification_type, - 'CommentType': EntryReviewComment.CommentType, - 'Notification': Notification, - 'comment': comment, - }) - return TemplateResponse( - request, 'entry/review_comment_notification_email.html', context + context.update( + { + "email_type": EmailCondition.EMAIL_COMMENT, + "notification_type": notification_type, + "CommentType": EntryReviewComment.CommentType, + "Notification": Notification, + "comment": comment, + } ) + return TemplateResponse(request, "entry/review_comment_notification_email.html", context) class CustomGraphQLView(FileUploadGraphQLView): @@ -255,26 +239,22 @@ def execute_graphql_request( operation_name, show_graphiql, ): - operation_type = self.get_backend(request)\ - .document_from_string(self.schema, query)\ - .get_operation_type(operation_name) + operation_type = self.get_backend(request).document_from_string(self.schema, query).get_operation_type(operation_name) with sentry_start_transaction(op=operation_type, name=operation_name): - return super().execute_graphql_request( - request, data, query, variables, operation_name, show_graphiql - ) + return super().execute_graphql_request(request, data, query, variables, operation_name, show_graphiql) @staticmethod def format_error(error): formatted_error = GraphQLView.format_error(error) - original_error = getattr(error, 'original_error', None) + original_error = getattr(error, "original_error", None) extensions = {} if original_error: - if hasattr(original_error, 'code'): - extensions['code'] = str(error.original_error.code) + if hasattr(original_error, "code"): + extensions["code"] = str(error.original_error.code) elif type(original_error) in [PermissionDenied, DjPermissionDenied]: - extensions['code'] = str(PermissionDeniedException.code) - formatted_error['message'] = str(PermissionDeniedException.default_message) + extensions["code"] = str(PermissionDeniedException.code) + formatted_error["message"] = str(PermissionDeniedException.default_message) else: - extensions['errorCode'] = str(status.HTTP_500_INTERNAL_SERVER_ERROR) - formatted_error['extensions'] = extensions + extensions["errorCode"] = str(status.HTTP_500_INTERNAL_SERVER_ERROR) + formatted_error["extensions"] = extensions return formatted_error diff --git a/deep/writable_nested_serializers.py b/deep/writable_nested_serializers.py index b866e2815f..54b94107da 100644 --- a/deep/writable_nested_serializers.py +++ b/deep/writable_nested_serializers.py @@ -4,8 +4,8 @@ from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.models import ContentType -from django.db.models import ProtectedError from django.core.exceptions import FieldDoesNotExist +from django.db.models import ProtectedError from django.db.models.fields.related import ForeignObjectRel from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -15,14 +15,15 @@ class ListToDictField(serializers.Field): """ Represent a list of entities as a dictionary """ + def __init__(self, *args, **kwargs): - self.child = kwargs.pop('child') - self.key = kwargs.pop('key') + self.child = kwargs.pop("child") + self.key = kwargs.pop("key") assert self.child.source is None, ( - 'The `source` argument is not meaningful when ' - 'applied to a `child=` field. ' - 'Remove `source=` from the field declaration.' + "The `source` argument is not meaningful when " + "applied to a `child=` field. " + "Remove `source=` from the field declaration." ) super().__init__(*args, **kwargs) @@ -43,10 +44,12 @@ def to_internal_value(self, data): def to_list_data(self, data): list_data = [] for key, value in data.items(): - list_data.append({ - self.key: key, - **value, - }) + list_data.append( + { + self.key: key, + **value, + } + ) return list_data @@ -70,20 +73,15 @@ def _extract_relations(self, validated_data): continue validated_data.pop(field.source) - reverse_relations[field_name] = ( - related_field, field, field.source - ) + reverse_relations[field_name] = (related_field, field, field.source) - if isinstance(field, serializers.ListSerializer) and \ - isinstance(field.child, serializers.ModelSerializer): + if isinstance(field, serializers.ListSerializer) and isinstance(field.child, serializers.ModelSerializer): if field.source not in validated_data: # Skip field if field is not required continue validated_data.pop(field.source) - reverse_relations[field_name] = ( - related_field, field.child, field.source - ) + reverse_relations[field_name] = (related_field, field.child, field.source) if isinstance(field, serializers.ModelSerializer): if field.source not in validated_data: @@ -102,14 +100,13 @@ def _extract_relations(self, validated_data): if direct: relations[field_name] = (field, field.source) else: - reverse_relations[field_name] = ( - related_field, field, field.source) + reverse_relations[field_name] = (related_field, field, field.source) return relations, reverse_relations def _get_related_field(self, field): model_class = self.Meta.model - if field.source.endswith('_set'): + if field.source.endswith("_set"): related_field = model_class._meta.get_field(field.source[:-4]) else: related_field = model_class._meta.get_field(field.source) @@ -119,16 +116,17 @@ def _get_related_field(self, field): return related_field, True def _get_serializer_for_field(self, field, **kwargs): - kwargs.update({ - 'context': self.context, - 'partial': self.partial, - }) + kwargs.update( + { + "context": self.context, + "partial": self.partial, + } + ) return field.__class__(**kwargs) def _get_generic_lookup(self, instance, related_field): return { - related_field.content_type_field_name: - ContentType.objects.get_for_model(instance), + related_field.content_type_field_name: ContentType.objects.get_for_model(instance), related_field.object_id_field_name: instance.pk, } @@ -141,16 +139,13 @@ def prefetch_related_instances(self, field, related_data): pk_list.append(pk) instances = { - str(related_instance.pk): related_instance - for related_instance in model_class.objects.filter( - pk__in=pk_list - ) + str(related_instance.pk): related_instance for related_instance in model_class.objects.filter(pk__in=pk_list) } return instances def _get_related_pk(self, data, model_class): - pk = data.get('pk') or data.get(model_class._meta.pk.attname) + pk = data.get("pk") or data.get(model_class._meta.pk.attname) if pk: return str(pk) @@ -159,8 +154,7 @@ def _get_related_pk(self, data, model_class): def update_or_create_reverse_relations(self, instance, reverse_relations): # Update or create reverse relations: # many-to-one, many-to-many, reversed one-to-one - for field_name, (related_field, field, field_source) in \ - reverse_relations.items(): + for field_name, (related_field, field, field_source) in reverse_relations.items(): related_data = self.initial_data[field_name] # Expand to array of one item for one-to-one for uniformity if related_field.one_to_one: @@ -185,9 +179,7 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): new_related_instances = [] for data in related_data: - obj = instances.get( - self._get_related_pk(data, field.Meta.model) - ) + obj = instances.get(self._get_related_pk(data, field.Meta.model)) serializer = self._get_serializer_for_field( field, instance=obj, @@ -195,7 +187,7 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): ) serializer.is_valid(raise_exception=True) related_instance = serializer.save(**save_kwargs) - data['pk'] = related_instance.pk + data["pk"] = related_instance.pk new_related_instances.append(related_instance) if related_field.many_to_many: @@ -219,9 +211,7 @@ def update_or_create_direct_relations(self, attrs, relations): data=data, ) serializer.is_valid(raise_exception=True) - attrs[field_source] = serializer.save( - **self.get_save_kwargs(field_name) - ) + attrs[field_source] = serializer.save(**self.get_save_kwargs(field_name)) def save(self, **kwargs): self.save_kwargs = defaultdict(dict, kwargs) @@ -231,9 +221,7 @@ def save(self, **kwargs): def get_save_kwargs(self, field_name): save_kwargs = self.save_kwargs[field_name] if not isinstance(save_kwargs, dict): - raise TypeError( - _("Arguments to nested serializer's `save` must be dict's") - ) + raise TypeError(_("Arguments to nested serializer's `save` must be dict's")) return save_kwargs @@ -252,6 +240,7 @@ class NestedCreateMixin(BaseNestedModelSerializer): """ Mixin adds nested create feature """ + def create(self, validated_data): relations, reverse_relations = self._extract_relations(validated_data) @@ -273,11 +262,8 @@ class NestedUpdateMixin(BaseNestedModelSerializer): """ Mixin adds update nested feature """ - default_error_messages = { - 'cannot_delete_protected': _( - "Cannot delete {instances} because " - "protected relation exists") - } + + default_error_messages = {"cannot_delete_protected": _("Cannot delete {instances} because " "protected relation exists")} def update(self, instance, validated_data): relations, reverse_relations = self._extract_relations(validated_data) @@ -299,12 +285,10 @@ def update(self, instance, validated_data): def delete_reverse_relations_if_need(self, instance, reverse_relations): # Reverse `reverse_relations` for correct delete priority - reverse_relations = OrderedDict( - reversed(list(reverse_relations.items()))) + reverse_relations = OrderedDict(reversed(list(reverse_relations.items()))) # Delete instances which is missed in data - for field_name, (related_field, field, field_source) in \ - reverse_relations.items(): + for field_name, (related_field, field, field_source) in reverse_relations.items(): # related_data = self.initial_data[field_name] related_data = self.get_initial()[field_name] @@ -319,14 +303,12 @@ def delete_reverse_relations_if_need(self, instance, reverse_relations): # M2M relation can be as direct or as reverse. For direct relation # we should use reverse relation name - if related_field.many_to_many and \ - not isinstance(related_field, ForeignObjectRel): + if related_field.many_to_many and not isinstance(related_field, ForeignObjectRel): related_field_lookup = { related_field.remote_field.name: instance, } elif isinstance(related_field, GenericRelation): - related_field_lookup = \ - self._get_generic_lookup(instance, related_field) + related_field_lookup = self._get_generic_lookup(instance, related_field) else: related_field_lookup = { related_field.name: instance, @@ -337,11 +319,7 @@ def delete_reverse_relations_if_need(self, instance, reverse_relations): try: pks_to_delete = list( - model_class.objects.filter( - **related_field_lookup - ).exclude( - pk__in=current_ids - ).values_list('pk', flat=True) + model_class.objects.filter(**related_field_lookup).exclude(pk__in=current_ids).values_list("pk", flat=True) ) if related_field.many_to_many: @@ -353,5 +331,4 @@ def delete_reverse_relations_if_need(self, instance, reverse_relations): except ProtectedError as e: instances = e.args[1] - self.fail('cannot_delete_protected', instances=", ".join([ - str(instance) for instance in instances])) + self.fail("cannot_delete_protected", instances=", ".join([str(instance) for instance in instances])) diff --git a/manage.py b/manage.py index 8ec72d9482..e2dcea0b25 100755 --- a/manage.py +++ b/manage.py @@ -5,8 +5,8 @@ if __name__ == "__main__": os.environ.setdefault("DJANGO_SETTINGS_MODULE", "deep.settings") try: - from django.core.management import execute_from_command_line from django.conf import settings + from django.core.management import execute_from_command_line except ImportError: # The above import may fail for some other reason. Ensure that the # issue is really that Django is missing to avoid masking other diff --git a/utils/aws.py b/utils/aws.py index 1346c13193..e527aca844 100644 --- a/utils/aws.py +++ b/utils/aws.py @@ -1,19 +1,20 @@ +import json +import logging + import boto3 import requests from botocore.exceptions import ClientError -import json -import logging logger = logging.getLogger(__name__) def fetch_db_credentials_from_secret_arn(cluster_secret_arn, ignore_error=False): - logger.warning(f'Fetching db cluster secret using ARN: {cluster_secret_arn}') + logger.warning(f"Fetching db cluster secret using ARN: {cluster_secret_arn}") # the passed secret is the aws arn instead session = boto3.session.Session() client = session.client( - service_name='secretsmanager', + service_name="secretsmanager", # region_name='us-east-1', ) @@ -22,28 +23,28 @@ def fetch_db_credentials_from_secret_arn(cluster_secret_arn, ignore_error=False) except ClientError as e: logger.error(f"Got client error {e.response['Error']['Code']} for {cluster_secret_arn}") else: - logger.info('Found secret...') + logger.info("Found secret...") # Secrets Manager decrypts the secret value using the associated KMS CMK # Depending on whether the secret was a string or binary, only one of these fields will be populated - if 'SecretString' in get_secret_value_response: - text_secret_data = get_secret_value_response['SecretString'] + if "SecretString" in get_secret_value_response: + text_secret_data = get_secret_value_response["SecretString"] return json.loads(text_secret_data) else: # binary_secret_data = get_secret_value_response['SecretBinary'] logger.error("Secret should be decrypted to string but found binary instead") if ignore_error: return - raise Exception('Failed to parse/fetch secret') + raise Exception("Failed to parse/fetch secret") def get_internal_ip(name): try: - resp = requests.get('http://169.254.170.2/v2/metadata', timeout=1).json() + resp = requests.get("http://169.254.170.2/v2/metadata", timeout=1).json() return [ - container['Networks'][0]['IPv4Addresses'][0] - for container in resp['Containers'] + container["Networks"][0]["IPv4Addresses"][0] + for container in resp["Containers"] # 'web' is from Dockerfile + web manifest - if container['DockerName'] == name + if container["DockerName"] == name ][0] except Exception: logger.error(f"Failed to retrieve AWS internal ip, {locals().get('resp')}", exc_info=True) diff --git a/utils/common.py b/utils/common.py index 845e629bee..5adeff05af 100644 --- a/utils/common.py +++ b/utils/common.py @@ -1,34 +1,34 @@ # -*- coding: utf-8 -*- +import datetime import hashlib +import logging import os -import re -import time import random +import re import string import tempfile -import requests -import logging -import datetime -from typing import Union, Optional +import time from collections import Counter from functools import reduce +from typing import Optional, Union +from xml.sax.saxutils import escape as xml_escape +import requests +from django.conf import settings from django.core.cache import cache -from django.utils.hashable import make_hashable -from django.utils.encoding import force_str from django.core.files.storage import FileSystemStorage, get_storage_class -from django.conf import settings -from django.utils.encoding import force_bytes, force_text -from django.utils.http import urlsafe_base64_encode, urlsafe_base64_decode -from xml.sax.saxutils import escape as xml_escape - +from django.utils.encoding import force_bytes, force_str, force_text +from django.utils.hashable import make_hashable +from django.utils.http import urlsafe_base64_decode, urlsafe_base64_encode from redis_store import redis -USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1)' + \ - ' AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36' +USER_AGENT = ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1)" + + " AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" +) DEFAULT_HEADERS = { - 'User-Agent': USER_AGENT, + "User-Agent": USER_AGENT, } ONE_DAY = 24 * 60 * 60 @@ -38,10 +38,10 @@ try: import matplotlib as mp - import plotly.io as pio import plotly.graph_objs as ploty_go + import plotly.io as pio except ImportError as e: - logger.warning(f'ImportError: {e}') + logger.warning(f"ImportError: {e}") StorageClass = get_storage_class() @@ -50,13 +50,13 @@ def sanitize_text(_text): text = _text # Remove NUL (0x00) characters - text = text.replace('\x00', '') + text = text.replace("\x00", "") # Tabs and nbsps to space - text = re.sub(r'(\t| )', ' ', text) + text = re.sub(r"(\t| )", " ", text) # Multiple spaces to single - text = re.sub(r' +', ' ', text) + text = re.sub(r" +", " ", text) # More than 3 line breaks to just 3 line breaks - text = re.sub(r'\n\s*\n\s*(\n\s*)+', '\n\n\n', text) + text = re.sub(r"\n\s*\n\s*(\n\s*)+", "\n\n\n", text) return text.strip() @@ -75,7 +75,7 @@ def write_file(r, fp): return fp -def get_temp_file(dir='/tmp/', mode='w+b', suffix=None): +def get_temp_file(dir="/tmp/", mode="w+b", suffix=None): if suffix: return tempfile.NamedTemporaryFile(dir=dir, suffix=suffix, mode=mode) return tempfile.NamedTemporaryFile(dir=dir, mode=mode) @@ -92,11 +92,11 @@ def get_file_from_url(url): def get_or_write_file(path, text): try: - extracted = open(path, 'r') + extracted = open(path, "r") except FileNotFoundError: - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(text) - extracted = open(path, 'r') + extracted = open(path, "r") return extracted @@ -109,15 +109,15 @@ def makedirs(path): def replace_ns(nsmap, tag): for k, v in nsmap.items(): - k = k or '' - tag = tag.replace('{{{}}}'.format(v), '{}:'.format(k)) + k = k or "" + tag = tag.replace("{{{}}}".format(v), "{}:".format(k)) return tag def get_ns_tag(nsmap, tag): for k, v in nsmap.items(): - k = k or '' - tag = tag.replace('{}:'.format(k), '{{{}}}'.format(v)) + k = k or "" + tag = tag.replace("{}:".format(k), "{{{}}}".format(v)) return tag @@ -125,35 +125,29 @@ def is_valid_xml_char_ordinal(c): codepoint = ord(c) # conditions ordered by presumed frequency return ( - 0x20 <= codepoint <= 0xD7FF or - codepoint in (0x9, 0xA, 0xD) or - 0xE000 <= codepoint <= 0xFFFD or - 0x10000 <= codepoint <= 0x10FFFF + 0x20 <= codepoint <= 0xD7FF + or codepoint in (0x9, 0xA, 0xD) + or 0xE000 <= codepoint <= 0xFFFD + or 0x10000 <= codepoint <= 0x10FFFF ) def get_valid_xml_string(string, escape=True): if string: s = xml_escape(string) if escape else string - return ''.join(c for c in s if is_valid_xml_char_ordinal(c)) - return '' - - -def deep_date_format( - date: Optional[Union[datetime.date, datetime.datetime]], - fallback: Optional[str] = '' -) -> Optional[str]: - if date and ( - isinstance(date, datetime.datetime) or - isinstance(date, datetime.date) - ): - return date.strftime('%d-%m-%Y') + return "".join(c for c in s if is_valid_xml_char_ordinal(c)) + return "" + + +def deep_date_format(date: Optional[Union[datetime.date, datetime.datetime]], fallback: Optional[str] = "") -> Optional[str]: + if date and (isinstance(date, datetime.datetime) or isinstance(date, datetime.date)): + return date.strftime("%d-%m-%Y") return fallback def deep_date_parse(date_str: str, raise_exception=True) -> Optional[datetime.date]: try: - return datetime.datetime.strptime(date_str, '%d-%m-%Y').date() + return datetime.datetime.strptime(date_str, "%d-%m-%Y").date() except (ValueError, TypeError) as e: if raise_exception: raise e @@ -161,14 +155,14 @@ def deep_date_parse(date_str: str, raise_exception=True) -> Optional[datetime.da def parse_date(date_str): try: - return date_str and datetime.datetime.strptime(date_str, '%d-%m-%Y') + return date_str and datetime.datetime.strptime(date_str, "%d-%m-%Y") except ValueError: return None def parse_time(time_str): try: - return time_str and datetime.datetime.strptime(time_str, '%H:%M').time() + return time_str and datetime.datetime.strptime(time_str, "%H:%M").time() except ValueError: return None @@ -188,13 +182,13 @@ def identity(x): def underscore_to_title(x): - return ' '.join([y.title() for y in x.split('_')]) + return " ".join([y.title() for y in x.split("_")]) def random_key(length=16): candidates = string.ascii_lowercase + string.digits winners = [random.choice(candidates) for _ in range(length)] - return ''.join(winners) + return "".join(winners) def get_max_occurence_and_count(items): @@ -203,9 +197,7 @@ def get_max_occurence_and_count(items): return 0, None count = Counter(items) return reduce( - lambda a, x: x if x[1] > a[1] else a, - count.items(), # [(item, count)...] - (items[0], -1) # Initial accumulator + lambda a, x: x if x[1] > a[1] else a, count.items(), (items[0], -1) # [(item, count)...] # Initial accumulator ) @@ -227,12 +219,9 @@ def excel_column_name(column_number): class LogTime: - logger = logging.getLogger('profiling') + logger = logging.getLogger("profiling") - def __init__( - self, block_name='', log_args=True, - args_accessor=identity, kwargs_accessor=identity - ): + def __init__(self, block_name="", log_args=True, args_accessor=identity, kwargs_accessor=identity): self.log_args = log_args self.block_name = block_name self.args_accessor = args_accessor @@ -246,8 +235,7 @@ def __exit__(self, *args, **kwds): if not settings.PROFILE: return end = time.time() - LogTime.logger.info("BLOCK: {} TIME {}s.".format( - self.block_name, end - self.start)) + LogTime.logger.info("BLOCK: {} TIME {}s.".format(self.block_name, end - self.start)) def __call__(self, func_to_be_tracked): def wrapper(*args, **kwargs): @@ -260,20 +248,18 @@ def wrapper(*args, **kwargs): fname = func_to_be_tracked.__name__ - str_args = 'args: {}'.format( - self.args_accessor(args) - )[:100] if self.log_args else '' + str_args = "args: {}".format(self.args_accessor(args))[:100] if self.log_args else "" - str_kwargs = 'kwargs: {}'.format( - self.kwargs_accessor(kwargs) - )[:100] if self.log_args else '' + str_kwargs = "kwargs: {}".format(self.kwargs_accessor(kwargs))[:100] if self.log_args else "" log_message = "FUNCTION[{}]: '{}({}, {})' : TIME {}s.".format( - self.block_name, fname, str_args, str_kwargs, end - start) + self.block_name, fname, str_args, str_kwargs, end - start + ) LogTime.logger.info(log_message) return ret + wrapper.__name__ = func_to_be_tracked.__name__ wrapper.__module__ = func_to_be_tracked.__module__ return wrapper @@ -300,28 +286,30 @@ def create_plot_image(func): """ Return tmp file image with func render logic """ + def func_wrapper(*args, **kwargs): - size = kwargs.pop('chart_size', (8, 4)) - if isinstance(kwargs.get('format', 'png'), list): - images_format = kwargs.pop('format') + size = kwargs.pop("chart_size", (8, 4)) + if isinstance(kwargs.get("format", "png"), list): + images_format = kwargs.pop("format") else: - images_format = [kwargs.pop('format', 'png')] + images_format = [kwargs.pop("format", "png")] func(*args, **kwargs) figure = mp.pyplot.gcf() if size: figure.set_size_inches(size) mp.pyplot.draw() - mp.pyplot.gca().spines['top'].set_visible(False) - mp.pyplot.gca().spines['right'].set_visible(False) + mp.pyplot.gca().spines["top"].set_visible(False) + mp.pyplot.gca().spines["right"].set_visible(False) images = [] for image_format in images_format: - fp = get_temp_file(suffix='.{}'.format(image_format)) - figure.savefig(fp, bbox_inches='tight', format=image_format, alpha=True, dpi=300) + fp = get_temp_file(suffix=".{}".format(image_format)) + figure.savefig(fp, bbox_inches="tight", format=image_format, alpha=True, dpi=300) mp.pyplot.close(figure) fp.seek(0) - images.append({'image': fp, 'format': image_format}) + images.append({"image": fp, "format": image_format}) return images + return func_wrapper @@ -329,54 +317,58 @@ def create_plotly_image(func): """ Return tmp file image with func render logic """ + def func_wrapper(*args, **kwargs): - width, height = kwargs.pop('chart_size', (5, 4)) - if isinstance(kwargs.get('format', 'png'), list): - images_format = kwargs.pop('format', 'png') + width, height = kwargs.pop("chart_size", (5, 4)) + if isinstance(kwargs.get("format", "png"), list): + images_format = kwargs.pop("format", "png") else: - images_format = [kwargs.pop('format', 'png')] - x_label = kwargs.pop('x_label') - y_label = kwargs.pop('y_label') - x_params = kwargs.pop('x_params', {}) - y_params = kwargs.pop('y_params', {}) + images_format = [kwargs.pop("format", "png")] + x_label = kwargs.pop("x_label") + y_label = kwargs.pop("y_label") + x_params = kwargs.pop("x_params", {}) + y_params = kwargs.pop("y_params", {}) data, layout = func(*args, **kwargs) if layout is None: - layout = ploty_go.Layout(**{ - 'title': x_label, - 'yaxis': { - **create_plotly_image.axis_config, - **y_params, - 'title': y_label, - }, - 'xaxis': { - **create_plotly_image.axis_config, - **x_params, - 'ticks': 'outside', - }, - }) + layout = ploty_go.Layout( + **{ + "title": x_label, + "yaxis": { + **create_plotly_image.axis_config, + **y_params, + "title": y_label, + }, + "xaxis": { + **create_plotly_image.axis_config, + **x_params, + "ticks": "outside", + }, + } + ) fig = ploty_go.Figure(data=data, layout=layout) images = [] for image_format in images_format: img_bytes = pio.to_image(fig, format=image_format, width=width, height=height, scale=2) - fp = get_temp_file(suffix='.{}'.format(image_format)) + fp = get_temp_file(suffix=".{}".format(image_format)) fp.write(img_bytes) fp.seek(0) - images.append({'image': fp, 'format': image_format}) + images.append({"image": fp, "format": image_format}) return images + return func_wrapper create_plotly_image.axis_config = { - 'automargin': True, - 'tickfont': dict(size=8), - 'separatethousands': True, + "automargin": True, + "tickfont": dict(size=8), + "separatethousands": True, } create_plotly_image.marker = dict( - color='teal', + color="teal", line=dict( - color='white', + color="white", width=0.5, - ) + ), ) @@ -391,24 +383,27 @@ def redis_lock(lock_key, timeout: float = 60 * 60 * 4): """ Default Lock lifetime 4 hours """ + def _dec(func): def _caller(*args, **kwargs): key = lock_key.format(*args, **kwargs) lock = redis.get_lock(key, timeout) have_lock = lock.acquire(blocking=False) if not have_lock: - logger.warning(f'Unable to get lock for {key}(ttl: {get_redis_lock_ttl(lock)})') + logger.warning(f"Unable to get lock for {key}(ttl: {get_redis_lock_ttl(lock)})") return False try: return_value = func(*args, **kwargs) or True except Exception: - logger.error('{}.{}'.format(func.__module__, func.__name__), exc_info=True) + logger.error("{}.{}".format(func.__module__, func.__name__), exc_info=True) return_value = False lock.release() return return_value + _caller.__name__ = func.__name__ _caller.__module__ = func.__module__ return _caller + return _dec @@ -418,35 +413,35 @@ def make_colormap(seq): and in the interval (0,1). """ seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] - cdict = {'red': [], 'green': [], 'blue': []} + cdict = {"red": [], "green": [], "blue": []} for i, item in enumerate(seq): if isinstance(item, float): r1, g1, b1 = seq[i - 1] r2, g2, b2 = seq[i + 1] - cdict['red'].append([item, r1, r2]) - cdict['green'].append([item, g1, g2]) - cdict['blue'].append([item, b1, b2]) - return mp.colors.LinearSegmentedColormap('CustomMap', cdict) + cdict["red"].append([item, r1, r2]) + cdict["green"].append([item, g1, g2]) + cdict["blue"].append([item, b1, b2]) + return mp.colors.LinearSegmentedColormap("CustomMap", cdict) def excel_to_python_date_format(excel_format): # TODO: support all formats # First replace excel's locale identifiers such as [$-409] by empty string - python_format = re.sub( - r'(\[\\$-\d+\])', '', excel_format.upper() - ).\ - replace('\\', '').\ - replace('YYYY', '%Y').\ - replace('YY', '%y').\ - replace('MMMM', '%m').\ - replace('MMM', '%m').\ - replace('MM', '%m').\ - replace('M', '%m').\ - replace('DD', '%d').\ - replace('D', '%d').\ - replace('HH', '%H').\ - replace('H', '%H').\ - replace('SS', '%S') + python_format = ( + re.sub(r"(\[\\$-\d+\])", "", excel_format.upper()) + .replace("\\", "") + .replace("YYYY", "%Y") + .replace("YY", "%y") + .replace("MMMM", "%m") + .replace("MMM", "%m") + .replace("MM", "%m") + .replace("M", "%m") + .replace("DD", "%d") + .replace("D", "%d") + .replace("HH", "%H") + .replace("H", "%H") + .replace("SS", "%S") + ) return python_format @@ -476,11 +471,11 @@ def calculate_md5(file): def camelcase_to_titlecase(label): - return re.sub(r'((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))', r' \1', label) + return re.sub(r"((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))", r" \1", label) def kebabcase_to_titlecase(kebab_str): - return ' '.join([x.title() for x in kebab_str.split('-')]) + return " ".join([x.title() for x in kebab_str.split("-")]) def is_valid_number(value): @@ -493,7 +488,7 @@ def is_valid_number(value): def to_camelcase(snake_str): - components = snake_str.split('_') + components = snake_str.split("_") return components[0] + "".join(x.title() for x in components[1:]) @@ -512,7 +507,7 @@ def has_prefetched(obj, field): """ Checks if field is prefetched. """ - if hasattr(obj, '_prefetched_objects_cache') and field in obj._prefetched_objects_cache: + if hasattr(obj, "_prefetched_objects_cache") and field in obj._prefetched_objects_cache: return True return False @@ -529,7 +524,7 @@ def has_select_related(obj, field): def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): - yield lst[i:i + n] + yield lst[i : i + n] def get_full_media_url(media_path, file_system_domain=None): @@ -541,7 +536,7 @@ def get_full_media_url(media_path, file_system_domain=None): return media_path -class UidBase64Helper(): +class UidBase64Helper: @staticmethod def encode(integer): return urlsafe_base64_encode(force_bytes(integer)) @@ -578,6 +573,7 @@ def graphene_cache(cache_key, cache_key_gen=None, timeout=60): """ Default Lock lifetime 4 hours """ + def _dec(func): def _caller(*args, **kwargs): if cache_key_gen: @@ -589,15 +585,17 @@ def _caller(*args, **kwargs): lambda: func(*args, **kwargs), timeout, ) + _caller.__name__ = func.__name__ _caller.__module__ = func.__module__ return _caller + return _dec def generate_sha256(text: str): m = hashlib.sha256() - m.update(text.encode('utf-8')) + m.update(text.encode("utf-8")) return m.hexdigest() @@ -605,6 +603,6 @@ def render_string_for_graphql(text): """ Return null if text is empty ("") """ - if text == '': + if text == "": return None return text diff --git a/utils/data_structures.py b/utils/data_structures.py index 6bfdc9e65e..6d51c54a2a 100644 --- a/utils/data_structures.py +++ b/utils/data_structures.py @@ -1,14 +1,14 @@ - class Dict(dict): """ Dict class where items can be accessed/set using dot notation """ + __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def __convert_nested(self, d): - """ convert dictionary(nested) to Dict """ + """convert dictionary(nested) to Dict""" new_kw = {} for k, v in d.items(): if isinstance(v, dict): diff --git a/utils/date_extractor.py b/utils/date_extractor.py index dad46f3b59..ffd059a46a 100644 --- a/utils/date_extractor.py +++ b/utils/date_extractor.py @@ -1,9 +1,8 @@ import datetime - +import re from typing import Union from dateutil.parser import parse -import re def str_to_date(value: str) -> Union[datetime.datetime, None]: @@ -15,7 +14,7 @@ def str_to_date(value: str) -> Union[datetime.datetime, None]: # Simple url check, though very less chance of working def _extract_from_url(url): - regex = r'([\./\-_]{0,1}(19|20)\d{2})[\./\-_]{0,1}(([0-3]{0,1}[0-9][\./\-_])|(\w{3,5}[\./\-_]))([0-3]{0,1}[0-9][\./\-]{0,1})?' # noqa + regex = r"([\./\-_]{0,1}(19|20)\d{2})[\./\-_]{0,1}(([0-3]{0,1}[0-9][\./\-_])|(\w{3,5}[\./\-_]))([0-3]{0,1}[0-9][\./\-]{0,1})?" # noqa m = re.search(regex, url) if m: return str_to_date(m.group(0)) @@ -24,40 +23,40 @@ def _extract_from_url(url): # Some common meta names and properties that may denote date DATE_META_NAMES = [ - 'date', - 'pubdate', - 'publishdate', - 'timestamp', - 'dc.date.issued', - 'sailthru.date', - 'article.published', - 'published-date', - 'article.created', - 'article_date_original', - 'cxenseparse:recs:publishtime', - 'date_published', - 'datepublished', - 'datecreated', - 'article:published_time', - 'bt:pubdate', + "date", + "pubdate", + "publishdate", + "timestamp", + "dc.date.issued", + "sailthru.date", + "article.published", + "published-date", + "article.created", + "article_date_original", + "cxenseparse:recs:publishtime", + "date_published", + "datepublished", + "datecreated", + "article:published_time", + "bt:pubdate", ] def _extract_from_meta(page): meta_date = None - for meta in page.findAll('meta'): - meta_name = meta.get('name', '').lower() - item_prop = meta.get('itemprop', '').lower() - meta_property = meta.get('property', '').lower() - http_equiv = meta.get('http-equiv', '').lower() + for meta in page.findAll("meta"): + meta_name = meta.get("name", "").lower() + item_prop = meta.get("itemprop", "").lower() + meta_property = meta.get("property", "").lower() + http_equiv = meta.get("http-equiv", "").lower() if ( - meta_name in DATE_META_NAMES or - item_prop in DATE_META_NAMES or - meta_property in DATE_META_NAMES or - http_equiv == 'date' + meta_name in DATE_META_NAMES + or item_prop in DATE_META_NAMES + or meta_property in DATE_META_NAMES + or http_equiv == "date" ): - meta_date = str_to_date(meta['content'].strip()) + meta_date = str_to_date(meta["content"].strip()) break return meta_date @@ -66,27 +65,26 @@ def _extract_from_meta(page): # From https://github.com/Webhose/article-date-extractor # Most probably can be optimized def _extract_from_tags(page): - for time in page.findAll('time'): - datetime = time.get('datetime', '') + for time in page.findAll("time"): + datetime = time.get("datetime", "") if len(datetime) > 0: return str_to_date(datetime) - datetime = time.get('class', '') - if len(datetime) > 0 and datetime[0].lower() == 'timestamp': + datetime = time.get("class", "") + if len(datetime) > 0 and datetime[0].lower() == "timestamp": return str_to_date(time.string) - tag = page.find('span', {'itemprop': 'datePublished'}) + tag = page.find("span", {"itemprop": "datePublished"}) if tag is not None: - date_text = tag.get('content') + date_text = tag.get("content") if date_text is None: date_text = tag.text if date_text: return str_to_date(date_text) - regex = 'pubdate|timestamp|article_date|articledate|date' - for tag in page.find_all(['span', 'p', 'div'], - class_=re.compile(regex, re.IGNORECASE)): + regex = "pubdate|timestamp|article_date|articledate|date" + for tag in page.find_all(["span", "p", "div"], class_=re.compile(regex, re.IGNORECASE)): date_text = tag.string if date_text is None: date_text = tag.text diff --git a/utils/db/functions.py b/utils/db/functions.py index 6c5ae19f1f..4b9c6771a8 100644 --- a/utils/db/functions.py +++ b/utils/db/functions.py @@ -1,18 +1,18 @@ -from django.db.models import Func, Transform, BooleanField from django.contrib.gis.db.models.fields import BaseSpatialField from django.contrib.gis.db.models.functions import GeoFuncMixin +from django.db.models import BooleanField, Func, Transform class StrPos(Func): - function = 'POSITION' # MySQL method + function = "POSITION" # MySQL method def as_sqlite(self, compiler, connection): # SQLite method - return self.as_sql(compiler, connection, function='INSTR') + return self.as_sql(compiler, connection, function="INSTR") def as_postgresql(self, compiler, connection): # PostgreSQL method - return self.as_sql(compiler, connection, function='STRPOS') + return self.as_sql(compiler, connection, function="STRPOS") @BaseSpatialField.register_lookup diff --git a/utils/external_storages/dropbox.py b/utils/external_storages/dropbox.py index d643197ce4..d8400af423 100644 --- a/utils/external_storages/dropbox.py +++ b/utils/external_storages/dropbox.py @@ -1,7 +1,9 @@ -import requests -from utils.common import write_file, DEFAULT_HEADERS import tempfile +import requests + +from utils.common import DEFAULT_HEADERS, write_file + def download(file_url, SUPPORTED_MIME_TYPES): """ diff --git a/utils/external_storages/google_drive.py b/utils/external_storages/google_drive.py index bd3238c2ed..f914c86377 100644 --- a/utils/external_storages/google_drive.py +++ b/utils/external_storages/google_drive.py @@ -1,25 +1,22 @@ -import httplib2 - -from rest_framework import serializers +import tempfile -from utils.common import USER_AGENT +import httplib2 from apiclient import discovery +from apiclient.http import MediaIoBaseDownload from oauth2client import client -import tempfile +from rest_framework import serializers -from apiclient.http import MediaIoBaseDownload +from utils.common import USER_AGENT # Google Specific Mimetypes -GDOCS = 'application/vnd.google-apps.document' -GSLIDES = 'application/vnd.google-apps.presentation' -GSHEETS = 'application/vnd.google-apps.spreadsheet' +GDOCS = "application/vnd.google-apps.document" +GSLIDES = "application/vnd.google-apps.presentation" +GSHEETS = "application/vnd.google-apps.spreadsheet" # Standard Mimetypes -DOCX = 'application/vnd.openxmlformats-officedocument.wordprocessingml.'\ - 'document' -PPT = 'application/vnd.openxmlformats-officedocument.presentationml.'\ - 'presentation' -EXCEL = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' +DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml." "document" +PPT = "application/vnd.openxmlformats-officedocument.presentationml." "presentation" +EXCEL = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" # Goggle Specific mimetypes to Standard Mimetypes mapping GOOGLE_DRIVE_EXPORT_MAP = { @@ -34,12 +31,7 @@ def get_credentials(access_token): return credentials -def download( - file_id, - mime_type, - access_token, - SUPPORTED_MIME_TYPES -): +def download(file_id, mime_type, access_token, SUPPORTED_MIME_TYPES): """ Download/Export file from google drive @@ -53,19 +45,16 @@ def download( credentials = get_credentials(access_token) http = credentials.authorize(httplib2.Http()) - service = discovery.build('drive', 'v3', http=http) + service = discovery.build("drive", "v3", http=http) if mime_type in SUPPORTED_MIME_TYPES: # Directly dowload the file request = service.files().get_media(fileId=file_id) elif mime_type in GOOGLE_DRIVE_EXPORT_MAP: export_mime_type = GOOGLE_DRIVE_EXPORT_MAP.get(mime_type) - request = service.files().export_media( - fileId=file_id, - mimeType=export_mime_type - ) + request = service.files().export_media(fileId=file_id, mimeType=export_mime_type) else: - raise serializers.ValidationError('Unsupported file type {}'.format(mime_type)) + raise serializers.ValidationError("Unsupported file type {}".format(mime_type)) outfp = tempfile.TemporaryFile("wb+") downloader = MediaIoBaseDownload(outfp, request) diff --git a/utils/extractor/document.py b/utils/extractor/document.py index cd807dbc08..76c793433f 100644 --- a/utils/extractor/document.py +++ b/utils/extractor/document.py @@ -1,10 +1,10 @@ from utils.extractor import extractors -HTML = 'html' -PDF = 'pdf' -DOCX = 'docx' -PPTX = 'pptx' -MSWORD = 'doc' +HTML = "html" +PDF = "pdf" +DOCX = "docx" +PPTX = "pptx" +MSWORD = "doc" EXTRACTORS = { HTML: extractors.HtmlExtractor, @@ -36,4 +36,4 @@ def extract(self): extractor = EXTRACTORS.get(self.type) if extractor: return extractor(self.doc, self.params).extract() - return '', [], 1 + return "", [], 1 diff --git a/utils/extractor/extractors.py b/utils/extractor/extractors.py index c27e866f1d..2030148aa2 100644 --- a/utils/extractor/extractors.py +++ b/utils/extractor/extractors.py @@ -1,11 +1,9 @@ from .exception import ExtractError +from .formats.docx import msword_process as msword_extract +from .formats.docx import pptx_process as pptx_extract +from .formats.docx import process as docx_extract from .formats.html import process as html_extract from .formats.pdf import process as pdf_extract -from .formats.docx import ( - process as docx_extract, - pptx_process as pptx_extract, - msword_process as msword_extract -) class BaseExtractor: @@ -15,6 +13,7 @@ class BaseExtractor: Verify Simlify """ + def __init__(self, doc, params=None): self.doc = doc self.params = params @@ -29,23 +28,21 @@ def extract(self): def verify(self): if not self.doc: raise ExtractError(self.ERROR_MSG) - if not hasattr(self.__class__, 'EXTRACT_METHOD'): - raise ExtractError( - "Class '{}' have no EXTRACT_METHOD Method". - format(self.__class__.__name__) - ) + if not hasattr(self.__class__, "EXTRACT_METHOD"): + raise ExtractError("Class '{}' have no EXTRACT_METHOD Method".format(self.__class__.__name__)) class HtmlExtractor(BaseExtractor): """ Extractor class to extract HTML documents. """ + ERROR_MSG = "Not a html document" EXTRACT_METHOD = html_extract def extract(self): self.verify() - url = self.params.get('url') if self.params else None + url = self.params.get("url") if self.params else None return self.__class__.EXTRACT_METHOD(self.doc, url) @@ -53,6 +50,7 @@ class PdfExtractor(BaseExtractor): """ Extractor class to extract PDF documents. """ + ERROR_MSG = "Not a pdf document" EXTRACT_METHOD = pdf_extract @@ -61,6 +59,7 @@ class DocxExtractor(BaseExtractor): """ Extractor class to extract Docx documents. """ + ERROR_MSG = "Not a docx document" EXTRACT_METHOD = docx_extract @@ -69,6 +68,7 @@ class PptxExtractor(BaseExtractor): """ Extractor class to extract PPTX documents. """ + ERROR_MSG = "Not a pptx document" EXTRACT_METHOD = pptx_extract @@ -77,5 +77,6 @@ class MswordExtractor(BaseExtractor): """ Extractor class to extract msword documents. """ + ERROR_MSG = "Not a msword (.doc) document" EXTRACT_METHOD = msword_extract diff --git a/utils/extractor/file_document.py b/utils/extractor/file_document.py index 0254a01352..ca3a5ca79c 100644 --- a/utils/extractor/file_document.py +++ b/utils/extractor/file_document.py @@ -1,8 +1,6 @@ import os -from .document import ( - Document, - HTML, PDF, DOCX, PPTX, MSWORD, -) + +from .document import DOCX, HTML, MSWORD, PDF, PPTX, Document class FileDocument(Document): @@ -11,11 +9,20 @@ class FileDocument(Document): Takes file Gives document and type """ - HTML_TYPES = ['.html', '.htm', '.txt'] - PDF_TYPES = ['.pdf', ] - DOCX_TYPES = ['.docx', ] - MSWORD_TYPES = ['.doc', ] - PPTX_TYPES = ['.pptx', ] + + HTML_TYPES = [".html", ".htm", ".txt"] + PDF_TYPES = [ + ".pdf", + ] + DOCX_TYPES = [ + ".docx", + ] + MSWORD_TYPES = [ + ".doc", + ] + PPTX_TYPES = [ + ".pptx", + ] def __init__(self, file, name): diff --git a/utils/extractor/formats/docx.py b/utils/extractor/formats/docx.py index 2756124a12..79c07c05f3 100644 --- a/utils/extractor/formats/docx.py +++ b/utils/extractor/formats/docx.py @@ -1,17 +1,18 @@ #! /usr/bin/env python3 -import xml.etree.ElementTree as ET -from django.conf import settings import argparse -import tempfile -import zipfile -import sys -import re +import logging import os import random +import re import string +import sys +import tempfile +import xml.etree.ElementTree as ET +import zipfile from subprocess import call -import logging + +from django.conf import settings logger = logging.getLogger(__name__) @@ -21,25 +22,23 @@ text, images = process(doc) -> images for tempfile """ -nsmap = {'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main', - 'p': 'http://schemas.openxmlformats.org/presentationml/2006/main', - 'a': 'http://schemas.openxmlformats.org/drawingml/2006/main', - 'wP': 'http://schemas.openxmlformats.org/officeDocument/2006/extended-properties', # noqa - } +nsmap = { + "w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main", + "p": "http://schemas.openxmlformats.org/presentationml/2006/main", + "a": "http://schemas.openxmlformats.org/drawingml/2006/main", + "wP": "http://schemas.openxmlformats.org/officeDocument/2006/extended-properties", # noqa +} def process_args(): - parser = argparse.ArgumentParser(description='A pure python-based utility ' - 'to extract text and images ' - 'from docx files.') + parser = argparse.ArgumentParser(description="A pure python-based utility " "to extract text and images " "from docx files.") parser.add_argument("docx", help="path of the docx file") - parser.add_argument('-i', '--img_dir', help='path of directory ' - 'to extract images') + parser.add_argument("-i", "--img_dir", help="path of directory " "to extract images") args = parser.parse_args() if not os.path.exists(args.docx): - print('File {} does not exist.'.format(args.docx)) + print("File {} does not exist.".format(args.docx)) sys.exit(1) if args.img_dir is not None: @@ -59,9 +58,9 @@ def qn(tag): example, ``qn('p:cSld')`` returns ``'{http://schemas.../main}cSld'``. Source: https://github.com/python-openxml/python-docx/ """ - prefix, tagroot = tag.split(':') + prefix, tagroot = tag.split(":") uri = nsmap[prefix] - return '{{{}}}{}'.format(uri, tagroot) + return "{{{}}}{}".format(uri, tagroot) def xml2text(xml, pptx=False): @@ -71,35 +70,35 @@ def xml2text(xml, pptx=False): equivalent. Adapted from: https://github.com/python-openxml/python-docx/ """ - text = u'' + text = "" root = ET.fromstring(xml) if pptx is False: for child in root.iter(): - if child.tag == qn('w:t'): + if child.tag == qn("w:t"): t_text = child.text - text += t_text if t_text is not None else '' - elif child.tag == qn('w:tab'): - text += '\t' - elif child.tag in (qn('w:br'), qn('w:cr')): - text += '\n' + text += t_text if t_text is not None else "" + elif child.tag == qn("w:tab"): + text += "\t" + elif child.tag in (qn("w:br"), qn("w:cr")): + text += "\n" elif child.tag == qn("w:p"): - text += '\n\n' + text += "\n\n" else: for child in root.iter(): - if child.tag == qn('a:t'): + if child.tag == qn("a:t"): t_text = child.text - text += t_text if t_text is not None else '' - elif child.tag == qn('a:tab'): - text += '\t' - elif child.tag in (qn('a:br'), qn('a:cr')): - text += '\n' + text += t_text if t_text is not None else "" + elif child.tag == qn("a:tab"): + text += "\t" + elif child.tag in (qn("a:br"), qn("a:cr")): + text += "\n" elif child.tag in (qn("a:p"), qn("a:bodyPr"), qn("a:fld")): - text += '\n\n' + text += "\n\n" return text def process(docx, pptx=False, img_dir=None): - text = u'' + text = "" # unzip the docx in memory zipf = zipfile.ZipFile(docx) @@ -109,13 +108,13 @@ def process(docx, pptx=False, img_dir=None): # get header text # there can be 3 header files in the zip - header_xmls = 'ppt/header[0-9]*.xml' if pptx else 'word/header[0-9]*.xml' + header_xmls = "ppt/header[0-9]*.xml" if pptx else "word/header[0-9]*.xml" for fname in filelist: if re.match(header_xmls, fname): text += xml2text(zipf.read(fname)) # get main text - doc_xml = 'ppt/slides/slide[0-9]*.xml' if pptx else 'word/document.xml' + doc_xml = "ppt/slides/slide[0-9]*.xml" if pptx else "word/document.xml" if pptx: for fname in filelist: if re.match(doc_xml, fname): @@ -129,7 +128,7 @@ def process(docx, pptx=False, img_dir=None): # get footer text # there can be 3 footer files in the zip - footer_xmls = 'ppt/footer[0-9]*.xml' if pptx else 'word/footer[0-9]*.xml' + footer_xmls = "ppt/footer[0-9]*.xml" if pptx else "word/footer[0-9]*.xml" for fname in filelist: if re.match(footer_xmls, fname): text += xml2text(zipf.read(fname)) @@ -157,48 +156,53 @@ def pptx_process(docx, img_dir=None): def msword_process(doc, img_dir=None): - tmp_filepath = '/tmp/{}'.format( - ''.join(random.sample(string.ascii_lowercase, 10)) + '.doc' - ) + tmp_filepath = "/tmp/{}".format("".join(random.sample(string.ascii_lowercase, 10)) + ".doc") - with open(tmp_filepath, 'wb') as tmpdoc: + with open(tmp_filepath, "wb") as tmpdoc: tmpdoc.write(doc.read()) tmpdoc.flush() - call([ - 'libreoffice', '--headless', '--convert-to', 'docx', - tmp_filepath, '--outdir', settings.TEMP_DIR, - ]) - - doc_filename = os.path.join( - settings.TEMP_DIR, - re.sub(r'doc$', 'docx', os.path.basename(tmp_filepath)) + call( + [ + "libreoffice", + "--headless", + "--convert-to", + "docx", + tmp_filepath, + "--outdir", + settings.TEMP_DIR, + ] ) + + doc_filename = os.path.join(settings.TEMP_DIR, re.sub(r"doc$", "docx", os.path.basename(tmp_filepath))) # docx = open(doc_filename) response = process(doc_filename) # Clean up converted docx file - call(['rm', '-f', doc_filename, tmp_filepath]) + call(["rm", "-f", doc_filename, tmp_filepath]) return response def get_pages_in_docx(file): with zipfile.ZipFile(file) as zipf: try: - xml = zipf.read('docProps/app.xml') - pages = ET.fromstring(xml).find('wP:Pages', nsmap) + xml = zipf.read("docProps/app.xml") + pages = ET.fromstring(xml).find("wP:Pages", nsmap) # pages could be False or None return int(pages.text) if pages is not None else 0 except KeyError: - logger.warning('Error reading page from docx {}'.format( - file, - ), exc_info=True) + logger.warning( + "Error reading page from docx {}".format( + file, + ), + exc_info=True, + ) return 0 -if __name__ == '__main__': +if __name__ == "__main__": args = process_args() text, images = process(args.docx, args.img_dir) - print(text.encode('utf-8')) + print(text.encode("utf-8")) print(images) diff --git a/utils/extractor/formats/html.py b/utils/extractor/formats/html.py index bbef53fb6f..3fc8914dcb 100644 --- a/utils/extractor/formats/html.py +++ b/utils/extractor/formats/html.py @@ -1,14 +1,13 @@ -from readability.readability import Document -from urllib.parse import urljoin - +import base64 import logging import re -import requests import tempfile -import base64 -from bs4 import BeautifulSoup +from urllib.parse import urljoin +import requests +from bs4 import BeautifulSoup from django.conf import settings +from readability.readability import Document from utils.common import write_file @@ -16,21 +15,21 @@ def _replace_with_newlines(element): - text = '' + text = "" for elem in element.recursiveChildGenerator(): if isinstance(elem, str): text += elem.strip() - elif elem.name == 'br': - text += '\n\n' + elif elem.name == "br": + text += "\n\n" return text.strip() def _get_plain_text(soup): - plain_text = '' - for line in soup.findAll('p'): + plain_text = "" + for line in soup.findAll("p"): line = _replace_with_newlines(line) plain_text += line - plain_text += '\n\n' + plain_text += "\n\n" return plain_text.strip() @@ -40,29 +39,29 @@ def process(doc, url): title = html_body.short_title() images = [] - for img in html_body.reverse_tags(html_body.html, 'img'): + for img in html_body.reverse_tags(html_body.html, "img"): try: fp = tempfile.NamedTemporaryFile(dir=settings.TEMP_DIR) - img_src = urljoin(url, img.get('src')) - if re.search(r'http[s]?://', img_src): + img_src = urljoin(url, img.get("src")) + if re.search(r"http[s]?://", img_src): r = requests.get(img_src, stream=True) write_file(r, fp) else: - image = base64.b64decode(img_src.split(',')[1]) + image = base64.b64decode(img_src.split(",")[1]) fp.write(image) images.append(fp) except Exception: logger.error( - 'extractor.formats.html Image Collector Error!!', + "extractor.formats.html Image Collector Error!!", exc_info=True, - extra={'data': {'url': url}}, + extra={"data": {"url": url}}, ) - html = '<h1>' + title + '</h1>' + summary + html = "<h1>" + title + "</h1>" + summary - regex = re.compile('\n*', flags=re.IGNORECASE) - html = '<p>{}</p>'.format(regex.sub('', html)) + regex = re.compile("\n*", flags=re.IGNORECASE) + html = "<p>{}</p>".format(regex.sub("", html)) - soup = BeautifulSoup(html, 'lxml') + soup = BeautifulSoup(html, "lxml") text = _get_plain_text(soup) return text, images, 1 diff --git a/utils/extractor/formats/ods.py b/utils/extractor/formats/ods.py index 876b2504f6..2555c635ad 100644 --- a/utils/extractor/formats/ods.py +++ b/utils/extractor/formats/ods.py @@ -5,10 +5,12 @@ def extract_meta(xlsx_file): workbook = pyexcel_ods.get_data(xlsx_file) wb_sheets = [] for index, wb_sheet in enumerate(workbook): - wb_sheets.append({ - 'key': str(index), - 'title': wb_sheet, - }) + wb_sheets.append( + { + "key": str(index), + "title": wb_sheet, + } + ) return { - 'sheets': wb_sheets, + "sheets": wb_sheets, } diff --git a/utils/extractor/formats/pdf.py b/utils/extractor/formats/pdf.py index de9344c9b5..9f06ef7dfe 100644 --- a/utils/extractor/formats/pdf.py +++ b/utils/extractor/formats/pdf.py @@ -1,14 +1,11 @@ from io import BytesIO -from pdfminer.pdfparser import PDFParser -from pdfminer.pdfdocument import PDFDocument -from pdfminer.pdfinterp import ( - resolve1, - PDFResourceManager, - PDFPageInterpreter, -) + from pdfminer.converter import TextConverter from pdfminer.layout import LAParams +from pdfminer.pdfdocument import PDFDocument +from pdfminer.pdfinterp import PDFPageInterpreter, PDFResourceManager, resolve1 from pdfminer.pdfpage import PDFPage +from pdfminer.pdfparser import PDFParser def process(doc): @@ -19,15 +16,21 @@ def process(doc): rsrcmgr = PDFResourceManager() laparams = LAParams() with TextConverter( - rsrcmgr, retstr, codec='utf-8', laparams=laparams, + rsrcmgr, + retstr, + codec="utf-8", + laparams=laparams, ) as device: interpreter = PDFPageInterpreter(rsrcmgr, device) maxpages = 0 caching = True pagenos = set() for page in PDFPage.get_pages( - fp, pagenos, maxpages=maxpages, - caching=caching, check_extractable=True, + fp, + pagenos, + maxpages=maxpages, + caching=caching, + check_extractable=True, ): interpreter.process_page(page) content = retstr.getvalue().decode() @@ -37,4 +40,4 @@ def process(doc): def get_pages_in_pdf(file): document = PDFDocument(PDFParser(file)) - return resolve1(document.catalog['Pages'])['Count'] + return resolve1(document.catalog["Pages"])["Count"] diff --git a/utils/extractor/formats/xlsx.py b/utils/extractor/formats/xlsx.py index aeb57fa01a..1ac560e9e4 100644 --- a/utils/extractor/formats/xlsx.py +++ b/utils/extractor/formats/xlsx.py @@ -5,11 +5,13 @@ def extract_meta(xlsx_file): workbook = load_workbook(xlsx_file, data_only=True, read_only=True) wb_sheets = [] for index, wb_sheet in enumerate(workbook.worksheets): - if wb_sheet.sheet_state != 'hidden': - wb_sheets.append({ - 'key': str(index), - 'title': wb_sheet.title, - }) + if wb_sheet.sheet_state != "hidden": + wb_sheets.append( + { + "key": str(index), + "title": wb_sheet.title, + } + ) return { - 'sheets': wb_sheets, + "sheets": wb_sheets, } diff --git a/utils/extractor/tests/test_extractors.py b/utils/extractor/tests/test_extractors.py index 539eb9bf93..fb5ef59270 100644 --- a/utils/extractor/tests/test_extractors.py +++ b/utils/extractor/tests/test_extractors.py @@ -1,12 +1,11 @@ from os.path import join -from django.test import TestCase from django.conf import settings +from django.test import TestCase from utils.common import get_or_write_file -from ..extractors import ( - PdfExtractor, DocxExtractor, PptxExtractor -) + +from ..extractors import DocxExtractor, PdfExtractor, PptxExtractor class ExtractorTest(TestCase): @@ -15,12 +14,13 @@ class ExtractorTest(TestCase): Pdf, Pptx and docx Note: Html test is in WebDocument Test """ + def setUp(self): - self.path = join(settings.TEST_DIR, 'documents') + self.path = join(settings.TEST_DIR, "documents") def extract(self, extractor, path): text, images, page_count = extractor.extract() - extracted = get_or_write_file(path + '.txt', text) + extracted = get_or_write_file(path + ".txt", text) self.assertEqual(text, extracted.read()) # TODO: Verify image @@ -30,22 +30,22 @@ def test_docx(self): """ Test Docx import """ - docx_file = join(self.path, 'doc.docx') - extractor = DocxExtractor(open(docx_file, 'rb+')) + docx_file = join(self.path, "doc.docx") + extractor = DocxExtractor(open(docx_file, "rb+")) self.extract(extractor, docx_file) def test_pptx(self): """ Test pptx import """ - pptx_file = join(self.path, 'doc.pptx') - extractor = PptxExtractor(open(pptx_file, 'rb+')) + pptx_file = join(self.path, "doc.pptx") + extractor = PptxExtractor(open(pptx_file, "rb+")) self.extract(extractor, pptx_file) def test_pdf(self): """ Test Pdf import """ - pdf_file = join(self.path, 'doc.pdf') - extractor = PdfExtractor(open(pdf_file, 'rb+')) + pdf_file = join(self.path, "doc.pdf") + extractor = PdfExtractor(open(pdf_file, "rb+")) self.extract(extractor, pdf_file) diff --git a/utils/extractor/tests/test_file_document.py b/utils/extractor/tests/test_file_document.py index 6a6f9f921e..9d8ba10960 100644 --- a/utils/extractor/tests/test_file_document.py +++ b/utils/extractor/tests/test_file_document.py @@ -1,20 +1,18 @@ -from os.path import ( - join, - # isfile, -) import json import logging +from os.path import join # isfile, -from django.test import TestCase from django.conf import settings +from django.test import TestCase + +from utils.common import get_or_write_file, makedirs -from utils.common import (get_or_write_file, makedirs) from ..file_document import FileDocument # TODO: Review/Add better urls -DOCX_FILE = 'doc.docx' -PPTX_FILE = 'doc.pptx' -PDF_FILE = 'doc.pdf' +DOCX_FILE = "doc.docx" +PPTX_FILE = "doc.pptx" +PDF_FILE = "doc.pdf" logger = logging.getLogger(__name__) @@ -24,28 +22,26 @@ class FileDocumentTest(TestCase): Import Test using files Html, Pdf, Pptx and docx """ + def setUp(self): - self.path = join(settings.TEST_DIR, 'documents_attachment') - self.documents = join(settings.TEST_DIR, 'documents') + self.path = join(settings.TEST_DIR, "documents_attachment") + self.documents = join(settings.TEST_DIR, "documents") - with open(join(self.documents, 'pages.json'), 'r') as pages: + with open(join(self.documents, "pages.json"), "r") as pages: self.pages = json.load(pages) makedirs(self.path) def extract(self, path): - file = open(join(self.documents, path), 'rb') - filename = file.name.split('/')[-1] - text, images, page_count = FileDocument( - file, - filename - ).extract() + file = open(join(self.documents, path), "rb") + filename = file.name.split("/")[-1] + text, images, page_count = FileDocument(file, filename).extract() path = join(self.path, filename) - extracted = get_or_write_file(path + '.txt', text) + extracted = get_or_write_file(path + ".txt", text) self.assertEqual(text, extracted.read()) - self.assertEqual(page_count, self.pages[filename.split('.')[-1]]) + self.assertEqual(page_count, self.pages[filename.split(".")[-1]]) # TODO: Verify image # self.assertEqual(len(images), 4) diff --git a/utils/extractor/tests/test_web_document.py b/utils/extractor/tests/test_web_document.py index c4dadf2f60..a95d788327 100644 --- a/utils/extractor/tests/test_web_document.py +++ b/utils/extractor/tests/test_web_document.py @@ -1,22 +1,22 @@ -import os -import logging import json +import logging +import os -from django.test import TestCase from django.conf import settings +from django.test import TestCase + +from utils.common import get_or_write_file, makedirs -from utils.common import (get_or_write_file, makedirs) from ..web_document import WebDocument logger = logging.getLogger(__name__) # TODO: Review/Add better urls -REDHUM_URL = 'https://redhum.org/documento/3227553' -HTML_URL = 'https://reliefweb.int/report/occupied-palestinian-territory/rehabilitation-services-urgently-needed-prevent-disability' # noqa -DOCX_URL = 'https://calibre-ebook.com/downloads/demos/demo.docx' -PPTX_URL = 'https://www.mhc.ab.ca/-/media/Files/PDF/Services/Online/'\ - 'BBSamples/powerpoint.pptx' -PDF_URL = 'http://che.org.il/wp-content/uploads/2016/12/pdf-sample.pdf' +REDHUM_URL = "https://redhum.org/documento/3227553" +HTML_URL = "https://reliefweb.int/report/occupied-palestinian-territory/rehabilitation-services-urgently-needed-prevent-disability" # noqa +DOCX_URL = "https://calibre-ebook.com/downloads/demos/demo.docx" +PPTX_URL = "https://www.mhc.ab.ca/-/media/Files/PDF/Services/Online/" "BBSamples/powerpoint.pptx" +PDF_URL = "http://che.org.il/wp-content/uploads/2016/12/pdf-sample.pdf" class WebDocumentTest(TestCase): @@ -24,9 +24,10 @@ class WebDocumentTest(TestCase): Import Test using urls Html, Pdf, Pptx and docx """ + def setUp(self): - self.path = os.path.join(settings.TEST_DIR, 'documents_urls') - with open(os.path.join(self.path, 'pages.json'), 'r') as pages: + self.path = os.path.join(settings.TEST_DIR, "documents_urls") + with open(os.path.join(self.path, "pages.json"), "r") as pages: self.pages = json.load(pages) makedirs(self.path) @@ -35,14 +36,15 @@ def extract(self, url, type): text, images, page_count = WebDocument(url).extract() except Exception: import traceback - logger.warning('\n' + ('*' * 30)) - logger.warning('EXTRACTOR ERROR: WEBDOCUMENT: ' + type.upper()) + + logger.warning("\n" + ("*" * 30)) + logger.warning("EXTRACTOR ERROR: WEBDOCUMENT: " + type.upper()) logger.warning(traceback.format_exc()) return - path = os.path.join(self.path, '.'.join(url.split('/')[-1:])) + path = os.path.join(self.path, ".".join(url.split("/")[-1:])) - extracted = get_or_write_file(path + '.txt', text) + extracted = get_or_write_file(path + ".txt", text) try: # TODO: Better way to handle the errors @@ -50,8 +52,9 @@ def extract(self, url, type): self.assertEqual(page_count, self.pages[type]) except AssertionError: import traceback - logger.warning('\n' + ('*' * 30)) - logger.warning('EXTRACTOR ERROR: WEBDOCUMENT: ' + type.upper()) + + logger.warning("\n" + ("*" * 30)) + logger.warning("EXTRACTOR ERROR: WEBDOCUMENT: " + type.upper()) logger.warning(traceback.format_exc()) # TODO: Verify image # self.assertEqual(len(images), 4) @@ -60,22 +63,22 @@ def test_html(self): """ Test html import """ - self.extract(HTML_URL, 'html') + self.extract(HTML_URL, "html") def test_docx(self): """ Test Docx import """ - self.extract(DOCX_URL, 'docx') + self.extract(DOCX_URL, "docx") def test_pptx(self): """ Test pptx import """ - self.extract(PPTX_URL, 'pptx') + self.extract(PPTX_URL, "pptx") def test_pdf(self): """ Test Pdf import """ - self.extract(PDF_URL, 'pdf') + self.extract(PDF_URL, "pdf") diff --git a/utils/extractor/web_document.py b/utils/extractor/web_document.py index c00882c887..0a5a877ba0 100644 --- a/utils/extractor/web_document.py +++ b/utils/extractor/web_document.py @@ -1,13 +1,12 @@ -import requests import tempfile + +import requests from django.conf import settings -from utils.common import (write_file, DEFAULT_HEADERS) +from utils.common import DEFAULT_HEADERS, write_file from utils.web_info_extractor import get_web_info_extractor -from .document import ( - Document, - HTML, PDF, DOCX, PPTX, -) + +from .document import DOCX, HTML, PDF, PPTX, Document class WebDocument(Document): @@ -15,18 +14,23 @@ class WebDocument(Document): Web documents can be html or pdf. Taks url Gives document and type """ - HTML_TYPES = ['text/html', 'text/plain'] - PDF_TYPES = ['application/pdf', ] - DOCX_TYPES = ['application/vnd.openxmlformats-officedocument' - '.wordprocessingml.document', ] - PPTX_TYPES = ['application/vnd.openxmlformats-officedocument' - '.presentationml.presentation', ] + + HTML_TYPES = ["text/html", "text/plain"] + PDF_TYPES = [ + "application/pdf", + ] + DOCX_TYPES = [ + "application/vnd.openxmlformats-officedocument" ".wordprocessingml.document", + ] + PPTX_TYPES = [ + "application/vnd.openxmlformats-officedocument" ".presentationml.presentation", + ] def __init__(self, url): type = HTML doc = None - params = {'url': url} + params = {"url": url} try: r = requests.head(url, headers=DEFAULT_HEADERS, verify=False) @@ -37,26 +41,21 @@ def __init__(self, url): super().__init__(doc, type, params=params) return - if not r.headers.get('content-type') or \ - any(x in r.headers["content-type"] for x in self.HTML_TYPES): + if not r.headers.get("content-type") or any(x in r.headers["content-type"] for x in self.HTML_TYPES): doc = get_web_info_extractor(url).get_content() else: - fp = tempfile.NamedTemporaryFile( - dir=settings.TEMP_DIR, delete=False) + fp = tempfile.NamedTemporaryFile(dir=settings.TEMP_DIR, delete=False) r = requests.get(url, stream=True, headers=DEFAULT_HEADERS, verify=False) write_file(r, fp) doc = fp - if any(x in r.headers["content-type"] - for x in self.PDF_TYPES): + if any(x in r.headers["content-type"] for x in self.PDF_TYPES): type = PDF - elif any(x in r.headers["content-type"] - for x in self.DOCX_TYPES): + elif any(x in r.headers["content-type"] for x in self.DOCX_TYPES): type = DOCX - elif any(x in r.headers["content-type"] - for x in self.PPTX_TYPES): + elif any(x in r.headers["content-type"] for x in self.PPTX_TYPES): type = PPTX super().__init__(doc, type, params=params) diff --git a/utils/files.py b/utils/files.py index bd51040a94..87b673aeb4 100644 --- a/utils/files.py +++ b/utils/files.py @@ -1,5 +1,5 @@ -from typing import Dict, List, Tuple, Union, IO import json +from typing import IO, Dict, List, Tuple, Union from django.core.files.base import ContentFile from django.core.serializers.json import DjangoJSONEncoder @@ -7,9 +7,7 @@ def generate_file_for_upload(file: IO): file.seek(0) - return ContentFile( - file.read().encode('utf-8') - ) + return ContentFile(file.read().encode("utf-8")) def generate_json_file_for_upload(data: Union[Dict, List, Tuple], **kwargs) -> ContentFile: @@ -18,5 +16,5 @@ def generate_json_file_for_upload(data: Union[Dict, List, Tuple], **kwargs) -> C data, cls=DjangoJSONEncoder, **kwargs, - ).encode('utf-8'), + ).encode("utf-8"), ) diff --git a/utils/graphene/dataloaders.py b/utils/graphene/dataloaders.py index f4e1feb85c..c982c3b03b 100644 --- a/utils/graphene/dataloaders.py +++ b/utils/graphene/dataloaders.py @@ -1,9 +1,9 @@ from promise.dataloader import DataLoader -class WithContextMixin(): +class WithContextMixin: def __init__(self, *args, **kwargs): - self.context = kwargs.pop('context') + self.context = kwargs.pop("context") super().__init__(*args, **kwargs) diff --git a/utils/graphene/enums.py b/utils/graphene/enums.py index f00a5d0cdc..3fa81e3c8c 100644 --- a/utils/graphene/enums.py +++ b/utils/graphene/enums.py @@ -1,10 +1,10 @@ from typing import Union import graphene +from django.contrib.postgres.fields import ArrayField +from django.db import models from django_enumfield import enum from rest_framework import serializers -from django.db import models -from django.contrib.postgres.fields import ArrayField from utils.common import to_camelcase @@ -43,11 +43,11 @@ def get_enum_name_from_django_field( serializer_name=None, ): def _have_model(_field): - if hasattr(_field, 'model') or hasattr(getattr(_field, 'Meta', None), 'model'): + if hasattr(_field, "model") or hasattr(getattr(_field, "Meta", None), "model"): return True def _get_serializer_name(_field): - if hasattr(_field, 'parent'): + if hasattr(_field, "parent"): return type(_field.parent).__name__ if field_name is None or model_name is None: @@ -85,12 +85,12 @@ def _get_serializer_name(_field): serializer_name = _get_serializer_name(field) field_name = field_name or field.name if field_name is None: - raise Exception(f'{field=} should have a name') + raise Exception(f"{field=} should have a name") if model_name: - return f'{model_name}{to_camelcase(field_name.title())}' + return f"{model_name}{to_camelcase(field_name.title())}" if serializer_name: - return f'{serializer_name}{to_camelcase(field_name.title())}' - raise Exception(f'{serializer_name=} should have a value') + return f"{serializer_name}{to_camelcase(field_name.title())}" + raise Exception(f"{serializer_name=} should have a value") class EnumDescription(graphene.Scalar): @@ -111,7 +111,7 @@ def coerce_string(value): _value = value if callable(value): _value = value() - return _value or '' + return _value or "" serialize = coerce_string parse_value = coerce_string diff --git a/utils/graphene/error_types.py b/utils/graphene/error_types.py index db5a8437f1..66ca04d89e 100644 --- a/utils/graphene/error_types.py +++ b/utils/graphene/error_types.py @@ -6,7 +6,7 @@ from graphene.utils.str_converters import to_snake_case from graphene_django.utils.utils import _camelize_django_str -ARRAY_NON_MEMBER_ERRORS = 'nonMemberErrors' +ARRAY_NON_MEMBER_ERRORS = "nonMemberErrors" CustomErrorType = GenericScalar @@ -16,11 +16,11 @@ class ArrayNestedErrorType(ObjectType): object_errors = graphene.List(graphene.NonNull(GenericScalar)) def keys(self): - return ['clientId', 'messages', 'objectErrors'] + return ["clientId", "messages", "objectErrors"] def __getitem__(self, key): key = to_snake_case(key) - if key in ('object_errors',) and getattr(self, key): + if key in ("object_errors",) and getattr(self, key): return [dict(each) for each in getattr(self, key)] return getattr(self, key) @@ -33,44 +33,51 @@ class _CustomErrorType(ObjectType): array_errors = graphene.List(graphene.NonNull(ArrayNestedErrorType)) def keys(self): - return ['clientId', 'field', 'messages', 'objectErrors', 'arrayErrors'] + return ["clientId", "field", "messages", "objectErrors", "arrayErrors"] def __getitem__(self, key): key = to_snake_case(key) - if key in ('object_errors', 'array_errors') and getattr(self, key): + if key in ("object_errors", "array_errors") and getattr(self, key): return [dict(each) for each in getattr(self, key)] return getattr(self, key) def serializer_error_to_error_types(errors: dict, initial_data: dict = None) -> List: initial_data = initial_data or dict() - node_client_id = initial_data.get('client_id') + node_client_id = initial_data.get("client_id") error_types = list() for field, value in errors.items(): if isinstance(value, dict): - error_types.append(_CustomErrorType( - client_id=node_client_id, - field=_camelize_django_str(field), - object_errors=serializer_error_to_error_types(value) - )) + error_types.append( + _CustomErrorType( + client_id=node_client_id, + field=_camelize_django_str(field), + object_errors=serializer_error_to_error_types(value), + ) + ) elif isinstance(value, list): if isinstance(value[0], str): if isinstance(initial_data.get(field), list): # we have found an array input with top level error - error_types.append(_CustomErrorType( - client_id=node_client_id, - field=_camelize_django_str(field), - array_errors=[ArrayNestedErrorType( - client_id=ARRAY_NON_MEMBER_ERRORS, - messages=''.join(str(msg) for msg in value) - )] - )) + error_types.append( + _CustomErrorType( + client_id=node_client_id, + field=_camelize_django_str(field), + array_errors=[ + ArrayNestedErrorType( + client_id=ARRAY_NON_MEMBER_ERRORS, messages="".join(str(msg) for msg in value) + ) + ], + ) + ) else: - error_types.append(_CustomErrorType( - client_id=node_client_id, - field=_camelize_django_str(field), - messages=''.join(str(msg) for msg in value) - )) + error_types.append( + _CustomErrorType( + client_id=node_client_id, + field=_camelize_django_str(field), + messages="".join(str(msg) for msg in value), + ) + ) elif isinstance(value[0], dict): array_errors = [] for pos, array_item in enumerate(value): @@ -82,22 +89,18 @@ def serializer_error_to_error_types(errors: dict, initial_data: dict = None) -> initial_data_field_pos = initial_data[field][pos] or {} except (KeyError, IndexError): initial_data_field_pos = {} - client_id = initial_data_field_pos.get('client_id', f'NOT_FOUND_{pos}') - array_errors.append(ArrayNestedErrorType( - client_id=client_id, - object_errors=serializer_error_to_error_types(array_item, initial_data_field_pos) - )) - error_types.append(_CustomErrorType( - client_id=node_client_id, - field=_camelize_django_str(field), - array_errors=array_errors - )) + client_id = initial_data_field_pos.get("client_id", f"NOT_FOUND_{pos}") + array_errors.append( + ArrayNestedErrorType( + client_id=client_id, object_errors=serializer_error_to_error_types(array_item, initial_data_field_pos) + ) + ) + error_types.append( + _CustomErrorType(client_id=node_client_id, field=_camelize_django_str(field), array_errors=array_errors) + ) else: # fallback - error_types.append(_CustomErrorType( - field=_camelize_django_str(field), - messages=' '.join(str(msg) for msg in value) - )) + error_types.append(_CustomErrorType(field=_camelize_django_str(field), messages=" ".join(str(msg) for msg in value))) return error_types diff --git a/utils/graphene/fields.py b/utils/graphene/fields.py index ca66e84e14..14e3ff6633 100644 --- a/utils/graphene/fields.py +++ b/utils/graphene/fields.py @@ -1,15 +1,18 @@ import inspect -from functools import partial from collections import OrderedDict -from typing import Type, Optional +from functools import partial +from typing import Optional, Type -from django.db.models import QuerySet import graphene +from django.db.models import QuerySet from graphene.types.structures import Structure from graphene.utils.str_converters import to_snake_case from graphene_django.filter.utils import get_filtering_args_from_filterset -from graphene_django.utils import maybe_queryset, is_valid_django_model from graphene_django.registry import get_global_registry +from graphene_django.rest_framework.serializer_converter import ( + get_graphene_type_from_serializer_field, +) +from graphene_django.utils import is_valid_django_model, maybe_queryset from graphene_django_extras import DjangoFilterPaginateListField from graphene_django_extras.base_types import DjangoListObjectBase from graphene_django_extras.fields import DjangoListField @@ -17,10 +20,12 @@ from graphene_django_extras.paginations.pagination import BaseDjangoGraphqlPagination from graphene_django_extras.settings import graphql_api_settings from graphene_django_extras.utils import get_extra_filters -from graphene_django.rest_framework.serializer_converter import get_graphene_type_from_serializer_field from rest_framework import serializers -from utils.graphene.pagination import OrderingOnlyArgumentPagination, NoOrderingPageGraphqlPagination +from utils.graphene.pagination import ( + NoOrderingPageGraphqlPagination, + OrderingOnlyArgumentPagination, +) class CustomDjangoListObjectBase(DjangoListObjectBase): @@ -36,7 +41,7 @@ def to_dict(self): self.results_field_name: [e.to_dict() for e in self.results], "count": self.count, "page": self.page, - "pageSize": self.pageSize + "pageSize": self.pageSize, } @@ -44,16 +49,15 @@ class CustomDjangoListField(DjangoListField): """ Removes the compulsion of using `get_queryset` in the DjangoListField """ + @staticmethod - def list_resolver( - django_object_type, resolver, root, info, **args - ): + def list_resolver(django_object_type, resolver, root, info, **args): queryset = maybe_queryset(resolver(root, info, **args)) if queryset is None: queryset = QuerySet.none() # FIXME: This will throw error if isinstance(queryset, QuerySet): - if hasattr(django_object_type, 'get_queryset'): + if hasattr(django_object_type, "get_queryset"): # Pass queryset to the DjangoObjectType get_queryset method queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) return queryset @@ -85,9 +89,7 @@ def __init__( filterset_class = filterset_class or _type._meta.filterset_class self.filterset_class = get_filterset_class(filterset_class) - self.filtering_args = get_filtering_args_from_non_model_filterset( - self.filterset_class - ) + self.filtering_args = get_filtering_args_from_non_model_filterset(self.filterset_class) kwargs["args"].update(self.filtering_args) pagination = pagination or OrderingOnlyArgumentPagination() @@ -102,26 +104,22 @@ def __init__( self.pagination = pagination kwargs.update(**pagination_kwargs) - self.accessor = kwargs.pop('accessor', None) - super(DjangoFilterPaginateListField, self).__init__( - _type, *args, **kwargs - ) + self.accessor = kwargs.pop("accessor", None) + super(DjangoFilterPaginateListField, self).__init__(_type, *args, **kwargs) - def list_resolver( - self, filterset_class, filtering_args, root, info, **kwargs - ): + def list_resolver(self, filterset_class, filtering_args, root, info, **kwargs): filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args} qs = getattr(root, self.accessor) - if hasattr(qs, 'all'): + if hasattr(qs, "all"): qs = qs.all() qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs count = qs.count() if getattr(self, "pagination", None): ordering = kwargs.pop(self.pagination.ordering_param, None) or self.pagination.ordering - ordering = ','.join([to_snake_case(each) for each in ordering.strip(',').replace(' ', '').split(',')]) - 'pageSize' in kwargs and kwargs['pageSize'] is None and kwargs.pop('pageSize') + ordering = ",".join([to_snake_case(each) for each in ordering.strip(",").replace(" ", "").split(",")]) + "pageSize" in kwargs and kwargs["pageSize"] is None and kwargs.pop("pageSize") kwargs[self.pagination.ordering_param] = ordering qs = self.pagination.paginate_queryset(qs, **kwargs) @@ -129,11 +127,14 @@ def list_resolver( count=count, results=maybe_queryset(qs), results_field_name=self.type._meta.results_field_name, - page=kwargs.get('page', 1) if hasattr(self.pagination, 'page') else None, - pageSize=kwargs.get( # TODO: Need to add cutoff to send max page size instead of requested - 'pageSize', - graphql_api_settings.DEFAULT_PAGE_SIZE - ) if hasattr(self.pagination, 'page') else None + page=kwargs.get("page", 1) if hasattr(self.pagination, "page") else None, + pageSize=( + kwargs.get( # TODO: Need to add cutoff to send max page size instead of requested + "pageSize", graphql_api_settings.DEFAULT_PAGE_SIZE + ) + if hasattr(self.pagination, "page") + else None + ), ) def get_resolver(self, parent_resolver): @@ -158,11 +159,11 @@ def __init__( *args, **kwargs, ): - ''' + """ If pagination is None, then we will only allow Ordering fields. - The page size will respect the settings. - Client will not be able to add pagination params - ''' + """ _fields = _type._meta.filter_fields _model = _type._meta.model @@ -173,9 +174,7 @@ def __init__( filterset_class = filterset_class or _type._meta.filterset_class self.filterset_class = get_filterset_class(filterset_class, **meta) - self.filtering_args = get_filtering_args_from_filterset( - self.filterset_class, _type - ) + self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, _type) kwargs.setdefault("args", {}) kwargs["args"].update(self.filtering_args) @@ -195,18 +194,14 @@ def __init__( kwargs["description"] = "{} list".format(_type._meta.model.__name__) # accessor will be used with m2m or reverse_fk fields - self.accessor = kwargs.pop('accessor', None) - super(DjangoFilterPaginateListField, self).__init__( - _type, *args, **kwargs - ) + self.accessor = kwargs.pop("accessor", None) + super(DjangoFilterPaginateListField, self).__init__(_type, *args, **kwargs) - def list_resolver( - self, manager, filterset_class, filtering_args, root, info, **kwargs - ): + def list_resolver(self, manager, filterset_class, filtering_args, root, info, **kwargs): filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args} if self.accessor: qs = getattr(root, self.accessor) - if hasattr(qs, 'all'): + if hasattr(qs, "all"): qs = qs.all() qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs else: @@ -224,20 +219,23 @@ def list_resolver( # This is handled in filterset kwargs[self.pagination.ordering_param] = None else: - ordering = ','.join([to_snake_case(each) for each in ordering.strip(',').replace(' ', '').split(',')]) + ordering = ",".join([to_snake_case(each) for each in ordering.strip(",").replace(" ", "").split(",")]) kwargs[self.pagination.ordering_param] = ordering - 'pageSize' in kwargs and kwargs['pageSize'] is None and kwargs.pop('pageSize') + "pageSize" in kwargs and kwargs["pageSize"] is None and kwargs.pop("pageSize") qs = self.pagination.paginate_queryset(qs, **kwargs) return CustomDjangoListObjectBase( count=count, results=maybe_queryset(qs), results_field_name=self.type._meta.results_field_name, - page=kwargs.get('page', 1) if hasattr(self.pagination, 'page_query_param') else None, - pageSize=kwargs.get( # TODO: Need to add cutoff to send max page size instead of requested - 'pageSize', - graphql_api_settings.DEFAULT_PAGE_SIZE - ) if hasattr(self.pagination, 'page_size_query_param') else None + page=kwargs.get("page", 1) if hasattr(self.pagination, "page_query_param") else None, + pageSize=( + kwargs.get( # TODO: Need to add cutoff to send max page size instead of requested + "pageSize", graphql_api_settings.DEFAULT_PAGE_SIZE + ) + if hasattr(self.pagination, "page_size_query_param") + else None + ), ) @@ -272,10 +270,7 @@ def generate_object_field_from_input_type(input_type, skip_fields=[]): if field_key in skip_fields: continue _type = field.type - if inspect.isclass(_type) and ( - issubclass(_type, graphene.Scalar) or - issubclass(_type, graphene.Enum) - ): + if inspect.isclass(_type) and (issubclass(_type, graphene.Scalar) or issubclass(_type, graphene.Enum)): new_fields_map[field_key] = graphene.Field(_type) else: new_fields_map[field_key] = _type @@ -285,15 +280,15 @@ def generate_object_field_from_input_type(input_type, skip_fields=[]): # use this for input type with direct scaler fields only. def generate_simple_object_type_from_input_type(input_type): new_fields_map = generate_object_field_from_input_type(input_type) - return type(input_type._meta.name.replace('Input', ''), (graphene.ObjectType,), new_fields_map) + return type(input_type._meta.name.replace("Input", ""), (graphene.ObjectType,), new_fields_map) def compare_input_output_type_fields(input_type, output_type): if len(output_type._meta.fields) != len(input_type._meta.fields): for field in input_type._meta.fields.keys(): if field not in output_type._meta.fields.keys(): - print('---> [Entry] Missing: ', field) - raise Exception('Conversion failed') + print("---> [Entry] Missing: ", field) + raise Exception("Conversion failed") def convert_serializer_field(field, convert_choices_to_enum=True, force_optional=False): @@ -311,10 +306,7 @@ def convert_serializer_field(field, convert_choices_to_enum=True, force_optional graphql_type = get_graphene_type_from_serializer_field(field) args = [] - kwargs = { - "description": field.help_text, - "required": field.required and not force_optional - } + kwargs = {"description": field.help_text, "required": field.required and not force_optional} # if it is a tuple or a list it means that we are returning # the graphql type and the child type @@ -344,21 +336,16 @@ def convert_serializer_to_type(serializer_class): """ graphene_django.rest_framework.serializer_converter.convert_serializer_to_type """ - cached_type = convert_serializer_to_type.cache.get( - serializer_class.__name__, None - ) + cached_type = convert_serializer_to_type.cache.get(serializer_class.__name__, None) if cached_type: return cached_type serializer = serializer_class() - items = { - name: convert_serializer_field(field) - for name, field in serializer.fields.items() - } + items = {name: convert_serializer_field(field) for name, field in serializer.fields.items()} # Alter naming serializer_name = serializer.__class__.__name__ - serializer_name = ''.join(''.join(serializer_name.split('ModelSerializer')).split('Serializer')) - ref_name = f'{serializer_name}Type' + serializer_name = "".join("".join(serializer_name.split("ModelSerializer")).split("Serializer")) + ref_name = f"{serializer_name}Type" base_classes = (graphene.ObjectType,) @@ -417,6 +404,6 @@ def generate_type_for_serializer( _type = type(name, (graphene.ObjectType,), data_members) if update_cache: if name in convert_serializer_to_type.cache: - raise Exception(f'<{name}> : <{serializer_class.__name__}> Alreay exists') + raise Exception(f"<{name}> : <{serializer_class.__name__}> Alreay exists") convert_serializer_to_type.cache[serializer_class.__name__] = _type return _type diff --git a/utils/graphene/filters.py b/utils/graphene/filters.py index 1d31e46300..8711869bf7 100644 --- a/utils/graphene/filters.py +++ b/utils/graphene/filters.py @@ -1,4 +1,5 @@ from functools import partial + import django_filters import graphene from graphene.types.generic import GenericScalar @@ -30,9 +31,7 @@ def _generate_filter_class(inner_type, filter_type=None, non_null=False): ).format(inner_type.__name__, _filter_type), }, ) - convert_form_field.register(form_field)( - lambda _: graphene.NonNull(inner_type) if non_null else inner_type() - ) + convert_form_field.register(form_field)(lambda _: graphene.NonNull(inner_type) if non_null else inner_type()) return filter_class @@ -69,9 +68,7 @@ def _generate_list_filter_class(inner_type, filter_type=None, field_class=None): ).format(inner_type.__name__, _filter_type), }, ) - convert_form_field.register(form_field)( - lambda _: graphene.List(graphene.NonNull(inner_type)) - ) + convert_form_field.register(form_field)(lambda _: graphene.List(graphene.NonNull(inner_type))) return filter_class @@ -117,14 +114,14 @@ def _get_id_list_filter(**kwargs): ) DateTimeGteFilter = partial( django_filters.DateTimeFilter, - lookup_expr='gte', + lookup_expr="gte", input_formats=[django_filters.fields.IsoDateTimeField.ISO_8601], ) DateTimeLteFilter = partial( django_filters.DateTimeFilter, - lookup_expr='lte', + lookup_expr="lte", input_formats=[django_filters.fields.IsoDateTimeField.ISO_8601], ) -DateGteFilter = partial(django_filters.DateFilter, lookup_expr='gte') -DateLteFilter = partial(django_filters.DateFilter, lookup_expr='lte') +DateGteFilter = partial(django_filters.DateFilter, lookup_expr="gte") +DateLteFilter = partial(django_filters.DateFilter, lookup_expr="lte") diff --git a/utils/graphene/geo_scalars.py b/utils/graphene/geo_scalars.py index 12f0676e88..54c887ec99 100644 --- a/utils/graphene/geo_scalars.py +++ b/utils/graphene/geo_scalars.py @@ -1,10 +1,12 @@ """ Source: https://raw.githubusercontent.com/EverWinter23/graphene-gis/master/graphene_gis/scalars.py """ + import json -from graphql.language import ast -from graphene.types import Scalar + from django.contrib.gis.geos import GEOSGeometry +from graphene.types import Scalar +from graphql.language import ast class GISScalar(Scalar): diff --git a/utils/graphene/middleware.py b/utils/graphene/middleware.py index dcc3f05373..67630a579f 100644 --- a/utils/graphene/middleware.py +++ b/utils/graphene/middleware.py @@ -1,14 +1,15 @@ from django.conf import settings -from deep.exceptions import UnauthorizedException - -from project.models import Project from project.change_log import ProjectChangeManager +from project.models import Project + +from deep.exceptions import UnauthorizedException class WhiteListMiddleware: - ''' + """ Graphql node whitelist for unauthenticated user - ''' + """ + def resolve(self, next, root, info, **args): # if user is not authenticated and user is not accessing # whitelisted nodes, then raise permission denied error @@ -23,37 +24,35 @@ class DisableIntrospectionSchemaMiddleware: """ This middleware disables request with __schema in production. """ + def resolve(self, next, root, info, **args): - if info.field_name == '__schema' and not settings.DEBUG: + if info.field_name == "__schema" and not settings.DEBUG: return None return next(root, info, **args) class ProjectLogMiddleware: - ''' + """ Middleware to track Project changes - ''' + """ + WATCHED_PATH = [ *[ - ['project', path] + ["project", path] for path in [ - 'projectUpdate', - 'projectDelete', - 'projectUserMembershipBulk', - 'projectUserGroupMembershipBulk', - 'projectRegionBulk', - 'projectVizConfigurationUpdate', - 'acceptRejectProject', + "projectUpdate", + "projectDelete", + "projectUserMembershipBulk", + "projectUserGroupMembershipBulk", + "projectRegionBulk", + "projectVizConfigurationUpdate", + "acceptRejectProject", ] ], ] def resolve(self, next, root, info, **args): - if ( - info.operation.operation == 'mutation' and - isinstance(root, Project) and - info.path in self.WATCHED_PATH - ): + if info.operation.operation == "mutation" and isinstance(root, Project) and info.path in self.WATCHED_PATH: with ProjectChangeManager(info.context.request, root.id): return next(root, info, **args) return next(root, info, **args) diff --git a/utils/graphene/mutation.py b/utils/graphene/mutation.py index aa5f1bb99c..8090689d34 100644 --- a/utils/graphene/mutation.py +++ b/utils/graphene/mutation.py @@ -1,5 +1,5 @@ -from typing import Type, List from collections import OrderedDict +from typing import List, Type import graphene import graphene_django @@ -8,20 +8,19 @@ from graphene_django.rest_framework.serializer_converter import ( get_graphene_type_from_serializer_field, ) -from rest_framework import serializers from graphene_file_upload.scalars import Upload +from rest_framework import serializers + +from deep.enums import ENUM_TO_GRAPHENE_ENUM_MAP -from utils.graphene.error_types import mutation_is_not_valid -from utils.graphene.enums import get_enum_name_from_django_field # from utils.common import to_camelcase from deep.exceptions import PermissionDeniedException -from deep.enums import ENUM_TO_GRAPHENE_ENUM_MAP +from deep.permissions import AnalysisFrameworkPermissions as AfP +from deep.permissions import ProjectPermissions as PP +from deep.permissions import UserGroupPermissions as UgP from deep.serializers import IntegerIDField, StringIDField -from deep.permissions import ( - ProjectPermissions as PP, - AnalysisFrameworkPermissions as AfP, - UserGroupPermissions as UgP, -) +from utils.graphene.enums import get_enum_name_from_django_field +from utils.graphene.error_types import mutation_is_not_valid @get_graphene_type_from_serializer_field.register(serializers.ListSerializer) @@ -69,7 +68,7 @@ def convert_serializer_field_to_enum(field): # Try django_enumfield (NOTE: Let's try to avoid this) custom_name = type(list(field.choices.values())[-1]).__name__ if custom_name is None: - raise Exception(f'Enum name generation failed for {field=}') + raise Exception(f"Enum name generation failed for {field=}") return ENUM_TO_GRAPHENE_ENUM_MAP[custom_name] @@ -88,10 +87,7 @@ def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True, graphql_type = get_graphene_type_from_serializer_field(field) args = [] - kwargs = { - "description": field.help_text, - "required": is_input and field.required and not force_optional - } + kwargs = {"description": field.help_text, "required": is_input and field.required and not force_optional} # if it is a tuple or a list it means that we are returning # the graphql type and the child type @@ -127,21 +123,16 @@ def convert_serializer_to_input_type(serializer_class): """ graphene_django.rest_framework.serializer_converter.convert_serializer_to_input_type """ - cached_type = convert_serializer_to_input_type.cache.get( - serializer_class.__name__, None - ) + cached_type = convert_serializer_to_input_type.cache.get(serializer_class.__name__, None) if cached_type: return cached_type serializer = serializer_class() - items = { - name: convert_serializer_field(field) - for name, field in serializer.fields.items() - } + items = {name: convert_serializer_field(field) for name, field in serializer.fields.items()} # Alter naming serializer_name = serializer.__class__.__name__ - serializer_name = ''.join(''.join(serializer_name.split('ModelSerializer')).split('Serializer')) - ref_name = f'{serializer_name}InputType' + serializer_name = "".join("".join(serializer_name.split("ModelSerializer")).split("Serializer")) + ref_name = f"{serializer_name}InputType" base_classes = (graphene.InputObjectType,) @@ -176,10 +167,8 @@ def fields_for_serializer( is_excluded = any( [ name in exclude_fields, - field.write_only and - not is_input, # don't show write_only fields in Query - field.read_only and is_input \ - and lookup_field != name, # don't show read_only fields in Input + field.write_only and not is_input, # don't show write_only fields in Query + field.read_only and is_input and lookup_field != name, # don't show read_only fields in Input ] ) @@ -232,17 +221,17 @@ def get_queryset(cls, info): @classmethod def get_object(cls, info, **kwargs): try: - return cls.get_queryset(info).get(id=kwargs['id']), None + return cls.get_queryset(info).get(id=kwargs["id"]), None except cls.model.DoesNotExist: - return None, [dict(field='nonFieldErrors', messages=f'{cls.model.__name__} does not exist.')] + return None, [dict(field="nonFieldErrors", messages=f"{cls.model.__name__} does not exist.")] @classmethod def check_permissions(cls, info, **kwargs): - raise Exception('This needs to be implemented in inheritances class') + raise Exception("This needs to be implemented in inheritances class") @classmethod def perform_mutate(cls, root, info, **kwargs): - raise Exception('This needs to be implemented in inheritances class') + raise Exception("This needs to be implemented in inheritances class") @classmethod def get_serializer_context(cls, instance, context): @@ -250,10 +239,10 @@ def get_serializer_context(cls, instance, context): @classmethod def _save_item(cls, item, info, **kwargs): - id = kwargs.pop('id', None) + id = kwargs.pop("id", None) base_context = { - 'gql_info': info, - 'request': info.context, + "gql_info": info, + "request": info.context, } if id: instance, errors = cls.get_object(info, id=id, **kwargs) @@ -288,7 +277,7 @@ class GrapheneMutation(BaseGrapheneMutation): @classmethod def perform_mutate(cls, root, info, **kwargs): - data = kwargs['data'] + data = kwargs["data"] instance, errors = cls._save_item(data, info, **kwargs) return cls(result=instance, errors=errors, ok=not errors) @@ -306,8 +295,8 @@ def get_valid_delete_items(cls, info, delete_ids): @classmethod def perform_mutate(cls, root, info, **kwargs): - items = kwargs.get('items') or [] - delete_ids = kwargs.get('delete_ids') + items = kwargs.get("items") or [] + delete_ids = kwargs.get("delete_ids") all_errors = [] all_instances = [] all_deleted_instances = [] @@ -323,7 +312,7 @@ def perform_mutate(cls, root, info, **kwargs): # cls.model.filter(pk__in=validated_delete_ids).delete() # Bulk Create/Update for item in items: - id = item.get('id') + id = item.get("id") instance, errors = cls._save_item(item, info, id=id, **kwargs) all_errors.append(errors) all_instances.append(instance) @@ -348,12 +337,12 @@ def perform_mutate(cls, root, info, **kwargs): result=None, ok=False, errors=[ - dict(field='nonFieldErrors', message='You are not allowed to delete!!'), + dict(field="nonFieldErrors", message="You are not allowed to delete!!"), ], ) -class ProjectScopeMixin(): +class ProjectScopeMixin: permissions: List[PP.Permission] @classmethod @@ -375,7 +364,7 @@ class PsDeleteMutation(ProjectScopeMixin, DeleteMutation): pass -class AfScopeMixin(): +class AfScopeMixin: permissions: List[AfP.Permission] @classmethod @@ -393,7 +382,7 @@ class AfBulkGrapheneMutation(AfScopeMixin, BulkGrapheneMutation): pass -class UgScopeMixin(): +class UgScopeMixin: permissions: List[UgP.Permission] @classmethod diff --git a/utils/graphene/pagination.py b/utils/graphene/pagination.py index 7619dce391..894b20f496 100644 --- a/utils/graphene/pagination.py +++ b/utils/graphene/pagination.py @@ -1,12 +1,13 @@ from graphene import String -from graphene_django_extras.paginations.pagination import BaseDjangoGraphqlPagination from graphene_django_extras import PageGraphqlPagination +from graphene_django_extras.paginations.pagination import BaseDjangoGraphqlPagination class NoOrderingPageGraphqlPagination(PageGraphqlPagination): """ Custom pagination to support enum ordering from filterset """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -21,6 +22,7 @@ class OrderingOnlyArgumentPagination(BaseDjangoGraphqlPagination): Pagination just for ordering. Created for DjangoFilterPaginateListField (or its subclasses) in mind, to remove the page related arguments. """ + __name__ = "OrderingOnlyArgument" def __init__( diff --git a/utils/graphene/resolver.py b/utils/graphene/resolver.py index 6f6d201d0d..57a9b514b8 100644 --- a/utils/graphene/resolver.py +++ b/utils/graphene/resolver.py @@ -1,5 +1,5 @@ -from graphene.types.resolver import dict_or_attr_resolver, set_default_resolver from django.db.models.fields.files import FieldFile as DjangoFieldFile +from graphene.types.resolver import dict_or_attr_resolver, set_default_resolver def custom_dict_or_attr_resolver(*args, **kwargs): diff --git a/utils/graphene/tests.py b/utils/graphene/tests.py index c0dbc890c7..2357ec6222 100644 --- a/utils/graphene/tests.py +++ b/utils/graphene/tests.py @@ -1,42 +1,42 @@ -import os -import json -import pytz -import inspect import datetime +import inspect +import json +import os from enum import Enum -from unittest.mock import patch from typing import Union +from unittest.mock import patch -from factory import random as factory_random -from snapshottest.django import TestCase as SnapShotTextCase -from django.utils import timezone -from django.core import management +import pytz +from analysis_framework.models import AnalysisFramework, AnalysisFrameworkRole from django.conf import settings from django.contrib.auth import get_user_model -from django.test import TestCase, override_settings +from django.core import management from django.db import models +from django.test import TestCase, override_settings +from django.utils import timezone +from factory import random as factory_random + # dramatiq test case: setupclass is not properly called # from django_dramatiq.test import DramatiqTestCase from graphene_django.utils import GraphQLTestCase as BaseGraphQLTestCase +from project.models import ProjectRole +from project.permissions import get_project_permissions_value from rest_framework import status +from snapshottest.django import TestCase as SnapShotTextCase from deep.middleware import _set_current_request from deep.tests.test_case import ( - TEST_CACHES, TEST_AUTH_PASSWORD_VALIDATORS, + TEST_CACHES, TEST_EMAIL_BACKEND, TEST_FILE_STORAGE, clean_up_test_media_files, ) -from analysis_framework.models import AnalysisFramework, AnalysisFrameworkRole -from project.permissions import get_project_permissions_value -from project.models import ProjectRole - User = get_user_model() -TEST_MEDIA_ROOT = 'media-temp' +TEST_MEDIA_ROOT = "media-temp" if settings.PYTEST_XDIST_WORKER: - TEST_MEDIA_ROOT = f'media-temp/{settings.PYTEST_XDIST_WORKER}' + TEST_MEDIA_ROOT = f"media-temp/{settings.PYTEST_XDIST_WORKER}" @override_settings( @@ -46,14 +46,14 @@ CACHES=TEST_CACHES, AUTH_PASSWORD_VALIDATORS=TEST_AUTH_PASSWORD_VALIDATORS, CELERY_TASK_ALWAYS_EAGER=True, - DEEPL_SERVER_CALLBACK_DOMAIN='http://testserver', + DEEPL_SERVER_CALLBACK_DOMAIN="http://testserver", ) class GraphQLTestCase(BaseGraphQLTestCase): """ GraphQLTestCase with custom helper methods """ - GRAPHQL_SCHEMA = 'deep.schema.schema' + GRAPHQL_SCHEMA = "deep.schema.schema" ENABLE_NOW_PATCHER = False PATCHER_NOW_VALUE = datetime.datetime(2021, 1, 1, 0, 0, 0, 123456, tzinfo=pytz.UTC) @@ -64,24 +64,24 @@ def tearDownClass(cls): super().tearDownClass() def _setup_premailer_patcher(self, mock): - mock.get.return_value.text = '' - mock.post.return_value.text = '' + mock.get.return_value.text = "" + mock.post.return_value.text = "" def setUp(self): super().setUp() self.create_project_roles() self.create_af_roles() - self.premailer_patcher_requests = patch('premailer.premailer.requests') + self.premailer_patcher_requests = patch("premailer.premailer.requests") self._setup_premailer_patcher(self.premailer_patcher_requests.start()) if self.ENABLE_NOW_PATCHER: - self.now_patcher = patch('django.utils.timezone.now') + self.now_patcher = patch("django.utils.timezone.now") self.now_datetime = self.PATCHER_NOW_VALUE self.now_datetime_str = lambda: self.now_datetime.isoformat() self.now_patcher.start().side_effect = lambda: self.now_datetime def tearDown(self): _set_current_request() # Clear request - if hasattr(self, 'now_patcher'): + if hasattr(self, "now_patcher"): self.now_patcher.stop() self.premailer_patcher_requests.stop() super().tearDown() @@ -121,14 +121,14 @@ def query_check(self, query, minput=None, mnested=None, assert_for_error=False, else: self.assertResponseNoErrors(response) if okay is not None: - _content = content['data'] + _content = content["data"] if mnested: for key in mnested: _content = _content[key] for key, datum in _content.items(): - if key == '__typename': + if key == "__typename": continue - okay_response = datum.get('ok') + okay_response = datum.get("ok") if okay: self.assertTrue(okay_response, content) else: @@ -144,11 +144,11 @@ def _create_role(title, _type, level=1, is_default_role=False): # TODO: Migrate current dynamic permission to static ones. return ProjectRole.objects.create( title=title, - lead_permissions=get_project_permissions_value('lead', '__all__'), - entry_permissions=get_project_permissions_value('entry', '__all__'), - setup_permissions=get_project_permissions_value('setup', '__all__'), - export_permissions=get_project_permissions_value('export', '__all__'), - assessment_permissions=get_project_permissions_value('assessment', '__all__'), + lead_permissions=get_project_permissions_value("lead", "__all__"), + entry_permissions=get_project_permissions_value("entry", "__all__"), + setup_permissions=get_project_permissions_value("setup", "__all__"), + export_permissions=get_project_permissions_value("export", "__all__"), + assessment_permissions=get_project_permissions_value("assessment", "__all__"), is_creator_role=False, level=level, is_default_role=is_default_role, @@ -158,33 +158,33 @@ def _create_role(title, _type, level=1, is_default_role=False): # TODO: Make sure merge roles have all the permissions # Follow deep.permissions.py PERMISSION_MAP for permitted actions. self.project_role_reader_non_confidential = _create_role( - 'Reader (Non Confidential)', + "Reader (Non Confidential)", ProjectRole.Type.READER_NON_CONFIDENTIAL, level=800, ) self.project_role_reader = _create_role( - 'Reader', + "Reader", ProjectRole.Type.READER, level=400, ) self.project_role_member = _create_role( - 'Member', + "Member", ProjectRole.Type.MEMBER, level=200, is_default_role=True, ) self.project_role_admin = _create_role( - 'Admin', + "Admin", ProjectRole.Type.ADMIN, level=100, ) self.project_role_owner = _create_role( - 'Project Owner', + "Project Owner", ProjectRole.Type.PROJECT_OWNER, level=1, ) self.project_base_access = _create_role( - 'Base Access', + "Base Access", ProjectRole.Type.UNKNOWN, level=999999, ) @@ -209,35 +209,33 @@ def _create_role(title, _type, permissions=dict, is_private_role=False, is_defau private_temp_af = AnalysisFramework(is_private=True) self.af_editor = _create_role( - 'Editor', - AnalysisFrameworkRole.Type.EDITOR, - permissions=public_temp_af.get_editor_permissions() + "Editor", AnalysisFrameworkRole.Type.EDITOR, permissions=public_temp_af.get_editor_permissions() ) self.af_owner = _create_role( - 'Owner', + "Owner", AnalysisFrameworkRole.Type.OWNER, permissions=public_temp_af.get_owner_permissions(), ) self.af_default = _create_role( - 'Default', + "Default", AnalysisFrameworkRole.Type.DEFAULT, permissions=public_temp_af.get_default_permissions(), is_default_role=True, ) self.af_private_editor = _create_role( - 'Private Editor', + "Private Editor", AnalysisFrameworkRole.Type.PRIVATE_EDITOR, permissions=private_temp_af.get_editor_permissions(), is_private_role=True, ) self.af_private_owner = _create_role( - 'Private Owner', + "Private Owner", AnalysisFrameworkRole.Type.PRIVATE_OWNER, permissions=private_temp_af.get_owner_permissions(), is_private_role=True, ) self.af_private_viewer = _create_role( - 'Private Viewer', + "Private Viewer", AnalysisFrameworkRole.Type.PRIVATE_VIEWER, permissions=private_temp_af.get_default_permissions(), is_private_role=True, @@ -246,8 +244,10 @@ def _create_role(title, _type, permissions=dict, is_private_role=False, is_defau def assertListIds( self, - current_list, excepted_list, message=None, - get_current_list_id=lambda x: str(x['id']), + current_list, + excepted_list, + message=None, + get_current_list_id=lambda x: str(x["id"]), get_excepted_list_id=lambda x: str(x.id), ): self.assertEqual( @@ -258,8 +258,10 @@ def assertListIds( def assertNotListIds( self, - current_list, excepted_list, message=None, - get_current_list_id=lambda x: str(x['id']), + current_list, + excepted_list, + message=None, + get_current_list_id=lambda x: str(x["id"]), get_not_excepted_list_id=lambda x: str(x.id), ): self.assertNotEqual( @@ -277,18 +279,15 @@ def _include(key): if exclude: return key not in keys return key in keys - return { - key: value - for key, value in _dict.items() - if _include(key) - } + + return {key: value for key, value in _dict.items() if _include(key)} if only_keys: assert _filter_by_keys(excepted, keys=only_keys) == _filter_by_keys(real, keys=only_keys), message elif ignore_keys: - assert _filter_by_keys(excepted, keys=ignore_keys, exclude=True) \ - == _filter_by_keys(real, keys=ignore_keys, exclude=True), \ - message + assert _filter_by_keys(excepted, keys=ignore_keys, exclude=True) == _filter_by_keys( + real, keys=ignore_keys, exclude=True + ), message else: assert excepted == real, message @@ -299,13 +298,13 @@ def assertQuerySetIdEqual(self, l1, l2): ) def get_media_url(self, file): - return f'http://testserver/media/{file}' + return f"http://testserver/media/{file}" - def get_media_file(self, file, mode='rb') -> bytes: + def get_media_file(self, file, mode="rb") -> bytes: with open(os.path.join(TEST_MEDIA_ROOT, file), mode) as fp: return fp.read() - def get_json_media_file(self, file, mode='rb') -> dict: + def get_json_media_file(self, file, mode="rb") -> dict: return json.loads(self.get_media_file(file, mode=mode)) def update_obj(self, obj, **fields): @@ -318,7 +317,7 @@ def get_datetime_str(self, _datetime): return _datetime.isoformat() def get_date_str(self, _datetime): - return _datetime.strftime('%Y-%m-%d') + return _datetime.strftime("%Y-%m-%d") def get_aware_datetime(self, *args, **kwargs): return timezone.make_aware(datetime.datetime(*args, **kwargs)) @@ -328,10 +327,10 @@ def get_aware_datetime_str(self, *args, **kwargs): # Some Rest helper functions def assert_http_code(self, response, status_code): - error_resp = getattr(response, 'data', None) + error_resp = getattr(response, "data", None) mesg = error_resp - if isinstance(error_resp, dict) and 'errors' in error_resp: - mesg = error_resp['errors'] + if isinstance(error_resp, dict) and "errors" in error_resp: + mesg = error_resp["errors"] return self.assertEqual(response.status_code, status_code, mesg) def assert_400(self, response): @@ -361,6 +360,7 @@ class GraphQLSnapShotTestCase(GraphQLTestCase, SnapShotTextCase): This TestCase can be used with `self.assertMatchSnapshot`. Make sure to only include snapshottests as we are using database flush. """ + maxDiff = None factories_used = [] @@ -373,9 +373,9 @@ def setUp(self): factory.reset_sequence() # XXX: Quick hack to make sure _snapshot_file is always defined. Which seems to be missing when running in CI # https://github.com/syrusakbary/snapshottest/blob/770b8f14cd965d923a0183a0e531e9ec0ba20192/snapshottest/unittest.py#L86 - if not hasattr(self, '_snapshot_file'): + if not hasattr(self, "_snapshot_file"): self._snapshot_file = inspect.getfile(type(self)) - if not hasattr(self, '_snapshot_tests'): + if not hasattr(self, "_snapshot_tests"): self._snapshot_tests = [] super().setUp() diff --git a/utils/graphene/types.py b/utils/graphene/types.py index 42cfd90037..d08dff2b38 100644 --- a/utils/graphene/types.py +++ b/utils/graphene/types.py @@ -1,10 +1,10 @@ from collections import OrderedDict - -import graphene from typing import Union +import graphene from django.db import models -from graphene import ObjectType, Field, Int +from graphene import Field, Int, ObjectType + # we will use graphene_django registry over the one from graphene_django_extras # since it adds information regarding nullability in the schema definition from graphene_django.registry import get_global_registry @@ -13,23 +13,20 @@ from graphene_django_extras.base_types import factory_type from graphene_django_extras.types import DjangoObjectOptions -from deep.serializers import TempClientIdMixin from deep.caches import local_cache -from deep.serializers import URLCachedFileField +from deep.serializers import TempClientIdMixin, URLCachedFileField from utils.graphene.fields import CustomDjangoListField from utils.graphene.options import CustomObjectTypeOptions class ClientIdMixin(graphene.ObjectType): - client_id = graphene.ID(required=True, description='Provides clientID if provided in the mutation. Fallback is id') + client_id = graphene.ID(required=True, description="Provides clientID if provided in the mutation. Fallback is id") @staticmethod def resolve_client_id(root, info): # NOTE: We should always provide non-null client_id client_id = ( - getattr(root, 'client_id', None) or - local_cache.get(TempClientIdMixin.get_cache_key(root, info.context)) or - root.id + getattr(root, "client_id", None) or local_cache.get(TempClientIdMixin.get_cache_key(root, info.context)) or root.id ) if client_id is not None: return client_id @@ -49,15 +46,14 @@ def __init_subclass_with_meta__( **options, ): - assert base_type is not None, ( - 'Base Type of the ListField should be defined in the Meta.' - ) + assert base_type is not None, "Base Type of the ListField should be defined in the Meta." if not DJANGO_FILTER_INSTALLED and filterset_class: raise Exception("Can only set filterset_class if Django-Filter is installed") if not filterset_class: from django_filters import rest_framework as df + filterset_class = df.FilterSet results_field_name = results_field_name or "results" @@ -94,19 +90,18 @@ def __init_subclass_with_meta__( name="pageSize", description="Page Size", ), - ) + ), ] ) - super(CustomListObjectType, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super(CustomListObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) class CustomDjangoListObjectType(DjangoListObjectType): """ Updates `DjangoListObjectType` to add page related fields into type definition """ + class Meta: abstract = True @@ -126,20 +121,19 @@ def __init_subclass_with_meta__( **options, ): - assert is_valid_django_model(model), ( - 'You need to pass a valid Django Model in {}.Meta, received "{}".' - ).format(cls.__name__, model) + assert is_valid_django_model(model), ('You need to pass a valid Django Model in {}.Meta, received "{}".').format( + cls.__name__, model + ) assert pagination is None, ( - 'Pagination should be applied on the ListField enclosing {0} rather than its `{0}.Meta`.' + "Pagination should be applied on the ListField enclosing {0} rather than its `{0}.Meta`." ).format(cls.__name__) if not DJANGO_FILTER_INSTALLED and filter_fields: raise Exception("Can only set filter_fields if Django-Filter is installed") assert isinstance(queryset, models.QuerySet) or queryset is None, ( - "The attribute queryset in {} needs to be an instance of " - 'Django model queryset, received "{}".' + "The attribute queryset in {} needs to be an instance of " 'Django model queryset, received "{}".' ).format(cls.__name__, queryset) results_field_name = results_field_name or "results" @@ -199,13 +193,11 @@ def __init_subclass_with_meta__( name="pageSize", description="Page Size", ), - ) + ), ] ) - super(DjangoListObjectType, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super(DjangoListObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) class FileFieldType(graphene.ObjectType): @@ -219,9 +211,7 @@ def resolve_name(root, info, **kwargs) -> Union[str, None]: return root.name def resolve_url(root, info, **kwargs) -> Union[str, None]: - return info.context.request.build_absolute_uri( - URLCachedFileField.name_to_representation(root) - ) + return info.context.request.build_absolute_uri(URLCachedFileField.name_to_representation(root)) class DateCountType(graphene.ObjectType): diff --git a/utils/hid/hid.py b/utils/hid/hid.py index 72a6c91436..ecff0dac44 100644 --- a/utils/hid/hid.py +++ b/utils/hid/hid.py @@ -1,8 +1,8 @@ -from django.conf import settings -from django.contrib.auth import get_user_model import logging -import requests +import requests +from django.conf import settings +from django.contrib.auth import get_user_model from user.models import Profile logger = logging.getLogger(__name__) @@ -13,6 +13,7 @@ class HidConfig: """ HID Configs """ + def __init__(self): self.client_id = settings.HID_CLIENT_ID self.redirect_url = settings.HID_CLIENT_REDIRECT_URL @@ -29,34 +30,33 @@ def __init__(self, message=None): class InvalidHIDConfigurationException(HIDBaseException): - message = 'Invalid HID Configuration' + message = "Invalid HID Configuration" class HIDFetchFailedException(HIDBaseException): - message = 'HID User data fetch failed' + message = "HID User data fetch failed" class HIDEmailNotVerifiedException(HIDBaseException): - message = 'Email is not verified in HID' + message = "Email is not verified in HID" class HumanitarianId: """ Handles HID Token """ + def __init__(self, access_token): - self.data = self._process_hid_user_data( - self.get_user_information_from_access_token(access_token) - ) - self.user_id = self.data['hid'] + self.data = self._process_hid_user_data(self.get_user_information_from_access_token(access_token)) + self.user_id = self.data["hid"] @staticmethod def _process_hid_user_data(data): - first_name, *last_name = (data['name'] or '').split(' ') - last_name = ' '.join(last_name) + first_name, *last_name = (data["name"] or "").split(" ") + last_name = " ".join(last_name) return dict( - hid=data['sub'], - email=data['email'], + hid=data["sub"], + email=data["email"], first_name=first_name, last_name=last_name, ) @@ -64,7 +64,7 @@ def _process_hid_user_data(data): def get_user(self): profile = Profile.objects.filter(hid=self.user_id).first() if profile is None: - user = User.objects.filter(email=self.data['email']).first() + user = User.objects.filter(email=self.data["email"]).first() if user: self._save_user(user) return user @@ -76,9 +76,9 @@ def _save_user(self, user): """ Sync data from HID to user """ - user.first_name = self.data['first_name'] - user.last_name = self.data['last_name'] - user.email = self.data['email'] + user.first_name = self.data["first_name"] + user.last_name = self.data["last_name"] + user.email = self.data["email"] user.profile.hid = self.user_id user.save() @@ -86,12 +86,12 @@ def _create_user(self): """ Create User with HID data """ - username = self.data['email'] + username = self.data["email"] user = User.objects.create_user( - first_name=self.data['first_name'], - last_name=self.data['last_name'], - email=self.data['email'], + first_name=self.data["first_name"], + last_name=self.data["last_name"], + email=self.data["email"], username=username, ) @@ -102,14 +102,15 @@ def _create_user(self): def get_user_information_from_access_token(self, access_token): if config.auth_uri: # https://github.com/UN-OCHA/hid_api/blob/363f5a06fe25360515494bce050a6d2987058a2a/api/controllers/UserController.js#L1536-L1546 - url = config.auth_uri + '/account.json' + url = config.auth_uri + "/account.json" r = requests.post( - url, headers={'Authorization': 'Bearer ' + access_token}, + url, + headers={"Authorization": "Bearer " + access_token}, ) if r.status_code == 200: data = r.json() - if not data['email_verified']: - raise HIDEmailNotVerifiedException('Email is not verified in HID') + if not data["email_verified"]: + raise HIDEmailNotVerifiedException("Email is not verified in HID") return data - raise HIDFetchFailedException('HID Get Token Failed!! \n{}'.format(r.json())) - raise InvalidHIDConfigurationException('Invalid HID Configuration') + raise HIDFetchFailedException("HID Get Token Failed!! \n{}".format(r.json())) + raise InvalidHIDConfigurationException("Invalid HID Configuration") diff --git a/utils/hid/tests/test_hid.py b/utils/hid/tests/test_hid.py index 466691980c..b230196059 100644 --- a/utils/hid/tests/test_hid.py +++ b/utils/hid/tests/test_hid.py @@ -1,25 +1,27 @@ -import requests import logging from unittest.mock import patch +import requests + # from rest_framework import status from django.test import TestCase -from utils.hid import hid + from utils.common import DEFAULT_HEADERS +from utils.hid import hid # from urllib.parse import urlparse # from requests.exceptions import ConnectionError # import traceback # MOCK Data -HID_EMAIL = 'dev@togglecorp.com' -HID_PASSWORD = 'XXXXXXXXXXXXXXXX' -HID_FIRSTNAME = 'Togglecorp' -HID_LASTNAME = 'Dev' +HID_EMAIL = "dev@togglecorp.com" +HID_PASSWORD = "XXXXXXXXXXXXXXXX" +HID_FIRSTNAME = "Togglecorp" +HID_LASTNAME = "Dev" HID_LOGIN_URL = ( - f'{hid.config.auth_uri}/oauth/authorize?' - f'response_type=token&client_id={hid.config.client_id}&scope=profile&state=12345&redirect_uri={hid.config.redirect_url}' + f"{hid.config.auth_uri}/oauth/authorize?" + f"response_type=token&client_id={hid.config.client_id}&scope=profile&state=12345&redirect_uri={hid.config.redirect_url}" ) logger = logging.getLogger(__name__) @@ -29,6 +31,7 @@ class HIDIntegrationTest(TestCase): """ Test HID Integration """ + def setUp(self): self.requests = requests.session() self.headers = DEFAULT_HEADERS @@ -56,7 +59,7 @@ def get_access_token(self): Get access token from HID """ # Mocking - return 'XXXXXXXXXXXXXXXXXXXXXXXXXXXX' + return "XXXXXXXXXXXXXXXXXXXXXXXXXXXX" """ # NOTE: LIVE API IS NOT USED FOR TESTING. LEAVING IT HERE FOR REFERENCE ONLY ##### @@ -102,8 +105,8 @@ def _setup_mock_hid_requests(self, mock_requests): mock_requests.post.return_value.status_code = 200 mock_requests.post.return_value.json.return_value = { # Also returns other value, but we don't require it for now - 'id': 'xxxxxxx1234xxxxxxxxxxxx', - 'sub': 'xxxxxxx1234xxxxxxxxxxxx', + "id": "xxxxxxx1234xxxxxxxxxxxx", + "sub": "xxxxxxx1234xxxxxxxxxxxx", # Also returns other value, but we don't require it for now "email_verified": True, "email": HID_EMAIL, @@ -114,7 +117,7 @@ def _setup_mock_hid_requests(self, mock_requests): } return mock_requests.post.return_value - @patch('utils.hid.hid.requests') + @patch("utils.hid.hid.requests") def test_new_user(self, mock_requests): """ Test for new user @@ -122,7 +125,7 @@ def test_new_user(self, mock_requests): mock_return_value = self._setup_mock_hid_requests(mock_requests) access_token = self.get_access_token() user = hid.HumanitarianId(access_token).get_user() - self.assertEqual(getattr(user, 'email', None), HID_EMAIL) + self.assertEqual(getattr(user, "email", None), HID_EMAIL) user.delete() mock_return_value.status_code = 400 @@ -130,43 +133,43 @@ def test_new_user(self, mock_requests): user = hid.HumanitarianId(access_token).get_user() mock_return_value.status_code = 200 - mock_return_value.json.return_value['email_verified'] = False + mock_return_value.json.return_value["email_verified"] = False with self.assertRaises(hid.HIDEmailNotVerifiedException): user = hid.HumanitarianId(access_token).get_user() - mock_return_value.json.return_value['email_verified'] = True + mock_return_value.json.return_value["email_verified"] = True - mock_return_value.json.return_value.pop('name') + mock_return_value.json.return_value.pop("name") with self.assertRaises(KeyError): user = hid.HumanitarianId(access_token).get_user() # ----------- Name attribute change test - sample_first_name = 'Xxxxxx' - sample_last_name = 'Yyyyyy' + sample_first_name = "Xxxxxx" + sample_last_name = "Yyyyyy" # Just FN in name - mock_return_value.json.return_value['name'] = sample_first_name + mock_return_value.json.return_value["name"] = sample_first_name user = hid.HumanitarianId(access_token).get_user() - self.assertEqual(getattr(user, 'first_name'), sample_first_name) - self.assertEqual(getattr(user, 'last_name'), '') + self.assertEqual(getattr(user, "first_name"), sample_first_name) + self.assertEqual(getattr(user, "last_name"), "") user.delete() # Both FN+LN in name - mock_return_value.json.return_value['name'] = f'{sample_first_name} {sample_last_name}' + mock_return_value.json.return_value["name"] = f"{sample_first_name} {sample_last_name}" user = hid.HumanitarianId(access_token).get_user() - self.assertEqual(getattr(user, 'first_name'), sample_first_name) - self.assertEqual(getattr(user, 'last_name'), sample_last_name) + self.assertEqual(getattr(user, "first_name"), sample_first_name) + self.assertEqual(getattr(user, "last_name"), sample_last_name) user.delete() # Name = None - for sample_name in [None, '']: - mock_return_value.json.return_value['name'] = sample_name + for sample_name in [None, ""]: + mock_return_value.json.return_value["name"] = sample_name user = hid.HumanitarianId(access_token).get_user() - self.assertEqual(getattr(user, 'first_name'), '') - self.assertEqual(getattr(user, 'last_name'), '') + self.assertEqual(getattr(user, "first_name"), "") + self.assertEqual(getattr(user, "last_name"), "") user.delete() - mock_return_value.json.return_value['name'] = 'Xxxxxx Xxxxxx' + mock_return_value.json.return_value["name"] = "Xxxxxx Xxxxxx" - @patch('utils.hid.hid.requests') + @patch("utils.hid.hid.requests") def test_link_user(self, mock_requests): """ Test for old user @@ -176,11 +179,7 @@ def test_link_user(self, mock_requests): access_token = self.get_access_token() user = hid.User.objects.create_user( - first_name=HID_FIRSTNAME, - last_name=HID_LASTNAME, - email=HID_EMAIL, - username=HID_EMAIL, - password=HID_PASSWORD + first_name=HID_FIRSTNAME, last_name=HID_LASTNAME, email=HID_EMAIL, username=HID_EMAIL, password=HID_PASSWORD ) hid_user = hid.HumanitarianId(access_token).get_user() diff --git a/utils/image.py b/utils/image.py index bcb8b680f5..41fa02e303 100644 --- a/utils/image.py +++ b/utils/image.py @@ -1,16 +1,17 @@ import base64 -import uuid import imghdr +import uuid + from django.core.files.base import ContentFile def decode_base64_if_possible(data): if not isinstance(data, str): return data, None - if 'data:' not in data or ';base64,' not in data: + if "data:" not in data or ";base64," not in data: return data, None - header, data = data.split(';base64,') + header, data = data.split(";base64,") try: decoded_file = base64.b64decode(data) @@ -19,7 +20,7 @@ def decode_base64_if_possible(data): filename = str(uuid.uuid4())[:12] ext = imghdr.what(filename, decoded_file) - complete_filename = '{}.{}'.format(filename, ext) + complete_filename = "{}.{}".format(filename, ext) data = ContentFile(decoded_file, name=complete_filename) return data, header diff --git a/utils/request.py b/utils/request.py index d535b8e594..725d8882c7 100644 --- a/utils/request.py +++ b/utils/request.py @@ -1,8 +1,8 @@ -import requests import json from dataclasses import dataclass, field -from typing import Union, Dict, Callable +from typing import Callable, Dict, Union +import requests from django.core.files.base import ContentFile from utils.common import sanitize_text @@ -13,6 +13,7 @@ def wrapper(self, *args, **kwargs): if self.ignore_error and self.error_on_response: return return func(self, *args, **kwargs) + return wrapper diff --git a/utils/sentry.py b/utils/sentry.py index f13f637109..9c443d9d58 100644 --- a/utils/sentry.py +++ b/utils/sentry.py @@ -1,21 +1,22 @@ +import logging import os -import logging import sentry_sdk -from django.core.exceptions import PermissionDenied -from django.conf import settings + +# Celery Terminated Exception: The worker processing a job has been terminated by user request. +from billiard.exceptions import Terminated from celery.exceptions import Retry as CeleryRetry -from sentry_sdk.integrations.logging import ignore_logger +from django.conf import settings +from django.core.exceptions import PermissionDenied from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.django import DjangoIntegration +from sentry_sdk.integrations.logging import ignore_logger from sentry_sdk.integrations.redis import RedisIntegration -# Celery Terminated Exception: The worker processing a job has been terminated by user request. -from billiard.exceptions import Terminated -from deep.exceptions import UnauthorizedException from apps.jwt_auth.errors import InvalidCaptchaError +from deep.exceptions import UnauthorizedException -logger = logging.getLogger('deep_sentry.errors.logging') +logger = logging.getLogger("deep_sentry.errors.logging") IGNORED_ERRORS = [ Terminated, @@ -25,8 +26,8 @@ CeleryRetry, ] IGNORED_LOGGERS = [ - 'graphql.execution.utils', - 'deep_sentry.errors.logging', + "graphql.execution.utils", + "deep_sentry.errors.logging", ] for _logger in IGNORED_LOGGERS: @@ -43,46 +44,41 @@ def fetch_git_sha(path, head=None): >>> fetch_git_sha(os.path.dirname(__file__)) """ if not head: - head_path = os.path.join(path, '.git', 'HEAD') + head_path = os.path.join(path, ".git", "HEAD") if not os.path.exists(head_path): - raise InvalidGitRepository( - 'Cannot identify HEAD for git repository at %s' % (path,)) + raise InvalidGitRepository("Cannot identify HEAD for git repository at %s" % (path,)) - with open(head_path, 'r') as fp: + with open(head_path, "r") as fp: head = str(fp.read()).strip() - if head.startswith('ref: '): + if head.startswith("ref: "): head = head[5:] - revision_file = os.path.join( - path, '.git', *head.split('/') - ) + revision_file = os.path.join(path, ".git", *head.split("/")) else: return head else: - revision_file = os.path.join(path, '.git', 'refs', 'heads', head) + revision_file = os.path.join(path, ".git", "refs", "heads", head) if not os.path.exists(revision_file): - if not os.path.exists(os.path.join(path, '.git')): - raise InvalidGitRepository( - '%s does not seem to be the root of a git repository' % (path,)) + if not os.path.exists(os.path.join(path, ".git")): + raise InvalidGitRepository("%s does not seem to be the root of a git repository" % (path,)) # Check for our .git/packed-refs' file since a `git gc` may have run # https://git-scm.com/book/en/v2/Git-Internals-Maintenance-and-Data-Recovery - packed_file = os.path.join(path, '.git', 'packed-refs') + packed_file = os.path.join(path, ".git", "packed-refs") if os.path.exists(packed_file): with open(packed_file) as fh: for line in fh: line = line.rstrip() - if line and line[:1] not in ('#', '^'): + if line and line[:1] not in ("#", "^"): try: - revision, ref = line.split(' ', 1) + revision, ref = line.split(" ", 1) except ValueError: continue if ref == head: return str(revision) - raise InvalidGitRepository( - 'Unable to find ref to head "%s" in repository' % (head,)) + raise InvalidGitRepository('Unable to find ref to head "%s" in repository' % (head,)) with open(revision_file) as fh: return str(fh.read()).strip() @@ -101,7 +97,7 @@ def init_sentry(app_type, tags={}, **config): integrations=integrations, ) with sentry_sdk.configure_scope() as scope: - scope.set_tag('app_type', app_type) + scope.set_tag("app_type", app_type) for tag, value in tags.items(): scope.set_tag(tag, value) @@ -112,6 +108,7 @@ class SentryGrapheneMiddleware(object): Then raise the error again and let Graphene handle it. https://medium.com/open-graphql/monitoring-graphene-django-python-graphql-api-using-sentry-c0b0c07a344f """ + # TODO: This need further work (Use this in GraphqlView instead of middleware) def on_error(self, root, info, **args): @@ -120,15 +117,16 @@ def _on_error(error): user = info.context.user if user and user.id: scope.user = { - 'id': user.id, - 'email': user.email, + "id": user.id, + "email": user.email, } - scope.set_extra('is_superuser', user.is_superuser) - scope.set_tag('kind', info.operation.operation) + scope.set_extra("is_superuser", user.is_superuser) + scope.set_tag("kind", info.operation.operation) sentry_sdk.capture_exception(error) # log to console logger.error(error, exc_info=True) raise error + return _on_error def resolve(self, next, root, info, **args): diff --git a/utils/tests.py b/utils/tests.py index 33660337e3..bb762c4f90 100644 --- a/utils/tests.py +++ b/utils/tests.py @@ -1,8 +1,8 @@ -import unittest import copy +import unittest -from utils.data_structures import Dict from utils.common import remove_empty_keys_from_dict +from utils.data_structures import Dict class TestDict(unittest.TestCase): @@ -10,7 +10,7 @@ def test_creation(self): d = Dict(a=1, b=2) assert isinstance(d, Dict) assert isinstance(d, dict) - d = Dict({'a': 1, 'b': 2}) + d = Dict({"a": 1, "b": 2}) assert isinstance(d, Dict) assert isinstance(d, dict) @@ -18,14 +18,14 @@ def test_access(self): d = Dict(a=1, b=2) assert d.a == 1 assert d.b == 2 - assert d['a'] == d.a - assert d['b'] == d.b + assert d["a"] == d.a + assert d["b"] == d.b def test_set(self): d = Dict() d.b = 3 - assert d['b'] == 3 - d['c'] = 4 + assert d["b"] == 3 + d["c"] = 4 assert d.c == 4 def test_nested(self): @@ -34,49 +34,51 @@ def test_nested(self): assert d.a.b == 1 assert d.a.c == 2 assert d.b == 3 - assert d['a']['b'] == 1 - d = Dict({'a': {'b': 1, 'c': 2}, 'b': 3, 'c': 4}) + assert d["a"]["b"] == 1 + d = Dict({"a": {"b": 1, "c": 2}, "b": 3, "c": 4}) assert isinstance(d.a, Dict) assert d.a.b == 1 assert d.a.c == 2 assert d.b == 3 - assert d['a']['b'] == 1 + assert d["a"]["b"] == 1 def test_other_methods(self): d = Dict(a=2, b=3, c=Dict(a=1)) - assert sorted(d.keys()) == ['a', 'b', 'c'] + assert sorted(d.keys()) == ["a", "b", "c"] def test_remove_empty_keys_from_dict(self): TEST_SET = [ ({}, {}), - ({ - 'key1': None, - 'key2': [], - 'key3': (), - 'key4': {}, - }, {}), ( - {'key1': {}, 'key2': 'value1'}, - {'key2': 'value1'} - ), - ({ - 'key1': { - 'key11': {}, - 'key2': 'value2', + { + "key1": None, + "key2": [], + "key3": (), + "key4": {}, }, - 'key2': 'value2', - 'sample': { - 'sample2': { - 'sample3': { + {}, + ), + ({"key1": {}, "key2": "value1"}, {"key2": "value1"}), + ( + { + "key1": { + "key11": {}, + "key2": "value2", + }, + "key2": "value2", + "sample": { + "sample2": { + "sample3": {}, }, }, }, - }, { - 'key1': { - 'key2': 'value2', + { + "key1": { + "key2": "value2", + }, + "key2": "value2", }, - 'key2': 'value2', - }), + ), ] for obj, expected_obj in TEST_SET: diff --git a/utils/web_info_extractor/__init__.py b/utils/web_info_extractor/__init__.py index 1dca4128c6..7d136db941 100644 --- a/utils/web_info_extractor/__init__.py +++ b/utils/web_info_extractor/__init__.py @@ -3,9 +3,8 @@ from .default import DefaultWebInfoExtractor from .redhum import RedhumWebInfoExtractor - EXTRACTORS = { - 'redhum.org': RedhumWebInfoExtractor, + "redhum.org": RedhumWebInfoExtractor, } diff --git a/utils/web_info_extractor/base.py b/utils/web_info_extractor/base.py index 1f034f250e..f39ffc5de1 100644 --- a/utils/web_info_extractor/base.py +++ b/utils/web_info_extractor/base.py @@ -1,15 +1,16 @@ -from datetime import datetime, date +from datetime import date, datetime class ExtractorMixin: """ Mixin that implements get_date_str and serialized_data """ + # fields are accessed by get_{fielname}. If fieldname is to be reanamed, mention # it as 'source_field:rename_to'. For example: 'date_str:date' will have date_str value # in 'date' field of serialized_data - fields = ['title', 'date_str:date', 'country', 'source', 'author'] + fields = ["title", "date_str:date", "country", "source", "author"] def get_date_str(self): parsed = self.get_date() @@ -20,10 +21,10 @@ def get_date_str(self): def serialized_data(self): data = {} for fieldname in self.fields: - if ':' in fieldname: - source_field, rename_as = fieldname.split(':')[:2] + if ":" in fieldname: + source_field, rename_as = fieldname.split(":")[:2] else: source_field, rename_as = fieldname, fieldname - getter = getattr(self, f'get_{source_field}') + getter = getattr(self, f"get_{source_field}") data[rename_as] = getter and getter() return data diff --git a/utils/web_info_extractor/default.py b/utils/web_info_extractor/default.py index 7a6f6628b3..ab8ffa6a26 100644 --- a/utils/web_info_extractor/default.py +++ b/utils/web_info_extractor/default.py @@ -1,16 +1,16 @@ -from bs4 import BeautifulSoup -from readability.readability import Document from urllib.parse import urlparse -from utils.date_extractor import extract_date import requests import tldextract +from bs4 import BeautifulSoup +from readability.readability import Document -from .base import ExtractorMixin +from utils.date_extractor import extract_date +from .base import ExtractorMixin HEADERS = { - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36', # noqa + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36", # noqa } @@ -26,7 +26,7 @@ def __init__(self, url): except requests.exceptions.RequestException: return - if 'text/html' in head.headers.get('content-type', ''): + if "text/html" in head.headers.get("content-type", ""): try: response = requests.get(url, headers=HEADERS, verify=False) html = response.text @@ -35,7 +35,7 @@ def __init__(self, url): return self.readable = Document(html) - self.page = BeautifulSoup(html, 'lxml') + self.page = BeautifulSoup(html, "lxml") def get_title(self): return self.readable and self.readable.short_title() @@ -50,11 +50,11 @@ def get_date_str(self): def get_country(self): if not self.page: return None - country = self.page.select('.primary-country .country a') + country = self.page.select(".primary-country .country a") if country: return country[0].text.strip() - country = self.page.select('.country') + country = self.page.select(".country") if country: return country[0].text.strip() @@ -65,7 +65,7 @@ def get_source(self): def get_author(self): if self.page: - source = self.page.select('.field-source') + source = self.page.select(".field-source") if source: return source[0].text.strip() @@ -78,10 +78,10 @@ def get_content(self): def serialized_data(self): data = {} for fieldname in self.fields: - if ':' in fieldname: - source_field, rename_as = fieldname.split(':')[:2] + if ":" in fieldname: + source_field, rename_as = fieldname.split(":")[:2] else: source_field, rename_as = fieldname, fieldname - getter = getattr(self, f'get_{source_field}') + getter = getattr(self, f"get_{source_field}") data[rename_as] = getter and getter() return data diff --git a/utils/web_info_extractor/redhum.py b/utils/web_info_extractor/redhum.py index 518ff101c1..61d08330cd 100644 --- a/utils/web_info_extractor/redhum.py +++ b/utils/web_info_extractor/redhum.py @@ -1,12 +1,12 @@ import re from urllib.parse import urlparse + import requests from .base import ExtractorMixin - HEADERS = { - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36', # noqa + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36", # noqa } @@ -14,39 +14,39 @@ class RedhumWebInfoExtractor(ExtractorMixin): def __init__(self, url): self.url = url self.page = {} - url_parse = re.search(r'https://redhum.org/documento/(?P<report_id>\d+)\/?$', url) + url_parse = re.search(r"https://redhum.org/documento/(?P<report_id>\d+)\/?$", url) if not url_parse: return - report_id = url_parse.group('report_id') - rw_url = f'https://api.reliefweb.int/v1/reports/{report_id}' + report_id = url_parse.group("report_id") + rw_url = f"https://api.reliefweb.int/v1/reports/{report_id}" params = { - 'appname': 'redhum', - 'fields[include][]': ['title', 'primary_country', 'source', 'date', 'body-html'], + "appname": "redhum", + "fields[include][]": ["title", "primary_country", "source", "date", "body-html"], } try: response = requests.get(rw_url, headers=HEADERS, params=params) - self.page = response.json()['data'][0]['fields'] + self.page = response.json()["data"][0]["fields"] except Exception: return def get_title(self): - return self.page.get('title') + return self.page.get("title") def get_date(self): - return self.page.get('date', {}).get('created', '').split('T')[0] + return self.page.get("date", {}).get("created", "").split("T")[0] def get_country(self): - return self.page.get('primary_country', {}).get('name') + return self.page.get("primary_country", {}).get("name") def get_source(self): - return 'redhum' + return "redhum" def get_author(self): - return self.page.get('source', [{}])[0].get('longname') + return self.page.get("source", [{}])[0].get("longname") def get_website(self): return urlparse(self.url).netloc def get_content(self): - return self.page.get('body-html', '') + return self.page.get("body-html", "")