Skip to content

Commit dd0d61a

Browse files
Add support for mTLS authentication in Arrow Flight client
1 parent 5bbdf93 commit dd0d61a

15 files changed

+489
-19
lines changed

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ public enum ArrowErrorCode
2727
ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR),
2828
ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL),
2929
ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL),
30-
ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL);
30+
ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL),
31+
ARROW_FLIGHT_INVALID_KEY_ERROR(5, INTERNAL_ERROR),
32+
ARROW_FLIGHT_INVALID_CERT_ERROR(6, INTERNAL_ERROR);
3133

3234
private final ErrorCode errorCode;
3335

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public class ArrowFlightConfig
2020
private String server;
2121
private boolean verifyServer = true;
2222
private String flightServerSSLCertificate;
23+
private String flightClientSSLCertificate;
24+
private String flightClientSSLKey;
2325
private boolean arrowFlightServerSslEnabled;
2426
private Integer arrowFlightPort;
2527

@@ -82,4 +84,38 @@ public ArrowFlightConfig setArrowFlightServerSslEnabled(boolean arrowFlightServe
8284
this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled;
8385
return this;
8486
}
87+
88+
public String getFlightClientSSLCertificate()
89+
{
90+
return flightClientSSLCertificate;
91+
}
92+
93+
/***
94+
* Set the client SSL certificate used for mTLS authentication with Flight server.
95+
* @param flightClientSSLCertificate path to the certificate file
96+
* @return Returns this config instance
97+
*/
98+
@Config("arrow-flight.client-ssl-certificate")
99+
public ArrowFlightConfig setFlightClientSSLCertificate(String flightClientSSLCertificate)
100+
{
101+
this.flightClientSSLCertificate = flightClientSSLCertificate;
102+
return this;
103+
}
104+
105+
public String getFlightClientSSLKey()
106+
{
107+
return flightClientSSLKey;
108+
}
109+
110+
/***
111+
* Set the client SSL key used for mTLS authentication with Flight server
112+
* @param flightClientSSLKey path to the key file
113+
* @return Returns this config instance
114+
*/
115+
@Config("arrow-flight.client-ssl-key")
116+
public ArrowFlightConfig setFlightClientSSLKey(String flightClientSSLKey)
117+
{
118+
this.flightClientSSLKey = flightClientSSLKey;
119+
return this;
120+
}
85121
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
*/
1414
package com.facebook.plugin.arrow;
1515

16+
import com.facebook.airlift.log.Logger;
1617
import com.facebook.presto.spi.ConnectorSession;
1718
import com.facebook.presto.spi.SchemaTableName;
1819
import org.apache.arrow.flight.CallOption;
@@ -30,11 +31,15 @@
3031
import java.net.URISyntaxException;
3132
import java.nio.ByteBuffer;
3233
import java.nio.file.Paths;
34+
import java.security.InvalidKeyException;
35+
import java.security.cert.CertificateException;
3336
import java.util.List;
3437
import java.util.Optional;
3538

3639
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
3740
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR;
41+
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_CERT_ERROR;
42+
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_KEY_ERROR;
3843
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
3944
import static java.nio.file.Files.newInputStream;
4045
import static java.util.Objects.requireNonNull;
@@ -43,6 +48,7 @@ public abstract class BaseArrowFlightClientHandler
4348
{
4449
private final ArrowFlightConfig config;
4550
private final BufferAllocator allocator;
51+
private static final Logger logger = Logger.get(BaseArrowFlightClientHandler.class);
4652

4753
public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig config)
4854
{
@@ -64,24 +70,61 @@ protected FlightClient createFlightClient()
6470

6571
protected FlightClient createFlightClient(Location location)
6672
{
73+
Optional<InputStream> trustedCertificate = Optional.empty();
74+
Optional<InputStream> clientCertificate = Optional.empty();
75+
Optional<InputStream> clientKey = Optional.empty();
6776
try {
68-
Optional<InputStream> trustedCertificate = Optional.empty();
6977
FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
7078
flightClientBuilder.verifyServer(config.getVerifyServer());
7179
if (config.getFlightServerSSLCertificate() != null) {
7280
trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate())));
7381
flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
7482
}
75-
76-
FlightClient flightClient = flightClientBuilder.build();
77-
if (trustedCertificate.isPresent()) {
78-
trustedCertificate.get().close();
83+
if (config.getFlightClientSSLCertificate() != null && config.getFlightClientSSLKey() != null) {
84+
clientCertificate = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLCertificate())));
85+
clientKey = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLKey())));
86+
flightClientBuilder.clientCertificate(clientCertificate.get(), clientKey.get()).useTls();
7987
}
8088

