ssh: move all ssh logic to a common place
We had ssh logic sprinkled between two git modules, and neither was
quite the right home for it. This largely moves the logic as-is to
its new home. We'll leave major refactoring to followup commits.
Bug: https://crbug.com/gerrit/12389
Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
Tested-by: Mike Frysinger <vapier@google.com>
diff --git a/git_command.py b/git_command.py
index f8cb280..fabad0e 100644
--- a/git_command.py
+++ b/git_command.py
@@ -14,16 +14,14 @@
import functools
import os
-import re
import sys
import subprocess
-import tempfile
-from signal import SIGTERM
from error import GitError
from git_refs import HEAD
import platform_utils
from repo_trace import REPO_TRACE, IsTrace, Trace
+import ssh
from wrapper import Wrapper
GIT = 'git'
@@ -43,85 +41,6 @@
LAST_GITDIR = None
LAST_CWD = None
-_ssh_proxy_path = None
-_ssh_sock_path = None
-_ssh_clients = []
-
-
-def _run_ssh_version():
- """run ssh -V to display the version number"""
- return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
-
-
-def _parse_ssh_version(ver_str=None):
- """parse a ssh version string into a tuple"""
- if ver_str is None:
- ver_str = _run_ssh_version()
- m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
- if m:
- return tuple(int(x) for x in m.group(1).split('.'))
- else:
- return ()
-
-
-@functools.lru_cache(maxsize=None)
-def ssh_version():
- """return ssh version as a tuple"""
- try:
- return _parse_ssh_version()
- except subprocess.CalledProcessError:
- print('fatal: unable to detect ssh version', file=sys.stderr)
- sys.exit(1)
-
-
-def ssh_sock(create=True):
- global _ssh_sock_path
- if _ssh_sock_path is None:
- if not create:
- return None
- tmp_dir = '/tmp'
- if not os.path.exists(tmp_dir):
- tmp_dir = tempfile.gettempdir()
- if ssh_version() < (6, 7):
- tokens = '%r@%h:%p'
- else:
- tokens = '%C' # hash of %l%h%p%r
- _ssh_sock_path = os.path.join(
- tempfile.mkdtemp('', 'ssh-', tmp_dir),
- 'master-' + tokens)
- return _ssh_sock_path
-
-
-def _ssh_proxy():
- global _ssh_proxy_path
- if _ssh_proxy_path is None:
- _ssh_proxy_path = os.path.join(
- os.path.dirname(__file__),
- 'git_ssh')
- return _ssh_proxy_path
-
-
-def _add_ssh_client(p):
- _ssh_clients.append(p)
-
-
-def _remove_ssh_client(p):
- try:
- _ssh_clients.remove(p)
- except ValueError:
- pass
-
-
-def terminate_ssh_clients():
- global _ssh_clients
- for p in _ssh_clients:
- try:
- os.kill(p.pid, SIGTERM)
- p.wait()
- except OSError:
- pass
- _ssh_clients = []
-
class _GitCall(object):
@functools.lru_cache(maxsize=None)
@@ -256,8 +175,8 @@
if disable_editor:
env['GIT_EDITOR'] = ':'
if ssh_proxy:
- env['REPO_SSH_SOCK'] = ssh_sock()
- env['GIT_SSH'] = _ssh_proxy()
+ env['REPO_SSH_SOCK'] = ssh.sock()
+ env['GIT_SSH'] = ssh.proxy()
env['GIT_SSH_VARIANT'] = 'ssh'
if 'http_proxy' in env and 'darwin' == sys.platform:
s = "'http.proxy=%s'" % (env['http_proxy'],)
@@ -340,7 +259,7 @@
raise GitError('%s: %s' % (command[1], e))
if ssh_proxy:
- _add_ssh_client(p)
+ ssh.add_client(p)
self.process = p
if input:
@@ -352,7 +271,7 @@
try:
self.stdout, self.stderr = p.communicate()
finally:
- _remove_ssh_client(p)
+ ssh.remove_client(p)
self.rc = p.wait()
@staticmethod
diff --git a/git_config.py b/git_config.py
index fcd0446..1d8d136 100644
--- a/git_config.py
+++ b/git_config.py
@@ -18,25 +18,17 @@
import json
import os
import re
-import signal
import ssl
import subprocess
import sys
-try:
- import threading as _threading
-except ImportError:
- import dummy_threading as _threading
-import time
import urllib.error
import urllib.request
from error import GitError, UploadError
import platform_utils
from repo_trace import Trace
-
+import ssh
from git_command import GitCommand
-from git_command import ssh_sock
-from git_command import terminate_ssh_clients
from git_refs import R_CHANGES, R_HEADS, R_TAGS
ID_RE = re.compile(r'^[0-9a-f]{40}$')
@@ -440,129 +432,6 @@
return s
-_master_processes = []
-_master_keys = set()
-_ssh_master = True
-_master_keys_lock = None
-
-
-def init_ssh():
- """Should be called once at the start of repo to init ssh master handling.
-
- At the moment, all we do is to create our lock.
- """
- global _master_keys_lock
- assert _master_keys_lock is None, "Should only call init_ssh once"
- _master_keys_lock = _threading.Lock()
-
-
-def _open_ssh(host, port=None):
- global _ssh_master
-
- # Bail before grabbing the lock if we already know that we aren't going to
- # try creating new masters below.
- if sys.platform in ('win32', 'cygwin'):
- return False
-
- # Acquire the lock. This is needed to prevent opening multiple masters for
- # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
- # manifest <remote fetch="ssh://xyz"> specifies a different host from the
- # one that was passed to repo init.
- _master_keys_lock.acquire()
- try:
-
- # Check to see whether we already think that the master is running; if we
- # think it's already running, return right away.
- if port is not None:
- key = '%s:%s' % (host, port)
- else:
- key = host
-
- if key in _master_keys:
- return True
-
- if not _ssh_master or 'GIT_SSH' in os.environ:
- # Failed earlier, so don't retry.
- return False
-
- # We will make two calls to ssh; this is the common part of both calls.
- command_base = ['ssh',
- '-o', 'ControlPath %s' % ssh_sock(),
- host]
- if port is not None:
- command_base[1:1] = ['-p', str(port)]
-
- # Since the key wasn't in _master_keys, we think that master isn't running.
- # ...but before actually starting a master, we'll double-check. This can
- # be important because we can't tell that that 'git@myhost.com' is the same
- # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
- check_command = command_base + ['-O', 'check']
- try:
- Trace(': %s', ' '.join(check_command))
- check_process = subprocess.Popen(check_command,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- check_process.communicate() # read output, but ignore it...
- isnt_running = check_process.wait()
-
- if not isnt_running:
- # Our double-check found that the master _was_ infact running. Add to
- # the list of keys.
- _master_keys.add(key)
- return True
- except Exception:
- # Ignore excpetions. We we will fall back to the normal command and print
- # to the log there.
- pass
-
- command = command_base[:1] + ['-M', '-N'] + command_base[1:]
- try:
- Trace(': %s', ' '.join(command))
- p = subprocess.Popen(command)
- except Exception as e:
- _ssh_master = False
- print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
- % (host, port, str(e)), file=sys.stderr)
- return False
-
- time.sleep(1)
- ssh_died = (p.poll() is not None)
- if ssh_died:
- return False
-
- _master_processes.append(p)
- _master_keys.add(key)
- return True
- finally:
- _master_keys_lock.release()
-
-
-def close_ssh():
- global _master_keys_lock
-
- terminate_ssh_clients()
-
- for p in _master_processes:
- try:
- os.kill(p.pid, signal.SIGTERM)
- p.wait()
- except OSError:
- pass
- del _master_processes[:]
- _master_keys.clear()
-
- d = ssh_sock(create=False)
- if d:
- try:
- platform_utils.rmdir(os.path.dirname(d))
- except OSError:
- pass
-
- # We're done with the lock, so we can delete it.
- _master_keys_lock = None
-
-
-URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
@@ -614,27 +483,6 @@
yield cookiefile, None
-def _preconnect(url):
- m = URI_ALL.match(url)
- if m:
- scheme = m.group(1)
- host = m.group(2)
- if ':' in host:
- host, port = host.split(':')
- else:
- port = None
- if scheme in ('ssh', 'git+ssh', 'ssh+git'):
- return _open_ssh(host, port)
- return False
-
- m = URI_SCP.match(url)
- if m:
- host = m.group(1)
- return _open_ssh(host)
-
- return False
-
-
class Remote(object):
"""Configuration options related to a remote.
"""
@@ -673,7 +521,7 @@
def PreConnectFetch(self):
connectionUrl = self._InsteadOf()
- return _preconnect(connectionUrl)
+ return ssh.preconnect(connectionUrl)
def ReviewUrl(self, userEmail, validate_certs):
if self._review_url is None:
diff --git a/main.py b/main.py
index 8aba2ec..9674433 100755
--- a/main.py
+++ b/main.py
@@ -39,7 +39,7 @@
import event_log
from repo_trace import SetTrace
from git_command import user_agent
-from git_config import init_ssh, close_ssh, RepoConfig
+from git_config import RepoConfig
from git_trace2_event_log import EventLog
from command import InteractiveCommand
from command import MirrorSafeCommand
@@ -56,6 +56,7 @@
import gitc_utils
from manifest_xml import GitcClient, RepoClient
from pager import RunPager, TerminatePager
+import ssh
from wrapper import WrapperPath, Wrapper
from subcmds import all_commands
@@ -592,7 +593,7 @@
repo = _Repo(opt.repodir)
try:
try:
- init_ssh()
+ ssh.init()
init_http()
name, gopts, argv = repo._ParseArgs(argv)
run = lambda: repo._Run(name, gopts, argv) or 0
@@ -604,7 +605,7 @@
else:
result = run()
finally:
- close_ssh()
+ ssh.close()
except KeyboardInterrupt:
print('aborted by user', file=sys.stderr)
result = 1
diff --git a/ssh.py b/ssh.py
new file mode 100644
index 0000000..d06c4eb
--- /dev/null
+++ b/ssh.py
@@ -0,0 +1,257 @@
+# Copyright (C) 2008 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common SSH management logic."""
+
+import functools
+import os
+import re
+import signal
+import subprocess
+import sys
+import tempfile
+try:
+ import threading as _threading
+except ImportError:
+ import dummy_threading as _threading
+import time
+
+import platform_utils
+from repo_trace import Trace
+
+
+_ssh_proxy_path = None
+_ssh_sock_path = None
+_ssh_clients = []
+
+
+def _run_ssh_version():
+ """run ssh -V to display the version number"""
+ return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
+
+
+def _parse_ssh_version(ver_str=None):
+ """parse a ssh version string into a tuple"""
+ if ver_str is None:
+ ver_str = _run_ssh_version()
+ m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
+ if m:
+ return tuple(int(x) for x in m.group(1).split('.'))
+ else:
+ return ()
+
+
+@functools.lru_cache(maxsize=None)
+def version():
+ """return ssh version as a tuple"""
+ try:
+ return _parse_ssh_version()
+ except subprocess.CalledProcessError:
+ print('fatal: unable to detect ssh version', file=sys.stderr)
+ sys.exit(1)
+
+
+def proxy():
+ global _ssh_proxy_path
+ if _ssh_proxy_path is None:
+ _ssh_proxy_path = os.path.join(
+ os.path.dirname(__file__),
+ 'git_ssh')
+ return _ssh_proxy_path
+
+
+def add_client(p):
+ _ssh_clients.append(p)
+
+
+def remove_client(p):
+ try:
+ _ssh_clients.remove(p)
+ except ValueError:
+ pass
+
+
+def _terminate_clients():
+ global _ssh_clients
+ for p in _ssh_clients:
+ try:
+ os.kill(p.pid, signal.SIGTERM)
+ p.wait()
+ except OSError:
+ pass
+ _ssh_clients = []
+
+
+_master_processes = []
+_master_keys = set()
+_ssh_master = True
+_master_keys_lock = None
+
+
+def init():
+ """Should be called once at the start of repo to init ssh master handling.
+
+ At the moment, all we do is to create our lock.
+ """
+ global _master_keys_lock
+ assert _master_keys_lock is None, "Should only call init once"
+ _master_keys_lock = _threading.Lock()
+
+
+def _open_ssh(host, port=None):
+ global _ssh_master
+
+ # Bail before grabbing the lock if we already know that we aren't going to
+ # try creating new masters below.
+ if sys.platform in ('win32', 'cygwin'):
+ return False
+
+ # Acquire the lock. This is needed to prevent opening multiple masters for
+ # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
+ # manifest <remote fetch="ssh://xyz"> specifies a different host from the
+ # one that was passed to repo init.
+ _master_keys_lock.acquire()
+ try:
+
+ # Check to see whether we already think that the master is running; if we
+ # think it's already running, return right away.
+ if port is not None:
+ key = '%s:%s' % (host, port)
+ else:
+ key = host
+
+ if key in _master_keys:
+ return True
+
+ if not _ssh_master or 'GIT_SSH' in os.environ:
+ # Failed earlier, so don't retry.
+ return False
+
+ # We will make two calls to ssh; this is the common part of both calls.
+ command_base = ['ssh',
+ '-o', 'ControlPath %s' % sock(),
+ host]
+ if port is not None:
+ command_base[1:1] = ['-p', str(port)]
+
+ # Since the key wasn't in _master_keys, we think that master isn't running.
+ # ...but before actually starting a master, we'll double-check. This can
+ # be important because we can't tell that that 'git@myhost.com' is the same
+ # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
+ check_command = command_base + ['-O', 'check']
+ try:
+ Trace(': %s', ' '.join(check_command))
+ check_process = subprocess.Popen(check_command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ check_process.communicate() # read output, but ignore it...
+ isnt_running = check_process.wait()
+
+ if not isnt_running:
+ # Our double-check found that the master _was_ infact running. Add to
+ # the list of keys.
+ _master_keys.add(key)
+ return True
+ except Exception:
+ # Ignore excpetions. We we will fall back to the normal command and print
+ # to the log there.
+ pass
+
+ command = command_base[:1] + ['-M', '-N'] + command_base[1:]
+ try:
+ Trace(': %s', ' '.join(command))
+ p = subprocess.Popen(command)
+ except Exception as e:
+ _ssh_master = False
+ print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
+ % (host, port, str(e)), file=sys.stderr)
+ return False
+
+ time.sleep(1)
+ ssh_died = (p.poll() is not None)
+ if ssh_died:
+ return False
+
+ _master_processes.append(p)
+ _master_keys.add(key)
+ return True
+ finally:
+ _master_keys_lock.release()
+
+
+def close():
+ global _master_keys_lock
+
+ _terminate_clients()
+
+ for p in _master_processes:
+ try:
+ os.kill(p.pid, signal.SIGTERM)
+ p.wait()
+ except OSError:
+ pass
+ del _master_processes[:]
+ _master_keys.clear()
+
+ d = sock(create=False)
+ if d:
+ try:
+ platform_utils.rmdir(os.path.dirname(d))
+ except OSError:
+ pass
+
+ # We're done with the lock, so we can delete it.
+ _master_keys_lock = None
+
+
+URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
+URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
+
+
+def preconnect(url):
+ m = URI_ALL.match(url)
+ if m:
+ scheme = m.group(1)
+ host = m.group(2)
+ if ':' in host:
+ host, port = host.split(':')
+ else:
+ port = None
+ if scheme in ('ssh', 'git+ssh', 'ssh+git'):
+ return _open_ssh(host, port)
+ return False
+
+ m = URI_SCP.match(url)
+ if m:
+ host = m.group(1)
+ return _open_ssh(host)
+
+ return False
+
+def sock(create=True):
+ global _ssh_sock_path
+ if _ssh_sock_path is None:
+ if not create:
+ return None
+ tmp_dir = '/tmp'
+ if not os.path.exists(tmp_dir):
+ tmp_dir = tempfile.gettempdir()
+ if version() < (6, 7):
+ tokens = '%r@%h:%p'
+ else:
+ tokens = '%C' # hash of %l%h%p%r
+ _ssh_sock_path = os.path.join(
+ tempfile.mkdtemp('', 'ssh-', tmp_dir),
+ 'master-' + tokens)
+ return _ssh_sock_path
diff --git a/tests/test_git_command.py b/tests/test_git_command.py
index 76c092f..93300a6 100644
--- a/tests/test_git_command.py
+++ b/tests/test_git_command.py
@@ -26,38 +26,6 @@
import wrapper
-class SSHUnitTest(unittest.TestCase):
- """Tests the ssh functions."""
-
- def test_parse_ssh_version(self):
- """Check parse_ssh_version() handling."""
- ver = git_command._parse_ssh_version('Unknown\n')
- self.assertEqual(ver, ())
- ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
- self.assertEqual(ver, (1, 0))
- ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
- self.assertEqual(ver, (6, 6, 1))
- ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
- self.assertEqual(ver, (7, 6))
-
- def test_ssh_version(self):
- """Check ssh_version() handling."""
- with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
- self.assertEqual(git_command.ssh_version(), (1, 2))
-
- def test_ssh_sock(self):
- """Check ssh_sock() function."""
- with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
- # old ssh version uses port
- with mock.patch('git_command.ssh_version', return_value=(6, 6)):
- self.assertTrue(git_command.ssh_sock().endswith('%p'))
- git_command._ssh_sock_path = None
- # new ssh version uses hash
- with mock.patch('git_command.ssh_version', return_value=(6, 7)):
- self.assertTrue(git_command.ssh_sock().endswith('%C'))
- git_command._ssh_sock_path = None
-
-
class GitCallUnitTest(unittest.TestCase):
"""Tests the _GitCall class (via git_command.git)."""
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
new file mode 100644
index 0000000..5a4f27e
--- /dev/null
+++ b/tests/test_ssh.py
@@ -0,0 +1,52 @@
+# Copyright 2019 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittests for the ssh.py module."""
+
+import unittest
+from unittest import mock
+
+import ssh
+
+
+class SshTests(unittest.TestCase):
+ """Tests the ssh functions."""
+
+ def test_parse_ssh_version(self):
+ """Check _parse_ssh_version() handling."""
+ ver = ssh._parse_ssh_version('Unknown\n')
+ self.assertEqual(ver, ())
+ ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
+ self.assertEqual(ver, (1, 0))
+ ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
+ self.assertEqual(ver, (6, 6, 1))
+ ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
+ self.assertEqual(ver, (7, 6))
+
+ def test_version(self):
+ """Check version() handling."""
+ with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
+ self.assertEqual(ssh.version(), (1, 2))
+
+ def test_ssh_sock(self):
+ """Check sock() function."""
+ with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
+ # old ssh version uses port
+ with mock.patch('ssh.version', return_value=(6, 6)):
+ self.assertTrue(ssh.sock().endswith('%p'))
+ ssh._ssh_sock_path = None
+ # new ssh version uses hash
+ with mock.patch('ssh.version', return_value=(6, 7)):
+ self.assertTrue(ssh.sock().endswith('%C'))
+ ssh._ssh_sock_path = None