Skip to content

Commit

Permalink
Merge pull request rails#52807 from Shopify/async-associations-loading
Browse files Browse the repository at this point in the history
Add an internal API to trigger association loading asynchronously
  • Loading branch information
byroot committed Sep 5, 2024
2 parents f698d73 + fe92250 commit a5e1d2a
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 60 deletions.
30 changes: 25 additions & 5 deletions activerecord/lib/active_record/associations/association.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ module Associations
# the <tt>reflection</tt> object represents a <tt>:has_many</tt> macro.
class Association # :nodoc:
attr_accessor :owner
attr_reader :target, :reflection, :disable_joins
attr_reader :reflection, :disable_joins

delegate :options, to: :reflection

Expand All @@ -50,6 +50,13 @@ def initialize(owner, reflection)
@skip_strict_loading = nil
end

def target
if @target.is_a?(Promise)
@target = @target.value
end
@target
end

# Resets the \loaded flag to +false+ and sets the \target to +nil+.
def reset
@loaded = false
Expand Down Expand Up @@ -172,14 +179,21 @@ def extensions
# ActiveRecord::RecordNotFound is rescued within the method, and it is
# not reraised. The proxy is \reset and +nil+ is the return value.
def load_target
@target = find_target if (@stale_state && stale_target?) || find_target?
@target = find_target(async: false) if (@stale_state && stale_target?) || find_target?

loaded! unless loaded?
target
rescue ActiveRecord::RecordNotFound
reset
end

def async_load_target # :nodoc:
@target = find_target(async: true) if (@stale_state && stale_target?) || find_target?

loaded! unless loaded?
nil
end

# We can't dump @reflection and @through_reflection since it contains the scope proc
def marshal_dump
ivars = (instance_variables - [:@reflection, :@through_reflection]).map { |name| [name, instance_variable_get(name)] }
Expand Down Expand Up @@ -223,13 +237,19 @@ def ensure_klass_exists!
klass
end

def find_target
def find_target(async: false)
if violates_strict_loading?
Base.strict_loading_violation!(owner: owner.class, reflection: reflection)
end

scope = self.scope
return scope.to_a if skip_statement_cache?(scope)
if skip_statement_cache?(scope)
if async
return scope.load_async.then(&:to_a)
else
return scope.to_a
end
end

sc = reflection.association_scope_cache(klass, owner) do |params|
as = AssociationScope.create { params.bind }
Expand All @@ -238,7 +258,7 @@ def find_target

binds = AssociationScope.get_bind_values(owner, reflection.chain)
klass.with_connection do |c|
sc.execute(binds, c) do |record|
sc.execute(binds, c, async: async) do |record|
set_inverse_instance(record)
if owner.strict_loading_n_plus_one_only? && reflection.macro == :has_many
record.strict_loading!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def delete_through_records(records)
end
end

def find_target
def find_target(async: false)
raise NotImplementedError, "No async loading for HasManyThroughAssociation yet" if async
return [] unless target_reflection_has_associated_record?
return scope.to_a if disable_joins
super
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def reader
def reset
super
@target = nil
@future_target = nil
end

# Implements the writer method, e.g. foo.bar= for Foo.belongs_to :bar
Expand All @@ -43,11 +44,15 @@ def scope_for_create
super.except!(*Array(klass.primary_key))
end

def find_target
def find_target(async: false)
if disable_joins
scope.first
if async
scope.load_async.then(&:first)
else
scope.first
end
else
super.first
super.then(&:first)
end
end

Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/core.rb
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ def cached_find_by(keys, values)
where(wheres).limit(1)
}

begin
statement.execute(values.flatten, connection, allow_retry: true).first
statement.execute(values.flatten, connection, allow_retry: true).then do |r|
r.first
rescue TypeError
raise ActiveRecord::StatementInvalid
end
Expand Down
10 changes: 4 additions & 6 deletions activerecord/lib/active_record/querying.rb
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ def find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block)
end

# Same as <tt>#find_by_sql</tt> but perform the query asynchronously and returns an ActiveRecord::Promise.
def async_find_by_sql(sql, binds = [], preparable: nil, &block)
result = with_connection do |c|
_query_by_sql(c, sql, binds, preparable: preparable, async: true)
end

result.then do |result|
def async_find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block)
with_connection do |c|
_query_by_sql(c, sql, binds, preparable: preparable, allow_retry: allow_retry, async: true)
end.then do |result|
_load_from_sql(result, &block)
end
end
Expand Down
10 changes: 10 additions & 0 deletions activerecord/lib/active_record/relation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,16 @@ def load_async
self
end

def then(&block) # :nodoc:
if @future_result
@future_result.then do
yield self
end
else
super
end
end

# Returns <tt>true</tt> if the relation was scheduled on the background
# thread pool.
def scheduled?
Expand Down
11 changes: 7 additions & 4 deletions activerecord/lib/active_record/statement_cache.rb
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,17 @@ def initialize(query_builder, bind_map, model)
@model = model
end

def execute(params, connection, allow_retry: false, &block)
def execute(params, connection, allow_retry: false, async: false, &block)
bind_values = @bind_map.bind params

sql = @query_builder.sql_for bind_values, connection

@model.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block)
if async
@model.async_find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block)
else
@model.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block)
end
rescue ::RangeError
[]
async ? Promise.wrap([]) : []
end