81-
return flightClient;
89+
return flightClientBuilder.build();
8290
}
8391
catch (Exception e) {
84-
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e);
92+
Optional<Throwable> cause = Optional.ofNullable(e.getCause());
93+
if (cause.filter(c -> c instanceof InvalidKeyException).isPresent()) {
94+
throw new ArrowException(ARROW_FLIGHT_INVALID_KEY_ERROR, "Error creating flight client, invalid key file: " + e.getMessage(), e);
95+
}
96+
else if (cause.filter(c -> c instanceof CertificateException).isPresent()) {
97+
throw new ArrowException(ARROW_FLIGHT_INVALID_CERT_ERROR, "Error creating flight client, invalid certificate file: " + e.getMessage(), e);
98+
}
99+
else {
100+
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e);
101+
}
102+
}
103+
finally {
104+
if (trustedCertificate.isPresent()) {
105+
try {
106+
trustedCertificate.get().close();
107+
}
108+
catch (IOException e) {
109+
logger.error("Error closing input stream for server certificate", e);
110+
}
111+
}
112+
if (clientCertificate.isPresent()) {
113+
try {
114+
clientCertificate.get().close();
115+
}
116+
catch (IOException e) {
117+
logger.error("Error closing input stream for client certificate", e);
118+
}
119+
}
120+
if (clientKey.isPresent()) {
121+
try {
122+
clientKey.get().close();
123+
}
124+
catch (IOException e) {
125+
logger.error("Error closing input stream for client key", e);
126+
}
127+
}
85128
}
86129
}
87130

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow;
15+
16+
import com.facebook.airlift.log.Logger;
17+
import com.facebook.plugin.arrow.testingServer.TestingArrowProducer;
18+
import com.facebook.presto.testing.QueryRunner;
19+
import com.facebook.presto.tests.AbstractTestQueryFramework;
20+
import com.facebook.presto.tests.DistributedQueryRunner;
21+
import com.google.common.collect.ImmutableMap;
22+
import org.apache.arrow.flight.FlightServer;
23+
import org.apache.arrow.flight.Location;
24+
import org.apache.arrow.memory.RootAllocator;
25+
import org.testng.annotations.AfterClass;
26+
import org.testng.annotations.BeforeClass;
27+
28+
import java.io.File;
29+
import java.io.IOException;
30+
import java.util.Map;
31+
import java.util.Optional;
32+
33+
public abstract class AbstractArrowFlightMTLSTestFramework
34+
extends AbstractTestQueryFramework
35+
{
36+
private static final Logger logger = Logger.get(AbstractArrowFlightMTLSTestFramework.class);
37+
private final int serverPort;
38+
private RootAllocator allocator;
39+
private FlightServer server;
40+
private DistributedQueryRunner arrowFlightQueryRunner;
41+
42+
public AbstractArrowFlightMTLSTestFramework()
43+
throws IOException
44+
{
45+
this.serverPort = ArrowFlightQueryRunner.findUnusedPort();
46+
}
47+
48+
@BeforeClass
49+
void setup()
50+
throws Exception
51+
{
52+
arrowFlightQueryRunner = getDistributedQueryRunner();
53+
File certChainFile = new File("src/test/resources/mtls/server.crt");
54+
File privateKeyFile = new File("src/test/resources/mtls/server.key");
55+
File caCertFile = new File("src/test/resources/mtls/ca.crt");
56+
57+
allocator = new RootAllocator(Long.MAX_VALUE);
58+
59+
Location location = Location.forGrpcTls("localhost", serverPort);
60+
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator))
61+
.useTls(certChainFile, privateKeyFile)
62+
.useMTlsClientVerification(caCertFile)
63+
.build();
64+
65+
server.start();
66+
logger.info("Server listening on port %s", server.getPort());
67+
}
68+
69+
@AfterClass(alwaysRun = true)
70+
void tearDown()
71+
throws InterruptedException
72+
{
73+
arrowFlightQueryRunner.close();
74+
server.close();
75+
allocator.close();
76+
}
77+
78+
@Override
79+
protected QueryRunner createQueryRunner()
80+
throws Exception
81+
{
82+
return ArrowFlightQueryRunner.createQueryRunner(ImmutableMap.of(), getCatalogProperties(), ImmutableMap.of(), Optional.empty());
83+
}
84+
85+
abstract Map<String, String> getCatalogProperties();
86+
87+
int getServerPort()
88+
{
89+
return serverPort;
90+
}
91+
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,15 @@ public static DistributedQueryRunner createQueryRunner(
6464
Optional<BiFunction<Integer, URI, Process>> externalWorkerLauncher)
6565
throws Exception
6666
{
67-
return createQueryRunner(extraProperties, ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)), coordinatorProperties, externalWorkerLauncher);
67+
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
68+
.put("arrow-flight.server.port", String.valueOf(flightServerPort))
69+
.put("arrow-flight.server", "localhost")
70+
.put("arrow-flight.server-ssl-enabled", "true")
71+
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt");
72+
return createQueryRunner(extraProperties, catalogProperties.build(), coordinatorProperties, externalWorkerLauncher);
6873
}
6974

