Skip to content

Commit 0567531

Browse files
xds: xDS based SNI setting and SAN validation (#12378)
When using xDS credentials make SNI for the Tls handshake to be configured via xDS, rather than use the channel authority as the SNI, and make SAN validation to be able to use the SNI sent when so instructed via xDS. Implements A101.
1 parent 82f9b8e commit 0567531

32 files changed

+1073
-419
lines changed

core/src/main/java/io/grpc/internal/CertificateUtils.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,29 @@
2626
import java.security.cert.CertificateFactory;
2727
import java.security.cert.X509Certificate;
2828
import java.util.Collection;
29+
import java.util.List;
2930
import javax.net.ssl.TrustManager;
3031
import javax.net.ssl.TrustManagerFactory;
32+
import javax.net.ssl.X509TrustManager;
3133
import javax.security.auth.x500.X500Principal;
3234

3335
/**
3436
* Contains certificate/key PEM file utility method(s) for internal usage.
3537
*/
3638
public final class CertificateUtils {
39+
private static final Class<?> x509ExtendedTrustManagerClass;
40+
41+
static {
42+
Class<?> x509ExtendedTrustManagerClass1;
43+
try {
44+
x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager");
45+
} catch (ClassNotFoundException e) {
46+
x509ExtendedTrustManagerClass1 = null;
47+
// Will disallow per-rpc authority override via call option.
48+
}
49+
x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1;
50+
}
51+
3752
/**
3853
* Creates X509TrustManagers using the provided CA certs.
3954
*/
@@ -71,6 +86,17 @@ public static TrustManager[] createTrustManager(InputStream rootCerts)
7186
return trustManagerFactory.getTrustManagers();
7287
}
7388

89+
public static X509TrustManager getX509ExtendedTrustManager(List<TrustManager> trustManagers) {
90+
if (x509ExtendedTrustManagerClass != null) {
91+
for (TrustManager trustManager : trustManagers) {
92+
if (x509ExtendedTrustManagerClass.isInstance(trustManager)) {
93+
return (X509TrustManager) trustManager;
94+
}
95+
}
96+
}
97+
return null;
98+
}
99+
74100
private static X509Certificate[] getX509Certificates(InputStream inputStream)
75101
throws CertificateException {
76102
CertificateFactory factory = CertificateFactory.getInstance("X.509");

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import io.netty.handler.ssl.SslContext;
2727
import io.netty.util.AsciiString;
2828
import java.util.concurrent.Executor;
29+
import javax.net.ssl.X509TrustManager;
2930

3031
/**
3132
* Internal accessor for {@link ProtocolNegotiators}.
@@ -42,9 +43,11 @@ private InternalProtocolNegotiators() {}
4243
*/
4344
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
4445
ObjectPool<? extends Executor> executorPool,
45-
Optional<Runnable> handshakeCompleteRunnable) {
46+
Optional<Runnable> handshakeCompleteRunnable,
47+
X509TrustManager extendedX509TrustManager,
48+
String sni) {
4649
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
47-
executorPool, handshakeCompleteRunnable, null);
50+
executorPool, handshakeCompleteRunnable, extendedX509TrustManager, sni);
4851
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
4952

5053
@Override
@@ -62,17 +65,19 @@ public void close() {
6265
negotiator.close();
6366
}
6467
}
65-
68+
6669
return new TlsNegotiator();
6770
}
68-
71+
6972
/**
7073
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
7174
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
7275
* may happen immediately, even before the TLS Handshake is complete.
7376
*/
74-
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
75-
return tls(sslContext, null, Optional.absent());
77+
public static InternalProtocolNegotiator.ProtocolNegotiator tls(
78+
SslContext sslContext, String sni,
79+
X509TrustManager extendedX509TrustManager) {
80+
return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni);
7681
}
7782

