Skip to content
Snippets Groups Projects
Commit e6ec0673 authored by Alinga Yeung's avatar Alinga Yeung
Browse files

Merge branch 'ac2' of /srv/cadc/git/wopencadc into ac2

parents ebdf965e 9af3038e
No related branches found
No related tags found
No related merge requests found
...@@ -82,6 +82,7 @@ import java.util.Set; ...@@ -82,6 +82,7 @@ import java.util.Set;
public class GetUserAction extends AbstractUserAction public class GetUserAction extends AbstractUserAction
{ {
private static final Logger log = Logger.getLogger(GetUserAction.class); private static final Logger log = Logger.getLogger(GetUserAction.class);
private final Principal userID; private final Principal userID;
private final String detail; private final String detail;
...@@ -105,30 +106,33 @@ public class GetUserAction extends AbstractUserAction ...@@ -105,30 +106,33 @@ public class GetUserAction extends AbstractUserAction
/** /**
* Special case 1 * Special case 1
* If the calling Subject user is the notAugmentedX500User, AND it is
* a GET, call the userDAO to get the User with all identities.
*/
if (isAugmentUser())
{
log.debug("getting augmented user " + principal.getName());
user = userPersistence.getAugmentedUser(principal);
}
/**
* Special case 2
* If detail=identity, AND if the calling Subject user is the same as * If detail=identity, AND if the calling Subject user is the same as
* the requested User, then return the User with the principals from the * the requested User, then return the User with the principals from the
* Subject which has already been augmented. * Subject which has already been augmented.
*/ */
if (detail != null && else if (detail != null &&
detail.equalsIgnoreCase("identity") && detail.equalsIgnoreCase("identity") &&
isSubjectUser(principal)) isSubjectUser(principal))
{ {
log.debug("augmenting " + principal.getName() + " from subject");
Subject subject = Subject.getSubject(AccessController.getContext()); Subject subject = Subject.getSubject(AccessController.getContext());
user = new User<Principal>(principal); user = new User<Principal>(principal);
user.getIdentities().addAll(subject.getPrincipals()); user.getIdentities().addAll(subject.getPrincipals());
} }
/**
* Special case 2
* If the calling Subject user is the notAugmentedX500User, AND it is
* a GET, call the userDAO to get the User with all identities.
*/
else if (this.isAugmentUser)
{
user = userPersistence.getAugmentedUser(principal);
}
else else
{ {
log.debug("getting user " + principal.getName());
try try
{ {
user = userPersistence.getUser(principal); user = userPersistence.getUser(principal);
......
...@@ -90,6 +90,7 @@ public class UserServlet extends HttpServlet ...@@ -90,6 +90,7 @@ public class UserServlet extends HttpServlet
private static final long serialVersionUID = 5289130885807305288L; private static final long serialVersionUID = 5289130885807305288L;
private static final Logger log = Logger.getLogger(UserServlet.class); private static final Logger log = Logger.getLogger(UserServlet.class);
private String notAugmentedX500User; private String notAugmentedX500User;
@Override @Override
...@@ -121,7 +122,6 @@ public class UserServlet extends HttpServlet ...@@ -121,7 +122,6 @@ public class UserServlet extends HttpServlet
{ {
log.info(logInfo.start()); log.info(logInfo.start());
AbstractUserAction action = factory.createAction(request); AbstractUserAction action = factory.createAction(request);
SyncOutput syncOut = new SyncOutput(response);
// Special case: if the calling subject has a servops X500Principal, // Special case: if the calling subject has a servops X500Principal,
// AND it is a GET request, do not augment the subject. // AND it is a GET request, do not augment the subject.
...@@ -129,14 +129,17 @@ public class UserServlet extends HttpServlet ...@@ -129,14 +129,17 @@ public class UserServlet extends HttpServlet
if (action instanceof GetUserAction && isNotAugmentedSubject()) if (action instanceof GetUserAction && isNotAugmentedSubject())
{ {
subject = Subject.getSubject(AccessController.getContext()); subject = Subject.getSubject(AccessController.getContext());
log.debug("subject not augmented: " + subject);
action.setAugmentUser(true); action.setAugmentUser(true);
} }
else else
{ {
subject = AuthenticationUtil.getSubject(request); subject = AuthenticationUtil.getSubject(request);
log.debug("augmented subject: " + subject);
} }
logInfo.setSubject(subject); logInfo.setSubject(subject);
SyncOutput syncOut = new SyncOutput(response);
action.setLogInfo(logInfo); action.setLogInfo(logInfo);
action.setSyncOut(syncOut); action.setSyncOut(syncOut);
action.setAcceptedContentType(getAcceptedContentType(request)); action.setAcceptedContentType(getAcceptedContentType(request));
...@@ -251,13 +254,16 @@ public class UserServlet extends HttpServlet ...@@ -251,13 +254,16 @@ public class UserServlet extends HttpServlet
{ {
boolean notAugmented = false; boolean notAugmented = false;
Subject subject = Subject.getSubject(AccessController.getContext()); Subject subject = Subject.getSubject(AccessController.getContext());
log.debug("subject: " + subject);
if (subject != null) if (subject != null)
{ {
log.debug("notAugmentedX500User" + notAugmentedX500User);
for (Principal principal : subject.getPrincipals()) for (Principal principal : subject.getPrincipals())
{ {
if (principal instanceof X500Principal) if (principal instanceof X500Principal)
{ {
if (principal.getName().equalsIgnoreCase(this.notAugmentedX500User)) log.debug("principal: " + principal.getName());
if (principal.getName().equalsIgnoreCase(notAugmentedX500User))
{ {
notAugmented = true; notAugmented = true;
break; break;
......
...@@ -78,12 +78,14 @@ import ca.nrc.cadc.auth.HttpPrincipal; ...@@ -78,12 +78,14 @@ import ca.nrc.cadc.auth.HttpPrincipal;
import ca.nrc.cadc.auth.NumericPrincipal; import ca.nrc.cadc.auth.NumericPrincipal;
import org.junit.Test; import org.junit.Test;
import javax.security.auth.Subject;
import javax.security.auth.x500.X500Principal; import javax.security.auth.x500.X500Principal;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.StringWriter; import java.io.StringWriter;
import java.io.Writer; import java.io.Writer;
import java.security.Principal; import java.security.Principal;
import java.security.PrivilegedExceptionAction;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
...@@ -136,12 +138,26 @@ public class GetUserActionTest ...@@ -136,12 +138,26 @@ public class GetUserActionTest
@Test @Test
public void writeUserWithDetailIdentity() throws Exception public void writeUserWithDetailIdentity() throws Exception
{ {
final HttpPrincipal httpPrincipal = new HttpPrincipal("CADCtest");
final NumericPrincipal numericPrincipal = new NumericPrincipal(789);
final X500Principal x500Principal = new X500Principal("cn=foo,o=bar");
Subject testUser = new Subject();
testUser.getPrincipals().add(httpPrincipal);
testUser.getPrincipals().add(numericPrincipal);
testUser.getPrincipals().add(x500Principal);
Subject.doAs(testUser, new PrivilegedExceptionAction<Object>()
{
public Object run() throws Exception
{
final HttpServletResponse mockResponse = createMock(HttpServletResponse.class); final HttpServletResponse mockResponse = createMock(HttpServletResponse.class);
final UserPersistence<HttpPrincipal> mockUserPersistence = final UserPersistence<HttpPrincipal> mockUserPersistence =
createMock(UserPersistence.class); createMock(UserPersistence.class);
final HttpPrincipal userID = new HttpPrincipal("CADCtest");
final GetUserAction testSubject = new GetUserAction(userID, "identity")
final GetUserAction testSubject = new GetUserAction(httpPrincipal, "identity")
{ {
@Override @Override
UserPersistence<HttpPrincipal> getUserPersistence() UserPersistence<HttpPrincipal> getUserPersistence()
...@@ -150,9 +166,10 @@ public class GetUserActionTest ...@@ -150,9 +166,10 @@ public class GetUserActionTest
} }
}; };
final User<HttpPrincipal> expected = new User<HttpPrincipal>(userID); final User<HttpPrincipal> expected = new User<HttpPrincipal>(httpPrincipal);
expected.getIdentities().add(new NumericPrincipal(789)); expected.getIdentities().add(httpPrincipal);
expected.getIdentities().add(new X500Principal("cn=foo,o=bar")); expected.getIdentities().add(numericPrincipal);
expected.getIdentities().add(x500Principal);
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
UserWriter userWriter = new UserWriter(); UserWriter userWriter = new UserWriter();
...@@ -169,7 +186,6 @@ public class GetUserActionTest ...@@ -169,7 +186,6 @@ public class GetUserActionTest
final Writer writer = new StringWriter(); final Writer writer = new StringWriter();
final PrintWriter printWriter = new PrintWriter(writer); final PrintWriter printWriter = new PrintWriter(writer);
expect(mockUserPersistence.getUser(userID)).andReturn(expected).once();
mockResponse.setHeader("Content-Type", "text/xml"); mockResponse.setHeader("Content-Type", "text/xml");
expectLastCall().once(); expectLastCall().once();
expect(mockResponse.getWriter()).andReturn(printWriter).once(); expect(mockResponse.getWriter()).andReturn(printWriter).once();
...@@ -185,6 +201,10 @@ public class GetUserActionTest ...@@ -185,6 +201,10 @@ public class GetUserActionTest
assertEquals(expectedUser, actualUser); assertEquals(expectedUser, actualUser);
verify(mockUserPersistence, mockResponse); verify(mockUserPersistence, mockResponse);
return null;
}
});
} }
@Test @Test
...@@ -245,6 +265,58 @@ public class GetUserActionTest ...@@ -245,6 +265,58 @@ public class GetUserActionTest
verify(mockUserPersistence, mockResponse); verify(mockUserPersistence, mockResponse);
} }
@Test
public void writeAugmentedUser() throws Exception
{
final UserPersistence<Principal> mockUserPersistence =
createMock(UserPersistence.class);
final HttpServletResponse mockResponse = createMock(HttpServletResponse.class);
final HttpPrincipal userID = new HttpPrincipal("CADCtest");
final GetUserAction testSubject = new GetUserAction(userID, null)
{
@Override
UserPersistence<Principal> getUserPersistence()
{
return mockUserPersistence;
}
};
testSubject.setAugmentUser(true);
final NumericPrincipal numericPrincipal = new NumericPrincipal(789);
final X500Principal x500Principal = new X500Principal("cn=foo,o=bar");
final User<Principal> expected = new User<Principal>(userID);
expected.getIdentities().add(userID);
expected.getIdentities().add(numericPrincipal);
expected.getIdentities().add(x500Principal);
StringBuilder sb = new StringBuilder();
UserWriter userWriter = new UserWriter();
userWriter.write(expected, sb);
String expectedUser = sb.toString();
final Writer writer = new StringWriter();
final PrintWriter printWriter = new PrintWriter(writer);
expect(mockUserPersistence.getAugmentedUser(userID)).andReturn(expected).once();
mockResponse.setHeader("Content-Type", "text/xml");
expectLastCall().once();
expect(mockResponse.getWriter()).andReturn(printWriter).once();
replay(mockUserPersistence, mockResponse);
SyncOutput syncOutput = new SyncOutput(mockResponse);
testSubject.setSyncOut(syncOutput);
testSubject.doAction();
String actualUser = writer.toString();
assertEquals(expectedUser, actualUser);
verify(mockUserPersistence, mockResponse);
}
@Test @Test
public void writeUserJSON() throws Exception public void writeUserJSON() throws Exception
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment