feat: show result in sql processor

This commit is contained in:
tjq 2021-03-13 20:35:43 +08:00
parent 93158ba19b
commit 5a9a5c6910
4 changed files with 123 additions and 47 deletions

View File

@ -67,11 +67,12 @@
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework</groupId> <groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId> <artifactId>spring-jdbc</artifactId>
<version>${spring.jdbc.version}</version> <version>${spring.jdbc.version}</version>
<scope>provided</scope> <scope>test</scope>
</dependency> </dependency>
<!-- Junit tests --> <!-- Junit tests -->

View File

@ -4,13 +4,22 @@ import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.github.kfcfans.powerjob.worker.log.OmsLogger; import com.github.kfcfans.powerjob.worker.log.OmsLogger;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import lombok.Data; import lombok.Data;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
import tech.powerjob.official.processors.CommonBasicProcessor; import tech.powerjob.official.processors.CommonBasicProcessor;
import tech.powerjob.official.processors.util.CommonUtils; import tech.powerjob.official.processors.util.CommonUtils;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -28,6 +37,7 @@ import java.util.function.Predicate;
*/ */
@Slf4j @Slf4j
public abstract class AbstractSqlProcessor extends CommonBasicProcessor { public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/** /**
* 默认超时时间 * 默认超时时间
*/ */
@ -44,6 +54,8 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
*/ */
protected SqlParser sqlParser; protected SqlParser sqlParser;
private static final Joiner JOINER = Joiner.on("|").useForNull("-");
@Override @Override
public ProcessResult process0(TaskContext taskContext) { public ProcessResult process0(TaskContext taskContext) {
@ -51,26 +63,29 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
OmsLogger omsLogger = taskContext.getOmsLogger(); OmsLogger omsLogger = taskContext.getOmsLogger();
// 解析参数 // 解析参数
SqlParams sqlParams = extractParams(taskContext); SqlParams sqlParams = extractParams(taskContext);
omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams)); omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams));
// 校验参数 // 校验参数
validateParams(sqlParams); validateParams(sqlParams);
StopWatch stopWatch = new StopWatch("SQL Processor"); StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName());
// 解析 // 解析
stopWatch.start("Parse SQL"); stopWatch.start("Parse SQL");
if (sqlParser != null) { if (sqlParser != null) {
sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext)); omsLogger.info("before parse sql: {}", sqlParams.getSql());
String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext);
sqlParams.setSql(newSQL);
omsLogger.info("after parse sql: {}", newSQL);
} }
stopWatch.stop(); stopWatch.stop();
// 校验 SQL // 校验 SQL
stopWatch.start("Validate SQL"); stopWatch.start("Validate SQL");
validateSql(sqlParams.getSql()); validateSql(sqlParams.getSql(), omsLogger);
stopWatch.stop(); stopWatch.stop();
// 执行 // 执行
stopWatch.start("Execute SQL"); stopWatch.start("Execute SQL");
omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams)); omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams));
executeSql(sqlParams, taskContext); executeSql(sqlParams, taskContext);
stopWatch.stop(); stopWatch.stop();
@ -79,13 +94,78 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
return new ProcessResult(true, message); return new ProcessResult(true, message);
} }
abstract DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext);
/** /**
* 执行 SQL * 执行 SQL
* * @param sqlParams SQL processor 参数信息
* @param sqlParams SQL processor 参数信息 * @param ctx 任务上下文
* @param taskContext 任务上下文
*/ */
abstract void executeSql(SqlParams sqlParams, TaskContext taskContext); @SneakyThrows
private void executeSql(SqlParams sqlParams, TaskContext ctx) {
OmsLogger omsLogger = ctx.getOmsLogger();
boolean originAutoCommitFlag ;
try (Connection connection = getDataSource(sqlParams, ctx).getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
if (sqlParams.showResult) {
outputSqlResult(statement, omsLogger);
}
} catch (Throwable e) {
omsLogger.error("execute sql failed, try to rollback", e);
connection.rollback();
throw e;
} finally {
connection.setAutoCommit(originAutoCommitFlag);
}
}
}
private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException {
omsLogger.info("====== SQL EXECUTE RESULT ======");
for (int index = 0; index < Integer.MAX_VALUE; index++) {
// 某一个结果集
ResultSet resultSet = statement.getResultSet();
if (resultSet != null) {
try (ResultSet rs = resultSet) {
int columnCount = rs.getMetaData().getColumnCount();
List<String> columnNames = Lists.newLinkedList();
//column the first column is 1, the second is 2, ...
for (int i = 1; i <= columnCount; i++) {
columnNames.add(rs.getMetaData().getColumnName(i));
}
omsLogger.info("[Result-{}] [Columns] {}" + System.lineSeparator(), index, JOINER.join(columnNames));
int rowIndex = 0;
List<Object> row = Lists.newLinkedList();
while (rs.next()) {
for (int i = 1; i <= columnCount; i++) {
row.add(rs.getObject(i));
}
omsLogger.info("[Result-{}] [Row-{}] {}" + System.lineSeparator(), index, rowIndex++, JOINER.join(row));
}
}
} else {
int updateCount = statement.getUpdateCount();
if (updateCount != -1) {
omsLogger.info("[Result-{}] update count: {}", index, updateCount);
}
}
if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) {
break;
}
}
omsLogger.info("====== SQL EXECUTE RESULT ======");
}
/** /**
* 提取参数信息 * 提取参数信息
@ -132,13 +212,14 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/** /**
* 校验 SQL 合法性 * 校验 SQL 合法性
*/ */
private void validateSql(String sql) { private void validateSql(String sql, OmsLogger omsLogger) {
if (sqlValidatorMap.isEmpty()) { if (sqlValidatorMap.isEmpty()) {
return; return;
} }
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) { for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
Predicate<String> validator = entry.getValue(); Predicate<String> validator = entry.getValue();
if (!validator.test(sql)) { if (!validator.test(sql)) {
omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey());
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey()); throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
} }
} }
@ -164,7 +245,10 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format * 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
*/ */
private String jdbcUrl; private String jdbcUrl;
/**
* 是否展示 SQL 执行结果
*/
private boolean showResult;
} }

