Skip to content

Commit ac11c50

Browse files
committed
Fix:If multiple targets are part of an OCI provider record operation, create a new record for each target.
1 parent 909519f commit ac11c50

File tree

2 files changed

+137
-10
lines changed

2 files changed

+137
-10
lines changed

provider/oci/oci.go

+62-3
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,39 @@ func (p *OCIProvider) zones(ctx context.Context) (map[string]dns.ZoneSummary, er
170170
return zones, nil
171171
}
172172

173+
// Merge Endpoints with the same Name and Type into a single endpoint with multiple Targets.
174+
func mergeEndpointsMultiTargets(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint {
175+
endpointsByNameType := map[string][]*endpoint.Endpoint{}
176+
177+
for _, ep := range endpoints {
178+
key := fmt.Sprintf("%s-%s", ep.DNSName, ep.RecordType)
179+
endpointsByNameType[key] = append(endpointsByNameType[key], ep)
180+
}
181+
182+
// If there were no merges, return endpoints.
183+
if len(endpointsByNameType) == len(endpoints) {
184+
return endpoints
185+
}
186+
187+
// Otherwise, create a new list of endpoints with the consolidated targets.
188+
var mergedEndpoints []*endpoint.Endpoint
189+
for _, endpoints := range endpointsByNameType {
190+
dnsName := endpoints[0].DNSName
191+
recordType := endpoints[0].RecordType
192+
recordTTL := endpoints[0].RecordTTL
193+
194+
targets := make([]string, len(endpoints))
195+
for i, e := range endpoints {
196+
targets[i] = e.Targets[0]
197+
}
198+
199+
e := endpoint.NewEndpointWithTTL(dnsName, recordType, recordTTL, targets...)
200+
mergedEndpoints = append(mergedEndpoints, e)
201+
}
202+
203+
return mergedEndpoints
204+
}
205+
173206
func (p *OCIProvider) addPaginatedZones(ctx context.Context, zones map[string]dns.ZoneSummary, scope dns.GetZoneScopeEnum) error {
174207
var page *string
175208
// Loop until we have listed all zones.
@@ -200,9 +233,19 @@ func (p *OCIProvider) addPaginatedZones(ctx context.Context, zones map[string]dn
200233

201234
func (p *OCIProvider) newFilteredRecordOperations(endpoints []*endpoint.Endpoint, opType dns.RecordOperationOperationEnum) []dns.RecordOperation {
202235
ops := []dns.RecordOperation{}
203-
for _, endpoint := range endpoints {
204-
if p.domainFilter.Match(endpoint.DNSName) {
205-
ops = append(ops, newRecordOperation(endpoint, opType))
236+
for _, ep := range endpoints {
237+
if p.domainFilter.Match(ep.DNSName) {
238+
for _, t := range ep.Targets {
239+
singleTargetEp := &endpoint.Endpoint{
240+
DNSName: ep.DNSName,
241+
Targets: []string{t},
242+
RecordType: ep.RecordType,
243+
RecordTTL: ep.RecordTTL,
244+
Labels: ep.Labels,
245+
ProviderSpecific: ep.ProviderSpecific,
246+
}
247+
ops = append(ops, newRecordOperation(singleTargetEp, opType))
248+
}
206249
}
207250
}
208251
return ops
@@ -248,6 +291,8 @@ func (p *OCIProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error)
248291
}
249292
}
250293

294+
endpoints = mergeEndpointsMultiTargets(endpoints)
295+
251296
return endpoints, nil
252297
}
253298

@@ -299,6 +344,20 @@ func (p *OCIProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e
299344
return nil
300345
}
301346

347+
// AdjustEndpoints modifies the endpoints as needed by the specific provider
348+
func (p *OCIProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
349+
adjustedEndpoints := []*endpoint.Endpoint{}
350+
for _, e := range endpoints {
351+
// OCI DNS does not support the set-identifier attribute, so we remove it to avoid plan failure
352+
if e.SetIdentifier != "" {
353+
log.Warnf("Adjusting endpont: %v. Ignoring unsupported annotation 'set-identifier': %s", *e, e.SetIdentifier)
354+
e.SetIdentifier = ""
355+
}
356+
adjustedEndpoints = append(adjustedEndpoints, e)
357+
}
358+
return adjustedEndpoints, nil
359+
}
360+
302361
// newRecordOperation returns a RecordOperation based on a given endpoint.
303362
func newRecordOperation(ep *endpoint.Endpoint, opType dns.RecordOperationOperationEnum) dns.RecordOperation {
304363
targets := make([]string, len(ep.Targets))

provider/oci/oci_test.go

+75-7
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ func newMutableMockOCIDNSClient(zones []dns.ZoneSummary, recordsByZone map[strin
551551

552552
for zoneID, records := range recordsByZone {
553553
for _, record := range records {
554-
c.records[zoneID][ociRecordKey(*record.Rtype, *record.Domain)] = record
554+
c.records[zoneID][ociRecordKey(*record.Rtype, *record.Domain, *record.Rdata)] = record
555555
}
556556
}
557557

@@ -587,8 +587,12 @@ func (c *mutableMockOCIDNSClient) GetZoneRecords(ctx context.Context, request dn
587587
return
588588
}
589589

590-
func ociRecordKey(rType, domain string) string {
591-
return rType + "/" + domain
590+
func ociRecordKey(rType, domain string, ip string) string {
591+
rdata := ""
592+
if rType == "A" { // adds support for multi-targets with same rtype and domain
593+
rdata = "_" + ip
594+
}
595+
return rType + "_" + domain + rdata
592596
}
593597

594598
func (c *mutableMockOCIDNSClient) PatchZoneRecords(ctx context.Context, request dns.PatchZoneRecordsRequest) (response dns.PatchZoneRecordsResponse, err error) {
@@ -609,7 +613,7 @@ func (c *mutableMockOCIDNSClient) PatchZoneRecords(ctx context.Context, request
609613
})
610614

611615
for _, op := range request.Items {
612-
k := ociRecordKey(*op.Rtype, *op.Domain)
616+
k := ociRecordKey(*op.Rtype, *op.Domain, *op.Rdata)
613617
switch op.Operation {
614618
case dns.RecordOperationOperationAdd:
615619
records[k] = dns.Record{
@@ -850,21 +854,26 @@ func TestOCIApplyChanges(t *testing.T) {
850854
Rtype: common.String(endpoint.RecordTypeA),
851855
Ttl: common.Int(ociRecordTTL),
852856
}, {
853-
Domain: common.String("bar.foo.com"),
857+
Domain: common.String("car.foo.com"),
854858
Rdata: common.String("bar.com."),
855859
Rtype: common.String(endpoint.RecordTypeCNAME),
856860
Ttl: common.Int(ociRecordTTL),
861+
}, {
862+
Domain: common.String("bar.foo.com"),
863+
Rdata: common.String("baz.com."),
864+
Rtype: common.String(endpoint.RecordTypeCNAME),
865+
Ttl: common.Int(ociRecordTTL),
857866
}},
858867
},
859868
changes: &plan.Changes{
860869
Delete: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
861870
"foo.foo.com",
862871
endpoint.RecordTypeA,
863872
endpoint.TTL(ociRecordTTL),
864-
"baz.com.",
873+
"127.0.0.1",
865874
)},
866875
UpdateOld: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
867-
"bar.foo.com",
876+
"car.foo.com",
868877
endpoint.RecordTypeCNAME,
869878
endpoint.TTL(ociRecordTTL),
870879
"baz.com.",
@@ -896,6 +905,65 @@ func TestOCIApplyChanges(t *testing.T) {
896905
"127.0.0.1"),
897906
},
898907
},
908+
{
909+
name: "combine_multi_target",
910+
zones: []dns.ZoneSummary{{
911+
Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"),
912+
Name: common.String("foo.com"),
913+
}},
914+
915+
changes: &plan.Changes{
916+
Create: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
917+
"foo.foo.com",
918+
endpoint.RecordTypeA,
919+
endpoint.TTL(ociRecordTTL),
920+
"192.168.1.2",
921+
), endpoint.NewEndpointWithTTL(
922+
"foo.foo.com",
923+
endpoint.RecordTypeA,
924+
endpoint.TTL(ociRecordTTL),
925+
"192.168.2.5",
926+
)},
927+
},
928+
expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
929+
"foo.foo.com",
930+
endpoint.RecordTypeA,
931+
endpoint.TTL(ociRecordTTL), "192.168.1.2", "192.168.2.5",
932+
)},
933+
},
934+
{
935+
name: "remove_from_multi_target",
936+
zones: []dns.ZoneSummary{{
937+
Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"),
938+
Name: common.String("foo.com"),
939+
}},
940+
records: map[string][]dns.Record{
941+
"ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959": {{
942+
Domain: common.String("foo.foo.com"),
943+
Rdata: common.String("192.168.1.2"),
944+
Rtype: common.String(endpoint.RecordTypeA),
945+
Ttl: common.Int(ociRecordTTL),
946+
}, {
947+
Domain: common.String("foo.foo.com"),
948+
Rdata: common.String("192.168.2.5"),
949+
Rtype: common.String(endpoint.RecordTypeA),
950+
Ttl: common.Int(ociRecordTTL),
951+
}},
952+
},
953+
changes: &plan.Changes{
954+
Delete: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
955+
"foo.foo.com",
956+
endpoint.RecordTypeA,
957+
endpoint.TTL(ociRecordTTL),
958+
"192.168.1.2",
959+
)},
960+
},
961+
expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL(
962+
"foo.foo.com",
963+
endpoint.RecordTypeA,
964+
endpoint.TTL(ociRecordTTL), "192.168.2.5",
965+
)},
966+
},
899967
}
900968

901969
for _, tc := range testCases {

0 commit comments

Comments
 (0)