Skip to content

Route Manager

Route Manager is the core app, and it handles adding and removing routes that will be blocked.

admin

Register models in the Admin site.

ActionTypeAdmin

Bases: SimpleHistoryAdmin

Configure the ActionType and how it shows up in the Admin site.

Source code in scram/route_manager/admin.py
34
35
36
37
38
39
@admin.register(ActionType)
class ActionTypeAdmin(SimpleHistoryAdmin):
    """Configure the ActionType and how it shows up in the Admin site."""

    list_filter = ("available",)
    list_display = ("name", "available")

EntryAdmin

Bases: SimpleHistoryAdmin

Configure how Entries show up in the Admin site.

Source code in scram/route_manager/admin.py
42
43
44
45
46
47
48
49
50
51
52
@admin.register(Entry)
class EntryAdmin(SimpleHistoryAdmin):
    """Configure how Entries show up in the Admin site."""

    list_select_related = True

    list_filter = [
        "is_active",
        WhoFilter,
    ]
    search_fields = ["route", "comment"]

WhoFilter

Bases: SimpleListFilter

Only display users who have added entries in the list_filter.

Source code in scram/route_manager/admin.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class WhoFilter(admin.SimpleListFilter):
    """Only display users who have added entries in the list_filter."""

    title = "By Username"
    parameter_name = "who"

    # ruff: noqa: PLR6301
    def lookups(self, request, model_admin):
        """Return list of users who have added entries."""
        users_with_entries = Entry.objects.values("who").distinct()

        # If no users have entries, return an empty list so they don't show in filter
        if not users_with_entries:
            return []

        # Return a list of users who have made entries
        return [(user["who"], user["who"]) for user in users_with_entries]

    def queryset(self, request, queryset):
        """Queryset for users."""
        if self.value():
            return queryset.filter(who=self.value())
        return queryset

lookups(request, model_admin)

Return list of users who have added entries.

Source code in scram/route_manager/admin.py
16
17
18
19
20
21
22
23
24
25
def lookups(self, request, model_admin):
    """Return list of users who have added entries."""
    users_with_entries = Entry.objects.values("who").distinct()

    # If no users have entries, return an empty list so they don't show in filter
    if not users_with_entries:
        return []

    # Return a list of users who have made entries
    return [(user["who"], user["who"]) for user in users_with_entries]

queryset(request, queryset)

Queryset for users.

Source code in scram/route_manager/admin.py
27
28
29
30
31
def queryset(self, request, queryset):
    """Queryset for users."""
    if self.value():
        return queryset.filter(who=self.value())
    return queryset

api

The API, which leverages Django Request Framework.

exceptions

Custom exceptions for the API.

ActiontypeNotAllowed

Bases: APIException

An operation attempted to perform an action on behalf of a client that is unauthorized to perform that type.

Source code in scram/route_manager/api/exceptions.py
25
26
27
28
29
30
class ActiontypeNotAllowed(APIException):
    """An operation attempted to perform an action on behalf of a client that is unauthorized to perform that type."""

    status_code = 403
    default_detail = "This client is not allowed to use this actiontype"
    default_code = "actiontype_not_allowed"

IgnoredRoute

Bases: APIException

An operation attempted to add a route that overlaps with a route on the ignore list.

Source code in scram/route_manager/api/exceptions.py
17
18
19
20
21
22
class IgnoredRoute(APIException):
    """An operation attempted to add a route that overlaps with a route on the ignore list."""

    status_code = 400
    default_detail = "This CIDR is on the ignore list. You are not allowed to add it here."
    default_code = "ignored_route"

PrefixTooLarge

Bases: APIException

The CIDR prefix that was specified is larger than the prefix allowed in the settings.

Source code in scram/route_manager/api/exceptions.py
 7
 8
 9
10
11
12
13
14
class PrefixTooLarge(APIException):
    """The CIDR prefix that was specified is larger than the prefix allowed in the settings."""

    v4_min_prefix = getattr(settings, "V4_MINPREFIX", 0)
    v6_min_prefix = getattr(settings, "V6_MINPREFIX", 0)
    status_code = 400
    default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}"  # noqa: E501
    default_code = "prefix_too_large"

serializers

Serializers provide mappings between the API and the underlying model.

ActionTypeSerializer

Bases: ModelSerializer

Map the serializer to the model via Meta.

Source code in scram/route_manager/api/serializers.py
21
22
23
24
25
26
27
28
class ActionTypeSerializer(serializers.ModelSerializer):
    """Map the serializer to the model via Meta."""

    class Meta:
        """Maps to the ActionType model, and specifies the fields exposed by the API."""

        model = ActionType
        fields = ["pk", "name", "available"]
Meta

Maps to the ActionType model, and specifies the fields exposed by the API.

Source code in scram/route_manager/api/serializers.py
24
25
26
27
28
class Meta:
    """Maps to the ActionType model, and specifies the fields exposed by the API."""

    model = ActionType
    fields = ["pk", "name", "available"]

ClientSerializer

Bases: ModelSerializer

Map the serializer to the model via Meta.

Source code in scram/route_manager/api/serializers.py
45
46
47
48
49
50
51
52
class ClientSerializer(serializers.ModelSerializer):
    """Map the serializer to the model via Meta."""

    class Meta:
        """Maps to the Client model, and specifies the fields exposed by the API."""

        model = Client
        fields = ["hostname", "uuid"]
Meta

Maps to the Client model, and specifies the fields exposed by the API.

Source code in scram/route_manager/api/serializers.py
48
49
50
51
52
class Meta:
    """Maps to the Client model, and specifies the fields exposed by the API."""

    model = Client
    fields = ["hostname", "uuid"]

CustomCidrAddressField

Bases: CidrAddressField

Define a wrapper field so swagger can properly handle the inherited field.

Source code in scram/route_manager/api/serializers.py
16
17
18
@extend_schema_field(field={"type": "string", "format": "cidr"})
class CustomCidrAddressField(rest_framework.CidrAddressField):
    """Define a wrapper field so swagger can properly handle the inherited field."""

EntrySerializer

Bases: HyperlinkedModelSerializer

Due to the use of ForeignKeys, this follows some relationships to make sense via the API.

Source code in scram/route_manager/api/serializers.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class EntrySerializer(serializers.HyperlinkedModelSerializer):
    """Due to the use of ForeignKeys, this follows some relationships to make sense via the API."""

    url = serializers.HyperlinkedIdentityField(
        view_name="api:v1:entry-detail",
        lookup_url_kwarg="pk",
        lookup_field="route",
    )
    route = CustomCidrAddressField()
    actiontype = serializers.CharField(default="block")
    if CurrentUserDefault():
        # This is set if we are calling this serializer from WUI
        who = CurrentUserDefault()
    else:
        who = serializers.CharField()
    comment = serializers.CharField()

    class Meta:
        """Maps to the Entry model, and specifies the fields exposed by the API."""

        model = Entry
        fields = ["route", "actiontype", "url", "comment", "who"]

    @staticmethod
    def get_comment(obj):
        """Provide a nicer name for change reason.

        Returns:
            string: The change reason that modified the Entry.
        """
        return obj.get_change_reason()

    @staticmethod
    def create(validated_data):
        """Implement custom logic and validates creating a new route."""
        valid_route = validated_data.pop("route")
        actiontype = validated_data.pop("actiontype")
        comment = validated_data.pop("comment")

        route_instance, _ = Route.objects.get_or_create(route=valid_route)
        actiontype_instance = ActionType.objects.get(name=actiontype)
        entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)

        logger.debug("Created entry with comment: %s", comment)
        update_change_reason(entry_instance, comment)

        return entry_instance
Meta

Maps to the Entry model, and specifies the fields exposed by the API.

Source code in scram/route_manager/api/serializers.py
72
73
74
75
76
class Meta:
    """Maps to the Entry model, and specifies the fields exposed by the API."""

    model = Entry
    fields = ["route", "actiontype", "url", "comment", "who"]
create(validated_data) staticmethod

Implement custom logic and validates creating a new route.

Source code in scram/route_manager/api/serializers.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@staticmethod
def create(validated_data):
    """Implement custom logic and validates creating a new route."""
    valid_route = validated_data.pop("route")
    actiontype = validated_data.pop("actiontype")
    comment = validated_data.pop("comment")

    route_instance, _ = Route.objects.get_or_create(route=valid_route)
    actiontype_instance = ActionType.objects.get(name=actiontype)
    entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)

    logger.debug("Created entry with comment: %s", comment)
    update_change_reason(entry_instance, comment)

    return entry_instance
get_comment(obj) staticmethod

Provide a nicer name for change reason.

Returns:

Name Type Description
string

The change reason that modified the Entry.

Source code in scram/route_manager/api/serializers.py
78
79
80
81
82
83
84
85
@staticmethod
def get_comment(obj):
    """Provide a nicer name for change reason.

    Returns:
        string: The change reason that modified the Entry.
    """
    return obj.get_change_reason()

IgnoreEntrySerializer

Bases: ModelSerializer

Map the route to the right field type.

Source code in scram/route_manager/api/serializers.py
104
105
106
107
108
109
110
111
112
113
class IgnoreEntrySerializer(serializers.ModelSerializer):
    """Map the route to the right field type."""

    route = CustomCidrAddressField()

    class Meta:
        """Maps to the IgnoreEntry model, and specifies the fields exposed by the API."""

        model = IgnoreEntry
        fields = ["route", "comment"]
Meta

Maps to the IgnoreEntry model, and specifies the fields exposed by the API.

Source code in scram/route_manager/api/serializers.py
109
110
111
112
113
class Meta:
    """Maps to the IgnoreEntry model, and specifies the fields exposed by the API."""

    model = IgnoreEntry
    fields = ["route", "comment"]

RouteSerializer

Bases: ModelSerializer

Exposes route as a CIDR field.

Source code in scram/route_manager/api/serializers.py
31
32
33
34
35
36
37
38
39
40
41
42
class RouteSerializer(serializers.ModelSerializer):
    """Exposes route as a CIDR field."""

    route = CustomCidrAddressField()

    class Meta:
        """Maps to the Route model, and specifies the fields exposed by the API."""

        model = Route
        fields = [
            "route",
        ]
Meta

Maps to the Route model, and specifies the fields exposed by the API.

Source code in scram/route_manager/api/serializers.py
36
37
38
39
40
41
42
class Meta:
    """Maps to the Route model, and specifies the fields exposed by the API."""

    model = Route
    fields = [
        "route",
    ]

views

Views provide mappings between the underlying model and how they're listed in the API.

ActionTypeViewSet

Bases: ReadOnlyModelViewSet

Lookup ActionTypes by name when authenticated, and bind to the serializer.

Source code in scram/route_manager/api/views.py
26
27
28
29
30
31
32
33
34
35
36
@extend_schema(
    description="API endpoint for actiontypes",
    responses={200: ActionTypeSerializer},
)
class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet):
    """Lookup ActionTypes by name when authenticated, and bind to the serializer."""

    queryset = ActionType.objects.all()
    permission_classes = (IsAuthenticated,)
    serializer_class = ActionTypeSerializer
    lookup_field = "name"

ClientViewSet

Bases: ModelViewSet

Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer.

Source code in scram/route_manager/api/views.py
52
53
54
55
56
57
58
59
60
61
62
63
64
@extend_schema(
    description="API endpoint for clients",
    responses={200: ClientSerializer},
)
class ClientViewSet(viewsets.ModelViewSet):
    """Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer."""

    queryset = Client.objects.all()
    # We want to allow a client to be registered from anywhere
    permission_classes = (AllowAny,)
    serializer_class = ClientSerializer
    lookup_field = "hostname"
    http_method_names = ["post"]

EntryViewSet

Bases: ModelViewSet

Lookup Entry when authenticated, and bind to the serializer.

Source code in scram/route_manager/api/views.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@extend_schema(
    description="API endpoint for entries",
    responses={200: EntrySerializer},
)
class EntryViewSet(viewsets.ModelViewSet):
    """Lookup Entry when authenticated, and bind to the serializer."""

    queryset = Entry.objects.filter(is_active=True)
    permission_classes = (IsAuthenticated,)
    serializer_class = EntrySerializer
    lookup_value_regex = ".*"
    http_method_names = ["get", "post", "head", "delete"]

    def get_permissions(self):
        """Override the permissions classes for POST method since we want to accept Entry creates from any client.

        Note: We make authorization decisions on whether to actually create the object in the perform_create method
        later.
        """
        if self.request.method == "POST":
            return [AllowAny()]
        return super().get_permissions()

    def check_client_authorization(self, actiontype):
        """Ensure that a given client is authorized to use a given actiontype."""
        uuid = self.request.data.get("uuid")
        if uuid:
            authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list(
                "authorized_actiontypes__name",
                flat=True,
            )
            authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized")
            if not authorized_client or actiontype not in authorized_actiontypes:
                logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes)
                logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype)
                raise ActiontypeNotAllowed
        elif not self.request.user.has_perm("route_manager.can_add_entry"):
            raise PermissionDenied

    @staticmethod
    def check_ignore_list(route):
        """Ensure that we're not trying to block something from the ignore list."""
        overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
        if overlapping_ignore.count():
            ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()]
            logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries)
            raise IgnoredRoute

    def perform_create(self, serializer):
        """Create a new Entry, causing that route to receive the actiontype (i.e. block)."""
        actiontype = serializer.validated_data["actiontype"]
        route = serializer.validated_data["route"]
        if self.request.user.username:
            # This is set if our request comes through the WUI path
            who = self.request.user.username
        else:
            # This is set if we pass the "who" through the json data in an API call (like from Zeek)
            who = serializer.validated_data["who"]
        comment = serializer.validated_data["comment"]
        tmp_exp = self.request.data.get("expiration", "")

        try:
            expiration = parse_datetime(tmp_exp)
        except ValueError:
            logger.warning("Could not parse expiration DateTime string: %s", tmp_exp)

        # Make sure we put in an acceptable sized prefix
        min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0)
        if route.prefixlen < min_prefix:
            raise PrefixTooLarge

        self.check_client_authorization(actiontype)
        self.check_ignore_list(route)

        elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
        if not elements:
            logger.warning("No elements found for actiontype: %s", actiontype)

        for element in elements:
            msg = element.websocketmessage
            msg.msg_data[msg.msg_data_route_field] = str(route)
            # Must match a channel name defined in asgi.py
            async_to_sync(channel_layer.group_send)(
                f"translator_{actiontype}",
                {"type": msg.msg_type, "message": msg.msg_data},
            )

        serializer.save()

        entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
        if expiration:
            entry.expiration = expiration
        entry.who = who
        entry.is_active = True
        entry.comment = comment
        entry.originating_scram_instance = settings.SCRAM_HOSTNAME
        logger.info("Created entry: %s", entry)
        entry.save()

    @staticmethod
    def find_entries(arg, active_filter=None):
        """Query entries either by pk or overlapping route."""
        if not arg:
            return Entry.objects.none()

        # Is our argument an integer?
        try:
            pk = int(arg)
            query = Q(pk=pk)
        except ValueError as exc:
            # Maybe a CIDR? We want the ValueError at this point, if not.
            cidr = ipaddress.ip_network(arg, strict=False)

            min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0)
            if cidr.prefixlen < min_prefix:
                raise PrefixTooLarge from exc

            query = Q(route__route__net_overlaps=cidr)

        if active_filter is not None:
            query &= Q(is_active=active_filter)

        return Entry.objects.filter(query)

    def retrieve(self, request, pk=None, **kwargs):
        """Retrieve a single route."""
        entries = self.find_entries(pk, active_filter=True)
        # TODO: What happens if we get multiple? Is that ok? I think yes, and return them all?
        if entries.count() != 1:
            raise Http404
        serializer = EntrySerializer(entries, many=True, context={"request": request})
        return Response(serializer.data)

    def destroy(self, request, pk=None, *args, **kwargs):
        """Only delete active (e.g. announced) entries."""
        for entry in self.find_entries(pk, active_filter=True):
            entry.delete()

        return Response(status=status.HTTP_204_NO_CONTENT)