7883
/**

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
652652
case PLAINTEXT_UPGRADE:
653653
return ProtocolNegotiators.plaintextUpgrade();
654654
case TLS:
655-
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null);
655+
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null);
656656
default:
657657
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
658658
}

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,6 @@ final class ProtocolNegotiators {
102102
private static final EnumSet<TlsServerCredentials.Feature> understoodServerTlsFeatures =
103103
EnumSet.of(
104104
TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS);
105-
private static Class<?> x509ExtendedTrustManagerClass;
106-
107-
static {
108-
try {
109-
x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager");
110-
} catch (ClassNotFoundException e) {
111-
// Will disallow per-rpc authority override via call option.
112-
}
113-
}
114105

115106
private ProtocolNegotiators() {
116107
}
@@ -147,15 +138,8 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) {
147138
trustManagers = Arrays.asList(tmf.getTrustManagers());
148139
}
149140
builder.trustManager(new FixedTrustManagerFactory(trustManagers));
150-
TrustManager x509ExtendedTrustManager = null;
151-
if (x509ExtendedTrustManagerClass != null) {
152-
for (TrustManager trustManager : trustManagers) {
153-
if (x509ExtendedTrustManagerClass.isInstance(trustManager)) {
154-
x509ExtendedTrustManager = trustManager;
155-
break;
156-
}
157-
}
158-
}
141+
TrustManager x509ExtendedTrustManager =
142+
CertificateUtils.getX509ExtendedTrustManager(trustManagers);
159143
return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(),
160144
(X509TrustManager) x509ExtendedTrustManager));
161145
} catch (SSLException | GeneralSecurityException ex) {
@@ -579,20 +563,22 @@ static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
579563

580564
public ClientTlsProtocolNegotiator(SslContext sslContext,
581565
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
582-
X509TrustManager x509ExtendedTrustManager) {
566+
X509TrustManager x509ExtendedTrustManager, String sni) {
583567
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
584568
this.executorPool = executorPool;
585569
if (this.executorPool != null) {
586570
this.executor = this.executorPool.getObject();
587571
}
588572
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
589573
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
574+
this.sni = sni;
590575
}
591576

592577
private final SslContext sslContext;
593578
private final ObjectPool<? extends Executor> executorPool;
594579
private final Optional<Runnable> handshakeCompleteRunnable;
595580
private final X509TrustManager x509ExtendedTrustManager;
581+
private final String sni;
596582
private Executor executor;
597583

598584
@Override
@@ -604,9 +590,17 @@ public AsciiString scheme() {
604590
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
605591
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
606592
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
607-
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
608-
this.executor, negotiationLogger, handshakeCompleteRunnable, this,
609-
x509ExtendedTrustManager);
593+
String authority;
594+
if ("".equals(sni)) {
595+
authority = null;
596+
} else if (sni != null) {
597+
authority = sni;
598+
} else {
599+
authority = grpcHandler.getAuthority();
600+
}
601+
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext,
602+
authority, this.executor, negotiationLogger, handshakeCompleteRunnable, this,
603+
x509ExtendedTrustManager);
610604
return new WaitUntilActiveHandler(cth, negotiationLogger);
611605
}
612606

@@ -630,28 +624,40 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
630624
private final int port;
631625
private Executor executor;
632626
private final Optional<Runnable> handshakeCompleteRunnable;
633-
private final X509TrustManager x509ExtendedTrustManager;
627+
private final X509TrustManager x509TrustManager;
634628
private SSLEngine sslEngine;
635629

636630
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
637631
Executor executor, ChannelLogger negotiationLogger,
638632
Optional<Runnable> handshakeCompleteRunnable,
639633
ClientTlsProtocolNegotiator clientTlsProtocolNegotiator,
640-
X509TrustManager x509ExtendedTrustManager) {
634+
X509TrustManager x509TrustManager) {
641635
super(next, negotiationLogger);
642636
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
643-
HostPort hostPort = parseAuthority(authority);
644-
this.host = hostPort.host;
645-
this.port = hostPort.port;
637+
// TODO: For empty authority and fallback flag
638+
// GRPC_USE_CHANNEL_AUTHORITY_IF_NO_SNI_APPLICABLE present, we should parse authority
639+
// but prevent it from being used for SAN validation in the TrustManager.
640+
if (authority != null) {
641+
HostPort hostPort = parseAuthority(authority);
642+
this.host = hostPort.host;
643+
this.port = hostPort.port;
644+
} else {
645+
this.host = null;
646+
this.port = 0;
647+
}
646648
this.executor = executor;
647649
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
648-
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
650+
this.x509TrustManager = x509TrustManager;
649651
}
650652

651653
@Override
652654
@IgnoreJRERequirement
653655
protected void handlerAdded0(ChannelHandlerContext ctx) {
654-
sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
656+
if (host != null) {
657+
sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
658+
} else {
659+
sslEngine = sslContext.newEngine(ctx.alloc());
660+
}
655661
SSLParameters sslParams = sslEngine.getSSLParameters();
656662
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
657663
sslEngine.setSSLParameters(sslParams);
@@ -709,7 +715,7 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session)
709715
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
710716
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
711717
.set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, new X509AuthorityVerifier(
712-
sslEngine, x509ExtendedTrustManager))
718+
sslEngine, x509TrustManager))
713719
.build();
714720
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
715721
if (handshakeCompleteRunnable.isPresent()) {
@@ -746,13 +752,14 @@ static HostPort parseAuthority(String authority) {
746752
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
747753
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
748754
* may happen immediately, even before the TLS Handshake is complete.
755+
*
749756
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
750757
*/
751758
public static ProtocolNegotiator tls(SslContext sslContext,
752759
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
753-
X509TrustManager x509ExtendedTrustManager) {
760+
X509TrustManager x509ExtendedTrustManager, String sni) {
754761
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable,
755-
x509ExtendedTrustManager);
762+
x509ExtendedTrustManager, sni);
756763
}
757764

758765
/**
@@ -762,7 +769,7 @@ public static ProtocolNegotiator tls(SslContext sslContext,
762769
*/
763770
public static ProtocolNegotiator tls(SslContext sslContext,
764771
X509TrustManager x509ExtendedTrustManager) {
765-
return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager);
772+
return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null);
766773
}
767774

