Deal with DataSet[AggregatedAuditEvent] in Auditlog

Previous this change, the auditlog ETL job was exposing transformed data
as dataframe, this changes modifies that interface to deal with
Dataset[AggregatedAuditEvent] instead.

An AggregatedAuditEvent represents the result of spark transformations
and aggregations before being indexed in elasticsearch.

Change-Id: Ic3e8b8aa1b9e4294b3500de1d98e66623188f5d1
diff --git a/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/model/AggregatedAuditEvent.scala b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/model/AggregatedAuditEvent.scala
new file mode 100644
index 0000000..2fd3331
--- /dev/null
+++ b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/model/AggregatedAuditEvent.scala
@@ -0,0 +1,29 @@
+// Copyright (C) 2019 GerritForge Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.gerritforge.analytics.auditlog.model
+import java.sql.Timestamp
+
+case class AggregatedAuditEvent(
+  events_time_bucket: Timestamp,
+  audit_type: String,
+  user_identifier: Option[String],
+  user_type: Option[String],
+  access_path: Option[String],
+  command: String,
+  command_arguments: String,
+  project: Option[String],
+  result: String,
+  num_events: Long
+)
\ No newline at end of file
diff --git a/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/AuditLogsTransformer.scala b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/AuditLogsTransformer.scala
index 3fef687..9710149 100644
--- a/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/AuditLogsTransformer.scala
+++ b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/AuditLogsTransformer.scala
@@ -15,13 +15,13 @@
 package com.gerritforge.analytics.auditlog.spark
 
 import com.gerritforge.analytics.auditlog.broadcast.{AdditionalUsersInfo, GerritProjects, GerritUserIdentifiers}
-import com.gerritforge.analytics.auditlog.model.AuditEvent
+import com.gerritforge.analytics.auditlog.model.{AggregatedAuditEvent, AuditEvent}
 import com.gerritforge.analytics.auditlog.model.ElasticSearchFields._
 import com.gerritforge.analytics.auditlog.range.TimeRange
 import com.gerritforge.analytics.auditlog.spark.dataframe.ops.DataFrameOps._
 import com.gerritforge.analytics.auditlog.spark.rdd.ops.SparkRDDOps._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{Dataset, SparkSession}
 
 case class AuditLogsTransformer(
   gerritIdentifiers: GerritUserIdentifiers = GerritUserIdentifiers.empty,
@@ -33,7 +33,8 @@
   private val broadcastAdditionalUsersInfo = spark.sparkContext.broadcast(additionalUsersInfo)
   private val broadcastGerritProjects = spark.sparkContext.broadcast(gerritProjects)
 
-  def transform(auditEventsRDD: RDD[AuditEvent], timeAggregation: String, timeRange: TimeRange = TimeRange.always): DataFrame =
+  def transform(auditEventsRDD: RDD[AuditEvent], timeAggregation: String, timeRange: TimeRange = TimeRange.always): Dataset[AggregatedAuditEvent] = {
+    import spark.implicits._
     auditEventsRDD
       .filterWithinRange(TimeRange(timeRange.since, timeRange.until))
       .toJsonString
@@ -44,4 +45,6 @@
       .withUserTypeColumn(USER_TYPE_FIELD, broadcastAdditionalUsersInfo.value)
       .withProjectColumn(PROJECT_FIELD, broadcastGerritProjects.value)
       .aggregateNumEventsColumn(NUM_EVENTS_FIELD, FACETING_FIELDS)
-}
\ No newline at end of file
+      .as[AggregatedAuditEvent]
+  }
+}
diff --git a/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/rdd/ops/SparkRDDOps.scala b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/rdd/ops/SparkRDDOps.scala
index bfc466e..3887aca 100644
--- a/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/rdd/ops/SparkRDDOps.scala
+++ b/auditlog/src/main/scala/com/gerritforge/analytics/auditlog/spark/rdd/ops/SparkRDDOps.scala
@@ -15,11 +15,11 @@
 package com.gerritforge.analytics.auditlog.spark.rdd.ops
 
 import com.gerritforge.analytics.auditlog.broadcast.{AdditionalUsersInfo, GerritProjects, GerritUserIdentifiers}