check_client_authorization(actiontype)

Ensure that a given client is authorized to use a given actiontype.

Source code in scram/route_manager/api/views.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def check_client_authorization(self, actiontype):
    """Ensure that a given client is authorized to use a given actiontype."""
    uuid = self.request.data.get("uuid")
    if uuid:
        authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list(
            "authorized_actiontypes__name",
            flat=True,
        )
        authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized")
        if not authorized_client or actiontype not in authorized_actiontypes:
            logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes)
            logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype)
            raise ActiontypeNotAllowed
    elif not self.request.user.has_perm("route_manager.can_add_entry"):
        raise PermissionDenied
check_ignore_list(route) staticmethod

Ensure that we're not trying to block something from the ignore list.

Source code in scram/route_manager/api/views.py
106
107
108
109
110
111
112
113
@staticmethod
def check_ignore_list(route):
    """Ensure that we're not trying to block something from the ignore list."""
    overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
    if overlapping_ignore.count():
        ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()]
        logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries)
        raise IgnoredRoute
destroy(request, pk=None, *args, **kwargs)

Only delete active (e.g. announced) entries.

Source code in scram/route_manager/api/views.py
200
201
202
203
204
205
def destroy(self, request, pk=None, *args, **kwargs):
    """Only delete active (e.g. announced) entries."""
    for entry in self.find_entries(pk, active_filter=True):
        entry.delete()

    return Response(status=status.HTTP_204_NO_CONTENT)
find_entries(arg, active_filter=None) staticmethod

Query entries either by pk or overlapping route.

Source code in scram/route_manager/api/views.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@staticmethod
def find_entries(arg, active_filter=None):
    """Query entries either by pk or overlapping route."""
    if not arg:
        return Entry.objects.none()

    # Is our argument an integer?
    try:
        pk = int(arg)
        query = Q(pk=pk)
    except ValueError as exc:
        # Maybe a CIDR? We want the ValueError at this point, if not.
        cidr = ipaddress.ip_network(arg, strict=False)

        min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0)
        if cidr.prefixlen < min_prefix:
            raise PrefixTooLarge from exc

        query = Q(route__route__net_overlaps=cidr)

    if active_filter is not None:
        query &= Q(is_active=active_filter)

    return Entry.objects.filter(query)
get_permissions()

Override the permissions classes for POST method since we want to accept Entry creates from any client.

Note: We make authorization decisions on whether to actually create the object in the perform_create method later.

Source code in scram/route_manager/api/views.py
80
81
82
83
84
85
86
87
88
def get_permissions(self):
    """Override the permissions classes for POST method since we want to accept Entry creates from any client.

    Note: We make authorization decisions on whether to actually create the object in the perform_create method
    later.
    """
    if self.request.method == "POST":
        return [AllowAny()]
    return super().get_permissions()
perform_create(serializer)

Create a new Entry, causing that route to receive the actiontype (i.e. block).

Source code in scram/route_manager/api/views.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def perform_create(self, serializer):
    """Create a new Entry, causing that route to receive the actiontype (i.e. block)."""
    actiontype = serializer.validated_data["actiontype"]
    route = serializer.validated_data["route"]
    if self.request.user.username:
        # This is set if our request comes through the WUI path
        who = self.request.user.username
    else:
        # This is set if we pass the "who" through the json data in an API call (like from Zeek)
        who = serializer.validated_data["who"]
    comment = serializer.validated_data["comment"]
    tmp_exp = self.request.data.get("expiration", "")

    try:
        expiration = parse_datetime(tmp_exp)
    except ValueError:
        logger.warning("Could not parse expiration DateTime string: %s", tmp_exp)

    # Make sure we put in an acceptable sized prefix
    min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0)
    if route.prefixlen < min_prefix:
        raise PrefixTooLarge

    self.check_client_authorization(actiontype)
    self.check_ignore_list(route)

    elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
    if not elements:
        logger.warning("No elements found for actiontype: %s", actiontype)

    for element in elements:
        msg = element.websocketmessage
        msg.msg_data[msg.msg_data_route_field] = str(route)
        # Must match a channel name defined in asgi.py
        async_to_sync(channel_layer.group_send)(
            f"translator_{actiontype}",
            {"type": msg.msg_type, "message": msg.msg_data},
        )

    serializer.save()

    entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
    if expiration:
        entry.expiration = expiration
    entry.who = who
    entry.is_active = True
    entry.comment = comment
    entry.originating_scram_instance = settings.SCRAM_HOSTNAME
    logger.info("Created entry: %s", entry)
    entry.save()
retrieve(request, pk=None, **kwargs)

Retrieve a single route.

Source code in scram/route_manager/api/views.py
191
192
193
194
195
196
197
198
def retrieve(self, request, pk=None, **kwargs):
    """Retrieve a single route."""
    entries = self.find_entries(pk, active_filter=True)
    # TODO: What happens if we get multiple? Is that ok? I think yes, and return them all?
    if entries.count() != 1:
        raise Http404
    serializer = EntrySerializer(entries, many=True, context={"request": request})
    return Response(serializer.data)

IgnoreEntryViewSet

Bases: ModelViewSet

Lookup IgnoreEntries by route when authenticated, and bind to the serializer.

Source code in scram/route_manager/api/views.py
39
40
41
42
43
44
45
46
47
48
49
@extend_schema(
    description="API endpoint for ignore entries",
    responses={200: IgnoreEntrySerializer},
)
class IgnoreEntryViewSet(viewsets.ModelViewSet):
    """Lookup IgnoreEntries by route when authenticated, and bind to the serializer."""

    queryset = IgnoreEntry.objects.all()
    permission_classes = (IsAuthenticated,)
    serializer_class = IgnoreEntrySerializer
    lookup_field = "route"

apps

Register ourselves with Django.

RouteManagerConfig

Bases: AppConfig

Define the name of the module that's the main app.

Source code in scram/route_manager/apps.py
6
7
8
9
class RouteManagerConfig(AppConfig):
    """Define the name of the module that's the main app."""

    name = "scram.route_manager"

authentication_backends

Define one or more custom auth backends.

ESnetAuthBackend

Bases: OIDCAuthenticationBackend

Extend the OIDC backend with a custom permission model.

Source code in scram/route_manager/authentication_backends.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ESnetAuthBackend(OIDCAuthenticationBackend):
    """Extend the OIDC backend with a custom permission model."""

    @staticmethod
    def update_groups(user, claims):
        """Set the user's group(s) to whatever is in the claims."""
        effective_groups = []
        claimed_groups = claims.get("groups", [])

        if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS):
            is_admin = False
        # Don't even look at anything else if they're denied
        else:
            is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS)
            if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS):
                effective_groups.append(Group.objects.get(name="readwrite"))
            if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS):
                effective_groups.append(Group.objects.get(name="readonly"))

        user.groups.set(effective_groups)
        user.is_staff = user.is_superuser = is_admin
        user.save()

    def create_user(self, claims):
        """Wrap the superclass's user creation."""
        user = super().create_user(claims)
        return self.update_user(user, claims)

    def update_user(self, user, claims):
        """Determine the user name from the claims and update said user's groups."""
        user.name = claims.get("given_name", "") + " " + claims.get("family_name", "")
        user.username = claims.get("preferred_username", "")
        if claims.get("groups", False):
            self.update_groups(user, claims)

        user.save()

        return user

create_user(claims)

Wrap the superclass's user creation.

Source code in scram/route_manager/authentication_backends.py
40
41
42
43
def create_user(self, claims):
    """Wrap the superclass's user creation."""
    user = super().create_user(claims)
    return self.update_user(user, claims)

update_groups(user, claims) staticmethod

Set the user's group(s) to whatever is in the claims.

Source code in scram/route_manager/authentication_backends.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@staticmethod
def update_groups(user, claims):
    """Set the user's group(s) to whatever is in the claims."""
    effective_groups = []
    claimed_groups = claims.get("groups", [])

    if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS):
        is_admin = False
    # Don't even look at anything else if they're denied
    else:
        is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS)
        if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS):
            effective_groups.append(Group.objects.get(name="readwrite"))
        if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS):
            effective_groups.append(Group.objects.get(name="readonly"))

    user.groups.set(effective_groups)
    user.is_staff = user.is_superuser = is_admin
    user.save()

update_user(user, claims)

Determine the user name from the claims and update said user's groups.

Source code in scram/route_manager/authentication_backends.py
45
46
47
48
49
50
51
52
53
54
def update_user(self, user, claims):
    """Determine the user name from the claims and update said user's groups."""
    user.name = claims.get("given_name", "") + " " + claims.get("family_name", "")
    user.username = claims.get("preferred_username", "")
    if claims.get("groups", False):
        self.update_groups(user, claims)

    user.save()

    return user

groups_overlap(a, b)

Helper function to see if a and b have any overlap.

Returns:

Name Type Description
bool

True if there's any overlap between a and b.

Source code in scram/route_manager/authentication_backends.py
 8
 9
10
11
12
13
14
def groups_overlap(a, b):
    """Helper function to see if a and b have any overlap.

    Returns:
        bool: True if there's any overlap between a and b.
    """
    return not set(a).isdisjoint(b)

context_processors

Define custom functions that take a request and add to the context before template rendering.

active_count(request)

Grab the active count of blocks.

Returns:

Name Type Description
dict

active count of blocks

Source code in scram/route_manager/context_processors.py
20
21
22
23
24
25
26
27
28
29
30
def active_count(request):
    """Grab the active count of blocks.

    Returns:
        dict: active count of blocks
    """
    if "admin" not in request.META["PATH_INFO"]:
        active_block_entries = Entry.objects.filter(is_active=True).count()
        total_block_entries = Entry.objects.all().count()
        return {"active_block_entries": active_block_entries, "total_block_entries": total_block_entries}
    return {}

login_logout(request)

Pass through the relevant URLs from the settings.

Returns:

Name Type Description
dict

login and logout URLs

Source code in scram/route_manager/context_processors.py
 9
10
11
12
13
14
15
16
17
def login_logout(request):
    """Pass through the relevant URLs from the settings.

    Returns:
       dict: login and logout URLs
    """
    login_url = reverse(settings.LOGIN_URL)
    logout_url = reverse(settings.LOGOUT_URL)
    return {"login": login_url, "logout": logout_url}

models

Define the models used in the route_manager app.

ActionType

Bases: Model

Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc.

Source code in scram/route_manager/models.py
33
34
35
36
37
38
39
40
41
42
43
44
class ActionType(models.Model):
    """Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc."""

    name = models.CharField(help_text="One-word description of the action", max_length=30)
    available = models.BooleanField(help_text="Is this a valid choice for new entries?", default=True)
    history = HistoricalRecords()

    def __str__(self):
        """Display clearly whether the action is currently available."""
        if not self.available:
            return f"{self.name} (Inactive)"
        return self.name

__str__()

Display clearly whether the action is currently available.

Source code in scram/route_manager/models.py
40
41
42
43
44
def __str__(self):
    """Display clearly whether the action is currently available."""
    if not self.available:
        return f"{self.name} (Inactive)"
    return self.name

Client

Bases: Model

Any client that would like to hit the API to add entries (e.g. Zeek).

Source code in scram/route_manager/models.py
169
170
171
172
173
174
175
176
177
178
179
180
class Client(models.Model):
    """Any client that would like to hit the API to add entries (e.g. Zeek)."""

    hostname = models.CharField(max_length=50, unique=True)
    uuid = models.UUIDField()

    is_authorized = models.BooleanField(null=True, blank=True, default=False)
    authorized_actiontypes = models.ManyToManyField(ActionType)

    def __str__(self):
        """Only display the hostname."""
        return str(self.hostname)

__str__()

Only display the hostname.

Source code in scram/route_manager/models.py
178
179
180
def __str__(self):
    """Only display the hostname."""
    return str(self.hostname)

Entry

Bases: Model

An instance of an action taken on a route.