768775
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext,

netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
877877
.keyManager(clientCert, clientKey)
878878
.build();
879879
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
880-
Optional.absent(), null);
880+
Optional.absent(), null, null);
881881
// after starting the client, the Executor in the client pool should be used
882882
assertEquals(true, clientExecutorPool.isInUse());
883883
final NettyClientTransport transport = newTransport(negotiator);

netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception {
10261026
private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException {
10271027
return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager(
10281028
TlsTesting.loadCert("ca.pem")).build(),
1029-
null, Optional.absent(), null);
1029+
null, Optional.absent(), null, "");
10301030
}
10311031

10321032
@Test
@@ -1277,7 +1277,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
12771277
}
12781278
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
12791279
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
1280-
null, Optional.absent(), null);
1280+
null, Optional.absent(), null, null);
12811281
WriteBufferingAndExceptionHandler clientWbaeh =
12821282
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
12831283

s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
3939
import io.grpc.netty.InternalProtocolNegotiators;
4040
import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler;
41-
import io.grpc.s2a.internal.handshaker.S2AIdentity;
4241
import io.netty.channel.ChannelHandler;
4342
import io.netty.channel.ChannelHandlerAdapter;
4443
import io.netty.channel.ChannelHandlerContext;
@@ -259,7 +258,8 @@ public void onSuccess(SslContext sslContext) {
259258
public void run() {
260259
s2aStub.close();
261260
}
262-
}))
261+
}),
262+
null, null)
263263
.newHandler(grpcHandler);
264264

265265
// Delegate the rest of the handshake to the TLS handler. and remove the

xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import io.grpc.xds.client.XdsClient;
5454
import io.grpc.xds.client.XdsLogger;
5555
import io.grpc.xds.client.XdsLogger.XdsLogLevel;
56+
import io.grpc.xds.internal.XdsInternalAttributes;
5657
import io.grpc.xds.internal.security.SecurityProtocolNegotiators;
5758
import io.grpc.xds.internal.security.SslContextProviderSupplier;
5859
import io.grpc.xds.orca.OrcaPerRequestUtil;
@@ -117,12 +118,12 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
117118
logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
118119
Attributes attributes = resolvedAddresses.getAttributes();
119120
if (xdsClientPool == null) {
120-
xdsClientPool = attributes.get(XdsAttributes.XDS_CLIENT_POOL);
121+
xdsClientPool = attributes.get(io.grpc.xds.XdsAttributes.XDS_CLIENT_POOL);
121122
assert xdsClientPool != null;
122123
xdsClient = xdsClientPool.getObject();
123124
}
124125
if (callCounterProvider == null) {
125-
callCounterProvider = attributes.get(XdsAttributes.CALL_COUNTER_PROVIDER);
126+
callCounterProvider = attributes.get(io.grpc.xds.XdsAttributes.CALL_COUNTER_PROVIDER);
126127
}
127128

128129
ClusterImplConfig config =
@@ -241,9 +242,9 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
241242
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference);
242243
if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) {
243244
String hostname = args.getAddresses().get(0).getAttributes()
244-
.get(XdsAttributes.ATTR_ADDRESS_NAME);
245+
.get(XdsInternalAttributes.ATTR_ADDRESS_NAME);
245246
if (hostname != null) {
246-
attrsBuilder.set(XdsAttributes.ATTR_ADDRESS_NAME, hostname);
247+
attrsBuilder.set(XdsInternalAttributes.ATTR_ADDRESS_NAME, hostname);
247248
}
248249
}
249250
args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build();
@@ -292,7 +293,7 @@ private List<EquivalentAddressGroup> withAdditionalAttributes(
292293
List<EquivalentAddressGroup> newAddresses = new ArrayList<>();
293294
for (EquivalentAddressGroup eag : addresses) {
294295
Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set(
295-
XdsAttributes.ATTR_CLUSTER_NAME, cluster);
296+
io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME, cluster);
296297
if (sslContextProviderSupplier != null) {
297298
attrBuilder.set(
298299
SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
@@ -304,7 +305,7 @@ private List<EquivalentAddressGroup> withAdditionalAttributes(
304305
}
305306

306307
private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAttributes) {
307-
Locality locality = addressAttributes.get(XdsAttributes.ATTR_LOCALITY);
308+
Locality locality = addressAttributes.get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY);
308309
String localityName = addressAttributes.get(EquivalentAddressGroup.ATTR_LOCALITY_NAME);
309310

310311
// Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain
@@ -438,7 +439,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
438439
result = PickResult.withSubchannel(result.getSubchannel(),
439440
result.getStreamTracerFactory(),
440441
result.getSubchannel().getAttributes().get(
441-
XdsAttributes.ATTR_ADDRESS_NAME));
442+
XdsInternalAttributes.ATTR_ADDRESS_NAME));
442443
}
443444
}
444445
return result;

0 commit comments

Comments
 (0)