Spark DataFrame处理数据倾斜问题

由于爬虫抓取等原因,会导致单一ID的日志条数过多。在spark中,同一ID的日志会被shuffle到单一的节点上进行处理,导致系统运行缓慢!

因为这些用户的访问本来就是无效的,所以可以直接过滤掉这部分用户。

话不多说,scala的DataFrame版输出和代码如下(参考链接见代码注释):

引用
spark version: 1.6.1

OriginalDataFrame(withfakeusers):

+---------+------+

|id|movie|

+---------+------+

|u1|WhoAmI|

|u2|Zoppia|

|u2|Lost|

|FakeUserA|Zoppia|

|FakeUserA|Lost|

|FakeUserA|Zoppia|

|FakeUserA|Lost|

|FakeUserA|Zoppia|

|FakeUserA|Lost|

|FakeUserB|Lost|

|FakeUserB|Lost|

|FakeUserB|Lost|

|FakeUserB|Lost|

+---------+------+

FakeUserswithcount(threshold=2):

+---------+-----+

|id|count|

+---------+-----+

|FakeUserA|6|

|FakeUserB|4|

+---------+-----+

FakeUsers:

Set(FakeUserA,FakeUserB)

Validusersafterfilter:

+---+------+

|id|movie|

+---+------+

|u1|WhoAmI|

|u2|Zoppia|

|u2|Lost|

+---+------+

import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._

/**
  * Created by colinliang on 2017/8/14.
  */
case class IDMovie(id: String, movie: String)
object BroadcastTest {
  def main(args: Array[String]): Unit = {
    Logger.getRootLogger().setLevel(Level.FATAL) //http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console
    val conf = new SparkConf().setAppName("word count").setMaster("local[1]")
    val sc = new SparkContext(conf)
    println("spark version: " + sc.version)
    sc.setLogLevel("WARN") //http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console
    val spark = new SQLContext(sc)



    val idvids = List(
      IDMovie("u1", "WhoAmI")
      , IDMovie("u2", "Zoppia")
      , IDMovie("u2", "Lost")
      , IDMovie("FakeUserA", "Zoppia")
      , IDMovie("FakeUserA", "Lost")
      , IDMovie("FakeUserA", "Zoppia")
      , IDMovie("FakeUserA", "Lost")
      , IDMovie("FakeUserA", "Zoppia")
      , IDMovie("FakeUserA", "Lost")
      , IDMovie("FakeUserB", "Lost")
      , IDMovie("FakeUserB", "Lost")
      , IDMovie("FakeUserB", "Lost")
      , IDMovie("FakeUserB", "Lost")
      );


    val df = spark
      .createDataFrame(idvids)
      .repartition(col("id"))

    println("Original DataFrame (with fake users): ")
    df.show()

//    val df_fakeUsers_with_count=df.sample(false,0.1).groupBy(col("id")).count().filter(col("count")>2).limit(10000)//实际中可以根据需要仅采样一部分数据
    val df_fakeUsers_with_count=df.groupBy(col("id")).count().filter(col("count")>2)
    /**DataFrame 中的groupby 为aggregation形式的,不涉及shuffle,速度很快。参见:https://forums.databricks.com/questions/956/how-do-i-group-my-dataset-by-a-key-or-combination.html
      更多聚合函数参见:https://spark.apache.org/docs/1.6.1/api/scala/index.html#org.apache.spark.sql.functions$
      此外,还可以通过agg()函数对groupBy后的数据的多列进行聚合
      */
    println("Fake Users with count (threshold=2):")
    df_fakeUsers_with_count.show()


    val set_fakeUsers=df_fakeUsers_with_count.select("id").collect().map(_(0)).toList.map(_.toString).toArray[String].toSet
    println("Fake Users:")
    println(set_fakeUsers)
    val set_fakeUsers_broadcast=sc.broadcast(set_fakeUsers)
    /** broadcast教程:https://jaceklaskowski.gitbooks.io/mastering-apache-spark/spark-broadcast.html
      * 官方文档: http://spark.apache.org/docs/latest/rdd-programming-guide.html#broadcast-variables
      */

    val udf_isValidUser = udf((id: String) => !set_fakeUsers_broadcast.value.contains(id)) //直接用set_highCountUsers.contains(id) 也行,但效率低,因为反序列化的次数可能比较多,参见http://spark.apache.org/docs/latest/rdd-programming-guide.html#broadcast-variables
    val df_filtered=df.filter(udf_isValidUser(col("id")) ) //过滤掉这部分用户
    /** 如果是要保留部分用户,而不是过滤掉这部分用户,且用户量很小,无需定义UDF:
      * https://stackoverflow.com/questions/39234360/filter-spark-scala-dataframe-if-column-is-present-in-set
      * val validValues = Set("A", "B", "C")
      * data.filter($"myColumn".isin(validValues.toSeq: _*))
      */
    /** 如果是要保留部分用户,且用户量比较大,可以用broadcast 的DataFrame:
      * https://stackoverflow.com/questions/33824933/spark-dataframe-filtering-retain-element-belonging-to-a-list
      * import org.apache.spark.sql.functions.broadcast
      * initialDataFrame.join(broadcast(usersToKeep), $"userID" === $"userID_")
      */
    println("\nValid users after filter:")
    df_filtered.show()
  }
}

相关推荐