-import com.gerritforge.analytics.auditlog.model.AuditEvent
+import com.gerritforge.analytics.auditlog.model.{AggregatedAuditEvent, AuditEvent}
 import com.gerritforge.analytics.auditlog.range.TimeRange
 import com.gerritforge.analytics.auditlog.spark.AuditLogsTransformer
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 
 object SparkRDDOps {
 
@@ -35,7 +35,7 @@
       gerritProjects: GerritProjects,
       timeAggregation: String,
       timeRange: TimeRange
-    )(implicit spark: SparkSession): DataFrame = {
+    )(implicit spark: SparkSession): Dataset[AggregatedAuditEvent] = {
 
       AuditLogsTransformer(gerritUserIdentifiers, additionalUsersInfo, gerritProjects)
         .transform(rdd, timeAggregation, timeRange)
diff --git a/auditlog/src/test/scala/com/gerritforge/analytics/auditlog/AuditLogsTransformerSpec.scala b/auditlog/src/test/scala/com/gerritforge/analytics/auditlog/AuditLogsTransformerSpec.scala
index 3b5fefc..79df0da 100644
--- a/auditlog/src/test/scala/com/gerritforge/analytics/auditlog/AuditLogsTransformerSpec.scala
+++ b/auditlog/src/test/scala/com/gerritforge/analytics/auditlog/AuditLogsTransformerSpec.scala
@@ -17,10 +17,9 @@
 
 import com.gerritforge.analytics.SparkTestSupport
 import com.gerritforge.analytics.auditlog.broadcast._
-import com.gerritforge.analytics.auditlog.model.{ElasticSearchFields, HttpAuditEvent, SshAuditEvent}
+import com.gerritforge.analytics.auditlog.model.{AggregatedAuditEvent, ElasticSearchFields, HttpAuditEvent, SshAuditEvent}
 import com.gerritforge.analytics.auditlog.spark.AuditLogsTransformer
 import com.gerritforge.analytics.support.ops.CommonTimeOperations._
-import org.apache.spark.sql.Row
 import org.scalatest.{FlatSpec, Matchers}
 
 class AuditLogsTransformerSpec extends FlatSpec with Matchers with SparkTestSupport with TestFixtures {
@@ -29,44 +28,42 @@
   it should "process an anonymous http audit entry" in {
     val events = Seq(anonymousHttpAuditEvent)
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 1
-    dataFrame.collect should contain only Row(
-        AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
-        HttpAuditEvent.auditType,
-        null, // no user identifier
-        null, // no user type
-        anonymousHttpAuditEvent.accessPath.get,
-        GIT_UPLOAD_PACK,
-        anonymousHttpAuditEvent.what,
-        null, // no project
-        anonymousHttpAuditEvent.result,
-        expectedAggregatedCount
+    aggregatedEventsDS.collect should contain only AggregatedAuditEvent(
+      AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
+      HttpAuditEvent.auditType,
+      None,
+      None,
+      anonymousHttpAuditEvent.accessPath,
+      GIT_UPLOAD_PACK,
+      anonymousHttpAuditEvent.what,
+      None,
+      anonymousHttpAuditEvent.result,
+      num_events = 1
     )
   }
 
   it should "process an authenticated http audit entry where gerrit account couldn't be identified" in {
     val events = Seq(authenticatedHttpAuditEvent)
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 1
-    dataFrame.collect should contain only Row(
+    aggregatedEventsDS.collect should contain only AggregatedAuditEvent(
       AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
       HttpAuditEvent.auditType,
-      s"${authenticatedHttpAuditEvent.who.get}",
-      AdditionalUserInfo.DEFAULT_USER_TYPE,
-      authenticatedHttpAuditEvent.accessPath.get,
+      authenticatedHttpAuditEvent.who.map(_.toString),
+      Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+      anonymousHttpAuditEvent.accessPath,
       GIT_UPLOAD_PACK,
-      authenticatedHttpAuditEvent.what,
-      null, // no project
-      authenticatedHttpAuditEvent.result,
-      expectedAggregatedCount
+      anonymousHttpAuditEvent.what,
+      None,
+      anonymousHttpAuditEvent.result,
+      num_events = 1
     )
   }
 
@@ -74,68 +71,65 @@
     val events = Seq(authenticatedHttpAuditEvent)
     val gerritUserIdentifier = "Antonio Barone"
 
-    val dataFrame =
+    val aggregatedEventsDS =
       AuditLogsTransformer(GerritUserIdentifiers(Map(authenticatedHttpAuditEvent.who.get -> gerritUserIdentifier)))
         .transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 1
-    dataFrame.collect should contain only Row(
+    aggregatedEventsDS.collect should contain only AggregatedAuditEvent(
       AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
       HttpAuditEvent.auditType,
-      gerritUserIdentifier,
-      AdditionalUserInfo.DEFAULT_USER_TYPE,
-      authenticatedHttpAuditEvent.accessPath.get,
+      Some(gerritUserIdentifier),
+      Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+      anonymousHttpAuditEvent.accessPath,
       GIT_UPLOAD_PACK,
       authenticatedHttpAuditEvent.what,
-      null, // no project
+      None,
       authenticatedHttpAuditEvent.result,
-      expectedAggregatedCount
+      num_events = 1
     )
   }
 
   it should "process an SSH audit entry" in {
     val events = Seq(sshAuditEvent)
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 1
-    dataFrame.collect should contain only Row(
+    aggregatedEventsDS.collect should contain only AggregatedAuditEvent(
       AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
       SshAuditEvent.auditType,
-      s"${sshAuditEvent.who.get}",
-      AdditionalUserInfo.DEFAULT_USER_TYPE,
-      sshAuditEvent.accessPath.get,
+      sshAuditEvent.who.map(_.toString),
+      Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+      sshAuditEvent.accessPath,
       SSH_GERRIT_COMMAND,
       SSH_GERRIT_COMMAND_ARGUMENTS,
-      null, // no project
+      None,
       sshAuditEvent.result,
-      expectedAggregatedCount
+      num_events = 1
     )
   }
 
   it should "group ssh events from the same user together, if they fall within the same time bucket (hour)" in {
     val events = Seq(sshAuditEvent, sshAuditEvent.copy(timeAtStart = timeAtStart + 1000))
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 2
-    dataFrame.collect should contain only Row(
+    aggregatedEventsDS.collect should contain only AggregatedAuditEvent(
       AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
       SshAuditEvent.auditType,
-      s"${sshAuditEvent.who.get}",
-      AdditionalUserInfo.DEFAULT_USER_TYPE,
-      sshAuditEvent.accessPath.get,
+      sshAuditEvent.who.map(_.toString),
+      Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+      sshAuditEvent.accessPath,
       SSH_GERRIT_COMMAND,
       SSH_GERRIT_COMMAND_ARGUMENTS,
-      null, // no project
+      None,
       sshAuditEvent.result,
-      expectedAggregatedCount
+      num_events = 2
     )
   }
 
@@ -143,35 +137,34 @@
     val user2Id = sshAuditEvent.who.map(_ + 1)
     val events = Seq(sshAuditEvent, sshAuditEvent.copy(who=user2Id))
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedAggregatedCount = 1
-    dataFrame.collect should contain allOf (
-      Row(
+    aggregatedEventsDS.collect should contain allOf (
+      AggregatedAuditEvent(
         AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
         SshAuditEvent.auditType,
-        s"${sshAuditEvent.who.get}",
-        AdditionalUserInfo.DEFAULT_USER_TYPE,
-        sshAuditEvent.accessPath.get,
+        sshAuditEvent.who.map(_.toString),
+        Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+        sshAuditEvent.accessPath,
         SSH_GERRIT_COMMAND,
         SSH_GERRIT_COMMAND_ARGUMENTS,
-        null, // no project
+        None,
         sshAuditEvent.result,
-        expectedAggregatedCount
+        num_events = 1
       ),
-      Row(
+      AggregatedAuditEvent(
         AuditLogsTransformerSpec.epochMillisToNearestHour(timeAtStart),
         SshAuditEvent.auditType,
-        s"${user2Id.get}",
-        AdditionalUserInfo.DEFAULT_USER_TYPE,
-        sshAuditEvent.accessPath.get,
+        user2Id.map(_.toString),
+        Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+        sshAuditEvent.accessPath,
         SSH_GERRIT_COMMAND,
         SSH_GERRIT_COMMAND_ARGUMENTS,
-        null, // no project
+        None,
         sshAuditEvent.result,
-        expectedAggregatedCount
+        num_events = 1
       )
     )
   }
@@ -179,36 +172,34 @@
   it should "group different event types separately, event if they fall within the same time bucket (hour)" in {
     val events = Seq(sshAuditEvent, authenticatedHttpAuditEvent.copy(timeAtStart = sshAuditEvent.timeAtStart + 1000))
 
-    val dataFrame = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
+    val aggregatedEventsDS = AuditLogsTransformer().transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
+    aggregatedEventsDS.columns should contain theSameElementsAs ElasticSearchFields.ALL_DOCUMENT_FIELDS
 
-    val expectedSshAggregatedCount = 1
-    val expectedHttpAggregatedCount = 1
-    dataFrame.collect should contain allOf (
-      Row(
+    aggregatedEventsDS.collect should contain allOf (
+      AggregatedAuditEvent(
         AuditLogsTransformerSpec.epochMillisToNearestHour(events.head.timeAtStart),
         SshAuditEvent.auditType,
-        s"${sshAuditEvent.who.get}",
-        AdditionalUserInfo.DEFAULT_USER_TYPE,
-        sshAuditEvent.accessPath.get,
+        sshAuditEvent.who.map(_.toString),
+        Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+        sshAuditEvent.accessPath,
         SSH_GERRIT_COMMAND,
         SSH_GERRIT_COMMAND_ARGUMENTS,
-        null, // no project
+        None,
         sshAuditEvent.result,
-        expectedSshAggregatedCount
+        num_events = 1
       ),
-      Row(
+      AggregatedAuditEvent(
         AuditLogsTransformerSpec.epochMillisToNearestHour(events.last.timeAtStart),
         HttpAuditEvent.auditType,
-        s"${authenticatedHttpAuditEvent.who.get}",
-        AdditionalUserInfo.DEFAULT_USER_TYPE,
-        authenticatedHttpAuditEvent.accessPath.get,
+        authenticatedHttpAuditEvent.who.map(_.toString),
+        Some(AdditionalUserInfo.DEFAULT_USER_TYPE),
+        authenticatedHttpAuditEvent.accessPath,
         GIT_UPLOAD_PACK,
         authenticatedHttpAuditEvent.what,
-        null, // no project
+        None,
         authenticatedHttpAuditEvent.result,
-        expectedHttpAggregatedCount
+        num_events = 1
       )
     )
   }
@@ -219,12 +210,12 @@
     val userType = "nonDefaultUserType"
     val additionalUserInfo = AdditionalUserInfo(authenticatedHttpAuditEvent.who.get, userType)
 
-    val dataFrame = AuditLogsTransformer(additionalUsersInfo = AdditionalUsersInfo(Map(authenticatedHttpAuditEvent.who.get -> additionalUserInfo))).transform(
+    val aggregatedEventsDS = AuditLogsTransformer(additionalUsersInfo = AdditionalUsersInfo(Map(authenticatedHttpAuditEvent.who.get -> additionalUserInfo))).transform(
         auditEventsRDD        = sc.parallelize(events),
         timeAggregation       = "hour"
     )
-    dataFrame.collect.length shouldBe 1
-    dataFrame.collect.head.getAs[String](ElasticSearchFields.USER_TYPE_FIELD) shouldBe userType
+    aggregatedEventsDS.collect.length shouldBe 1
+    aggregatedEventsDS.collect.head.user_type should contain(userType)
   }
 
   it should "process user type when gerrit account could be identified" in {
@@ -234,7 +225,7 @@
     val userType = "nonDefaultUserType"
     val additionalUserInfo = AdditionalUserInfo(authenticatedHttpAuditEvent.who.get, userType)
 
-    val dataFrame =
+    val aggregatedEventsDS =
       AuditLogsTransformer(
         gerritIdentifiers = GerritUserIdentifiers(Map(authenticatedHttpAuditEvent.who.get -> gerritUserIdentifier)),
         additionalUsersInfo = AdditionalUsersInfo(Map(authenticatedHttpAuditEvent.who.get -> additionalUserInfo))
@@ -243,28 +234,28 @@
           timeAggregation       = "hour"
       )
 
-    dataFrame.collect.length shouldBe 1
-    dataFrame.collect.head.getAs[String](ElasticSearchFields.USER_TYPE_FIELD) shouldBe userType
+    aggregatedEventsDS.collect.length shouldBe 1
+    aggregatedEventsDS.collect.head.user_type should contain(userType)
   }
 
   it should "extract gerrit project from an http event" in {
     val events = Seq(authenticatedHttpAuditEvent)
 
-    val dataFrame = AuditLogsTransformer(gerritProjects = GerritProjects(Map(project -> GerritProject(project))))
+    val aggregatedEventsDS = AuditLogsTransformer(gerritProjects = GerritProjects(Map(project -> GerritProject(project))))
       .transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.collect.length shouldBe 1
-    dataFrame.collect.head.getAs[String](ElasticSearchFields.PROJECT_FIELD) shouldBe project
+    aggregatedEventsDS.collect.length shouldBe 1
+    aggregatedEventsDS.collect.head.project should contain(project)
   }
 
   it should "extract gerrit project from an ssh event" in {
     val events = Seq(sshAuditEvent)
 
-    val dataFrame = AuditLogsTransformer(gerritProjects = GerritProjects(Map(project -> GerritProject(project))))
+    val aggregatedEventsDS = AuditLogsTransformer(gerritProjects = GerritProjects(Map(project -> GerritProject(project))))
       .transform(sc.parallelize(events), timeAggregation="hour")
 
-    dataFrame.collect.length shouldBe 1
-    dataFrame.collect.head.getAs[String](ElasticSearchFields.PROJECT_FIELD) shouldBe project
+    aggregatedEventsDS.collect.length shouldBe 1
+    aggregatedEventsDS.collect.head.project should contain(project)
   }
 }