diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 05832906d1..a72905b43b 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -26,7 +26,7 @@ timeout: args: [ls, -la] functions: - assume-test-secrets-ec2-role: + assume-test-secrets-ec2-role: - command: ec2.assume_role params: role_arn: ${aws_test_secrets_role} @@ -121,21 +121,21 @@ functions: params: binary: bash add_expansions_to_env: true - args: + args: - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/pull-mongohouse-image.sh - command: subprocess.exec params: binary: bash - args: + args: - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/run-mongohouse-image.sh bootstrap-mongo-orchestration: - command: subprocess.exec params: binary: bash - env: + env: MONGODB_VERSION: ${VERSION} - include_expansions_in_env: [TOPOLOGY, AUTH, SSL, ORCHESTRATION_FILE, + include_expansions_in_env: [TOPOLOGY, AUTH, SSL, ORCHESTRATION_FILE, REQUIRE_API_VERSION, LOAD_BALANCER] args: ["${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh"] - command: expansions.update @@ -146,7 +146,7 @@ functions: - command: subprocess.exec params: binary: bash - env: + env: MONGODB_VERSION: ${VERSION} include_expansions_in_env: [TOPOLOGY, AUTH, SSL, ORCHESTRATION_FILE] args: ["${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh"] @@ -158,23 +158,23 @@ functions: - command: subprocess.exec params: binary: bash - args: + args: # Ensure the instance profile is reassigned for aws tests. - ${DRIVERS_TOOLS}/.evergreen/auth_aws/teardown.sh - command: subprocess.exec params: binary: bash - args: + args: - ${DRIVERS_TOOLS}/.evergreen/csfle/teardown.sh - command: subprocess.exec params: binary: bash - args: + args: - ${DRIVERS_TOOLS}/.evergreen/ocsp/teardown.sh - command: subprocess.exec params: binary: bash - args: + args: - ${DRIVERS_TOOLS}/.evergreen/teardown.sh assume-test-secrets-ec2-role: @@ -200,7 +200,19 @@ functions: binary: bash env: GO_BUILD_TAGS: cse - include_expansions_in_env: ["TOPOLOGY", "AUTH", "SSL", "SKIP_CSOT_TESTS", "MONGODB_URI", "CRYPT_SHARED_LIB_PATH", "SKIP_CRYPT_SHARED_LIB", "RACE", "MONGO_GO_DRIVER_COMPRESSOR", "REQUIRE_API_VERSION", "LOAD_BALANCER"] + include_expansions_in_env: + - "TOPOLOGY" + - "AUTH" + - "SSL" + - "SKIP_CSOT_TESTS" + - "MONGODB_URI" + - "CRYPT_SHARED_LIB_PATH" + - "SKIP_CRYPT_SHARED_LIB" + - "RACE" + - "MONGO_GO_DRIVER_COMPRESSOR" + - "REQUIRE_API_VERSION" + - "LOAD_BALANCER" + - "MONGOPROXY" args: [*task-runner, setup-test] - command: subprocess.exec type: test @@ -294,7 +306,7 @@ functions: params: binary: bash include_expansions_in_env: [AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN] - env: + env: TEST_ENTERPRISE_AUTH: plain args: [*task-runner, setup-test] - command: subprocess.exec @@ -312,7 +324,7 @@ functions: params: binary: bash include_expansions_in_env: [AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN] - env: + env: TEST_ENTERPRISE_AUTH: gssapi args: [*task-runner, setup-test] - command: subprocess.exec @@ -344,7 +356,7 @@ functions: type: test params: binary: bash - env: + env: TOPOLOGY: server AUTH: auth SSL: ssl @@ -362,9 +374,9 @@ functions: type: test params: binary: bash - env: + env: GO_BUILD_TAGS: cse - include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, MONGO_GO_DRIVER_COMPRESSOR, + include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, MONGO_GO_DRIVER_COMPRESSOR, REQUIRE_API_VERSION, SKIP_CRYPT_SHARED_LIB, CRYPT_SHARED_LIB_PATH] args: [*task-runner, setup-test] - command: subprocess.exec @@ -379,7 +391,7 @@ functions: params: binary: bash include_expansions_in_env: [SINGLE_MONGOS_LB_URI, MULTI_MONGOS_LB_URI, AUTH, SSL, MONGO_GO_DRIVER_COMPRESSOR] - env: + env: LOAD_BALANCER: "true" args: [*task-runner, setup-test] - command: subprocess.exec @@ -393,7 +405,7 @@ functions: type: test params: binary: "bash" - env: + env: AUTH: auth SSL: nossl TOPOLOGY: server @@ -417,7 +429,7 @@ functions: type: test params: binary: "bash" - env: + env: TOPOLOGY: sharded_cluster TASKFILE_TARGET: test-short args: [*task-runner, run-docker] @@ -477,7 +489,7 @@ functions: params: include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] binary: "bash" - args: + args: - ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup-secrets.sh run-aws-auth-test-with-regular-aws-credentials: @@ -502,7 +514,7 @@ functions: - command: subprocess.exec type: test params: - binary: bash + binary: bash include_expansions_in_env: [SKIP_EC2_AUTH_TEST] env: AWS_TEST: ec2 @@ -547,7 +559,7 @@ functions: type: test params: binary: bash - env: + env: AWS_ROLE_SESSION_NAME: test AWS_TEST: web-identity include_expansions_in_env: [SKIP_WEB_IDENTITY_AUTH_TEST] @@ -570,14 +582,25 @@ functions: binary: bash args: ["${DRIVERS_TOOLS}/.evergreen/csfle/await-servers.sh"] + start-mongoproxy: + - command: ec2.assume_role # TODO: not sure this is needd + params: + role_arn: ${aws_test_secrets_role} + - command: subprocess.exec + params: + binary: bash + include_expansions_in_env: [AUTH, SSL, MONGODB_URI] + background: true + args: [*task-runner, start-mongoproxy] + run-kms-tls-test: - command: subprocess.exec type: test params: binary: "bash" - env: + env: GO_BUILD_TAGS: cse - include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, + include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, MONGO_GO_DRIVER_COMPRESSOR] args: [*task-runner, setup-test] - command: subprocess.exec @@ -592,9 +615,9 @@ functions: type: test params: binary: "bash" - env: + env: GO_BUILD_TAGS: cse - include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, + include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, MONGO_GO_DRIVER_COMPRESSOR] args: [*task-runner, setup-test] - command: subprocess.exec @@ -610,9 +633,9 @@ functions: type: test params: binary: "bash" - env: + env: GO_BUILD_TAGS: cse - include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, + include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, MONGO_GO_DRIVER_COMPRESSOR] args: [*task-runner, setup-test] - command: subprocess.exec @@ -651,14 +674,14 @@ tasks: allowed_requesters: ["patch", "github_pr"] commands: - func: assume-test-secrets-ec2-role - - func: "add PR reviewer" + - func: "add PR reviewer" - func: "add PR labels" - func: "create-api-report" - name: backport-pr allowed_requesters: ["commit"] commands: - - func: "backport pr" + - func: "backport pr" - name: perf tags: ["performance"] @@ -686,11 +709,13 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + - func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" AUTH: "noauth" SSL: "nossl" + MONGOPROXY: "true" - name: test-standalone-noauth-nossl-snappy-compression tags: ["test", "standalone", "compression", "snappy"] @@ -701,6 +726,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -717,6 +743,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -733,6 +760,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -749,6 +777,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -764,6 +793,7 @@ tasks: AUTH: "auth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -779,6 +809,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -795,6 +826,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -811,6 +843,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "server" @@ -1231,6 +1264,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "replica_set" @@ -1247,6 +1281,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "replica_set" @@ -1262,6 +1297,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "replica_set" @@ -1277,6 +1313,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "replica_set" @@ -1295,6 +1332,7 @@ tasks: AUTH: "auth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "replica_set" @@ -1310,6 +1348,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1325,6 +1364,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1341,6 +1381,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1357,6 +1398,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1373,6 +1415,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1388,6 +1431,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1404,6 +1448,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1420,6 +1465,7 @@ tasks: AUTH: "auth" SSL: "ssl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1436,6 +1482,7 @@ tasks: AUTH: "auth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-tests vars: TOPOLOGY: "sharded_cluster" @@ -1454,7 +1501,7 @@ tasks: vars: MONGO_GO_DRIVER_COMPRESSOR: "snappy" - # Build the compilecheck submodule with all supported versions of Go >= + # Build the compilecheck submodule with all supported versions of Go >= # the minimum supported version. - name: go-build tags: ["compile-check"] @@ -1507,6 +1554,7 @@ tasks: SSL: "nossl" REQUIRE_API_VERSION: true - func: start-cse-servers + #- func: start-mongoproxy - func: run-versioned-api-test vars: TOPOLOGY: "server" @@ -1524,6 +1572,7 @@ tasks: SSL: "nossl" REQUIRE_API_VERSION: true - func: start-cse-servers + #- func: start-mongoproxy - func: run-versioned-api-test vars: TOPOLOGY: "sharded_cluster" @@ -1541,6 +1590,7 @@ tasks: SSL: "nossl" ORCHESTRATION_FILE: "versioned-api-testing.json" - func: start-cse-servers + #- func: start-mongoproxy - func: run-versioned-api-test vars: TOPOLOGY: "server" @@ -1556,6 +1606,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-kms-tls-test vars: KMS_TLS_TESTCASE: "INVALID_CERT" @@ -1572,6 +1623,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-kms-tls-test vars: KMS_TLS_TESTCASE: "INVALID_HOSTNAME" @@ -1588,6 +1640,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-kmip-tests vars: TOPOLOGY: "server" @@ -1603,6 +1656,7 @@ tasks: AUTH: "noauth" SSL: "nossl" - func: start-cse-servers + #- func: start-mongoproxy - func: run-retry-kms-requests - name: "testgcpkms-task" @@ -1645,7 +1699,7 @@ tasks: params: binary: "bash" add_expansions_to_env: true - env: + env: EXPECT_ERROR: 'status=400' args: [*task-runner, test-awskms] @@ -1901,13 +1955,13 @@ task_groups: params: binary: "bash" add_expansions_to_env: true - args: + args: - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/setup.sh teardown_group: - command: subprocess.exec params: binary: "bash" - args: + args: - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/teardown.sh - func: teardown - func: handle-test-artifacts @@ -1925,7 +1979,7 @@ task_groups: params: binary: bash add_expansions_to_env: true - env: + env: AZUREKMS_VMNAME_PREFIX: GODRIVER args: - ${DRIVERS_TOOLS}/.evergreen/csfle/azurekms/setup.sh @@ -1933,7 +1987,7 @@ task_groups: - command: subprocess.exec params: binary: "bash" - args: + args: - ${DRIVERS_TOOLS}/.evergreen/csfle/azurekms/teardown.sh - func: teardown - func: handle-test-artifacts @@ -2156,7 +2210,7 @@ buildvariants: - rhel8.7-small expansions: GO_DIST: "/opt/golang/go1.23" - tasks: + tasks: - name: "backport-pr" - name: atlas-test diff --git a/.evergreen/run-task.sh b/.evergreen/run-task.sh index 564d937703..a5001f04b8 100755 --- a/.evergreen/run-task.sh +++ b/.evergreen/run-task.sh @@ -1,13 +1,13 @@ #!/usr/bin/env bash # # Source the env.sh file and run the given task -set -eu +set -exu -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) PROJECT_DIRECTORY=$(dirname $SCRIPT_DIR) -pushd ${PROJECT_DIRECTORY} > /dev/null +pushd ${PROJECT_DIRECTORY} >/dev/null source env.sh task "$@" -popd > /dev/null +popd >/dev/null diff --git a/.evergreen/setup-system.sh b/.evergreen/setup-system.sh index 542060fee4..83f00d1572 100755 --- a/.evergreen/setup-system.sh +++ b/.evergreen/setup-system.sh @@ -4,7 +4,7 @@ set -eu # Set up default environment variables. -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) PROJECT_DIRECTORY=$(dirname $SCRIPT_DIR) pushd $PROJECT_DIRECTORY ROOT_DIR=$(dirname $PROJECT_DIRECTORY) @@ -40,18 +40,19 @@ PATH="${GOROOT}/bin:${GOPATH}/bin:${MONGODB_BINARIES}:${EXTRA_PATH}:${PATH}" # Get the current unique version of this checkout. if [ "${IS_PATCH:-}" = "true" ]; then - CURRENT_VERSION=$(git describe)-patch-${VERSION_ID} + CURRENT_VERSION=$(git describe)-patch-${VERSION_ID} else - CURRENT_VERSION=latest + CURRENT_VERSION=latest fi # Ensure a checkout of drivers-tools. if [ ! -d "$DRIVERS_TOOLS" ]; then - git clone https://github.com/mongodb-labs/drivers-evergreen-tools $DRIVERS_TOOLS + #git clone https://github.com/mongodb-labs/drivers-evergreen-tools $DRIVERS_TOOLS + git clone -b DRIVERS-2884-mongoproxy https://github.com/prestonvasquez/drivers-evergreen-tools $DRIVERS_TOOLS fi # Write the .env file for drivers-tools. -cat < ${DRIVERS_TOOLS}/.env +cat <${DRIVERS_TOOLS}/.env SKIP_LEGACY_SHELL=1 DRIVERS_TOOLS="$DRIVERS_TOOLS" MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME" @@ -67,7 +68,7 @@ go env go install github.com/go-task/task/v3/cmd/task@v3.39.1 # Write our own env file. -cat < env.sh +cat <env.sh export GOROOT="$GOROOT" export GOPATH="$GOPATH" export GOCACHE="$GOCACHE" @@ -78,12 +79,12 @@ export PATH="$PATH" EOT if [ "Windows_NT" = "$OS" ]; then - echo "export USERPROFILE=$USERPROFILE" >> env.sh - echo "export HOME=$HOME" >> env.sh + echo "export USERPROFILE=$USERPROFILE" >>env.sh + echo "export HOME=$HOME" >>env.sh fi # source the env.sh file and write the expansion file. -cat < expansion.yml +cat <expansion.yml CURRENT_VERSION: "$CURRENT_VERSION" DRIVERS_TOOLS: "$DRIVERS_TOOLS" PROJECT_DIRECTORY: "$PROJECT_DIRECTORY" diff --git a/Taskfile.yml b/Taskfile.yml index 3473cb4981..b25a76b61a 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -123,6 +123,8 @@ tasks: setup-encryption: bash etc/setup-encryption.sh + start-mongoproxy: bash etc/start-mongoproxy.sh + evg-test: - go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH} DYLD_LIBRARY_PATH=$MACOS_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s -p 1 ./... >> test.suite diff --git a/etc/setup-test.sh b/etc/setup-test.sh index 8f31e301a4..6ce2da635a 100755 --- a/etc/setup-test.sh +++ b/etc/setup-test.sh @@ -10,108 +10,117 @@ RACE=${RACE:-} SERVERLESS=${SERVERLESS:-} LOAD_BALANCER=${LOAD_BALANCER:-} MONGODB_URI=${MONGODB_URI:-} +MONGOPROXY=${MONGOPROXY:-} +MONGO_PROX_URI=${MONGO_PROX_URI:-} # Handle special cases first. if [ -n "${TEST_ENTERPRISE_AUTH:-}" ]; then - . $DRIVERS_TOOLS/.evergreen/secrets_handling/setup-secrets.sh drivers/enterprise_auth - AUTH="auth" - case $TEST_ENTERPRISE_AUTH in - plain) - MONGODB_URI="mongodb://${SASL_USER}:${SASL_PASS}@${SASL_HOST}:${SASL_PORT}/ldap?authMechanism=PLAIN" - ;; - gssapi) - if [ "Windows_NT" = "${OS:-}" ]; then - MONGODB_URI="mongodb://${PRINCIPAL/@/%40}:${SASL_PASS}@${SASL_HOST}:${SASL_PORT}/kerberos?authMechanism=GSSAPI" - else - echo ${KEYTAB_BASE64} | base64 -d > ${PROJECT_DIRECTORY}/.evergreen/drivers.keytab - mkdir -p ~/.krb5 - cat .evergreen/krb5.config | tee -a ~/.krb5/config - kinit -k -t .evergreen/drivers.keytab -p "${PRINCIPAL}" - MONGODB_URI="mongodb://${PRINCIPAL/@/%40}@${SASL_HOST}:${SASL_PORT}/kerberos?authMechanism=GSSAPI" - fi - ;; - esac - rm secrets-export.sh + . $DRIVERS_TOOLS/.evergreen/secrets_handling/setup-secrets.sh drivers/enterprise_auth + AUTH="auth" + case $TEST_ENTERPRISE_AUTH in + plain) + MONGODB_URI="mongodb://${SASL_USER}:${SASL_PASS}@${SASL_HOST}:${SASL_PORT}/ldap?authMechanism=PLAIN" + ;; + gssapi) + if [ "Windows_NT" = "${OS:-}" ]; then + MONGODB_URI="mongodb://${PRINCIPAL/@/%40}:${SASL_PASS}@${SASL_HOST}:${SASL_PORT}/kerberos?authMechanism=GSSAPI" + else + echo ${KEYTAB_BASE64} | base64 -d >${PROJECT_DIRECTORY}/.evergreen/drivers.keytab + mkdir -p ~/.krb5 + cat .evergreen/krb5.config | tee -a ~/.krb5/config + kinit -k -t .evergreen/drivers.keytab -p "${PRINCIPAL}" + MONGODB_URI="mongodb://${PRINCIPAL/@/%40}@${SASL_HOST}:${SASL_PORT}/kerberos?authMechanism=GSSAPI" + fi + ;; + esac + rm secrets-export.sh fi if [ -n "${SERVERLESS}" ]; then - . $DRIVERS_TOOLS/.evergreen/serverless/secrets-export.sh - MONGODB_URI="${SERVERLESS_URI}" - AUTH="auth" + . $DRIVERS_TOOLS/.evergreen/serverless/secrets-export.sh + MONGODB_URI="${SERVERLESS_URI}" + AUTH="auth" fi if [ -n "${TEST_ATLAS_CONNECT:-}" ]; then - . $DRIVERS_TOOLS/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect + . $DRIVERS_TOOLS/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect fi if [ -n "${LOAD_BALANCER}" ]; then - # Verify that the required LB URI expansions are set to ensure that the test runner can correctly connect to - # the LBs. - if [ -z "${SINGLE_MONGOS_LB_URI}" ]; then - echo "SINGLE_MONGOS_LB_URI must be set for testing against LBs" - exit 1 - fi - if [ -z "${MULTI_MONGOS_LB_URI}" ]; then - echo "MULTI_MONGOS_LB_URI must be set for testing against LBs" - exit 1 - fi - MONGODB_URI="${SINGLE_MONGOS_LB_URI}" + # Verify that the required LB URI expansions are set to ensure that the test runner can correctly connect to + # the LBs. + if [ -z "${SINGLE_MONGOS_LB_URI}" ]; then + echo "SINGLE_MONGOS_LB_URI must be set for testing against LBs" + exit 1 + fi + if [ -z "${MULTI_MONGOS_LB_URI}" ]; then + echo "MULTI_MONGOS_LB_URI must be set for testing against LBs" + exit 1 + fi + MONGODB_URI="${SINGLE_MONGOS_LB_URI}" +fi + +if [ -n "${MONGOPROXY}" ]; then + echo "MONGOPROXY is set, using the proxy for qualifying tests." + # If MONGOPROXY is set, we assume that the user wants to use the proxy for all tests. + # This is useful for testing the proxy itself. + MONGO_PROXY_URI="mongodb://127.0.0.1:28017/?directConnection=true" fi if [ -n "${OCSP_ALGORITHM:-}" ]; then - MONGO_GO_DRIVER_CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" - if [ "Windows_NT" = "$OS" ]; then - MONGO_GO_DRIVER_CA_FILE=$(cygpath -m $MONGO_GO_DRIVER_CA_FILE) - fi + MONGO_GO_DRIVER_CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" + if [ "Windows_NT" = "$OS" ]; then + MONGO_GO_DRIVER_CA_FILE=$(cygpath -m $MONGO_GO_DRIVER_CA_FILE) + fi fi # Handle encryption. if [[ "${GO_BUILD_TAGS}" =~ cse ]]; then - # Install libmongocrypt if needed. - task install-libmongocrypt - - # Handle libmongocrypt paths. - PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib64/pkgconfig - LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib64 - - if [ "$(uname -s)" = "Darwin" ]; then - PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib/pkgconfig - DYLD_FALLBACK_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib - fi - - if [ "${SKIP_CRYPT_SHARED_LIB:-''}" = "true" ]; then - CRYPT_SHARED_LIB_PATH="" - echo "crypt_shared library is skipped" - elif [ -z "${CRYPT_SHARED_LIB_PATH:-}" ]; then - echo "crypt_shared library path is empty" - else - echo "crypt_shared library will be loaded from path: $CRYPT_SHARED_LIB_PATH" - fi + # Install libmongocrypt if needed. + task install-libmongocrypt + + # Handle libmongocrypt paths. + PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib64/pkgconfig + LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib64 + + if [ "$(uname -s)" = "Darwin" ]; then + PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib/pkgconfig + DYLD_FALLBACK_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib + fi + + if [ "${SKIP_CRYPT_SHARED_LIB:-''}" = "true" ]; then + CRYPT_SHARED_LIB_PATH="" + echo "crypt_shared library is skipped" + elif [ -z "${CRYPT_SHARED_LIB_PATH:-}" ]; then + echo "crypt_shared library path is empty" + else + echo "crypt_shared library will be loaded from path: $CRYPT_SHARED_LIB_PATH" + fi fi # Handle the build tags argument. if [ -n "${GO_BUILD_TAGS}" ]; then - BUILD_TAGS="${RACE} --tags=${GO_BUILD_TAGS}" + BUILD_TAGS="${RACE} --tags=${GO_BUILD_TAGS}" else - BUILD_TAGS="${RACE}" + BUILD_TAGS="${RACE}" fi # Handle certificates. if [ "$SSL" != "nossl" ] && [ -z "${SERVERLESS}" ] && [ -z "${OCSP_ALGORITHM:-}" ]; then - MONGO_GO_DRIVER_CA_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/ca.pem" - MONGO_GO_DRIVER_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client.pem" - MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client-pkcs8-encrypted.pem" - MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client-pkcs8-unencrypted.pem" - - if [ "Windows_NT" = "$OS" ]; then - MONGO_GO_DRIVER_CA_FILE=$(cygpath -m $MONGO_GO_DRIVER_CA_FILE) - MONGO_GO_DRIVER_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_KEY_FILE) - MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE) - MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE) - fi + MONGO_GO_DRIVER_CA_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/ca.pem" + MONGO_GO_DRIVER_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client.pem" + MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client-pkcs8-encrypted.pem" + MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE="$DRIVERS_TOOLS/.evergreen/x509gen/client-pkcs8-unencrypted.pem" + + if [ "Windows_NT" = "$OS" ]; then + MONGO_GO_DRIVER_CA_FILE=$(cygpath -m $MONGO_GO_DRIVER_CA_FILE) + MONGO_GO_DRIVER_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_KEY_FILE) + MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE) + MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE=$(cygpath -m $MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE) + fi fi -cat < .test.env +cat <.test.env AUTH="${AUTH:-}" SSL="${SSL}" MONGO_GO_DRIVER_CA_FILE="${MONGO_GO_DRIVER_CA_FILE:-}" @@ -129,27 +138,28 @@ PKG_CONFIG_PATH="${PKG_CONFIG_PATH:-}" LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}" MACOS_LIBRARY_PATH="${DYLD_FALLBACK_LIBRARY_PATH:-}" SKIP_CSOT_TESTS=${SKIP_CSOT_TESTS:-} +MONGO_PROXY_URI="${MONGO_PROXY_URI:-}" EOT if [ -n "${MONGODB_URI}" ]; then - echo "MONGODB_URI=\"${MONGODB_URI}\"" >> .test.env + echo "MONGODB_URI=\"${MONGODB_URI}\"" >>.test.env fi if [ -n "${SERVERLESS}" ]; then - echo "SERVERLESS_ATLAS_USER=$SERVERLESS_ATLAS_USER" >> .test.env - echo "SERVERLESS_ATLAS_PASSWORD=$SERVERLESS_ATLAS_PASSWORD" >> .test.env + echo "SERVERLESS_ATLAS_USER=$SERVERLESS_ATLAS_USER" >>.test.env + echo "SERVERLESS_ATLAS_PASSWORD=$SERVERLESS_ATLAS_PASSWORD" >>.test.env fi -if [ -n "${LOAD_BALANCER}" ];then - echo "SINGLE_MONGOS_LB_URI=${SINGLE_MONGOS_LB_URI}" >> .test.env - echo "MULTI_MONGOS_LB_URI=${MULTI_MONGOS_LB_URI}" >> .test.env +if [ -n "${LOAD_BALANCER}" ]; then + echo "SINGLE_MONGOS_LB_URI=${SINGLE_MONGOS_LB_URI}" >>.test.env + echo "MULTI_MONGOS_LB_URI=${MULTI_MONGOS_LB_URI}" >>.test.env fi # Add secrets to the test file. if [ -f "secrets-export.sh" ]; then - while read p; do - if [[ "$p" =~ ^export ]]; then - echo "$p" | sed 's/export //' >> .test.env - fi - done >.test.env + fi + done 1 and the blocking time is > default pending read timeout. + _ = t.Client.Ping(context.Background(), nil) + db := t.Client.Database("admin") for _, fp := range t.failPointNames { cmd := failpoint.FailPoint{ @@ -640,6 +667,14 @@ func (t *T) createTestClient() { atomic.AddInt64(&t.connsCheckedOut, 1) case event.ConnectionCheckedIn: atomic.AddInt64(&t.connsCheckedOut, -1) + case event.ConnectionPendingResponseStarted: + atomic.AddInt64(&t.connPendingReadStarted, 1) + case event.ConnectionPendingResponseSucceeded: + atomic.AddInt64(&t.connPendingReadSucceeded, 1) + case event.ConnectionPendingResponseFailed: + atomic.AddInt64(&t.connPendingReadFailed, 1) + case event.ConnectionClosed: + atomic.AddInt64(&t.connClosed, 1) } }, }) @@ -660,6 +695,8 @@ func (t *T) createTestClient() { t.mockDeployment = drivertest.NewMockDeployment() clientOpts.Deployment = t.mockDeployment + t.Client, err = mongo.Connect(clientOpts) + case MongoProxy: t.Client, err = mongo.Connect(clientOpts) case Proxy: t.proxyDialer = newProxyDialer() diff --git a/internal/integration/mtest/options.go b/internal/integration/mtest/options.go index aff188b481..72d3affddc 100644 --- a/internal/integration/mtest/options.go +++ b/internal/integration/mtest/options.go @@ -43,6 +43,9 @@ const ( // Proxy specifies a client that proxies messages to the server and also stores parsed copies. The proxied // messages can be retrieved via T.GetProxiedMessages or T.GetRawProxiedMessages. Proxy + // MongoProxy specifies a client that communicates with a MongoDB proxy server + // as defined in Drivers Evergreen Tools. + MongoProxy ) var ( diff --git a/internal/integration/unified/event.go b/internal/integration/unified/event.go index abbec74439..9ee8fe7404 100644 --- a/internal/integration/unified/event.go +++ b/internal/integration/unified/event.go @@ -16,27 +16,30 @@ import ( type monitoringEventType string const ( - commandStartedEvent monitoringEventType = "CommandStartedEvent" - commandSucceededEvent monitoringEventType = "CommandSucceededEvent" - commandFailedEvent monitoringEventType = "CommandFailedEvent" - poolCreatedEvent monitoringEventType = "PoolCreatedEvent" - poolReadyEvent monitoringEventType = "PoolReadyEvent" - poolClearedEvent monitoringEventType = "PoolClearedEvent" - poolClosedEvent monitoringEventType = "PoolClosedEvent" - connectionCreatedEvent monitoringEventType = "ConnectionCreatedEvent" - connectionReadyEvent monitoringEventType = "ConnectionReadyEvent" - connectionClosedEvent monitoringEventType = "ConnectionClosedEvent" - connectionCheckOutStartedEvent monitoringEventType = "ConnectionCheckOutStartedEvent" - connectionCheckOutFailedEvent monitoringEventType = "ConnectionCheckOutFailedEvent" - connectionCheckedOutEvent monitoringEventType = "ConnectionCheckedOutEvent" - connectionCheckedInEvent monitoringEventType = "ConnectionCheckedInEvent" - serverDescriptionChangedEvent monitoringEventType = "ServerDescriptionChangedEvent" - serverHeartbeatFailedEvent monitoringEventType = "ServerHeartbeatFailedEvent" - serverHeartbeatStartedEvent monitoringEventType = "ServerHeartbeatStartedEvent" - serverHeartbeatSucceededEvent monitoringEventType = "ServerHeartbeatSucceededEvent" - topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent" - topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent" - topologyClosedEvent monitoringEventType = "TopologyClosedEvent" + commandStartedEvent monitoringEventType = "CommandStartedEvent" + commandSucceededEvent monitoringEventType = "CommandSucceededEvent" + commandFailedEvent monitoringEventType = "CommandFailedEvent" + poolCreatedEvent monitoringEventType = "PoolCreatedEvent" + poolReadyEvent monitoringEventType = "PoolReadyEvent" + poolClearedEvent monitoringEventType = "PoolClearedEvent" + poolClosedEvent monitoringEventType = "PoolClosedEvent" + connectionCreatedEvent monitoringEventType = "ConnectionCreatedEvent" + connectionReadyEvent monitoringEventType = "ConnectionReadyEvent" + connectionClosedEvent monitoringEventType = "ConnectionClosedEvent" + connectionCheckOutStartedEvent monitoringEventType = "ConnectionCheckOutStartedEvent" + connectionCheckOutFailedEvent monitoringEventType = "ConnectionCheckOutFailedEvent" + connectionCheckedOutEvent monitoringEventType = "ConnectionCheckedOutEvent" + connectionCheckedInEvent monitoringEventType = "ConnectionCheckedInEvent" + connectionPendingResponseStarted monitoringEventType = "ConnectionPendingResponseStarted" + connectionPendingResponseSucceeded monitoringEventType = "ConnectionPendingResponseSucceeded" + connectionPendingResponseFailed monitoringEventType = "ConnectionPendingResponseFailed" + serverDescriptionChangedEvent monitoringEventType = "ServerDescriptionChangedEvent" + serverHeartbeatFailedEvent monitoringEventType = "ServerHeartbeatFailedEvent" + serverHeartbeatStartedEvent monitoringEventType = "ServerHeartbeatStartedEvent" + serverHeartbeatSucceededEvent monitoringEventType = "ServerHeartbeatSucceededEvent" + topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent" + topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent" + topologyClosedEvent monitoringEventType = "TopologyClosedEvent" ) func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) { @@ -69,6 +72,12 @@ func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) return connectionCheckedOutEvent, true case "connectioncheckedinevent": return connectionCheckedInEvent, true + case "connectionpendingresponsestarted": + return connectionPendingResponseStarted, true + case "connectionpendingresponsesucceeded": + return connectionPendingResponseSucceeded, true + case "connectionpendingresponsefailed": + return connectionPendingResponseFailed, true case "serverdescriptionchangedevent": return serverDescriptionChangedEvent, true case "serverheartbeatfailedevent": @@ -112,6 +121,12 @@ func monitoringEventTypeFromPoolEvent(evt *event.PoolEvent) monitoringEventType return connectionCheckedOutEvent case event.ConnectionCheckedIn: return connectionCheckedInEvent + case event.ConnectionPendingResponseStarted: + return connectionPendingResponseStarted + case event.ConnectionPendingResponseSucceeded: + return connectionPendingResponseSucceeded + case event.ConnectionPendingResponseFailed: + return connectionPendingResponseFailed default: return "" } diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 56c53f8adb..eb3e8b49be 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -56,7 +56,10 @@ type cmapEvent struct { Reason *string `bson:"reason"` } `bson:"connectionCheckOutFailedEvent"` - ConnectionCheckedInEvent *struct{} `bson:"connectionCheckedInEvent"` + ConnectionCheckedInEvent *struct{} `bson:"connectionCheckedInEvent"` + ConnectionPendingResponseStarted *struct{} `bson:"connectionPendingResponseStarted"` + ConnectionPendingResponseSucceeded *struct{} `bson:"connectionPendingResponseSucceeded"` + ConnectionPendingResponseFailed *struct{} `bson:"connectionPendingResponseFailed"` PoolClearedEvent *struct { HasServiceID *bool `bson:"hasServiceId"` @@ -359,6 +362,18 @@ func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) erro if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionCheckedIn); err != nil { return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) } + case evt.ConnectionPendingResponseStarted != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseStarted); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } + case evt.ConnectionPendingResponseSucceeded != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseSucceeded); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } + case evt.ConnectionPendingResponseFailed != nil: + if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionPendingResponseFailed); err != nil { + return newEventVerificationError(idx, client, "failed to get next pool event: %v", err.Error()) + } case evt.PoolClearedEvent != nil: var actual *event.PoolEvent if actual, pooled, err = getNextPoolEvent(pooled, event.ConnectionPoolCleared); err != nil { diff --git a/internal/integration/unified/schema_version.go b/internal/integration/unified/schema_version.go index 7908b39017..f1f5447883 100644 --- a/internal/integration/unified/schema_version.go +++ b/internal/integration/unified/schema_version.go @@ -16,7 +16,7 @@ import ( var ( supportedSchemaVersions = map[int]string{ - 1: "1.22", + 1: "1.24", } ) diff --git a/internal/logger/component.go b/internal/logger/component.go index a601707cbf..366a482b86 100644 --- a/internal/logger/component.go +++ b/internal/logger/component.go @@ -14,32 +14,35 @@ import ( ) const ( - CommandFailed = "Command failed" - CommandStarted = "Command started" - CommandSucceeded = "Command succeeded" - ConnectionPoolCreated = "Connection pool created" - ConnectionPoolReady = "Connection pool ready" - ConnectionPoolCleared = "Connection pool cleared" - ConnectionPoolClosed = "Connection pool closed" - ConnectionCreated = "Connection created" - ConnectionReady = "Connection ready" - ConnectionClosed = "Connection closed" - ConnectionCheckoutStarted = "Connection checkout started" - ConnectionCheckoutFailed = "Connection checkout failed" - ConnectionCheckedOut = "Connection checked out" - ConnectionCheckedIn = "Connection checked in" - ServerSelectionFailed = "Server selection failed" - ServerSelectionStarted = "Server selection started" - ServerSelectionSucceeded = "Server selection succeeded" - ServerSelectionWaiting = "Waiting for suitable server to become available" - TopologyClosed = "Stopped topology monitoring" - TopologyDescriptionChanged = "Topology description changed" - TopologyOpening = "Starting topology monitoring" - TopologyServerClosed = "Stopped server monitoring" - TopologyServerHeartbeatFailed = "Server heartbeat failed" - TopologyServerHeartbeatStarted = "Server heartbeat started" - TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded" - TopologyServerOpening = "Starting server monitoring" + CommandFailed = "Command failed" + CommandStarted = "Command started" + CommandSucceeded = "Command succeeded" + ConnectionPoolCreated = "Connection pool created" + ConnectionPoolReady = "Connection pool ready" + ConnectionPoolCleared = "Connection pool cleared" + ConnectionPoolClosed = "Connection pool closed" + ConnectionCreated = "Connection created" + ConnectionReady = "Connection ready" + ConnectionClosed = "Connection closed" + ConnectionCheckoutStarted = "Connection checkout started" + ConnectionCheckoutFailed = "Connection checkout failed" + ConnectionCheckedOut = "Connection checked out" + ConnectionCheckedIn = "Connection checked in" + ConnectionPendingResponseStarted = "Pending response started" + ConnectionPendingResponseSucceeded = "Pending response succeeded" + ConnectionPendingResponseFailed = "Pending response failed" + ServerSelectionFailed = "Server selection failed" + ServerSelectionStarted = "Server selection started" + ServerSelectionSucceeded = "Server selection succeeded" + ServerSelectionWaiting = "Waiting for suitable server to become available" + TopologyClosed = "Stopped topology monitoring" + TopologyDescriptionChanged = "Topology description changed" + TopologyOpening = "Starting topology monitoring" + TopologyServerClosed = "Stopped server monitoring" + TopologyServerHeartbeatFailed = "Server heartbeat failed" + TopologyServerHeartbeatStarted = "Server heartbeat started" + TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded" + TopologyServerOpening = "Starting server monitoring" ) const ( diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 50136456e4..d4a23c8293 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -76,6 +76,13 @@ type RetryablePoolError interface { Retryable() bool } +// RetryablePendingResponseError is an error that indicates that an error +// occurred while reading a pending response caused by a socket timeout is +// retryable. +type RetryablePendingResponseError interface { + Retryable() bool +} + // labeledError is an error that can have error labels added to it. type labeledError interface { error @@ -641,6 +648,11 @@ func (op Operation) Execute(ctx context.Context) error { if srvr == nil || conn == nil { srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers) if err != nil { + // If the returned error is a context error, return it immediately. + if ctx.Err() != nil { + return err + } + // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server // and connection to nil to request a new server and connection. @@ -785,6 +797,14 @@ func (op Operation) Execute(ctx context.Context) error { if moreToCome { roundTrip = op.moreToComeRoundTrip } + + // Set context values to handle a pending read in case of a socket + // timeout. + if maxTimeMS != 0 { + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, startedInfo.requestID) + } + res, err = roundTrip(ctx, conn, *wm) if ep, ok := srvr.(ErrorProcessor); ok { diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index b1bf1d13f1..1c9b1e02ac 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -7,6 +7,7 @@ package topology import ( + "bufio" // Import bufio "context" "crypto/tls" "encoding/binary" @@ -46,6 +47,13 @@ var ( func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) } +type pendingResponseState struct { + remainingBytes int32 + sizeBytesReadBeforeSocketTimeout []byte + requestID int32 + start time.Time +} + type connection struct { // state must be accessed using the atomic package and should be at the beginning of the struct. // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG @@ -53,7 +61,8 @@ type connection struct { state int64 id string - nc net.Conn // When nil, the connection is closed. + nc net.Conn // When nil, the connection is closed. + br *bufio.Reader // When non-nil, used to read from nc. addr address.Address idleTimeout time.Duration idleStart atomic.Value // Stores a time.Time @@ -81,9 +90,11 @@ type connection struct { // accessTokens in the OIDC authenticator cache. oidcTokenGenID uint64 - // awaitRemainingBytes indicates the size of server response that was not completely - // read before returning the connection to the pool. - awaitRemainingBytes *int32 + // pendingResponseState contains information required to attempt a pending read + // in the event of a socket timeout for an operation that has appended + // maxTimeMS to the wire message. + pendingResponseState *pendingResponseState + pendingResponseStateMu sync.Mutex } // newConnection handles the creation of a connection. It does not connect the connection. @@ -233,6 +244,9 @@ func (c *connection) connect(ctx context.Context) (err error) { c.nc = tlsNc } + // Initialize the buffered reader now that we have a finalized net.Conn. + c.br = bufio.NewReader(c.nc) + // running hello and authentication is handled by a handshaker on the configuration instance. handshaker := c.config.handshaker if handshaker == nil { @@ -409,11 +423,14 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - if c.awaitRemainingBytes == nil { - // If the connection was not marked as awaiting response, close the - // connection because we don't know what the connection state is. + c.pendingResponseStateMu.Lock() + if c.pendingResponseState == nil { + // If there is no pending read on the connection, use the pre-CSOT + // behavior and close the connection because we don't know if there are + // other bytes left to read. c.close() } + c.pendingResponseStateMu.Unlock() message := errMsg return nil, ConnectionError{ ConnectionID: c.id, @@ -473,10 +490,16 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst // because there might be more than one wire message waiting to be read, for example when // reading messages from an exhaust cursor. - n, err := io.ReadFull(c.nc, sizeBuf[:]) + n, err := io.ReadFull(c.nc, sizeBuf[:]) // Use the buffered reader if err != nil { - if l := int32(n); l == 0 && isCSOTTimeout(err) { - c.awaitRemainingBytes = &l + if isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) { + requestID, _ := driverutil.GetRequestID(ctx) + + c.pendingResponseState = &pendingResponseState{ + sizeBytesReadBeforeSocketTimeout: sizeBuf[:n], + requestID: requestID, + start: time.Now(), + } } return nil, "incomplete read of message header", err } @@ -488,11 +511,17 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, dst := make([]byte, size) copy(dst, sizeBuf[:]) - n, err = io.ReadFull(c.nc, dst[4:]) + n, err = io.ReadFull(c.nc, dst[4:]) // Use the buffered reader if err != nil { remainingBytes := size - 4 - int32(n) - if remainingBytes > 0 && isCSOTTimeout(err) { - c.awaitRemainingBytes = &remainingBytes + if remainingBytes > 0 && isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) { + requestID, _ := driverutil.GetRequestID(ctx) + + c.pendingResponseState = &pendingResponseState{ + remainingBytes: remainingBytes, + requestID: requestID, + start: time.Now(), + } } return dst, "incomplete read of full message", err } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 5b2f39f272..995c0f97f7 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -7,6 +7,7 @@ package topology import ( + "bufio" "context" "crypto/tls" "errors" @@ -380,7 +381,7 @@ func TestConnection(t *testing.T) { t.Run("size read errors", func(t *testing.T) { err := errors.New("Read error") tnc := &testNetConn{readerr: err} - conn := &connection{id: "foobar", nc: tnc, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -397,7 +398,7 @@ func TestConnection(t *testing.T) { t.Run("size too small errors", func(t *testing.T) { err := errors.New("malformed message length: 3") tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}} - conn := &connection{id: "foobar", nc: tnc, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -414,7 +415,7 @@ func TestConnection(t *testing.T) { t.Run("full message read errors", func(t *testing.T) { err := errors.New("Read error") tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}} - conn := &connection{id: "foobar", nc: tnc, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -450,7 +451,7 @@ func TestConnection(t *testing.T) { err := errors.New("length of read message too large") tnc := &testNetConn{buf: make([]byte, len(tc.buffer))} copy(tnc.buf, tc.buffer) - conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -467,7 +468,7 @@ func TestConnection(t *testing.T) { want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} tnc := &testNetConn{buf: make([]byte, len(want))} copy(tnc.buf, want) - conn := &connection{id: "foobar", nc: tnc, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -497,7 +498,7 @@ func TestConnection(t *testing.T) { readBuf := []byte{10, 0, 0, 0} nc := newCancellationReadConn(&testNetConn{}, tc.skip, readBuf) - conn := &connection{id: "foobar", nc: nc, state: connConnected} + conn := &connection{id: "foobar", nc: nc, state: connConnected, br: bufio.NewReader(nc)} listener := newTestCancellationListener(false) conn.cancellationListener = listener @@ -525,7 +526,7 @@ func TestConnection(t *testing.T) { }) t.Run("closes connection if context is cancelled even if the socket read succeeds", func(t *testing.T) { tnc := &testNetConn{buf: []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}} - conn := &connection{id: "foobar", nc: tnc, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected, br: bufio.NewReader(tnc)} listener := newTestCancellationListener(true) conn.cancellationListener = listener @@ -566,7 +567,7 @@ func TestConnection(t *testing.T) { t.Run("cancellation listener callback", func(t *testing.T) { t.Run("closes connection", func(t *testing.T) { tnc := &testNetConn{} - conn := &connection{state: connConnected, nc: tnc} + conn := &connection{state: connConnected, nc: tnc, br: bufio.NewReader(tnc), id: "foobar"} conn.cancellationListenerCallback() assert.True(t, conn.state == connDisconnected, "expected connection state %v, got %v", connDisconnected, diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 162bb9c1af..af906a3c1a 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -8,8 +8,9 @@ package topology import ( "context" + "encoding/binary" + "errors" "fmt" - "io" "net" "sync" "sync/atomic" @@ -574,6 +575,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { return nil, w.err } + if err := awaitPendingResponse(ctx, p, w.conn); err != nil { + return nil, err + } + duration = time.Since(start) if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ @@ -630,6 +635,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { return nil, w.err } + if err := awaitPendingResponse(ctx, p, w.conn); err != nil { + return nil, err + } + duration := time.Since(start) if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ @@ -769,82 +778,324 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro return nil } -var ( - // BGReadTimeout is the maximum amount of the to wait when trying to read - // the server reply on a connection after an operation timed out. The - // default is 400ms. - // - // Deprecated: BGReadTimeout is intended for internal use only and may be - // removed or modified at any time. - BGReadTimeout = 400 * time.Millisecond +// PendingResponseTimeout is the maximum amount of the to wait when trying to +// read the server reply on a connection after an operation timed out. The +// default is 3000 milliseconds. +// +// Deprecated: PendingResponseTimeout is intended for internal use only and may +// be removed or modified at any time. +var PendingResponseTimeout = 3000 * time.Millisecond + +// publishPendingResponseStarted will log a message to the pool logger and +// publish an event to the pool monitor if they are set. +func publishPendingResponseStarted(pool *pool, conn *connection) { + prs := conn.pendingResponseState + if prs == nil { + return + } - // BGReadCallback is a callback for monitoring the behavior of the - // background-read-on-timeout connection preserving mechanism. - // - // Deprecated: BGReadCallback is intended for internal use only and may be - // removed or modified at any time. - BGReadCallback func(addr string, start, read time.Time, errs []error, connClosed bool) -) + // log a message to the pool logger if it is set. + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + } -// bgRead sets a new read deadline on the provided connection and tries to read -// any bytes returned by the server. If successful, it checks the connection -// into the provided pool. If there are any errors, it closes the connection. -// -// It calls the package-global BGReadCallback function, if set, with the -// address, timings, and any errors that occurred. -func bgRead(pool *pool, conn *connection, size int32) { - var err error - start := time.Now() + logPoolMessage(pool, logger.ConnectionPendingResponseStarted, keysAndValues...) + } - defer func() { - read := time.Now() - errs := make([]error, 0) - connClosed := false + // publish an event to the pool monitor if it is set. + if pool.monitor != nil { + event := &event.PoolEvent{ + Type: event.ConnectionPendingResponseStarted, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, + } + + pool.monitor.Event(event) + } +} + +func publishPendingResponseFailed(pool *pool, conn *connection, dur time.Duration, err error) { + prs := conn.pendingResponseState + if prs == nil { + return + } + + reason := event.ReasonError + if errors.Is(err, context.DeadlineExceeded) { + reason = event.ReasonTimedOut + } + + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + logger.KeyReason, reason, + logger.KeyError, err.Error(), + } + + logPoolMessage(pool, logger.ConnectionPendingResponseFailed, keysAndValues...) + } + + if pool.monitor != nil { + e := &event.PoolEvent{ + Type: event.ConnectionPendingResponseFailed, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, + Reason: reason, + Error: err, + Duration: dur, + } + pool.monitor.Event(e) + } +} + +func publishPendingResponseSucceeded(pool *pool, conn *connection, dur time.Duration) { + prs := conn.pendingResponseState + if prs == nil { + return + } + + if mustLogPoolMessage(pool) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyRequestID, prs.requestID, + logger.KeyDurationMS, dur.Milliseconds(), + } + + logPoolMessage(pool, logger.ConnectionPendingResponseSucceeded, keysAndValues...) + } + + if pool.monitor != nil { + event := &event.PoolEvent{ + Type: event.ConnectionPendingResponseSucceeded, + Address: pool.address.String(), + ConnectionID: conn.driverConnectionID, + RequestID: prs.requestID, + Duration: dur, + } + + pool.monitor.Event(event) + } +} + +func peekConnectionAlive(conn *connection) (int, error) { + // very short deadline so we don’t block + if err := conn.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + return 0, err + } + + // Peek(1) will fill the bufio.Reader’s buffer if needed, + // but will NOT advance it. + bytes, err := conn.br.Peek(1) + return len(bytes), err +} + +func attemptPendingResponse(ctx context.Context, conn *connection, remainingTime time.Duration) (int, error) { + state := conn.pendingResponseState + if state == nil { + return 0, fmt.Errorf("no pending read state") + } + + // compute an absolute deadline combining ctx + our leftover + var dl time.Time + if dl0, ok := ctx.Deadline(); ok && time.Now().Add(remainingTime).After(dl0) { + dl = dl0 + } else { + dl = time.Now().Add(remainingTime) + } + if err := conn.nc.SetReadDeadline(dl); err != nil { + return 0, fmt.Errorf("setting read deadline: %w", err) + } + + totalRead := 0 + + // if we haven’t even parsed the 4-byte length yet, peek it + if state.remainingBytes == 0 { + bytesLeft := 4 - len(state.sizeBytesReadBeforeSocketTimeout) + hdr, err := conn.br.Peek(bytesLeft) + if err != nil { + return 0, fmt.Errorf("peeking length prefix: %w", err) + } + + // Combine hdr with state.sizeBytesReadBeforeSocketTimeout to get the full + // header. + if len(hdr) < 4 { + hdr = append(state.sizeBytesReadBeforeSocketTimeout, hdr...) + } + + msgLen := int(binary.LittleEndian.Uint32(hdr)) + // consume those 4 bytes now that we know the message length + if _, err := conn.br.Discard(bytesLeft); err != nil { + return 0, fmt.Errorf("discarding length prefix: %w", err) + } + state.remainingBytes = int32(msgLen - 4) + } + + buf := make([]byte, 4096) + for state.remainingBytes > 0 { + // refresh the deadline so large messages don't time out + if err := conn.nc.SetReadDeadline(time.Now().Add(time.Until(dl))); err != nil { + return totalRead, fmt.Errorf("renewing deadline: %w", err) + } + + toRead := int32(len(buf)) + if state.remainingBytes < toRead { + toRead = state.remainingBytes + } + + n, err := conn.br.Read(buf[:toRead]) + if n > 0 { + totalRead += n + state.remainingBytes -= int32(n) + } if err != nil { - errs = append(errs, err) - connClosed = true - err = conn.close() - if err != nil { - errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) + // if it's just a timeout, record how much is left + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return totalRead, fmt.Errorf("timeout discarding %d bytes: %w", + state.remainingBytes, err) } + return totalRead, fmt.Errorf("reading body: %w", err) } + } + + return totalRead + 4, nil +} + +// poolClearedError is an error returned when the connection pool is cleared or currently paused. It +// is a retryable error. +type pendingResponseError struct { + err error +} + +var _ error = pendingResponseError{} +var _ driver.RetryablePoolError = pendingResponseError{} + +func (pre pendingResponseError) Error() string { + if pre.err == nil { + return "" + } + return pre.err.Error() +} + +// Retryable returns true. All poolClearedErrors are retryable. +func (pendingResponseError) Retryable() bool { return true } + +func (pre pendingResponseError) Unwrap() error { + if pre.err == nil { + return nil + } + return pre.err +} + +// awaitPendingResponse sets a new read deadline on the provided connection and +// tries to read any bytes returned by the server. If there are any errors, the +// connection will be checked back into the pool to be retried. +func awaitPendingResponse(ctx context.Context, pool *pool, conn *connection) error { + conn.pendingResponseStateMu.Lock() + defer conn.pendingResponseStateMu.Unlock() + + // If there are no bytes pending read, do nothing. + if conn.pendingResponseState == nil { + return nil + } + + publishPendingResponseStarted(pool, conn) + + var ( + pendingResponseState = conn.pendingResponseState + remainingTime = pendingResponseState.start.Add(PendingResponseTimeout).Sub(time.Now()) + err error + bytesRead int + alivenessCheck bool + ) + st := time.Now() + if remainingTime <= 0 { + // If there is no remaining time, we can just peek at the connection to check + // aliveness. In such cases, we don't want to close the connection. + bytesRead, err = peekConnectionAlive(conn) + + // Mark this attempt as alive but check the connection back it the pull and + // send a retryable error. + alivenessCheck = true + } else { + bytesRead, err = attemptPendingResponse(ctx, conn, remainingTime) + } + + endTime := time.Now() + endDuration := time.Since(st) + + if err != nil { // No matter what happens, always check the connection back into the // pool, which will either make it available for other operations or // remove it from the pool if it was closed. - err = pool.checkInNoEvent(conn) - if err != nil { - errs = append(errs, fmt.Errorf("error checking in: %w", err)) + // + // TODO(GODRIVER-3385): Figure out how to handle this error. It's possible + // that a single connection can be checked out to handle multiple concurrent + // operations. This is likely a bug in the Go Driver. So it's possible that + // the connection is idle at the point of check-in. + defer func() { + dur := time.Since(st) + publishPendingResponseFailed(pool, conn, dur, err) + + _ = pool.checkInNoEvent(conn) + }() + + isCSOTTimeout := func(err error) bool { + // If the error was a timeout error, instead of closing the + // connection mark it as awaiting response so the pool can read the + // response before making it available to other operations. + nerr := net.Error(nil) + return errors.As(err, &nerr) && nerr.Timeout() } - if BGReadCallback != nil { - BGReadCallback(conn.addr.String(), start, read, errs, connClosed) + if !isCSOTTimeout(err) { + if err := conn.close(); err != nil { + return pendingResponseError{err: err} + } + return pendingResponseError{err: err} } - }() + } - err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) - if err != nil { - err = fmt.Errorf("error setting a read deadline: %w", err) - return + // If the read was successful, then we should refresh the remaining timeout. + if bytesRead > 0 { + pendingResponseState.start = endTime } - if size == 0 { - var sizeBuf [4]byte - _, err = io.ReadFull(conn.nc, sizeBuf[:]) - if err != nil { - err = fmt.Errorf("error reading the message size: %w", err) - return + // If the remaining time has been exceeded, then close the connection. + if endTime.Sub(pendingResponseState.start) > PendingResponseTimeout { + if err := conn.close(); err != nil { + return pendingResponseError{err: err} } - size, err = conn.parseWmSizeBytes(sizeBuf) - if err != nil { - return - } - size -= 4 } - _, err = io.CopyN(io.Discard, conn.nc, int64(size)) + if err != nil { - err = fmt.Errorf("error discarding %d byte message: %w", size, err) + return pendingResponseError{err: err} + } + + // If the connection is alive but undrained we can check it back into the pool + // and return a pendingResponseError to indicate that the connection is + // alive and retryable. + if alivenessCheck { + dur := time.Since(st) + publishPendingResponseFailed(pool, conn, dur, err) + + _ = pool.checkInNoEvent(conn) + + // TODO this should be a special error noting that the remainting timeout + // has been exceeded. + return pendingResponseError{err: fmt.Errorf("connection is alive and retryable: %w", err)} } + + publishPendingResponseSucceeded(pool, conn, endDuration) + + conn.pendingResponseState = nil + + return nil } // checkIn returns an idle connection to the pool. If the connection is perished or the pool is @@ -886,21 +1137,6 @@ func (p *pool) checkInNoEvent(conn *connection) error { return ErrWrongPool } - // If the connection has an awaiting server response, try to read the - // response in another goroutine before checking it back into the pool. - // - // Do this here because we want to publish checkIn events when the operation - // is done with the connection, not when it's ready to be used again. That - // means that connections in "awaiting response" state are checked in but - // not usable, which is not covered by the current pool events. We may need - // to add pool event information in the future to communicate that. - if conn.awaitRemainingBytes != nil { - size := *conn.awaitRemainingBytes - conn.awaitRemainingBytes = nil - go bgRead(p, conn, size) - return nil - } - // Bump the connection idle start time here because we're about to make the // connection "available". The idle start time is used to determine how long // a connection has been idle and when it has reached its max idle time and diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index f58e1cf204..697fe3bb9b 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -18,6 +18,7 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/csot" + "go.mongodb.org/mongo-driver/v2/internal/driverutil" "go.mongodb.org/mongo-driver/v2/internal/eventtest" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/address" @@ -1161,24 +1162,10 @@ func TestPool_maintain(t *testing.T) { }) } -func TestBackgroundRead(t *testing.T) { +func TestAwaitPendingRead(t *testing.T) { t.Parallel() - newBGReadCallback := func(errsCh chan []error) func(string, time.Time, time.Time, []error, bool) { - return func(_ string, _, _ time.Time, errs []error, _ bool) { - errsCh <- errs - close(errsCh) - } - } - t.Run("incomplete read of message header", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond cleanup := make(chan struct{}) @@ -1202,24 +1189,22 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) - assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") - close(errsCh) // this line causes a double close if BGReadCallback is ever called. + assert.NotNil(t, conn.pendingResponseState) + assert.Len(t, conn.pendingResponseState.sizeBytesReadBeforeSocketTimeout, 3) }) t.Run("timeout reading message header, successful background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1233,8 +1218,20 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) defer p.close(context.Background()) err := p.ready() @@ -1242,8 +1239,13 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1251,22 +1253,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 0, "expected no error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.NoError(t, err) + + require.NoError(t, pendingReadError) }) t.Run("timeout reading message header, incomplete head during background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1280,8 +1273,20 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) defer p.close(context.Background()) err := p.ready() @@ -1289,8 +1294,13 @@ func TestBackgroundRead(t *testing.T) { conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1298,23 +1308,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") - assert.EqualError(t, bgErrs[0], "error reading the message size: unexpected EOF") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + + assert.EqualError(t, pendingReadError, "peeking length prefix: EOF") }) t.Run("timeout reading message header, background read timeout", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond cleanup := make(chan struct{}) @@ -1332,17 +1332,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1350,26 +1368,16 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + wantErr := regexp.MustCompile( - `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + `^timeout discarding 2 bytes: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) - assert.True(t, wantErr.MatchString(bgErrs[0].Error()), "error %q does not match pattern %q", bgErrs[0], wantErr) + assert.True(t, wantErr.MatchString(pendingReadError.Error()), "error %q does not match pattern %q", pendingReadError, wantErr) }) t.Run("timeout reading full message, successful background read", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1386,17 +1394,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1404,22 +1430,13 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 0, "expected no error from bgRead()") + + _, err = p.checkOut(context.Background()) + require.NoError(t, err) + + require.NoError(t, pendingReadError) }) t.Run("timeout reading full message, background read EOF", func(t *testing.T) { - errsCh := make(chan []error) - var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) - t.Cleanup(func() { - BGReadCallback = originalCallback - }) - timeout := 10 * time.Millisecond addr := bootstrapConnections(t, 1, func(nc net.Conn) { @@ -1436,17 +1453,35 @@ func TestBackgroundRead(t *testing.T) { require.NoError(t, err) }) + var pendingReadError error + monitor := &event.PoolMonitor{ + Event: func(pe *event.PoolEvent) { + if pe.Type == event.ConnectionPendingResponseFailed { + pendingReadError = pe.Error + } + }, + } + p := newPool( - poolConfig{Address: address.Address(addr.String())}, + poolConfig{ + Address: address.Address(addr.String()), + PoolMonitor: monitor, + }, ) + defer p.close(context.Background()) err := p.ready() require.NoError(t, err) conn, err := p.checkOut(context.Background()) require.NoError(t, err) + ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() + + ctx = driverutil.WithValueHasMaxTimeMS(ctx, true) + ctx = driverutil.WithRequestID(ctx, -1) + _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, @@ -1454,14 +1489,11 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) require.NoError(t, err) - var bgErrs []error - select { - case bgErrs = <-errsCh: - case <-time.After(3 * time.Second): - assert.Fail(t, "did not receive expected error after waiting for 3 seconds") - } - require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") - assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF") + + _, err = p.checkOut(context.Background()) + require.Error(t, err) + + assert.EqualError(t, pendingReadError, "reading body: EOF") }) }