repo_trace: Avoid race conditions with trace_file updating.
Change-Id: I0bc1bb3c8f60465dc6bee5081688a9f163dd8cf8
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/354515
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Joanna Wang <jojwang@google.com>
diff --git a/repo_trace.py b/repo_trace.py
index d79408d..1ba86c7 100644
--- a/repo_trace.py
+++ b/repo_trace.py
@@ -23,6 +23,7 @@
import sys
import os
import time
+import tempfile
from contextlib import ContextDecorator
import platform_utils
@@ -35,34 +36,40 @@
_TRACE_TO_STDERR = False
_TRACE_FILE = None
_TRACE_FILE_NAME = 'TRACE_FILE'
-_MAX_SIZE = 70 # in mb
+_MAX_SIZE = 70 # in MiB
_NEW_COMMAND_SEP = '+++++++++++++++NEW COMMAND+++++++++++++++++++'
def IsTraceToStderr():
+ """Whether traces are written to stderr."""
return _TRACE_TO_STDERR
def IsTrace():
+ """Whether tracing is enabled."""
return _TRACE
def SetTraceToStderr():
+ """Enables tracing logging to stderr."""
global _TRACE_TO_STDERR
_TRACE_TO_STDERR = True
def SetTrace():
+ """Enables tracing."""
global _TRACE
_TRACE = True
def _SetTraceFile(quiet):
+ """Sets the trace file location."""
global _TRACE_FILE
_TRACE_FILE = _GetTraceFile(quiet)
class Trace(ContextDecorator):
+ """Used to capture and save git traces."""
def _time(self):
"""Generate nanoseconds of time in a py3.6 safe way"""
@@ -128,20 +135,32 @@
def _ClearOldTraces():
- """Clear the oldest commands if trace file is too big.
+ """Clear the oldest commands if trace file is too big."""
+ try:
+ with open(_TRACE_FILE, 'r', errors='ignore') as f:
+ if os.path.getsize(f.name) / (1024 * 1024) <= _MAX_SIZE:
+ return
+ trace_lines = f.readlines()
+ except FileNotFoundError:
+ return
- Note: If the trace file contains output from two `repo`
- commands that were running at the same time, this
- will not work precisely.
- """
- if os.path.isfile(_TRACE_FILE):
- while os.path.getsize(_TRACE_FILE) / (1024 * 1024) > _MAX_SIZE:
- temp_file = _TRACE_FILE + '.tmp'
- with open(_TRACE_FILE, 'r', errors='ignore') as fin:
- with open(temp_file, 'w') as tf:
- trace_lines = fin.readlines()
- for i, l in enumerate(trace_lines):
- if 'END:' in l and _NEW_COMMAND_SEP in l:
- tf.writelines(trace_lines[i + 1:])
- break
- platform_utils.rename(temp_file, _TRACE_FILE)
+ while sum(len(x) for x in trace_lines) / (1024 * 1024) > _MAX_SIZE:
+ for i, line in enumerate(trace_lines):
+ if 'END:' in line and _NEW_COMMAND_SEP in line:
+ trace_lines = trace_lines[i + 1:]
+ break
+ else:
+ # The last chunk is bigger than _MAX_SIZE, so just throw everything away.
+ trace_lines = []
+
+ while trace_lines and trace_lines[-1] == '\n':
+ trace_lines = trace_lines[:-1]
+ # Write to a temporary file with a unique name in the same filesystem
+ # before replacing the original trace file.
+ temp_dir, temp_prefix = os.path.split(_TRACE_FILE)
+ with tempfile.NamedTemporaryFile('w',
+ dir=temp_dir,
+ prefix=temp_prefix,
+ delete=False) as f:
+ f.writelines(trace_lines)
+ platform_utils.rename(f.name, _TRACE_FILE)
diff --git a/tests/test_repo_trace.py b/tests/test_repo_trace.py
new file mode 100644
index 0000000..5faf293
--- /dev/null
+++ b/tests/test_repo_trace.py
@@ -0,0 +1,56 @@
+# Copyright 2022 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittests for the repo_trace.py module."""
+
+import os
+import unittest
+from unittest import mock
+
+import repo_trace
+
+
+class TraceTests(unittest.TestCase):
+ """Check Trace behavior."""
+
+ def testTrace_MaxSizeEnforced(self):
+ content = 'git chicken'
+
+ with repo_trace.Trace(content, first_trace=True):
+ pass
+ first_trace_size = os.path.getsize(repo_trace._TRACE_FILE)
+
+ with repo_trace.Trace(content):
+ pass
+ self.assertGreater(
+ os.path.getsize(repo_trace._TRACE_FILE), first_trace_size)
+
+ # Check we clear everything is the last chunk is larger than _MAX_SIZE.
+ with mock.patch('repo_trace._MAX_SIZE', 0):
+ with repo_trace.Trace(content, first_trace=True):
+ pass
+ self.assertEqual(first_trace_size,
+ os.path.getsize(repo_trace._TRACE_FILE))
+
+ # Check we only clear the chunks we need to.
+ repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024)
+ with repo_trace.Trace(content, first_trace=True):
+ pass
+ self.assertEqual(first_trace_size * 2,
+ os.path.getsize(repo_trace._TRACE_FILE))
+
+ with repo_trace.Trace(content, first_trace=True):
+ pass
+ self.assertEqual(first_trace_size * 2,
+ os.path.getsize(repo_trace._TRACE_FILE))