Source code in scram/route_manager/models.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
class Entry(models.Model):
    """An instance of an action taken on a route."""

    route = models.ForeignKey("Route", on_delete=models.PROTECT)
    actiontype = models.ForeignKey("ActionType", on_delete=models.PROTECT)
    comment = models.TextField(blank=True, default="")
    is_active = models.BooleanField(default=True)
    # TODO: fix name if this works
    history = HistoricalRecords()
    when = models.DateTimeField(auto_now_add=True)
    who = models.CharField("Username", default="Unknown", max_length=30)
    originating_scram_instance = models.CharField(default="scram_hostname_not_set", max_length=255)
    expiration = models.DateTimeField(default=datetime.datetime(9999, 12, 31, 0, 0, tzinfo=datetime.UTC))
    expiration_reason = models.CharField(
        help_text="Optional reason for the expiration",
        max_length=200,
        blank=True,
        default="",
    )

    class Meta:
        """Ensure that multiple routes can be added as long as they have different action types."""

        unique_together = ["route", "actiontype"]
        verbose_name_plural = "Entries"

    def __str__(self):
        """Summarize the most important fields to something easily readable."""
        desc = f"{self.route} ({self.actiontype}) from: {self.originating_scram_instance}"
        if not self.is_active:
            desc += " (inactive)"
        return desc

    def delete(self, *args, **kwargs):
        """Set inactive instead of deleting, as we want to ensure a history of entries."""
        if not self.is_active:
            # We've already expired this route, don't send another message
            return
        # We don't actually delete records; we set them to inactive and then tell the translator to remove them
        logger.info("Deactivating %s", self.route)
        self.is_active = False
        self.save()

        # Unblock it
        async_to_sync(channel_layer.group_send)(
            f"translator_{self.actiontype}",
            {
                "type": "translator_remove",
                "message": {"route": str(self.route)},
            },
        )

    def get_change_reason(self):
        """Traverse some complex relationships to determine the most recent change reason.

        Returns:
           str: The most recent change reason
        """
        hist_mgr = getattr(self, self._meta.simple_history_manager_attribute)
        return hist_mgr.order_by("-history_date").first().history_change_reason

Meta

Ensure that multiple routes can be added as long as they have different action types.

Source code in scram/route_manager/models.py
110
111
112
113
114
class Meta:
    """Ensure that multiple routes can be added as long as they have different action types."""

    unique_together = ["route", "actiontype"]
    verbose_name_plural = "Entries"

__str__()

Summarize the most important fields to something easily readable.

Source code in scram/route_manager/models.py
116
117
118
119
120
121
def __str__(self):
    """Summarize the most important fields to something easily readable."""
    desc = f"{self.route} ({self.actiontype}) from: {self.originating_scram_instance}"
    if not self.is_active:
        desc += " (inactive)"
    return desc

delete(*args, **kwargs)

Set inactive instead of deleting, as we want to ensure a history of entries.

Source code in scram/route_manager/models.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def delete(self, *args, **kwargs):
    """Set inactive instead of deleting, as we want to ensure a history of entries."""
    if not self.is_active:
        # We've already expired this route, don't send another message
        return
    # We don't actually delete records; we set them to inactive and then tell the translator to remove them
    logger.info("Deactivating %s", self.route)
    self.is_active = False
    self.save()

    # Unblock it
    async_to_sync(channel_layer.group_send)(
        f"translator_{self.actiontype}",
        {
            "type": "translator_remove",
            "message": {"route": str(self.route)},
        },
    )

get_change_reason()

Traverse some complex relationships to determine the most recent change reason.

Returns:

Name Type Description
str

The most recent change reason

Source code in scram/route_manager/models.py
142
143
144
145
146
147
148
149
def get_change_reason(self):
    """Traverse some complex relationships to determine the most recent change reason.

    Returns:
       str: The most recent change reason
    """
    hist_mgr = getattr(self, self._meta.simple_history_manager_attribute)
    return hist_mgr.order_by("-history_date").first().history_change_reason

IgnoreEntry

Bases: Model

Define CIDRs you NEVER want to block (i.e. the "don't shoot yourself in the foot" list).

Source code in scram/route_manager/models.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class IgnoreEntry(models.Model):
    """Define CIDRs you NEVER want to block (i.e. the "don't shoot yourself in the foot" list)."""

    route = CidrAddressField(unique=True)
    comment = models.CharField(max_length=100)
    history = HistoricalRecords()

    class Meta:
        """Ensure the plural is grammatically correct."""

        verbose_name_plural = "Ignored Entries"

    def __str__(self):
        """Only display the route."""
        return str(self.route)

Meta

Ensure the plural is grammatically correct.

Source code in scram/route_manager/models.py
159
160
161
162
class Meta:
    """Ensure the plural is grammatically correct."""

    verbose_name_plural = "Ignored Entries"

__str__()

Only display the route.

Source code in scram/route_manager/models.py
164
165
166
def __str__(self):
    """Only display the route."""
    return str(self.route)

Route

Bases: Model

Define a route as a CIDR route and a UUID.

Source code in scram/route_manager/models.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Route(models.Model):
    """Define a route as a CIDR route and a UUID."""

    route = CidrAddressField(unique=True)
    uuid = models.UUIDField(db_index=True, default=uuid_lib.uuid4, editable=False)

    def __str__(self):
        """Don't display the UUID, only the route."""
        return str(self.route)

    @staticmethod
    def get_absolute_url():
        """Ensure we use UUID on the API side instead."""
        return reverse("")

__str__()

Don't display the UUID, only the route.

Source code in scram/route_manager/models.py
23
24
25
def __str__(self):
    """Don't display the UUID, only the route."""
    return str(self.route)

get_absolute_url() staticmethod

Ensure we use UUID on the API side instead.

Source code in scram/route_manager/models.py
27
28
29
30
@staticmethod
def get_absolute_url():
    """Ensure we use UUID on the API side instead."""
    return reverse("")

WebSocketMessage

Bases: Model

Define a single message sent to downstream translators via WebSocket.

Source code in scram/route_manager/models.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class WebSocketMessage(models.Model):
    """Define a single message sent to downstream translators via WebSocket."""

    msg_type = models.CharField("The type of the message", max_length=50)
    msg_data = models.JSONField("The JSON payload. See also msg_data_route_field.", default=dict)
    msg_data_route_field = models.CharField(
        "The key in the JSON payload whose value will contain the route being acted on.",
        default="route",
        max_length=25,
    )

    def __str__(self):
        """Display clearly what the fields are used for."""
        return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}"

__str__()

Display clearly what the fields are used for.

Source code in scram/route_manager/models.py
58
59
60
def __str__(self):
    """Display clearly what the fields are used for."""
    return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}"

WebSocketSequenceElement

Bases: Model

In a sequence of messages, define a single element.

Source code in scram/route_manager/models.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class WebSocketSequenceElement(models.Model):
    """In a sequence of messages, define a single element."""

    websocketmessage = models.ForeignKey("WebSocketMessage", on_delete=models.CASCADE)
    order_num = models.SmallIntegerField(
        "Sequences are sent from the smallest order_num to the highest. "
        "Messages with the same order_num could be sent in any order",
        default=0,
    )

    VERB_CHOICES = [
        ("A", "Add"),
        ("C", "Check"),
        ("R", "Remove"),
    ]
    verb = models.CharField(max_length=1, choices=VERB_CHOICES)

    action_type = models.ForeignKey("ActionType", on_delete=models.CASCADE)

    def __str__(self):
        """Summarize the fields into something short and readable."""
        return (
            f"{self.websocketmessage} as order={self.order_num} for "
            f"{self.verb} actions on actiontype={self.action_type}"
        )

__str__()

Summarize the fields into something short and readable.

Source code in scram/route_manager/models.py
82
83
84
85
86
87
def __str__(self):
    """Summarize the fields into something short and readable."""
    return (
        f"{self.websocketmessage} as order={self.order_num} for "
        f"{self.verb} actions on actiontype={self.action_type}"
    )

tests

Define tests executed by pytest.

functional_tests

Use the Django web client to perform end-to-end, WebUI-based testing.

HomePageTest

Bases: TestCase

Ensure the home page works.

Source code in scram/route_manager/tests/functional_tests.py
6
7
class HomePageTest(unittest.TestCase):
    """Ensure the home page works."""

test_admin

Test the WhoFilter functionality of our admin site.

WhoFilterTest

Bases: TestCase

Test that the WhoFilter only shows users who have made entries.

Source code in scram/route_manager/tests/test_admin.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class WhoFilterTest(TestCase):
    """Test that the WhoFilter only shows users who have made entries."""

    def setUp(self):
        """Set up the test environment."""
        self.atype = ActionType.objects.create(name="Block")
        route1 = Route.objects.create(route="192.168.1.1")
        route2 = Route.objects.create(route="192.168.1.2")

        self.entry1 = Entry.objects.create(route=route1, actiontype=self.atype, who="admin")
        self.entry2 = Entry.objects.create(route=route2, actiontype=self.atype, who="user1")

    def test_who_filter_lookups(self):
        """Test that the WhoFilter returns the correct users who have made entries."""
        who_filter = WhoFilter(request=None, params={}, model=Entry, model_admin=EntryAdmin)

        mock_request = MagicMock()
        mock_model_admin = MagicMock(spec=EntryAdmin)

        result = who_filter.lookups(mock_request, mock_model_admin)

        self.assertIn(("admin", "admin"), result)
        self.assertIn(("user1", "user1"), result)
        self.assertEqual(len(result), 2)  # Only two users should be present

    def test_who_filter_queryset_with_value(self):
        """Test that the queryset is filtered correctly when a user is selected."""
        who_filter = WhoFilter(request=None, params={"who": "admin"}, model=Entry, model_admin=EntryAdmin)

        queryset = Entry.objects.all()
        filtered_queryset = who_filter.queryset(None, queryset)

        self.assertEqual(filtered_queryset.count(), 1)
        self.assertEqual(filtered_queryset.first(), self.entry1)
setUp()

Set up the test environment.

Source code in scram/route_manager/tests/test_admin.py
14
15
16
17
18
19
20
21
def setUp(self):
    """Set up the test environment."""
    self.atype = ActionType.objects.create(name="Block")
    route1 = Route.objects.create(route="192.168.1.1")
    route2 = Route.objects.create(route="192.168.1.2")

    self.entry1 = Entry.objects.create(route=route1, actiontype=self.atype, who="admin")
    self.entry2 = Entry.objects.create(route=route2, actiontype=self.atype, who="user1")
test_who_filter_lookups()

Test that the WhoFilter returns the correct users who have made entries.

Source code in scram/route_manager/tests/test_admin.py
23
24
25
26
27
28
29
30
31
32
33
34
def test_who_filter_lookups(self):
    """Test that the WhoFilter returns the correct users who have made entries."""
    who_filter = WhoFilter(request=None, params={}, model=Entry, model_admin=EntryAdmin)

    mock_request = MagicMock()
    mock_model_admin = MagicMock(spec=EntryAdmin)

    result = who_filter.lookups(mock_request, mock_model_admin)

    self.assertIn(("admin", "admin"), result)
    self.assertIn(("user1", "user1"), result)
    self.assertEqual(len(result), 2)  # Only two users should be present
test_who_filter_queryset_with_value()

Test that the queryset is filtered correctly when a user is selected.

Source code in scram/route_manager/tests/test_admin.py
36
37
38
39
40
41
42
43
44
def test_who_filter_queryset_with_value(self):
    """Test that the queryset is filtered correctly when a user is selected."""
    who_filter = WhoFilter(request=None, params={"who": "admin"}, model=Entry, model_admin=EntryAdmin)

    queryset = Entry.objects.all()
    filtered_queryset = who_filter.queryset(None, queryset)

    self.assertEqual(filtered_queryset.count(), 1)
    self.assertEqual(filtered_queryset.first(), self.entry1)

test_api

Use pytest to unit test the API.

TestAddRemoveIP

Bases: APITestCase

Ensure that we can block IPs, and that duplicate blocks don't generate an error.

