mirror of
https://github.com/PowerJob/PowerJob.git
synced 2025-07-17 00:00:04 +08:00
refactor: SqlProcessor => SimpleSpringJdbcTemplateSqlProcessor 💡
This commit is contained in:
parent
30d0d7d338
commit
1d67e97b45
@ -1,4 +1,4 @@
|
|||||||
package tech.powerjob.official.processors.impl;
|
package tech.powerjob.official.processors.impl.sql;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSON;
|
||||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||||
@ -7,81 +7,55 @@ import com.github.kfcfans.powerjob.worker.log.OmsLogger;
|
|||||||
import com.google.common.collect.Maps;
|
import com.google.common.collect.Maps;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.jdbc.core.JdbcTemplate;
|
|
||||||
import org.springframework.util.Assert;
|
|
||||||
import org.springframework.util.StopWatch;
|
import org.springframework.util.StopWatch;
|
||||||
import org.springframework.util.StringUtils;
|
|
||||||
import tech.powerjob.official.processors.CommonBasicProcessor;
|
import tech.powerjob.official.processors.CommonBasicProcessor;
|
||||||
import tech.powerjob.official.processors.util.CommonUtils;
|
import tech.powerjob.official.processors.util.CommonUtils;
|
||||||
|
|
||||||
import javax.sql.DataSource;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SQL 处理器,只能用 Spring Bean 的方式加载
|
* SQL Processor
|
||||||
* 注意 : 默认情况下没有过滤任何 SQL
|
|
||||||
* 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
|
|
||||||
*
|
*
|
||||||
* 默认情况下会直接执行参数中的 SQL
|
* 处理流程:
|
||||||
* 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等)
|
* * * 解析参数 => 校验参数 => 解析 SQL => 校验 SQL => 执行 SQL
|
||||||
|
*
|
||||||
|
* 可以通过 {@link AbstractSqlProcessor#registerSqlValidator} 方法注册 SQL 校验器拦截非法 SQL
|
||||||
|
* 可以通过指定 {@link AbstractSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等)
|
||||||
*
|
*
|
||||||
* @author Echo009
|
* @author Echo009
|
||||||
* @since 2021/3/10
|
* @since 2021/3/12
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlProcessor extends CommonBasicProcessor {
|
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
|
||||||
/**
|
|
||||||
* 默认的数据源名称
|
|
||||||
*/
|
|
||||||
private static final String DEFAULT_DATASOURCE_NAME = "default";
|
|
||||||
/**
|
/**
|
||||||
* 默认超时时间
|
* 默认超时时间
|
||||||
*/
|
*/
|
||||||
private static final int DEFAULT_TIMEOUT = 60;
|
protected static final int DEFAULT_TIMEOUT = 60;
|
||||||
/**
|
|
||||||
* name => data source
|
|
||||||
*/
|
|
||||||
private final Map<String, DataSource> dataSourceMap;
|
|
||||||
/**
|
/**
|
||||||
* name => SQL validator
|
* name => SQL validator
|
||||||
* 注意 :
|
* 注意 :
|
||||||
* - 返回 true 表示验证通过
|
* - 返回 true 表示验证通过
|
||||||
* - 返回 false 表示 SQL 非法,将被拒绝执行
|
* - 返回 false 表示 SQL 非法,将被拒绝执行
|
||||||
*/
|
*/
|
||||||
private final Map<String, Predicate<String>> sqlValidatorMap;
|
protected final Map<String, Predicate<String>> sqlValidatorMap = Maps.newConcurrentMap();
|
||||||
/**
|
/**
|
||||||
* 自定义 SQL 解析器
|
* 自定义 SQL 解析器
|
||||||
*/
|
*/
|
||||||
private SqlParser sqlParser;
|
protected SqlParser sqlParser;
|
||||||
|
|
||||||
/**
|
|
||||||
* 指定默认的数据源
|
|
||||||
*
|
|
||||||
* @param defaultDataSource 默认数据源
|
|
||||||
*/
|
|
||||||
public SqlProcessor(DataSource defaultDataSource) {
|
|
||||||
dataSourceMap = Maps.newConcurrentMap();
|
|
||||||
sqlValidatorMap = Maps.newConcurrentMap();
|
|
||||||
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ProcessResult process0(TaskContext taskContext) {
|
public ProcessResult process0(TaskContext taskContext) {
|
||||||
|
|
||||||
OmsLogger omsLogger = taskContext.getOmsLogger();
|
OmsLogger omsLogger = taskContext.getOmsLogger();
|
||||||
SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
|
// 解析参数
|
||||||
// 检查数据源
|
SqlParams sqlParams = extractParams(taskContext);
|
||||||
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
|
omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
|
||||||
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
|
// 校验参数
|
||||||
omsLogger.info("current data source name is empty, use the default data source");
|
validateParams(sqlParams);
|
||||||
}
|
|
||||||
DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
|
|
||||||
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
|
|
||||||
});
|
|
||||||
StopWatch stopWatch = new StopWatch("SQL Processor");
|
|
||||||
|
|
||||||
|
StopWatch stopWatch = new StopWatch("SQL Processor");
|
||||||
// 解析
|
// 解析
|
||||||
stopWatch.start("Parse SQL");
|
stopWatch.start("Parse SQL");
|
||||||
if (sqlParser != null) {
|
if (sqlParser != null) {
|
||||||
@ -96,17 +70,43 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
|
|
||||||
// 执行
|
// 执行
|
||||||
stopWatch.start("Execute SQL");
|
stopWatch.start("Execute SQL");
|
||||||
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
|
omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
|
||||||
jdbcTemplate.setSkipResultsProcessing(true);
|
executeSql(sqlParams, taskContext);
|
||||||
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
|
|
||||||
omsLogger.info("start to execute sql: {}", sqlParams.getSql());
|
|
||||||
jdbcTemplate.execute(sqlParams.getSql());
|
|
||||||
stopWatch.stop();
|
stopWatch.stop();
|
||||||
|
|
||||||
omsLogger.info(stopWatch.prettyPrint());
|
omsLogger.info(stopWatch.prettyPrint());
|
||||||
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
|
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
|
||||||
return new ProcessResult(true, message);
|
return new ProcessResult(true, message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 SQL
|
||||||
|
*
|
||||||
|
* @param sqlParams SQL processor 参数信息
|
||||||
|
* @param taskContext 任务上下文
|
||||||
|
*/
|
||||||
|
abstract void executeSql(SqlParams sqlParams, TaskContext taskContext);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 提取参数信息
|
||||||
|
*
|
||||||
|
* @param taskContext 任务上下文
|
||||||
|
* @return SqlParams
|
||||||
|
*/
|
||||||
|
protected SqlParams extractParams(TaskContext taskContext) {
|
||||||
|
return JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验参数,如果校验不通过直接抛异常
|
||||||
|
*
|
||||||
|
* @param sqlParams SQL 参数信息
|
||||||
|
*/
|
||||||
|
protected void validateParams(SqlParams sqlParams) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 设置 SQL 验证器
|
* 设置 SQL 验证器
|
||||||
*
|
*
|
||||||
@ -116,6 +116,7 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
this.sqlParser = sqlParser;
|
this.sqlParser = sqlParser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 注册一个 SQL 验证器
|
* 注册一个 SQL 验证器
|
||||||
*
|
*
|
||||||
@ -124,33 +125,7 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
*/
|
*/
|
||||||
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
|
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
|
||||||
sqlValidatorMap.put(validatorName, sqlValidator);
|
sqlValidatorMap.put(validatorName, sqlValidator);
|
||||||
log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName);
|
log.info("register sql validator({})' successfully.", validatorName);
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 注册数据源
|
|
||||||
*
|
|
||||||
* @param dataSourceName 数据源名称
|
|
||||||
* @param dataSource 数据源
|
|
||||||
*/
|
|
||||||
public void registerDataSource(String dataSourceName, DataSource dataSource) {
|
|
||||||
Assert.notNull(dataSourceName, "DataSource name must not be null");
|
|
||||||
Assert.notNull(dataSource, "DataSource must not be null");
|
|
||||||
dataSourceMap.put(dataSourceName, dataSource);
|
|
||||||
log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 移除数据源
|
|
||||||
*
|
|
||||||
* @param dataSourceName 数据源名称
|
|
||||||
*/
|
|
||||||
public void removeDataSource(String dataSourceName) {
|
|
||||||
DataSource remove = dataSourceMap.remove(dataSourceName);
|
|
||||||
if (remove != null) {
|
|
||||||
log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +144,7 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class SqlParams {
|
public static class SqlParams {
|
||||||
/**
|
/**
|
||||||
@ -183,9 +159,15 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
* 超时时间
|
* 超时时间
|
||||||
*/
|
*/
|
||||||
private Integer timeout;
|
private Integer timeout;
|
||||||
|
/**
|
||||||
|
* jdbc url
|
||||||
|
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
|
||||||
|
*/
|
||||||
|
private String jdbcUrl;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@FunctionalInterface
|
@FunctionalInterface
|
||||||
public interface SqlParser {
|
public interface SqlParser {
|
||||||
/**
|
/**
|
||||||
@ -198,4 +180,5 @@ public class SqlProcessor extends CommonBasicProcessor {
|
|||||||
String parse(String sql, TaskContext taskContext);
|
String parse(String sql, TaskContext taskContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -0,0 +1,103 @@
|
|||||||
|
package tech.powerjob.official.processors.impl.sql;
|
||||||
|
|
||||||
|
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
|
||||||
|
import com.google.common.collect.Maps;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.jdbc.core.JdbcTemplate;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import javax.sql.DataSource;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 简单 Spring SQL 处理器,强依赖于 Spring jdbc ,只能用 Spring Bean 的方式加载
|
||||||
|
* 直接忽略 SQL 执行的返回值
|
||||||
|
*
|
||||||
|
* 注意 :
|
||||||
|
* 默认情况下没有过滤任何 SQL
|
||||||
|
* 建议生产环境一定要使用 {@link SimpleSpringJdbcTemplateSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
|
||||||
|
*
|
||||||
|
* 默认情况下会直接执行参数中的 SQL
|
||||||
|
* 可以通过添加 {@link SimpleSpringJdbcTemplateSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等)
|
||||||
|
*
|
||||||
|
* @author Echo009
|
||||||
|
* @since 2021/3/10
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
|
||||||
|
/**
|
||||||
|
* 默认的数据源名称
|
||||||
|
*/
|
||||||
|
private static final String DEFAULT_DATASOURCE_NAME = "default";
|
||||||
|
/**
|
||||||
|
* name => data source
|
||||||
|
*/
|
||||||
|
private final Map<String, DataSource> dataSourceMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 指定默认的数据源
|
||||||
|
*
|
||||||
|
* @param defaultDataSource 默认数据源
|
||||||
|
*/
|
||||||
|
public SimpleSpringJdbcTemplateSqlProcessor(DataSource defaultDataSource) {
|
||||||
|
dataSourceMap = Maps.newConcurrentMap();
|
||||||
|
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验参数,如果校验不通过直接抛异常
|
||||||
|
*
|
||||||
|
* @param sqlParams SQL 参数信息
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected void validateParams(SqlParams sqlParams) {
|
||||||
|
// 检查数据源
|
||||||
|
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
|
||||||
|
// use the default data source when current data source name is empty
|
||||||
|
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
|
||||||
|
}
|
||||||
|
dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
|
||||||
|
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 SQL
|
||||||
|
*
|
||||||
|
* @param sqlParams SQL processor 参数信息
|
||||||
|
* @param taskContext 任务上下文
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
void executeSql(SqlParams sqlParams, TaskContext taskContext) {
|
||||||
|
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSourceMap.get(sqlParams.getDataSourceName()));
|
||||||
|
jdbcTemplate.setSkipResultsProcessing(true);
|
||||||
|
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
|
||||||
|
jdbcTemplate.execute(sqlParams.getSql());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 注册数据源
|
||||||
|
*
|
||||||
|
* @param dataSourceName 数据源名称
|
||||||
|
* @param dataSource 数据源
|
||||||
|
*/
|
||||||
|
public void registerDataSource(String dataSourceName, DataSource dataSource) {
|
||||||
|
Assert.notNull(dataSourceName, "DataSource name must not be null");
|
||||||
|
Assert.notNull(dataSource, "DataSource must not be null");
|
||||||
|
dataSourceMap.put(dataSourceName, dataSource);
|
||||||
|
log.info("register data source({})' successfully.", dataSourceName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除数据源
|
||||||
|
*
|
||||||
|
* @param dataSourceName 数据源名称
|
||||||
|
*/
|
||||||
|
public void removeDataSource(String dataSourceName) {
|
||||||
|
DataSource remove = dataSourceMap.remove(dataSourceName);
|
||||||
|
if (remove != null) {
|
||||||
|
log.warn("remove data source({})' successfully.", dataSourceName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
package tech.powerjob.official.processors.impl;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
||||||
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
||||||
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
||||||
|
import tech.powerjob.official.processors.TestUtils;
|
||||||
|
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Echo009
|
||||||
|
* @since 2021/3/11
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
class SimpleSpringJdbcTemplateSqlProcessorTest {
|
||||||
|
|
||||||
|
private static SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void initSqlProcessor() {
|
||||||
|
|
||||||
|
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
|
||||||
|
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
|
||||||
|
.addScript("classpath:db_init.sql")
|
||||||
|
.build();
|
||||||
|
simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(database);
|
||||||
|
// do nothing
|
||||||
|
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
||||||
|
// 排除掉包含 drop 的 SQL
|
||||||
|
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
||||||
|
// do nothing
|
||||||
|
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
||||||
|
log.info("init sql processor successfully!");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testSqlValidator() {
|
||||||
|
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
|
||||||
|
sqlParams.setSql("drop table test_table");
|
||||||
|
// 校验不通过
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testIncorrectDataSourceName() {
|
||||||
|
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
|
||||||
|
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
|
||||||
|
// 数据源名称非法
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testExecDDL() {
|
||||||
|
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
|
||||||
|
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
||||||
|
assertTrue(processResult.isSuccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testExecSQL() {
|
||||||
|
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
||||||
|
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
||||||
|
assertTrue(processResult.isSuccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
static SimpleSpringJdbcTemplateSqlProcessor.SqlParams constructSqlParam(String sql){
|
||||||
|
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
|
||||||
|
sqlParams.setSql(sql);
|
||||||
|
return sqlParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,80 +0,0 @@
|
|||||||
package tech.powerjob.official.processors.impl;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
|
||||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
|
||||||
import tech.powerjob.official.processors.TestUtils;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Echo009
|
|
||||||
* @since 2021/3/11
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
class SqlProcessorTest {
|
|
||||||
|
|
||||||
private static SqlProcessor sqlProcessor;
|
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
static void initSqlProcessor() {
|
|
||||||
|
|
||||||
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
|
|
||||||
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
|
|
||||||
.addScript("classpath:db_init.sql")
|
|
||||||
.build();
|
|
||||||
sqlProcessor = new SqlProcessor(database);
|
|
||||||
// do nothing
|
|
||||||
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
|
||||||
// 排除掉包含 drop 的 SQL
|
|
||||||
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
|
||||||
// do nothing
|
|
||||||
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
|
||||||
log.info("init sql processor successfully!");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void testSqlValidator() {
|
|
||||||
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
|
|
||||||
sqlParams.setSql("drop table test_table");
|
|
||||||
// 校验不通过
|
|
||||||
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void testIncorrectDataSourceName() {
|
|
||||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
|
|
||||||
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
|
|
||||||
// 数据源名称非法
|
|
||||||
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void testExecDDL() {
|
|
||||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
|
|
||||||
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
|
||||||
assertTrue(processResult.isSuccess());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void testExecSQL() {
|
|
||||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
|
||||||
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
|
||||||
assertTrue(processResult.isSuccess());
|
|
||||||
}
|
|
||||||
|
|
||||||
static SqlProcessor.SqlParams constructSqlParam(String sql){
|
|
||||||
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
|
|
||||||
sqlParams.setSql(sql);
|
|
||||||
return sqlParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -8,7 +8,7 @@ import org.springframework.beans.factory.annotation.Qualifier;
|
|||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.DependsOn;
|
import org.springframework.context.annotation.DependsOn;
|
||||||
import tech.powerjob.official.processors.impl.SqlProcessor;
|
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
|
||||||
|
|
||||||
import javax.sql.DataSource;
|
import javax.sql.DataSource;
|
||||||
|
|
||||||
@ -37,15 +37,15 @@ public class SqlProcessorConfiguration {
|
|||||||
|
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
|
public SimpleSpringJdbcTemplateSqlProcessor springSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
|
||||||
SqlProcessor sqlProcessor = new SqlProcessor(dataSource);
|
SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(dataSource);
|
||||||
// do nothing
|
// do nothing
|
||||||
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
||||||
// 排除掉包含 drop 的 SQL
|
// 排除掉包含 drop 的 SQL
|
||||||
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
||||||
// do nothing
|
// do nothing
|
||||||
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
||||||
return sqlProcessor;
|
return simpleSpringJdbcTemplateSqlProcessor;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -13,11 +13,11 @@ public interface BasicProcessor {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 核心处理逻辑
|
* 核心处理逻辑
|
||||||
* 可通过 {@link TaskContext#workflowContext} 获取工作流上下文
|
* 可通过 {@link TaskContext#getWorkflowContext()} 方法获取工作流上下文
|
||||||
*
|
*
|
||||||
* @param context 任务上下文,可通过 jobParams 和 instanceParams 分别获取控制台参数和OpenAPI传递的任务实例参数
|
* @param context 任务上下文,可通过 jobParams 和 instanceParams 分别获取控制台参数和OpenAPI传递的任务实例参数
|
||||||
* @return 处理结果,msg有长度限制,超长会被裁剪,不允许返回 null
|
* @return 处理结果,msg有长度限制,超长会被裁剪,不允许返回 null
|
||||||
* @throws Exception 异常,允许抛出异常,但不推荐,最好由业务开发者自己处理
|
* @throws Exception 异常,允许抛出异常,但不推荐,最好由业务开发者自己处理
|
||||||
*/
|
*/
|
||||||
ProcessResult process(TaskContext context) throws Exception;
|
ProcessResult process(TaskContext context) throws Exception;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user