// Copyright (C) 2023 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import com.google.common.base.Strings
import com.google.common.flogger.FluentLogger
import com.google.gerrit.extensions.annotations.Listen
import com.google.gerrit.extensions.annotations.PluginName
import com.google.gerrit.extensions.events.LifecycleListener
import com.google.gerrit.metrics.CallbackMetric1
import com.google.gerrit.metrics.Description
import com.google.gerrit.metrics.Field
import com.google.gerrit.metrics.MetricMaker
import com.google.gerrit.server.config.ConfigUtil
import com.google.gerrit.server.config.PluginConfigFactory
import com.google.gerrit.server.git.WorkQueue
import com.google.gerrit.server.logging.Metadata
import com.google.inject.Inject
import com.google.inject.Singleton
import sun.security.x509.GeneralNameInterface

import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory
import java.security.cert.Certificate
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.concurrent.ScheduledFuture

import static java.util.concurrent.TimeUnit.HOURS
import static java.util.concurrent.TimeUnit.MILLISECONDS
import static java.util.concurrent.TimeUnit.SECONDS

@Singleton
@Listen
class CertificatesValidityChecker implements LifecycleListener {
  private static final int DEFAULT_CHECK_INTERVAL_HOURS = 24
  private final WorkQueue queue
  private final PluginConfigFactory config
  private final String pluginName
  private final CertificatesCheckMetrics metrics

  private ScheduledFuture<?> certificatesValidityChecksTask
  private List<String> endpoints
  private Long checkIntervalInMillis

  @Inject
  CertificatesValidityChecker(WorkQueue queue, PluginConfigFactory cfg,
                                   CertificatesCheckMetrics metrics,
                                   @PluginName String pluginName) {
    this.metrics = metrics
    this.queue = queue
    this.config = cfg
    this.pluginName = pluginName
  }

  @Override
  void start() {
    endpoints = getEndpointsList(config, pluginName)
    checkIntervalInMillis = getCheckIntervalMillis(config, pluginName)
    certificatesValidityChecksTask = queue.getDefaultQueue()
        .scheduleAtFixedRate(
            new CheckCertificatesValidityTask(metrics, endpoints),
            SECONDS.toMillis(1),
            checkIntervalInMillis,
            MILLISECONDS)
  }

  @Override
  void stop() {
    if (certificatesValidityChecksTask != null) {
      certificatesValidityChecksTask.cancel(true)
      certificatesValidityChecksTask = null
    }
  }

  private Long getCheckIntervalMillis(PluginConfigFactory cfg, String pluginName) {
    String fromConfig =
        Strings.nullToEmpty(cfg.getGlobalPluginConfig(pluginName).getString("validation",null,"checkInterval"))
    return HOURS.toMillis(ConfigUtil.getTimeUnit(fromConfig, DEFAULT_CHECK_INTERVAL_HOURS, HOURS))
  }

  private List<String> getEndpointsList(PluginConfigFactory cfg, String pluginName) {
    return cfg.getGlobalPluginConfig(pluginName).getStringList("validation",null,"endpoint")
  }

  private static class CertificatesCheckMetrics {
    private static final Field<String> ENDPOINT_NAME =
        Field.ofString("endpoint_name", Metadata.Builder.&cacheName).build()
    private final CallbackMetric1<String, Integer> metrics

    @Inject
    CertificatesCheckMetrics(MetricMaker metricMaker) {
      this.metrics =
          metricMaker.newCallbackMetric(
              "certificates/number_of_day_to_expire/per_endpoint",
              Integer.class,
              new Description("Per-endpoint certificate expiration date")
                  .setGauge()
                  .setUnit("days"),
              ENDPOINT_NAME)
    }

    def setMetric(String endpoint, int numberOfDays) {
      metrics.set(endpoint, numberOfDays)
    }
  }

  private static class CheckCertificatesValidityTask implements Runnable {
    private static final FluentLogger logger = FluentLogger.forEnclosingClass()
    private final CertificatesCheckMetrics metrics
    private final List<String> endpoints

    CheckCertificatesValidityTask(CertificatesCheckMetrics metrics, List<String> endpoints) {
      this.endpoints = endpoints
      this.metrics = metrics
    }

    @Override
    void run() {
      for (String endpoint : endpoints) {
        logger.atInfo().log("Checking certificate expiry date for %s endpoint", endpoint)
        SSLSocket conn
        try {
          def (hostname, port) = parseEndpoint(endpoint)
          conn = openConnection(hostname as String, port as int)
          conn.startHandshake();
          Certificate[] certs = conn.getSession().getPeerCertificates();
          for (Certificate cert : certs) {
            if (cert instanceof X509Certificate &&
              cert.getSubjectAlternativeNames().findAll{it[0] == GeneralNameInterface.NAME_DNS}
                  .any {isHostnameMatching(hostname as String, it.get(1) as String) }) {
              def numberOfDaysToExpire = Duration
                  .between(new Date().toInstant(), cert.notAfter.toInstant()).toDays()
              metrics
                  .setMetric(
                      hostname as String,
                      numberOfDaysToExpire.intValue())
            } else {
              logger.atFine().log("Certificate type %s is not a valid X.509 certificate for the specified endpoint: %s. Skipping!", cert.getType(), endpoint)
            }
          }
        } catch(e) {
          logger.atSevere()
              .withCause(e)
              .log("Cannot check certificates expiry date for %s endpoint", endpoint)
        } finally {
            conn?.close()
        }
      }
    }

    def parseEndpoint(String endpoint) {
      def hostAndPort = endpoint.split(':')
      if (hostAndPort.size() != 2) {
        throw new IllegalArgumentException("Wrong endpoint format, expected <host>:<port> but was ${endpoint}")
      }

      hostAndPort
    }

    private boolean isHostnameMatching(String hostname, String certName) {
      // Replace the wildcard (*) with a regex wildcard (.*)
      def certPattern = certName.replaceFirst("\\*", ".*")
      return hostname.matches(certPattern)
    }

    private SSLSocket openConnection(String hostname, int port) {
      logger
          .atInfo()
          .log("Opening connection for %s endpoint successful",
              hostname)
      (SSLSocket) SSLSocketFactory.getDefault()
          .createSocket(hostname, port);
    }
  }
}
