From 7e5fb0b3bd759ce80f4ce4341b92261fa21630b1 Mon Sep 17 00:00:00 2001 From: Ernie Miller Date: Wed, 28 Apr 2010 23:14:05 -0400 Subject: [PATCH] Fix eager loading of associations causing table name collisions --- activerecord/lib/active_record/associations.rb | 81 +++++++++++++++----- .../lib/active_record/relation/finder_methods.rb | 1 - .../lib/active_record/relation/query_methods.rb | 83 +++++++++++--------- .../associations/cascaded_eager_loading_test.rb | 11 +++- 4 files changed, 119 insertions(+), 57 deletions(-) diff --git a/activerecord/lib/active_record/associations.rb b/activerecord/lib/active_record/associations.rb index 6dbee9f..f896277 100755 --- a/activerecord/lib/active_record/associations.rb +++ b/activerecord/lib/active_record/associations.rb @@ -1736,6 +1736,14 @@ module ActiveRecord @table_aliases[base.table_name] = 1 build(associations) end + + def graft(*associations) + associations.each do |association| + join_associations.detect {|a| association == a} || + build(association.reflection.name, association.find_parent_in(self), association.join_class) + end + self + end def join_associations @joins[1..-1].to_a @@ -1744,6 +1752,16 @@ module ActiveRecord def join_base @joins[0] end + + def count_aliases_from_table_joins(name) + quoted_name = join_base.active_record.connection.quote_table_name(name.downcase) + join_sql = join_base.table_joins.to_s.downcase + join_sql.blank? ? 0 : + # Table names + join_sql.scan(/join(?:\s+\w+)?\s+#{quoted_name}\son/).size + + # Table aliases + join_sql.scan(/join(?:\s+\w+)?\s+\S+\s+#{quoted_name}\son/).size + end def instantiate(rows) rows.each_with_index do |row, i| @@ -1789,22 +1807,22 @@ module ActiveRecord end protected - def build(associations, parent = nil) + def build(associations, parent = nil, join_class = Arel::InnerJoin) parent ||= @joins.last case associations when Symbol, String reflection = parent.reflections[associations.to_s.intern] or raise ConfigurationError, "Association named '#{ associations }' was not found; perhaps you misspelled it?" @reflections << reflection - @joins << build_join_association(reflection, parent) + @joins << build_join_association(reflection, parent).with_join_class(join_class) when Array associations.each do |association| - build(association, parent) + build(association, parent, join_class) end when Hash associations.keys.sort{|a,b|a.to_s<=>b.to_s}.each do |name| - build(name, parent) - build(associations[name]) + build(name, parent, join_class) + build(associations[name], nil, join_class) end else raise ConfigurationError, associations.inspect @@ -1880,6 +1898,12 @@ module ActiveRecord @cached_record = {} @table_joins = joins end + + def ==(other) + other.is_a?(JoinBase) && + other.active_record == active_record && + other.table_joins == table_joins + end def aliased_prefix "t0" @@ -1946,6 +1970,27 @@ module ActiveRecord end end + def ==(other) + other.is_a?(JoinAssociation) && + other.reflection == reflection && + other.parent == parent + end + + def find_parent_in(other_join_dependency) + other_join_dependency.joins.detect do |join| + self.parent == join + end + end + + def join_class + @join_class ||= Arel::InnerJoin + end + + def with_join_class(join_class) + @join_class = join_class + self + end + def association_join return @join if @join @@ -2045,31 +2090,29 @@ module ActiveRecord end def join_relation(joining_relation, join = nil) - if (relations = relation).is_a?(Array) - joining_relation.joins(Relation::JoinOperation.new(relations.first, Arel::OuterJoin, association_join.first)). - joins(Relation::JoinOperation.new(relations.last, Arel::OuterJoin, association_join.last)) - else - joining_relation.joins(Relation::JoinOperation.new(relations, Arel::OuterJoin, association_join)) - end + joining_relation.joins(self.with_join_class(Arel::OuterJoin)) end protected def aliased_table_name_for(name, suffix = nil) - if !parent.table_joins.blank? && parent.table_joins.to_s.downcase =~ %r{join(\s+\w+)?\s+#{active_record.connection.quote_table_name name.downcase}\son} - @join_dependency.table_aliases[name] += 1 + if @join_dependency.table_aliases[name].zero? + @join_dependency.table_aliases[name] = @join_dependency.count_aliases_from_table_joins(name) end - - unless @join_dependency.table_aliases[name].zero? - # if the table name has been used, then use an alias + + if !@join_dependency.table_aliases[name].zero? # We need an alias name = active_record.connection.table_alias_for "#{pluralize(reflection.name)}_#{parent_table_name}#{suffix}" - table_index = @join_dependency.table_aliases[name] @join_dependency.table_aliases[name] += 1 - name = name[0..active_record.connection.table_alias_length-3] + "_#{table_index+1}" if table_index > 0 + if @join_dependency.table_aliases[name] == 1 # First time we've seen this name + # Also need to count the aliases from the table_aliases to avoid incorrect count + @join_dependency.table_aliases[name] += @join_dependency.count_aliases_from_table_joins(name) + end + table_index = @join_dependency.table_aliases[name] + name = name[0..active_record.connection.table_alias_length-3] + "_#{table_index}" if table_index > 1 else @join_dependency.table_aliases[name] += 1 end - + name end diff --git a/activerecord/lib/active_record/relation/finder_methods.rb b/activerecord/lib/active_record/relation/finder_methods.rb index 3514d0a..d6144dc 100644 --- a/activerecord/lib/active_record/relation/finder_methods.rb +++ b/activerecord/lib/active_record/relation/finder_methods.rb @@ -188,7 +188,6 @@ module ActiveRecord def construct_relation_for_association_calculations including = (@eager_load_values + @includes_values).uniq join_dependency = ActiveRecord::Associations::ClassMethods::JoinDependency.new(@klass, including, arel.joins(arel)) - relation = except(:includes, :eager_load, :preload) apply_join_dependency(relation, join_dependency) end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 58af930..e53191f 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -79,6 +79,26 @@ module ActiveRecord def arel @arel ||= build_arel end + + def custom_join_sql(*joins) + arel = table + joins.each do |join| + next if join.blank? + + @implicit_readonly = true + + case join + when Hash, Array, Symbol + if array_of_strings?(join) + join_string = join.join(' ') + arel = arel.join(join_string) + end + else + arel = arel.join(join) + end + end + arel.joins(arel) + end def build_arel arel = table @@ -88,49 +108,40 @@ module ActiveRecord joins = @joins_values.map {|j| j.respond_to?(:strip) ? j.strip : j}.uniq - # Build association joins first joins.each do |join| association_joins << join if [Hash, Array, Symbol].include?(join.class) && !array_of_strings?(join) end - - if association_joins.any? - join_dependency = ActiveRecord::Associations::ClassMethods::JoinDependency.new(@klass, association_joins.uniq, nil) - to_join = [] - - join_dependency.join_associations.each do |association| - if (association_relation = association.relation).is_a?(Array) - to_join << [association_relation.first, association.association_join.first] - to_join << [association_relation.last, association.association_join.last] - else - to_join << [association_relation, association.association_join] - end - end - - to_join.each do |tj| - unless joined_associations.detect {|ja| ja[0] == tj[0] && ja[1] == tj[1] } - joined_associations << tj - arel = arel.join(tj[0]).on(*tj[1]) - end + + stashed_association_joins = joins.select {|j| j.is_a?(ActiveRecord::Associations::ClassMethods::JoinDependency::JoinAssociation)} + + non_association_joins = (joins - association_joins - stashed_association_joins).reject {|j| j.blank?} + custom_joins = custom_join_sql(*non_association_joins) + + join_dependency = ActiveRecord::Associations::ClassMethods::JoinDependency.new(@klass, association_joins, custom_joins) + + join_dependency.graft(*stashed_association_joins) + + @implicit_readonly = true unless association_joins.empty? && stashed_association_joins.empty? + + to_join = [] + + join_dependency.join_associations.each do |association| + if (association_relation = association.relation).is_a?(Array) + to_join << [association_relation.first, association.join_class, association.association_join.first] + to_join << [association_relation.last, association.join_class, association.association_join.last] + else + to_join << [association_relation, association.join_class, association.association_join] end end - - joins.each do |join| - next if join.blank? - - @implicit_readonly = true - - case join - when Relation::JoinOperation - arel = arel.join(join.relation, join.join_class).on(*join.on) - when Hash, Array, Symbol - if array_of_strings?(join) - join_string = join.join(' ') - arel = arel.join(join_string) - end - else - arel = arel.join(join) + + to_join.each do |tj| + unless joined_associations.detect {|ja| ja[0] == tj[0] && ja[1] == tj[1] && ja[2] == tj[2] } + joined_associations << tj + arel = arel.join(tj[0], tj[1]).on(*tj[2]) end end + + arel = arel.join(custom_joins) @where_values.uniq.each do |where| next if where.blank? diff --git a/activerecord/test/cases/associations/cascaded_eager_loading_test.rb b/activerecord/test/cases/associations/cascaded_eager_loading_test.rb index ed2e2e9..ea0d20e 100644 --- a/activerecord/test/cases/associations/cascaded_eager_loading_test.rb +++ b/activerecord/test/cases/associations/cascaded_eager_loading_test.rb @@ -18,7 +18,7 @@ class CascadedEagerLoadingTest < ActiveRecord::TestCase assert_equal 1, authors[1].posts.size assert_equal 9, authors[0].posts.collect{|post| post.comments.size }.inject(0){|sum,i| sum+i} end - + def test_eager_association_loading_with_cascaded_two_levels_and_one_level authors = Author.find(:all, :include=>[{:posts=>:comments}, :categorizations], :order=>"authors.id") assert_equal 2, authors.size @@ -28,6 +28,15 @@ class CascadedEagerLoadingTest < ActiveRecord::TestCase assert_equal 1, authors[0].categorizations.size assert_equal 2, authors[1].categorizations.size end + + def test_eager_association_loading_with_hmt_does_not_table_name_collide_when_joining_associations + assert_nothing_raised do + Author.joins(:posts).eager_load(:comments).where(:posts => {:taggings_count => 1}).all + end + authors = Author.joins(:posts).eager_load(:comments).where(:posts => {:taggings_count => 1}).all + assert_equal 1, assert_no_queries { authors.size } + assert_equal 9, assert_no_queries { authors[0].comments.size } + end def test_eager_association_loading_with_cascaded_two_levels_with_two_has_many_associations authors = Author.find(:all, :include=>{:posts=>[:comments, :categorizations]}, :order=>"authors.id") -- 1.6.4.4