Merge "Handle shallow checkout of SHA1 pinned repos"
diff --git a/color.py b/color.py
index 7970198..b279928 100644
--- a/color.py
+++ b/color.py
@@ -83,15 +83,38 @@
   return code
 
 
+DEFAULT = None
+
+def SetDefaultColoring(state):
+  """Set coloring behavior to |state|.
+
+  This is useful for overriding config options via the command line.
+  """
+  if state is None:
+    # Leave it alone -- return quick!
+    return
+
+  global DEFAULT
+  state = state.lower()
+  if state in ('auto',):
+    DEFAULT = state
+  elif state in ('always', 'yes', 'true', True):
+    DEFAULT = 'always'
+  elif state in ('never', 'no', 'false', False):
+    DEFAULT = 'never'
+
+
 class Coloring(object):
   def __init__(self, config, section_type):
     self._section = 'color.%s' % section_type
     self._config = config
     self._out = sys.stdout
 
-    on = self._config.GetString(self._section)
+    on = DEFAULT
     if on is None:
-      on = self._config.GetString('color.ui')
+      on = self._config.GetString(self._section)
+      if on is None:
+        on = self._config.GetString('color.ui')
 
     if on == 'auto':
       if pager.active or os.isatty(1):
diff --git a/main.py b/main.py
index 72fb39b..47f083d 100755
--- a/main.py
+++ b/main.py
@@ -36,6 +36,7 @@
 except ImportError:
   kerberos = None
 
+from color import SetDefaultColoring
 from trace import SetTrace
 from git_command import git, GitCommand
 from git_config import init_ssh, close_ssh
@@ -69,6 +70,9 @@
 global_options.add_option('--no-pager',
                           dest='no_pager', action='store_true',
                           help='disable the pager')
+global_options.add_option('--color',
+                          choices=('auto', 'always', 'never'), default=None,
+                          help='control color usage: auto, always, never')
 global_options.add_option('--trace',
                           dest='trace', action='store_true',
                           help='trace git command execution')
@@ -113,6 +117,8 @@
         print('fatal: invalid usage of --version', file=sys.stderr)
         return 1
 
+    SetDefaultColoring(gopts.color)
+
     try:
       cmd = self.commands[name]
     except KeyError:
diff --git a/project.py b/project.py
index b9a53dc..49db02e 100644
--- a/project.py
+++ b/project.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from __future__ import print_function
-import traceback
+import contextlib
 import errno
 import filecmp
 import os
@@ -26,6 +26,7 @@
 import tarfile
 import tempfile
 import time
+import traceback
 
 from color import Coloring
 from git_command import GitCommand, git_require
@@ -84,7 +85,7 @@
   global _project_hook_list
   if _project_hook_list is None:
     d = os.path.realpath(os.path.abspath(os.path.dirname(__file__)))
-    d = os.path.join(d , 'hooks')
+    d = os.path.join(d, 'hooks')
     _project_hook_list = [os.path.join(d, x) for x in os.listdir(d)]
   return _project_hook_list
 
@@ -182,28 +183,28 @@
 class StatusColoring(Coloring):
   def __init__(self, config):
     Coloring.__init__(self, config, 'status')
-    self.project   = self.printer('header',    attr = 'bold')
-    self.branch    = self.printer('header',    attr = 'bold')
-    self.nobranch  = self.printer('nobranch',  fg = 'red')
-    self.important = self.printer('important', fg = 'red')
+    self.project = self.printer('header', attr='bold')
+    self.branch = self.printer('header', attr='bold')
+    self.nobranch = self.printer('nobranch', fg='red')
+    self.important = self.printer('important', fg='red')
 
-    self.added     = self.printer('added',     fg = 'green')
-    self.changed   = self.printer('changed',   fg = 'red')
-    self.untracked = self.printer('untracked', fg = 'red')
+    self.added = self.printer('added', fg='green')
+    self.changed = self.printer('changed', fg='red')
+    self.untracked = self.printer('untracked', fg='red')
 
 
 class DiffColoring(Coloring):
   def __init__(self, config):
     Coloring.__init__(self, config, 'diff')
-    self.project   = self.printer('header',    attr = 'bold')
+    self.project = self.printer('header', attr='bold')
 
