diff --git a/gms/src/main/java/it/inaf/ia2/gms/rap/RapClient.java b/gms/src/main/java/it/inaf/ia2/gms/rap/RapClient.java index 5d8b2bde72af492fdd3886076d83182e95ce48aa..a2e1948fa3f620416baf5bdad053d29661fcefd5 100644 --- a/gms/src/main/java/it/inaf/ia2/gms/rap/RapClient.java +++ b/gms/src/main/java/it/inaf/ia2/gms/rap/RapClient.java @@ -1,5 +1,7 @@ package it.inaf.ia2.gms.rap; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import it.inaf.ia2.gms.authn.SessionData; import it.inaf.ia2.gms.model.RapUser; import java.util.ArrayList; @@ -21,6 +23,8 @@ import org.springframework.stereotype.Component; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestTemplate; @Component @@ -51,6 +55,8 @@ public class RapClient { private final RestTemplate refreshTokenRestTemplate; + private final ObjectMapper objectMapper = new ObjectMapper(); + @Autowired public RapClient(RestTemplate rapRestTemplate) { this.rapRestTemplate = rapRestTemplate; @@ -101,14 +107,30 @@ public class RapClient { private <R, T> R httpCall(Function<HttpEntity<?>, R> function, T body) { try { - return function.apply(getEntity(body)); - } catch (HttpClientErrorException.Unauthorized ex) { - if (request.getSession(false) == null) { - // we can't refresh the token without a session - throw ex; + try { + return function.apply(getEntity(body)); + } catch (HttpClientErrorException.Unauthorized ex) { + if (request.getSession(false) == null || sessionData.getExpiresIn() > 0) { + // we can't refresh the token without a session + throw ex; + } + refreshToken(); + return function.apply(getEntity(body)); + } + } catch (HttpStatusCodeException ex) { + try { + Map<String, String> map = objectMapper.readValue(ex.getResponseBodyAsString(), Map.class); + if (map.containsKey("error")) { + String error = map.get("error"); + if (ex instanceof HttpClientErrorException) { + throw new HttpClientErrorException(ex.getStatusCode(), error); + } else if (ex instanceof HttpServerErrorException) { + throw new HttpServerErrorException(ex.getStatusCode(), error); + } + } + } catch (JsonProcessingException ignore) { } - refreshToken(); - return function.apply(getEntity(body)); + throw ex; } } diff --git a/gms/src/test/java/it/inaf/ia2/gms/rap/RapClientTest.java b/gms/src/test/java/it/inaf/ia2/gms/rap/RapClientTest.java new file mode 100644 index 0000000000000000000000000000000000000000..56651a9b695c4776aa94aa272a26dceb14690807 --- /dev/null +++ b/gms/src/test/java/it/inaf/ia2/gms/rap/RapClientTest.java @@ -0,0 +1,153 @@ +package it.inaf.ia2.gms.rap; + +import it.inaf.ia2.gms.authn.SessionData; +import it.inaf.ia2.gms.model.RapUser; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import org.mockito.Mock; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.HttpClientErrorException.Unauthorized; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.HttpServerErrorException.InternalServerError; +import org.springframework.web.client.RestTemplate; + +@RunWith(MockitoJUnitRunner.class) +public class RapClientTest { + + @Mock + private HttpServletRequest request; + + @Mock + private SessionData sessionData; + + @Mock + private RestTemplate restTemplate; + + @Mock + private RestTemplate refreshTokenRestTemplate; + + private RapClient rapClient; + + @Before + public void init() { + rapClient = new RapClient(restTemplate); + ReflectionTestUtils.setField(rapClient, "request", request); + ReflectionTestUtils.setField(rapClient, "refreshTokenRestTemplate", refreshTokenRestTemplate); + ReflectionTestUtils.setField(rapClient, "scope", "openid"); + } + + @Test + public void testUnauthorizedNoRefreshJsonMsg() { + + String jsonError = "{\"error\":\"Unauthorized: foo\"}"; + + HttpClientErrorException exception = Unauthorized + .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() { + }))).thenThrow(exception); + + try { + rapClient.getUser("123"); + } catch (HttpClientErrorException ex) { + assertEquals("401 Unauthorized: foo", ex.getMessage()); + } + } + + @Test + public void testUnauthorizedNoRefreshNotJsonMsg() { + + String errorMessage = "THIS IS NOT A JSON"; + + HttpClientErrorException exception = Unauthorized + .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, errorMessage.getBytes(), StandardCharsets.UTF_8); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() { + }))).thenThrow(exception); + + try { + rapClient.getUser("123"); + } catch (HttpClientErrorException ex) { + assertNotNull(ex.getMessage()); + } + } + + @Test + public void testServerErrorJsonMsg() { + + String jsonError = "{\"error\":\"Fatal error\"}"; + + HttpServerErrorException exception = InternalServerError + .create(HttpStatus.INTERNAL_SERVER_ERROR, "500", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() { + }))).thenThrow(exception); + + try { + rapClient.getUser("123"); + } catch (HttpServerErrorException ex) { + assertEquals("500 Fatal error", ex.getMessage()); + } + } + + @Test + public void testRefreshToken() { + + when(request.getSession(eq(false))).thenReturn(mock(HttpSession.class)); + when(sessionData.getExpiresIn()).thenReturn(-100l); + + ReflectionTestUtils.setField(rapClient, "sessionData", sessionData); + ReflectionTestUtils.setField(rapClient, "clientId", "clientId"); + ReflectionTestUtils.setField(rapClient, "clientSecret", "clientSecret"); + ReflectionTestUtils.setField(rapClient, "accessTokenUri", "https://sso.ia2.inaf.it"); + + String jsonError = "{\"error\":\"Unauthorized: token expired\"}"; + + HttpClientErrorException exception = Unauthorized + .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() { + }))).thenThrow(exception) + .thenReturn(ResponseEntity.ok(new RapUser())); + + ResponseEntity refreshTokenResponse = mock(ResponseEntity.class); + Map<String, Object> mockedBody = new HashMap<>(); + mockedBody.put("access_token", "<access_token>"); + mockedBody.put("refresh_token", "<refresh_token>"); + mockedBody.put("expires_in", 3600); + when(refreshTokenResponse.getBody()).thenReturn(mockedBody); + + when(refreshTokenRestTemplate.postForEntity(anyString(), any(HttpEntity.class), any())) + .thenReturn(refreshTokenResponse); + + RapUser user = rapClient.getUser("123"); + assertNotNull(user); + + // verifies that token is refreshed + verify(sessionData, times(1)).setAccessToken(eq("<access_token>")); + verify(sessionData, times(1)).setExpiresIn(eq(3600)); + } +}