From 87ed3047377a5e026132a6f1ed77c8806daae8a2 Mon Sep 17 00:00:00 2001 From: Echo009 Date: Thu, 11 Mar 2021 11:40:25 +0800 Subject: [PATCH] feat: SQL processor --- powerjob-official-processors/pom.xml | 18 ++ .../processors/impl/SqlProcessor.java | 201 ++++++++++++++++++ .../official/processors/TestUtils.java | 4 +- .../processors/impl/SqlProcessorTest.java | 80 +++++++ .../src/test/resources/db_init.sql | 7 + powerjob-worker-samples/pom.xml | 6 + .../config/SqlProcessorConfiguration.java | 51 +++++ .../worker/persistence/ConnectionFactory.java | 1 - 8 files changed, 365 insertions(+), 3 deletions(-) create mode 100644 powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java create mode 100644 powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java create mode 100644 powerjob-official-processors/src/test/resources/db_init.sql create mode 100644 powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java diff --git a/powerjob-official-processors/pom.xml b/powerjob-official-processors/pom.xml index 73681a87..c7165e2f 100644 --- a/powerjob-official-processors/pom.xml +++ b/powerjob-official-processors/pom.xml @@ -21,6 +21,8 @@ 5.6.1 1.2.3 3.4.6 + 5.2.9.RELEASE + 1.4.200 1.2.68 @@ -65,6 +67,13 @@ provided + + org.springframework + spring-jdbc + ${spring.jdbc.version} + provided + + org.junit.jupiter @@ -80,6 +89,14 @@ ${logback.version} test + + + + com.h2database + h2 + ${h2.db.version} + test + @@ -104,6 +121,7 @@ shade.powerjob.org org.slf4j.* + org.springframework.* diff --git a/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java new file mode 100644 index 00000000..d66df931 --- /dev/null +++ b/powerjob-official-processors/src/main/java/tech/powerjob/official/processors/impl/SqlProcessor.java @@ -0,0 +1,201 @@ +package tech.powerjob.official.processors.impl; + +import com.alibaba.fastjson.JSON; +import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; +import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; +import com.github.kfcfans.powerjob.worker.log.OmsLogger; +import com.google.common.collect.Maps; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StopWatch; +import org.springframework.util.StringUtils; +import tech.powerjob.official.processors.CommonBasicProcessor; +import tech.powerjob.official.processors.util.CommonUtils; + +import javax.sql.DataSource; +import java.util.Map; +import java.util.function.Predicate; + +/** + * SQL 处理器,只能用 Spring Bean 的方式加载 + * 注意 : 默认情况下没有过滤任何 SQL + * 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL + * + * 默认情况下会直接执行参数中的 SQL + * 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等) + * + * @author Echo009 + * @since 2021/3/10 + */ +@Slf4j +public class SqlProcessor extends CommonBasicProcessor { + /** + * 默认的数据源名称 + */ + private static final String DEFAULT_DATASOURCE_NAME = "default"; + /** + * 默认超时时间 + */ + private static final int DEFAULT_TIMEOUT = 60; + /** + * name => data source + */ + private final Map dataSourceMap; + /** + * name => SQL validator + * 注意 : + * - 返回 true 表示验证通过 + * - 返回 false 表示 SQL 非法,将被拒绝执行 + */ + private final Map> sqlValidatorMap; + /** + * 自定义 SQL 解析器 + */ + private SqlParser sqlParser; + + /** + * 指定默认的数据源 + * + * @param defaultDataSource 默认数据源 + */ + public SqlProcessor(DataSource defaultDataSource) { + dataSourceMap = Maps.newConcurrentMap(); + sqlValidatorMap = Maps.newConcurrentMap(); + registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource); + } + + + @Override + protected ProcessResult process0(TaskContext taskContext) { + + OmsLogger omsLogger = taskContext.getOmsLogger(); + SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class); + // 检查数据源 + if (StringUtils.isEmpty(sqlParams.getDataSourceName())) { + sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME); + omsLogger.info("current data source name is empty, use the default data source"); + } + DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> { + throw new IllegalArgumentException("can't find data source with name " + dataSourceName); + }); + StopWatch stopWatch = new StopWatch("SQL Processor"); + + // 解析 + stopWatch.start("Parse SQL"); + if (sqlParser != null) { + sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext)); + } + stopWatch.stop(); + + // 校验 SQL + stopWatch.start("Validate SQL"); + validateSql(sqlParams.getSql()); + stopWatch.stop(); + + // 执行 + stopWatch.start("Execute SQL"); + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + jdbcTemplate.setSkipResultsProcessing(true); + jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout()); + omsLogger.info("start to execute sql: {}", sqlParams.getSql()); + jdbcTemplate.execute(sqlParams.getSql()); + stopWatch.stop(); + omsLogger.info(stopWatch.prettyPrint()); + String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis()); + return new ProcessResult(true, message); + } + + /** + * 设置 SQL 验证器 + * + * @param sqlParser SQL 解析器 + */ + public void setSqlParser(SqlParser sqlParser) { + this.sqlParser = sqlParser; + } + + /** + * 注册一个 SQL 验证器 + * + * @param validatorName 验证器名称 + * @param sqlValidator 验证器 + */ + public void registerSqlValidator(String validatorName, Predicate sqlValidator) { + sqlValidatorMap.put(validatorName, sqlValidator); + log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName); + } + + + /** + * 注册数据源 + * + * @param dataSourceName 数据源名称 + * @param dataSource 数据源 + */ + public void registerDataSource(String dataSourceName, DataSource dataSource) { + Assert.notNull(dataSourceName, "DataSource name must not be null"); + Assert.notNull(dataSource, "DataSource must not be null"); + dataSourceMap.put(dataSourceName, dataSource); + log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName); + } + + /** + * 移除数据源 + * + * @param dataSourceName 数据源名称 + */ + public void removeDataSource(String dataSourceName) { + DataSource remove = dataSourceMap.remove(dataSourceName); + if (remove != null) { + log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName); + } + } + + + /** + * 校验 SQL 合法性 + */ + private void validateSql(String sql) { + if (sqlValidatorMap.isEmpty()) { + return; + } + for (Map.Entry> entry : sqlValidatorMap.entrySet()) { + Predicate validator = entry.getValue(); + if (!validator.test(sql)) { + throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey()); + } + } + } + + @Data + public static class SqlParams { + /** + * 数据源名称 + */ + private String dataSourceName; + /** + * 需要执行的 SQL + */ + private String sql; + /** + * 超时时间 + */ + private Integer timeout; + + } + + @FunctionalInterface + public interface SqlParser { + /** + * 自定义 SQL 解析逻辑 + * + * @param sql 原始 SQL 语句 + * @param taskContext 任务上下文 + * @return 解析后的 SQL + */ + String parse(String sql, TaskContext taskContext); + } + +} diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java index 5673447a..d252e149 100644 --- a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java +++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/TestUtils.java @@ -1,8 +1,8 @@ package tech.powerjob.official.processors; import com.github.kfcfans.powerjob.worker.core.processor.TaskContext; +import com.github.kfcfans.powerjob.worker.core.processor.WorkflowContext; import com.github.kfcfans.powerjob.worker.log.impl.OmsLocalLogger; -import com.github.kfcfans.powerjob.worker.log.impl.OmsServerLogger; import java.util.concurrent.ThreadLocalRandom; @@ -25,7 +25,7 @@ public class TestUtils { taskContext.setTaskId("0.0"); taskContext.setTaskName("TEST_TASK"); taskContext.setOmsLogger(new OmsLocalLogger()); - + taskContext.setWorkflowContext(new WorkflowContext(null, null)); return taskContext; } } diff --git a/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java new file mode 100644 index 00000000..c216db7d --- /dev/null +++ b/powerjob-official-processors/src/test/java/tech/powerjob/official/processors/impl/SqlProcessorTest.java @@ -0,0 +1,80 @@ +package tech.powerjob.official.processors.impl; + +import com.alibaba.fastjson.JSON; +import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import tech.powerjob.official.processors.TestUtils; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author Echo009 + * @since 2021/3/11 + */ +@Slf4j +class SqlProcessorTest { + + private static SqlProcessor sqlProcessor; + + @BeforeAll + static void initSqlProcessor() { + + EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); + EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2) + .addScript("classpath:db_init.sql") + .build(); + sqlProcessor = new SqlProcessor(database); + // do nothing + sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); + // 排除掉包含 drop 的 SQL + sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); + // do nothing + sqlProcessor.setSqlParser((sql, taskContext) -> sql); + log.info("init sql processor successfully!"); + + } + + + @Test + void testSqlValidator() { + SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams(); + sqlParams.setSql("drop table test_table"); + // 校验不通过 + assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); + } + + @Test + void testIncorrectDataSourceName() { + SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))"); + sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧"); + // 数据源名称非法 + assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)))); + } + + @Test + void testExecDDL() { + SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))"); + ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))); + assertTrue(processResult.isSuccess()); + } + + @Test + void testExecSQL() { + SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')"); + ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))); + assertTrue(processResult.isSuccess()); + } + + static SqlProcessor.SqlParams constructSqlParam(String sql){ + SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams(); + sqlParams.setSql(sql); + return sqlParams; + } + +} diff --git a/powerjob-official-processors/src/test/resources/db_init.sql b/powerjob-official-processors/src/test/resources/db_init.sql new file mode 100644 index 00000000..0d31f4dd --- /dev/null +++ b/powerjob-official-processors/src/test/resources/db_init.sql @@ -0,0 +1,7 @@ +create table test_table +( + id bigint(20) primary key, + content varchar(255), + gmt_create datetime default now(), + gmt_modified datetime default now() +); \ No newline at end of file diff --git a/powerjob-worker-samples/pom.xml b/powerjob-worker-samples/pom.xml index 9526b515..73336f2a 100644 --- a/powerjob-worker-samples/pom.xml +++ b/powerjob-worker-samples/pom.xml @@ -16,6 +16,7 @@ 2.2.6.RELEASE 3.4.6 1.2.68 + 1.0.1 true @@ -52,6 +53,11 @@ fastjson ${fastjson.version} + + com.github.kfcfans + powerjob-official-processors + ${powerjob.official.processors.version} + diff --git a/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java new file mode 100644 index 00000000..e9ca6db3 --- /dev/null +++ b/powerjob-worker-samples/src/main/java/com/github/kfcfans/powerjob/samples/config/SqlProcessorConfiguration.java @@ -0,0 +1,51 @@ +package com.github.kfcfans.powerjob.samples.config; + +import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import org.h2.Driver; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.DependsOn; +import tech.powerjob.official.processors.impl.SqlProcessor; + +import javax.sql.DataSource; + +/** + * @author Echo009 + * @since 2021/3/10 + */ +@Configuration +public class SqlProcessorConfiguration { + + + @Bean + @DependsOn({"initPowerJob"}) + public DataSource sqlProcessorDataSource() { + String jdbcUrl = String.format("jdbc:h2:file:%spowerjob_sql_processor_db;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=false", OmsWorkerFileUtils.getH2WorkDir()); + HikariConfig config = new HikariConfig(); + config.setDriverClassName(Driver.class.getName()); + config.setJdbcUrl(jdbcUrl); + config.setAutoCommit(true); + // 池中最小空闲连接数量 + config.setMinimumIdle(1); + // 池中最大连接数量 + config.setMaximumPoolSize(10); + return new HikariDataSource(config); + } + + + @Bean + public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) { + SqlProcessor sqlProcessor = new SqlProcessor(dataSource); + // do nothing + sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true); + // 排除掉包含 drop 的 SQL + sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$")); + // do nothing + sqlProcessor.setSqlParser((sql, taskContext) -> sql); + return sqlProcessor; + } + +} diff --git a/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java b/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java index 7bd23325..5cb8c8ed 100644 --- a/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java +++ b/powerjob-worker/src/main/java/com/github/kfcfans/powerjob/worker/persistence/ConnectionFactory.java @@ -1,6 +1,5 @@ package com.github.kfcfans.powerjob.worker.persistence; -import com.github.kfcfans.powerjob.worker.OhMyWorker; import com.github.kfcfans.powerjob.worker.common.constants.StoreStrategy; import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils; import com.zaxxer.hikari.HikariConfig;