ssh: move all ssh logic to a common place

We had ssh logic sprinkled between two git modules, and neither was
quite the right home for it.  This largely moves the logic as-is to
its new home.  We'll leave major refactoring to followup commits.

Bug: https://crbug.com/gerrit/12389
Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
Tested-by: Mike Frysinger <vapier@google.com>
diff --git a/git_command.py b/git_command.py
index f8cb280..fabad0e 100644
--- a/git_command.py
+++ b/git_command.py
@@ -14,16 +14,14 @@
 
 import functools
 import os
-import re
 import sys
 import subprocess
-import tempfile
-from signal import SIGTERM
 
 from error import GitError
 from git_refs import HEAD
 import platform_utils
 from repo_trace import REPO_TRACE, IsTrace, Trace
+import ssh
 from wrapper import Wrapper
 
 GIT = 'git'
@@ -43,85 +41,6 @@
 LAST_GITDIR = None
 LAST_CWD = None
 
-_ssh_proxy_path = None
-_ssh_sock_path = None
-_ssh_clients = []
-
-
-def _run_ssh_version():
-  """run ssh -V to display the version number"""
-  return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
-
-
-def _parse_ssh_version(ver_str=None):
-  """parse a ssh version string into a tuple"""
-  if ver_str is None:
-    ver_str = _run_ssh_version()
-  m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
-  if m:
-    return tuple(int(x) for x in m.group(1).split('.'))
-  else:
-    return ()
-
-
-@functools.lru_cache(maxsize=None)
-def ssh_version():
-  """return ssh version as a tuple"""
-  try:
-    return _parse_ssh_version()
-  except subprocess.CalledProcessError:
-    print('fatal: unable to detect ssh version', file=sys.stderr)
-    sys.exit(1)
-
-
-def ssh_sock(create=True):
-  global _ssh_sock_path
-  if _ssh_sock_path is None:
-    if not create:
-      return None
-    tmp_dir = '/tmp'
-    if not os.path.exists(tmp_dir):
-      tmp_dir = tempfile.gettempdir()
-    if ssh_version() < (6, 7):
-      tokens = '%r@%h:%p'
-    else:
-      tokens = '%C'  # hash of %l%h%p%r
-    _ssh_sock_path = os.path.join(
-        tempfile.mkdtemp('', 'ssh-', tmp_dir),
-        'master-' + tokens)
-  return _ssh_sock_path
-
-
-def _ssh_proxy():
-  global _ssh_proxy_path
-  if _ssh_proxy_path is None:
-    _ssh_proxy_path = os.path.join(
-        os.path.dirname(__file__),
-        'git_ssh')
-  return _ssh_proxy_path
-
-
-def _add_ssh_client(p):
-  _ssh_clients.append(p)
-
-
-def _remove_ssh_client(p):
-  try:
-    _ssh_clients.remove(p)
-  except ValueError:
-    pass
-
-
-def terminate_ssh_clients():
-  global _ssh_clients
-  for p in _ssh_clients:
-    try:
-      os.kill(p.pid, SIGTERM)
-      p.wait()
-    except OSError:
-      pass
-  _ssh_clients = []
-
 
 class _GitCall(object):
   @functools.lru_cache(maxsize=None)
@@ -256,8 +175,8 @@
     if disable_editor:
       env['GIT_EDITOR'] = ':'
     if ssh_proxy:
-      env['REPO_SSH_SOCK'] = ssh_sock()
-      env['GIT_SSH'] = _ssh_proxy()
+      env['REPO_SSH_SOCK'] = ssh.sock()
+      env['GIT_SSH'] = ssh.proxy()
       env['GIT_SSH_VARIANT'] = 'ssh'
     if 'http_proxy' in env and 'darwin' == sys.platform:
       s = "'http.proxy=%s'" % (env['http_proxy'],)
@@ -340,7 +259,7 @@
       raise GitError('%s: %s' % (command[1], e))
 
     if ssh_proxy:
-      _add_ssh_client(p)
+      ssh.add_client(p)
 
     self.process = p
     if input:
@@ -352,7 +271,7 @@
     try:
       self.stdout, self.stderr = p.communicate()
     finally:
