command: add a helper for the parallel execution boilerplate

Now that we have a bunch of subcommands doing parallel execution, a
common pattern arises that we can factor out for most of them.  We
leave forall alone as it's a bit too complicated atm to cut over.

Change-Id: I3617a4f7c66142bcd1ab030cb4cca698a65010ac
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/301942
Tested-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
diff --git a/command.py b/command.py
index be2d6a6..9b1220d 100644
--- a/command.py
+++ b/command.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import multiprocessing
 import os
 import optparse
 import platform
@@ -21,6 +22,7 @@
 from event_log import EventLog
 from error import NoSuchProjectError
 from error import InvalidProjectGroupsError
+import progress
 
 
 # Number of projects to submit to a single worker process at a time.
@@ -156,6 +158,44 @@
     """
     raise NotImplementedError
 
+  @staticmethod
+  def ExecuteInParallel(jobs, func, inputs, callback, output=None, ordered=False):
+    """Helper for managing parallel execution boiler plate.
+
+    For subcommands that can easily split their work up.
+
+    Args:
+      jobs: How many parallel processes to use.
+      func: The function to apply to each of the |inputs|.  Usually a
+          functools.partial for wrapping additional arguments.  It will be run
+          in a separate process, so it must be pickalable, so nested functions
+          won't work.  Methods on the subcommand Command class should work.
+      inputs: The list of items to process.  Must be a list.
+      callback: The function to pass the results to for processing.  It will be
+          executed in the main thread and process the results of |func| as they
+          become available.  Thus it may be a local nested function.  Its return
+          value is passed back directly.  It takes three arguments:
+          - The processing pool (or None with one job).
+          - The |output| argument.
+          - An iterator for the results.
+      output: An output manager.  May be progress.Progess or color.Coloring.
+      ordered: Whether the jobs should be processed in order.
+
+    Returns:
+      The |callback| function's results are returned.
+    """
+    try:
+      # NB: Multiprocessing is heavy, so don't spin it up for one job.
+      if len(inputs) == 1 or jobs == 1:
+        return callback(None, output, (func(x) for x in inputs))
+      else:
+        with multiprocessing.Pool(jobs) as pool:
+          submit = pool.imap if ordered else pool.imap_unordered
+          return callback(pool, output, submit(func, inputs, chunksize=WORKER_BATCH_SIZE))
+    finally:
+      if isinstance(output, progress.Progress):
+        output.end()
+
   def _ResetPathToProjectMap(self, projects):
     self._by_path = dict((p.worktree, p) for p in projects)
 
diff --git a/subcmds/abandon.py b/subcmds/abandon.py
index 1d22917..c7c127d 100644
--- a/subcmds/abandon.py
+++ b/subcmds/abandon.py
@@ -15,10 +15,9 @@
 from collections import defaultdict
 import functools
 import itertools
-import multiprocessing
 import sys
 
-from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE
+from command import Command, DEFAULT_LOCAL_JOBS
 from git_command import git
 from progress import Progress
 
@@ -52,9 +51,9 @@
     else:
       args.insert(0, "'All local branches'")
 
-  def _ExecuteOne(self, opt, nb, project):
+  def _ExecuteOne(self, all_branches, nb, project):
     """Abandon one project."""
-    if opt.all:
+    if all_branches:
       branches = project.GetBranches()
     else:
       branches = [nb]
@@ -72,7 +71,7 @@
     success = defaultdict(list)
     all_projects = self.GetProjects(args[1:])
 
-    def _ProcessResults(states):
+    def _ProcessResults(_pool, pm, states):
       for (results, project) in states:
         for branch, status in results.items():
           if status:
@@ -81,17 +80,12 @@
             err[branch].append(project)
         pm.update()
 
-    pm = Progress('Abandon %s' % nb, len(all_projects), quiet=opt.quiet)
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(all_projects) == 1 or opt.jobs == 1:
-      _ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects)
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        states = pool.imap_unordered(
-            functools.partial(self._ExecuteOne, opt, nb), all_projects,
-            chunksize=WORKER_BATCH_SIZE)
-        _ProcessResults(states)
-    pm.end()
+    self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._ExecuteOne, opt.all, nb),
+        all_projects,
+        callback=_ProcessResults,
+        output=Progress('Abandon %s' % (nb,), len(all_projects), quiet=opt.quiet))
 
     width = max(itertools.chain(
         [25], (len(x) for x in itertools.chain(success, err))))
diff --git a/subcmds/branches.py b/subcmds/branches.py
index d5ea580..2dc102b 100644
--- a/subcmds/branches.py
+++ b/subcmds/branches.py
@@ -13,10 +13,10 @@
 # limitations under the License.
 
 import itertools
-import multiprocessing
 import sys
+
 from color import Coloring
-from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE
+from command import Command, DEFAULT_LOCAL_JOBS
 
 
 class BranchColoring(Coloring):
@@ -102,15 +102,19 @@
     out = BranchColoring(self.manifest.manifestProject.config)
     all_branches = {}
     project_cnt = len(projects)
-    with multiprocessing.Pool(processes=opt.jobs) as pool:
-      project_branches = pool.imap_unordered(
-          expand_project_to_branches, projects, chunksize=WORKER_BATCH_SIZE)
 
-      for name, b in itertools.chain.from_iterable(project_branches):
+    def _ProcessResults(_pool, _output, results):
+      for name, b in itertools.chain.from_iterable(results):
         if name not in all_branches:
           all_branches[name] = BranchInfo(name)
         all_branches[name].add(b)
 
+    self.ExecuteInParallel(
+        opt.jobs,
+        expand_project_to_branches,
+        projects,
+        callback=_ProcessResults)
+
     names = sorted(all_branches)
 
     if not names:
diff --git a/subcmds/checkout.py b/subcmds/checkout.py
index 6b71a8f..4d8009b 100644
--- a/subcmds/checkout.py
+++ b/subcmds/checkout.py
@@ -13,10 +13,9 @@
 # limitations under the License.
 
 import functools
-import multiprocessing
 import sys
 
-from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE
+from command import Command, DEFAULT_LOCAL_JOBS
 from progress import Progress
 
 
@@ -50,7 +49,7 @@
     success = []
     all_projects = self.GetProjects(args[1:])
 
-    def _ProcessResults(results):
+    def _ProcessResults(_pool, pm, results):
       for status, project in results:
         if status is not None:
           if status:
@@ -59,17 +58,12 @@
             err.append(project)
         pm.update()
 
-    pm = Progress('Checkout %s' % nb, len(all_projects), quiet=opt.quiet)
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(all_projects) == 1 or opt.jobs == 1:
-      _ProcessResults(self._ExecuteOne(nb, x) for x in all_projects)
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        results = pool.imap_unordered(
-            functools.partial(self._ExecuteOne, nb), all_projects,
-            chunksize=WORKER_BATCH_SIZE)
-        _ProcessResults(results)
-    pm.end()
+    self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._ExecuteOne, nb),
+        all_projects,
+        callback=_ProcessResults,
+        output=Progress('Checkout %s' % (nb,), len(all_projects), quiet=opt.quiet))
 
     if err:
       for p in err:
diff --git a/subcmds/diff.py b/subcmds/diff.py
index cdc262e..4966bb1 100644
--- a/subcmds/diff.py
+++ b/subcmds/diff.py
@@ -14,9 +14,8 @@
 
 import functools
 import io
-import multiprocessing
 
-from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE
+from command import DEFAULT_LOCAL_JOBS, PagedCommand
 
 
 class Diff(PagedCommand):
@@ -36,7 +35,7 @@
                  dest='absolute', action='store_true',
                  help='Paths are relative to the repository root')
 
-  def _DiffHelper(self, absolute, project):
+  def _ExecuteOne(self, absolute, project):
     """Obtains the diff for a specific project.
 
     Args:
