Scalaコードをリファクタリングして、よりエレガントで慣用的なScalaにする方法についてのアドバイスを探しています。
機能があります
def joinDataFramesOnColumns(joinColumns: Seq[String]) : org.apache.spark.sql.DataFrame
でSeq[org.apache.spark.sql.DataFrame]
それらを結合することにより、で動作しますjoinColumns
。関数の定義は次のとおりです。
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
}
}
}
すべて成功するユニットテストがいくつかあります。
class FeatureGeneratorDataFrameExtensionsTest extends WordSpec {
val fruitValues = Seq(
Row(0, "BasketA", "Bananas", "Jack"),
Row(2, "BasketB", "Oranges", "Jack"),
Row(2, "BasketC", "Oranges", "Jill"),
Row(3, "BasketD", "Oranges", "Jack"),
Row(4, "BasketE", "Oranges", "Jack"),
Row(4, "BasketE", "Apples", "Jack"),
Row(4, "BasketF", "Bananas", "Jill")
)
val schema = List(
StructField("weeksPrior", IntegerType, true),
StructField("basket", StringType, true),
StructField("Product", StringType, true),
StructField("Customer", StringType, true)
)
val fruitDf = spark.createDataFrame(
spark.sparkContext.parallelize(fruitValues),
StructType(schema)
).withColumn("Date", udfDateSubWeeks(lit(dayPriorToAsAt), col("weeksPrior")))
"FeatureGenerator.SequenceOfDataFrames" should {
"join multiple dataframes on a specified set of columns" in {
val sequenceOfDataFrames = Seq[DataFrame](
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior1"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior2"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior3"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior4"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior5")
)
val joinedDataFrames = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product", "Customer", "Date"))
assert(joinedDataFrames.columns.length === 9)
assert(joinedDataFrames.columns.contains("basket"))
assert(joinedDataFrames.columns.contains("Product"))
assert(joinedDataFrames.columns.contains("Customer"))
assert(joinedDataFrames.columns.contains("Date"))
assert(joinedDataFrames.columns.contains("weeksPrior1"))
assert(joinedDataFrames.columns.contains("weeksPrior2"))
assert(joinedDataFrames.columns.contains("weeksPrior3"))
assert(joinedDataFrames.columns.contains("weeksPrior4"))
assert(joinedDataFrames.columns.contains("weeksPrior5"))
}
"when passed a list of one dataframe return that same dataframe" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product"))
assert(joinedDataFrame.columns.sorted === fruitDf.columns.sorted)
assert(joinedDataFrame.count === fruitDf.count)
}
"when passed an empty list of dataframes return an empty dataframe" in {
val joinedDataFrame = Seq[DataFrame]().joinDataFramesOnColumns(Seq("basket"))
assert(joinedDataFrame === spark.emptyDataFrame)
}
"when passed an empty list of joinColumns return the dataframes crossjoined" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf,fruitDf, fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq[String]())
assert(joinedDataFrame.count === scala.math.pow(fruitDf.count, sequenceOfDataFrames.size))
assert(joinedDataFrame.columns.size === fruitDf.columns.size * sequenceOfDataFrames.size)
}
}
}
このSparkバグが原因でエラーが発生するまで、これはすべて正常に機能していました:https://issues.apache.org/jira/browse/SPARK-25150これは、結合列が同じ名前の場合、特定の条件下でエラーを引き起こす可能性があります。
回避策は、列を別のものとしてエイリアスすることです。そのため、結合列をエイリアスし、結合を実行してから、名前を元に戻すように関数を書き直しました。
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
/*
The horrible, gnarly, unelegent code below would ideally exist simply as:
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
however that will fail in certain specific circumstances due to a bug in spark,
see https://issues.apache.org/jira/browse/SPARK-25150 for details
*/
val aliasSuffix = "_aliased"
val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
var aliasedNonEmptyDataFrames: Seq[DataFrame] = Seq()
nonEmptyDataFrames.foreach(
nonEmptyDataFrame =>{
var tempNonEmptyDataFrame = nonEmptyDataFrame
joinColumns.foreach(
joinColumn => {
tempNonEmptyDataFrame = tempNonEmptyDataFrame.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
}
)
aliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames :+ tempNonEmptyDataFrame
}
)
var joinedAliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames.reduce(_.join(_, aliasedJoinColumns))
joinColumns.foreach(
joinColumn => joinedAliasedNonEmptyDataFrames = joinedAliasedNonEmptyDataFrames.withColumnRenamed(
joinColumn+aliasSuffix, joinColumn
)
)
joinedAliasedNonEmptyDataFrames
}
}
}
テストはまだ合格しているので、かなり満足していvar
ますがvar
、各反復で結果をそれに戻すループを調べています...特に元のテストと比較して、かなりエレガントではなく、醜いです関数のバージョン。var
sを使わなくてもいいように書く方法があるはずなのに、試行錯誤の末、これが一番いいと思います。
誰かがよりエレガントな解決策を提案できますか?初心者のScala開発者として、このような問題を解決する慣用的な方法に慣れることは本当に助けになります。
コードの残りの部分(テストなど)に関する建設的なコメントも歓迎します
foldLeft()の使用を提案してくれた@Duelistに感謝します。ScalaのfoldLeftはDataFrameでどのように機能しますか?これにより、コードをそのように適合させて、var
sを排除することになりました。
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
/*
The code below would ideally exist simply as:
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
however that will fail in certain specific circumstances due to a bug in spark,
see https://issues.apache.org/jira/browse/SPARK-25150 for details
hence this code aliases the joinColumns, performs the join, then renames the
aliased columns back to their original name
*/
val aliasSuffix = "_aliased"
val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
val joinedAliasedNonEmptyDataFrames = nonEmptyDataFrames.foldLeft(Seq[DataFrame]()){
(tempDf, nonEmptyDataFrame) => tempDf :+ joinColumns.foldLeft(nonEmptyDataFrame){
(tempDf2, joinColumn) => tempDf2.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
}
}.reduce(_.join(_, aliasedJoinColumns))
joinColumns.foldLeft(joinedAliasedNonEmptyDataFrames){
(tempDf, joinColumn) => tempDf.withColumnRenamed(joinColumn+aliasSuffix, joinColumn)
}
}
}
}
2つのステートメントを1つにまとめて削除することでさらに進めることができましたがval joinedAliasedNonEmptyDataFrames
、その暫定的な使用によってもたらされる明確さを好みましたval
。
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加