70-
private static DistributedQueryRunner createQueryRunner(
75+
protected static DistributedQueryRunner createQueryRunner(
7176
Map<String, String> extraProperties,
7277
Map<String, String> catalogProperties,
7378
Map<String, String> coordinatorProperties,
@@ -92,13 +97,7 @@ private static DistributedQueryRunner createQueryRunner(
9297
boolean nativeExecution = externalWorkerLauncher.isPresent();
9398
queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution));
9499

95-
ImmutableMap.Builder<String, String> properties = ImmutableMap.<String, String>builder()
96-
.putAll(catalogProperties)
97-
.put("arrow-flight.server", "localhost")
98-
.put("arrow-flight.server-ssl-enabled", "true")
99-
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt");
100-
101-
queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, properties.build());
100+
queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, catalogProperties);
102101

103102
return queryRunner;
104103
}
@@ -140,8 +139,8 @@ public static void main(String[] args)
140139
log.info("Server listening on port " + server.getPort());
141140

142141
DistributedQueryRunner queryRunner = createQueryRunner(
142+
9443,
143143
ImmutableMap.of("http-server.http.port", "8080"),
144-
ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443)),
145144
ImmutableMap.of(),
146145
Optional.empty());
147146
Thread.sleep(10);
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow;
15+
16+
import com.facebook.airlift.log.Logger;
17+
import com.google.common.collect.ImmutableMap;
18+
import org.testng.annotations.Test;
19+
20+
import java.io.IOException;
21+
import java.util.Map;
22+
23+
public class TestArrowFlightMTLS
24+
extends AbstractArrowFlightMTLSTestFramework
25+
{
26+
private static final Logger logger = Logger.get(TestArrowFlightMTLS.class);
27+
28+
public TestArrowFlightMTLS()
29+
throws IOException
30+
{
31+
super();
32+
}
33+
34+
@Override
35+
Map<String, String> getCatalogProperties()
36+
{
37+
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
38+
.put("arrow-flight.server.port", String.valueOf(getServerPort()))
39+
.put("arrow-flight.server", "localhost")
40+
.put("arrow-flight.server-ssl-enabled", "true")
41+
.put("arrow-flight.server-ssl-certificate", "src/test/resources/mtls/server.crt")
42+
.put("arrow-flight.client-ssl-certificate", "src/test/resources/mtls/client.crt")
43+
.put("arrow-flight.client-ssl-key", "src/test/resources/mtls/client.key");
44+
return catalogProperties.build();
45+
}
46+
47+
@Test
48+
public void testMtls()
49+
{
50+
assertQuery("SELECT COUNT(*) FROM orders");
51+
}
52+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow;
15+
16+
import com.facebook.airlift.log.Logger;
17+
import com.google.common.collect.ImmutableMap;
18+
import org.testng.annotations.Test;
19+
20+
import java.io.IOException;
21+
import java.util.Map;
22+
23+
public class TestArrowFlightMTLSFails
24+
extends AbstractArrowFlightMTLSTestFramework
25+
{
26+
private static final Logger logger = Logger.get(TestArrowFlightMTLSFails.class);
27+
28+
public TestArrowFlightMTLSFails()
29+
throws IOException
30+
{
31+
super();
32+
}
33+
34+
@Override
35+
Map<String, String> getCatalogProperties()
36+
{
37+
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
38+
.put("arrow-flight.server.port", String.valueOf(getServerPort()))
39+
.put("arrow-flight.server", "localhost")
40+
.put("arrow-flight.server-ssl-enabled", "true")
41+
.put("arrow-flight.server-ssl-certificate", "src/test/resources/mtls/server.crt");
42+
return catalogProperties.build();
43+
}
44+
45+
@Test
46+
public void testMtlsFailure()
47+
{
48+
assertQueryFails("SELECT COUNT(*) FROM orders", "ssl exception");
49+
}
50+
}

0 commit comments

Comments
 (0)