refactor: SimpleSpringJdbcTemplateSqlProcessor => SimpleSpringSqlProcessor

This commit is contained in:
Echo009 2021-03-13 14:49:42 +08:00
parent 86b584be2f
commit 93158ba19b
4 changed files with 133 additions and 100 deletions

View File

@ -2,30 +2,32 @@ package tech.powerjob.official.processors.impl.sql;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import javax.sql.DataSource; import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.Statement;
import java.util.Map; import java.util.Map;
/** /**
* 简单 Spring SQL 处理器强依赖于 Spring jdbc 只能用 Spring Bean 的方式加载 * 简单 Spring SQL 处理器目前只能用 Spring Bean 的方式加载
* 直接忽略 SQL 执行的返回值 * 直接忽略 SQL 执行的返回值
* *
* 注意 : * 注意 :
* 默认情况下没有过滤任何 SQL * 默认情况下没有过滤任何 SQL
* 建议生产环境一定要使用 {@link SimpleSpringJdbcTemplateSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL * 建议生产环境一定要使用 {@link SimpleSpringSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
* *
* 默认情况下会直接执行参数中的 SQL * 默认情况下会直接执行参数中的 SQL
* 可以通过添加 {@link SimpleSpringJdbcTemplateSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等 * 可以通过添加 {@link SimpleSpringSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
* *
* @author Echo009 * @author Echo009
* @since 2021/3/10 * @since 2021/3/10
*/ */
@Slf4j @Slf4j
public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor { public class SimpleSpringSqlProcessor extends AbstractSqlProcessor {
/** /**
* 默认的数据源名称 * 默认的数据源名称
*/ */
@ -40,7 +42,7 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
* *
* @param defaultDataSource 默认数据源 * @param defaultDataSource 默认数据源
*/ */
public SimpleSpringJdbcTemplateSqlProcessor(DataSource defaultDataSource) { public SimpleSpringSqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap(); dataSourceMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource); registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
} }
@ -63,17 +65,33 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
} }
/** /**
* 执行 SQL * 执行 SQL忽略返回值
* *
* @param sqlParams SQL processor 参数信息 * @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文 * @param taskContext 任务上下文
*/ */
@Override @Override
void executeSql(SqlParams sqlParams, TaskContext taskContext) { @SneakyThrows
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSourceMap.get(sqlParams.getDataSourceName())); @SuppressWarnings({"squid:S1181"})
jdbcTemplate.setSkipResultsProcessing(true); protected void executeSql(SqlParams sqlParams, TaskContext taskContext) {
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName());
jdbcTemplate.execute(sqlParams.getSql()); boolean originAutoCommitFlag ;
try (Connection connection = currentDataSource.getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
} catch (Throwable e) {
connection.rollback();
// rethrow
throw e;
} finally {
// reset
connection.setAutoCommit(originAutoCommitFlag);
}
}
} }
/** /**

View File

@ -1,81 +0,0 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SimpleSpringJdbcTemplateSqlProcessorTest {
private static SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(database);
// do nothing
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
static SimpleSpringJdbcTemplateSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -0,0 +1,96 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.h2.jdbc.JdbcSQLIntegrityConstraintViolationException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import static org.junit.jupiter.api.Assertions.*;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SimpleSpringSqlProcessorTest {
private static SimpleSpringSqlProcessor simpleSpringSqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(database);
// do nothing
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SimpleSpringSqlProcessor.SqlParams sqlParams1 = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult1 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1)));
assertTrue(processResult1.isSuccess());
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1))));
// 第二条会失败回滚
SimpleSpringSqlProcessor.SqlParams sqlParams2 = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams2))));
// 上方回滚这里就能成功插入
SimpleSpringSqlProcessor.SqlParams sqlParams3 = constructSqlParam("insert into test_table (id, content) values (1, '?')");
ProcessResult processResult3 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams3)));
assertTrue(processResult3.isSuccess());
SimpleSpringSqlProcessor.SqlParams sqlParams4 = constructSqlParam("insert into test_table (id, content) values (2, '?');insert into test_table (id, content) values (3, '?')");
ProcessResult processResult4 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams4)));
assertTrue(processResult4.isSuccess());
}
static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -8,7 +8,7 @@ import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn; import org.springframework.context.annotation.DependsOn;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor; import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import javax.sql.DataSource; import javax.sql.DataSource;
@ -37,15 +37,15 @@ public class SqlProcessorConfiguration {
@Bean @Bean
public SimpleSpringJdbcTemplateSqlProcessor springSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) { public SimpleSpringSqlProcessor simpleSpringSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(dataSource); SimpleSpringSqlProcessor simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(dataSource);
// do nothing // do nothing
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", sql -> true);
// 排除掉包含 drop SQL // 排除掉包含 drop SQL
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", sql -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing // do nothing
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql); simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
return simpleSpringJdbcTemplateSqlProcessor; return simpleSpringSqlProcessor;
} }
} }