repo_trace: Avoid race conditions with trace_file updating.

Change-Id: I0bc1bb3c8f60465dc6bee5081688a9f163dd8cf8
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/354515
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Joanna Wang <jojwang@google.com>
diff --git a/repo_trace.py b/repo_trace.py
index d79408d..1ba86c7 100644
--- a/repo_trace.py
+++ b/repo_trace.py
@@ -23,6 +23,7 @@
 import sys
 import os
 import time
+import tempfile
 from contextlib import ContextDecorator
 
 import platform_utils
@@ -35,34 +36,40 @@
 _TRACE_TO_STDERR = False
 _TRACE_FILE = None
 _TRACE_FILE_NAME = 'TRACE_FILE'
-_MAX_SIZE = 70  # in mb
+_MAX_SIZE = 70  # in MiB
 _NEW_COMMAND_SEP = '+++++++++++++++NEW COMMAND+++++++++++++++++++'
 
 
 def IsTraceToStderr():
+  """Whether traces are written to stderr."""
   return _TRACE_TO_STDERR
 
 
 def IsTrace():
+  """Whether tracing is enabled."""
   return _TRACE
 
 
 def SetTraceToStderr():
+  """Enables tracing logging to stderr."""
   global _TRACE_TO_STDERR
   _TRACE_TO_STDERR = True
 
 
 def SetTrace():
+  """Enables tracing."""
   global _TRACE
   _TRACE = True
 
 
 def _SetTraceFile(quiet):
+  """Sets the trace file location."""
   global _TRACE_FILE
   _TRACE_FILE = _GetTraceFile(quiet)
 
 
 class Trace(ContextDecorator):
+  """Used to capture and save git traces."""
 
   def _time(self):
     """Generate nanoseconds of time in a py3.6 safe way"""
@@ -128,20 +135,32 @@
 
 
 def _ClearOldTraces():
-  """Clear the oldest commands if trace file is too big.
+  """Clear the oldest commands if trace file is too big."""
+  try:
+    with open(_TRACE_FILE, 'r', errors='ignore') as f:
+      if os.path.getsize(f.name) / (1024 * 1024) <= _MAX_SIZE:
+        return
+      trace_lines = f.readlines()
+  except FileNotFoundError:
+    return
 
-  Note: If the trace file contains output from two `repo`
-        commands that were running at the same time, this
-        will not work precisely.
-  """
-  if os.path.isfile(_TRACE_FILE):
-    while os.path.getsize(_TRACE_FILE) / (1024 * 1024) > _MAX_SIZE:
-      temp_file = _TRACE_FILE + '.tmp'
-      with open(_TRACE_FILE, 'r', errors='ignore') as fin:
-        with open(temp_file, 'w') as tf:
-          trace_lines = fin.readlines()
-          for i, l in enumerate(trace_lines):
-            if 'END:' in l and _NEW_COMMAND_SEP in l:
-              tf.writelines(trace_lines[i + 1:])
-              break
-      platform_utils.rename(temp_file, _TRACE_FILE)
+  while sum(len(x) for x in trace_lines) / (1024 * 1024) > _MAX_SIZE:
+    for i, line in enumerate(trace_lines):
+      if 'END:' in line and _NEW_COMMAND_SEP in line:
+        trace_lines = trace_lines[i + 1:]
+        break
+    else:
+      # The last chunk is bigger than _MAX_SIZE, so just throw everything away.
+      trace_lines = []
+
+  while trace_lines and trace_lines[-1] == '\n':
+    trace_lines = trace_lines[:-1]
+  # Write to a temporary file with a unique name in the same filesystem
+  # before replacing the original trace file.
+  temp_dir, temp_prefix = os.path.split(_TRACE_FILE)
+  with tempfile.NamedTemporaryFile('w',
+                                   dir=temp_dir,
+                                   prefix=temp_prefix,
+                                   delete=False) as f:
+    f.writelines(trace_lines)
+  platform_utils.rename(f.name, _TRACE_FILE)
diff --git a/tests/test_repo_trace.py b/tests/test_repo_trace.py
new file mode 100644
index 0000000..5faf293
--- /dev/null
+++ b/tests/test_repo_trace.py
@@ -0,0 +1,56 @@
+# Copyright 2022 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittests for the repo_trace.py module."""
+
+import os
+import unittest
+from unittest import mock
+
+import repo_trace
+
+
+class TraceTests(unittest.TestCase):
+  """Check Trace behavior."""
+
+  def testTrace_MaxSizeEnforced(self):
+    content = 'git chicken'
+
+    with repo_trace.Trace(content, first_trace=True):
+      pass
+    first_trace_size = os.path.getsize(repo_trace._TRACE_FILE)
+
+    with repo_trace.Trace(content):
+      pass
+    self.assertGreater(
+        os.path.getsize(repo_trace._TRACE_FILE), first_trace_size)
+
+    # Check we clear everything is the last chunk is larger than _MAX_SIZE.
+    with mock.patch('repo_trace._MAX_SIZE', 0):
+      with repo_trace.Trace(content, first_trace=True):
+        pass
+      self.assertEqual(first_trace_size,
+                       os.path.getsize(repo_trace._TRACE_FILE))
+
+      # Check we only clear the chunks we need to.
+      repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024)
+      with repo_trace.Trace(content, first_trace=True):
+        pass
+      self.assertEqual(first_trace_size * 2,
+                       os.path.getsize(repo_trace._TRACE_FILE))
+
+      with repo_trace.Trace(content, first_trace=True):
+        pass
+      self.assertEqual(first_trace_size * 2,
+                       os.path.getsize(repo_trace._TRACE_FILE))