View File

@ -2,14 +2,11 @@ package tech.powerjob.official.processors.impl.sql;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import javax.sql.DataSource; import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.Statement;
import java.util.Map; import java.util.Map;
/** /**
@ -64,34 +61,9 @@ public class SimpleSpringSqlProcessor extends AbstractSqlProcessor {
}); });
} }
/**
* 执行 SQL忽略返回值
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
*/
@Override @Override
@SneakyThrows DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext) {
@SuppressWarnings({"squid:S1181"}) return dataSourceMap.get(sqlParams.getDataSourceName());
protected void executeSql(SqlParams sqlParams, TaskContext taskContext) {
DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName());
boolean originAutoCommitFlag ;
try (Connection connection = currentDataSource.getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
} catch (Throwable e) {
connection.rollback();
// rethrow
throw e;
} finally {
// reset
connection.setAutoCommit(originAutoCommitFlag);
}
}
} }
/** /**

View File

@ -1,4 +1,4 @@
package tech.powerjob.official.processors.impl; package tech.powerjob.official.processors.impl.sql;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
@ -10,9 +10,9 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils; import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* @author Echo009 * @author Echo009
@ -35,8 +35,17 @@ class SimpleSpringSqlProcessorTest {
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL // 排除掉包含 drop SQL
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing // add ';'
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql); simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> {
if (!sql.endsWith(";")) {
return sql + ";";
}
return sql;
});
// just invoke clean datasource method
simpleSpringSqlProcessor.removeDataSource("NULL_DATASOURCE");
log.info("init sql processor successfully!"); log.info("init sql processor successfully!");
} }
@ -87,9 +96,19 @@ class SimpleSpringSqlProcessorTest {
} }
@Test
public void testQuery() {
SimpleSpringSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams)));
SimpleSpringSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table");
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams)));
}
static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){ static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams(); SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql(sql); sqlParams.setSql(sql);
sqlParams.setShowResult(true);
return sqlParams; return sqlParams;
} }