Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to attach RelatedFactoryList result to instance? #1092

Open
albertalexandrov opened this issue Sep 10, 2024 · 6 comments
Open

How to attach RelatedFactoryList result to instance? #1092

albertalexandrov opened this issue Sep 10, 2024 · 6 comments

Comments

@albertalexandrov
Copy link

albertalexandrov commented Sep 10, 2024

Hi!

I have a question about using RelatedFactoryList in async SQLAlchemy. RelatedFactoryList creates instances but they are not attached to instance.

overridden for async base factory (from discussions in this repository):

import inspect

from factory.alchemy import SESSION_PERSISTENCE_COMMIT, SESSION_PERSISTENCE_FLUSH, SQLAlchemyModelFactory
from factory.base import FactoryOptions
from factory.builder import StepBuilder, BuildStep, parse_declarations
from factory import FactoryError, RelatedFactoryList, CREATE_STRATEGY
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, NoResultFound


def use_postgeneration_results(self, step, instance, results):
    return self.factory._after_postgeneration(
        instance,
        create=step.builder.strategy == CREATE_STRATEGY,
        results=results,
    )


FactoryOptions.use_postgeneration_results = use_postgeneration_results


class SQLAlchemyFactory(SQLAlchemyModelFactory):
    @classmethod
    async def _generate(cls, strategy, params):
        if cls._meta.abstract:
            raise FactoryError(
                "Cannot generate instances of abstract factory %(f)s; "
                "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
                "is either not set or False." % dict(f=cls.__name__)
            )

        step = AsyncStepBuilder(cls._meta, params, strategy)
        return await step.build()

    @classmethod
    async def _create(cls, model_class, *args, **kwargs):
        for key, value in kwargs.items():
            if inspect.isawaitable(value):
                kwargs[key] = await value
        return await super()._create(model_class, *args, **kwargs)

    @classmethod
    async def create_batch(cls, size, **kwargs):
        return [await cls.create(**kwargs) for _ in range(size)]

    @classmethod
    async def _save(cls, model_class, session, args, kwargs):
        session_persistence = cls._meta.sqlalchemy_session_persistence
        obj = model_class(*args, **kwargs)
        session.add(obj)
        if session_persistence == SESSION_PERSISTENCE_FLUSH:
            await session.flush()
        elif session_persistence == SESSION_PERSISTENCE_COMMIT:
            await session.commit()
        return obj

    @classmethod
    async def _get_or_create(cls, model_class, session, args, kwargs):
        key_fields = {}
        for field in cls._meta.sqlalchemy_get_or_create:
            if field not in kwargs:
                raise FactoryError(
                    "sqlalchemy_get_or_create - "
                    "Unable to find initialization value for '%s' in factory %s" % (field, cls.__name__)
                )
            key_fields[field] = kwargs.pop(field)

        obj = (await session.execute(select(model_class).filter_by(*args, **key_fields))).scalars().one_or_none()

        if not obj:
            try:
                obj = await cls._save(model_class, session, args, {**key_fields, **kwargs})
            except IntegrityError as e:
                session.rollback()

                if cls._original_params is None:
                    raise e

                get_or_create_params = {
                    lookup: value
                    for lookup, value in cls._original_params.items()
                    if lookup in cls._meta.sqlalchemy_get_or_create
                }
                if get_or_create_params:
                    try:
                        obj = (
                            (await session.execute(select(model_class).filter_by(**get_or_create_params)))
                            .scalars()
                            .one()
                        )
                    except NoResultFound:
                        # Original params are not a valid lookup and triggered a create(),
                        # that resulted in an IntegrityError.
                        raise e
                else:
                    raise e

        return obj


class AsyncStepBuilder(StepBuilder):
    # Redefine build function that await for instance creation and awaitable postgenerations
    async def build(self, parent_step=None, force_sequence=None):
        """Build a factory instance."""
        # TODO: Handle "batch build" natively
        pre, post = parse_declarations(
            self.extras,
            base_pre=self.factory_meta.pre_declarations,
            base_post=self.factory_meta.post_declarations,
        )

        if force_sequence is not None:
            sequence = force_sequence
        elif self.force_init_sequence is not None:
            sequence = self.force_init_sequence
        else:
            sequence = self.factory_meta.next_sequence()

        step = BuildStep(
            builder=self,
            sequence=sequence,
            parent_step=parent_step,
        )
        step.resolve(pre)

        args, kwargs = self.factory_meta.prepare_arguments(step.attributes)

        instance = await self.factory_meta.instantiate(
            step=step,
            args=args,
            kwargs=kwargs,
        )
        postgen_results = {}
        for declaration_name in post.sorted():
            declaration = post[declaration_name]
            declaration_result = declaration.declaration.evaluate_post(
                instance=instance,
                step=step,
                overrides=declaration.context,
            )
            if inspect.isawaitable(declaration_result):
                declaration_result = await declaration_result
            if isinstance(declaration.declaration, RelatedFactoryList):
                for idx, item in enumerate(declaration_result):
                    if inspect.isawaitable(item):
                        declaration_result[idx] = await item
            postgen_results[declaration_name] = declaration_result
        postgen = self.factory_meta.use_postgeneration_results(
            instance=instance,
            step=step,
            results=postgen_results,
        )
        if inspect.isawaitable(postgen):
            await postgen
        return instance