Source code in scram/route_manager/tests/test_api.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class TestAddRemoveIP(APITestCase):
    """Ensure that we can block IPs, and that duplicate blocks don't generate an error."""

    def setUp(self):
        """Set up the environment for our tests."""
        self.url = reverse("api:v1:entry-list")
        self.superuser = get_user_model().objects.create_superuser("admin", "admin@es.net", "admintestpassword")
        self.client.login(username="admin", password="admintestpassword")
        self.authorized_client = Client.objects.create(
            hostname="authorized_client.es.net",
            uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            is_authorized=True,
        )
        self.authorized_client.authorized_actiontypes.set([1])

    def test_block_ipv4(self):
        """Block a v4 IP."""
        response = self.client.post(
            self.url,
            {
                "route": "192.0.2.4",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

    def test_block_duplicate_ipv4(self):
        """Block an existing v4 IP and ensure we don't get an error."""
        self.client.post(
            self.url,
            {
                "route": "192.0.2.4",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        response = self.client.post(
            self.url,
            {
                "route": "192.0.2.4",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

    def test_block_ipv6(self):
        """Block a v6 IP."""
        response = self.client.post(
            self.url,
            {
                "route": "1::",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

    def test_block_duplicate_ipv6(self):
        """Block an existing v6 IP and ensure we don't get an error."""
        self.client.post(
            self.url,
            {
                "route": "1::",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        response = self.client.post(
            self.url,
            {
                "route": "1::",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
            format="json",
        )
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
setUp()

Set up the environment for our tests.

Source code in scram/route_manager/tests/test_api.py
14
15
16
17
18
19
20
21
22
23
24
def setUp(self):
    """Set up the environment for our tests."""
    self.url = reverse("api:v1:entry-list")
    self.superuser = get_user_model().objects.create_superuser("admin", "admin@es.net", "admintestpassword")
    self.client.login(username="admin", password="admintestpassword")
    self.authorized_client = Client.objects.create(
        hostname="authorized_client.es.net",
        uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        is_authorized=True,
    )
    self.authorized_client.authorized_actiontypes.set([1])
test_block_duplicate_ipv4()

Block an existing v4 IP and ensure we don't get an error.

Source code in scram/route_manager/tests/test_api.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def test_block_duplicate_ipv4(self):
    """Block an existing v4 IP and ensure we don't get an error."""
    self.client.post(
        self.url,
        {
            "route": "192.0.2.4",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    response = self.client.post(
        self.url,
        {
            "route": "192.0.2.4",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    self.assertEqual(response.status_code, status.HTTP_201_CREATED)
test_block_duplicate_ipv6()

Block an existing v6 IP and ensure we don't get an error.

Source code in scram/route_manager/tests/test_api.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def test_block_duplicate_ipv6(self):
    """Block an existing v6 IP and ensure we don't get an error."""
    self.client.post(
        self.url,
        {
            "route": "1::",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    response = self.client.post(
        self.url,
        {
            "route": "1::",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    self.assertEqual(response.status_code, status.HTTP_201_CREATED)
test_block_ipv4()

Block a v4 IP.

Source code in scram/route_manager/tests/test_api.py
26
27
28
29
30
31
32
33
34
35
36
37
def test_block_ipv4(self):
    """Block a v4 IP."""
    response = self.client.post(
        self.url,
        {
            "route": "192.0.2.4",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    self.assertEqual(response.status_code, status.HTTP_201_CREATED)
test_block_ipv6()

Block a v6 IP.

Source code in scram/route_manager/tests/test_api.py
61
62
63
64
65
66
67
68
69
70
71
72
def test_block_ipv6(self):
    """Block a v6 IP."""
    response = self.client.post(
        self.url,
        {
            "route": "1::",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
        format="json",
    )
    self.assertEqual(response.status_code, status.HTTP_201_CREATED)

TestUnauthenticatedAccess

Bases: APITestCase

Ensure that an unathenticated client can't do anything.

Source code in scram/route_manager/tests/test_api.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class TestUnauthenticatedAccess(APITestCase):
    """Ensure that an unathenticated client can't do anything."""

    def setUp(self):
        """Define some helper variables."""
        self.entry_url = reverse("api:v1:entry-list")
        self.ignore_url = reverse("api:v1:ignoreentry-list")

    def test_unauthenticated_users_have_no_create_access(self):
        """Ensure an unauthenticated client can't add an Entry."""
        response = self.client.post(
            self.entry_url,
            {
                "route": "192.0.2.4",
                "comment": "test",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
                "who": "person",
            },
            format="json",
        )
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

    def test_unauthenticated_users_have_no_ignore_create_access(self):
        """Ensure an unauthenticated client can't add an IgnoreEntry."""
        response = self.client.post(self.ignore_url, {"route": "192.0.2.4"}, format="json")
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

    def test_unauthenticated_users_have_no_list_access(self):
        """Ensure an unauthenticated client can't list Entries."""
        response = self.client.get(self.entry_url, format="json")
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
setUp()

Define some helper variables.

Source code in scram/route_manager/tests/test_api.py
100
101
102
103
def setUp(self):
    """Define some helper variables."""
    self.entry_url = reverse("api:v1:entry-list")
    self.ignore_url = reverse("api:v1:ignoreentry-list")
test_unauthenticated_users_have_no_create_access()

Ensure an unauthenticated client can't add an Entry.

Source code in scram/route_manager/tests/test_api.py
105
106
107
108
109
110
111
112
113
114
115
116
117
def test_unauthenticated_users_have_no_create_access(self):
    """Ensure an unauthenticated client can't add an Entry."""
    response = self.client.post(
        self.entry_url,
        {
            "route": "192.0.2.4",
            "comment": "test",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            "who": "person",
        },
        format="json",
    )
    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
test_unauthenticated_users_have_no_ignore_create_access()

Ensure an unauthenticated client can't add an IgnoreEntry.

Source code in scram/route_manager/tests/test_api.py
119
120
121
122
def test_unauthenticated_users_have_no_ignore_create_access(self):
    """Ensure an unauthenticated client can't add an IgnoreEntry."""
    response = self.client.post(self.ignore_url, {"route": "192.0.2.4"}, format="json")
    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
test_unauthenticated_users_have_no_list_access()

Ensure an unauthenticated client can't list Entries.

Source code in scram/route_manager/tests/test_api.py
124
125
126
127
def test_unauthenticated_users_have_no_list_access(self):
    """Ensure an unauthenticated client can't list Entries."""
    response = self.client.get(self.entry_url, format="json")
    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

test_authorization

Define tests for authorization and permissions.

AuthzTest

Bases: TestCase

Define tests using the built-in authentication.

Source code in scram/route_manager/tests/test_authorization.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class AuthzTest(TestCase):
    """Define tests using the built-in authentication."""

    def setUp(self):
        """Define several users for our tests."""
        self.client = Client()
        self.unauthorized_user = User.objects.create(username="unauthorized")

        self.readonly_group = Group.objects.get(name="readonly")
        self.readonly_user = User.objects.create(username="readonly")
        self.readonly_user.groups.set([self.readonly_group])
        self.readonly_user.save()

        self.readwrite_group = Group.objects.get(name="readwrite")
        self.readwrite_user = User.objects.create(username="readwrite")
        self.readwrite_user.groups.set([self.readwrite_group])
        self.readwrite_user.save()

        self.admin_user = User.objects.create(username="admin", is_staff=True, is_superuser=True)

        self.write_blocked_users = [None, self.unauthorized_user, self.readonly_user]
        self.write_allowed_users = [self.readwrite_user, self.admin_user]

        self.detail_blocked_users = [None, self.unauthorized_user]
        self.detail_allowed_users = [
            self.readonly_user,
            self.readwrite_user,
            self.admin_user,
        ]

        self.authorized_client = ClientModel.objects.create(
            hostname="authorized_client.es.net",
            uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            is_authorized=True,
        )
        self.authorized_client.authorized_actiontypes.set([1])

        self.unauthorized_client = ClientModel.objects.create(
            hostname="unauthorized_client.es.net",
            uuid="91e134a5-77cf-4560-9797-6bbdbffde9f8",
        )

    def create_entry(self):
        """Ensure the admin user can create an Entry."""
        self.client.force_login(self.admin_user)
        self.client.post(
            reverse("route_manager:add"),
            {
                "route": "192.0.2.199/32",
                "actiontype": "block",
                "comment": "create entry",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
        )
        self.client.logout()
        return Entry.objects.latest("id").id

    def test_unauthorized_add_entry(self):
        """Unauthorized users should not be able to add an Entry."""
        for user in self.write_blocked_users:
            if user:
                self.client.force_login(user)
            response = self.client.post(
                reverse("route_manager:add"),
                {
                    "route": "192.0.2.4/32",
                    "actiontype": "block",
                    "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
                },
            )
            self.assertEqual(response.status_code, 302)

    def test_authorized_add_entry(self):
        """Test authorized users with various permissions to ensure they can add an Entry."""
        for user in self.write_allowed_users:
            self.client.force_login(user)
            response = self.client.post(
                reverse("route_manager:add"),
                {
                    "route": "192.0.2.4/32",
                    "actiontype": "block",
                    "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
                },
            )
            self.assertEqual(response.status_code, 302)

    def test_unauthorized_detail_view(self):
        """Ensure that unauthorized users can't view the blocked IPs."""
        pk = self.create_entry()

        for user in self.detail_blocked_users:
            if user:
                self.client.force_login(user)
            response = self.client.get(reverse("route_manager:detail", kwargs={"pk": pk}))
            self.assertIn(response.status_code, [302, 403], msg=f"username={user}")

    def test_authorized_detail_view(self):
        """Test authorized users with various permissions to ensure they can view block details."""
        pk = self.create_entry()

        for user in self.detail_allowed_users:
            self.client.force_login(user)
            response = self.client.get(reverse("route_manager:detail", kwargs={"pk": pk}))
            self.assertEqual(response.status_code, 200, msg=f"username={user}")

    def test_unauthorized_after_group_removal(self):
        """The user has r/w access, then when we remove them from the r/w group, they no longer do."""
        test_user = User.objects.create(username="tmp_readwrite")
        test_user.groups.set([self.readwrite_group])
        test_user.save()

        self.client.force_login(test_user)
        response = self.client.post(reverse("route_manager:add"), {"route": "192.0.2.4/32", "actiontype": "block"})
        self.assertEqual(response.status_code, 302)

        test_user.groups.set([])

        response = self.client.post(reverse("route_manager:add"), {"route": "192.0.2.5/32", "actiontype": "block"})
        self.assertEqual(response.status_code, 302)
create_entry()

Ensure the admin user can create an Entry.

Source code in scram/route_manager/tests/test_authorization.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def create_entry(self):
    """Ensure the admin user can create an Entry."""
    self.client.force_login(self.admin_user)
    self.client.post(
        reverse("route_manager:add"),
        {
            "route": "192.0.2.199/32",
            "actiontype": "block",
            "comment": "create entry",
            "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        },
    )
    self.client.logout()
    return Entry.objects.latest("id").id
setUp()

Define several users for our tests.

Source code in scram/route_manager/tests/test_authorization.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def setUp(self):
    """Define several users for our tests."""
    self.client = Client()
    self.unauthorized_user = User.objects.create(username="unauthorized")

    self.readonly_group = Group.objects.get(name="readonly")
    self.readonly_user = User.objects.create(username="readonly")
    self.readonly_user.groups.set([self.readonly_group])
    self.readonly_user.save()

    self.readwrite_group = Group.objects.get(name="readwrite")
    self.readwrite_user = User.objects.create(username="readwrite")
    self.readwrite_user.groups.set([self.readwrite_group])
    self.readwrite_user.save()

    self.admin_user = User.objects.create(username="admin", is_staff=True, is_superuser=True)

    self.write_blocked_users = [None, self.unauthorized_user, self.readonly_user]
    self.write_allowed_users = [self.readwrite_user, self.admin_user]

    self.detail_blocked_users = [None, self.unauthorized_user]
    self.detail_allowed_users = [
        self.readonly_user,
        self.readwrite_user,
        self.admin_user,
    ]

    self.authorized_client = ClientModel.objects.create(
        hostname="authorized_client.es.net",
        uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
        is_authorized=True,
    )
    self.authorized_client.authorized_actiontypes.set([1])

    self.unauthorized_client = ClientModel.objects.create(
        hostname="unauthorized_client.es.net",
        uuid="91e134a5-77cf-4560-9797-6bbdbffde9f8",
    )
test_authorized_add_entry()

Test authorized users with various permissions to ensure they can add an Entry.

Source code in scram/route_manager/tests/test_authorization.py
86
87
88
89
90
91
92
93
94
95
96
97
98
def test_authorized_add_entry(self):
    """Test authorized users with various permissions to ensure they can add an Entry."""
    for user in self.write_allowed_users:
        self.client.force_login(user)
        response = self.client.post(
            reverse("route_manager:add"),
            {
                "route": "192.0.2.4/32",
                "actiontype": "block",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
        )
        self.assertEqual(response.status_code, 302)
test_authorized_detail_view()

Test authorized users with various permissions to ensure they can view block details.

Source code in scram/route_manager/tests/test_authorization.py
110
111
112
113
114
115
116
117
def test_authorized_detail_view(self):
    """Test authorized users with various permissions to ensure they can view block details."""
    pk = self.create_entry()

    for user in self.detail_allowed_users:
        self.client.force_login(user)
        response = self.client.get(reverse("route_manager:detail", kwargs={"pk": pk}))
        self.assertEqual(response.status_code, 200, msg=f"username={user}")
test_unauthorized_add_entry()

Unauthorized users should not be able to add an Entry.

Source code in scram/route_manager/tests/test_authorization.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def test_unauthorized_add_entry(self):
    """Unauthorized users should not be able to add an Entry."""
    for user in self.write_blocked_users:
        if user:
            self.client.force_login(user)
        response = self.client.post(
            reverse("route_manager:add"),
            {
                "route": "192.0.2.4/32",
                "actiontype": "block",
                "uuid": "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
            },
        )
        self.assertEqual(response.status_code, 302)
test_unauthorized_after_group_removal()

The user has r/w access, then when we remove them from the r/w group, they no longer do.

Source code in scram/route_manager/tests/test_authorization.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def test_unauthorized_after_group_removal(self):
    """The user has r/w access, then when we remove them from the r/w group, they no longer do."""
    test_user = User.objects.create(username="tmp_readwrite")
    test_user.groups.set([self.readwrite_group])
    test_user.save()

    self.client.force_login(test_user)
    response = self.client.post(reverse("route_manager:add"), {"route": "192.0.2.4/32", "actiontype": "block"})
    self.assertEqual(response.status_code, 302)

    test_user.groups.set([])

    response = self.client.post(reverse("route_manager:add"), {"route": "192.0.2.5/32", "actiontype": "block"})
    self.assertEqual(response.status_code, 302)
test_unauthorized_detail_view()

Ensure that unauthorized users can't view the blocked IPs.

Source code in scram/route_manager/tests/test_authorization.py
100
101
102
103
104
105
106
107
108
def test_unauthorized_detail_view(self):
    """Ensure that unauthorized users can't view the blocked IPs."""
    pk = self.create_entry()

    for user in self.detail_blocked_users:
        if user:
            self.client.force_login(user)
        response = self.client.get(reverse("route_manager:detail", kwargs={"pk": pk}))
        self.assertIn(response.status_code, [302, 403], msg=f"username={user}")

ESnetAuthBackendTest

Bases: TestCase

Define tests using OIDC authentication with our ESnetAuthBackend.

Source code in scram/route_manager/tests/test_authorization.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class ESnetAuthBackendTest(TestCase):
    """Define tests using OIDC authentication with our ESnetAuthBackend."""

    def setUp(self):
        """Create a sample OIDC user."""
        self.client = Client()
        self.claims = {
            "given_name": "Edward",
            "family_name": "Scissorhands",
            "preferred_username": "eddy",
            "groups": [],
        }

    def test_unauthorized(self):
        """A user with no groups should have no access."""
        claims = dict(self.claims)
        user = ESnetAuthBackend().create_user(claims)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertEqual(list(user.user_permissions.all()), [])

    def test_readonly(self):
        """Test r/o groups."""
        claims = dict(self.claims)
        claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]]
        user = ESnetAuthBackend().create_user(claims)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertFalse(user.has_perm("route_manager.add_entry"))

    def test_readwrite(self):
        """Test r/w groups."""
        claims = dict(self.claims)
        claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]]
        user = ESnetAuthBackend().create_user(claims)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)

        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertTrue(user.has_perm("route_manager.add_entry"))

    def test_admin(self):
        """Test admin_groups."""
        claims = dict(self.claims)
        claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
        user = ESnetAuthBackend().create_user(claims)

        self.assertTrue(user.is_staff)
        self.assertTrue(user.is_superuser)
        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertTrue(user.has_perm("route_manager.add_entry"))

    def test_authorized_removal(self):
        """Have an authorized user, then downgrade them and make sure they're unauthorized."""
        claims = dict(self.claims)
        claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
        user = ESnetAuthBackend().create_user(claims)
        pk = user.pk

        self.assertTrue(user.is_staff)
        self.assertTrue(user.is_superuser)
        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertTrue(user.has_perm("route_manager.add_entry"))

        claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]]
        ESnetAuthBackend().update_user(user, claims)

        # Bypass cache
        user = User.objects.get(pk=pk)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertTrue(user.has_perm("route_manager.add_entry"))

        claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]]
        ESnetAuthBackend().update_user(user, claims)

        # Bypass cache
        user = User.objects.get(pk=pk)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertTrue(user.has_perm("route_manager.view_entry"))
        self.assertFalse(user.has_perm("route_manager.add_entry"))

        claims["groups"] = [settings.SCRAM_DENIED_GROUPS[0]]
        ESnetAuthBackend().update_user(user, claims)

        # Bypass cache
        user = User.objects.get(pk=pk)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertFalse(user.has_perm("route_manager.view_entry"))
        self.assertFalse(user.has_perm("route_manager.add_entry"))

    def test_disabled(self):
        """Pass all the groups, user should be disabled as it takes precedence."""
        claims = dict(self.claims)
        claims["groups"] = settings.SCRAM_GROUPS
        user = ESnetAuthBackend().create_user(claims)

        self.assertFalse(user.is_staff)
        self.assertFalse(user.is_superuser)
        self.assertFalse(user.has_perm("route_manager.view_entry"))
        self.assertFalse(user.has_perm("route_manager.add_entry"))
setUp()

Create a sample OIDC user.

Source code in scram/route_manager/tests/test_authorization.py
138
139
140
141
142
143
144
145
146
def setUp(self):
    """Create a sample OIDC user."""
    self.client = Client()
    self.claims = {
        "given_name": "Edward",
        "family_name": "Scissorhands",
        "preferred_username": "eddy",
        "groups": [],
    }
test_admin()

Test admin_groups.

Source code in scram/route_manager/tests/test_authorization.py
180
181
182
183
184
185
186
187
188
189
def test_admin(self):
    """Test admin_groups."""
    claims = dict(self.claims)
    claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
    user = ESnetAuthBackend().create_user(claims)

    self.assertTrue(user.is_staff)
    self.assertTrue(user.is_superuser)
    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertTrue(user.has_perm("route_manager.add_entry"))
test_authorized_removal()

Have an authorized user, then downgrade them and make sure they're unauthorized.

Source code in scram/route_manager/tests/test_authorization.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def test_authorized_removal(self):
    """Have an authorized user, then downgrade them and make sure they're unauthorized."""
    claims = dict(self.claims)
    claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
    user = ESnetAuthBackend().create_user(claims)
    pk = user.pk

    self.assertTrue(user.is_staff)
    self.assertTrue(user.is_superuser)
    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertTrue(user.has_perm("route_manager.add_entry"))

    claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]]
    ESnetAuthBackend().update_user(user, claims)

    # Bypass cache
    user = User.objects.get(pk=pk)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertTrue(user.has_perm("route_manager.add_entry"))

    claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]]
    ESnetAuthBackend().update_user(user, claims)

    # Bypass cache
    user = User.objects.get(pk=pk)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertFalse(user.has_perm("route_manager.add_entry"))

    claims["groups"] = [settings.SCRAM_DENIED_GROUPS[0]]
    ESnetAuthBackend().update_user(user, claims)

    # Bypass cache
    user = User.objects.get(pk=pk)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertFalse(user.has_perm("route_manager.view_entry"))
    self.assertFalse(user.has_perm("route_manager.add_entry"))
test_disabled()

Pass all the groups, user should be disabled as it takes precedence.

Source code in scram/route_manager/tests/test_authorization.py
236
237
238
239
240
241
242
243
244
245
def test_disabled(self):
    """Pass all the groups, user should be disabled as it takes precedence."""
    claims = dict(self.claims)
    claims["groups"] = settings.SCRAM_GROUPS
    user = ESnetAuthBackend().create_user(claims)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertFalse(user.has_perm("route_manager.view_entry"))
    self.assertFalse(user.has_perm("route_manager.add_entry"))
test_readonly()

Test r/o groups.

Source code in scram/route_manager/tests/test_authorization.py
157
158
159
160
161
162
163
164
165
166
def test_readonly(self):
    """Test r/o groups."""
    claims = dict(self.claims)
    claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]]
    user = ESnetAuthBackend().create_user(claims)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertFalse(user.has_perm("route_manager.add_entry"))
test_readwrite()

Test r/w groups.

Source code in scram/route_manager/tests/test_authorization.py
168
169
170
171
172
173
174
175
176
177
178
def test_readwrite(self):
    """Test r/w groups."""
    claims = dict(self.claims)
    claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]]
    user = ESnetAuthBackend().create_user(claims)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)

    self.assertTrue(user.has_perm("route_manager.view_entry"))
    self.assertTrue(user.has_perm("route_manager.add_entry"))
test_unauthorized()

A user with no groups should have no access.

Source code in scram/route_manager/tests/test_authorization.py
148
149
150
151
152
153
154
155
def test_unauthorized(self):
    """A user with no groups should have no access."""
    claims = dict(self.claims)
    user = ESnetAuthBackend().create_user(claims)

    self.assertFalse(user.is_staff)
    self.assertFalse(user.is_superuser)
    self.assertEqual(list(user.user_permissions.all()), [])

test_autocreate_admin

Test the auto-creation of an admin user.

test_autocreate_admin(settings)

Test that an admin user is auto-created when AUTOCREATE_ADMIN is True.

Source code in scram/route_manager/tests/test_autocreate_admin.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@pytest.mark.django_db
def test_autocreate_admin(settings):
    """Test that an admin user is auto-created when AUTOCREATE_ADMIN is True."""
    settings.AUTOCREATE_ADMIN = True
    client = Client()
    response = client.get(reverse("route_manager:home"))
    assert response.status_code == 200
    assert User.objects.count() == 1
    user = User.objects.get(username="admin")
    assert user.is_superuser
    assert user.email == "admin@example.com"
    messages = list(get_messages(response.wsgi_request))
    assert len(messages) == 2
    assert messages[0].level == LEVEL_SUCCESS
    assert messages[1].level == LEVEL_INFO

test_autocreate_admin_disabled(settings)

Test that an admin user is not auto-created when AUTOCREATE_ADMIN is False.

Source code in scram/route_manager/tests/test_autocreate_admin.py
31
32
33
34
35
36
37
38
@pytest.mark.django_db
def test_autocreate_admin_disabled(settings):
    """Test that an admin user is not auto-created when AUTOCREATE_ADMIN is False."""
    settings.AUTOCREATE_ADMIN = False
    client = Client()
    response = client.get(reverse("route_manager:home"))
    assert response.status_code == 200
    assert User.objects.count() == 0

test_autocreate_admin_existing_user(settings)

Test that an admin user is not auto-created when an existing user is present.

Source code in scram/route_manager/tests/test_autocreate_admin.py
41
42
43
44
45
46
47
48
49
50
@pytest.mark.django_db
def test_autocreate_admin_existing_user(settings):
    """Test that an admin user is not auto-created when an existing user is present."""
    settings.AUTOCREATE_ADMIN = True
    User.objects.create_user("testuser", "test@example.com", "password")
    client = Client()
    response = client.get(reverse("route_manager:home"))
    assert response.status_code == 200
    assert User.objects.count() == 1
    assert not User.objects.filter(username="admin").exists()

test_history

Define tests for the history feature.

TestActiontypeHistory

Bases: TestCase

Test the history on an action type.

Source code in scram/route_manager/tests/test_history.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
class TestActiontypeHistory(TestCase):
    """Test the history on an action type."""

    def setUp(self):
        """Set up the test environment."""
        self.atype = ActionType.objects.create(name="Block")

    def test_comments(self):
        """Ensure we can go back and set a reason."""
        self.atype.name = "Nullroute"
        self.atype._change_reason = "Use more descriptive name"  # noqa SLF001
        self.atype.save()
        self.assertIsNotNone(get_change_reason_from_object(self.atype))
setUp()

Set up the test environment.

Source code in scram/route_manager/tests/test_history.py
12
13
14
def setUp(self):
    """Set up the test environment."""
    self.atype = ActionType.objects.create(name="Block")
test_comments()

Ensure we can go back and set a reason.

Source code in scram/route_manager/tests/test_history.py
16
17
18
19
20
21
def test_comments(self):
    """Ensure we can go back and set a reason."""
    self.atype.name = "Nullroute"
    self.atype._change_reason = "Use more descriptive name"  # noqa SLF001
    self.atype.save()
    self.assertIsNotNone(get_change_reason_from_object(self.atype))

TestEntryHistory

Bases: TestCase

Test the history on an Entry.

Source code in scram/route_manager/tests/test_history.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class TestEntryHistory(TestCase):
    """Test the history on an Entry."""

    routes = ["192.0.2.16/32", "198.51.100.16/28"]

    def setUp(self):
        """Set up the test environment."""
        self.atype = ActionType.objects.create(name="Block")
        for r in self.routes:
            route = Route.objects.create(route=r)
            entry = Entry.objects.create(route=route, actiontype=self.atype)
            create_reason = "Zeek detected a scan from 192.0.2.1."
            update_change_reason(entry, create_reason)
            self.assertEqual(entry.get_change_reason(), create_reason)

    def test_comments(self):
        """Ensure we can update the reason."""
        for r in self.routes:
            route_old = Route.objects.get(route=r)
            e = Entry.objects.get(route=route_old)
            self.assertEqual(e.get_change_reason(), "Zeek detected a scan from 192.0.2.1.")

            route_new = str(route_old).replace("16", "32")
            e.route = Route.objects.create(route=route_new)

            change_reason = "I meant 32, not 16."
            e._change_reason = change_reason  # noqa SLF001
            e.save()

            self.assertEqual(len(e.history.all()), 2)
            self.assertEqual(e.get_change_reason(), change_reason)
setUp()

Set up the test environment.

Source code in scram/route_manager/tests/test_history.py
29
30
31
32
33
34
35
36
37
def setUp(self):
    """Set up the test environment."""
    self.atype = ActionType.objects.create(name="Block")
    for r in self.routes:
        route = Route.objects.create(route=r)
        entry = Entry.objects.create(route=route, actiontype=self.atype)
        create_reason = "Zeek detected a scan from 192.0.2.1."
        update_change_reason(entry, create_reason)
        self.assertEqual(entry.get_change_reason(), create_reason)
test_comments()

Ensure we can update the reason.

Source code in scram/route_manager/tests/test_history.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def test_comments(self):
    """Ensure we can update the reason."""
    for r in self.routes:
        route_old = Route.objects.get(route=r)
        e = Entry.objects.get(route=route_old)
        self.assertEqual(e.get_change_reason(), "Zeek detected a scan from 192.0.2.1.")

        route_new = str(route_old).replace("16", "32")
        e.route = Route.objects.create(route=route_new)

        change_reason = "I meant 32, not 16."
        e._change_reason = change_reason  # noqa SLF001
        e.save()

        self.assertEqual(len(e.history.all()), 2)
        self.assertEqual(e.get_change_reason(), change_reason)

test_pagination

Define simple tests for pagination.

TestEntriesListView

Bases: TestCase

Test to make sure our pagination and related scaffolding work.

Source code in scram/route_manager/tests/test_pagination.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@pytest.mark.django_db
class TestEntriesListView(TestCase):
    """Test to make sure our pagination and related scaffolding work."""

    TEST_PAGINATION_SIZE = 5

    def setUp(self):
        """Set up the test environment."""
        self.fake = Faker()
        self.fake.add_provider(internet)
        get_user_model().objects.create_user(username="testuser", password="testpass123")

        self.atype1 = ActionType.objects.create(name="Type1", available=True)
        self.atype2 = ActionType.objects.create(name="Type2", available=True)
        self.atype3 = ActionType.objects.create(name="Type3", available=False)

        # Create enough entries to test pagination
        created_routes = Route.objects.bulk_create([
            Route(route=self.fake.unique.ipv4_public()) for x in range(self.TEST_PAGINATION_SIZE + 3)
        ])
        entries_type1 = Entry.objects.bulk_create([
            Entry(route=route, actiontype=self.atype1, is_active=True) for route in created_routes
        ])

        # Create a second type of entries to test filtering per actiontype
        created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
        entries_type2 = Entry.objects.bulk_create([
            Entry(route=route, actiontype=self.atype2, is_active=True) for route in created_routes
        ])

        # Create inactive entries to test filtering by available actiontypes
        created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
        Entry.objects.bulk_create([
            Entry(route=route, actiontype=self.atype1, is_active=False) for route in created_routes
        ])

        # Create entries for an invalid actiontype to test that
        created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
        Entry.objects.bulk_create([
            Entry(route=route, actiontype=self.atype3, is_active=False) for route in created_routes
        ])

        self.entries = {
            "type1": entries_type1,
            "type2": entries_type2,
        }

    def test_context(self):
        """Test that the context structure is correctly filled out."""
        self.client.login(username="testuser", password="testpass123")

        url = reverse("route_manager:entry-list")
        response = self.client.get(url)

        assert response.status_code == 200
        assert "entries" in response.context
        entries_context = response.context["entries"]

        assert self.atype1 in entries_context
        assert self.atype2 in entries_context
        assert self.atype3 not in entries_context

    def test_filtering_entries_by_action_type(self):
        """Test that our paginated output has entries for all available actiontypes in our paginated output."""
        self.client.login(username="testuser", password="testpass123")

        url = reverse("route_manager:entry-list")
        response = self.client.get(url)

        entries_context = response.context["entries"]

        assert entries_context[self.atype1]["total"] == len(self.entries["type1"])
        assert entries_context[self.atype2]["total"] == len(self.entries["type2"])

    @override_settings(PAGINATION_SIZE=5)
    def test_pagination(self):
        """Test pagination when there's multiple action types."""
        self.client.login(username="testuser", password="testpass123")

        url = reverse("route_manager:entry-list")

        response = self.client.get(url)
        entries_context = response.context["entries"]

        # First page should have PAGINATION_SIZE entries for actiontype with more entries than pagination size
        assert len(entries_context[self.atype1]["objs"]) == settings.PAGINATION_SIZE
        assert entries_context[self.atype1]["page_param"] == "page_type1"
        assert str(entries_context[self.atype1]["page_number"]) == "1"

        # First page should include all entries for actiontype with less entries than pagination size
        assert len(entries_context[self.atype2]["objs"]) == len(self.entries["type2"])

        # Second page should have the rest of the entries for actiontype with more entries than pagination size
        page2_response = self.client.get(f"{url}?page_type1=2")
        page2_context = page2_response.context["entries"]

        assert str(page2_context[self.atype1]["page_number"]) == "2"
        assert len(page2_context[self.atype1]["objs"]) == 3

    @override_settings(PAGINATION_SIZE=TEST_PAGINATION_SIZE)
    def test_invalid_page_handling(self):
        """Test handling of invalid page numbers."""
        self.client.login(username="testuser", password="testpass123")

        url = reverse("route_manager:entry-list")
        response = self.client.get(f"{url}?page_type1=999")

        entries_context = response.context["entries"]

        # Should default to page 1
        assert entries_context[self.atype1]["objs"].number == 1

    def test_multiple_page_parameters(self):
        """Test that we can have separate pages when we have more than one actiontype."""
        self.client.login(username="testuser", password="testpass123")

        url = reverse("route_manager:entry-list")
        response = self.client.get(f"{url}?page_type1=2&page_type2=1")

        entries_context = response.context["entries"]

        # Each type should have its own page number
        assert str(entries_context[self.atype1]["page_number"]) == "2"
        assert str(entries_context[self.atype2]["page_number"]) == "1"
        assert "page_type1" in entries_context[self.atype1]["current_page_params"]
        assert "page_type2" in entries_context[self.atype1]["current_page_params"]
setUp()

Set up the test environment.

Source code in scram/route_manager/tests/test_pagination.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def setUp(self):
    """Set up the test environment."""
    self.fake = Faker()
    self.fake.add_provider(internet)
    get_user_model().objects.create_user(username="testuser", password="testpass123")

    self.atype1 = ActionType.objects.create(name="Type1", available=True)
    self.atype2 = ActionType.objects.create(name="Type2", available=True)
    self.atype3 = ActionType.objects.create(name="Type3", available=False)

    # Create enough entries to test pagination
    created_routes = Route.objects.bulk_create([
        Route(route=self.fake.unique.ipv4_public()) for x in range(self.TEST_PAGINATION_SIZE + 3)
    ])
    entries_type1 = Entry.objects.bulk_create([
        Entry(route=route, actiontype=self.atype1, is_active=True) for route in created_routes
    ])

    # Create a second type of entries to test filtering per actiontype
    created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
    entries_type2 = Entry.objects.bulk_create([
        Entry(route=route, actiontype=self.atype2, is_active=True) for route in created_routes
    ])

    # Create inactive entries to test filtering by available actiontypes
    created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
    Entry.objects.bulk_create([
        Entry(route=route, actiontype=self.atype1, is_active=False) for route in created_routes
    ])

    # Create entries for an invalid actiontype to test that
    created_routes = Route.objects.bulk_create([Route(route=self.fake.unique.ipv4_public()) for x in range(3)])
    Entry.objects.bulk_create([
        Entry(route=route, actiontype=self.atype3, is_active=False) for route in created_routes
    ])

    self.entries = {
        "type1": entries_type1,
        "type2": entries_type2,
    }
test_context()

Test that the context structure is correctly filled out.

Source code in scram/route_manager/tests/test_pagination.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def test_context(self):
    """Test that the context structure is correctly filled out."""
    self.client.login(username="testuser", password="testpass123")

    url = reverse("route_manager:entry-list")
    response = self.client.get(url)

    assert response.status_code == 200
    assert "entries" in response.context
    entries_context = response.context["entries"]

    assert self.atype1 in entries_context
    assert self.atype2 in entries_context
    assert self.atype3 not in entries_context
test_filtering_entries_by_action_type()

Test that our paginated output has entries for all available actiontypes in our paginated output.

Source code in scram/route_manager/tests/test_pagination.py
76
77
78
79
80
81
82
83
84
85
86
def test_filtering_entries_by_action_type(self):
    """Test that our paginated output has entries for all available actiontypes in our paginated output."""
    self.client.login(username="testuser", password="testpass123")

    url = reverse("route_manager:entry-list")
    response = self.client.get(url)

    entries_context = response.context["entries"]

    assert entries_context[self.atype1]["total"] == len(self.entries["type1"])
    assert entries_context[self.atype2]["total"] == len(self.entries["type2"])
test_invalid_page_handling()

Test handling of invalid page numbers.

Source code in scram/route_manager/tests/test_pagination.py
113
114
115
116
117
118
119
120
121
122
123
124
@override_settings(PAGINATION_SIZE=TEST_PAGINATION_SIZE)
def test_invalid_page_handling(self):
    """Test handling of invalid page numbers."""
    self.client.login(username="testuser", password="testpass123")

    url = reverse("route_manager:entry-list")
    response = self.client.get(f"{url}?page_type1=999")

    entries_context = response.context["entries"]

    # Should default to page 1
    assert entries_context[self.atype1]["objs"].number == 1
test_multiple_page_parameters()

Test that we can have separate pages when we have more than one actiontype.

Source code in scram/route_manager/tests/test_pagination.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def test_multiple_page_parameters(self):
    """Test that we can have separate pages when we have more than one actiontype."""
    self.client.login(username="testuser", password="testpass123")

    url = reverse("route_manager:entry-list")
    response = self.client.get(f"{url}?page_type1=2&page_type2=1")

    entries_context = response.context["entries"]

    # Each type should have its own page number
    assert str(entries_context[self.atype1]["page_number"]) == "2"
    assert str(entries_context[self.atype2]["page_number"]) == "1"
    assert "page_type1" in entries_context[self.atype1]["current_page_params"]
    assert "page_type2" in entries_context[self.atype1]["current_page_params"]
test_pagination()

Test pagination when there's multiple action types.

Source code in scram/route_manager/tests/test_pagination.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@override_settings(PAGINATION_SIZE=5)
def test_pagination(self):
    """Test pagination when there's multiple action types."""
    self.client.login(username="testuser", password="testpass123")

    url = reverse("route_manager:entry-list")

    response = self.client.get(url)
    entries_context = response.context["entries"]

    # First page should have PAGINATION_SIZE entries for actiontype with more entries than pagination size
    assert len(entries_context[self.atype1]["objs"]) == settings.PAGINATION_SIZE
    assert entries_context[self.atype1]["page_param"] == "page_type1"
    assert str(entries_context[self.atype1]["page_number"]) == "1"

    # First page should include all entries for actiontype with less entries than pagination size
    assert len(entries_context[self.atype2]["objs"]) == len(self.entries["type2"])

    # Second page should have the rest of the entries for actiontype with more entries than pagination size
    page2_response = self.client.get(f"{url}?page_type1=2")
    page2_context = page2_response.context["entries"]

    assert str(page2_context[self.atype1]["page_number"]) == "2"
    assert len(page2_context[self.atype1]["objs"]) == 3

test_performance

Tests for performance (load time, DB queries, etc.).

TestViewNumQueries

Bases: TestCase

Viewing an entry should only require one query.

Source code in scram/route_manager/tests/test_performance.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class TestViewNumQueries(TestCase):
    """Viewing an entry should only require one query."""

    NUM_ENTRIES = 100_000

    def setUp(self):
        """Set up the test environment."""
        self.fake = Faker()
        self.fake.add_provider(internet)

        # Query the homepage once to setup the user
        self.client.get(reverse("route_manager:home"))

        self.atype, _ = ActionType.objects.get_or_create(name="block")
        routes = [Route(route=self.fake.unique.ipv4_public()) for x in range(self.NUM_ENTRIES)]
        created_routes = Route.objects.bulk_create(routes)
        entries = [Entry(route=route, actiontype=self.atype, is_active=True) for route in created_routes]
        Entry.objects.bulk_create(entries)
        # Manually set the when time to be old so that we don't trigger `process_updates`
        # on all 100,000 of the test routes.
        Entry.objects.update(when=datetime.datetime(2000, 1, 1, 0, 0, tzinfo=datetime.UTC))

    def test_home_page(self):
        """Home page requires 11 queries.

        1. create transaction
        2. lookup session
        3. lookup user
        4. filter available actiontypes
        5. count entries with actiontype=1 and is_active
        6. count by user
        7. context processor active_count active blocks
        8. context processor active_count all blocks
        9. first page for actiontype=1
        10. close transaction
        """
        with self.assertNumQueries(10):
            start = time.time()
            self.client.get(reverse("route_manager:home"))
            time_taken = time.time() - start
            self.assertLess(time_taken, 1, "Home page took longer than 1 second")

    def test_entry_view(self):
        """Viewing an entry requires 8 queries.

        1. create transaction savepoint
        2. lookup session
        3. lookup user
        4. get entry
        5. rollback to savepoint
        6. release transaction savepoint
        7. context processor active_count active blocks
        8. context processor active_count all blocks
        """
        with self.assertNumQueries(8):
            start = time.time()
            self.client.get(reverse("route_manager:detail", kwargs={"pk": 9999}))
            time_taken = time.time() - start
            self.assertLess(time_taken, 1, "Entry detail page took longer than 1 second")

    def test_admin_entry_page(self):
        """Admin entry list page requires 8 queries.

        1. create transaction
        2. lookup session
        3. lookup user
        4. lookup distinct users for our WhoFilter
        4. count entries
        5. count entries
        6. get first 100 entries
        7. query entries (a single query, with select_related)
        8. release transaction
        """
        with self.assertNumQueries(8):
            start = time.time()
            self.client.get(reverse("admin:route_manager_entry_changelist"))
            time_taken = time.time() - start
            self.assertLess(time_taken, 1, "Admin entry list page took longer than 1 seconds")

    def test_process_updates(self):
        """Process expired requires 6 queries.

        1. create transaction
        2. get entries_start active entry count
        3. find and delete expired entries
        4. get entries_end active entry count
        5. get new_entries from DB
        6. release transaction
        """
        with self.assertNumQueries(6):
            start = time.time()
            self.client.get(reverse("route_manager:process-updates"))
            time_taken = time.time() - start
            self.assertLess(time_taken, 1, "Process expired page took longer than 1 seconds")
setUp()

Set up the test environment.

Source code in scram/route_manager/tests/test_performance.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def setUp(self):
    """Set up the test environment."""
    self.fake = Faker()
    self.fake.add_provider(internet)

    # Query the homepage once to setup the user
    self.client.get(reverse("route_manager:home"))

    self.atype, _ = ActionType.objects.get_or_create(name="block")
    routes = [Route(route=self.fake.unique.ipv4_public()) for x in range(self.NUM_ENTRIES)]
    created_routes = Route.objects.bulk_create(routes)
    entries = [Entry(route=route, actiontype=self.atype, is_active=True) for route in created_routes]
    Entry.objects.bulk_create(entries)
    # Manually set the when time to be old so that we don't trigger `process_updates`
    # on all 100,000 of the test routes.
    Entry.objects.update(when=datetime.datetime(2000, 1, 1, 0, 0, tzinfo=datetime.UTC))
test_admin_entry_page()

Admin entry list page requires 8 queries.

  1. create transaction
  2. lookup session
  3. lookup user
  4. lookup distinct users for our WhoFilter
  5. count entries
  6. count entries
  7. get first 100 entries
  8. query entries (a single query, with select_related)
  9. release transaction
Source code in scram/route_manager/tests/test_performance.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def test_admin_entry_page(self):
    """Admin entry list page requires 8 queries.

    1. create transaction
    2. lookup session
    3. lookup user
    4. lookup distinct users for our WhoFilter
    4. count entries
    5. count entries
    6. get first 100 entries
    7. query entries (a single query, with select_related)
    8. release transaction
    """
    with self.assertNumQueries(8):
        start = time.time()
        self.client.get(reverse("admin:route_manager_entry_changelist"))
        time_taken = time.time() - start
        self.assertLess(time_taken, 1, "Admin entry list page took longer than 1 seconds")
test_entry_view()

Viewing an entry requires 8 queries.

  1. create transaction savepoint
  2. lookup session
  3. lookup user
  4. get entry
  5. rollback to savepoint
  6. release transaction savepoint
  7. context processor active_count active blocks
  8. context processor active_count all blocks
Source code in scram/route_manager/tests/test_performance.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def test_entry_view(self):
    """Viewing an entry requires 8 queries.

    1. create transaction savepoint
    2. lookup session
    3. lookup user
    4. get entry
    5. rollback to savepoint
    6. release transaction savepoint
    7. context processor active_count active blocks
    8. context processor active_count all blocks
    """
    with self.assertNumQueries(8):
        start = time.time()
        self.client.get(reverse("route_manager:detail", kwargs={"pk": 9999}))
        time_taken = time.time() - start
        self.assertLess(time_taken, 1, "Entry detail page took longer than 1 second")
test_home_page()

Home page requires 11 queries.

  1. create transaction
  2. lookup session
  3. lookup user
  4. filter available actiontypes
  5. count entries with actiontype=1 and is_active
  6. count by user
  7. context processor active_count active blocks
  8. context processor active_count all blocks
  9. first page for actiontype=1
  10. close transaction
Source code in scram/route_manager/tests/test_performance.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def test_home_page(self):
    """Home page requires 11 queries.

    1. create transaction
    2. lookup session
    3. lookup user
    4. filter available actiontypes
    5. count entries with actiontype=1 and is_active
    6. count by user
    7. context processor active_count active blocks
    8. context processor active_count all blocks
    9. first page for actiontype=1
    10. close transaction
    """
    with self.assertNumQueries(10):
        start = time.time()
        self.client.get(reverse("route_manager:home"))
        time_taken = time.time() - start
        self.assertLess(time_taken, 1, "Home page took longer than 1 second")
test_process_updates()

Process expired requires 6 queries.

  1. create transaction
  2. get entries_start active entry count
  3. find and delete expired entries
  4. get entries_end active entry count
  5. get new_entries from DB
  6. release transaction
Source code in scram/route_manager/tests/test_performance.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def test_process_updates(self):
    """Process expired requires 6 queries.

    1. create transaction
    2. get entries_start active entry count
    3. find and delete expired entries
    4. get entries_end active entry count
    5. get new_entries from DB
    6. release transaction
    """
    with self.assertNumQueries(6):
        start = time.time()
        self.client.get(reverse("route_manager:process-updates"))
        time_taken = time.time() - start
        self.assertLess(time_taken, 1, "Process expired page took longer than 1 seconds")

test_swagger

Test the swagger API endpoints.

test_redoc_api(client)

Test that the Redoc API endpoint returns a successful response.

Source code in scram/route_manager/tests/test_swagger.py
15
16
17
18
19
20
@pytest.mark.django_db
def test_redoc_api(client):
    """Test that the Redoc API endpoint returns a successful response."""
    url = reverse("redoc")
    response = client.get(url)
    assert response.status_code == 200

test_schema_api(client)

Test that the Schema API endpoint returns a successful response.

Source code in scram/route_manager/tests/test_swagger.py
23
24
25
26
27
28
29
30
@pytest.mark.django_db
def test_schema_api(client):
    """Test that the Schema API endpoint returns a successful response."""
    url = reverse("schema")
    response = client.get(url)
    assert response.status_code == 200
    expected_strings = [b"/entries/", b"/actiontypes/", b"/ignore_entries/", b"/users/"]
    assert all(string in response.content for string in expected_strings)

test_swagger_api(client)

Test that the Swagger API endpoint returns a successful response.

Source code in scram/route_manager/tests/test_swagger.py
 7
 8
 9
10
11
12
@pytest.mark.django_db
def test_swagger_api(client):
    """Test that the Swagger API endpoint returns a successful response."""
    url = reverse("swagger-ui")
    response = client.get(url)
    assert response.status_code == 200

test_views

Define simple tests for the template-based Views.

HomePageFirstVisitTest

Bases: TestCase

Test how the home page renders the first time we view it.

Source code in scram/route_manager/tests/test_views.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class HomePageFirstVisitTest(TestCase):
    """Test how the home page renders the first time we view it."""

    def setUp(self):
        """Get the home page."""
        self.response = self.client.get(reverse("route_manager:home"))

    def test_first_homepage_view_has_userinfo(self):
        """The first time we view the home page, a user was created for us."""
        self.assertContains(self.response, b"An admin user was created for you.")

    def test_first_homepage_view_is_logged_in(self):
        """The first time we view the home page, we're logged in."""
        self.assertContains(self.response, b'type="submit">Logout')
setUp()

Get the home page.

Source code in scram/route_manager/tests/test_views.py
22
23
24
def setUp(self):
    """Get the home page."""
    self.response = self.client.get(reverse("route_manager:home"))
test_first_homepage_view_has_userinfo()

The first time we view the home page, a user was created for us.

Source code in scram/route_manager/tests/test_views.py
26
27
28
def test_first_homepage_view_has_userinfo(self):
    """The first time we view the home page, a user was created for us."""
    self.assertContains(self.response, b"An admin user was created for you.")
test_first_homepage_view_is_logged_in()

The first time we view the home page, we're logged in.

Source code in scram/route_manager/tests/test_views.py
30
31
32
def test_first_homepage_view_is_logged_in(self):
    """The first time we view the home page, we're logged in."""
    self.assertContains(self.response, b'type="submit">Logout')

HomePageLogoutTest

Bases: TestCase

Verify that once logged out, we can't view anything.

Source code in scram/route_manager/tests/test_views.py
35
36
37
38
39
40
41
42
43
44
45
46
47
class HomePageLogoutTest(TestCase):
    """Verify that once logged out, we can't view anything."""

    def test_homepage_logout_links_missing(self):
        """After logout, we can't see anything."""
        response = self.client.get(reverse("route_manager:home"))
        response = self.client.post(reverse(settings.LOGOUT_URL), follow=True)
        self.assertEqual(response.status_code, 200)
        response = self.client.get(reverse("route_manager:home"))

        self.assertNotContains(response, b"An admin user was created for you.")
        self.assertNotContains(response, b'type="submit">Logout')
        self.assertNotContains(response, b">Admin</a>")

After logout, we can't see anything.

Source code in scram/route_manager/tests/test_views.py
38
39
40
41
42
43
44
45
46
47
def test_homepage_logout_links_missing(self):
    """After logout, we can't see anything."""
    response = self.client.get(reverse("route_manager:home"))
    response = self.client.post(reverse(settings.LOGOUT_URL), follow=True)
    self.assertEqual(response.status_code, 200)
    response = self.client.get(reverse("route_manager:home"))

    self.assertNotContains(response, b"An admin user was created for you.")
    self.assertNotContains(response, b'type="submit">Logout')
    self.assertNotContains(response, b">Admin</a>")

HomePageTest

Bases: TestCase

Test how the home page renders.

Source code in scram/route_manager/tests/test_views.py
10
11
12
13
14
15
16
class HomePageTest(TestCase):
    """Test how the home page renders."""

    def test_root_url_resolves_to_home_page_view(self):
        """Ensure we can find the home page."""
        found = resolve("/")
        self.assertEqual(found.func, home_page)
test_root_url_resolves_to_home_page_view()

Ensure we can find the home page.

Source code in scram/route_manager/tests/test_views.py
13
14
15
16
def test_root_url_resolves_to_home_page_view(self):
    """Ensure we can find the home page."""
    found = resolve("/")
    self.assertEqual(found.func, home_page)

NotFoundTest

Bases: TestCase

Verify that our custom 404 page is being served up.

Source code in scram/route_manager/tests/test_views.py
50
51
52
53
54
55
56
57
58
class NotFoundTest(TestCase):
    """Verify that our custom 404 page is being served up."""

    def test_404(self):
        """Grab a bad URL."""
        response = self.client.get("/foobarbaz")
        self.assertContains(
            response, b'<div class="mb-4 lead">The page you are looking for was not found.</div>', status_code=404
        )
test_404()

Grab a bad URL.

Source code in scram/route_manager/tests/test_views.py
53
54
55
56
57
58
def test_404(self):
    """Grab a bad URL."""
    response = self.client.get("/foobarbaz")
    self.assertContains(
        response, b'<div class="mb-4 lead">The page you are looking for was not found.</div>', status_code=404
    )

test_websockets

Define unit tests for the websockets-based communication.

TestTranslatorBaseCase

Bases: TestCase

Base case that other test cases build on top of. Three translators in one group, test one v4 and one v6.

Source code in scram/route_manager/tests/test_websockets.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class TestTranslatorBaseCase(TestCase):
    """Base case that other test cases build on top of. Three translators in one group, test one v4 and one v6."""

    def setUp(self):
        """Set up our test environment."""
        # TODO: This is copied from test_api; should de-dupe this.
        self.url = reverse("api:v1:entry-list")
        self.superuser = get_user_model().objects.create_superuser("admin", "admin@example.net", "admintestpassword")
        self.client.login(username="admin", password="admintestpassword")
        self.uuid = "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3"

        self.action_name = "block"

        self.actiontype, _ = ActionType.objects.get_or_create(name=self.action_name)
        self.actiontype.save()

        self.authorized_client = Client.objects.create(
            hostname="authorized_client.example.net",
            uuid=self.uuid,
            is_authorized=True,
        )
        self.authorized_client.authorized_actiontypes.set([self.actiontype])

        wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
        _, _ = WebSocketSequenceElement.objects.get_or_create(
            websocketmessage=wsm,
            verb="A",
            action_type=self.actiontype,
        )

        # Set some defaults; some child classes override this
        self.actiontypes = ["block"] * 3
        self.should_match = [True] * 3
        self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}]

        # Now we run any local setup actions by the child classes
        self.local_setup()

    def local_setup(self):
        """Allow child classes to override this if desired."""
        return

    async def get_messages(self, communicator, messages, should_match):
        """Receive a number of messages from the WebSocket and validate them."""
        for msg in messages:
            response = json.loads(await communicator.receive_from())
            match = response == msg
            assert match == should_match

    async def get_nothings(self, communicator):
        """Check there are no more messages waiting."""
        assert await communicator.receive_nothing(timeout=0.1, interval=0.01) is False

    async def add_ip(self, ip, mask):
        """Ensure we can add an IP to block."""
        async with get_communicators(self.actiontypes, self.should_match) as communicators:
            await self.api_create_entry(ip)

            # A list of that many function calls to verify the response
            get_message_func_calls = [
                self.get_messages(c, self.generate_add_msgs(ip, mask), should_match)
                for c, should_match in communicators
            ]

            # Turn our list into parameters to the function and await them all
            await gather(*get_message_func_calls)

            await self.ensure_no_more_msgs(communicators)

    async def ensure_no_more_msgs(self, communicators):
        """Run through all communicators and ensure they have no messages waiting."""
        get_nothing_func_calls = [self.get_nothings(c) for c, _ in communicators]

        # Ensure we don't receive any other messages
        await gather(*get_nothing_func_calls)

    # Django ensures that the create is synchronous, so we have some extra steps to do
    @sync_to_async
    def api_create_entry(self, route):
        """Ensure we can create an Entry via the API."""
        return self.client.post(
            self.url,
            {
                "route": route,
                "comment": "test",
                "uuid": self.uuid,
                "who": "Test User",
            },
            format="json",
        )

    async def test_add_v4(self):
        """Test adding a few v4 routes."""
        await self.add_ip("192.0.2.224", 32)
        await self.add_ip("192.0.2.225", 32)
        await self.add_ip("192.0.2.226", 32)
        await self.add_ip("198.51.100.224", 32)

    async def test_add_v6(self):
        """Test adding a few v6 routes."""
        await self.add_ip("2001:DB8:FDF0::", 128)
        await self.add_ip("2001:DB8:FDF0::D", 128)
        await self.add_ip("2001:DB8:FDF0::DB", 128)
        await self.add_ip("2001:DB8:FDF0::DB8", 128)
add_ip(ip, mask) async

Ensure we can add an IP to block.

Source code in scram/route_manager/tests/test_websockets.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
async def add_ip(self, ip, mask):
    """Ensure we can add an IP to block."""
    async with get_communicators(self.actiontypes, self.should_match) as communicators:
        await self.api_create_entry(ip)

        # A list of that many function calls to verify the response
        get_message_func_calls = [
            self.get_messages(c, self.generate_add_msgs(ip, mask), should_match)
            for c, should_match in communicators
        ]

        # Turn our list into parameters to the function and await them all
        await gather(*get_message_func_calls)

        await self.ensure_no_more_msgs(communicators)
api_create_entry(route)

Ensure we can create an Entry via the API.

Source code in scram/route_manager/tests/test_websockets.py
124
125
126
127
128
129
130
131
132
133
134
135
136
@sync_to_async
def api_create_entry(self, route):
    """Ensure we can create an Entry via the API."""
    return self.client.post(
        self.url,
        {
            "route": route,
            "comment": "test",
            "uuid": self.uuid,
            "who": "Test User",
        },
        format="json",
    )
ensure_no_more_msgs(communicators) async

Run through all communicators and ensure they have no messages waiting.

Source code in scram/route_manager/tests/test_websockets.py
116
117
118
119
120
121
async def ensure_no_more_msgs(self, communicators):
    """Run through all communicators and ensure they have no messages waiting."""
    get_nothing_func_calls = [self.get_nothings(c) for c, _ in communicators]

    # Ensure we don't receive any other messages
    await gather(*get_nothing_func_calls)
get_messages(communicator, messages, should_match) async

Receive a number of messages from the WebSocket and validate them.

Source code in scram/route_manager/tests/test_websockets.py
89
90
91
92
93
94
async def get_messages(self, communicator, messages, should_match):
    """Receive a number of messages from the WebSocket and validate them."""
    for msg in messages:
        response = json.loads(await communicator.receive_from())
        match = response == msg
        assert match == should_match
get_nothings(communicator) async

Check there are no more messages waiting.

Source code in scram/route_manager/tests/test_websockets.py
96
97
98
async def get_nothings(self, communicator):
    """Check there are no more messages waiting."""
    assert await communicator.receive_nothing(timeout=0.1, interval=0.01) is False
local_setup()

Allow child classes to override this if desired.

Source code in scram/route_manager/tests/test_websockets.py
85
86
87
def local_setup(self):
    """Allow child classes to override this if desired."""
    return
setUp()

Set up our test environment.

Source code in scram/route_manager/tests/test_websockets.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def setUp(self):
    """Set up our test environment."""
    # TODO: This is copied from test_api; should de-dupe this.
    self.url = reverse("api:v1:entry-list")
    self.superuser = get_user_model().objects.create_superuser("admin", "admin@example.net", "admintestpassword")
    self.client.login(username="admin", password="admintestpassword")
    self.uuid = "0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3"

    self.action_name = "block"

    self.actiontype, _ = ActionType.objects.get_or_create(name=self.action_name)
    self.actiontype.save()

    self.authorized_client = Client.objects.create(
        hostname="authorized_client.example.net",
        uuid=self.uuid,
        is_authorized=True,
    )
    self.authorized_client.authorized_actiontypes.set([self.actiontype])

    wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
    _, _ = WebSocketSequenceElement.objects.get_or_create(
        websocketmessage=wsm,
        verb="A",
        action_type=self.actiontype,
    )

    # Set some defaults; some child classes override this
    self.actiontypes = ["block"] * 3
    self.should_match = [True] * 3
    self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}]

    # Now we run any local setup actions by the child classes
    self.local_setup()
test_add_v4() async

Test adding a few v4 routes.

Source code in scram/route_manager/tests/test_websockets.py
138
139
140
141
142
143
async def test_add_v4(self):
    """Test adding a few v4 routes."""
    await self.add_ip("192.0.2.224", 32)
    await self.add_ip("192.0.2.225", 32)
    await self.add_ip("192.0.2.226", 32)
    await self.add_ip("198.51.100.224", 32)
test_add_v6() async

Test adding a few v6 routes.

Source code in scram/route_manager/tests/test_websockets.py
145
146
147
148
149
150
async def test_add_v6(self):
    """Test adding a few v6 routes."""
    await self.add_ip("2001:DB8:FDF0::", 128)
    await self.add_ip("2001:DB8:FDF0::D", 128)
    await self.add_ip("2001:DB8:FDF0::DB", 128)
    await self.add_ip("2001:DB8:FDF0::DB8", 128)

TranslatorDontCrossTheStreamsTestCase

Bases: TestTranslatorBaseCase

Two translators in one group, two in another group, single IP, ensure we get only the messages we expect.

Source code in scram/route_manager/tests/test_websockets.py
153
154
155
156
157
158
159
class TranslatorDontCrossTheStreamsTestCase(TestTranslatorBaseCase):
    """Two translators in one group, two in another group, single IP, ensure we get only the messages we expect."""

    def local_setup(self):
        """Define the actions and what we expect."""
        self.actiontypes = ["block", "block", "noop", "noop"]
        self.should_match = [True, True, False, False]
local_setup()

Define the actions and what we expect.

Source code in scram/route_manager/tests/test_websockets.py
156
157
158
159
def local_setup(self):
    """Define the actions and what we expect."""
    self.actiontypes = ["block", "block", "noop", "noop"]
    self.should_match = [True, True, False, False]

TranslatorParametersTestCase

Bases: TestTranslatorBaseCase

Additional parameters in the JSONField.

Source code in scram/route_manager/tests/test_websockets.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class TranslatorParametersTestCase(TestTranslatorBaseCase):
    """Additional parameters in the JSONField."""

    def local_setup(self):
        """Define the message we want to send."""
        wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route")
        wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."}
        wsm.save()

        self.generate_add_msgs = [
            lambda ip, mask: {
                "type": "translator_add",
                "message": {"asn": 65550, "community": 100, "route": f"{ip}/{mask}"},
            },
            lambda ip, mask: {
                "type": "translator_add",
                "message": {"asn": 64496, "community": 4294967295, "route": f"{ip}/{mask}"},
            },
        ]
local_setup()

Define the message we want to send.

Source code in scram/route_manager/tests/test_websockets.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def local_setup(self):
    """Define the message we want to send."""
    wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route")
    wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."}
    wsm.save()

    self.generate_add_msgs = [
        lambda ip, mask: {
            "type": "translator_add",
            "message": {"asn": 65550, "community": 100, "route": f"{ip}/{mask}"},
        },
        lambda ip, mask: {
            "type": "translator_add",
            "message": {"asn": 64496, "community": 4294967295, "route": f"{ip}/{mask}"},
        },
    ]

TranslatorSequenceTestCase

Bases: TestTranslatorBaseCase

Test a sequence of WebSocket messages.

Source code in scram/route_manager/tests/test_websockets.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class TranslatorSequenceTestCase(TestTranslatorBaseCase):
    """Test a sequence of WebSocket messages."""

    def local_setup(self):
        """Define the messages we want to send."""
        wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo")
        _ = WebSocketSequenceElement.objects.create(
            websocketmessage=wsm2,
            verb="A",
            action_type=self.actiontype,
            order_num=20,
        )
        wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar")
        _ = WebSocketSequenceElement.objects.create(
            websocketmessage=wsm3,
            verb="A",
            action_type=self.actiontype,
            order_num=2,
        )

        self.generate_add_msgs = [
            lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}},  # order_num=0
            lambda ip, mask: {"type": "translator_add", "message": {"bar": f"{ip}/{mask}"}},  # order_num=2
            lambda ip, mask: {"type": "translator_add", "message": {"foo": f"{ip}/{mask}"}},  # order_num=20
        ]
local_setup()

Define the messages we want to send.

Source code in scram/route_manager/tests/test_websockets.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def local_setup(self):
    """Define the messages we want to send."""
    wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo")
    _ = WebSocketSequenceElement.objects.create(
        websocketmessage=wsm2,
        verb="A",
        action_type=self.actiontype,
        order_num=20,
    )
    wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar")
    _ = WebSocketSequenceElement.objects.create(
        websocketmessage=wsm3,
        verb="A",
        action_type=self.actiontype,
        order_num=2,
    )

    self.generate_add_msgs = [
        lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}},  # order_num=0
        lambda ip, mask: {"type": "translator_add", "message": {"bar": f"{ip}/{mask}"}},  # order_num=2
        lambda ip, mask: {"type": "translator_add", "message": {"foo": f"{ip}/{mask}"}},  # order_num=20
    ]

