1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package com.github.springtestdbunit;
18
19 import java.lang.annotation.Annotation;
20 import java.lang.reflect.AnnotatedElement;
21 import java.sql.SQLException;
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.Collections;
25 import java.util.Iterator;
26 import java.util.LinkedList;
27 import java.util.List;
28 import java.util.Map;
29
30 import org.apache.commons.logging.Log;
31 import org.apache.commons.logging.LogFactory;
32 import org.dbunit.DatabaseUnitException;
33 import org.dbunit.database.IDatabaseConnection;
34 import org.dbunit.dataset.CompositeDataSet;
35 import org.dbunit.dataset.DataSetException;
36 import org.dbunit.dataset.IDataSet;
37 import org.dbunit.dataset.ITable;
38 import org.dbunit.dataset.filter.IColumnFilter;
39 import org.springframework.core.annotation.AnnotationUtils;
40 import org.springframework.util.Assert;
41 import org.springframework.util.StringUtils;
42
43 import com.github.springtestdbunit.annotation.DatabaseOperation;
44 import com.github.springtestdbunit.annotation.DatabaseSetup;
45 import com.github.springtestdbunit.annotation.DatabaseSetups;
46 import com.github.springtestdbunit.annotation.DatabaseTearDown;
47 import com.github.springtestdbunit.annotation.DatabaseTearDowns;
48 import com.github.springtestdbunit.annotation.ExpectedDatabase;
49 import com.github.springtestdbunit.annotation.ExpectedDatabases;
50 import com.github.springtestdbunit.assertion.DatabaseAssertion;
51 import com.github.springtestdbunit.dataset.DataSetLoader;
52 import com.github.springtestdbunit.dataset.DataSetModifier;
53
54
55
56
57
58
59
60
61
62
63 public class DbUnitRunner {
64
65 private static final Log logger = LogFactory.getLog(DbUnitTestExecutionListener.class);
66
67
68
69
70
71
72 public void beforeTestMethod(DbUnitTestContext testContext) throws Exception {
73 Annotations<DatabaseSetup> annotations = Annotations.get(testContext, DatabaseSetups.class,
74 DatabaseSetup.class);
75 setupOrTeardown(testContext, true, AnnotationAttributes.get(annotations));
76 }
77
78
79
80
81
82
83 public void afterTestMethod(DbUnitTestContext testContext) throws Exception {
84 try {
85 try {
86 verifyExpected(testContext,
87 Annotations.get(testContext, ExpectedDatabases.class, ExpectedDatabase.class));
88 } finally {
89 Annotations<DatabaseTearDown> annotations = Annotations.get(testContext, DatabaseTearDowns.class,
90 DatabaseTearDown.class);
91 try {
92 setupOrTeardown(testContext, false, AnnotationAttributes.get(annotations));
93 } catch (RuntimeException ex) {
94 if (testContext.getTestException() == null) {
95 throw ex;
96 }
97 if (logger.isWarnEnabled()) {
98 logger.warn("Unable to throw database cleanup exception due to existing test error", ex);
99 }
100 }
101 }
102 } finally {
103 testContext.getConnections().closeAll();
104 }
105 }
106
107 private void verifyExpected(DbUnitTestContext testContext, Annotations<ExpectedDatabase> annotations)
108 throws Exception {
109 if (testContext.getTestException() != null) {
110 if (logger.isDebugEnabled()) {
111 logger.debug("Skipping @DatabaseTest expectation due to test exception "
112 + testContext.getTestException().getClass());
113 }
114 return;
115 }
116 DatabaseConnections connections = testContext.getConnections();
117 DataSetModifier modifier = getModifier(testContext, annotations);
118 boolean override = false;
119 for (ExpectedDatabase annotation : annotations.getMethodAnnotations()) {
120 verifyExpected(testContext, connections, modifier, annotation);
121 override |= annotation.override();
122 }
123 if (!override) {
124 for (ExpectedDatabase annotation : annotations.getClassAnnotations()) {
125 verifyExpected(testContext, connections, modifier, annotation);
126 }
127 }
128 }
129
130 private void verifyExpected(DbUnitTestContext testContext, DatabaseConnections connections,
131 DataSetModifier modifier, ExpectedDatabase annotation)
132 throws Exception, DataSetException, SQLException, DatabaseUnitException {
133 String query = annotation.query();
134 String table = annotation.table();
135 IDataSet expectedDataSet = loadDataset(testContext, annotation.value(), modifier);
136 IDatabaseConnection connection = connections.get(annotation.connection());
137 if (expectedDataSet != null) {
138 if (logger.isDebugEnabled()) {
139 logger.debug("Veriftying @DatabaseTest expectation using " + annotation.value());
140 }
141 DatabaseAssertion assertion = annotation.assertionMode().getDatabaseAssertion();
142 List<IColumnFilter> columnFilters = getColumnFilters(annotation);
143 if (StringUtils.hasLength(query)) {
144 Assert.hasLength(table, "The table name must be specified when using a SQL query");
145 ITable expectedTable = expectedDataSet.getTable(table);
146 ITable actualTable = connection.createQueryTable(table, query);
147 assertion.assertEquals(expectedTable, actualTable, columnFilters);
148 } else if (StringUtils.hasLength(table)) {
149 ITable actualTable = connection.createTable(table);
150 ITable expectedTable = expectedDataSet.getTable(table);
151 assertion.assertEquals(expectedTable, actualTable, columnFilters);
152 } else {
153 IDataSet actualDataSet = connection.createDataSet();
154 assertion.assertEquals(expectedDataSet, actualDataSet, columnFilters);
155 }
156 }
157 }
158
159 private DataSetModifier getModifier(DbUnitTestContext testContext, Annotations<ExpectedDatabase> annotations) {
160 DataSetModifiers modifiers = new DataSetModifiers();
161 for (ExpectedDatabase annotation : annotations) {
162 for (Class<? extends DataSetModifier> modifierClass : annotation.modifiers()) {
163 modifiers.add(testContext.getTestInstance(), modifierClass);
164 }
165 }
166 return modifiers;
167 }
168
169 private void setupOrTeardown(DbUnitTestContext testContext, boolean isSetup,
170 Collection<AnnotationAttributes> annotations) throws Exception {
171 DatabaseConnections connections = testContext.getConnections();
172 for (AnnotationAttributes annotation : annotations) {
173 List<IDataSet> datasets = loadDataSets(testContext, annotation);
174 DatabaseOperation operation = annotation.getType();
175 org.dbunit.operation.DatabaseOperation dbUnitOperation = getDbUnitDatabaseOperation(testContext, operation);
176 if (!datasets.isEmpty()) {
177 if (logger.isDebugEnabled()) {
178 logger.debug("Executing " + (isSetup ? "Setup" : "Teardown") + " of @DatabaseTest using "
179 + operation + " on " + datasets.toString());
180 }
181 IDatabaseConnection connection = connections.get(annotation.getConnection());
182 IDataSet dataSet = new CompositeDataSet(datasets.toArray(new IDataSet[datasets.size()]));
183 dbUnitOperation.execute(connection, dataSet);
184 }
185 }
186 }
187
188 private List<IDataSet> loadDataSets(DbUnitTestContext testContext, AnnotationAttributes annotation)
189 throws Exception {
190 List<IDataSet> datasets = new ArrayList<IDataSet>();
191 for (String dataSetLocation : annotation.getValue()) {
192 datasets.add(loadDataset(testContext, dataSetLocation, DataSetModifier.NONE));
193 }
194 if (datasets.isEmpty()) {
195 datasets.add(getFullDatabaseDataSet(testContext, annotation.getConnection()));
196 }
197 return datasets;
198 }
199
200 private IDataSet getFullDatabaseDataSet(DbUnitTestContext testContext, String name) throws Exception {
201 IDatabaseConnection connection = testContext.getConnections().get(name);
202 return connection.createDataSet();
203 }
204
205 private IDataSet loadDataset(DbUnitTestContext testContext, String dataSetLocation, DataSetModifier modifier)
206 throws Exception {
207 DataSetLoader dataSetLoader = testContext.getDataSetLoader();
208 if (StringUtils.hasLength(dataSetLocation)) {
209 IDataSet dataSet = dataSetLoader.loadDataSet(testContext.getTestClass(), dataSetLocation);
210 dataSet = modifier.modify(dataSet);
211 Assert.notNull(dataSet,
212 "Unable to load dataset from \"" + dataSetLocation + "\" using " + dataSetLoader.getClass());
213 return dataSet;
214 }
215 return null;
216 }
217
218 private List<IColumnFilter> getColumnFilters(ExpectedDatabase annotation) throws Exception {
219 Class<? extends IColumnFilter>[] columnFilterClasses = annotation.columnFilters();
220 List<IColumnFilter> columnFilters = new LinkedList<IColumnFilter>();
221 for (Class<? extends IColumnFilter> columnFilterClass : columnFilterClasses) {
222 columnFilters.add(columnFilterClass.newInstance());
223 }
224 return columnFilters;
225 }
226
227 private org.dbunit.operation.DatabaseOperation getDbUnitDatabaseOperation(DbUnitTestContext testContext,
228 DatabaseOperation operation) {
229 org.dbunit.operation.DatabaseOperation databaseOperation = testContext.getDatbaseOperationLookup()
230 .get(operation);
231 Assert.state(databaseOperation != null, "The database operation " + operation + " is not supported");
232 return databaseOperation;
233 }
234
235 private static class AnnotationAttributes {
236
237 private final DatabaseOperation type;
238
239 private final String[] value;
240
241 private final String connection;
242
243 public AnnotationAttributes(Annotation annotation) {
244 Assert.state((annotation instanceof DatabaseSetup) || (annotation instanceof DatabaseTearDown),
245 "Only DatabaseSetup and DatabaseTearDown annotations are supported");
246 Map<String, Object> attributes = AnnotationUtils.getAnnotationAttributes(annotation);
247 this.type = (DatabaseOperation) attributes.get("type");
248 this.value = (String[]) attributes.get("value");
249 this.connection = (String) attributes.get("connection");
250 }
251
252 public DatabaseOperation getType() {
253 return this.type;
254 }
255
256 public String[] getValue() {
257 return this.value;
258 }
259
260 public String getConnection() {
261 return this.connection;
262 }
263
264 public static <T extends Annotation> Collection<AnnotationAttributes> get(Annotations<T> annotations) {
265 List<AnnotationAttributes> annotationAttributes = new ArrayList<AnnotationAttributes>();
266 for (T annotation : annotations) {
267 annotationAttributes.add(new AnnotationAttributes(annotation));
268 }
269 return annotationAttributes;
270 }
271
272 }
273
274 private static class Annotations<T extends Annotation> implements Iterable<T> {
275
276 private final List<T> classAnnotations;
277
278 private final List<T> methodAnnotations;
279
280 private final List<T> allAnnotations;
281
282 public Annotations(DbUnitTestContext context, Class<? extends Annotation> container, Class<T> annotation) {
283 this.classAnnotations = getAnnotations(context.getTestClass(), container, annotation);
284 this.methodAnnotations = getAnnotations(context.getTestMethod(), container, annotation);
285 List<T> allAnnotations = new ArrayList<T>(this.classAnnotations.size() + this.methodAnnotations.size());
286 allAnnotations.addAll(this.classAnnotations);
287 allAnnotations.addAll(this.methodAnnotations);
288 this.allAnnotations = Collections.unmodifiableList(allAnnotations);
289 }
290
291 private List<T> getAnnotations(AnnotatedElement element, Class<? extends Annotation> container,
292 Class<T> annotation) {
293 List<T> annotations = new ArrayList<T>();
294 addAnnotationToList(annotations, AnnotationUtils.findAnnotation(element, annotation));
295 addRepeatableAnnotationsToList(annotations, AnnotationUtils.findAnnotation(element, container));
296 return Collections.unmodifiableList(annotations);
297 }
298
299 private void addAnnotationToList(List<T> annotations, T annotation) {
300 if (annotation != null) {
301 annotations.add(annotation);
302 }
303 }
304
305 @SuppressWarnings("unchecked")
306 private void addRepeatableAnnotationsToList(List<T> annotations, Annotation container) {
307 if (container != null) {
308 T[] value = (T[]) AnnotationUtils.getValue(container);
309 for (T annotation : value) {
310 annotations.add(annotation);
311 }
312 }
313 }
314
315 public List<T> getClassAnnotations() {
316 return this.classAnnotations;
317 }
318
319 public List<T> getMethodAnnotations() {
320 return this.methodAnnotations;
321 }
322
323 public Iterator<T> iterator() {
324 return this.allAnnotations.iterator();
325 }
326
327 private static <T extends Annotation> Annotations<T> get(DbUnitTestContext testContext,
328 Class<? extends Annotation> container, Class<T> annotation) {
329 return new Annotations<T>(testContext, container, annotation);
330 }
331
332 }
333
334 }