blob: 841efdff5c231066382168aeae2147c85d4600d3 [file] [log] [blame]
// Copyright (C) 2023 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import com.google.common.base.Strings
import com.google.common.flogger.FluentLogger
import com.google.gerrit.extensions.annotations.Listen
import com.google.gerrit.extensions.annotations.PluginName
import com.google.gerrit.extensions.events.LifecycleListener
import com.google.gerrit.metrics.CallbackMetric1
import com.google.gerrit.metrics.Description
import com.google.gerrit.metrics.Field
import com.google.gerrit.metrics.MetricMaker
import com.google.gerrit.server.config.ConfigUtil
import com.google.gerrit.server.config.PluginConfigFactory
import com.google.gerrit.server.git.WorkQueue
import com.google.gerrit.server.logging.Metadata
import com.google.inject.Inject
import com.google.inject.Singleton
import sun.security.x509.GeneralNameInterface
import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory
import java.security.cert.Certificate
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.concurrent.ScheduledFuture
import static java.util.concurrent.TimeUnit.HOURS
import static java.util.concurrent.TimeUnit.MILLISECONDS
import static java.util.concurrent.TimeUnit.SECONDS
@Singleton
@Listen
class CertificatesValidityChecker implements LifecycleListener {
private static final int DEFAULT_CHECK_INTERVAL_HOURS = 24
private final WorkQueue queue
private final PluginConfigFactory config
private final String pluginName
private final CertificatesCheckMetrics metrics
private ScheduledFuture<?> certificatesValidityChecksTask
private List<String> endpoints
private Long checkIntervalInMillis
@Inject
CertificatesValidityChecker(WorkQueue queue, PluginConfigFactory cfg,
CertificatesCheckMetrics metrics,
@PluginName String pluginName) {
this.metrics = metrics
this.queue = queue
this.config = cfg
this.pluginName = pluginName
}
@Override
void start() {
endpoints = getEndpointsList(config, pluginName)
checkIntervalInMillis = getCheckIntervalMillis(config, pluginName)
certificatesValidityChecksTask = queue.getDefaultQueue()
.scheduleAtFixedRate(
new CheckCertificatesValidityTask(metrics, endpoints),
SECONDS.toMillis(1),
checkIntervalInMillis,
MILLISECONDS)
}
@Override
void stop() {
if (certificatesValidityChecksTask != null) {
certificatesValidityChecksTask.cancel(true)
certificatesValidityChecksTask = null
}
}
private Long getCheckIntervalMillis(PluginConfigFactory cfg, String pluginName) {
String fromConfig =
Strings.nullToEmpty(cfg.getGlobalPluginConfig(pluginName).getString("validation",null,"checkInterval"))
return HOURS.toMillis(ConfigUtil.getTimeUnit(fromConfig, DEFAULT_CHECK_INTERVAL_HOURS, HOURS))
}
private List<String> getEndpointsList(PluginConfigFactory cfg, String pluginName) {
return cfg.getGlobalPluginConfig(pluginName).getStringList("validation",null,"endpoint")
}
private static class CertificatesCheckMetrics {
private static final Field<String> ENDPOINT_NAME =
Field.ofString("endpoint_name", Metadata.Builder.&cacheName).build()
private final CallbackMetric1<String, Integer> metrics
@Inject
CertificatesCheckMetrics(MetricMaker metricMaker) {
this.metrics =
metricMaker.newCallbackMetric(
"certificates/number_of_day_to_expire/per_endpoint",
Integer.class,
new Description("Per-endpoint certificate expiration date")
.setGauge()
.setUnit("days"),
ENDPOINT_NAME)
}
def setMetric(String endpoint, int numberOfDays) {
metrics.set(endpoint, numberOfDays)
}
}
private static class CheckCertificatesValidityTask implements Runnable {
private static final FluentLogger logger = FluentLogger.forEnclosingClass()
private final CertificatesCheckMetrics metrics
private final List<String> endpoints
CheckCertificatesValidityTask(CertificatesCheckMetrics metrics, List<String> endpoints) {
this.endpoints = endpoints
this.metrics = metrics
}
@Override
void run() {
for (String endpoint : endpoints) {
logger.atInfo().log("Checking certificate expiry date for %s endpoint", endpoint)
SSLSocket conn
try {
def (hostname, port) = parseEndpoint(endpoint)
conn = openConnection(hostname as String, port as int)
conn.startHandshake();
Certificate[] certs = conn.getSession().getPeerCertificates();
for (Certificate cert : certs) {
if (cert instanceof X509Certificate &&
cert.getSubjectAlternativeNames().findAll{it[0] == GeneralNameInterface.NAME_DNS}
.any {isHostnameMatching(hostname as String, it.get(1) as String) }) {
def numberOfDaysToExpire = Duration
.between(new Date().toInstant(), cert.notAfter.toInstant()).toDays()
metrics
.setMetric(
hostname as String,
numberOfDaysToExpire.intValue())
} else {
logger.atFine().log("Certificate type %s is not a valid X.509 certificate for the specified endpoint: %s. Skipping!", cert.getType(), endpoint)
}
}
} catch(e) {
logger.atSevere()
.withCause(e)
.log("Cannot check certificates expiry date for %s endpoint", endpoint)
} finally {
conn?.close()
}
}
}
def parseEndpoint(String endpoint) {
def hostAndPort = endpoint.split(':')
if (hostAndPort.size() != 2) {
throw new IllegalArgumentException("Wrong endpoint format, expected <host>:<port> but was ${endpoint}")
}
hostAndPort
}
private boolean isHostnameMatching(String hostname, String certName) {
// Replace the wildcard (*) with a regex wildcard (.*)
def certPattern = certName.replaceFirst("\\*", ".*")
return hostname.matches(certPattern)
}
private SSLSocket openConnection(String hostname, int port) {
logger
.atInfo()
.log("Opening connection for %s endpoint successful",
hostname)
(SSLSocket) SSLSocketFactory.getDefault()
.createSocket(hostname, port);
}
}
}