mirror of
https://github.com/PowerJob/PowerJob.git
synced 2025-07-17 00:00:04 +08:00
feat: show result in sql processor
This commit is contained in:
parent
93158ba19b
commit
5a9a5c6910
@ -67,11 +67,12 @@
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-jdbc</artifactId>
|
||||
<version>${spring.jdbc.version}</version>
|
||||
<scope>provided</scope>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Junit tests -->
|
||||
|
@ -4,13 +4,22 @@ import com.alibaba.fastjson.JSON;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
|
||||
import com.github.kfcfans.powerjob.worker.log.OmsLogger;
|
||||
import com.google.common.base.Joiner;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.Data;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.StopWatch;
|
||||
import tech.powerjob.official.processors.CommonBasicProcessor;
|
||||
import tech.powerjob.official.processors.util.CommonUtils;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.Connection;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
@ -28,6 +37,7 @@ import java.util.function.Predicate;
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
|
||||
/**
|
||||
* 默认超时时间
|
||||
*/
|
||||
@ -44,6 +54,8 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
*/
|
||||
protected SqlParser sqlParser;
|
||||
|
||||
private static final Joiner JOINER = Joiner.on("|").useForNull("-");
|
||||
|
||||
|
||||
@Override
|
||||
public ProcessResult process0(TaskContext taskContext) {
|
||||
@ -51,26 +63,29 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
OmsLogger omsLogger = taskContext.getOmsLogger();
|
||||
// 解析参数
|
||||
SqlParams sqlParams = extractParams(taskContext);
|
||||
omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
|
||||
omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams));
|
||||
// 校验参数
|
||||
validateParams(sqlParams);
|
||||
|
||||
StopWatch stopWatch = new StopWatch("SQL Processor");
|
||||
StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName());
|
||||
// 解析
|
||||
stopWatch.start("Parse SQL");
|
||||
if (sqlParser != null) {
|
||||
sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext));
|
||||
omsLogger.info("before parse sql: {}", sqlParams.getSql());
|
||||
String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext);
|
||||
sqlParams.setSql(newSQL);
|
||||
omsLogger.info("after parse sql: {}", newSQL);
|
||||
}
|
||||
stopWatch.stop();
|
||||
|
||||
// 校验 SQL
|
||||
stopWatch.start("Validate SQL");
|
||||
validateSql(sqlParams.getSql());
|
||||
validateSql(sqlParams.getSql(), omsLogger);
|
||||
stopWatch.stop();
|
||||
|
||||
// 执行
|
||||
stopWatch.start("Execute SQL");
|
||||
omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
|
||||
omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams));
|
||||
executeSql(sqlParams, taskContext);
|
||||
stopWatch.stop();
|
||||
|
||||
@ -79,13 +94,78 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
return new ProcessResult(true, message);
|
||||
}
|
||||
|
||||
abstract DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext);
|
||||
|
||||
/**
|
||||
* 执行 SQL
|
||||
*
|
||||
* @param sqlParams SQL processor 参数信息
|
||||
* @param taskContext 任务上下文
|
||||
* @param ctx 任务上下文
|
||||
*/
|
||||
abstract void executeSql(SqlParams sqlParams, TaskContext taskContext);
|
||||
@SneakyThrows
|
||||
private void executeSql(SqlParams sqlParams, TaskContext ctx) {
|
||||
|
||||
OmsLogger omsLogger = ctx.getOmsLogger();
|
||||
|
||||
boolean originAutoCommitFlag ;
|
||||
try (Connection connection = getDataSource(sqlParams, ctx).getConnection()) {
|
||||
originAutoCommitFlag = connection.getAutoCommit();
|
||||
connection.setAutoCommit(false);
|
||||
try (Statement statement = connection.createStatement()) {
|
||||
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
|
||||
statement.execute(sqlParams.getSql());
|
||||
|
||||
connection.commit();
|
||||
|
||||
if (sqlParams.showResult) {
|
||||
outputSqlResult(statement, omsLogger);
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
omsLogger.error("execute sql failed, try to rollback", e);
|
||||
connection.rollback();
|
||||
throw e;
|
||||
} finally {
|
||||
connection.setAutoCommit(originAutoCommitFlag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException {
|
||||
omsLogger.info("====== SQL EXECUTE RESULT ======");
|
||||
|
||||
for (int index = 0; index < Integer.MAX_VALUE; index++) {
|
||||
|
||||
// 某一个结果集
|
||||
ResultSet resultSet = statement.getResultSet();
|
||||
if (resultSet != null) {
|
||||
try (ResultSet rs = resultSet) {
|
||||
int columnCount = rs.getMetaData().getColumnCount();
|
||||
List<String> columnNames = Lists.newLinkedList();
|
||||
//column – the first column is 1, the second is 2, ...
|
||||
for (int i = 1; i <= columnCount; i++) {
|
||||
columnNames.add(rs.getMetaData().getColumnName(i));
|
||||
}
|
||||
omsLogger.info("[Result-{}] [Columns] {}" + System.lineSeparator(), index, JOINER.join(columnNames));
|
||||
int rowIndex = 0;
|
||||
List<Object> row = Lists.newLinkedList();
|
||||
while (rs.next()) {
|
||||
for (int i = 1; i <= columnCount; i++) {
|
||||
row.add(rs.getObject(i));
|
||||
}
|
||||
omsLogger.info("[Result-{}] [Row-{}] {}" + System.lineSeparator(), index, rowIndex++, JOINER.join(row));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int updateCount = statement.getUpdateCount();
|
||||
if (updateCount != -1) {
|
||||
omsLogger.info("[Result-{}] update count: {}", index, updateCount);
|
||||
}
|
||||
}
|
||||
if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
omsLogger.info("====== SQL EXECUTE RESULT ======");
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取参数信息
|
||||
@ -132,13 +212,14 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
/**
|
||||
* 校验 SQL 合法性
|
||||
*/
|
||||
private void validateSql(String sql) {
|
||||
private void validateSql(String sql, OmsLogger omsLogger) {
|
||||
if (sqlValidatorMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
|
||||
Predicate<String> validator = entry.getValue();
|
||||
if (!validator.test(sql)) {
|
||||
omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey());
|
||||
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
|
||||
}
|
||||
}
|
||||
@ -164,7 +245,10 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
|
||||
*/
|
||||
private String jdbcUrl;
|
||||
|
||||
/**
|
||||
* 是否展示 SQL 执行结果
|
||||
*/
|
||||
private boolean showResult;
|
||||
}
|
||||
|
||||
|
||||
|
@ -2,14 +2,11 @@ package tech.powerjob.official.processors.impl.sql;
|
||||
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.Connection;
|
||||
import java.sql.Statement;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
@ -64,34 +61,9 @@ public class SimpleSpringSqlProcessor extends AbstractSqlProcessor {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 SQL,忽略返回值
|
||||
*
|
||||
* @param sqlParams SQL processor 参数信息
|
||||
* @param taskContext 任务上下文
|
||||
*/
|
||||
@Override
|
||||
@SneakyThrows
|
||||
@SuppressWarnings({"squid:S1181"})
|
||||
protected void executeSql(SqlParams sqlParams, TaskContext taskContext) {
|
||||
DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName());
|
||||
boolean originAutoCommitFlag ;
|
||||
try (Connection connection = currentDataSource.getConnection()) {
|
||||
originAutoCommitFlag = connection.getAutoCommit();
|
||||
connection.setAutoCommit(false);
|
||||
try (Statement statement = connection.createStatement()) {
|
||||
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
|
||||
statement.execute(sqlParams.getSql());
|
||||
connection.commit();
|
||||
} catch (Throwable e) {
|
||||
connection.rollback();
|
||||
// rethrow
|
||||
throw e;
|
||||
} finally {
|
||||
// reset
|
||||
connection.setAutoCommit(originAutoCommitFlag);
|
||||
}
|
||||
}
|
||||
DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext) {
|
||||
return dataSourceMap.get(sqlParams.getDataSourceName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,4 +1,4 @@
|
||||
package tech.powerjob.official.processors.impl;
|
||||
package tech.powerjob.official.processors.impl.sql;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||
@ -10,9 +10,9 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
||||
import tech.powerjob.official.processors.TestUtils;
|
||||
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* @author Echo009
|
||||
@ -35,8 +35,17 @@ class SimpleSpringSqlProcessorTest {
|
||||
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
||||
// 排除掉包含 drop 的 SQL
|
||||
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
||||
// do nothing
|
||||
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
||||
// add ';'
|
||||
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> {
|
||||
if (!sql.endsWith(";")) {
|
||||
return sql + ";";
|
||||
}
|
||||
return sql;
|
||||
});
|
||||
|
||||
// just invoke clean datasource method
|
||||
simpleSpringSqlProcessor.removeDataSource("NULL_DATASOURCE");
|
||||
|
||||
log.info("init sql processor successfully!");
|
||||
|
||||
}
|
||||
@ -87,9 +96,19 @@ class SimpleSpringSqlProcessorTest {
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuery() {
|
||||
SimpleSpringSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
||||
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams)));
|
||||
|
||||
SimpleSpringSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table");
|
||||
simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams)));
|
||||
}
|
||||
|
||||
static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){
|
||||
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
|
||||
sqlParams.setSql(sql);
|
||||
sqlParams.setShowResult(true);
|
||||
return sqlParams;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user