models.py

class TtzFile(Base):
    """Модель файла ТТЗ."""

    __tablename__ = "ttz_files"
    __mapper_args__ = {"eager_defaults": True}

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    ttz_id: Mapped[int] = mapped_column(ForeignKey("ttz.id"))
    attachment_id: Mapped[UUID] = mapped_column()
    ttz: Mapped["Ttz"] = relationship(back_populates="files")


class Ttz(Base):
    __tablename__ = "ttz"

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    name: Mapped[str] = mapped_column(String(250))
    files: Mapped[list["TtzFile"]] = relationship(cascade="all, delete-orphan", back_populates="ttz")

factories.py

class TtzFactory(SQLAlchemyFactory):
    name = Sequence(lambda n: f"ТТЗ {n + 1}")
    start_date = FuzzyDate(parse_date("2024-02-23"))
    is_deleted = False
    output_message = None
    input_message = None
    error_output_message = None
    files = RelatedFactoryList("tests.factories.ttz.TtzFileFactory", 'ttz', 2)

    class Meta:
        model = Ttz
        sqlalchemy_get_or_create = ["name"]
        sqlalchemy_session_factory = Session
        sqlalchemy_session_persistence = SESSION_PERSISTENCE_FLUSH

    @classmethod
    def _after_postgeneration(cls, instance, create, results=None):
        session = cls._meta.sqlalchemy_session_factory()
        return session.refresh(instance, attribute_names=["files"])


class TtzFileFactory(SQLAlchemyFactory):
    ttz = SubFactory(TtzFactory)
    file_name = Faker("file_name")
    attachment_id = FuzzyUuid()

    class Meta:
        model = TtzFile
        sqlalchemy_get_or_create = ["attachment_id"]
        sqlalchemy_session_factory = Session
        sqlalchemy_session_persistence = SESSION_PERSISTENCE_FLUSH

To make it available to get Ttz.files I have do refresh:

@classmethod
def _after_postgeneration(cls, instance, create, results=None):
    session = cls._meta.sqlalchemy_session_factory()
    return session.refresh(instance, attribute_names=["files"])

My question is it is the only way to get Ttz.files? I mean do I have to write _after_postgeneration method in each factory where I need to get related list?

@rbarrois
Copy link
Member

Thanks for providing the full code example.

It is, however, quite complex to read without prior knowledge of your project.

By default, with a RelatedFactoryList, the behaviour is akin to:

ttz = Ttz(name="TTZ 1")
session.add(ttz)
for i in range(2):
  session.add(TtzFile(ttz=ttz, file_name="some_file_name", attachment_id=SomeUUID()))

How would you write that piece of code without factories in order to get the files attribute populated?

@albertalexandrov
Copy link
Author

albertalexandrov commented Sep 12, 2024

Hi, @rbarrois !

I would write like this:

files = []

for i in range(2):
    file = TtzFile(file_name="some_file_name", attachment_id=SomeUUID())
    files.append(file) 

ttz = Ttz(name="TTZ 1", files=files)
session.add(ttz)

As far as I now factory boy first creates main object and then related list.

@rbarrois
Copy link
Member

Your snippet wouldn't work, the ttz is not created beforehand!

However, if that's the way you'd write it, I suggest using a factory.List and a factory.SubFactory:

class FileFactory:
  ...

class TtzFactory:
  files = factory.List([
    factory.SubFactory(FileFactory),
    factory.SubFactory(FileFactory),
  ])

This might work, instantiating the two File objects before attaching them.

@albertalexandrov
Copy link
Author

albertalexandrov commented Sep 12, 2024

There was a mistake (copy paste). I fixed.

Does SubFactory(FileFictory) return a stub object? As you can see TtzFile cannot be created without Ttz.

Sorry, I can't check it because I don't have access to my computer. Well I ll try in a week.

@rbarrois
Copy link
Member

Thanks! Can you try the approach I suggested above, i.e a list of subfactories instead of a RelatedFactoryList?

@albertalexandrov
Copy link
Author

I'll try later in a week when I reach my computer. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants