1
2
3
4
5
6
7
8
9
10
11
12
13
14 package waffle.shiro.negotiate;
15
16
17
18
19
20
21
22 import javax.servlet.ServletRequest;
23 import javax.servlet.ServletResponse;
24 import javax.servlet.http.HttpServletRequest;
25 import javax.servlet.http.HttpServletResponse;
26
27 import org.apache.shiro.authc.AuthenticationException;
28 import org.apache.shiro.authc.AuthenticationToken;
29 import org.apache.shiro.subject.Subject;
30 import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
31 import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
32 import org.apache.shiro.web.util.WebUtils;
33 import org.slf4j.Logger;
34 import org.slf4j.LoggerFactory;
35
36 import com.google.common.io.BaseEncoding;
37
38 import waffle.util.AuthorizationHeader;
39 import waffle.util.NtlmServletRequest;
40
41 import java.io.IOException;
42 import java.util.ArrayList;
43 import java.util.List;
44
45
46
47
48
49
50
51
52 public class NegotiateAuthenticationFilter extends AuthenticatingFilter {
53
54
55
56
57 private static final Logger LOGGER = LoggerFactory
58 .getLogger(NegotiateAuthenticationFilter.class);
59
60
61
62
63
64
65 private static final List<String> PROTOCOLS = new ArrayList<String>();
66
67
68 private String failureKeyAttribute = FormAuthenticationFilter.DEFAULT_ERROR_KEY_ATTRIBUTE_NAME;
69
70
71 private String rememberMeParam = FormAuthenticationFilter.DEFAULT_REMEMBER_ME_PARAM;
72
73
74
75
76 public NegotiateAuthenticationFilter() {
77 NegotiateAuthenticationFilter.PROTOCOLS.add("Negotiate");
78 NegotiateAuthenticationFilter.PROTOCOLS.add("NTLM");
79 }
80
81
82
83
84
85
86 public String getRememberMeParam() {
87 return this.rememberMeParam;
88 }
89
90
91
92
93
94
95
96
97
98
99
100
101 public void setRememberMeParam(final String value) {
102 this.rememberMeParam = value;
103 }
104
105
106
107
108 @Override
109 protected boolean isRememberMe(final ServletRequest request) {
110 return WebUtils.isTrue(request, this.getRememberMeParam());
111 }
112
113
114
115
116 @Override
117 protected AuthenticationToken createToken(final ServletRequest request, final ServletResponse response) {
118 final String authorization = this.getAuthzHeader(request);
119 final String[] elements = authorization.split(" ");
120 final byte[] inToken = BaseEncoding.base64().decode(elements[1]);
121
122
123
124 final String connectionId = NtlmServletRequest.getConnectionId((HttpServletRequest) request);
125 final String securityPackage = elements[0];
126
127
128 final AuthorizationHeader authorizationHeader = new AuthorizationHeader((HttpServletRequest) request);
129 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
130
131 NegotiateAuthenticationFilter.LOGGER.debug("security package: {}, connection id: {}, ntlmPost: {}", securityPackage, connectionId,
132 Boolean.valueOf(ntlmPost));
133
134 final boolean rememberMe = this.isRememberMe(request);
135 final String host = this.getHost(request);
136
137 return new NegotiateToken(inToken, new byte[0], connectionId, securityPackage, ntlmPost, rememberMe, host);
138 }
139
140
141
142
143 @Override
144 protected boolean onLoginSuccess(final AuthenticationToken token, final Subject subject,
145 final ServletRequest request, final ServletResponse response) throws Exception {
146 request.setAttribute("MY_SUBJECT", ((NegotiateToken) token).getSubject());
147 return true;
148 }
149
150
151
152
153 @Override
154 protected boolean onLoginFailure(final AuthenticationToken token, final AuthenticationException e,
155 final ServletRequest request, final ServletResponse response) {
156 if (e instanceof AuthenticationInProgressException) {
157
158 final String protocol = this.getAuthzHeaderProtocol(request);
159 NegotiateAuthenticationFilter.LOGGER.debug("Negotiation in progress for protocol: {}", protocol);
160 this.sendChallengeDuringNegotiate(protocol, response, ((NegotiateToken) token).getOut());
161 return false;
162 }
163 NegotiateAuthenticationFilter.LOGGER.warn("login exception: {}", e.getMessage());
164
165
166 this.sendChallengeOnFailure(response);
167
168 this.setFailureAttribute(request, e);
169 return true;
170 }
171
172
173
174
175
176
177
178
179
180 protected void setFailureAttribute(final ServletRequest request, final AuthenticationException ae) {
181 final String className = ae.getClass().getName();
182 request.setAttribute(this.getFailureKeyAttribute(), className);
183 }
184
185
186
187
188
189
190 public String getFailureKeyAttribute() {
191 return this.failureKeyAttribute;
192 }
193
194
195
196
197
198
199
200 public void setFailureKeyAttribute(final String value) {
201 this.failureKeyAttribute = value;
202 }
203
204
205
206
207 @Override
208 protected boolean onAccessDenied(final ServletRequest request, final ServletResponse response) throws Exception {
209
210 boolean loggedIn = false;
211
212 if (this.isLoginAttempt(request)) {
213 loggedIn = this.executeLogin(request, response);
214 } else {
215 NegotiateAuthenticationFilter.LOGGER.debug("authorization required, supported protocols: {}", NegotiateAuthenticationFilter.PROTOCOLS);
216 this.sendChallengeInitiateNegotiate(response);
217 }
218 return loggedIn;
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234 private boolean isLoginAttempt(final ServletRequest request) {
235 final String authzHeader = this.getAuthzHeader(request);
236 return authzHeader != null && this.isLoginAttempt(authzHeader);
237 }
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252 private String getAuthzHeader(final ServletRequest request) {
253 final HttpServletRequest httpRequest = WebUtils.toHttp(request);
254 return httpRequest.getHeader("Authorization");
255 }
256
257
258
259
260
261
262
263
264 private String getAuthzHeaderProtocol(final ServletRequest request) {
265 final String authzHeader = this.getAuthzHeader(request);
266 return authzHeader.substring(0, authzHeader.indexOf(" "));
267 }
268
269
270
271
272
273
274
275
276
277
278
279 boolean isLoginAttempt(final String authzHeader) {
280 for (final String protocol : NegotiateAuthenticationFilter.PROTOCOLS) {
281 if (authzHeader.toLowerCase().startsWith(protocol.toLowerCase())) {
282 return true;
283 }
284 }
285 return false;
286 }
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302 private void sendChallenge(final List<String> protocols, final ServletResponse response, final byte[] out) {
303 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
304 this.sendAuthenticateHeader(protocols, out, httpResponse);
305 httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
306 }
307
308
309
310
311
312
313
314 void sendChallengeInitiateNegotiate(final ServletResponse response) {
315 this.sendChallenge(NegotiateAuthenticationFilter.PROTOCOLS, response, null);
316 }
317
318
319
320
321
322
323
324
325
326
327
328 void sendChallengeDuringNegotiate(final String protocol, final ServletResponse response, final byte[] out) {
329 final List<String> protocolsList = new ArrayList<String>();
330 protocolsList.add(protocol);
331 this.sendChallenge(protocolsList, response, out);
332 }
333
334
335
336
337
338
339
340 void sendChallengeOnFailure(final ServletResponse response) {
341 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
342 this.sendUnauthorized(NegotiateAuthenticationFilter.PROTOCOLS, null, httpResponse);
343 httpResponse.setHeader("Connection", "close");
344 try {
345 httpResponse.sendError(HttpServletResponse.SC_UNAUTHORIZED);
346 httpResponse.flushBuffer();
347 } catch (final IOException e) {
348 throw new RuntimeException(e);
349 }
350 }
351
352
353
354
355
356
357
358
359
360
361
362 private void sendAuthenticateHeader(final List<String> protocolsList, final byte[] out,
363 final HttpServletResponse httpResponse) {
364 this.sendUnauthorized(protocolsList, out, httpResponse);
365 httpResponse.setHeader("Connection", "keep-alive");
366 }
367
368
369
370
371
372
373
374
375
376
377
378 private void sendUnauthorized(final List<String> protocols, final byte[] out, final HttpServletResponse response) {
379 for (final String protocol : protocols) {
380 if (out == null || out.length == 0) {
381 response.addHeader("WWW-Authenticate", protocol);
382 } else {
383 response.setHeader("WWW-Authenticate", protocol + " " + BaseEncoding.base64().encode(out));
384 }
385 }
386 }
387
388 }