Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AfterEagerFind() #802

Merged
merged 1 commit into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,24 @@ type AfterFindable interface {
AfterFind(*Connection) error
}

func (m *Model) afterFind(c *Connection) error {
if x, ok := m.Value.(AfterFindable); ok {
// AfterEagerFindable callback will be called after a record, or records,
// has been retrieved from the database and their associations have been
// eagerly loaded.
type AfterEagerFindable interface {
AfterEagerFind(*Connection) error
}

func (m *Model) afterFind(c *Connection, eager bool) error {
if x, ok := m.Value.(AfterFindable); ok && !eager {
if err := x.AfterFind(c); err != nil {
return err
}
}
if x, ok := m.Value.(AfterEagerFindable); ok && eager {
if err := x.AfterEagerFind(c); err != nil {
return err
}
}
Copy link
Member

@sio4 sio4 Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be tiny. Since those two blocks are exclusive to each other, I would prefer to rewrite the block for readability as below. What do you think? (same for blocks in line 49)

	if eager {
		if x, ok := m.Value.(AfterFindable); ok {
			if err := x.AfterFind(c); err != nil {
				return err
			}
		}
	} else {
		if x, ok := m.Value.(AfterEagerFindable); ok {
			if err := x.AfterEagerFind(c); err != nil {
				return err
			}
		}
	}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically just a habit of keeping nesting flat - would be fine either way :)


// if the "model" is a slice/array we want
// to loop through each of the elements in the collection
Expand All @@ -34,9 +46,13 @@ func (m *Model) afterFind(c *Connection) error {
wg.Go(func() error {
y := rv.Index(i)
y = y.Addr()
if x, ok := y.Interface().(AfterFindable); ok {
if x, ok := y.Interface().(AfterFindable); ok && !eager {
return x.AfterFind(c)
}

if x, ok := y.Interface().(AfterEagerFindable); ok && eager {
return x.AfterEagerFind(c)
}
return nil
})
}(i)
Expand Down
13 changes: 11 additions & 2 deletions callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func Test_Callbacks(t *testing.T) {
r.Equal("AF", user.AfterF)
r.NoError(tx.Find(user, user.ID))
r.Equal("AfterFind", user.AfterF)
r.Empty(user.AfterEF)

r.NoError(tx.Eager().Find(user, user.ID))
r.Equal("AfterEagerFind", user.AfterEF)

r.NoError(tx.Destroy(user))

Expand All @@ -70,11 +74,16 @@ func Test_Callbacks_on_Slice(t *testing.T) {

users := CallbacksUsers{}
r.NoError(tx.All(&users))

r.Len(users, 2)

for _, u := range users {
r.Equal("AfterFind", u.AfterF)
r.Empty(u.AfterEF)
}

r.NoError(tx.Eager().All(&users))
r.Len(users, 2)
for _, u := range users {
r.Equal("AfterEagerFind", u.AfterEF)
}
})
}
37 changes: 26 additions & 11 deletions finders.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,29 @@ func (c *Connection) First(model interface{}) error {
//
// q.Where("name = ?", "mark").First(&User{})
func (q *Query) First(model interface{}) error {
var m *Model
err := q.Connection.timeFunc("First", func() error {
q.Limit(1)
m := NewModel(model, q.Connection.Context())
m = NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
return err
}
return m.afterFind(q.Connection)
return m.afterFind(q.Connection, false)
})

if err != nil {
return err
}

if q.eager {
err = q.eagerAssociations(model)
err := q.eagerAssociations(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious if there is a specific reason for adding ':' here. If there is a specific reason, an inline comment could be useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mostly a coding habit on my end, also fine either way

q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
}

Expand All @@ -98,14 +103,15 @@ func (c *Connection) Last(model interface{}) error {
//
// q.Where("name = ?", "mark").Last(&User{})
func (q *Query) Last(model interface{}) error {
var m *Model
err := q.Connection.timeFunc("Last", func() error {
q.Limit(1)
q.Order("created_at DESC, id DESC")
m := NewModel(model, q.Connection.Context())
m = NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
return err
}
return m.afterFind(q.Connection)
return m.afterFind(q.Connection, false)
})

if err != nil {
Expand All @@ -115,7 +121,10 @@ func (q *Query) Last(model interface{}) error {
if q.eager {
err = q.eagerAssociations(model)
q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
Expand All @@ -132,17 +141,20 @@ func (c *Connection) All(models interface{}) error {
//
// q.Where("name = ?", "mark").All(&[]User{})
func (q *Query) All(models interface{}) error {
var m *Model
err := q.Connection.timeFunc("All", func() error {
m := NewModel(models, q.Connection.Context())
m = NewModel(models, q.Connection.Context())
err := q.Connection.Dialect.SelectMany(q.Connection, m, *q)
if err != nil {
return err
}

err = q.paginateModel(models)
if err != nil {
return err
}
return m.afterFind(q.Connection)

return m.afterFind(q.Connection, false)
})

if err != nil {
Expand All @@ -152,7 +164,10 @@ func (q *Query) All(models interface{}) error {
if q.eager {
err = q.eagerAssociations(models)
q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
Expand Down Expand Up @@ -301,7 +316,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
// Exists returns true/false if a record exists in the database that matches
// the query.
//
// q.Where("name = ?", "mark").Exists(&User{})
// q.Where("name = ?", "mark").Exists(&User{})
func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery := Q(q.Connection)
q.Clone(tmpQuery) // avoid meddling with original query
Expand Down
6 changes: 6 additions & 0 deletions pop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ type CallbacksUser struct {
AfterU string `db:"after_u"`
AfterD string `db:"after_d"`
AfterF string `db:"after_f"`
AfterEF string `db:"after_ef"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
Expand Down Expand Up @@ -420,6 +421,11 @@ func (u *CallbacksUser) AfterFind(tx *Connection) error {
return nil
}

func (u *CallbacksUser) AfterEagerFind(tx *Connection) error {
u.AfterEF = "AfterEagerFind"
return nil
}

type Label struct {
ID string `db:"id"`
}
Expand Down
1 change: 1 addition & 0 deletions testdata/migrations/20181104135800_callbacks_users.up.fizz
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ create_table("callbacks_users") {
t.Column("after_u", "string", {})
t.Column("after_d", "string", {})
t.Column("after_f", "string", {})
t.Column("after_ef", "string", {})
t.Column("before_v", "string", {})
t.Timestamps()
}