@@ -131,8 +131,12 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
131131 }
132132
133133 for name , recordGroup := range recordsByName {
134+ // We assume that all records in a group have the same name, type, and ttl.
135+ // TODO(herko): Add error checking to ensure that the above is the case.
136+ firstRecord := recordGroup [0 ]
137+
134138 err := n .withRecordLock (name , func () error {
135- existingRecords , err := n .lookupSRVRecords (ctx , name )
139+ existingRecords , err := n .lookupRecords (ctx , firstRecord . Type , name )
136140 if err != nil {
137141 return err
138142 }
@@ -151,9 +155,6 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
151155 combinedRecords [record .Data ] = record
152156 }
153157
154- // We assume that all records in a group have the same name, type, and ttl.
155- // TODO(herko): Add error checking to ensure that the above is the case.
156- firstRecord := recordGroup [0 ]
157158 data := maps .Keys (combinedRecords )
158159 sort .Strings (data )
159160 zone := n .managedZone
@@ -194,24 +195,26 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
194195}
195196
196197// LookupSRVRecords implements the vm.DNSProvider interface.
197- func (n * dnsProvider ) LookupSRVRecords (ctx context.Context , name string ) ([]vm.DNSRecord , error ) {
198+ func (n * dnsProvider ) LookupRecords (
199+ ctx context.Context , recordType vm.DNSType , name string ,
200+ ) ([]vm.DNSRecord , error ) {
198201 var records []vm.DNSRecord
199202 var err error
200203 err = n .withRecordLock (name , func () error {
201- if config .FastDNS {
204+ if config .FastDNS && recordType == vm . SRV {
202205 rIdx := randutil .FastUint32 () % uint32 (len (n .resolvers ))
203206 records , err = n .fastLookupSRVRecords (ctx , n .resolvers [rIdx ], name , true )
204207 return err
205208 }
206- records , err = n .lookupSRVRecords (ctx , name )
209+ records , err = n .lookupRecords (ctx , recordType , name )
207210 return err
208211 })
209212 return records , err
210213}
211214
212215// ListRecords implements the vm.DNSProvider interface.
213216func (n * dnsProvider ) ListRecords (ctx context.Context ) ([]vm.DNSRecord , error ) {
214- return n .listSRVRecords (ctx , "" , dnsMaxResults )
217+ return n .listRecords (ctx , vm . SRV , "" , dnsMaxResults )
215218}
216219
217220func (n * dnsProvider ) deleteRecords (
@@ -253,7 +256,7 @@ func (n *dnsProvider) DeletePublicRecordsByName(ctx context.Context, names ...st
253256// DeleteRecordsBySubdomain implements the vm.DNSProvider interface.
254257func (n * dnsProvider ) DeleteSRVRecordsBySubdomain (ctx context.Context , subdomain string ) error {
255258 suffix := fmt .Sprintf ("%s.%s." , subdomain , n .Domain ())
256- records , err := n .listSRVRecords (ctx , suffix , dnsMaxResults )
259+ records , err := n .listRecords (ctx , vm . SRV , suffix , dnsMaxResults )
257260 if err != nil {
258261 return err
259262 }
@@ -287,13 +290,15 @@ func (n *dnsProvider) Domain() string {
287290// network problems. For lookups, we prefer this to using the gcloud command as
288291// it is faster, and preferable when service information is being queried
289292// regularly.
290- func (n * dnsProvider ) lookupSRVRecords (ctx context.Context , name string ) ([]vm.DNSRecord , error ) {
293+ func (n * dnsProvider ) lookupRecords (
294+ ctx context.Context , recordType vm.DNSType , name string ,
295+ ) ([]vm.DNSRecord , error ) {
291296 // Check the cache first.
292297 if cachedRecords , ok := n .getCache (name ); ok {
293298 return cachedRecords , nil
294299 }
295300 // Lookup the records, if no records are found in the cache.
296- records , err := n .listSRVRecords (ctx , name , dnsMaxResults )
301+ records , err := n .listRecords (ctx , recordType , name , dnsMaxResults )
297302 if err != nil {
298303 return nil , err
299304 }
@@ -310,16 +315,21 @@ func (n *dnsProvider) lookupSRVRecords(ctx context.Context, name string) ([]vm.D
310315 return filteredRecords , nil
311316}
312317
313- // listSRVRecords returns all SRV records that match the given filter from Google Cloud DNS.
318+ // listRecords returns all records that match the given filter from Google Cloud DNS.
314319// The data field of the records could be a comma-separated list of values if multiple
315320// records are returned for the same name.
316- func (n * dnsProvider ) listSRVRecords (
317- ctx context.Context , filter string , limit int ,
321+ func (n * dnsProvider ) listRecords (
322+ ctx context.Context , recordType vm. DNSType , filter string , limit int ,
318323) ([]vm.DNSRecord , error ) {
324+ zone := n .managedZone
325+ if recordType == vm .A {
326+ zone = n .publicZone
327+ }
328+
319329 args := []string {"--project" , n .dnsProject , "dns" , "record-sets" , "list" ,
320330 "--limit" , strconv .Itoa (limit ),
321331 "--page-size" , strconv .Itoa (limit ),
322- "--zone" , n . managedZone ,
332+ "--zone" , zone ,
323333 "--format" , "json" ,
324334 }
325335 if filter != "" {
@@ -348,11 +358,11 @@ func (n *dnsProvider) listSRVRecords(
348358 if record .Kind != "dns#resourceRecordSet" {
349359 continue
350360 }
351- if record .RecordType != string (vm . SRV ) {
361+ if record .RecordType != string (recordType ) {
352362 continue
353363 }
354364 for _ , data := range record .RRDatas {
355- records = append (records , vm .CreateDNSRecord (record .Name , vm . SRV , data , record .TTL ))
365+ records = append (records , vm .CreateDNSRecord (record .Name , recordType , data , record .TTL ))
356366 }
357367 }
358368 return records , nil
0 commit comments