-      _remove_ssh_client(p)
+      ssh.remove_client(p)
     self.rc = p.wait()
 
   @staticmethod
diff --git a/git_config.py b/git_config.py
index fcd0446..1d8d136 100644
--- a/git_config.py
+++ b/git_config.py
@@ -18,25 +18,17 @@
 import json
 import os
 import re
-import signal
 import ssl
 import subprocess
 import sys
-try:
-  import threading as _threading
-except ImportError:
-  import dummy_threading as _threading
-import time
 import urllib.error
 import urllib.request
 
 from error import GitError, UploadError
 import platform_utils
 from repo_trace import Trace
-
+import ssh
 from git_command import GitCommand
-from git_command import ssh_sock
-from git_command import terminate_ssh_clients
 from git_refs import R_CHANGES, R_HEADS, R_TAGS
 
 ID_RE = re.compile(r'^[0-9a-f]{40}$')
@@ -440,129 +432,6 @@
     return s
 
 
-_master_processes = []
-_master_keys = set()
-_ssh_master = True
-_master_keys_lock = None
-
-
-def init_ssh():
-  """Should be called once at the start of repo to init ssh master handling.
-
-  At the moment, all we do is to create our lock.
-  """
-  global _master_keys_lock
-  assert _master_keys_lock is None, "Should only call init_ssh once"
-  _master_keys_lock = _threading.Lock()
-
-
-def _open_ssh(host, port=None):
-  global _ssh_master
-
-  # Bail before grabbing the lock if we already know that we aren't going to
-  # try creating new masters below.
-  if sys.platform in ('win32', 'cygwin'):
-    return False
-
-  # Acquire the lock.  This is needed to prevent opening multiple masters for
-  # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
-  # manifest <remote fetch="ssh://xyz"> specifies a different host from the
-  # one that was passed to repo init.
-  _master_keys_lock.acquire()
-  try:
-
-    # Check to see whether we already think that the master is running; if we
-    # think it's already running, return right away.
-    if port is not None:
-      key = '%s:%s' % (host, port)
-    else:
-      key = host
-
-    if key in _master_keys:
-      return True
-
-    if not _ssh_master or 'GIT_SSH' in os.environ:
-      # Failed earlier, so don't retry.
-      return False
-
-    # We will make two calls to ssh; this is the common part of both calls.
-    command_base = ['ssh',
-                    '-o', 'ControlPath %s' % ssh_sock(),
-                    host]
-    if port is not None:
-      command_base[1:1] = ['-p', str(port)]
-
-    # Since the key wasn't in _master_keys, we think that master isn't running.
-    # ...but before actually starting a master, we'll double-check.  This can
-    # be important because we can't tell that that 'git@myhost.com' is the same
-    # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
-    check_command = command_base + ['-O', 'check']
-    try:
-      Trace(': %s', ' '.join(check_command))
-      check_process = subprocess.Popen(check_command,
-                                       stdout=subprocess.PIPE,
-                                       stderr=subprocess.PIPE)
-      check_process.communicate()  # read output, but ignore it...
-      isnt_running = check_process.wait()
-
-      if not isnt_running:
-        # Our double-check found that the master _was_ infact running.  Add to
-        # the list of keys.
-        _master_keys.add(key)
-        return True
-    except Exception:
-      # Ignore excpetions.  We we will fall back to the normal command and print
-      # to the log there.
-      pass
-
-    command = command_base[:1] + ['-M', '-N'] + command_base[1:]
-    try:
-      Trace(': %s', ' '.join(command))
-      p = subprocess.Popen(command)
-    except Exception as e:
-      _ssh_master = False
-      print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
-            % (host, port, str(e)), file=sys.stderr)
-      return False
-
-    time.sleep(1)
-    ssh_died = (p.poll() is not None)
-    if ssh_died:
-      return False
-
-    _master_processes.append(p)
-    _master_keys.add(key)
-    return True
-  finally:
-    _master_keys_lock.release()
-
-
-def close_ssh():
-  global _master_keys_lock
-
-  terminate_ssh_clients()
-
-  for p in _master_processes:
-    try:
-      os.kill(p.pid, signal.SIGTERM)
-      p.wait()
-    except OSError:
-      pass
-  del _master_processes[:]
-  _master_keys.clear()
-
-  d = ssh_sock(create=False)
-  if d:
-    try:
-      platform_utils.rmdir(os.path.dirname(d))
-    except OSError:
-      pass
-
-  # We're done with the lock, so we can delete it.
-  _master_keys_lock = None
-
-
-URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
 URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
 
 
