diff --git a/sunburnt/schema.py b/sunburnt/schema.py index 7ba818f..2008aa6 100644 --- a/sunburnt/schema.py +++ b/sunburnt/schema.py @@ -489,6 +489,11 @@ def parse_result_doc(self, doc, name=None): elif field_class is None: raise SolrError("unexpected field found in result") return name, SolrFieldInstance.from_solr(field_class, doc.text or '').to_user_data() + + def parse_group(self, group, value=None): + if value is None: + value = group.xpath("str[@name='groupValue']")[0].text + return value, [self.parse_result_doc(n) for n in group.xpath("result/doc")] class SolrUpdate(object): @@ -613,8 +618,12 @@ def __init__(self, schema, xmlmsg): setattr(self, attr, details['responseHeader'].get(attr)) if self.status != 0: raise ValueError("Response indicates an error") - result_node = doc.xpath("/response/result")[0] - self.result = SolrResult(schema, result_node) + result_node_list = doc.xpath("/response/result") + group_node_list = doc.xpath("/response/lst[@name='grouped']") + if result_node_list: + self.result = SolrResult(schema, result_node_list[0]) + else: + self.result = SolrResult(schema, group_node_list[0]) self.facet_counts = SolrFacetCounts.from_response(details) self.highlighting = dict((k, dict(v)) for k, v in details.get("highlighting", ())) @@ -641,12 +650,25 @@ def __getitem__(self, key): class SolrResult(object): def __init__(self, schema, node): + self.grouped = True if (node.tag == 'lst' and node.attrib['name'] == 'grouped') else False self.schema = schema self.name = node.attrib['name'] - self.numFound = int(node.attrib['numFound']) - self.start = int(node.attrib['start']) + self.numFound = node.xpath("lst/int[@name='matches']")[0].text if self.grouped else int(node.attrib['numFound']) + + if self.grouped: + ngroups = node.xpath("lst/int[@name='ngroups']") + if ngroups: + self.ngroups = int(ngroups[0].text) + self.groupField = node.xpath("lst")[0].attrib['name'] + + if 'start' in node.attrib: + self.start = int(node.attrib['start']) + else: + start_param = node.xpath("../lst[@name='responseHeader']/lst[@name='params']/str[@name='start']") + self.start = start_param[0].text if start_param else 0 self.docs = [schema.parse_result_doc(n) for n in node.xpath("doc")] - + self.groups = [schema.parse_group(n) for n in node.xpath("lst/arr[@name='groups']/lst")] + def __str__(self): return "%(numFound)s results found, starting at #%(start)s\n\n" % self.__dict__ + str(self.docs) @@ -720,6 +742,8 @@ def value_from_node(node): value = float(node.text) elif node.tag == 'date': value = solr_date(node.text) + elif node.tag == 'result': + value = [value_from_node(n) for n in node.getchildren()] if name is not None: return name, value else: diff --git a/sunburnt/search.py b/sunburnt/search.py index 8267ac3..5ce83d0 100644 --- a/sunburnt/search.py +++ b/sunburnt/search.py @@ -358,7 +358,7 @@ def add_boost(self, kwargs, boost_score): class SolrSearch(object): - option_modules = ('query_obj', 'filter_obj', 'paginator', 'more_like_this', 'highlighter', 'faceter', 'sorter', 'facet_querier', 'field_limiter',) + option_modules = ('query_obj', 'filter_obj', 'paginator', 'more_like_this', 'highlighter', 'faceter', 'grouper', 'sorter', 'facet_querier', 'field_limiter',) def __init__(self, interface, original=None): self.interface = interface self.schema = interface.schema @@ -369,6 +369,7 @@ def __init__(self, interface, original=None): self.more_like_this = MoreLikeThisOptions(self.schema) self.highlighter = HighlightOptions(self.schema) self.faceter = FacetOptions(self.schema) + self.grouper = GroupOptions(self.schema) self.sorter = SortOptions(self.schema) self.field_limiter = FieldLimitOptions(self.schema) self.facet_querier = FacetQueryOptions(self.schema) @@ -418,6 +419,16 @@ def facet_by(self, field, **kwargs): newself = self.clone() newself.faceter.update(field, **kwargs) return newself + + def group_by(self, field, **kwargs): + newself = self.clone() + kwargs['field'] = field + + if not kwargs.has_key('ngroups'): + kwargs['ngroups'] = True + + newself.grouper.update(None, **kwargs) + return newself def facet_query(self, *args, **kwargs): newself = self.clone() @@ -554,6 +565,25 @@ def __init__(self, schema, original=None): def field_names_in_opts(self, opts, fields): if fields: opts["facet.field"] = sorted(fields) + +class GroupOptions(Options): + option_name = "group" + opts = {"field":unicode, + "limit":int, + "main":bool, + "ngroups":bool + } + + def __init__(self, schema, original=None): + self.schema = schema + if original is None: + self.fields = collections.defaultdict(dict) + else: + self.fields = copy.copy(original.fields) + + def field_names_in_opts(self, opts, fields): + if fields: + opts["facet.field"] = sorted(fields) class HighlightOptions(Options):