1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package com.github.springtestdbunit;
18
19 import java.lang.reflect.Method;
20 import java.util.Arrays;
21
22 import javax.sql.DataSource;
23
24 import org.apache.commons.logging.Log;
25 import org.apache.commons.logging.LogFactory;
26 import org.dbunit.database.IDatabaseConnection;
27 import org.springframework.context.ApplicationContext;
28 import org.springframework.core.Conventions;
29 import org.springframework.test.context.TestContext;
30 import org.springframework.test.context.support.AbstractTestExecutionListener;
31 import org.springframework.test.context.transaction.TransactionalTestExecutionListener;
32 import org.springframework.util.Assert;
33 import org.springframework.util.ObjectUtils;
34 import org.springframework.util.ReflectionUtils;
35 import org.springframework.util.StringUtils;
36
37 import com.github.springtestdbunit.annotation.DatabaseSetup;
38 import com.github.springtestdbunit.annotation.DatabaseTearDown;
39 import com.github.springtestdbunit.annotation.DbUnitConfiguration;
40 import com.github.springtestdbunit.annotation.ExpectedDatabase;
41 import com.github.springtestdbunit.bean.DatabaseDataSourceConnectionFactoryBean;
42 import com.github.springtestdbunit.dataset.DataSetLoader;
43 import com.github.springtestdbunit.dataset.FlatXmlDataSetLoader;
44 import com.github.springtestdbunit.operation.DatabaseOperationLookup;
45 import com.github.springtestdbunit.operation.DefaultDatabaseOperationLookup;
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 public class DbUnitTestExecutionListener extends AbstractTestExecutionListener {
67
68 private static final Log logger = LogFactory.getLog(DbUnitTestExecutionListener.class);
69
70 private static final String[] COMMON_DATABASE_CONNECTION_BEAN_NAMES = { "dbUnitDatabaseConnection", "dataSource" };
71
72 private static final String DATA_SET_LOADER_BEAN_NAME = "dbUnitDataSetLoader";
73
74 protected static final String CONNECTION_ATTRIBUTE = Conventions
75 .getQualifiedAttributeName(DbUnitTestExecutionListener.class, "connection");
76
77 protected static final String DATA_SET_LOADER_ATTRIBUTE = Conventions
78 .getQualifiedAttributeName(DbUnitTestExecutionListener.class, "dataSetLoader");
79
80 protected static final String DATABASE_OPERATION_LOOKUP_ATTRIBUTE = Conventions
81 .getQualifiedAttributeName(DbUnitTestExecutionListener.class, "databseOperationLookup");
82
83 private static DbUnitRunner runner = new DbUnitRunner();
84
85 @Override
86 public void prepareTestInstance(TestContext testContext) throws Exception {
87 prepareTestInstance(new DbUnitTestContextAdapter(testContext));
88 }
89
90 public void prepareTestInstance(DbUnitTestContextAdapter testContext) throws Exception {
91 if (logger.isDebugEnabled()) {
92 logger.debug("Preparing test instance " + testContext.getTestClass() + " for DBUnit");
93 }
94 String[] databaseConnectionBeanNames = null;
95 String dataSetLoaderBeanName = null;
96 Class<? extends DataSetLoader> dataSetLoaderClass = FlatXmlDataSetLoader.class;
97 Class<? extends DatabaseOperationLookup> databaseOperationLookupClass = DefaultDatabaseOperationLookup.class;
98
99 DbUnitConfiguration configuration = testContext.getTestClass().getAnnotation(DbUnitConfiguration.class);
100 if (configuration != null) {
101 if (logger.isDebugEnabled()) {
102 logger.debug("Using @DbUnitConfiguration configuration");
103 }
104 databaseConnectionBeanNames = configuration.databaseConnection();
105 dataSetLoaderClass = configuration.dataSetLoader();
106 dataSetLoaderBeanName = configuration.dataSetLoaderBean();
107 databaseOperationLookupClass = configuration.databaseOperationLookup();
108 }
109
110 if (ObjectUtils.isEmpty(databaseConnectionBeanNames)
111 || ((databaseConnectionBeanNames.length == 1) && StringUtils.isEmpty(databaseConnectionBeanNames[0]))) {
112 databaseConnectionBeanNames = new String[] { getDatabaseConnectionUsingCommonBeanNames(testContext) };
113 }
114
115 if (!StringUtils.hasLength(dataSetLoaderBeanName)) {
116 if (testContext.getApplicationContext().containsBean(DATA_SET_LOADER_BEAN_NAME)) {
117 dataSetLoaderBeanName = DATA_SET_LOADER_BEAN_NAME;
118 }
119 }
120
121 if (logger.isDebugEnabled()) {
122 logger.debug("DBUnit tests will run using databaseConnection \""
123 + StringUtils.arrayToCommaDelimitedString(databaseConnectionBeanNames)
124 + "\", datasets will be loaded using " + (StringUtils.hasLength(dataSetLoaderBeanName)
125 ? "'" + dataSetLoaderBeanName + "'" : dataSetLoaderClass));
126 }
127 prepareDatabaseConnection(testContext, databaseConnectionBeanNames);
128 prepareDataSetLoader(testContext, dataSetLoaderBeanName, dataSetLoaderClass);
129 prepareDatabaseOperationLookup(testContext, databaseOperationLookupClass);
130 }
131
132 private String getDatabaseConnectionUsingCommonBeanNames(DbUnitTestContextAdapter testContext) {
133 for (String beanName : COMMON_DATABASE_CONNECTION_BEAN_NAMES) {
134 if (testContext.getApplicationContext().containsBean(beanName)) {
135 return beanName;
136 }
137 }
138 throw new IllegalStateException(
139 "Unable to find a DB Unit database connection, missing one the following beans: "
140 + Arrays.asList(COMMON_DATABASE_CONNECTION_BEAN_NAMES));
141 }
142
143 private void prepareDatabaseConnection(DbUnitTestContextAdapter testContext, String[] connectionBeanNames)
144 throws Exception {
145 IDatabaseConnection[] connections = new IDatabaseConnection[connectionBeanNames.length];
146 for (int i = 0; i < connectionBeanNames.length; i++) {
147 Object databaseConnection = testContext.getApplicationContext().getBean(connectionBeanNames[i]);
148 if (databaseConnection instanceof DataSource) {
149 databaseConnection = DatabaseDataSourceConnectionFactoryBean
150 .newConnection((DataSource) databaseConnection);
151 }
152 Assert.isInstanceOf(IDatabaseConnection.class, databaseConnection);
153 connections[i] = (IDatabaseConnection) databaseConnection;
154 }
155 testContext.setAttribute(CONNECTION_ATTRIBUTE, new DatabaseConnections(connectionBeanNames, connections));
156 }
157
158 private void prepareDataSetLoader(DbUnitTestContextAdapter testContext, String beanName,
159 Class<? extends DataSetLoader> dataSetLoaderClass) {
160 if (StringUtils.hasLength(beanName)) {
161 testContext.setAttribute(DATA_SET_LOADER_ATTRIBUTE,
162 testContext.getApplicationContext().getBean(beanName, DataSetLoader.class));
163 } else {
164 try {
165 testContext.setAttribute(DATA_SET_LOADER_ATTRIBUTE, dataSetLoaderClass.newInstance());
166 } catch (Exception ex) {
167 throw new IllegalArgumentException(
168 "Unable to create data set loader instance for " + dataSetLoaderClass, ex);
169 }
170 }
171 }
172
173 private void prepareDatabaseOperationLookup(DbUnitTestContextAdapter testContext,
174 Class<? extends DatabaseOperationLookup> databaseOperationLookupClass) {
175 try {
176 testContext.setAttribute(DATABASE_OPERATION_LOOKUP_ATTRIBUTE, databaseOperationLookupClass.newInstance());
177 } catch (Exception ex) {
178 throw new IllegalArgumentException(
179 "Unable to create database operation lookup instance for " + databaseOperationLookupClass, ex);
180 }
181 }
182
183 @Override
184 public void beforeTestMethod(TestContext testContext) throws Exception {
185 runner.beforeTestMethod(new DbUnitTestContextAdapter(testContext));
186 }
187
188 @Override
189 public void afterTestMethod(TestContext testContext) throws Exception {
190 runner.afterTestMethod(new DbUnitTestContextAdapter(testContext));
191 }
192
193
194
195
196
197 private static class DbUnitTestContextAdapter implements DbUnitTestContext {
198
199 private static final Method GET_TEST_CLASS;
200 private static final Method GET_TEST_INSTANCE;
201 private static final Method GET_TEST_METHOD;
202 private static final Method GET_TEST_EXCEPTION;
203 private static final Method GET_APPLICATION_CONTEXT;
204 private static final Method GET_ATTRIBUTE;
205 private static final Method SET_ATTRIBUTE;
206
207 static {
208 try {
209 GET_TEST_CLASS = TestContext.class.getMethod("getTestClass");
210 GET_TEST_INSTANCE = TestContext.class.getMethod("getTestInstance");
211 GET_TEST_METHOD = TestContext.class.getMethod("getTestMethod");
212 GET_TEST_EXCEPTION = TestContext.class.getMethod("getTestException");
213 GET_APPLICATION_CONTEXT = TestContext.class.getMethod("getApplicationContext");
214 GET_ATTRIBUTE = TestContext.class.getMethod("getAttribute", String.class);
215 SET_ATTRIBUTE = TestContext.class.getMethod("setAttribute", String.class, Object.class);
216 } catch (Exception ex) {
217 throw new IllegalStateException(ex);
218 }
219 }
220
221 private TestContext testContext;
222
223 public DbUnitTestContextAdapter(TestContext testContext) {
224 this.testContext = testContext;
225 }
226
227 public DatabaseConnections getConnections() {
228 return (DatabaseConnections) getAttribute(CONNECTION_ATTRIBUTE);
229 }
230
231 public DataSetLoader getDataSetLoader() {
232 return (DataSetLoader) getAttribute(DATA_SET_LOADER_ATTRIBUTE);
233 }
234
235 public DatabaseOperationLookup getDatbaseOperationLookup() {
236 return (DatabaseOperationLookup) getAttribute(DATABASE_OPERATION_LOOKUP_ATTRIBUTE);
237 }
238
239 public Class<?> getTestClass() {
240 return (Class<?>) ReflectionUtils.invokeMethod(GET_TEST_CLASS, this.testContext);
241 }
242
243 public Method getTestMethod() {
244 return (Method) ReflectionUtils.invokeMethod(GET_TEST_METHOD, this.testContext);
245 }
246
247 public Object getTestInstance() {
248 return ReflectionUtils.invokeMethod(GET_TEST_INSTANCE, this.testContext);
249 }
250
251 public Throwable getTestException() {
252 return (Throwable) ReflectionUtils.invokeMethod(GET_TEST_EXCEPTION, this.testContext);
253 }
254
255 public ApplicationContext getApplicationContext() {
256 return (ApplicationContext) ReflectionUtils.invokeMethod(GET_APPLICATION_CONTEXT, this.testContext);
257 }
258
259 public Object getAttribute(String name) {
260 return ReflectionUtils.invokeMethod(GET_ATTRIBUTE, this.testContext, name);
261 }
262
263 public void setAttribute(String name, Object value) {
264 ReflectionUtils.invokeMethod(SET_ATTRIBUTE, this.testContext, name, value);
265 }
266
267 }
268 }