diff --git a/powerjob-official-processors/pom.xml b/powerjob-official-processors/pom.xml index db7c4199..4b906e4a 100644 --- a/powerjob-official-processors/pom.xml +++ b/powerjob-official-processors/pom.xml @@ -67,11 +67,12 @@ provided + org.springframework spring-jdbc ${spring.jdbc.version} - provided + test diff --git a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java index 99d857b8..bb11c8d7 100644 --- a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java +++ b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java @@ -4,13 +4,22 @@ import com.alibaba.fastjson.JSON; import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.github.kfcfans.powerjob.worker.log.OmsLogger; +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import lombok.Data; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.util.StopWatch; import tech.powerjob.official.processors.CommonBasicProcessor; import tech.powerjob.official.processors.util.CommonUtils; +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -28,6 +37,7 @@ import java.util.function.Predicate; */ @Slf4j public abstract class AbstractSqlProcessor extends CommonBasicProcessor { + /** * 默认超时时间 */ @@ -44,6 +54,8 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor { */ protected SqlParser sqlParser; + private static final Joiner JOINER = Joiner.on("|").useForNull("-"); + @Override public ProcessResult process0(TaskContext taskContext) { @@ -51,26 +63,29 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor { OmsLogger omsLogger = taskContext.getOmsLogger(); // 解析参数 SqlParams sqlParams = extractParams(taskContext); - omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams)); + omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams)); // 校验参数 validateParams(sqlParams); - StopWatch stopWatch = new StopWatch("SQL Processor"); + StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName()); // 解析 stopWatch.start("Parse SQL"); if (sqlParser != null) { - sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext)); + omsLogger.info("before parse sql: {}", sqlParams.getSql()); + String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext); + sqlParams.setSql(newSQL); + omsLogger.info("after parse sql: {}", newSQL); } stopWatch.stop(); // 校验 SQL stopWatch.start("Validate SQL"); - validateSql(sqlParams.getSql()); + validateSql(sqlParams.getSql(), omsLogger); stopWatch.stop(); // 执行 stopWatch.start("Execute SQL"); - omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams)); + omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams)); executeSql(sqlParams, taskContext); stopWatch.stop(); @@ -79,13 +94,78 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor { return new ProcessResult(true, message); } + abstract DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext); + /** * 执行 SQL - * - * @param sqlParams SQL processor 参数信息 - * @param taskContext 任务上下文 + * @param sqlParams SQL processor 参数信息 + * @param ctx 任务上下文 */ - abstract void executeSql(SqlParams sqlParams, TaskContext taskContext); + @SneakyThrows + private void executeSql(SqlParams sqlParams, TaskContext ctx) { + + OmsLogger omsLogger = ctx.getOmsLogger(); + + boolean originAutoCommitFlag ; + try (Connection connection = getDataSource(sqlParams, ctx).getConnection()) { + originAutoCommitFlag = connection.getAutoCommit(); + connection.setAutoCommit(false); + try (Statement statement = connection.createStatement()) { + statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); + statement.execute(sqlParams.getSql()); + + connection.commit(); + + if (sqlParams.showResult) { + outputSqlResult(statement, omsLogger); + } + } catch (Throwable e) { + omsLogger.error("execute sql failed, try to rollback", e); + connection.rollback(); + throw e; + } finally { + connection.setAutoCommit(originAutoCommitFlag); + } + } + } + + private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException { + omsLogger.info("====== SQL EXECUTE RESULT ======"); + + for (int index = 0; index < Integer.MAX_VALUE; index++) { + + // 某一个结果集 + ResultSet resultSet = statement.getResultSet(); + if (resultSet != null) { + try (ResultSet rs = resultSet) { + int columnCount = rs.getMetaData().getColumnCount(); + List columnNames = Lists.newLinkedList(); + //column – the first column is 1, the second is 2, ... + for (int i = 1; i <= columnCount; i++) { + columnNames.add(rs.getMetaData().getColumnName(i)); + } + omsLogger.info("[Result-{}] [Columns] {}" + System.lineSeparator(), index, JOINER.join(columnNames)); + int rowIndex = 0; + List row = Lists.newLinkedList(); + while (rs.next()) { + for (int i = 1; i <= columnCount; i++) { + row.add(rs.getObject(i)); + } + omsLogger.info("[Result-{}] [Row-{}] {}" + System.lineSeparator(), index, rowIndex++, JOINER.join(row)); + } + } + } else { + int updateCount = statement.getUpdateCount(); + if (updateCount != -1) { + omsLogger.info("[Result-{}] update count: {}", index, updateCount); + } + } + if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) { + break; + } + } + omsLogger.info("====== SQL EXECUTE RESULT ======"); + } /** * 提取参数信息 @@ -132,13 +212,14 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor { /** * 校验 SQL 合法性 */ - private void validateSql(String sql) { + private void validateSql(String sql, OmsLogger omsLogger) { if (sqlValidatorMap.isEmpty()) { return; } for (Map.Entry> entry : sqlValidatorMap.entrySet()) { Predicate validator = entry.getValue(); if (!validator.test(sql)) { + omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey()); throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey()); } } @@ -164,7 +245,10 @@ public abstract class AbstractSqlProcessor extends CommonBasicProcessor { * 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format */ private String jdbcUrl; - + /** + * 是否展示 SQL 执行结果 + */ + private boolean showResult; } diff --git a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java index 26b2c3c6..bca7ba36 100644 --- a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java +++ b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java @@ -2,14 +2,11 @@ package tech.powerjob.official.processors.impl.sql; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.google.common.collect.Maps; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.Statement; import java.util.Map; /** @@ -64,34 +61,9 @@ public class SimpleSpringSqlProcessor extends AbstractSqlProcessor { }); } - /** - * 执行 SQL,忽略返回值 - * - * @param sqlParams SQL processor 参数信息 - * @param taskContext 任务上下文 - */ @Override - @SneakyThrows - @SuppressWarnings({"squid:S1181"}) - protected void executeSql(SqlParams sqlParams, TaskContext taskContext) { - DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName()); - boolean originAutoCommitFlag ; - try (Connection connection = currentDataSource.getConnection()) { - originAutoCommitFlag = connection.getAutoCommit(); - connection.setAutoCommit(false); - try (Statement statement = connection.createStatement()) { - statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); - statement.execute(sqlParams.getSql()); - connection.commit(); - } catch (Throwable e) { - connection.rollback(); - // rethrow - throw e; - } finally { - // reset - connection.setAutoCommit(originAutoCommitFlag); - } - } + DataSource getDataSource(SqlParams sqlParams, TaskContext taskContext) { + return dataSourceMap.get(sqlParams.getDataSourceName()); } /** diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessorTest.java similarity index 81% rename from powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java rename to powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessorTest.java index 9024893d..66ce3a4d 100644 --- a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java +++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessorTest.java @@ -1,4 +1,4 @@ -package tech.powerjob.official.processors.impl; +package tech.powerjob.official.processors.impl.sql; import com.alibaba.fastjson.JSON; import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; @@ -10,9 +10,9 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import tech.powerjob.official.processors.TestUtils; -import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Echo009 @@ -35,8 +35,17 @@ class SimpleSpringSqlProcessorTest { simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); // 排除掉包含 drop 的 SQL simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); - // do nothing - simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql); + // add ';' + simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> { + if (!sql.endsWith(";")) { + return sql + ";"; + } + return sql; + }); + + // just invoke clean datasource method + simpleSpringSqlProcessor.removeDataSource("NULL_DATASOURCE"); + log.info("init sql processor successfully!"); } @@ -87,9 +96,19 @@ class SimpleSpringSqlProcessorTest { } + @Test + public void testQuery() { + SimpleSpringSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')"); + simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams))); + + SimpleSpringSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table"); + simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams))); + } + static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){ SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams(); sqlParams.setSql(sql); + sqlParams.setShowResult(true); return sqlParams; }