-class _Annotation:
+class _Annotation(object):
   def __init__(self, name, value, keep):
     self.name = name
     self.value = value
     self.keep = keep
 
-class _CopyFile:
+class _CopyFile(object):
   def __init__(self, src, dest, abssrc, absdest):
     self.src = src
     self.dest = dest
@@ -231,7 +232,7 @@
       except IOError:
         _error('Cannot copy file %s to %s', src, dest)
 
-class _LinkFile:
+class _LinkFile(object):
   def __init__(self, src, dest, abssrc, absdest):
     self.src = src
     self.dest = dest
@@ -258,9 +259,9 @@
 class RemoteSpec(object):
   def __init__(self,
                name,
-               url = None,
-               review = None,
-               revision = None):
+               url=None,
+               review=None,
+               revision=None):
     self.name = name
     self.url = url
     self.review = review
@@ -520,15 +521,15 @@
                relpath,
                revisionExpr,
                revisionId,
-               rebase = True,
-               groups = None,
-               sync_c = False,
-               sync_s = False,
-               clone_depth = None,
-               upstream = None,
-               parent = None,
-               is_derived = False,
-               dest_branch = None):
+               rebase=True,
+               groups=None,
+               sync_c=False,
+               sync_s=False,
+               clone_depth=None,
+               upstream=None,
+               parent=None,
+               is_derived=False,
+               dest_branch=None):
     """Init a Project object.
 
     Args:
@@ -585,8 +586,8 @@
     self.linkfiles = []
     self.annotations = []
     self.config = GitConfig.ForRepository(
-                    gitdir = self.gitdir,
-                    defaults =  self.manifest.globalConfig)
+                    gitdir=self.gitdir,
+                    defaults=self.manifest.globalConfig)
 
     if self.worktree:
       self.work_git = self._GitGetByExec(self, bare=False, gitdir=gitdir)
@@ -879,8 +880,8 @@
     cmd.append('--')
     p = GitCommand(self,
                    cmd,
-                   capture_stdout = True,
-                   capture_stderr = True)
+                   capture_stdout=True,
+                   capture_stderr=True)
     has_diff = False
     for line in p.process.stdout:
       if not has_diff:
@@ -965,7 +966,7 @@
     return None
 
   def UploadForReview(self, branch=None,
-                      people=([],[]),
+                      people=([], []),
                       auto_topic=False,
                       draft=False,
                       dest_branch=None):
@@ -1026,13 +1027,13 @@
         ref_spec = ref_spec + '%' + ','.join(rp)
     cmd.append(ref_spec)
 
-    if GitCommand(self, cmd, bare = True).Wait() != 0:
+    if GitCommand(self, cmd, bare=True).Wait() != 0:
       raise UploadError('Upload failed')
 
     msg = "posted to %s for %s" % (branch.remote.review, dest_branch)
     self.bare_git.UpdateRef(R_PUB + branch.name,
                             R_HEADS + branch.name,
-                            message = msg)
+                            message=msg)
 
 
 ## Sync ##
@@ -1133,7 +1134,7 @@
         and not self._RemoteFetch(initial=is_new, quiet=quiet, alt_dir=alt_dir,
                                   current_branch_only=current_branch_only,
                                   no_tags=no_tags)):
-          return False
+      return False
 
     if self.worktree:
       self._InitMRef()
@@ -1329,7 +1330,7 @@
 
     if cnt_mine > 0 and self.rebase:
       def _dorebase():
-        self._Rebase(upstream = '%s^1' % last_mine, onto = revid)
+        self._Rebase(upstream='%s^1' % last_mine, onto=revid)
         self._CopyAndLinkFiles()
       syncbuf.later2(self, _dorebase)
     elif local_changes:
@@ -1384,11 +1385,11 @@
       return True
 
     all_refs = self.bare_ref.all
-    if (R_HEADS + name) in all_refs:
+    if R_HEADS + name in all_refs:
       return GitCommand(self,
                         ['checkout', name, '--'],
-                        capture_stdout = True,
-                        capture_stderr = True).Wait() == 0
+                        capture_stdout=True,
+                        capture_stderr=True).Wait() == 0
 
     branch = self.GetBranch(name)
     branch.remote = self.GetRemote(self.remote.name)
@@ -1415,8 +1416,8 @@
 
     if GitCommand(self,
                   ['checkout', '-b', branch.name, revid],
-                  capture_stdout = True,
-                  capture_stderr = True).Wait() == 0:
+                  capture_stdout=True,
+                  capture_stderr=True).Wait() == 0:
       branch.Save()
       return True
     return False
@@ -1462,8 +1463,8 @@
 
     return GitCommand(self,
                       ['checkout', name, '--'],
-                      capture_stdout = True,
-                      capture_stderr = True).Wait() == 0
+                      capture_stdout=True,
+                      capture_stderr=True).Wait() == 0
 
   def AbandonBranch(self, name):
     """Destroy a local topic branch.
