diff --git a/src/main/java/it/inaf/oats/vospace/UriService.java b/src/main/java/it/inaf/oats/vospace/UriService.java index 6f78cb11d4273cd6d8e3790df6a21a71a6eb43e8..f140d2a0f8b6311d0acbf7efccb9cf3e8fc376dd 100644 --- a/src/main/java/it/inaf/oats/vospace/UriService.java +++ b/src/main/java/it/inaf/oats/vospace/UriService.java @@ -28,6 +28,7 @@ import java.net.MalformedURLException; import java.net.URL; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; import javax.servlet.http.HttpServletRequest; @@ -94,14 +95,15 @@ public class UriService { JobService.JobDirection jobDirection = JobDirection.getJobDirectionEnumFromTransfer(transfer); - List<String> validProtocolUris = new ArrayList<>(); + List<ProtocolType> validProtocolTypes = new ArrayList<>(); switch (jobDirection) { case pullFromVoSpace: + validProtocolTypes.add(ProtocolType.HTTPSGET); case pullToVoSpace: - validProtocolUris.add("ivo://ivoa.net/vospace/core#httpget"); + validProtocolTypes.add(ProtocolType.HTTPGET); break; case pushToVoSpace: - validProtocolUris.add("ivo://ivoa.net/vospace/core#httpput"); + validProtocolTypes.add(ProtocolType.HTTPPUT); break; default: @@ -110,15 +112,30 @@ public class UriService { List<Protocol> validProtocols = transfer.getProtocols().stream() - // discard invalid protocols - .filter(protocol -> validProtocolUris.contains(protocol.getUri())) + // discard invalid protocols by uri String + .filter(protocol + -> validProtocolTypes.stream().map(pt + -> { + return pt.getUri(); + }) + .collect(Collectors.toList()) + .contains(protocol.getUri())) .map(p -> { // set endpoints - Protocol protocol = new Protocol(); - protocol.setUri(p.getUri()); - protocol.setEndpoint(getEndpoint(job, transfer)); - return protocol; - }).collect(Collectors.toList()); + String endpoint = getEndpoint(job, transfer); + ProtocolType pt + = ProtocolType.getProtocolTypeFromURI(p.getUri()); + + if (pt.isEndpointCompliant(endpoint)) { + Protocol protocol = new Protocol(); + protocol.setUri(p.getUri()); + protocol.setEndpoint(endpoint); + return protocol; + } else { + return null; + } + }).filter(Objects::nonNull) + .collect(Collectors.toList()); if (validProtocols.isEmpty()) { Protocol protocol = transfer.getProtocols().get(0); @@ -326,4 +343,41 @@ public class UriService { } } } + + public enum ProtocolType { + // Please keep the URIs in this enum UNIQUE! + // will add a unit test to check this + HTTPGET("ivo://ivoa.net/vospace/core#httpget", "http"), + HTTPSGET("ivo://ivoa.net/vospace/core#httpsget", "https"), + HTTPPUT("ivo://ivoa.net/vospace/core#httpput", "http"), + HTTPSPUT("ivo://ivoa.net/vospace/core#httpsput", "https"); + + private final String uri; + private final String protocolString; + + private ProtocolType(String uri, String protocolString) { + this.uri = uri; + this.protocolString = protocolString; + } + + public String getUri() { + return this.uri; + } + + public boolean isEndpointCompliant(String endpoint) { + return endpoint.toLowerCase() + .startsWith(this.protocolString + "://"); + } + + public static ProtocolType getProtocolTypeFromURI(String uri) { + for (ProtocolType pt : ProtocolType.values()) { + if (pt.getUri().equals(uri)) { + return pt; + } + } + + return null; + } + + } }