Skip to content
217 changes: 152 additions & 65 deletions src/snowflake/connector/crl.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,61 +270,62 @@ def from_config(
cache_manager=cache_manager,
)

def validate_certificate_chains(
self, certificate_chains: list[list[x509.Certificate]]
def validate_certificate_chain(
self, peer_cert: x509.Certificate, chain: list[x509.Certificate] | None
) -> bool:
"""
Validate certificate chains against CRLs with actual HTTP requests
Validate a certificate chain against CRLs with actual HTTP requests

Args:
certificate_chains: List of certificate chains to validate
peer_cert: The peer certificate to validate (e.g., server certificate)
chain: Certificate chain to use for validation (can be None or empty)

Returns:
True if validation passes, False otherwise

Raises:
ValueError: If certificate_chains is None or empty
"""
if self._cert_revocation_check_mode == CertRevocationCheckMode.DISABLED:
return True

if certificate_chains is None or len(certificate_chains) == 0:
logger.warning("Certificate chains are empty")
return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY

results = []
for chain in certificate_chains:
result = self._validate_single_chain(chain)
# If any of the chains is valid, the whole check is considered positive
if result == CRLValidationResult.UNREVOKED:
return True
results.append(result)
chain = chain if chain is not None else []
result = self._validate_chain(peer_cert, chain)

# In non-advisory mode we require at least one chain get a clear UNREVOKED status
if self._cert_revocation_check_mode != CertRevocationCheckMode.ADVISORY:
if result == CRLValidationResult.UNREVOKED:
return True
if result == CRLValidationResult.REVOKED:
return False
# In advisory mode, errors are treated positively
return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY

# We're in advisory mode, so any error is treated positively
return any(result == CRLValidationResult.ERROR for result in results)

def _validate_single_chain(
self, chain: list[x509.Certificate]
def _validate_chain(
self, start_cert: x509.Certificate, chain: list[x509.Certificate]
) -> CRLValidationResult:
"""
Validate a certificate chain starting from start_cert.

Args:
start_cert: The certificate to start validation from
chain: List of certificates to use for building the trust path

Returns:
UNREVOKED: If there is a path to any trusted certificate where all certificates are unrevoked.
REVOKED: If all paths to trusted certificates are revoked.
ERROR: If there is a path to any trusted certificate on which none certificate is revoked,
but some certificates can't be verified.
"""
# An empty chain is considered an error
if len(chain) == 0:
# Check if start certificate is expired
if not self._is_valid(start_cert):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/optional: I would be tempted to move away from is_valid which generally means more than simply having correct dates. The documentation is very helpful regardless.

Maybe _is_within_lifetime or _is_within_validity_dates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, changed

logger.warning(
"Start certificate is expired or not yet valid: %s", start_cert.subject
)
return CRLValidationResult.ERROR

subject_certificates: dict[x509.Name, list[x509.Certificate]] = defaultdict(
list
)
for cert in chain:
if not self._is_ca_certificate(cert):
logger.warning("Ignoring non-CA certificate: %s", cert)
continue
Comment on lines +326 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also be checking the lifetimes of the intermediates.

I realize we are almost certainly not going to get attacked by an expired leaked intermediate, but it should be cheap to add and would be closer to spec.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! I've overlooked this - fixed the check and tests

subject_certificates[cert.subject].append(cert)
currently_visited_subjects: set[x509.Name] = set()

Expand Down Expand Up @@ -387,19 +388,7 @@ def traverse_chain(cert: x509.Certificate) -> CRLValidationResult | None:
# no ERROR result found, all paths are REVOKED
return CRLValidationResult.REVOKED

currently_visited_subjects.add(chain[0].subject)
error_result = False
revoked_result = False
for cert in subject_certificates[chain[0].subject]:
result = traverse_chain(cert)
if result == CRLValidationResult.UNREVOKED:
return result
error_result |= result == CRLValidationResult.ERROR
revoked_result |= result == CRLValidationResult.REVOKED

if error_result or not revoked_result:
return CRLValidationResult.ERROR
return CRLValidationResult.REVOKED
return traverse_chain(start_cert)

def _is_certificate_trusted_by_os(self, cert: x509.Certificate) -> bool:
if cert.subject not in self._trusted_ca:
Expand All @@ -426,6 +415,51 @@ def _verify_certificate_signature(
except Exception:
return False

@staticmethod
def _is_ca_certificate(ca_cert: x509.Certificate) -> bool:
# Check if a certificate has basicConstraints extension with CA flag set to True.
try:
basic_constraints = ca_cert.extensions.get_extension_for_oid(
ExtensionOID.BASIC_CONSTRAINTS
).value
return basic_constraints.ca
except x509.ExtensionNotFound:
# If the extension is not present, the certificate is not a CA
return False

@staticmethod
def _get_certificate_validity_dates(
cert: x509.Certificate,
) -> tuple[datetime, datetime]:
# Extract UTC-aware validity dates from a certificate.

try:
# Use timezone-aware versions to avoid deprecation warnings
not_valid_before = cert.not_valid_before_utc
not_valid_after = cert.not_valid_after_utc
except AttributeError:
# Fallback for older versions without _utc methods
not_valid_before = cert.not_valid_before
not_valid_after = cert.not_valid_after

# Convert to UTC if not timezone-aware
if not_valid_before.tzinfo is None:
not_valid_before = not_valid_before.replace(tzinfo=timezone.utc)
if not_valid_after.tzinfo is None:
not_valid_after = not_valid_after.replace(tzinfo=timezone.utc)

return not_valid_before, not_valid_after

@staticmethod
def _is_valid(cert: x509.Certificate) -> bool:
# Check if a certificate is currently valid (not expired and not before validity period).

not_valid_before, not_valid_after = (
CRLValidator._get_certificate_validity_dates(cert)
)
now = datetime.now(timezone.utc)
return not_valid_before <= now <= not_valid_after

def _validate_certificate_is_not_revoked_with_cache(
self, cert: x509.Certificate, ca_cert: x509.Certificate
) -> CRLValidationResult:
Expand Down Expand Up @@ -474,18 +508,8 @@ def _is_short_lived_certificate(cert: x509.Certificate) -> bool:
- For certificates issued on or after 15 March 2026:
validity period <= 7 days (604,800 seconds)
"""
try:
# Use timezone.utc versions to avoid deprecation warnings
issue_date = cert.not_valid_before_utc
validity_period = cert.not_valid_after_utc - cert.not_valid_before_utc
except AttributeError:
# Fallback for older versions
issue_date = cert.not_valid_before
validity_period = cert.not_valid_after - cert.not_valid_before

# Convert issue_date to UTC if it's not timezone-aware
if issue_date.tzinfo is None:
issue_date = issue_date.replace(tzinfo=timezone.utc)
issue_date, expiry_date = CRLValidator._get_certificate_validity_dates(cert)
validity_period = expiry_date - issue_date
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the notAfter date is inclusive so this calculation is off by one [text from RFC 5280]. I think your code falls on the safe side and CAs worth their salt say away from the edges, but we should probably fix it to show competence.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


march_15_2026 = datetime(2026, 3, 15, tzinfo=timezone.utc)
if issue_date >= march_15_2026:
Expand Down Expand Up @@ -594,6 +618,11 @@ def _check_certificate_against_crl_url(
# We cannot trust a CRL whose signature cannot be verified
return CRLValidationResult.ERROR

# Verify that the CRL URL matches the IDP extension
if not self._verify_against_idp_extension(crl, crl_url):
logger.warning("CRL URL does not match IDP extension for URL: %s", crl_url)
return CRLValidationResult.ERROR

# Check if certificate is revoked
return self._check_certificate_against_crl(cert, crl)

Expand Down Expand Up @@ -645,6 +674,52 @@ def _verify_crl_signature(
logger.warning("CRL signature verification failed: %s", e)
return False

def _verify_against_idp_extension(
self, crl: x509.CertificateRevocationList, crl_url: str
) -> bool:
# Verify that the CRL distribution point URL matches the IDP extension.
logger.debug(
"Trying to verify CRL URL against IDP extension for URL: %s", crl_url
)

try:
idp_extension = crl.extensions.get_extension_for_oid(
ExtensionOID.ISSUING_DISTRIBUTION_POINT
)
idp = idp_extension.value

# If the IDP has a distribution point, verify it matches the CRL URL
if not idp.full_name:
# according to baseline requirements this should not happen
# https://github.com/cabforum/servercert/blob/main/docs/BR.md
logger.debug(
"IDP extension has no full_name - treating as invalid",
crl_url,
)
return False

for name in idp.full_name:
if isinstance(name, x509.UniformResourceIdentifier):
if name.value == crl_url:
logger.debug("CRL URL matches IDP extension: %s", crl_url)
return True
# If we found distribution points but none matched
logger.warning(
"CRL URL %s does not match any IDP distribution point", crl_url
)
return False

except x509.ExtensionNotFound:
# If the IDP extension is not present, consider it valid
logger.debug(
"No IDP extension found in CRL, treating as valid for URL: %s", crl_url
)
return True
except Exception as e:
# If we can't parse the IDP extension, log and treat as error
logger.warning("Failed to verify IDP extension: %s", e)
return False

def _check_certificate_against_crl(
self, cert: x509.Certificate, crl: x509.CertificateRevocationList
) -> CRLValidationResult:
Expand All @@ -660,44 +735,56 @@ def validate_connection(self, connection: SSLConnection) -> bool:
"""
Validate an OpenSSL connection against CRLs.

This method extracts certificate chains from the connection and validates them
against Certificate Revocation Lists (CRLs).
This method extracts the peer certificate and certificate chain from the
connection and validates them against Certificate Revocation Lists (CRLs).

Args:
connection: OpenSSL connection object

Returns:
True if validation passes, False otherwise
"""
certificate_chains = self._extract_certificate_chains_from_connection(
connection
)
return self.validate_certificate_chains(certificate_chains)
try:
# Get the peer certificate (the start certificate)
peer_cert = connection.get_peer_certificate(as_cryptography=True)
if peer_cert is None:
logger.warning("No peer certificate found in connection")
return (
self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY
)

# Extract the certificate chain
cert_chain = self._extract_certificate_chain_from_connection(connection)

return self.validate_certificate_chain(peer_cert, cert_chain)
except Exception as e:
logger.warning("Failed to validate connection: %s", e)
return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY

def _extract_certificate_chains_from_connection(
def _extract_certificate_chain_from_connection(
self, connection
) -> list[list[x509.Certificate]]:
"""Extract certificate chains from OpenSSL connection for CRL validation.
) -> list[x509.Certificate] | None:
"""Extract certificate chain from OpenSSL connection for CRL validation.

Args:
connection: OpenSSL connection object

Returns:
List of certificate chains, where each chain is a list of x509.Certificate objects
Certificate chain as a list of x509.Certificate objects, or None on error
"""
try:
# Convert OpenSSL certificates to cryptography x509 certificates
cert_chain = connection.get_peer_cert_chain(as_cryptography=True)
if not cert_chain:
logger.debug("No certificate chain found in connection")
return []
return None
logger.debug(
"Extracted %d certificates for CRL validation", len(cert_chain)
)
return [cert_chain] # Return as a single chain
return cert_chain

except Exception as e:
logger.warning(
"Failed to extract certificate chain for CRL validation: %s", e
)
return []
return None
Loading
Loading