diff --git a/cassandra/policies.py b/cassandra/policies.py index a679bff877..d681980d77 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -526,14 +526,16 @@ def make_query_plan(self, working_keyspace=None, query=None): if self.shuffle_replicas: shuffle(replicas) - for replica in replicas: - if replica.is_up and child.distance(replica) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]: - yield replica - - for host in child.make_query_plan(keyspace, query): - # skip if we've already listed this host - if host not in replicas or child.distance(host) == HostDistance.REMOTE: - yield host + def yield_in_order(hosts): + for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: + for replica in hosts: + if replica.is_up and child.distance(replica) == distance: + yield replica + + # yield replicas: local_rack, local, remote + yield from yield_in_order(replicas) + # yield rest of the cluster: local_rack, local, remote + yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e65a89bca7..084b2c3137 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -636,7 +636,7 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas - policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)) + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2)) policy.populate(cluster, hosts) for i in range(4): @@ -648,14 +648,75 @@ def get_replicas(keyspace, packed_key): assert qplan[0] in replicas assert qplan[0].datacenter == "dc1" - # then the local non-replica - assert qplan[1] not in replicas + # then the replica from remote DC + assert qplan[1] in replicas + assert qplan[1].datacenter == "dc2" + + # then non-replica from local DC + assert qplan[2] not in replicas + assert qplan[2].datacenter == "dc1" + + # and only then non-replica from remote DC + assert qplan[3] not in replicas + assert qplan[3].datacenter == "dc2" + + assert 4 == len(qplan) + + def test_wrap_rack_aware(self): + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(8)] + for host in hosts: + host.set_up() + hosts[0].set_location_info("dc1", "rack1") + hosts[1].set_location_info("dc1", "rack2") + hosts[2].set_location_info("dc2", "rack1") + hosts[3].set_location_info("dc2", "rack2") + hosts[4].set_location_info("dc1", "rack1") + hosts[5].set_location_info("dc1", "rack2") + hosts[6].set_location_info("dc2", "rack1") + hosts[7].set_location_info("dc2", "rack2") + + def get_replicas(keyspace, packed_key): + index = struct.unpack('>i', packed_key)[0] + # return one node from each DC + if index % 2 == 0: + return [hosts[0], hosts[1], hosts[2], hosts[3]] + else: + return [hosts[4], hosts[5], hosts[6], hosts[7]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy(RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=4)) + policy.populate(cluster, hosts) + + for i in range(4): + query = Statement(routing_key=struct.pack('>i', i), keyspace='keyspace_name') + qplan = list(policy.make_query_plan(None, query)) + replicas = get_replicas(None, struct.pack('>i', i)) + + print(qplan) + print(replicas) + + # first should be replica from local rack local dc + assert qplan[0] in replicas + assert qplan[0].datacenter == "dc1" + assert qplan[0].rack == "rack1" + + # second should be replica from remote rack local dc + assert qplan[1] in replicas assert qplan[1].datacenter == "dc1" + assert qplan[1].rack == "rack2" - # then one of the remotes (used_hosts_per_remote_dc is 1, so we - # shouldn't see two remotes) + # third and forth should be replica from the remote dcs + assert qplan[2] in replicas assert qplan[2].datacenter == "dc2" - assert 3 == len(qplan) + assert qplan[3] in replicas + assert qplan[3].datacenter == "dc2" + + assert 8 == len(qplan) class FakeCluster: def __init__(self):