get_communicators(actiontypes, should_match, *args, **kwds) async

Create a set of communicators, and then handle tear-down.

Given two lists of the same length, a set of actiontypes, and set of boolean values, creates that many communicators, one for each actiontype-bool pair.

The boolean determines whether or not we're expecting to recieve a message to that communicator.

Returns a list of (communicator, should_match bool) pairs.

Source code in scram/route_manager/tests/test_websockets.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@asynccontextmanager
async def get_communicators(actiontypes, should_match, *args, **kwds):
    """Create a set of communicators, and then handle tear-down.

    Given two lists of the same length, a set of actiontypes, and set of boolean values,
    creates that many communicators, one for each actiontype-bool pair.

    The boolean determines whether or not we're expecting to recieve a message to that communicator.

    Returns a list of (communicator, should_match bool) pairs.
    """
    router = URLRouter(websocket_urlpatterns)
    communicators = [
        WebsocketCommunicator(router, f"/ws/route_manager/translator_{actiontype}/") for actiontype in actiontypes
    ]
    response = zip(communicators, should_match, strict=True)

    for communicator, _ in response:
        connected, _ = await communicator.connect()
        assert connected

    try:
        yield response

    finally:
        for communicator, _ in response:
            await communicator.disconnect()

urls

Register URLs known to Django, and the View that will handle each.

