SecurityFilterProviderCollection.java
/**
* Waffle (https://github.com/dblock/waffle)
*
* Copyright (c) 2010 - 2015 Application Security, Inc.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Application Security, Inc.
*/
package waffle.servlet.spi;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.sun.jna.platform.win32.Win32Exception;
import waffle.util.AuthorizationHeader;
import waffle.windows.auth.IWindowsAuthProvider;
import waffle.windows.auth.IWindowsIdentity;
/**
* A collection of security filter providers.
*
* @author dblock[at]dblock[dot]org
*/
public class SecurityFilterProviderCollection {
/** The Constant LOGGER. */
private static final Logger LOGGER = LoggerFactory.getLogger(SecurityFilterProviderCollection.class);
/** The providers. */
private final List<SecurityFilterProvider> providers = new ArrayList<SecurityFilterProvider>();
/**
* Instantiates a new security filter provider collection.
*
* @param providerArray
* the provider array
*/
public SecurityFilterProviderCollection(final SecurityFilterProvider[] providerArray) {
for (final SecurityFilterProvider provider : providerArray) {
SecurityFilterProviderCollection.LOGGER.info("using '{}'", provider.getClass().getName());
this.providers.add(provider);
}
}
/**
* Instantiates a new security filter provider collection.
*
* @param providerNames
* the provider names
* @param auth
* the auth
*/
@SuppressWarnings("unchecked")
public SecurityFilterProviderCollection(final String[] providerNames, final IWindowsAuthProvider auth) {
Class<SecurityFilterProvider> providerClass;
Constructor<SecurityFilterProvider> providerConstructor;
for (String providerName : providerNames) {
providerName = providerName.trim();
SecurityFilterProviderCollection.LOGGER.info("loading '{}'", providerName);
try {
providerClass = (Class<SecurityFilterProvider>) Class.forName(providerName);
providerConstructor = providerClass.getConstructor(IWindowsAuthProvider.class);
final SecurityFilterProvider provider = providerConstructor.newInstance(auth);
this.providers.add(provider);
} catch (final ClassNotFoundException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
throw new RuntimeException(e);
} catch (final SecurityException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
} catch (final NoSuchMethodException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
} catch (final IllegalArgumentException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
} catch (final InstantiationException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
} catch (final IllegalAccessException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
} catch (final InvocationTargetException e) {
SecurityFilterProviderCollection.LOGGER.error("error loading '{}': {}", providerName, e.getMessage());
SecurityFilterProviderCollection.LOGGER.trace("{}", e);
}
}
}
/**
* Instantiates a new security filter provider collection.
*
* @param auth
* the auth
*/
public SecurityFilterProviderCollection(final IWindowsAuthProvider auth) {
this.providers.add(new NegotiateSecurityFilterProvider(auth));
this.providers.add(new BasicSecurityFilterProvider(auth));
}
/**
* Tests whether a specific security package is supported by any of the underlying providers.
*
* @param securityPackage
* Security package.
* @return True if the security package is supported, false otherwise.
*/
public boolean isSecurityPackageSupported(final String securityPackage) {
return this.get(securityPackage) != null;
}
/**
* Gets the.
*
* @param securityPackage
* the security package
* @return the security filter provider
*/
private SecurityFilterProvider get(final String securityPackage) {
for (final SecurityFilterProvider provider : this.providers) {
if (provider.isSecurityPackageSupported(securityPackage)) {
return provider;
}
}
return null;
}
/**
* Filter.
*
* @param request
* Http Request
* @param response
* Http Response
* @return Windows Identity or NULL.
* @throws IOException
* on doFilter.
*/
public IWindowsIdentity doFilter(final HttpServletRequest request, final HttpServletResponse response)
throws IOException {
final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
final SecurityFilterProvider provider = this.get(authorizationHeader.getSecurityPackage());
if (provider == null) {
throw new RuntimeException("Unsupported security package: " + authorizationHeader.getSecurityPackage());
}
try {
return provider.doFilter(request, response);
} catch (final Win32Exception e) {
throw new IOException(e);
}
}
/**
* Returns true if authentication still needs to happen despite an existing principal.
*
* @param request
* Http Request
* @return True if authentication is required.
*/
public boolean isPrincipalException(final HttpServletRequest request) {
for (final SecurityFilterProvider provider : this.providers) {
if (provider.isPrincipalException(request)) {
return true;
}
}
return false;
}
/**
* Send authorization headers.
*
* @param response
* Http Response
*/
public void sendUnauthorized(final HttpServletResponse response) {
for (final SecurityFilterProvider provider : this.providers) {
provider.sendUnauthorized(response);
}
}
/**
* Number of providers.
*
* @return Number of providers.
*/
public int size() {
return this.providers.size();
}
/**
* Get a security provider by class name.
*
* @param name
* Class name.
* @return A security provider instance.
* @throws ClassNotFoundException
* when class not found.
*/
public SecurityFilterProvider getByClassName(final String name) throws ClassNotFoundException {
for (final SecurityFilterProvider provider : this.providers) {
if (provider.getClass().getName().equals(name)) {
return provider;
}
}
throw new ClassNotFoundException(name);
}
}