def self.unsupported_value?(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1839,3 +1839,37 @@ def test_destroy_linked_models
assert_not Author.exists?(author.id)
end
end

class AsyncBelongsToAssociationsTest < ActiveRecord::TestCase
include WaitForAsyncTestHelper

self.use_transactional_tests = false

fixtures :companies

unless in_memory_db?
def test_async_load_belongs_to
client = Client.find(3)
first_firm = companies(:first_firm)

client.association(:firm).async_load_target
wait_for_async_query

events = []
callback = -> (event) do
events << event unless event.payload[:name] == "SCHEMA"
end
ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do
client.firm
end

assert_no_queries do
assert_equal first_firm, client.firm
assert_equal first_firm.name, client.firm.name
end

assert_equal 1, events.size
assert_equal true, events.first.payload[:async]
end
end
end
33 changes: 33 additions & 0 deletions activerecord/test/cases/associations/has_many_associations_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3252,3 +3252,36 @@ def force_signal37_to_load_all_clients_of_firm
companies(:first_firm).clients_of_firm.load_target
end
end

class AsyncHasOneAssociationsTest < ActiveRecord::TestCase
include WaitForAsyncTestHelper

self.use_transactional_tests = false

fixtures :companies

unless in_memory_db?
def test_async_load_has_many
firm = companies(:first_firm)

firm.association(:clients).async_load_target
wait_for_async_query

events = []
callback = -> (event) do
events << event unless event.payload[:name] == "SCHEMA"
end

ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do
assert_equal 3, firm.clients.size
end

assert_no_queries do
assert_not_nil firm.clients[2]
end

assert_equal 1, events.size
assert_equal true, events.first.payload[:async]
end
end
end
34 changes: 34 additions & 0 deletions activerecord/test/cases/associations/has_one_associations_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,37 @@ def test_has_one_with_touch_option_on_nonpersisted_built_associations_doesnt_upd
MESSAGE
end
end

class AsyncHasOneAssociationsTest < ActiveRecord::TestCase
include WaitForAsyncTestHelper

self.use_transactional_tests = false

fixtures :companies, :accounts

unless in_memory_db?
def test_async_load_has_one
firm = companies(:first_firm)
first_account = Account.find(1)

firm.association(:account).async_load_target
wait_for_async_query

events = []
callback = -> (event) do
events << event unless event.payload[:name] == "SCHEMA"
end
ActiveSupport::Notifications.subscribed(callback, "sql.active_record") do
firm.account
end

assert_no_queries do
assert_equal first_account, firm.account
assert_equal first_account.credit_limit, firm.account.credit_limit
end

assert_equal 1, events.size
assert_equal true, events.first.payload[:async]
end
end
end
65 changes: 41 additions & 24 deletions activerecord/test/cases/helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,53 @@
ActiveRecord::ConnectionAdapters.register("abstract", "ActiveRecord::ConnectionAdapters::AbstractAdapter", "active_record/connection_adapters/abstract_adapter")
ActiveRecord::ConnectionAdapters.register("fake", "FakeActiveRecordAdapter", File.expand_path("../support/fake_adapter.rb", __dir__))

class SQLSubscriber
attr_reader :logged
attr_reader :payloads
class ActiveRecord::TestCase
class SQLSubscriber
attr_reader :logged
attr_reader :payloads

def initialize
@logged = []
@payloads = []
end

def start(name, id, payload)
@payloads << payload
@logged << [payload[:sql].squish, payload[:name], payload[:binds]]
end

def initialize
@logged = []
@payloads = []
def finish(name, id, payload); end
end

def start(name, id, payload)
@payloads << payload
@logged << [payload[:sql].squish, payload[:name], payload[:binds]]
module InTimeZone
private
def in_time_zone(zone)
old_zone = Time.zone
old_tz = ActiveRecord::Base.time_zone_aware_attributes

Time.zone = zone ? ActiveSupport::TimeZone[zone] : nil
ActiveRecord::Base.time_zone_aware_attributes = !zone.nil?
yield
ensure
Time.zone = old_zone
ActiveRecord::Base.time_zone_aware_attributes = old_tz
end
end

def finish(name, id, payload); end
end
module WaitForAsyncTestHelper
private
def wait_for_async_query(connection = ActiveRecord::Base.lease_connection, timeout: 5)
return unless connection.async_enabled?

module InTimeZone
private
def in_time_zone(zone)
old_zone = Time.zone
old_tz = ActiveRecord::Base.time_zone_aware_attributes

Time.zone = zone ? ActiveSupport::TimeZone[zone] : nil
ActiveRecord::Base.time_zone_aware_attributes = !zone.nil?
yield
ensure
Time.zone = old_zone
ActiveRecord::Base.time_zone_aware_attributes = old_tz
end
executor = connection.pool.async_executor
(timeout * 100).times do
return unless executor.scheduled_task_count > executor.completed_task_count
sleep 0.01
end

raise Timeout::Error, "The async executor wasn't drained after #{timeout} seconds"
end
end
end

# Encryption
Expand Down
15 changes: 0 additions & 15 deletions activerecord/test/cases/relation/load_async_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,6 @@
require "models/other_dog"

module ActiveRecord
module WaitForAsyncTestHelper
private
def wait_for_async_query(connection = ActiveRecord::Base.lease_connection, timeout: 5)
return unless connection.async_enabled?

executor = connection.pool.async_executor
(timeout * 100).times do
return unless executor.scheduled_task_count > executor.completed_task_count
sleep 0.01
end

raise Timeout::Error, "The async executor wasn't drained after #{timeout} seconds"
end
end

class LoadAsyncTest < ActiveRecord::TestCase
include WaitForAsyncTestHelper

Expand Down

0 comments on commit a5e1d2a

Please sign in to comment.