@@ -1497,8 +1498,8 @@
 
     return GitCommand(self,
                       ['branch', '-D', name],
-                      capture_stdout = True,
-                      capture_stderr = True).Wait() == 0
+                      capture_stdout=True,
+                      capture_stderr=True).Wait() == 0
 
   def PruneHeads(self):
     """Prune any topic branches already merged into upstream.
@@ -1515,7 +1516,7 @@
     rev = self.GetRevisionId(left)
     if cb is not None \
        and not self._revlist(HEAD + '...' + rev) \
-       and not self.IsDirty(consider_untracked = False):
+       and not self.IsDirty(consider_untracked=False):
       self.work_git.DetachHead(HEAD)
       kill.append(cb)
 
@@ -1548,7 +1549,7 @@
 
     kept = []
     for branch in kill:
-      if (R_HEADS + branch) in left:
+      if R_HEADS + branch in left:
         branch = self.GetBranch(branch)
         base = branch.LocalMerge
         if not base:
@@ -1598,8 +1599,8 @@
     def parse_gitmodules(gitdir, rev):
       cmd = ['cat-file', 'blob', '%s:.gitmodules' % rev]
       try:
-        p = GitCommand(None, cmd, capture_stdout = True, capture_stderr = True,
-                       bare = True, gitdir = gitdir)
+        p = GitCommand(None, cmd, capture_stdout=True, capture_stderr=True,
+                       bare=True, gitdir=gitdir)
       except GitError:
         return [], []
       if p.Wait() != 0:
@@ -1611,8 +1612,8 @@
         os.write(fd, p.stdout)
         os.close(fd)
         cmd = ['config', '--file', temp_gitmodules_path, '--list']
-        p = GitCommand(None, cmd, capture_stdout = True, capture_stderr = True,
-                       bare = True, gitdir = gitdir)
+        p = GitCommand(None, cmd, capture_stdout=True, capture_stderr=True,
+                       bare=True, gitdir=gitdir)
         if p.Wait() != 0:
           return [], []
         gitmodules_lines = p.stdout.split('\n')
@@ -1645,8 +1646,8 @@
       cmd = ['ls-tree', rev, '--']
       cmd.extend(paths)
       try:
-        p = GitCommand(None, cmd, capture_stdout = True, capture_stderr = True,
-                       bare = True, gitdir = gitdir)
+        p = GitCommand(None, cmd, capture_stdout=True, capture_stderr=True,
+                       bare=True, gitdir=gitdir)
       except GitError:
         return []
       if p.Wait() != 0:
@@ -1681,24 +1682,24 @@
         continue
 
       remote = RemoteSpec(self.remote.name,
-                          url = url,
-                          review = self.remote.review,
-                          revision = self.remote.revision)
-      subproject = Project(manifest = self.manifest,
-                           name = name,
-                           remote = remote,
-                           gitdir = gitdir,
-                           objdir = objdir,
-                           worktree = worktree,
-                           relpath = relpath,
-                           revisionExpr = self.revisionExpr,
-                           revisionId = rev,
-                           rebase = self.rebase,
-                           groups = self.groups,
-                           sync_c = self.sync_c,
-                           sync_s = self.sync_s,
-                           parent = self,
-                           is_derived = True)
+                          url=url,
+                          review=self.remote.review,
+                          revision=self.remote.revision)
+      subproject = Project(manifest=self.manifest,
+                           name=name,
+                           remote=remote,
+                           gitdir=gitdir,
+                           objdir=objdir,
+                           worktree=worktree,
+                           relpath=relpath,
+                           revisionExpr=self.revisionExpr,
+                           revisionId=rev,
+                           rebase=self.rebase,
+                           groups=self.groups,
+                           sync_c=self.sync_c,
+                           sync_s=self.sync_s,
+                           parent=self,
+                           is_derived=True)
       result.append(subproject)
       result.extend(subproject.GetDerivedSubprojects())
     return result
@@ -1866,9 +1867,9 @@
       GitCommand(self, ['fetch', '--unshallow', name] + shallowfetch.split(),
                  bare=True, ssh_proxy=ssh_proxy).Wait()
     if depth:
-        self.config.SetString('repo.shallowfetch', ' '.join(spec))
+      self.config.SetString('repo.shallowfetch', ' '.join(spec))
     else:
-        self.config.SetString('repo.shallowfetch', None)
+      self.config.SetString('repo.shallowfetch', None)
 
     ok = False
     for _i in range(2):
@@ -1958,34 +1959,34 @@
         os.remove(tmpPath)
     if 'http_proxy' in os.environ and 'darwin' == sys.platform:
       cmd += ['--proxy', os.environ['http_proxy']]
-    cookiefile = self._GetBundleCookieFile(srcUrl)
-    if cookiefile:
-      cmd += ['--cookie', cookiefile]
-    if srcUrl.startswith('persistent-'):
-      srcUrl = srcUrl[len('persistent-'):]
-    cmd += [srcUrl]
+    with self._GetBundleCookieFile(srcUrl, quiet) as cookiefile:
+      if cookiefile:
+        cmd += ['--cookie', cookiefile, '--cookie-jar', cookiefile]
+      if srcUrl.startswith('persistent-'):
+        srcUrl = srcUrl[len('persistent-'):]
+      cmd += [srcUrl]
 
-    if IsTrace():
-      Trace('%s', ' '.join(cmd))
-    try:
-      proc = subprocess.Popen(cmd)
-    except OSError:
-      return False
+      if IsTrace():
+        Trace('%s', ' '.join(cmd))
+      try:
+        proc = subprocess.Popen(cmd)
+      except OSError:
+        return False
 
-    curlret = proc.wait()
+      curlret = proc.wait()
 
-    if curlret == 22:
-      # From curl man page:
-      # 22: HTTP page not retrieved. The requested url was not found or
-      # returned another error with the HTTP error code being 400 or above.
-      # This return code only appears if -f, --fail is used.
-      if not quiet:
-        print("Server does not provide clone.bundle; ignoring.",
-              file=sys.stderr)
-      return False
+      if curlret == 22:
+        # From curl man page:
+        # 22: HTTP page not retrieved. The requested url was not found or
+        # returned another error with the HTTP error code being 400 or above.
+        # This return code only appears if -f, --fail is used.
+        if not quiet:
+          print("Server does not provide clone.bundle; ignoring.",
+                file=sys.stderr)
+        return False
 
     if os.path.exists(tmpPath):
-      if curlret == 0 and self._IsValidBundle(tmpPath):
+      if curlret == 0 and self._IsValidBundle(tmpPath, quiet):
         os.rename(tmpPath, dstPath)
         return True
       else:
@@ -1994,45 +1995,51 @@
     else:
       return False
 
-  def _IsValidBundle(self, path):
+  def _IsValidBundle(self, path, quiet):
     try:
       with open(path) as f:
         if f.read(16) == '# v2 git bundle\n':
           return True
         else:
-          print("Invalid clone.bundle file; ignoring.", file=sys.stderr)
+          if not quiet:
+            print("Invalid clone.bundle file; ignoring.", file=sys.stderr)
           return False
     except OSError:
       return False
 
-  def _GetBundleCookieFile(self, url):
+  @contextlib.contextmanager
+  def _GetBundleCookieFile(self, url, quiet):
     if url.startswith('persistent-'):
       try:
         p = subprocess.Popen(
             ['git-remote-persistent-https', '-print_config', url],
             stdin=subprocess.PIPE, stdout=subprocess.PIPE,
             stderr=subprocess.PIPE)
-        p.stdin.close()  # Tell subprocess it's ok to close.
-        prefix = 'http.cookiefile='
-        cookiefile = None
-        for line in p.stdout:
-          line = line.strip()
-          if line.startswith(prefix):
-            cookiefile = line[len(prefix):]
-            break
-        if p.wait():
-          err_msg = p.stderr.read()
-          if ' -print_config' in err_msg:
-            pass  # Persistent proxy doesn't support -print_config.
-          else:
-            print(err_msg, file=sys.stderr)
-        if cookiefile:
-          return cookiefile
+        try:
+          prefix = 'http.cookiefile='
+          cookiefile = None
+          for line in p.stdout:
+            line = line.strip()
+            if line.startswith(prefix):
+              cookiefile = line[len(prefix):]
+              break
+          # Leave subprocess open, as cookie file may be transient.
+          if cookiefile:
+            yield cookiefile
+            return
+        finally:
+          p.stdin.close()
+          if p.wait():
+            err_msg = p.stderr.read()
+            if ' -print_config' in err_msg:
+              pass  # Persistent proxy doesn't support -print_config.
+            elif not quiet:
+              print(err_msg, file=sys.stderr)
       except OSError as e:
         if e.errno == errno.ENOENT:
           pass  # No persistent proxy.
         raise
-    return GitConfig.ForUser().GetString('http.cookiefile')
+    yield GitConfig.ForUser().GetString('http.cookiefile')
 
   def _Checkout(self, rev, quiet=False):
     cmd = ['checkout']
@@ -2044,7 +2051,7 @@
       if self._allrefs:
         raise GitError('%s checkout %s ' % (self.name, rev))
 
-  def _CherryPick(self, rev, quiet=False):
+  def _CherryPick(self, rev):
     cmd = ['cherry-pick']
     cmd.append(rev)
     cmd.append('--')
@@ -2052,7 +2059,7 @@
       if self._allrefs:
         raise GitError('%s cherry-pick %s ' % (self.name, rev))
 
-  def _Revert(self, rev, quiet=False):
+  def _Revert(self, rev):
     cmd = ['revert']
     cmd.append('--no-edit')
     cmd.append(rev)
@@ -2069,7 +2076,7 @@
     if GitCommand(self, cmd).Wait() != 0:
       raise GitError('%s reset --hard %s ' % (self.name, rev))
 
-  def _Rebase(self, upstream, onto = None):
+  def _Rebase(self, upstream, onto=None):
     cmd = ['rebase']
     if onto is not None:
       cmd.extend(['--onto', onto])
@@ -2124,7 +2131,7 @@
 
       m = self.manifest.manifestProject.config
       for key in ['user.name', 'user.email']:
-        if m.Has(key, include_defaults = False):
+        if m.Has(key, include_defaults=False):
           self.config.SetString(key, m.GetString(key))
       if self.manifest.IsMirror:
         self.config.SetString('core.bare', 'true')
@@ -2133,15 +2140,6 @@
 
   def _UpdateHooks(self):
     if os.path.exists(self.gitdir):
-      # Always recreate hooks since they can have been changed
-      # since the latest update.
-      hooks = self._gitdir_path('hooks')
-      try:
-        to_rm = os.listdir(hooks)
-      except OSError:
-        to_rm = []
-      for old_hook in to_rm:
-        os.remove(os.path.join(hooks, old_hook))
       self._InitHooks()
 
   def _InitHooks(self):
@@ -2204,7 +2202,7 @@
       if cur != '' or self.bare_ref.get(ref) != self.revisionId:
         msg = 'manifest set to %s' % self.revisionId
         dst = self.revisionId + '^0'
-        self.bare_git.UpdateRef(ref, dst, message = msg, detach = True)
+        self.bare_git.UpdateRef(ref, dst, message=msg, detach=True)
     else:
       remote = self.GetRemote(self.remote.name)
       dst = remote.ToLocal(self.revisionExpr)
@@ -2348,10 +2346,10 @@
                       '-z',
                       '--others',
                       '--exclude-standard'],
-                     bare = False,
+                     bare=False,
                      gitdir=self._gitdir,
-                     capture_stdout = True,
-                     capture_stderr = True)
+                     capture_stdout=True,
+                     capture_stderr=True)
       if p.Wait() == 0:
         out = p.stdout
         if out:
@@ -2366,9 +2364,9 @@
       p = GitCommand(self._project,
                      cmd,
                      gitdir=self._gitdir,
-                     bare = False,
-                     capture_stdout = True,
-                     capture_stderr = True)
+                     bare=False,
+                     capture_stdout=True,
+                     capture_stderr=True)
       try:
         out = p.process.stdout.read()
         r = {}
@@ -2474,10 +2472,10 @@
       cmdv.extend(args)
       p = GitCommand(self._project,
                      cmdv,
-                     bare = self._bare,
+                     bare=self._bare,
                      gitdir=self._gitdir,
-                     capture_stdout = True,
-                     capture_stderr = True)
+                     capture_stdout=True,
+                     capture_stderr=True)
       r = []
       for line in p.process.stdout:
         if line[-1] == '\n':
@@ -2527,10 +2525,10 @@
         cmdv.extend(args)
         p = GitCommand(self._project,
                        cmdv,
-                       bare = self._bare,
+                       bare=self._bare,
                        gitdir=self._gitdir,
-                       capture_stdout = True,
-                       capture_stderr = True)
+                       capture_stdout=True,
+                       capture_stderr=True)
         if p.Wait() != 0:
           raise GitError('%s %s: %s' % (
                          self._project.name,
@@ -2595,9 +2593,9 @@
 class _SyncColoring(Coloring):
   def __init__(self, config):
     Coloring.__init__(self, config, 'reposync')
-    self.project   = self.printer('header', attr = 'bold')
-    self.info      = self.printer('info')
-    self.fail      = self.printer('fail', fg='red')
+    self.project = self.printer('header', attr='bold')
+    self.info = self.printer('info')
+    self.fail = self.printer('fail', fg='red')
 
 class SyncBuffer(object):
   def __init__(self, config, detach_head=False):
@@ -2659,16 +2657,16 @@
   """
   def __init__(self, manifest, name, gitdir, worktree):
     Project.__init__(self,
-                     manifest = manifest,
-                     name = name,
-                     gitdir = gitdir,
-                     objdir = gitdir,
-                     worktree = worktree,
-                     remote = RemoteSpec('origin'),
-                     relpath = '.repo/%s' % name,
-                     revisionExpr = 'refs/heads/master',
-                     revisionId = None,
-                     groups = None)
+                     manifest=manifest,
+                     name=name,
+                     gitdir=gitdir,
+                     objdir=gitdir,
+                     worktree=worktree,
+                     remote=RemoteSpec('origin'),
+                     relpath='.repo/%s' % name,
+                     revisionExpr='refs/heads/master',
+                     revisionId=None,
+                     groups=None)
 
   def PreSync(self):
     if self.Exists:
