diff --git a/powerjob-official-processors/pom.xml b/powerjob-official-processors/pom.xml
index 73681a87..c7165e2f 100644
--- a/powerjob-official-processors/pom.xml
+++ b/powerjob-official-processors/pom.xml
@@ -21,6 +21,8 @@
5.6.1
1.2.3
3.4.6
+ 5.2.9.RELEASE
+
1.4.200
1.2.68
@@ -65,6 +67,13 @@
provided
+
+ org.springframework
+ spring-jdbc
+ ${spring.jdbc.version}
+ provided
+
+
org.junit.jupiter
@@ -80,6 +89,14 @@
${logback.version}
test
+
+
+
+ com.h2database
+ h2
+ ${h2.db.version}
+ test
+
@@ -104,6 +121,7 @@
shade.powerjob.org
org.slf4j.*
+ org.springframework.*
diff --git a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java
new file mode 100644
index 00000000..d66df931
--- /dev/null
+++ b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java
@@ -0,0 +1,201 @@
+package tech.powerjob.official.processors.impl;
+
+import com.alibaba.fastjson.JSON;
+import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
+import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
+import com.github.kfcfans.powerjob.worker.log.OmsLogger;
+import com.google.common.collect.Maps;
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.util.Assert;
+import org.springframework.util.StopWatch;
+import org.springframework.util.StringUtils;
+import tech.powerjob.official.processors.CommonBasicProcessor;
+import tech.powerjob.official.processors.util.CommonUtils;
+
+import javax.sql.DataSource;
+import java.util.Map;
+import java.util.function.Predicate;
+
+/**
+ * SQL 处理器,只能用 Spring Bean 的方式加载
+ * 注意 : 默认情况下没有过滤任何 SQL
+ * 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
+ *
+ * 默认情况下会直接执行参数中的 SQL
+ * 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等)
+ *
+ * @author Echo009
+ * @since 2021/3/10
+ */
+@Slf4j
+public class SqlProcessor extends CommonBasicProcessor {
+ /**
+ * 默认的数据源名称
+ */
+ private static final String DEFAULT_DATASOURCE_NAME = "default";
+ /**
+ * 默认超时时间
+ */
+ private static final int DEFAULT_TIMEOUT = 60;
+ /**
+ * name => data source
+ */
+ private final Map dataSourceMap;
+ /**
+ * name => SQL validator
+ * 注意 :
+ * - 返回 true 表示验证通过
+ * - 返回 false 表示 SQL 非法,将被拒绝执行
+ */
+ private final Map> sqlValidatorMap;
+ /**
+ * 自定义 SQL 解析器
+ */
+ private SqlParser sqlParser;
+
+ /**
+ * 指定默认的数据源
+ *
+ * @param defaultDataSource 默认数据源
+ */
+ public SqlProcessor(DataSource defaultDataSource) {
+ dataSourceMap = Maps.newConcurrentMap();
+ sqlValidatorMap = Maps.newConcurrentMap();
+ registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
+ }
+
+
+ @Override
+ protected ProcessResult process0(TaskContext taskContext) {
+
+ OmsLogger omsLogger = taskContext.getOmsLogger();
+ SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
+ // 检查数据源
+ if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
+ sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
+ omsLogger.info("current data source name is empty, use the default data source");
+ }
+ DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
+ throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
+ });
+ StopWatch stopWatch = new StopWatch("SQL Processor");
+
+ // 解析
+ stopWatch.start("Parse SQL");
+ if (sqlParser != null) {
+ sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext));
+ }
+ stopWatch.stop();
+
+ // 校验 SQL
+ stopWatch.start("Validate SQL");
+ validateSql(sqlParams.getSql());
+ stopWatch.stop();
+
+ // 执行
+ stopWatch.start("Execute SQL");
+ JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
+ jdbcTemplate.setSkipResultsProcessing(true);
+ jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
+ omsLogger.info("start to execute sql: {}", sqlParams.getSql());
+ jdbcTemplate.execute(sqlParams.getSql());
+ stopWatch.stop();
+ omsLogger.info(stopWatch.prettyPrint());
+ String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
+ return new ProcessResult(true, message);
+ }
+
+ /**
+ * 设置 SQL 验证器
+ *
+ * @param sqlParser SQL 解析器
+ */
+ public void setSqlParser(SqlParser sqlParser) {
+ this.sqlParser = sqlParser;
+ }
+
+ /**
+ * 注册一个 SQL 验证器
+ *
+ * @param validatorName 验证器名称
+ * @param sqlValidator 验证器
+ */
+ public void registerSqlValidator(String validatorName, Predicate sqlValidator) {
+ sqlValidatorMap.put(validatorName, sqlValidator);
+ log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName);
+ }
+
+
+ /**
+ * 注册数据源
+ *
+ * @param dataSourceName 数据源名称
+ * @param dataSource 数据源
+ */
+ public void registerDataSource(String dataSourceName, DataSource dataSource) {
+ Assert.notNull(dataSourceName, "DataSource name must not be null");
+ Assert.notNull(dataSource, "DataSource must not be null");
+ dataSourceMap.put(dataSourceName, dataSource);
+ log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName);
+ }
+
+ /**
+ * 移除数据源
+ *
+ * @param dataSourceName 数据源名称
+ */
+ public void removeDataSource(String dataSourceName) {
+ DataSource remove = dataSourceMap.remove(dataSourceName);
+ if (remove != null) {
+ log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName);
+ }
+ }
+
+
+ /**
+ * 校验 SQL 合法性
+ */
+ private void validateSql(String sql) {
+ if (sqlValidatorMap.isEmpty()) {
+ return;
+ }
+ for (Map.Entry> entry : sqlValidatorMap.entrySet()) {
+ Predicate validator = entry.getValue();
+ if (!validator.test(sql)) {
+ throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
+ }
+ }
+ }
+
+ @Data
+ public static class SqlParams {
+ /**
+ * 数据源名称
+ */
+ private String dataSourceName;
+ /**
+ * 需要执行的 SQL
+ */
+ private String sql;
+ /**
+ * 超时时间
+ */
+ private Integer timeout;
+
+ }
+
+ @FunctionalInterface
+ public interface SqlParser {
+ /**
+ * 自定义 SQL 解析逻辑
+ *
+ * @param sql 原始 SQL 语句
+ * @param taskContext 任务上下文
+ * @return 解析后的 SQL
+ */
+ String parse(String sql, TaskContext taskContext);
+ }
+
+}
diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java
index 5673447a..d252e149 100644
--- a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java
+++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java
@@ -1,8 +1,8 @@
package tech.powerjob.official.processors;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
+import com.github.kfcfans.powerjob.worker.core.processor.WorkflowContext;
import com.github.kfcfans.powerjob.worker.log.impl.OmsLocalLogger;
-import com.github.kfcfans.powerjob.worker.log.impl.OmsServerLogger;
import java.util.concurrent.ThreadLocalRandom;
@@ -25,7 +25,7 @@ public class TestUtils {
taskContext.setTaskId("0.0");
taskContext.setTaskName("TEST_TASK");
taskContext.setOmsLogger(new OmsLocalLogger());
-
+ taskContext.setWorkflowContext(new WorkflowContext(null, null));
return taskContext;
}
}
diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java
new file mode 100644
index 00000000..c216db7d
--- /dev/null
+++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java
@@ -0,0 +1,80 @@
+package tech.powerjob.official.processors.impl;
+
+import com.alibaba.fastjson.JSON;
+import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
+import lombok.extern.slf4j.Slf4j;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import tech.powerjob.official.processors.TestUtils;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * @author Echo009
+ * @since 2021/3/11
+ */
+@Slf4j
+class SqlProcessorTest {
+
+ private static SqlProcessor sqlProcessor;
+
+ @BeforeAll
+ static void initSqlProcessor() {
+
+ EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
+ EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
+ .addScript("classpath:db_init.sql")
+ .build();
+ sqlProcessor = new SqlProcessor(database);
+ // do nothing
+ sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
+ // 排除掉包含 drop 的 SQL
+ sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
+ // do nothing
+ sqlProcessor.setSqlParser((sql, taskContext) -> sql);
+ log.info("init sql processor successfully!");
+
+ }
+
+
+ @Test
+ void testSqlValidator() {
+ SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
+ sqlParams.setSql("drop table test_table");
+ // 校验不通过
+ assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
+ }
+
+ @Test
+ void testIncorrectDataSourceName() {
+ SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
+ sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
+ // 数据源名称非法
+ assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
+ }
+
+ @Test
+ void testExecDDL() {
+ SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
+ ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
+ assertTrue(processResult.isSuccess());
+ }
+
+ @Test
+ void testExecSQL() {
+ SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
+ ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
+ assertTrue(processResult.isSuccess());
+ }
+
+ static SqlProcessor.SqlParams constructSqlParam(String sql){
+ SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
+ sqlParams.setSql(sql);
+ return sqlParams;
+ }
+
+}
diff --git a/powerjob-official-processors/src/test/resources/db_init.sql b/powerjob-official-processors/src/test/resources/db_init.sql
new file mode 100644
index 00000000..0d31f4dd
--- /dev/null
+++ b/powerjob-official-processors/src/test/resources/db_init.sql
@@ -0,0 +1,7 @@
+create table test_table
+(
+ id bigint(20) primary key,
+ content varchar(255),
+ gmt_create datetime default now(),
+ gmt_modified datetime default now()
+);
\ No newline at end of file
diff --git a/powerjob-worker-samples/pom.xml b/powerjob-worker-samples/pom.xml
index 9526b515..73336f2a 100644
--- a/powerjob-worker-samples/pom.xml
+++ b/powerjob-worker-samples/pom.xml
@@ -16,6 +16,7 @@
2.2.6.RELEASE
3.4.6
1.2.68
+ 1.0.1
true
@@ -52,6 +53,11 @@
fastjson
${fastjson.version}
+
+ com.github.kfcfans
+ powerjob-official-processors
+ ${powerjob.official.processors.version}
+
diff --git a/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java
new file mode 100644
index 00000000..e9ca6db3
--- /dev/null
+++ b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java
@@ -0,0 +1,51 @@
+package com.github.kfcfans.powerjob.samples.config;
+
+import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
+import com.zaxxer.hikari.HikariConfig;
+import com.zaxxer.hikari.HikariDataSource;
+import org.h2.Driver;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.annotation.DependsOn;
+import tech.powerjob.official.processors.impl.SqlProcessor;
+
+import javax.sql.DataSource;
+
+/**
+ * @author Echo009
+ * @since 2021/3/10
+ */
+@Configuration
+public class SqlProcessorConfiguration {
+
+
+ @Bean
+ @DependsOn({"initPowerJob"})
+ public DataSource sqlProcessorDataSource() {
+ String jdbcUrl = String.format("jdbc:h2:file:%spowerjob_sql_processor_db;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=false", OmsWorkerFileUtils.getH2WorkDir());
+ HikariConfig config = new HikariConfig();
+ config.setDriverClassName(Driver.class.getName());
+ config.setJdbcUrl(jdbcUrl);
+ config.setAutoCommit(true);
+ // 池中最小空闲连接数量
+ config.setMinimumIdle(1);
+ // 池中最大连接数量
+ config.setMaximumPoolSize(10);
+ return new HikariDataSource(config);
+ }
+
+
+ @Bean
+ public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
+ SqlProcessor sqlProcessor = new SqlProcessor(dataSource);
+ // do nothing
+ sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
+ // 排除掉包含 drop 的 SQL
+ sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
+ // do nothing
+ sqlProcessor.setSqlParser((sql, taskContext) -> sql);
+ return sqlProcessor;
+ }
+
+}
diff --git a/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java b/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java
index 7bd23325..5cb8c8ed 100644
--- a/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java
+++ b/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java
@@ -1,6 +1,5 @@
package com.github.kfcfans.powerjob.worker.persistence;
-import com.github.kfcfans.powerjob.worker.OhMyWorker;
import com.github.kfcfans.powerjob.worker.common.constants.StoreStrategy;
import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
import com.zaxxer.hikari.HikariConfig;