@@ -51,22 +50,20 @@
     return (ret, buf.getvalue())
 
   def Execute(self, opt, args):
-    ret = 0
     all_projects = self.GetProjects(args)
 
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(all_projects) == 1 or opt.jobs == 1:
-      for project in all_projects:
-        if not project.PrintWorkTreeDiff(opt.absolute):
+    def _ProcessResults(_pool, _output, results):
+      ret = 0
+      for (state, output) in results:
+        if output:
+          print(output, end='')
+        if not state:
           ret = 1
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        states = pool.imap(functools.partial(self._DiffHelper, opt.absolute),
-                           all_projects, WORKER_BATCH_SIZE)
-        for (state, output) in states:
-          if output:
-            print(output, end='')
-          if not state:
-            ret = 1
+      return ret
 
-    return ret
+    return self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._ExecuteOne, opt.absolute),
+        all_projects,
+        callback=_ProcessResults,
+        ordered=True)
diff --git a/subcmds/grep.py b/subcmds/grep.py
index 9a4a8a3..6cb1445 100644
--- a/subcmds/grep.py
+++ b/subcmds/grep.py
@@ -13,11 +13,10 @@
 # limitations under the License.
 
 import functools
-import multiprocessing
 import sys
 
 from color import Coloring
-from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE
+from command import DEFAULT_LOCAL_JOBS, PagedCommand
 from error import GitError
 from git_command import GitCommand
 
@@ -173,7 +172,7 @@
     return (project, p.Wait(), p.stdout, p.stderr)
 
   @staticmethod
-  def _ProcessResults(out, full_name, have_rev, results):
+  def _ProcessResults(full_name, have_rev, _pool, out, results):
     git_failed = False
     bad_rev = False
     have_match = False
