背景
目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update
这样的形式来实现。
实践
要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:
public enum I4SaveMode { Append, Overwrite, ErrorIfExists, Ignore, Update }
JDBC数据源的相关实现主要在JdbcRelationProvider
里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):
override def createRelation( sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { val options = new JDBCOptions(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis // 替换成自己的saveMode var saveMode = mode match { case SaveMode.Overwrite => I4SaveMode.Overwrite case SaveMode.Append => I4SaveMode.Append case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists case SaveMode.Ignore => I4SaveMode.Ignore } //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式 val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2)) if(parameterLower.keySet.contains("savemode")){ saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode } val conn = JdbcUtils.createConnectionFactory(options)() try { val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { saveMode match { case I4SaveMode.Overwrite => if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, options.table) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options, saveMode) } else { ...... }
接下来就是saveTable方法:
def saveTable( df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, options: JDBCOptions, mode: I4SaveMode): Unit = { ...... val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) ..... repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) ) }
这里通过getInsertStatement
方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):
def getInsertStatement( table: String, rddSchema: StructType, tableSchema: Option[StructType], isCaseSensitive: Boolean, dialect: JdbcDialect, mode: I4SaveMode): String = { val columns = if (tableSchema.isEmpty) { rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") } else { val columnNameEquality = if (isCaseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution } else { org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution } val tableColumnNames = tableSchema.get.fieldNames rddSchema.fields.map { col => val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse { throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""") } dialect.quoteIdentifier(normalizedName) }.mkString(",") } val placeholders = rddSchema.fields.map(_ => "?").mkString(",") // s"INSERT INTO $table ($columns) VALUES ($placeholders)" //若为update模式需要单独构造 mode match { case I4SaveMode.Update val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name s"$name=?").mkString(",") s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting" case _ s"INSERT INTO $table ($columns) VALUES ($placeholders)" } }
只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:
def savePartition( getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, insertStmt: String, batchSize: Int, dialect: JdbcDialect, isolationLevel: Int): Iterator[Byte] = { val conn = getConnection() var committed = false var finalIsolationLevel = Connection.TRANSACTION_NONE if (isolationLevel != Connection.TRANSACTION_NONE) { try { val metadata = conn.getMetaData if (metadata.supportsTransactions()) { // Update to at least use the default isolation, if any transaction level // has been chosen and transactions are supported val defaultIsolation = metadata.getDefaultTransactionIsolation finalIsolationLevel = defaultIsolation if (metadata.supportsTransactionIsolationLevel(isolationLevel)) { // Finally update to actually requested level if possible finalIsolationLevel = isolationLevel } else { logWarning(s"Requested isolation level $isolationLevel is not supported; " + s"falling back to default isolation level $defaultIsolation") } } else { logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported") } } catch { case NonFatal(e) => logWarning("Exception while detecting transaction support", e) } } val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE try { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. conn.setTransactionIsolation(finalIsolationLevel) } val stmt = conn.prepareStatement(insertStmt) val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType)) val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType) val numFields = rddSchema.fields.length try { var rowCount = 0 while (iterator.hasNext) { val row = iterator.next() var i = 0 while (i < numFields) { if (row.isNullAt(i)) { stmt.setNull(i + 1, nullTypes(i)) } else { setters(i).apply(stmt, row, i) } i = i + 1 } stmt.addBatch() rowCount += 1 if (rowCount % batchSize == 0) { stmt.executeBatch() rowCount = 0 } } if (rowCount > 0) { stmt.executeBatch() } } finally { stmt.close() } if (supportsTransactions) { conn.commit() } committed = true Iterator.empty } catch { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { // If there is no cause already, set 'next exception' as cause. If cause is null, // it *may* be because no cause was set yet if (e.getCause == null) { try { e.initCause(cause) } catch { // Or it may be null because the cause *was* explicitly initialized, to *null*, // in which case this fails. There is no other way to detect it. // addSuppressed in this case as well. case _: IllegalStateException => e.addSuppressed(cause) } } else { e.addSuppressed(cause) } } throw e } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. if (supportsTransactions) { conn.rollback() } conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. try { conn.close() } catch { case e: Exception => logWarning("Transaction succeeded, but closing failed", e) } } } }
大体思想就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
在非update的情况下:insert into tb (id,name,age) values (?,?,?)
在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement
多喂一遍数据。原本的makeSetter方法如下:
private def makeSetter( conn: Connection, dialect: JdbcDialect, dataType: DataType): JDBCValueSetter = dataType match { case IntegerType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getInt(pos)) case LongType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setLong(pos + 1, row.getLong(pos)) ... }
我们只需要再加一个相对位置参数offset来控制,即改造成:
private def makeSetter( conn: Connection, dialect: JdbcDialect, dataType: DataType): JDBCValueSetter = dataType match { case IntegerType (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) stmt.setInt(pos + 1, row.getInt(pos - offset)) case LongType (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) stmt.setLong(pos + 1, row.getLong(pos - offset)) ...
在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:
def savePartition( getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, insertStmt: String, batchSize: Int, dialect: JdbcDialect, isolationLevel: Int, mode: I4SaveMode): Iterator[Byte] = { ... //判断是否为update val isUpdateMode = mode == I4SaveMode.Update val stmt = conn.prepareStatement(insertStmt) val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType)) val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType) val length = rddSchema.fields.length // update模式下占位符是2倍 val numFields = if (isUpdateMode) length * 2 else length val midField = numFields / 2 try { var rowCount = 0 while (iterator.hasNext) { val row = iterator.next() var i = 0 while (i < numFields) { if (isUpdateMode) { // update模式下未超过字段长度,offset为0 i < midField match { case true ? if (row.isNullAt(i)) { stmt.setNull(i + 1, nullTypes(i)) } else { setters(i).apply(stmt, row, i, 0) } // update模式下超过字段长度,offset为midField,即字段长度 case false ? if (row.isNullAt(i - midField)) { stmt.setNull(i + 1, nullTypes(i - midField)) } else { setters(i - midField).apply(stmt, row, i, midField) } } } else { if (row.isNullAt(i)) { stmt.setNull(i + 1, nullTypes(i)) } else { setters(i).apply(stmt, row, i, 0) } } i = i + 1 } ...
改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。
如何使用
若需要使用到update模式:
df.write.option("saveMode","update").jdbc(...)
作者:BIGUFO
链接:https://www.jianshu.com/p/d0bac129a04c
共同学习,写下你的评论
评论加载中...
作者其他优质文章