views

Define the Views that will handle the HTTP requests.

EntryDetailView

Bases: PermissionRequiredMixin, DetailView

Define a view for the API to use.

Source code in scram/route_manager/views.py
111
112
113
114
115
116
class EntryDetailView(PermissionRequiredMixin, DetailView):
    """Define a view for the API to use."""

    permission_required = ["route_manager.view_entry"]
    model = Entry
    template_name = "route_manager/entry_detail.html"

EntryListView

Bases: ListView

Define a view for the API to use.

Source code in scram/route_manager/views.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
class EntryListView(ListView):
    """Define a view for the API to use."""

    model = Entry
    template_name = "route_manager/entry_list.html"
    context_object_name = "object_list"
    paginate_by = settings.PAGINATION_SIZE

    def get_context_data(self, **kwargs):
        """Add action type grouping to context with separate paginators."""
        context = super().get_context_data(**kwargs)

        current_page_params = {}
        for key, value in self.request.GET.items():
            if key.startswith("page_"):
                current_page_params[key] = value

        entries_by_type = {}

        # Get all available action types
        for at in ActionType.objects.filter(available=True):
            queryset = Entry.objects.filter(actiontype=at, is_active=True).order_by("-pk")

            # Create a paginator for this action type
            paginator = Paginator(queryset, settings.PAGINATION_SIZE)

            # Get page number from request with a unique parameter name per type
            page_param = f"page_{at.name.lower()}"
            page_number = self.request.GET.get(page_param, 1)

            try:
                page_obj = paginator.page(page_number)
            except (PageNotAnInteger, EmptyPage):
                page_obj = paginator.page(1)

            entries_by_type[at] = {
                "total": queryset.count(),
                "objs": page_obj,
                "page_param": page_param,
                "page_number": page_number,
                "current_page_params": current_page_params.copy(),
            }

        context["entries"] = entries_by_type
        return context