@@ -256,18 +255,13 @@
       cmd_argv.extend(opt.revision)
     cmd_argv.append('--')
 
-    process_results = functools.partial(
-        self._ProcessResults, out, full_name, have_rev)
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(projects) == 1 or opt.jobs == 1:
-      git_failed, bad_rev, have_match = process_results(
-          self._ExecuteOne(cmd_argv, x) for x in projects)
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        results = pool.imap(
-            functools.partial(self._ExecuteOne, cmd_argv), projects,
-            chunksize=WORKER_BATCH_SIZE)
-        git_failed, bad_rev, have_match = process_results(results)
+    git_failed, bad_rev, have_match = self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._ExecuteOne, cmd_argv),
+        projects,
+        callback=functools.partial(self._ProcessResults, full_name, have_rev),
+        output=out,
+        ordered=True)
 
     if git_failed:
       sys.exit(1)
diff --git a/subcmds/prune.py b/subcmds/prune.py
index 4084c8b..236b647 100644
--- a/subcmds/prune.py
+++ b/subcmds/prune.py
@@ -13,10 +13,9 @@
 # limitations under the License.
 
 import itertools
-import multiprocessing
 
 from color import Coloring
-from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE
+from command import DEFAULT_LOCAL_JOBS, PagedCommand
 
 
 class Prune(PagedCommand):
@@ -36,18 +35,15 @@
 
     # NB: Should be able to refactor this module to display summary as results
     # come back from children.
-    def _ProcessResults(results):
+    def _ProcessResults(_pool, _output, results):
       return list(itertools.chain.from_iterable(results))
 
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(projects) == 1 or opt.jobs == 1:
-      all_branches = _ProcessResults(self._ExecuteOne(x) for x in projects)
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        results = pool.imap(
-            self._ExecuteOne, projects,
-            chunksize=WORKER_BATCH_SIZE)
-        all_branches = _ProcessResults(results)
+    all_branches = self.ExecuteInParallel(
+        opt.jobs,
+        self._ExecuteOne,
+        projects,
+        callback=_ProcessResults,
+        ordered=True)
 
     if not all_branches:
       return
diff --git a/subcmds/start.py b/subcmds/start.py
index aa2f915..ff2bae5 100644
--- a/subcmds/start.py
+++ b/subcmds/start.py
@@ -13,11 +13,10 @@
 # limitations under the License.
 
 import functools
-import multiprocessing
 import os
 import sys
 
-from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE
+from command import Command, DEFAULT_LOCAL_JOBS
 from git_config import IsImmutable
 from git_command import git
 import gitc_utils
@@ -55,7 +54,7 @@
     if not git.check_ref_format('heads/%s' % nb):
       self.OptionParser.error("'%s' is not a valid name" % nb)
 
-  def _ExecuteOne(self, opt, nb, project):
+  def _ExecuteOne(self, revision, nb, project):
     """Start one project."""
     # If the current revision is immutable, such as a SHA1, a tag or
     # a change, then we can't push back to it. Substitute with
@@ -69,7 +68,7 @@
 
     try:
       ret = project.StartBranch(
-          nb, branch_merge=branch_merge, revision=opt.revision)
+          nb, branch_merge=branch_merge, revision=revision)
     except Exception as e:
       print('error: unable to checkout %s: %s' % (project.name, e), file=sys.stderr)
       ret = False
@@ -123,23 +122,18 @@
         pm.update()
       pm.end()
 
-    def _ProcessResults(results):
+    def _ProcessResults(_pool, pm, results):
       for (result, project) in results:
         if not result:
           err.append(project)
         pm.update()
 
-    pm = Progress('Starting %s' % nb, len(all_projects), quiet=opt.quiet)
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(all_projects) == 1 or opt.jobs == 1:
-      _ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects)
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        results = pool.imap_unordered(
-            functools.partial(self._ExecuteOne, opt, nb), all_projects,
-            chunksize=WORKER_BATCH_SIZE)
-        _ProcessResults(results)
-    pm.end()
+    self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._ExecuteOne, opt.revision, nb),
+        all_projects,
+        callback=_ProcessResults,
+        output=Progress('Starting %s' % (nb,), len(all_projects), quiet=opt.quiet))
 
     if err:
       for p in err:
diff --git a/subcmds/status.py b/subcmds/status.py
index dc223a0..1b48dce 100644
--- a/subcmds/status.py
+++ b/subcmds/status.py
@@ -15,10 +15,9 @@
 import functools
 import glob
 import io
-import multiprocessing
 import os
 
-from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE
+from command import DEFAULT_LOCAL_JOBS, PagedCommand
 
 from color import Coloring
 import platform_utils
@@ -119,22 +118,23 @@
 
   def Execute(self, opt, args):
     all_projects = self.GetProjects(args)
