trace: allow writing traces to a socket

Git can write trace2 events to a Unix domain socket [1]. This can be
specified via Git's `trace2.eventTarget` config option, which we read to
determine where to log our own trace2 events. Currently, if the Git
config specifies a socket as the trace2 target, we fail to log any
traces.

Fix this by adding support for writing to a Unix domain socket,
following the same specification that Git supports.

[1]: https://git-scm.com/docs/api-trace2#_enabling_a_target

Change-Id: I928bc22ba04fba603a9132eb055141845fa48ab2
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/332339
Reviewed-by: Raman Tenneti <rtenneti@google.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Josh Steadmon <steadmon@google.com>
diff --git a/git_trace2_event_log.py b/git_trace2_event_log.py
index 0e5e908..7426aba 100644
--- a/git_trace2_event_log.py
+++ b/git_trace2_event_log.py
@@ -29,8 +29,10 @@
 
 
 import datetime
+import errno
 import json
 import os
+import socket
 import sys
 import tempfile
 import threading
@@ -218,20 +220,39 @@
           retval, p.stderr), file=sys.stderr)
     return path
 
+  def _WriteLog(self, write_fn):
+    """Writes the log out using a provided writer function.
+
+    Generate compact JSON output for each item in the log, and write it using
+    write_fn.
+
+    Args:
+      write_fn: A function that accepts byts and writes them to a destination.
+    """
+
+    for e in self._log:
+      # Dump in compact encoding mode.
+      # See 'Compact encoding' in Python docs:
+      # https://docs.python.org/3/library/json.html#module-json
+      write_fn(json.dumps(e, indent=None, separators=(',', ':')).encode('utf-8') + b'\n')
+
   def Write(self, path=None):
-    """Writes the log out to a file.
+    """Writes the log out to a file or socket.
 
     Log is only written if 'path' or 'git config --get trace2.eventtarget'
-    provide a valid path to write logs to.
+    provide a valid path (or socket) to write logs to.
 
     Logging filename format follows the git trace2 style of being a unique
     (exclusive writable) file.
 
     Args:
-      path: Path to where logs should be written.
+      path: Path to where logs should be written. The path may have a prefix of
+          the form "af_unix:[{stream|dgram}:]", in which case the path is
+          treated as a Unix domain socket. See
+          https://git-scm.com/docs/api-trace2#_enabling_a_target for details.
 
     Returns:
-      log_path: Path to the log file if log is written, otherwise None
+      log_path: Path to the log file or socket if log is written, otherwise None
     """
     log_path = None
     # If no logging path is specified, get the path from 'trace2.eventtarget'.
@@ -242,29 +263,66 @@
     if path is None:
       return None
 
+    path_is_socket = False
+    socket_type = None
     if isinstance(path, str):
-      # Get absolute path.
-      path = os.path.abspath(os.path.expanduser(path))
+      parts = path.split(':', 1)
+      if parts[0] == 'af_unix' and len(parts) == 2:
+        path_is_socket = True
+        path = parts[1]
+        parts = path.split(':', 1)
+        if parts[0] == 'stream' and len(parts) == 2:
+          socket_type = socket.SOCK_STREAM
+          path = parts[1]
+        elif parts[0] == 'dgram' and len(parts) == 2:
+          socket_type = socket.SOCK_DGRAM
+          path = parts[1]
+      else:
+        # Get absolute path.
+        path = os.path.abspath(os.path.expanduser(path))
     else:
       raise TypeError('path: str required but got %s.' % type(path))
 
     # Git trace2 requires a directory to write log to.
 
     # TODO(https://crbug.com/gerrit/13706): Support file (append) mode also.
-    if not os.path.isdir(path):
+    if not (path_is_socket or os.path.isdir(path)):
       return None
+
+    if path_is_socket:
+      if socket_type == socket.SOCK_STREAM or socket_type is None:
+        try:
+          with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
+            sock.connect(path)
+            self._WriteLog(sock.sendall)
+          return f'af_unix:stream:{path}'
+        except OSError as err:
+          # If we tried to connect to a DGRAM socket using STREAM, ignore the
+          # attempt and continue to DGRAM below. Otherwise, issue a warning.
+          if err.errno != errno.EPROTOTYPE:
+            print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr)
+            return None
+      if socket_type == socket.SOCK_DGRAM or socket_type is None:
+        try:
+          with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock:
+            self._WriteLog(lambda bs: sock.sendto(bs, path))
+            return f'af_unix:dgram:{path}'
+        except OSError as err:
+          print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr)
+          return None
+      # Tried to open a socket but couldn't connect (SOCK_STREAM) or write
+      # (SOCK_DGRAM).
+      print('repo: warning: git trace2 logging failed: could not write to socket', file=sys.stderr)
+      return None
+
+    # Path is an absolute path
     # Use NamedTemporaryFile to generate a unique filename as required by git trace2.
     try:
-      with tempfile.NamedTemporaryFile(mode='x', prefix=self._sid, dir=path,
+      with tempfile.NamedTemporaryFile(mode='xb', prefix=self._sid, dir=path,
                                        delete=False) as f:
         # TODO(https://crbug.com/gerrit/13706): Support writing events as they
         # occur.
-        for e in self._log:
-          # Dump in compact encoding mode.
-          # See 'Compact encoding' in Python docs:
-          # https://docs.python.org/3/library/json.html#module-json
-          json.dump(e, f, indent=None, separators=(',', ':'))
-          f.write('\n')
+        self._WriteLog(f.write)
         log_path = f.name
     except FileExistsError as err:
       print('repo: warning: git trace2 logging failed: %r' % err,
diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py
index 89dcfb9..0623d32 100644
--- a/tests/test_git_trace2_event_log.py
+++ b/tests/test_git_trace2_event_log.py
@@ -16,11 +16,42 @@
 
 import json
 import os
+import socket
 import tempfile
+import threading
 import unittest
 from unittest import mock
 
 import git_trace2_event_log
+import platform_utils
+
+
+def serverLoggingThread(socket_path, server_ready, received_traces):
+  """Helper function to receive logs over a Unix domain socket.
+
+  Appends received messages on the provided socket and appends to received_traces.
+
+  Args:
+    socket_path: path to a Unix domain socket on which to listen for traces
+    server_ready: a threading.Condition used to signal to the caller that this thread is ready to
+        accept connections
+    received_traces: a list to which received traces will be appended (after decoding to a utf-8
+        string).
+  """
+  platform_utils.remove(socket_path, missing_ok=True)
+  data = b''
+  with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
+    sock.bind(socket_path)
+    sock.listen(0)
+    with server_ready:
+      server_ready.notify()
+    with sock.accept()[0] as conn:
+      while True:
+        recved = conn.recv(4096)
+        if not recved:
+          break
+        data += recved
+  received_traces.extend(data.decode('utf-8').splitlines())
 
 
 class EventLogTestCase(unittest.TestCase):
@@ -324,6 +355,37 @@
     with self.assertRaises(TypeError):
       self._event_log_module.Write(path=1234)
 
+  def test_write_socket(self):
+    """Test Write() with Unix domain socket for |path| and validate received traces."""
+    received_traces = []
+    with tempfile.TemporaryDirectory(prefix='test_server_sockets') as tempdir:
+      socket_path = os.path.join(tempdir, "server.sock")
+      server_ready = threading.Condition()
+      # Start "server" listening on Unix domain socket at socket_path.
+      try:
+        server_thread = threading.Thread(
+            target=serverLoggingThread,
+            args=(socket_path, server_ready, received_traces))
+        server_thread.start()
+
+        with server_ready:
+          server_ready.wait()
+
+        self._event_log_module.StartEvent()
+        path = self._event_log_module.Write(path=f'af_unix:{socket_path}')
+      finally:
+        server_thread.join(timeout=5)
+
+    self.assertEqual(path, f'af_unix:stream:{socket_path}')
+    self.assertEqual(len(received_traces), 2)
+    version_event = json.loads(received_traces[0])
+    start_event = json.loads(received_traces[1])
+    self.verifyCommonKeys(version_event, expected_event_name='version')
+    self.verifyCommonKeys(start_event, expected_event_name='start')
+    # Check for 'start' event specific fields.
+    self.assertIn('argv', start_event)
+    self.assertIsInstance(start_event['argv'], list)
+
 
 if __name__ == '__main__':
   unittest.main()