Skip to content

Commit

Permalink
Merge pull request #50 from City-of-Turku/feature/subscribe-allow-any
Browse files Browse the repository at this point in the history
Feature/subscribe allow any
  • Loading branch information
juuso-j committed Apr 11, 2024
2 parents 4c1bf66 + 4d95cc9 commit a5726eb
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 45 deletions.
26 changes: 18 additions & 8 deletions account/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
from rest_framework.response import Response
from rest_framework.throttling import AnonRateThrottle

from account.models import MailingList, MailingListEmail, Profile
from account.models import MailingList, MailingListEmail, Profile, User
from profiles.api.views import update_postal_code_result
from profiles.models import Result

from .serializers import ProfileSerializer, SubscribeSerializer, UnSubscribeSerializer

all_views = []


class UnsubscribeRateThrottle(AnonRateThrottle):
class MailRateThrottle(AnonRateThrottle):
"""
The AnonRateThrottle will only ever throttle unauthenticated users.
The IP address of the incoming request is used to generate a unique key to throttle against.
Expand All @@ -32,7 +31,7 @@ class ProfileViewSet(UpdateModelMixin, viewsets.GenericViewSet):
serializer_class = ProfileSerializer

def get_permissions(self):
if self.action in ["unsubscribe"]:
if self.action in ["unsubscribe", "subscribe"]:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
Expand Down Expand Up @@ -74,11 +73,20 @@ def update(self, request, *args, **kwargs):
@action(
detail=False,
methods=["POST"],
permission_classes=[IsAuthenticated],
permission_classes=[AllowAny],
throttle_classes=[MailRateThrottle],
)
@db.transaction.atomic
def subscribe(self, request):
result = Result.objects.filter(id=request.data.get("result", None)).first()
user = User.objects.filter(id=request.data.get("user", None)).first()
if not user:
return Response("Invalid request", status=status.HTTP_400_BAD_REQUEST)
if user.has_subscribed:
return Response(
"The user has already subscribed", status=status.HTTP_400_BAD_REQUEST
)

result = user.result
if not result:
return Response("'result' not found", status=status.HTTP_400_BAD_REQUEST)

Expand All @@ -97,11 +105,13 @@ def subscribe(self, request):
mailing_list = MailingList.objects.create(result=result)

MailingListEmail.objects.create(mailing_list=mailing_list, email=email)
user.has_subscribed = True
user.save()
return Response("subscribed", status=status.HTTP_201_CREATED)

@extend_schema(
description="Unaubscribe the email from the mailing list attached to the result."
f"Note, there is a rate-limit of {UnsubscribeRateThrottle.rate} requests.",
f"Note, there is a rate-limit of {MailRateThrottle.rate} requests.",
request=UnSubscribeSerializer,
responses={
200: OpenApiResponse(description="unsubscribed"),
Expand All @@ -114,7 +124,7 @@ def subscribe(self, request):
detail=False,
methods=["POST"],
permission_classes=[AllowAny],
throttle_classes=[UnsubscribeRateThrottle],
throttle_classes=[MailRateThrottle],
)
def unsubscribe(self, request):
email = request.data.get("email", None)
Expand Down
18 changes: 18 additions & 0 deletions account/migrations/0013_user_has_subscribed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-10 15:59

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("account", "0012_profile_is_interested_in_mobility"),
]

operations = [
migrations.AddField(
model_name="user",
name="has_subscribed",
field=models.BooleanField(default=False),
),
]
1 change: 1 addition & 0 deletions account/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class User(AbstractUser):
email_verified = models.BooleanField(default=False)
is_active = models.BooleanField(default=True)
is_generated = models.BooleanField(default=False)
has_subscribed = models.BooleanField(default=False)
# Flag that is used to ensure the user is only Once calculated to the PostalCodeResults model.
postal_code_result_saved = models.BooleanField(default=False)

Expand Down
10 changes: 7 additions & 3 deletions account/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def api_client_with_custom_ip_address(ip_address):

@pytest.fixture
def users(results):
User.objects.create(username="test1", result=results.first())
[
User.objects.create(username=f"test{i}", result=results[i % results.count()])
for i in range(20)
]
return User.objects.all()


Expand All @@ -53,9 +56,10 @@ def mailing_lists(results):

@pytest.fixture
def mailing_list_emails(mailing_lists):
for c in range(20):
[
MailingListEmail.objects.create(
email=f"test_{c}@test.com", mailing_list=mailing_lists.first()
)

for c in range(20)
]
return MailingListEmail.objects.all()
121 changes: 89 additions & 32 deletions account/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_unauthenticated_cannot_do_anything(api_client, users):
@pytest.mark.parametrize(
"ip_address",
[
("192.168.1.40"),
("199.168.1.40"),
],
)
def test_mailing_list_unsubscribe_throttling(
Expand Down Expand Up @@ -227,46 +227,80 @@ def test_profile_patch_result_can_be_used(api_client_authenticated, users, profi


@pytest.mark.django_db
def test_mailing_list_unauthenticated_subscribe(api_client, results):
def test_mailing_list_subscribe(api_client, users, results):
url = reverse("account:profiles-subscribe")
response = api_client.post(
url, {"email": "[email protected]", "result": results.first().id}
user = users.get(username="test1")
response = api_client.post(url, {"email": "[email protected]", "user": user.id})
assert response.status_code == 201
assert MailingListEmail.objects.count() == 1
assert MailingListEmail.objects.first().email == "[email protected]"
assert (
MailingList.objects.filter(result=user.result).first().emails.first()
== MailingListEmail.objects.first()
)
assert response.status_code == 401
assert MailingList.objects.first().result == user.result


@pytest.mark.django_db
def test_mailing_list_subscribe(
api_client_authenticated, users, results, mailing_lists
):
def test_mailing_user_has_subscribed(api_client, users, results, mailing_lists):
url = reverse("account:profiles-subscribe")
response = api_client_authenticated.post(
url, {"email": "[email protected]", "result": results.first().id}
response = api_client.post(
url, {"email": "[email protected]", "user": users.first().id}
)
assert response.status_code == 201
assert MailingListEmail.objects.count() == 1
assert MailingListEmail.objects.first().email == "[email protected]"
assert (
MailingList.objects.first().emails.first() == MailingListEmail.objects.first()
response = api_client.post(
url, {"email": "[email protected]", "user": users.first().id}
)
assert response.status_code == 400


@pytest.mark.django_db
@pytest.mark.parametrize(
"ip_address",
[
("92.68.21.220"),
],
)
def test_mailing_list_subscribe_throttling(
api_client_with_custom_ip_address, mailing_list_emails, users
):
num_requests = int(
ProfileViewSet.subscribe.kwargs["throttle_classes"][0].rate.split("/")[0]
)
url = reverse("account:profiles-subscribe")
count = 0
while count < num_requests:
response = api_client_with_custom_ip_address.post(
url,
{
"email": f"throttlling_test_{count}@test.com",
"user": users[count].id,
},
)
assert response.status_code == 201
count += 1

time.sleep(2)
response = api_client_with_custom_ip_address.post(
url, {"email": f"test_{count}@test.com", "user": users[count].id}
)
assert response.status_code == 429


@pytest.mark.django_db
def test_mailing_list_is_created_on_subscribe(api_client_authenticated, users, results):
def test_mailing_list_is_created_on_subscribe(api_client, users, results):
assert MailingList.objects.count() == 0
url = reverse("account:profiles-subscribe")
response = api_client_authenticated.post(
url, {"email": "[email protected]", "result": results.first().id}
response = api_client.post(
url, {"email": "[email protected]", "user": users.first().id}
)
assert response.status_code == 201
assert MailingList.objects.count() == 1
assert MailingListEmail.objects.count() == 1


@pytest.mark.django_db
def test_mailing_list_subscribe_with_invalid_emails(
api_client_authenticated, users, results
):
def test_mailing_list_subscribe_with_invalid_emails(api_client, users, results):
assert MailingList.objects.count() == 0
url = reverse("account:profiles-subscribe")
for email in [
Expand All @@ -276,9 +310,7 @@ def test_mailing_list_subscribe_with_invalid_emails(
"john.doe@example",
"john.doe@example",
]:
response = api_client_authenticated.post(
url, {"email": email, "result": results.first().id}
)
response = api_client.post(url, {"email": email, "user": users.first().id})
assert response.status_code == 400
assert MailingList.objects.count() == 0
assert MailingList.objects.count() == 0
Expand All @@ -290,43 +322,68 @@ def test_mailing_list_subscribe_with_invalid_post_data(
):
url = reverse("account:profiles-subscribe")
# Missing email
response = api_client_authenticated.post(url, {"result": results.first().id})
response = api_client_authenticated.post(url, {"user": users.first().id})
assert response.status_code == 400
assert MailingList.objects.count() == 0
assert MailingList.objects.count() == 0
# Missing result
response = api_client_authenticated.post(url, {"email": "[email protected]"})
assert response.status_code == 400
assert MailingList.objects.count() == 0
assert MailingList.objects.count() == 0


@pytest.mark.django_db
def test_mailing_list_unsubscribe(api_client, mailing_list_emails):
@pytest.mark.parametrize(
"ip_address",
[
("100.1.1.40"),
],
)
def test_mailing_list_unsubscribe(
api_client_with_custom_ip_address, mailing_list_emails
):
num_mailing_list_emails = mailing_list_emails.count()
assert MailingListEmail.objects.count() == num_mailing_list_emails
assert MailingList.objects.first().emails.count() == num_mailing_list_emails
url = reverse("account:profiles-unsubscribe")
response = api_client.post(url, {"email": "[email protected]"})
response = api_client_with_custom_ip_address.post(url, {"email": "[email protected]"})
assert response.status_code == 200
assert MailingListEmail.objects.count() == num_mailing_list_emails - 1
assert MailingList.objects.first().emails.count() == num_mailing_list_emails - 1


@pytest.mark.django_db
def test_mailing_list_unsubscribe_non_existing_email(api_client, mailing_list_emails):
@pytest.mark.django_db
@pytest.mark.parametrize(
"ip_address",
[
("101.1.1.40"),
],
)
def test_mailing_list_unsubscribe_non_existing_email(
api_client_with_custom_ip_address, mailing_list_emails
):
num_mailing_list_emails = mailing_list_emails.count()
assert MailingListEmail.objects.count() == num_mailing_list_emails
assert MailingList.objects.first().emails.count() == num_mailing_list_emails
url = reverse("account:profiles-unsubscribe")
response = api_client.post(url, {"email": "[email protected]"})
response = api_client_with_custom_ip_address.post(
url, {"email": "[email protected]"}
)
assert response.status_code == 400
assert MailingListEmail.objects.count() == num_mailing_list_emails
assert MailingList.objects.first().emails.count() == num_mailing_list_emails


@pytest.mark.django_db
def test_mailing_list_unsubscribe_email_not_provided(api_client, mailing_list_emails):
@pytest.mark.parametrize(
"ip_address",
[
("12.6.121.22"),
],
)
def test_mailing_list_unsubscribe_email_not_provided(
api_client_with_custom_ip_address, mailing_list_emails
):
url = reverse("account:profiles-unsubscribe")
response = api_client.post(url)
response = api_client_with_custom_ip_address.post(url)
assert response.status_code == 400
10 changes: 8 additions & 2 deletions profiles/tests/api/test_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ def test_answer_post_unauthenticated(api_client):


@pytest.mark.django_db
def test_start_poll(api_client):
@pytest.mark.parametrize(
"ip_address",
[
("28.18.23.111"),
],
)
def test_start_poll(api_client_with_custom_ip_address):
User.objects.all().count() == 0
url = reverse("profiles:question-start-poll")
response = api_client.post(url)
response = api_client_with_custom_ip_address.post(url)
assert response.status_code == 200
assert User.objects.all().count() == 1

Expand Down

0 comments on commit a5726eb

Please sign in to comment.