-    counter = 0
 
-    if opt.jobs == 1:
-      for project in all_projects:
-        state = project.PrintWorkTreeStatus(quiet=opt.quiet)
+    def _ProcessResults(_pool, _output, results):
+      ret = 0
+      for (state, output) in results:
+        if output:
+          print(output, end='')
         if state == 'CLEAN':
-          counter += 1
-    else:
-      with multiprocessing.Pool(opt.jobs) as pool:
-        states = pool.imap(functools.partial(self._StatusHelper, opt.quiet),
-                           all_projects, chunksize=WORKER_BATCH_SIZE)
-        for (state, output) in states:
-          if output:
-            print(output, end='')
-          if state == 'CLEAN':
-            counter += 1
+          ret += 1
+      return ret
+
+    counter = self.ExecuteInParallel(
+        opt.jobs,
+        functools.partial(self._StatusHelper, opt.quiet),
+        all_projects,
+        callback=_ProcessResults,
+        ordered=True)
+
     if not opt.quiet and len(all_projects) == counter:
       print('nothing to commit (working directory clean)')
 
diff --git a/subcmds/sync.py b/subcmds/sync.py
index 21166af..4763fad 100644
--- a/subcmds/sync.py
+++ b/subcmds/sync.py
@@ -51,7 +51,7 @@
 import gitc_utils
 from project import Project
 from project import RemoteSpec
-from command import Command, MirrorSafeCommand, WORKER_BATCH_SIZE
+from command import Command, MirrorSafeCommand
 from error import RepoChangedException, GitError, ManifestParseError
 import platform_utils
 from project import SyncBuffer
@@ -428,11 +428,12 @@
 
     return (ret, fetched)
 
-  def _CheckoutOne(self, opt, project):
+  def _CheckoutOne(self, detach_head, force_sync, project):
     """Checkout work tree for one project
 
     Args:
-      opt: Program options returned from optparse.  See _Options().
+      detach_head: Whether to leave a detached HEAD.
+      force_sync: Force checking out of the repo.
       project: Project object for the project to checkout.
 
     Returns:
@@ -440,10 +441,10 @@
     """
     start = time.time()
     syncbuf = SyncBuffer(self.manifest.manifestProject.config,
-                         detach_head=opt.detach_head)
+                         detach_head=detach_head)
     success = False
     try:
-      project.Sync_LocalHalf(syncbuf, force_sync=opt.force_sync)
+      project.Sync_LocalHalf(syncbuf, force_sync=force_sync)
       success = syncbuf.Finish()
     except Exception as e:
       print('error: Cannot checkout %s: %s: %s' %
@@ -464,44 +465,32 @@
       opt: Program options returned from optparse.  See _Options().
       err_results: A list of strings, paths to git repos where checkout failed.
     """
-    ret = True
-    jobs = opt.jobs_checkout if opt.jobs_checkout else self.jobs
-
     # Only checkout projects with worktrees.
     all_projects = [x for x in all_projects if x.worktree]
 
-    pm = Progress('Checking out', len(all_projects), quiet=opt.quiet)
-
-    def _ProcessResults(results):
+    def _ProcessResults(pool, pm, results):
+      ret = True
       for (success, project, start, finish) in results:
         self.event_log.AddSync(project, event_log.TASK_SYNC_LOCAL,
                                start, finish, success)
         # Check for any errors before running any more tasks.
         # ...we'll let existing jobs finish, though.
         if not success:
+          ret = False
           err_results.append(project.relpath)
           if opt.fail_fast:
-            return False
+            if pool:
+              pool.close()
+            return ret
         pm.update(msg=project.name)
-      return True
+      return ret
 
-    # NB: Multiprocessing is heavy, so don't spin it up for one job.
-    if len(all_projects) == 1 or jobs == 1:
-      if not _ProcessResults(self._CheckoutOne(opt, x) for x in all_projects):
-        ret = False
-    else:
-      with multiprocessing.Pool(jobs) as pool:
-        results = pool.imap_unordered(
-            functools.partial(self._CheckoutOne, opt),
-            all_projects,
-            chunksize=WORKER_BATCH_SIZE)
-        if not _ProcessResults(results):
-          ret = False
-          pool.close()
-
-    pm.end()
-
-    return ret and not err_results
+    return self.ExecuteInParallel(
+        opt.jobs_checkout if opt.jobs_checkout else self.jobs,
+        functools.partial(self._CheckoutOne, opt.detach_head, opt.force_sync),
+        all_projects,
+        callback=_ProcessResults,
+        output=Progress('Checking out', len(all_projects), quiet=opt.quiet)) and not err_results
 
   def _GCProjects(self, projects, opt, err_event):
     pm = Progress('Garbage collecting', len(projects), delay=False, quiet=opt.quiet)