From 93158ba19bd360f99d37ebc757547f12c86a8e48 Mon Sep 17 00:00:00 2001 From: Echo009 Date: Sat, 13 Mar 2021 14:49:42 +0800 Subject: [PATCH] refactor: SimpleSpringJdbcTemplateSqlProcessor => SimpleSpringSqlProcessor --- ...sor.java => SimpleSpringSqlProcessor.java} | 42 +++++--- ...pleSpringJdbcTemplateSqlProcessorTest.java | 81 ---------------- .../impl/SimpleSpringSqlProcessorTest.java | 96 +++++++++++++++++++ .../config/SqlProcessorConfiguration.java | 14 +-- 4 files changed, 133 insertions(+), 100 deletions(-) rename powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/{SimpleSpringJdbcTemplateSqlProcessor.java => SimpleSpringSqlProcessor.java} (61%) delete mode 100644 powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringJdbcTemplateSqlProcessorTest.java create mode 100644 powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java diff --git a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringJdbcTemplateSqlProcessor.java b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java similarity index 61% rename from powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringJdbcTemplateSqlProcessor.java rename to powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java index bccae773..26b2c3c6 100644 --- a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringJdbcTemplateSqlProcessor.java +++ b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/sql/SimpleSpringSqlProcessor.java @@ -2,30 +2,32 @@ package tech.powerjob.official.processors.impl.sql; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.google.common.collect.Maps; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; -import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.Statement; import java.util.Map; /** - * 简单 Spring SQL 处理器,强依赖于 Spring jdbc ,只能用 Spring Bean 的方式加载 + * 简单 Spring SQL 处理器,目前只能用 Spring Bean 的方式加载 * 直接忽略 SQL 执行的返回值 * * 注意 : * 默认情况下没有过滤任何 SQL - * 建议生产环境一定要使用 {@link SimpleSpringJdbcTemplateSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL + * 建议生产环境一定要使用 {@link SimpleSpringSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL * * 默认情况下会直接执行参数中的 SQL - * 可以通过添加 {@link SimpleSpringJdbcTemplateSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等) + * 可以通过添加 {@link SimpleSpringSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等) * * @author Echo009 * @since 2021/3/10 */ @Slf4j -public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor { +public class SimpleSpringSqlProcessor extends AbstractSqlProcessor { /** * 默认的数据源名称 */ @@ -40,7 +42,7 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor { * * @param defaultDataSource 默认数据源 */ - public SimpleSpringJdbcTemplateSqlProcessor(DataSource defaultDataSource) { + public SimpleSpringSqlProcessor(DataSource defaultDataSource) { dataSourceMap = Maps.newConcurrentMap(); registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource); } @@ -63,17 +65,33 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor { } /** - * 执行 SQL + * 执行 SQL,忽略返回值 * * @param sqlParams SQL processor 参数信息 * @param taskContext 任务上下文 */ @Override - void executeSql(SqlParams sqlParams, TaskContext taskContext) { - JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSourceMap.get(sqlParams.getDataSourceName())); - jdbcTemplate.setSkipResultsProcessing(true); - jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); - jdbcTemplate.execute(sqlParams.getSql()); + @SneakyThrows + @SuppressWarnings({"squid:S1181"}) + protected void executeSql(SqlParams sqlParams, TaskContext taskContext) { + DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName()); + boolean originAutoCommitFlag ; + try (Connection connection = currentDataSource.getConnection()) { + originAutoCommitFlag = connection.getAutoCommit(); + connection.setAutoCommit(false); + try (Statement statement = connection.createStatement()) { + statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); + statement.execute(sqlParams.getSql()); + connection.commit(); + } catch (Throwable e) { + connection.rollback(); + // rethrow + throw e; + } finally { + // reset + connection.setAutoCommit(originAutoCommitFlag); + } + } } /** diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringJdbcTemplateSqlProcessorTest.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringJdbcTemplateSqlProcessorTest.java deleted file mode 100644 index f792304e..00000000 --- a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringJdbcTemplateSqlProcessorTest.java +++ /dev/null @@ -1,81 +0,0 @@ -package tech.powerjob.official.processors.impl; - -import com.alibaba.fastjson.JSON; -import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; -import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; -import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; -import tech.powerjob.official.processors.TestUtils; -import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -/** - * @author Echo009 - * @since 2021/3/11 - */ -@Slf4j -class SimpleSpringJdbcTemplateSqlProcessorTest { - - private static SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor; - - @BeforeAll - static void initSqlProcessor() { - - EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); - EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2) - .addScript("classpath:db_init.sql") - .build(); - simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(database); - // do nothing - simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); - // 排除掉包含 drop 的 SQL - simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); - // do nothing - simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql); - log.info("init sql processor successfully!"); - - } - - - @Test - void testSqlValidator() { - SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams(); - sqlParams.setSql("drop table test_table"); - // 校验不通过 - assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); - } - - @Test - void testIncorrectDataSourceName() { - SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))"); - sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧"); - // 数据源名称非法 - assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); - } - - @Test - void testExecDDL() { - SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))"); - ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))); - assertTrue(processResult.isSuccess()); - } - - @Test - void testExecSQL() { - SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')"); - ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))); - assertTrue(processResult.isSuccess()); - } - - static SimpleSpringJdbcTemplateSqlProcessor.SqlParams constructSqlParam(String sql){ - SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams(); - sqlParams.setSql(sql); - return sqlParams; - } - -} diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java new file mode 100644 index 00000000..9024893d --- /dev/null +++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SimpleSpringSqlProcessorTest.java @@ -0,0 +1,96 @@ +package tech.powerjob.official.processors.impl; + +import com.alibaba.fastjson.JSON; +import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; +import lombok.extern.slf4j.Slf4j; +import org.h2.jdbc.JdbcSQLIntegrityConstraintViolationException; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import tech.powerjob.official.processors.TestUtils; +import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author Echo009 + * @since 2021/3/11 + */ +@Slf4j +class SimpleSpringSqlProcessorTest { + + private static SimpleSpringSqlProcessor simpleSpringSqlProcessor; + + @BeforeAll + static void initSqlProcessor() { + + EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); + EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2) + .addScript("classpath:db_init.sql") + .build(); + simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(database); + // do nothing + simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); + // 排除掉包含 drop 的 SQL + simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); + // do nothing + simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql); + log.info("init sql processor successfully!"); + + } + + + @Test + void testSqlValidator() { + SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams(); + sqlParams.setSql("drop table test_table"); + // 校验不通过 + assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); + } + + @Test + void testIncorrectDataSourceName() { + SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))"); + sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧"); + // 数据源名称非法 + assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); + } + + @Test + void testExecDDL() { + SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))"); + ProcessResult processResult = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))); + assertTrue(processResult.isSuccess()); + } + + @Test + void testExecSQL() { + + SimpleSpringSqlProcessor.SqlParams sqlParams1 = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')"); + ProcessResult processResult1 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1))); + assertTrue(processResult1.isSuccess()); + + assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1)))); + // 第二条会失败回滚 + SimpleSpringSqlProcessor.SqlParams sqlParams2 = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')"); + assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams2)))); + // 上方回滚,这里就能成功插入 + SimpleSpringSqlProcessor.SqlParams sqlParams3 = constructSqlParam("insert into test_table (id, content) values (1, '?')"); + ProcessResult processResult3 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams3))); + assertTrue(processResult3.isSuccess()); + + SimpleSpringSqlProcessor.SqlParams sqlParams4 = constructSqlParam("insert into test_table (id, content) values (2, '?');insert into test_table (id, content) values (3, '?')"); + ProcessResult processResult4 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams4))); + assertTrue(processResult4.isSuccess()); + + } + + static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){ + SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams(); + sqlParams.setSql(sql); + return sqlParams; + } + +} diff --git a/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java index 29c77d4b..7dbb30ba 100644 --- a/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java +++ b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java @@ -8,7 +8,7 @@ import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.DependsOn; -import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor; +import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor; import javax.sql.DataSource; @@ -37,15 +37,15 @@ public class SqlProcessorConfiguration { @Bean - public SimpleSpringJdbcTemplateSqlProcessor springSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) { - SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(dataSource); + public SimpleSpringSqlProcessor simpleSpringSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) { + SimpleSpringSqlProcessor simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(dataSource); // do nothing - simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); + simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", sql -> true); // 排除掉包含 drop 的 SQL - simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); + simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", sql -> sql.matches("^(?i)((?!drop).)*$")); // do nothing - simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql); - return simpleSpringJdbcTemplateSqlProcessor; + simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql); + return simpleSpringSqlProcessor; } }