为了账号安全,请及时绑定邮箱和手机立即绑定

Spark 实现MySQL update操作

标签:
Spark

背景

目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update这样的形式来实现。

实践

要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC数据源的相关实现主要在JdbcRelationProvider里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val options = new JDBCOptions(parameters)
        val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis        // 替换成自己的saveMode
        var saveMode = mode match {                case SaveMode.Overwrite => I4SaveMode.Overwrite                case SaveMode.Append => I4SaveMode.Append                case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists                case SaveMode.Ignore => I4SaveMode.Ignore
            }        //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
        val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))        if(parameterLower.keySet.contains("savemode")){
            saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
        }
        val conn = JdbcUtils.createConnectionFactory(options)()        try {
            val tableExists = JdbcUtils.tableExists(conn, options)            if (tableExists) {
                saveMode match {                    case I4SaveMode.Overwrite =>                        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {                            // In this case, we should truncate table and then load.
                            truncateTable(conn, options.table)
                            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                            saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                        } else {
                        ......
    }

接下来就是saveTable方法:

def saveTable(
      df: DataFrame,      tableSchema: Option[StructType],      isCaseSensitive: Boolean,      options: JDBCOptions,      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

这里通过getInsertStatement方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):

def getInsertStatement(
      table: String,      rddSchema: StructType,      tableSchema: Option[StructType],      isCaseSensitive: Boolean,      dialect: JdbcDialect,      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若为update模式需要单独构造
    mode match {            case I4SaveMode.Update 
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name  s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _  s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:

 def savePartition(
      getConnection: () => Connection,      table: String,      iterator: Iterator[Row],      rddSchema: StructType,      insertStmt: String,      batchSize: Int,      dialect: JdbcDialect,      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE    if (isolationLevel != Connection.TRANSACTION_NONE) {      try {
        val metadata = conn.getMetaData        if (metadata.supportsTransactions()) {          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE    try {      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length      try {        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()          var i = 0
          while (i < numFields) {            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {      case e: SQLException =>
        val cause = e.getNextException        if (cause != null && e.getCause != cause) {          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {            try {
              e.initCause(cause)
            } catch {              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }        throw e
    } finally {      if (!committed) {        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大体思想就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
在非update的情况下:insert into tb (id,name,age) values (?,?,?)
在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement多喂一遍数据。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我们只需要再加一个相对位置参数offset来控制,即改造成:

private def makeSetter(       conn: Connection,       dialect: JdbcDialect,       dataType: DataType): JDBCValueSetter = dataType match {     case IntegerType 
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) 
             stmt.setInt(pos + 1, row.getInt(pos - offset))     case LongType 
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) 
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:

def savePartition(
                 getConnection: () => Connection,                 table: String,                 iterator: Iterator[Row],                 rddSchema: StructType,                 insertStmt: String,                 batchSize: Int,                 dialect: JdbcDialect,                 isolationLevel: Int,                 mode: I4SaveMode): Iterator[Byte] = {
    ...    //判断是否为update
    val isUpdateMode = mode == I4SaveMode.Update
    val stmt = conn.prepareStatement(insertStmt)
    val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
    val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val length = rddSchema.fields.length    // update模式下占位符是2倍
    val numFields = if (isUpdateMode) length * 2 else length
    val midField = numFields / 2
    try {        var rowCount = 0
        while (iterator.hasNext) {
            val row = iterator.next()            var i = 0
            while (i < numFields) {                if (isUpdateMode) {                    // update模式下未超过字段长度,offset为0
                    i < midField match {                        case true ?                            if (row.isNullAt(i)) {
                                stmt.setNull(i + 1, nullTypes(i))
                            } else {
                                setters(i).apply(stmt, row, i, 0)
                            }                        // update模式下超过字段长度,offset为midField,即字段长度
                        case false ?                            if (row.isNullAt(i - midField)) {
                                stmt.setNull(i + 1, nullTypes(i - midField))
                            } else {
                                setters(i - midField).apply(stmt, row, i, midField)
                            }
                    }
                
                } else {                    if (row.isNullAt(i)) {
                        stmt.setNull(i + 1, nullTypes(i))
                    } else {
                        setters(i).apply(stmt, row, i, 0)
                    }
                }
                i = i + 1
            }
          ...

改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)



作者:BIGUFO
链接:https://www.jianshu.com/p/d0bac129a04c

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
JAVA开发工程师
手记
粉丝
50
获赞与收藏
175

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消