View Javadoc
1   /*
2    * Copyright 2002-2016 the original author or authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.github.springtestdbunit;
18  
19  import java.lang.reflect.Method;
20  import java.util.Arrays;
21  
22  import javax.sql.DataSource;
23  
24  import org.apache.commons.logging.Log;
25  import org.apache.commons.logging.LogFactory;
26  import org.dbunit.database.IDatabaseConnection;
27  import org.springframework.context.ApplicationContext;
28  import org.springframework.core.Conventions;
29  import org.springframework.test.context.TestContext;
30  import org.springframework.test.context.support.AbstractTestExecutionListener;
31  import org.springframework.test.context.transaction.TransactionalTestExecutionListener;
32  import org.springframework.util.Assert;
33  import org.springframework.util.ObjectUtils;
34  import org.springframework.util.ReflectionUtils;
35  import org.springframework.util.StringUtils;
36  
37  import com.github.springtestdbunit.annotation.DatabaseSetup;
38  import com.github.springtestdbunit.annotation.DatabaseTearDown;
39  import com.github.springtestdbunit.annotation.DbUnitConfiguration;
40  import com.github.springtestdbunit.annotation.ExpectedDatabase;
41  import com.github.springtestdbunit.bean.DatabaseDataSourceConnectionFactoryBean;
42  import com.github.springtestdbunit.dataset.DataSetLoader;
43  import com.github.springtestdbunit.dataset.FlatXmlDataSetLoader;
44  import com.github.springtestdbunit.operation.DatabaseOperationLookup;
45  import com.github.springtestdbunit.operation.DefaultDatabaseOperationLookup;
46  
47  /**
48   * <code>TestExecutionListener</code> which provides support for {@link DatabaseSetup &#064;DatabaseSetup},
49   * {@link DatabaseTearDown &#064;DatabaseTearDown} and {@link ExpectedDatabase &#064;ExpectedDatabase} annotations.
50   * <p>
51   * A bean named "<tt>dbUnitDatabaseConnection</tt>" or "<tt>dataSource</tt>" is expected in the
52   * <tt>ApplicationContext</tt> associated with the test. This bean can contain either a {@link IDatabaseConnection} or a
53   * {@link DataSource} . A custom bean name can also be specified using the
54   * {@link DbUnitConfiguration#databaseConnection() &#064;DbUnitConfiguration} annotation.
55   * <p>
56   * Datasets are loaded using the {@link FlatXmlDataSetLoader} and DBUnit database operation lookups are performed using
57   * the {@link DefaultDatabaseOperationLookup} unless otherwise {@link DbUnitConfiguration#dataSetLoader() configured}.
58   * <p>
59   * If you are running this listener in combination with the {@link TransactionalTestExecutionListener} then consider
60   * using {@link TransactionDbUnitTestExecutionListener} instead.
61   *
62   * @see TransactionDbUnitTestExecutionListener
63   *
64   * @author Phillip Webb
65   */
66  public class DbUnitTestExecutionListener extends AbstractTestExecutionListener {
67  
68  	private static final Log logger = LogFactory.getLog(DbUnitTestExecutionListener.class);
69  
70  	private static final String[] COMMON_DATABASE_CONNECTION_BEAN_NAMES = { "dbUnitDatabaseConnection", "dataSource" };
71  
72  	private static final String DATA_SET_LOADER_BEAN_NAME = "dbUnitDataSetLoader";
73  
74  	protected static final String CONNECTION_ATTRIBUTE = Conventions
75  			.getQualifiedAttributeName(DbUnitTestExecutionListener.class, "connection");
76  
77  	protected static final String DATA_SET_LOADER_ATTRIBUTE = Conventions
78  			.getQualifiedAttributeName(DbUnitTestExecutionListener.class, "dataSetLoader");
79  
80  	protected static final String DATABASE_OPERATION_LOOKUP_ATTRIBUTE = Conventions
81  			.getQualifiedAttributeName(DbUnitTestExecutionListener.class, "databseOperationLookup");
82  
83  	private static DbUnitRunner runner = new DbUnitRunner();
84  
85  	@Override
86  	public void prepareTestInstance(TestContext testContext) throws Exception {
87  		prepareTestInstance(new DbUnitTestContextAdapter(testContext));
88  	}
89  
90  	public void prepareTestInstance(DbUnitTestContextAdapter testContext) throws Exception {
91  		if (logger.isDebugEnabled()) {
92  			logger.debug("Preparing test instance " + testContext.getTestClass() + " for DBUnit");
93  		}
94  		String[] databaseConnectionBeanNames = null;
95  		String dataSetLoaderBeanName = null;
96  		Class<? extends DataSetLoader> dataSetLoaderClass = FlatXmlDataSetLoader.class;
97  		Class<? extends DatabaseOperationLookup> databaseOperationLookupClass = DefaultDatabaseOperationLookup.class;
98  
99  		DbUnitConfiguration configuration = testContext.getTestClass().getAnnotation(DbUnitConfiguration.class);
100 		if (configuration != null) {
101 			if (logger.isDebugEnabled()) {
102 				logger.debug("Using @DbUnitConfiguration configuration");
103 			}
104 			databaseConnectionBeanNames = configuration.databaseConnection();
105 			dataSetLoaderClass = configuration.dataSetLoader();
106 			dataSetLoaderBeanName = configuration.dataSetLoaderBean();
107 			databaseOperationLookupClass = configuration.databaseOperationLookup();
108 		}
109 
110 		if (ObjectUtils.isEmpty(databaseConnectionBeanNames)
111 				|| ((databaseConnectionBeanNames.length == 1) && StringUtils.isEmpty(databaseConnectionBeanNames[0]))) {
112 			databaseConnectionBeanNames = new String[] { getDatabaseConnectionUsingCommonBeanNames(testContext) };
113 		}
114 
115 		if (!StringUtils.hasLength(dataSetLoaderBeanName)) {
116 			if (testContext.getApplicationContext().containsBean(DATA_SET_LOADER_BEAN_NAME)) {
117 				dataSetLoaderBeanName = DATA_SET_LOADER_BEAN_NAME;
118 			}
119 		}
120 
121 		if (logger.isDebugEnabled()) {
122 			logger.debug("DBUnit tests will run using databaseConnection \""
123 					+ StringUtils.arrayToCommaDelimitedString(databaseConnectionBeanNames)
124 					+ "\", datasets will be loaded using " + (StringUtils.hasLength(dataSetLoaderBeanName)
125 							? "'" + dataSetLoaderBeanName + "'" : dataSetLoaderClass));
126 		}
127 		prepareDatabaseConnection(testContext, databaseConnectionBeanNames);
128 		prepareDataSetLoader(testContext, dataSetLoaderBeanName, dataSetLoaderClass);
129 		prepareDatabaseOperationLookup(testContext, databaseOperationLookupClass);
130 	}
131 
132 	private String getDatabaseConnectionUsingCommonBeanNames(DbUnitTestContextAdapter testContext) {
133 		for (String beanName : COMMON_DATABASE_CONNECTION_BEAN_NAMES) {
134 			if (testContext.getApplicationContext().containsBean(beanName)) {
135 				return beanName;
136 			}
137 		}
138 		throw new IllegalStateException(
139 				"Unable to find a DB Unit database connection, missing one the following beans: "
140 						+ Arrays.asList(COMMON_DATABASE_CONNECTION_BEAN_NAMES));
141 	}
142 
143 	private void prepareDatabaseConnection(DbUnitTestContextAdapter testContext, String[] connectionBeanNames)
144 			throws Exception {
145 		IDatabaseConnection[] connections = new IDatabaseConnection[connectionBeanNames.length];
146 		for (int i = 0; i < connectionBeanNames.length; i++) {
147 			Object databaseConnection = testContext.getApplicationContext().getBean(connectionBeanNames[i]);
148 			if (databaseConnection instanceof DataSource) {
149 				databaseConnection = DatabaseDataSourceConnectionFactoryBean
150 						.newConnection((DataSource) databaseConnection);
151 			}
152 			Assert.isInstanceOf(IDatabaseConnection.class, databaseConnection);
153 			connections[i] = (IDatabaseConnection) databaseConnection;
154 		}
155 		testContext.setAttribute(CONNECTION_ATTRIBUTE, new DatabaseConnections(connectionBeanNames, connections));
156 	}
157 
158 	private void prepareDataSetLoader(DbUnitTestContextAdapter testContext, String beanName,
159 			Class<? extends DataSetLoader> dataSetLoaderClass) {
160 		if (StringUtils.hasLength(beanName)) {
161 			testContext.setAttribute(DATA_SET_LOADER_ATTRIBUTE,
162 					testContext.getApplicationContext().getBean(beanName, DataSetLoader.class));
163 		} else {
164 			try {
165 				testContext.setAttribute(DATA_SET_LOADER_ATTRIBUTE, dataSetLoaderClass.newInstance());
166 			} catch (Exception ex) {
167 				throw new IllegalArgumentException(
168 						"Unable to create data set loader instance for " + dataSetLoaderClass, ex);
169 			}
170 		}
171 	}
172 
173 	private void prepareDatabaseOperationLookup(DbUnitTestContextAdapter testContext,
174 			Class<? extends DatabaseOperationLookup> databaseOperationLookupClass) {
175 		try {
176 			testContext.setAttribute(DATABASE_OPERATION_LOOKUP_ATTRIBUTE, databaseOperationLookupClass.newInstance());
177 		} catch (Exception ex) {
178 			throw new IllegalArgumentException(
179 					"Unable to create database operation lookup instance for " + databaseOperationLookupClass, ex);
180 		}
181 	}
182 
183 	@Override
184 	public void beforeTestMethod(TestContext testContext) throws Exception {
185 		runner.beforeTestMethod(new DbUnitTestContextAdapter(testContext));
186 	}
187 
188 	@Override
189 	public void afterTestMethod(TestContext testContext) throws Exception {
190 		runner.afterTestMethod(new DbUnitTestContextAdapter(testContext));
191 	}
192 
193 	/**
194 	 * Adapter class to convert Spring's {@link TestContext} to a {@link DbUnitTestContext}. Since Spring 4.0 change the
195 	 * TestContext class from a class to an interface this method uses reflection.
196 	 */
197 	private static class DbUnitTestContextAdapter implements DbUnitTestContext {
198 
199 		private static final Method GET_TEST_CLASS;
200 		private static final Method GET_TEST_INSTANCE;
201 		private static final Method GET_TEST_METHOD;
202 		private static final Method GET_TEST_EXCEPTION;
203 		private static final Method GET_APPLICATION_CONTEXT;
204 		private static final Method GET_ATTRIBUTE;
205 		private static final Method SET_ATTRIBUTE;
206 
207 		static {
208 			try {
209 				GET_TEST_CLASS = TestContext.class.getMethod("getTestClass");
210 				GET_TEST_INSTANCE = TestContext.class.getMethod("getTestInstance");
211 				GET_TEST_METHOD = TestContext.class.getMethod("getTestMethod");
212 				GET_TEST_EXCEPTION = TestContext.class.getMethod("getTestException");
213 				GET_APPLICATION_CONTEXT = TestContext.class.getMethod("getApplicationContext");
214 				GET_ATTRIBUTE = TestContext.class.getMethod("getAttribute", String.class);
215 				SET_ATTRIBUTE = TestContext.class.getMethod("setAttribute", String.class, Object.class);
216 			} catch (Exception ex) {
217 				throw new IllegalStateException(ex);
218 			}
219 		}
220 
221 		private TestContext testContext;
222 
223 		public DbUnitTestContextAdapter(TestContext testContext) {
224 			this.testContext = testContext;
225 		}
226 
227 		public DatabaseConnections getConnections() {
228 			return (DatabaseConnections) getAttribute(CONNECTION_ATTRIBUTE);
229 		}
230 
231 		public DataSetLoader getDataSetLoader() {
232 			return (DataSetLoader) getAttribute(DATA_SET_LOADER_ATTRIBUTE);
233 		}
234 
235 		public DatabaseOperationLookup getDatbaseOperationLookup() {
236 			return (DatabaseOperationLookup) getAttribute(DATABASE_OPERATION_LOOKUP_ATTRIBUTE);
237 		}
238 
239 		public Class<?> getTestClass() {
240 			return (Class<?>) ReflectionUtils.invokeMethod(GET_TEST_CLASS, this.testContext);
241 		}
242 
243 		public Method getTestMethod() {
244 			return (Method) ReflectionUtils.invokeMethod(GET_TEST_METHOD, this.testContext);
245 		}
246 
247 		public Object getTestInstance() {
248 			return ReflectionUtils.invokeMethod(GET_TEST_INSTANCE, this.testContext);
249 		}
250 
251 		public Throwable getTestException() {
252 			return (Throwable) ReflectionUtils.invokeMethod(GET_TEST_EXCEPTION, this.testContext);
253 		}
254 
255 		public ApplicationContext getApplicationContext() {
256 			return (ApplicationContext) ReflectionUtils.invokeMethod(GET_APPLICATION_CONTEXT, this.testContext);
257 		}
258 
259 		public Object getAttribute(String name) {
260 			return ReflectionUtils.invokeMethod(GET_ATTRIBUTE, this.testContext, name);
261 		}
262 
263 		public void setAttribute(String name, Object value) {
264 			ReflectionUtils.invokeMethod(SET_ATTRIBUTE, this.testContext, name, value);
265 		}
266 
267 	}
268 }