@@ -614,27 +483,6 @@
   yield cookiefile, None
 
 
-def _preconnect(url):
-  m = URI_ALL.match(url)
-  if m:
-    scheme = m.group(1)
-    host = m.group(2)
-    if ':' in host:
-      host, port = host.split(':')
-    else:
-      port = None
-    if scheme in ('ssh', 'git+ssh', 'ssh+git'):
-      return _open_ssh(host, port)
-    return False
-
-  m = URI_SCP.match(url)
-  if m:
-    host = m.group(1)
-    return _open_ssh(host)
-
-  return False
-
-
 class Remote(object):
   """Configuration options related to a remote.
   """
@@ -673,7 +521,7 @@
 
   def PreConnectFetch(self):
     connectionUrl = self._InsteadOf()
-    return _preconnect(connectionUrl)
+    return ssh.preconnect(connectionUrl)
 
   def ReviewUrl(self, userEmail, validate_certs):
     if self._review_url is None:
diff --git a/main.py b/main.py
index 8aba2ec..9674433 100755
--- a/main.py
+++ b/main.py
@@ -39,7 +39,7 @@
 import event_log
 from repo_trace import SetTrace
 from git_command import user_agent
-from git_config import init_ssh, close_ssh, RepoConfig
+from git_config import RepoConfig
 from git_trace2_event_log import EventLog
 from command import InteractiveCommand
 from command import MirrorSafeCommand
@@ -56,6 +56,7 @@
 import gitc_utils
 from manifest_xml import GitcClient, RepoClient
 from pager import RunPager, TerminatePager
+import ssh
 from wrapper import WrapperPath, Wrapper
 
 from subcmds import all_commands
@@ -592,7 +593,7 @@
   repo = _Repo(opt.repodir)
   try:
     try:
-      init_ssh()
+      ssh.init()
       init_http()
       name, gopts, argv = repo._ParseArgs(argv)
       run = lambda: repo._Run(name, gopts, argv) or 0
@@ -604,7 +605,7 @@
       else:
         result = run()
     finally:
-      close_ssh()
+      ssh.close()
   except KeyboardInterrupt:
     print('aborted by user', file=sys.stderr)
     result = 1