@@ -2679,20 +2677,20 @@
           self.revisionExpr = base
           self.revisionId = None
 
-  def MetaBranchSwitch(self, target):
+  def MetaBranchSwitch(self):
     """ Prepare MetaProject for manifest branch switch
     """
 
     # detach and delete manifest branch, allowing a new
     # branch to take over
-    syncbuf = SyncBuffer(self.config, detach_head = True)
+    syncbuf = SyncBuffer(self.config, detach_head=True)
     self.Sync_LocalHalf(syncbuf)
     syncbuf.Finish()
 
     return GitCommand(self,
                         ['update-ref', '-d', 'refs/heads/default'],
-                        capture_stdout = True,
-                        capture_stderr = True).Wait() == 0
+                        capture_stdout=True,
+                        capture_stderr=True).Wait() == 0
 
 
   @property
diff --git a/subcmds/init.py b/subcmds/init.py
index b1fcb69..b73de71 100644
--- a/subcmds/init.py
+++ b/subcmds/init.py
@@ -233,7 +233,7 @@
       sys.exit(1)
 
     if opt.manifest_branch:
-      m.MetaBranchSwitch(opt.manifest_branch)
+      m.MetaBranchSwitch()
 
     syncbuf = SyncBuffer(m.config)
     m.Sync_LocalHalf(syncbuf)