get_context_data(**kwargs)

Add action type grouping to context with separate paginators.

Source code in scram/route_manager/views.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_context_data(self, **kwargs):
    """Add action type grouping to context with separate paginators."""
    context = super().get_context_data(**kwargs)

    current_page_params = {}
    for key, value in self.request.GET.items():
        if key.startswith("page_"):
            current_page_params[key] = value

    entries_by_type = {}

    # Get all available action types
    for at in ActionType.objects.filter(available=True):
        queryset = Entry.objects.filter(actiontype=at, is_active=True).order_by("-pk")

        # Create a paginator for this action type
        paginator = Paginator(queryset, settings.PAGINATION_SIZE)

        # Get page number from request with a unique parameter name per type
        page_param = f"page_{at.name.lower()}"
        page_number = self.request.GET.get(page_param, 1)

        try:
            page_obj = paginator.page(page_number)
        except (PageNotAnInteger, EmptyPage):
            page_obj = paginator.page(1)

        entries_by_type[at] = {
            "total": queryset.count(),
            "objs": page_obj,
            "page_param": page_param,
            "page_number": page_number,
            "current_page_params": current_page_params.copy(),
        }

    context["entries"] = entries_by_type
    return context

add_entry(request)

Send a WebSocket message when adding a new entry.

Source code in scram/route_manager/views.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@permission_required(["route_manager.view_entry", "route_manager.add_entry"])
def add_entry(request):
    """Send a WebSocket message when adding a new entry."""
    with transaction.atomic():
        res = add_entry_api(request)

    if res.status_code == 201:  # noqa: PLR2004
        messages.add_message(
            request,
            messages.SUCCESS,
            "Entry successfully added",
        )
    elif res.status_code == 400:  # noqa: PLR2004
        errors = []
        if isinstance(res.data, rest_framework.utils.serializer_helpers.ReturnDict):
            for k, v in res.data.items():
                errors.extend(f"'{k}' error: {error!s}" for error in v)
        else:
            errors.extend(f"error: {v!s}" for v in res.data.values())
        messages.add_message(request, messages.ERROR, "<br>".join(errors))
    elif res.status_code == 403:  # noqa: PLR2004
        messages.add_message(request, messages.ERROR, "Permission Denied")
    else:
        messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}")
    with transaction.atomic():
        home_page(request)
    return redirect("route_manager:home")

