View Javadoc
1   /*
2    * Copyright 2002-2016 the original author or authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.github.springtestdbunit;
18  
19  import java.lang.annotation.Annotation;
20  import java.lang.reflect.AnnotatedElement;
21  import java.sql.SQLException;
22  import java.util.ArrayList;
23  import java.util.Collection;
24  import java.util.Collections;
25  import java.util.Iterator;
26  import java.util.LinkedList;
27  import java.util.List;
28  import java.util.Map;
29  
30  import org.apache.commons.logging.Log;
31  import org.apache.commons.logging.LogFactory;
32  import org.dbunit.DatabaseUnitException;
33  import org.dbunit.database.IDatabaseConnection;
34  import org.dbunit.dataset.CompositeDataSet;
35  import org.dbunit.dataset.DataSetException;
36  import org.dbunit.dataset.IDataSet;
37  import org.dbunit.dataset.ITable;
38  import org.dbunit.dataset.filter.IColumnFilter;
39  import org.springframework.core.annotation.AnnotationUtils;
40  import org.springframework.util.Assert;
41  import org.springframework.util.StringUtils;
42  
43  import com.github.springtestdbunit.annotation.DatabaseOperation;
44  import com.github.springtestdbunit.annotation.DatabaseSetup;
45  import com.github.springtestdbunit.annotation.DatabaseSetups;
46  import com.github.springtestdbunit.annotation.DatabaseTearDown;
47  import com.github.springtestdbunit.annotation.DatabaseTearDowns;
48  import com.github.springtestdbunit.annotation.ExpectedDatabase;
49  import com.github.springtestdbunit.annotation.ExpectedDatabases;
50  import com.github.springtestdbunit.assertion.DatabaseAssertion;
51  import com.github.springtestdbunit.dataset.DataSetLoader;
52  import com.github.springtestdbunit.dataset.DataSetModifier;
53  
54  /**
55   * Internal delegate class used to run tests with support for {@link DatabaseSetup @DatabaseSetup},
56   * {@link DatabaseTearDown @DatabaseTearDown} and {@link ExpectedDatabase @ExpectedDatabase} annotations.
57   *
58   * @author Phillip Webb
59   * @author Mario Zagar
60   * @author Sunitha Rajarathnam
61   * @author Oleksii Lomako
62   */
63  public class DbUnitRunner {
64  
65  	private static final Log logger = LogFactory.getLog(DbUnitTestExecutionListener.class);
66  
67  	/**
68  	 * Called before a test method is executed to perform any database setup.
69  	 * @param testContext The test context
70  	 * @throws Exception
71  	 */
72  	public void beforeTestMethod(DbUnitTestContext testContext) throws Exception {
73  		Annotations<DatabaseSetup> annotations = Annotations.get(testContext, DatabaseSetups.class,
74  				DatabaseSetup.class);
75  		setupOrTeardown(testContext, true, AnnotationAttributes.get(annotations));
76  	}
77  
78  	/**
79  	 * Called after a test method is executed to perform any database teardown and to check expected results.
80  	 * @param testContext The test context
81  	 * @throws Exception
82  	 */
83  	public void afterTestMethod(DbUnitTestContext testContext) throws Exception {
84  		try {
85  			try {
86  				verifyExpected(testContext,
87  						Annotations.get(testContext, ExpectedDatabases.class, ExpectedDatabase.class));
88  			} finally {
89  				Annotations<DatabaseTearDown> annotations = Annotations.get(testContext, DatabaseTearDowns.class,
90  						DatabaseTearDown.class);
91  				try {
92  					setupOrTeardown(testContext, false, AnnotationAttributes.get(annotations));
93  				} catch (RuntimeException ex) {
94  					if (testContext.getTestException() == null) {
95  						throw ex;
96  					}
97  					if (logger.isWarnEnabled()) {
98  						logger.warn("Unable to throw database cleanup exception due to existing test error", ex);
99  					}
100 				}
101 			}
102 		} finally {
103 			testContext.getConnections().closeAll();
104 		}
105 	}
106 
107 	private void verifyExpected(DbUnitTestContext testContext, Annotations<ExpectedDatabase> annotations)
108 			throws Exception {
109 		if (testContext.getTestException() != null) {
110 			if (logger.isDebugEnabled()) {
111 				logger.debug("Skipping @DatabaseTest expectation due to test exception "
112 						+ testContext.getTestException().getClass());
113 			}
114 			return;
115 		}
116 		DatabaseConnections connections = testContext.getConnections();
117 		DataSetModifier modifier = getModifier(testContext, annotations);
118 		boolean override = false;
119 		for (ExpectedDatabase annotation : annotations.getMethodAnnotations()) {
120 			verifyExpected(testContext, connections, modifier, annotation);
121 			override |= annotation.override();
122 		}
123 		if (!override) {
124 			for (ExpectedDatabase annotation : annotations.getClassAnnotations()) {
125 				verifyExpected(testContext, connections, modifier, annotation);
126 			}
127 		}
128 	}
129 
130 	private void verifyExpected(DbUnitTestContext testContext, DatabaseConnections connections,
131 			DataSetModifier modifier, ExpectedDatabase annotation)
132 					throws Exception, DataSetException, SQLException, DatabaseUnitException {
133 		String query = annotation.query();
134 		String table = annotation.table();
135 		IDataSet expectedDataSet = loadDataset(testContext, annotation.value(), modifier);
136 		IDatabaseConnection connection = connections.get(annotation.connection());
137 		if (expectedDataSet != null) {
138 			if (logger.isDebugEnabled()) {
139 				logger.debug("Veriftying @DatabaseTest expectation using " + annotation.value());
140 			}
141 			DatabaseAssertion assertion = annotation.assertionMode().getDatabaseAssertion();
142 			List<IColumnFilter> columnFilters = getColumnFilters(annotation);
143 			if (StringUtils.hasLength(query)) {
144 				Assert.hasLength(table, "The table name must be specified when using a SQL query");
145 				ITable expectedTable = expectedDataSet.getTable(table);
146 				ITable actualTable = connection.createQueryTable(table, query);
147 				assertion.assertEquals(expectedTable, actualTable, columnFilters);
148 			} else if (StringUtils.hasLength(table)) {
149 				ITable actualTable = connection.createTable(table);
150 				ITable expectedTable = expectedDataSet.getTable(table);
151 				assertion.assertEquals(expectedTable, actualTable, columnFilters);
152 			} else {
153 				IDataSet actualDataSet = connection.createDataSet();
154 				assertion.assertEquals(expectedDataSet, actualDataSet, columnFilters);
155 			}
156 		}
157 	}
158 
159 	private DataSetModifier getModifier(DbUnitTestContext testContext, Annotations<ExpectedDatabase> annotations) {
160 		DataSetModifiers modifiers = new DataSetModifiers();
161 		for (ExpectedDatabase annotation : annotations) {
162 			for (Class<? extends DataSetModifier> modifierClass : annotation.modifiers()) {
163 				modifiers.add(testContext.getTestInstance(), modifierClass);
164 			}
165 		}
166 		return modifiers;
167 	}
168 
169 	private void setupOrTeardown(DbUnitTestContext testContext, boolean isSetup,
170 			Collection<AnnotationAttributes> annotations) throws Exception {
171 		DatabaseConnections connections = testContext.getConnections();
172 		for (AnnotationAttributes annotation : annotations) {
173 			List<IDataSet> datasets = loadDataSets(testContext, annotation);
174 			DatabaseOperation operation = annotation.getType();
175 			org.dbunit.operation.DatabaseOperation dbUnitOperation = getDbUnitDatabaseOperation(testContext, operation);
176 			if (!datasets.isEmpty()) {
177 				if (logger.isDebugEnabled()) {
178 					logger.debug("Executing " + (isSetup ? "Setup" : "Teardown") + " of @DatabaseTest using "
179 							+ operation + " on " + datasets.toString());
180 				}
181 				IDatabaseConnection connection = connections.get(annotation.getConnection());
182 				IDataSet dataSet = new CompositeDataSet(datasets.toArray(new IDataSet[datasets.size()]));
183 				dbUnitOperation.execute(connection, dataSet);
184 			}
185 		}
186 	}
187 
188 	private List<IDataSet> loadDataSets(DbUnitTestContext testContext, AnnotationAttributes annotation)
189 			throws Exception {
190 		List<IDataSet> datasets = new ArrayList<IDataSet>();
191 		for (String dataSetLocation : annotation.getValue()) {
192 			datasets.add(loadDataset(testContext, dataSetLocation, DataSetModifier.NONE));
193 		}
194 		if (datasets.isEmpty()) {
195 			datasets.add(getFullDatabaseDataSet(testContext, annotation.getConnection()));
196 		}
197 		return datasets;
198 	}
199 
200 	private IDataSet getFullDatabaseDataSet(DbUnitTestContext testContext, String name) throws Exception {
201 		IDatabaseConnection connection = testContext.getConnections().get(name);
202 		return connection.createDataSet();
203 	}
204 
205 	private IDataSet loadDataset(DbUnitTestContext testContext, String dataSetLocation, DataSetModifier modifier)
206 			throws Exception {
207 		DataSetLoader dataSetLoader = testContext.getDataSetLoader();
208 		if (StringUtils.hasLength(dataSetLocation)) {
209 			IDataSet dataSet = dataSetLoader.loadDataSet(testContext.getTestClass(), dataSetLocation);
210 			dataSet = modifier.modify(dataSet);
211 			Assert.notNull(dataSet,
212 					"Unable to load dataset from \"" + dataSetLocation + "\" using " + dataSetLoader.getClass());
213 			return dataSet;
214 		}
215 		return null;
216 	}
217 
218 	private List<IColumnFilter> getColumnFilters(ExpectedDatabase annotation) throws Exception {
219 		Class<? extends IColumnFilter>[] columnFilterClasses = annotation.columnFilters();
220 		List<IColumnFilter> columnFilters = new LinkedList<IColumnFilter>();
221 		for (Class<? extends IColumnFilter> columnFilterClass : columnFilterClasses) {
222 			columnFilters.add(columnFilterClass.newInstance());
223 		}
224 		return columnFilters;
225 	}
226 
227 	private org.dbunit.operation.DatabaseOperation getDbUnitDatabaseOperation(DbUnitTestContext testContext,
228 			DatabaseOperation operation) {
229 		org.dbunit.operation.DatabaseOperation databaseOperation = testContext.getDatbaseOperationLookup()
230 				.get(operation);
231 		Assert.state(databaseOperation != null, "The database operation " + operation + " is not supported");
232 		return databaseOperation;
233 	}
234 
235 	private static class AnnotationAttributes {
236 
237 		private final DatabaseOperation type;
238 
239 		private final String[] value;
240 
241 		private final String connection;
242 
243 		public AnnotationAttributes(Annotation annotation) {
244 			Assert.state((annotation instanceof DatabaseSetup) || (annotation instanceof DatabaseTearDown),
245 					"Only DatabaseSetup and DatabaseTearDown annotations are supported");
246 			Map<String, Object> attributes = AnnotationUtils.getAnnotationAttributes(annotation);
247 			this.type = (DatabaseOperation) attributes.get("type");
248 			this.value = (String[]) attributes.get("value");
249 			this.connection = (String) attributes.get("connection");
250 		}
251 
252 		public DatabaseOperation getType() {
253 			return this.type;
254 		}
255 
256 		public String[] getValue() {
257 			return this.value;
258 		}
259 
260 		public String getConnection() {
261 			return this.connection;
262 		}
263 
264 		public static <T extends Annotation> Collection<AnnotationAttributes> get(Annotations<T> annotations) {
265 			List<AnnotationAttributes> annotationAttributes = new ArrayList<AnnotationAttributes>();
266 			for (T annotation : annotations) {
267 				annotationAttributes.add(new AnnotationAttributes(annotation));
268 			}
269 			return annotationAttributes;
270 		}
271 
272 	}
273 
274 	private static class Annotations<T extends Annotation> implements Iterable<T> {
275 
276 		private final List<T> classAnnotations;
277 
278 		private final List<T> methodAnnotations;
279 
280 		private final List<T> allAnnotations;
281 
282 		public Annotations(DbUnitTestContext context, Class<? extends Annotation> container, Class<T> annotation) {
283 			this.classAnnotations = getAnnotations(context.getTestClass(), container, annotation);
284 			this.methodAnnotations = getAnnotations(context.getTestMethod(), container, annotation);
285 			List<T> allAnnotations = new ArrayList<T>(this.classAnnotations.size() + this.methodAnnotations.size());
286 			allAnnotations.addAll(this.classAnnotations);
287 			allAnnotations.addAll(this.methodAnnotations);
288 			this.allAnnotations = Collections.unmodifiableList(allAnnotations);
289 		}
290 
291 		private List<T> getAnnotations(AnnotatedElement element, Class<? extends Annotation> container,
292 				Class<T> annotation) {
293 			List<T> annotations = new ArrayList<T>();
294 			addAnnotationToList(annotations, AnnotationUtils.findAnnotation(element, annotation));
295 			addRepeatableAnnotationsToList(annotations, AnnotationUtils.findAnnotation(element, container));
296 			return Collections.unmodifiableList(annotations);
297 		}
298 
299 		private void addAnnotationToList(List<T> annotations, T annotation) {
300 			if (annotation != null) {
301 				annotations.add(annotation);
302 			}
303 		}
304 
305 		@SuppressWarnings("unchecked")
306 		private void addRepeatableAnnotationsToList(List<T> annotations, Annotation container) {
307 			if (container != null) {
308 				T[] value = (T[]) AnnotationUtils.getValue(container);
309 				for (T annotation : value) {
310 					annotations.add(annotation);
311 				}
312 			}
313 		}
314 
315 		public List<T> getClassAnnotations() {
316 			return this.classAnnotations;
317 		}
318 
319 		public List<T> getMethodAnnotations() {
320 			return this.methodAnnotations;
321 		}
322 
323 		public Iterator<T> iterator() {
324 			return this.allAnnotations.iterator();
325 		}
326 
327 		private static <T extends Annotation> Annotations<T> get(DbUnitTestContext testContext,
328 				Class<? extends Annotation> container, Class<T> annotation) {
329 			return new Annotations<T>(testContext, container, annotation);
330 		}
331 
332 	}
333 
334 }