diff --git a/ssh.py b/ssh.py
new file mode 100644
index 0000000..d06c4eb
--- /dev/null
+++ b/ssh.py
@@ -0,0 +1,257 @@
+# Copyright (C) 2008 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common SSH management logic."""
+
+import functools
+import os
+import re
+import signal
+import subprocess
+import sys
+import tempfile
+try:
+  import threading as _threading
+except ImportError:
+  import dummy_threading as _threading
+import time
+
+import platform_utils
+from repo_trace import Trace
+
+
+_ssh_proxy_path = None
+_ssh_sock_path = None
+_ssh_clients = []
+
+
+def _run_ssh_version():
+  """run ssh -V to display the version number"""
+  return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
+
+
+def _parse_ssh_version(ver_str=None):
+  """parse a ssh version string into a tuple"""
+  if ver_str is None:
+    ver_str = _run_ssh_version()
+  m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
+  if m:
+    return tuple(int(x) for x in m.group(1).split('.'))
+  else:
+    return ()
+
+
+@functools.lru_cache(maxsize=None)
+def version():
+  """return ssh version as a tuple"""
+  try:
+    return _parse_ssh_version()
+  except subprocess.CalledProcessError:
+    print('fatal: unable to detect ssh version', file=sys.stderr)
+    sys.exit(1)
+
+
+def proxy():
+  global _ssh_proxy_path
+  if _ssh_proxy_path is None:
+    _ssh_proxy_path = os.path.join(
+        os.path.dirname(__file__),
+        'git_ssh')
+  return _ssh_proxy_path
+
+
+def add_client(p):
+  _ssh_clients.append(p)
+
+
+def remove_client(p):
+  try:
+    _ssh_clients.remove(p)
+  except ValueError:
+    pass
+
+
+def _terminate_clients():
+  global _ssh_clients
+  for p in _ssh_clients:
+    try:
+      os.kill(p.pid, signal.SIGTERM)
+      p.wait()
+    except OSError:
+      pass
+  _ssh_clients = []
+
+
+_master_processes = []
+_master_keys = set()
+_ssh_master = True
+_master_keys_lock = None
+
+
+def init():
+  """Should be called once at the start of repo to init ssh master handling.
+
+  At the moment, all we do is to create our lock.
+  """
+  global _master_keys_lock
+  assert _master_keys_lock is None, "Should only call init once"
+  _master_keys_lock = _threading.Lock()
+
+
+def _open_ssh(host, port=None):
+  global _ssh_master
+
+  # Bail before grabbing the lock if we already know that we aren't going to
+  # try creating new masters below.
+  if sys.platform in ('win32', 'cygwin'):
+    return False
+
+  # Acquire the lock.  This is needed to prevent opening multiple masters for
+  # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
+  # manifest <remote fetch="ssh://xyz"> specifies a different host from the
+  # one that was passed to repo init.
+  _master_keys_lock.acquire()
+  try:
+
+    # Check to see whether we already think that the master is running; if we
+    # think it's already running, return right away.
+    if port is not None:
+      key = '%s:%s' % (host, port)
+    else:
+      key = host
+
+    if key in _master_keys:
+      return True
+
+    if not _ssh_master or 'GIT_SSH' in os.environ:
+      # Failed earlier, so don't retry.
+      return False
+
+    # We will make two calls to ssh; this is the common part of both calls.
+    command_base = ['ssh',
+                    '-o', 'ControlPath %s' % sock(),
+                    host]
+    if port is not None:
+      command_base[1:1] = ['-p', str(port)]
+
+    # Since the key wasn't in _master_keys, we think that master isn't running.
+    # ...but before actually starting a master, we'll double-check.  This can
+    # be important because we can't tell that that 'git@myhost.com' is the same
+    # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
+    check_command = command_base + ['-O', 'check']
+    try:
+      Trace(': %s', ' '.join(check_command))
+      check_process = subprocess.Popen(check_command,
+                                       stdout=subprocess.PIPE,
+                                       stderr=subprocess.PIPE)
+      check_process.communicate()  # read output, but ignore it...
+      isnt_running = check_process.wait()
+
+      if not isnt_running:
+        # Our double-check found that the master _was_ infact running.  Add to
+        # the list of keys.
+        _master_keys.add(key)
+        return True
+    except Exception:
+      # Ignore excpetions.  We we will fall back to the normal command and print
+      # to the log there.
+      pass
+
+    command = command_base[:1] + ['-M', '-N'] + command_base[1:]
+    try:
+      Trace(': %s', ' '.join(command))
+      p = subprocess.Popen(command)
+    except Exception as e:
+      _ssh_master = False
+      print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
+            % (host, port, str(e)), file=sys.stderr)
+      return False
+
+    time.sleep(1)
+    ssh_died = (p.poll() is not None)
+    if ssh_died:
+      return False
+
+    _master_processes.append(p)
+    _master_keys.add(key)
+    return True
+  finally:
+    _master_keys_lock.release()
+
+
+def close():
+  global _master_keys_lock
+
+  _terminate_clients()
+
+  for p in _master_processes:
+    try:
+      os.kill(p.pid, signal.SIGTERM)
+      p.wait()
+    except OSError:
+      pass
+  del _master_processes[:]
+  _master_keys.clear()
+
+  d = sock(create=False)
+  if d:
+    try:
+      platform_utils.rmdir(os.path.dirname(d))
+    except OSError:
+      pass
+
+  # We're done with the lock, so we can delete it.
+  _master_keys_lock = None
+
+
+URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
+URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
+
+
+def preconnect(url):
+  m = URI_ALL.match(url)
+  if m:
+    scheme = m.group(1)
+    host = m.group(2)
+    if ':' in host:
+      host, port = host.split(':')
+    else:
+      port = None
+    if scheme in ('ssh', 'git+ssh', 'ssh+git'):
+      return _open_ssh(host, port)
+    return False
+
+  m = URI_SCP.match(url)
+  if m:
+    host = m.group(1)
+    return _open_ssh(host)
+
+  return False
+
+def sock(create=True):
+  global _ssh_sock_path
+  if _ssh_sock_path is None:
+    if not create:
+      return None
+    tmp_dir = '/tmp'
+    if not os.path.exists(tmp_dir):
+      tmp_dir = tempfile.gettempdir()
+    if version() < (6, 7):
+      tokens = '%r@%h:%p'
+    else:
+      tokens = '%C'  # hash of %l%h%p%r
+    _ssh_sock_path = os.path.join(
+        tempfile.mkdtemp('', 'ssh-', tmp_dir),
+        'master-' + tokens)
+  return _ssh_sock_path
diff --git a/tests/test_git_command.py b/tests/test_git_command.py
index 76c092f..93300a6 100644
--- a/tests/test_git_command.py
+++ b/tests/test_git_command.py
@@ -26,38 +26,6 @@
 import wrapper
 
 
