feat: show result in sql processor

This commit is contained in:
tjq 2021-03-13 20:35:43 +08:00
parent 93158ba19b
commit 5a9a5c6910
4 changed files with 123 additions and 47 deletions

View File

@ -67,11 +67,12 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId>
<version>${spring.jdbc.version}</version>
<scope>provided</scope>
<scope>test</scope>
</dependency>
<!-- Junit tests -->

View File

@ -4,13 +4,22 @@ import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.github.kfcfans.powerjob.worker.log.OmsLogger;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import lombok.Data;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StopWatch;
import tech.powerjob.official.processors.CommonBasicProcessor;
import tech.powerjob.official.processors.util.CommonUtils;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
@ -28,6 +37,7 @@ import java.util.function.Predicate;
*/
@Slf4j
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/**
* 默认超时时间
*/
@ -44,6 +54,8 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
*/
protected SqlParser sqlParser;
private static final Joiner JOINER = Joiner.on("|").useForNull("-");
@Override
public ProcessResult process0(TaskContext taskContext) {
@ -51,26 +63,29 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
OmsLogger omsLogger = taskContext.getOmsLogger();
// 解析参数
SqlParams sqlParams = extractParams(taskContext);
omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams));
// 校验参数
validateParams(sqlParams);
StopWatch stopWatch = new StopWatch("SQL Processor");
StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName());
// 解析
stopWatch.start("Parse SQL");
if (sqlParser != null) {
sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext));
omsLogger.info("before parse sql: {}", sqlParams.getSql());
String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext);
sqlParams.setSql(newSQL);
omsLogger.info("after parse sql: {}", newSQL);
}
stopWatch.stop();
// 校验 SQL
stopWatch.start("Validate SQL");
validateSql(sqlParams.getSql());
validateSql(sqlParams.getSql(), omsLogger);
stopWatch.stop();
// 执行
stopWatch.start("Execute SQL");
omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams));
executeSql(sqlParams, taskContext);
stopWatch.stop();
@ -79,13 +94,78 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
return new ProcessResult(true, message);
}
abstract DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext);
/**
* 执行 SQL
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
* @param ctx 任务上下文
*/
abstract void executeSql(SqlParams sqlParams, TaskContext taskContext);
@SneakyThrows
private void executeSql(SqlParams sqlParams, TaskContext ctx) {
OmsLogger omsLogger = ctx.getOmsLogger();
boolean originAutoCommitFlag ;
try (Connection connection = getDataSource(sqlParams, ctx).getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
if (sqlParams.showResult) {
outputSqlResult(statement, omsLogger);
}
} catch (Throwable e) {
omsLogger.error("execute sql failed, try to rollback", e);
connection.rollback();
throw e;
} finally {
connection.setAutoCommit(originAutoCommitFlag);
}
}
}
private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException {
omsLogger.info("====== SQL EXECUTE RESULT ======");
for (int index = 0; index < Integer.MAX_VALUE; index++) {
// 某一个结果集
ResultSet resultSet = statement.getResultSet();
if (resultSet != null) {
try (ResultSet rs = resultSet) {
int columnCount = rs.getMetaData().getColumnCount();
List<String> columnNames = Lists.newLinkedList();
//column the first column is 1, the second is 2, ...
for (int i = 1; i <= columnCount; i++) {
columnNames.add(rs.getMetaData().getColumnName(i));
}
omsLogger.info("[Result-{}] [Columns] {}" + System.lineSeparator(), index, JOINER.join(columnNames));
int rowIndex = 0;
List<Object> row = Lists.newLinkedList();
while (rs.next()) {
for (int i = 1; i <= columnCount; i++) {
row.add(rs.getObject(i));
}
omsLogger.info("[Result-{}] [Row-{}] {}" + System.lineSeparator(), index, rowIndex++, JOINER.join(row));
}
}
} else {
int updateCount = statement.getUpdateCount();
if (updateCount != -1) {
omsLogger.info("[Result-{}] update count: {}", index, updateCount);
}
}
if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) {
break;
}
}
omsLogger.info("====== SQL EXECUTE RESULT ======");
}
/**
* 提取参数信息
@ -132,13 +212,14 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/**
* 校验 SQL 合法性
*/
private void validateSql(String sql) {
private void validateSql(String sql, OmsLogger omsLogger) {
if (sqlValidatorMap.isEmpty()) {
return;
}
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
Predicate<String> validator = entry.getValue();
if (!validator.test(sql)) {
omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey());
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
}
}
@ -164,7 +245,10 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
*/
private String jdbcUrl;
/**
* 是否展示 SQL 执行结果
*/
private boolean showResult;
}

View File

@ -2,14 +2,11 @@ package tech.powerjob.official.processors.impl.sql;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.google.common.collect.Maps;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.Statement;
import java.util.Map;
/**
@ -64,34 +61,9 @@ public class SimpleSpringSqlProcessor extends AbstractSqlProcessor {
});
}
/**
* 执行 SQL忽略返回值
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
*/
@Override
@SneakyThrows
@SuppressWarnings({"squid:S1181"})
protected void executeSql(SqlParams sqlParams, TaskContext taskContext) {
DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName());
boolean originAutoCommitFlag ;
try (Connection connection = currentDataSource.getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
} catch (Throwable e) {
connection.rollback();
// rethrow
throw e;
} finally {
// reset
connection.setAutoCommit(originAutoCommitFlag);
}
}
DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext) {
return dataSourceMap.get(sqlParams.getDataSourceName());
}
/**

View File

@ -1,4 +1,4 @@
package tech.powerjob.official.processors.impl;
package tech.powerjob.official.processors.impl.sql;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
@ -10,9 +10,9 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
@ -35,8 +35,17 @@ class SimpleSpringSqlProcessorTest {
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
// add ';'
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> {
if (!sql.endsWith(";")) {
return sql + ";";
}
return sql;
});
// just invoke clean datasource method
simpleSpringSqlProcessor.removeDataSource("NULL_DATASOURCE");
log.info("init sql processor successfully!");
}
@ -87,9 +96,19 @@ class SimpleSpringSqlProcessorTest {
}
@Test
public void testQuery() {
SimpleSpringSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams)));
SimpleSpringSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table");
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams)));
}
static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql(sql);
sqlParams.setShowResult(true);
return sqlParams;
}