diff --git a/processor/src/test/java/org/mapstruct/ap/testutil/MapperTestBase.java b/processor/src/test/java/org/mapstruct/ap/testutil/MapperTestBase.java index 5844c67a6..d8de75053 100644 --- a/processor/src/test/java/org/mapstruct/ap/testutil/MapperTestBase.java +++ b/processor/src/test/java/org/mapstruct/ap/testutil/MapperTestBase.java @@ -25,10 +25,13 @@ import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Set; import javax.tools.DiagnosticCollector; import javax.tools.JavaCompiler; import javax.tools.JavaCompiler.CompilationTask; @@ -110,7 +113,7 @@ public abstract class MapperTestBase { @BeforeMethod public void generateMapperImplementation(Method testMethod) { diagnostics = new DiagnosticCollector(); - List sourceFiles = getSourceFiles( getTestClasses( testMethod ) ); + Set sourceFiles = getSourceFiles( getTestClasses( testMethod ) ); List processorOptions = getProcessorOptions( testMethod ); boolean compilationSuccessful = compile( diagnostics, sourceFiles, processorOptions ); @@ -182,22 +185,28 @@ public abstract class MapperTestBase { * * @param testMethod The test method of interest * - * @return A list containing the classes to be compiled for this test + * @return A set containing the classes to be compiled for this test */ - private List> getTestClasses(Method testMethod) { - WithClasses withClasses = testMethod.getAnnotation( WithClasses.class ); + private Set> getTestClasses(Method testMethod) { + Set> testClasses = new HashSet>(); - if ( withClasses == null ) { - withClasses = this.getClass().getAnnotation( WithClasses.class ); + WithClasses withClasses = testMethod.getAnnotation( WithClasses.class ); + if ( withClasses != null ) { + testClasses.addAll( Arrays.asList( withClasses.value() ) ); } - if ( withClasses == null || withClasses.value().length == 0 ) { + withClasses = this.getClass().getAnnotation( WithClasses.class ); + if ( withClasses != null ) { + testClasses.addAll( Arrays.asList( withClasses.value() ) ); + } + + if ( testClasses.isEmpty() ) { throw new IllegalStateException( "The classes to be compiled during the test must be specified via @WithClasses." ); } - return Arrays.asList( withClasses.value() ); + return testClasses; } /** @@ -222,8 +231,8 @@ public abstract class MapperTestBase { return String.format( "-A%s=%s", processorOption.name(), processorOption.value() ); } - private List getSourceFiles(List> classes) { - List sourceFiles = new ArrayList( classes.size() ); + private Set getSourceFiles(Collection> classes) { + Set sourceFiles = new HashSet( classes.size() ); for ( Class clazz : classes ) { sourceFiles.add(