-class SSHUnitTest(unittest.TestCase):
-  """Tests the ssh functions."""
-
-  def test_parse_ssh_version(self):
-    """Check parse_ssh_version() handling."""
-    ver = git_command._parse_ssh_version('Unknown\n')
-    self.assertEqual(ver, ())
-    ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
-    self.assertEqual(ver, (1, 0))
-    ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
-    self.assertEqual(ver, (6, 6, 1))
-    ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n  7 Dec 2017\n')
-    self.assertEqual(ver, (7, 6))
-
-  def test_ssh_version(self):
-    """Check ssh_version() handling."""
-    with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
-      self.assertEqual(git_command.ssh_version(), (1, 2))
-
-  def test_ssh_sock(self):
-    """Check ssh_sock() function."""
-    with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
-      # old ssh version uses port
-      with mock.patch('git_command.ssh_version', return_value=(6, 6)):
-        self.assertTrue(git_command.ssh_sock().endswith('%p'))
-      git_command._ssh_sock_path = None
-      # new ssh version uses hash
-      with mock.patch('git_command.ssh_version', return_value=(6, 7)):
-        self.assertTrue(git_command.ssh_sock().endswith('%C'))
-      git_command._ssh_sock_path = None
-
-
 class GitCallUnitTest(unittest.TestCase):
   """Tests the _GitCall class (via git_command.git)."""
 
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
new file mode 100644
index 0000000..5a4f27e
--- /dev/null
+++ b/tests/test_ssh.py
@@ -0,0 +1,52 @@
+# Copyright 2019 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittests for the ssh.py module."""
+
+import unittest
+from unittest import mock
+
+import ssh
+
+
+class SshTests(unittest.TestCase):
+  """Tests the ssh functions."""
+
+  def test_parse_ssh_version(self):
+    """Check _parse_ssh_version() handling."""
+    ver = ssh._parse_ssh_version('Unknown\n')
+    self.assertEqual(ver, ())
+    ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
+    self.assertEqual(ver, (1, 0))
+    ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
+    self.assertEqual(ver, (6, 6, 1))
+    ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n  7 Dec 2017\n')
+    self.assertEqual(ver, (7, 6))
+
+  def test_version(self):
+    """Check version() handling."""
+    with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
+      self.assertEqual(ssh.version(), (1, 2))
+
+  def test_ssh_sock(self):
+    """Check sock() function."""
+    with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
+      # old ssh version uses port
+      with mock.patch('ssh.version', return_value=(6, 6)):
+        self.assertTrue(ssh.sock().endswith('%p'))
+      ssh._ssh_sock_path = None
+      # new ssh version uses hash
+      with mock.patch('ssh.version', return_value=(6, 7)):
+        self.assertTrue(ssh.sock().endswith('%C'))
+      ssh._ssh_sock_path = None