diff --git a/invenio_accounts/cli.py b/invenio_accounts/cli.py index ff506d2c..fde055e6 100644 --- a/invenio_accounts/cli.py +++ b/invenio_accounts/cli.py @@ -3,6 +3,7 @@ # This file is part of Invenio. # Copyright (C) 2015-2023 CERN. # Copyright (C) 2024 Graz University of Technology. +# Copyright (C) 2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -170,8 +171,14 @@ def users_deactivate(user): @with_appcontext def domains_create(domain): """Create domain.""" + domain = domain.lower() + + if DomainCategory.get(domain): + click.secho(f"Domain {domain} already exists.", fg="red") + return + try: - domain_category = DomainCategory.create(domain.lower()) + domain_category = DomainCategory.create(domain) db.session.merge(domain_category) db.session.commit() except Exception as error: diff --git a/invenio_accounts/models.py b/invenio_accounts/models.py index 7a411567..14163960 100644 --- a/invenio_accounts/models.py +++ b/invenio_accounts/models.py @@ -2,7 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. -# Copyright (C) 2022 KTH Royal Institute of Technology +# Copyright (C) 2022-2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -526,7 +526,7 @@ class DomainCategory(db.Model): id = db.Column(db.Integer(), primary_key=True, autoincrement=True) - label = db.Column(db.String(255)) + label = db.Column(db.String(255), unique=True) @classmethod def create(cls, label): diff --git a/tests/test_cli.py b/tests/test_cli.py index f1d3dc3b..ae130dce 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2018 CERN. +# Copyright (C) 2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -10,6 +11,7 @@ """Module tests.""" from invenio_accounts.cli import ( + domains_create, roles_add, roles_create, roles_remove, @@ -156,3 +158,28 @@ def test_cli_activate_deactivate(app): assert result.exit_code == 0 result = runner.invoke(users_deactivate, ["a@test.org"]) assert result.exit_code == 0 + + +def test_cli_createdomain(app): + """Test create domain CLI.""" + runner = app.test_cli_runner() + + # Create a domain successfully + result = runner.invoke(domains_create, ["mailprovider"]) + assert result.exit_code == 0 + assert "Domain mailprovider created successfully" in result.output + + # Reject Creating the same domain again + result = runner.invoke(domains_create, ["mailprovider"]) + assert result.exit_code == 0 + assert "Domain mailprovider already exists." in result.output + + # Create another domain successfully + result = runner.invoke(domains_create, ["kth.se"]) + assert result.exit_code == 0 + assert "Domain kth.se created successfully" in result.output + + # Create a domain with a fancy case should be treated like others + result = runner.invoke(domains_create, ["MailPrOvIdEr"]) + assert result.exit_code == 0 + assert "Domain mailprovider already exists." in result.output diff --git a/tests/test_models.py b/tests/test_models.py index b5260a21..5bd5f660 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,7 +3,7 @@ # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. # Copyright (C) 2022 TU Wien. -# Copyright (C) 2022 KTH Royal Institute of Technology +# Copyright (C) 2022-2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -14,6 +14,7 @@ from invenio_db import db from marshmallow import Schema, fields from sqlalchemy import inspect +from sqlalchemy.exc import IntegrityError from invenio_accounts import testutils from invenio_accounts.models import ( @@ -213,9 +214,25 @@ def test_domain_org(app): def test_domain_category(app): + """Test DomainCategory creation and retrieval.""" c1 = DomainCategory.create("spammer") c2 = DomainCategory.create("organisation") db.session.commit() c = DomainCategory.get("spammer") assert c.label == "spammer" + + # Try to create a duplicate category + with pytest.raises(IntegrityError): + duplicate_category = DomainCategory.create("spammer") + db.session.commit() + # Clean the state after the IntegrityError + db.session.rollback() + + # Assert a valid category can still be created after the error + c3 = DomainCategory.create("company") + db.session.commit() + + # Make sure the new category is correctly stored + c = DomainCategory.get("company") + assert c.label == "company"