delete_entry(request, pk)

Wrap delete via the API and redirect to the home page.

Source code in scram/route_manager/views.py
103
104
105
106
107
108
@require_POST
@permission_required(["route_manager.view_entry", "route_manager.delete_entry"])
def delete_entry(request, pk):
    """Wrap delete via the API and redirect to the home page."""
    delete_entry_api(request, pk)
    return redirect("route_manager:home")

home_page(request, prefilter=None)

Return the home page, autocreating a user if none exists.

Source code in scram/route_manager/views.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def home_page(request, prefilter=None):
    """Return the home page, autocreating a user if none exists."""
    if not prefilter:
        prefilter = Entry.objects.all().select_related("actiontype", "route")
    num_entries = settings.RECENT_LIMIT
    if request.user.has_perms(("route_manager.view_entry", "route_manager.add_entry")):
        readwrite = True
    else:
        readwrite = False
    context = {"entries": {}, "readwrite": readwrite}
    for at in ActionType.objects.all():
        queryset_active = prefilter.filter(actiontype=at, is_active=True).order_by("-pk")
        context["entries"][at] = {
            "objs": queryset_active[:num_entries],
            "active": queryset_active.count(),
        }

    if settings.AUTOCREATE_ADMIN:
        if User.objects.count() == 0:
            password = make_random_password(length=20)
            User.objects.create_superuser("admin", "admin@example.com", password)
            authenticated_admin = authenticate(request, username="admin", password=password)
            login(request, authenticated_admin)
            messages.add_message(
                request,
                messages.SUCCESS,
                f"An admin user was created for you. Please save this password: {password}",
            )
            messages.add_message(
                request,
                messages.INFO,
                "You have been logged in as the admin user",
            )

    return render(request, "route_manager/home.html", context)

process_updates(request)

For entries with an expiration, set them to inactive if expired.

Grab and announce any new entries added to the shared database by other SCRAM instances.

Return some simple stats.

Source code in scram/route_manager/views.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def process_updates(request):
    """For entries with an expiration, set them to inactive if expired.

    Grab and announce any new entries added to the shared database by other SCRAM instances.

    Return some simple stats.
    """
    logger.debug("Executing process_updates")
    # This operation should be atomic, but we set ATOMIC_REQUESTS=True
    current_time = timezone.now()
    entries_start = Entry.objects.filter(is_active=True).count()

    logger.debug("Looking for expired entries")
    # More efficient to call objects.filter.delete, but that doesn't call the Entry.delete() method
    for obj in Entry.objects.filter(is_active=True, expiration__lt=current_time):
        logger.info("Found expired entry: %s. Deleting now", obj)
        obj.delete()
    entries_end = Entry.objects.filter(is_active=True).count()

    logger.debug("Looking for new entries from other SCRAM instances")
    # Grab all entries from the last 2 minutes that originated from a different SCRAM instance.
    cutoff_time = current_time - timedelta(minutes=2)
    new_entries = Entry.objects.filter(when__gt=cutoff_time).exclude(
        originating_scram_instance=settings.SCRAM_HOSTNAME
    )

    # Resend new entries that popped up in the database
    for entry in new_entries:
        logger.info("Processing new entry: %s", entry)
        translator_group = f"translator_{entry.actiontype}"
        elements = list(
            WebSocketSequenceElement.objects.filter(action_type__name=entry.actiontype).order_by("order_num")
        )
        for element in elements:
            msg = element.websocketmessage
            msg.msg_data[msg.msg_data_route_field] = str(entry.route)

            json_to_send = {"type": msg.msg_type, "message": msg.msg_data}
            async_to_sync(channel_layer.group_send)(translator_group, json_to_send)

    return HttpResponse(
        json.dumps(
            {
                "entries_deleted": entries_start - entries_end,
                "active_entries": entries_end,
                "remote_entries_added": new_entries.count(),
            },
        ),
        content_type="application/json",
    )

search_entries(request)

Wrap the home page with a specified CIDR to restrict Entries to.

Source code in scram/route_manager/views.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def search_entries(request):
    """Wrap the home page with a specified CIDR to restrict Entries to."""
    if request.method != "POST":
        return redirect("route_manager:home")

    try:
        # Using ipaddress because we needed to turn off strict mode
        # (which netfields uses by default with seemingly no toggle)
        # This caused searches with host bits set to 500 which is bad UX see: 68854ee1ad4789a62863083d521bddbc96ab7025
        addr = ipaddress.ip_network(request.POST.get("cidr"), strict=False)
    except ValueError:
        try:
            # leading space was breaking ipaddress module
            str_addr = str(request.POST.get("cidr")).strip()
            addr = ipaddress.ip_network(str_addr, strict=False)
        except ValueError:
            messages.add_message(request, messages.ERROR, "Search query was not a valid CIDR address")

            # Send a 400, but show the home page instead of an error page
            return HttpResponseBadRequest(render(request, "route_manager/home.html"))

    # We call home_page because search is just a more specific case of the same view and template to return.
    return home_page(
        request,
        Entry.objects.filter(route__route__net_contained_or_equal=addr),
    )