diff --git a/lib/xpath/dsl.rb b/lib/xpath/dsl.rb index b6de457..21fe426 100644 --- a/lib/xpath/dsl.rb +++ b/lib/xpath/dsl.rb @@ -64,6 +64,10 @@ def union(*expressions) end alias_method :+, :union + def group + Expression.new(:group, current) + end + def last function(:last) end diff --git a/lib/xpath/renderer.rb b/lib/xpath/renderer.rb index 6d1d8aa..b9db890 100644 --- a/lib/xpath/renderer.rb +++ b/lib/xpath/renderer.rb @@ -11,6 +11,8 @@ def initialize(type) end def render(node) + return render_grouped_where(node) if grouped_where_clause?(node) + arguments = node.arguments.map { |argument| convert_argument(argument) } send(node.expression, *arguments) end @@ -107,8 +109,54 @@ def function(name, *arguments) "#{name}(#{arguments.join(', ')})" end + def group(expression) + "(#{expression})" + end + private + def grouped_where_clause?(node) + node.expression == :where && + node.arguments.length == 2 && + node.arguments[0].is_a?(Expression) && + node.arguments[0].expression == :group + end + + def render_grouped_where(node) + group_content = render(node.arguments[0].arguments[0]) + condition = convert_argument(node.arguments[1]) + condition = unwrap_outer_parentheses(condition) + + "(#{group_content})[#{condition}]" + end + + def unwrap_outer_parentheses(condition) + return condition unless wrapped_in_parentheses?(condition) + return condition unless balanced_inner_parentheses?(condition) + + condition[1..-2] + end + + def wrapped_in_parentheses?(string) + string.start_with?('(') && string.end_with?(')') + end + + def balanced_inner_parentheses?(string) + inner_content = string[1..-2] + parentheses_count = 0 + + inner_content.each_char do |char| + case char + when '(' then parentheses_count += 1 + when ')' then parentheses_count -= 1 + end + + return false if parentheses_count.negative? + end + + parentheses_count.zero? + end + def with_element_conditions(expression, element_names) if element_names.length == 1 "#{expression}#{element_names.first}" diff --git a/spec/xpath_spec.rb b/spec/xpath_spec.rb index ae18db3..9b93790 100644 --- a/spec/xpath_spec.rb +++ b/spec/xpath_spec.rb @@ -559,4 +559,15 @@ def xpath(type = nil, &block) expect(@results[2][:id]).to eq 'foo' end end + + describe '#group' do + it 'wraps expressions in parentheses' do + expect(XPath.descendant(:div).group.to_xpath).to eq '(.//div)' + end + + it 'allows predicates to apply to grouped expressions' do + grouped = XPath.descendant(:div).attr(:id).group[XPath.position == XPath.last] + expect(grouped.to_xpath).to eq '(.//div/@id)[position() = last()]' + end + end end