1
2
3
4
5
6
7
8
9
10
11
12
13
14 package waffle.servlet.spi;
15
16 import java.io.IOException;
17 import java.security.InvalidParameterException;
18 import java.util.ArrayList;
19 import java.util.Iterator;
20 import java.util.List;
21
22 import javax.servlet.http.HttpServletRequest;
23 import javax.servlet.http.HttpServletResponse;
24
25 import org.slf4j.Logger;
26 import org.slf4j.LoggerFactory;
27
28 import com.google.common.io.BaseEncoding;
29
30 import waffle.util.AuthorizationHeader;
31 import waffle.util.NtlmServletRequest;
32 import waffle.windows.auth.IWindowsAuthProvider;
33 import waffle.windows.auth.IWindowsIdentity;
34 import waffle.windows.auth.IWindowsSecurityContext;
35
36
37
38
39
40
41 public class NegotiateSecurityFilterProvider implements SecurityFilterProvider {
42
43
44 private static final Logger LOGGER = LoggerFactory.getLogger(NegotiateSecurityFilterProvider.class);
45
46
47 private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
48
49
50 private static final String PROTOCOLS = "protocols";
51
52
53 private static final String NEGOTIATE = "Negotiate";
54
55
56 private static final String NTLM = "NTLM";
57
58
59 private List<String> protocols = new ArrayList<String>();
60
61
62 private final IWindowsAuthProvider auth;
63
64
65
66
67
68
69
70 public NegotiateSecurityFilterProvider(final IWindowsAuthProvider newAuthProvider) {
71 this.auth = newAuthProvider;
72 this.protocols.add(NegotiateSecurityFilterProvider.NEGOTIATE);
73 this.protocols.add(NegotiateSecurityFilterProvider.NTLM);
74 }
75
76
77
78
79
80
81 public List<String> getProtocols() {
82 return this.protocols;
83 }
84
85
86
87
88
89
90
91 public void setProtocols(final List<String> values) {
92 this.protocols = values;
93 }
94
95
96
97
98 @Override
99 public void sendUnauthorized(final HttpServletResponse response) {
100 final Iterator<String> protocolsIterator = this.protocols.iterator();
101 while (protocolsIterator.hasNext()) {
102 response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, protocolsIterator.next());
103 }
104 }
105
106
107
108
109 @Override
110 public boolean isPrincipalException(final HttpServletRequest request) {
111 final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
112 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
113 NegotiateSecurityFilterProvider.LOGGER.debug("authorization: {}, ntlm post: {}", authorizationHeader, Boolean.valueOf(ntlmPost));
114 return ntlmPost;
115 }
116
117
118
119
120 @Override
121 public IWindowsIdentity doFilter(final HttpServletRequest request, final HttpServletResponse response)
122 throws IOException {
123
124 final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
125 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
126
127
128 final String connectionId = NtlmServletRequest.getConnectionId(request);
129 final String securityPackage = authorizationHeader.getSecurityPackage();
130 NegotiateSecurityFilterProvider.LOGGER.debug("security package: {}, connection id: {}", securityPackage, connectionId);
131
132 if (ntlmPost) {
133
134 this.auth.resetSecurityToken(connectionId);
135 }
136
137 final byte[] tokenBuffer = authorizationHeader.getTokenBytes();
138 NegotiateSecurityFilterProvider.LOGGER.debug("token buffer: {} byte(s)", Integer.valueOf(tokenBuffer.length));
139 final IWindowsSecurityContext securityContext = this.auth.acceptSecurityToken(connectionId, tokenBuffer,
140 securityPackage);
141
142 final byte[] continueTokenBytes = securityContext.getToken();
143 if (continueTokenBytes != null && continueTokenBytes.length > 0) {
144 final String continueToken = BaseEncoding.base64().encode(continueTokenBytes);
145 NegotiateSecurityFilterProvider.LOGGER.debug("continue token: {}", continueToken);
146 response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, securityPackage + " " + continueToken);
147 }
148
149 NegotiateSecurityFilterProvider.LOGGER.debug("continue required: {}", Boolean.valueOf(securityContext.isContinue()));
150 if (securityContext.isContinue() || ntlmPost) {
151 response.setHeader("Connection", "keep-alive");
152 response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
153 response.flushBuffer();
154 return null;
155 }
156
157 final IWindowsIdentity identity = securityContext.getIdentity();
158 securityContext.dispose();
159 return identity;
160 }
161
162
163
164
165 @Override
166 public boolean isSecurityPackageSupported(final String securityPackage) {
167 for (final String protocol : this.protocols) {
168 if (protocol.equalsIgnoreCase(securityPackage)) {
169 return true;
170 }
171 }
172 return false;
173 }
174
175
176
177
178 @Override
179 public void initParameter(final String parameterName, final String parameterValue) {
180 if (parameterName.equals(NegotiateSecurityFilterProvider.PROTOCOLS)) {
181 this.protocols = new ArrayList<String>();
182 final String[] protocolNames = parameterValue.split("\\s+");
183 for (String protocolName : protocolNames) {
184 protocolName = protocolName.trim();
185 if (protocolName.length() > 0) {
186 NegotiateSecurityFilterProvider.LOGGER.debug("init protocol: {}", protocolName);
187 if (protocolName.equals(NegotiateSecurityFilterProvider.NEGOTIATE) || protocolName.equals(NegotiateSecurityFilterProvider.NTLM)) {
188 this.protocols.add(protocolName);
189 } else {
190 NegotiateSecurityFilterProvider.LOGGER.error("unsupported protocol: {}", protocolName);
191 throw new RuntimeException("Unsupported protocol: " + protocolName);
192 }
193 }
194 }
195 } else {
196 throw new InvalidParameterException(parameterName);
197 }
198 }
199 }