diff --git a/dspace-api/src/main/java/org/dspace/service/impl/ClientInfoServiceImpl.java b/dspace-api/src/main/java/org/dspace/service/impl/ClientInfoServiceImpl.java index f63a7a4f91..01fc6d4e72 100644 --- a/dspace-api/src/main/java/org/dspace/service/impl/ClientInfoServiceImpl.java +++ b/dspace-api/src/main/java/org/dspace/service/impl/ClientInfoServiceImpl.java @@ -139,7 +139,7 @@ public class ClientInfoServiceImpl implements ClientInfoService { // If our IPTable is not empty, log the trusted proxies and return it if (!ipTable.isEmpty()) { - log.info("Trusted proxies (configure via 'proxies.trusted.ipranges'): {}", ipTable.toSet().toString()); + log.info("Trusted proxies (configure via 'proxies.trusted.ipranges'): {}", ipTable.toString()); return ipTable; } else { return null; diff --git a/dspace-api/src/main/java/org/dspace/statistics/util/IPTable.java b/dspace-api/src/main/java/org/dspace/statistics/util/IPTable.java index 139b75e8cf..b3ae961d35 100644 --- a/dspace-api/src/main/java/org/dspace/statistics/util/IPTable.java +++ b/dspace-api/src/main/java/org/dspace/statistics/util/IPTable.java @@ -7,9 +7,10 @@ */ package org.dspace.statistics.util; -import java.util.HashMap; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.util.HashSet; -import java.util.Map; +import java.util.Iterator; import java.util.Set; import org.apache.logging.log4j.LogManager; @@ -25,8 +26,32 @@ public class IPTable { private static final Logger log = LogManager.getLogger(IPTable.class); /* A lookup tree for IP addresses and SubnetRanges */ - private final Map>>> map - = new HashMap<>(); + private final Set ipRanges = new HashSet<>(); + + /** + * Internal class representing an IP range + */ + static class IPRange { + + /* Lowest address in the range */ + private final long ipLo; + + /* Highest address in the range */ + private final long ipHi; + + IPRange(long ipLo, long ipHi) { + this.ipLo = ipLo; + this.ipHi = ipHi; + } + + public long getIpLo() { + return ipLo; + } + + public long getIpHi() { + return ipHi; + } + } /** * Can be full v4 IP, subnet or range string. @@ -45,155 +70,114 @@ public class IPTable { */ public void add(String ip) throws IPFormatException { - String[] start; + String start; - String[] end; + String end; String[] range = ip.split("-"); - if (range.length >= 2) { + if (range.length == 2) { - start = range[0].trim().split("/")[0].split("\\."); - end = range[1].trim().split("/")[0].split("\\."); + start = range[0].trim(); + end = range[1].trim(); - if (start.length != 4 || end.length != 4) { - throw new IPFormatException(ip + " - Ranges need to be full IPv4 Addresses"); + try { + long ipLo = ipToLong(InetAddress.getByName(start)); + long ipHi = ipToLong(InetAddress.getByName(end)); + ipRanges.add(new IPRange(ipLo, ipHi)); + return; + } catch (UnknownHostException e) { + throw new IPFormatException(ip + " - Range format should be similar to 1.2.3.0-1.2.3.255"); } - if (!(start[0].equals(end[0]) && start[1].equals(end[1]) && start[2].equals(end[2]))) { - throw new IPFormatException(ip + " - Ranges can only be across the last subnet x.y.z.0 - x.y.z.254"); - } - - } else { + } else if (ip.contains("/")) { //need to ignore CIDR notation for the moment. //ip = ip.split("\\/")[0]; - String[] subnets = ip.split("\\."); - - if (subnets.length < 3) { - throw new IPFormatException(ip + " - require at least three subnet places (255.255.255.0"); - } - - start = subnets; - end = subnets; - } - - if (start.length >= 3) { - Map>> first = map.get(start[0]); - - if (first == null) { - first = new HashMap<>(); - map.put(start[0], first); - } - - Map> second = first.get(start[1]); - - if (second == null) { - second = new HashMap<>(); - first.put(start[1], second); - } - - Set third = second.get(start[2]); - - if (third == null) { - third = new HashSet<>(); - second.put(start[2], third); - } - - //now populate fourth place (* or value 0-254); - - if (start.length == 3) { - third.add("*"); - } - - if (third.contains("*")) { - return; - } - - if (start.length >= 4) { - int s = Integer.valueOf(start[3]); - int e = Integer.valueOf(end[3]); - for (int i = s; i <= e; i++) { - third.add(String.valueOf(i)); + String[] parts = ip.split("/"); + try { + byte[] octets = InetAddress.getByName(parts[0]).getAddress(); + long result = 0; + for (byte octet : octets) { + result <<= 8; + result |= octet & 0xff; } + long mask = (long) Math.pow(2, 32 - Integer.parseInt(parts[1])); + long ipLo = (result / mask) * mask; + long ipHi = (( (result / mask) + 1) * mask) - 1; + ipRanges.add(new IPRange(ipLo, ipHi)); + return; + } catch (Exception e) { + throw new IPFormatException(ip + " - Range format should be similar to 172.16.0.0/12"); + } + + } else { + try { + long ipLo = ipToLong(InetAddress.getByName(ip)); + ipRanges.add(new IPRange(ipLo, ipLo)); + return; + } catch (UnknownHostException e) { + throw new IPFormatException(ip + " - IP address format should be similar to 1.2.3.14"); } } } + public static long ipToLong(InetAddress ip) { + byte[] octets = ip.getAddress(); + long result = 0; + for (byte octet : octets) { + result <<= 8; + result |= octet & 0xff; + } + return result; + } + + public static String longToIp(long ip) { + long part = ip; + String[] parts = new String[4]; + for (int i = 0; i < 4; i++) { + long octet = part & 0xff; + parts[3-i] = String.valueOf(octet); + part = part / 256; + } + + return parts[0]+"."+parts[1]+"."+parts[2]+"."+parts[3]; + } + /** * Check whether a given address is contained in this netblock. * * @param ip the address to be tested * @return true if {@code ip} is within this table's limits. Returns false - * if {@link ip} looks like an IPv6 address. + * if {@code ip} looks like an IPv6 address. * @throws IPFormatException Exception Class to deal with IPFormat errors. */ public boolean contains(String ip) throws IPFormatException { - String[] subnets = ip.split("\\."); - - // Does it look like IPv6? - if (subnets.length > 4 || ip.contains("::")) { - log.warn("Address {} assumed not to match. IPv6 is not implemented.", ip); - return false; + try { + long ipToTest = ipToLong(InetAddress.getByName(ip)); + return ipRanges.stream() + .anyMatch(ipRange -> (ipToTest >= ipRange.getIpLo() && ipToTest <= ipRange.getIpHi())); + } catch (UnknownHostException e) { + throw new IPFormatException("ip not valid"); } - - // Does it look like a subnet? - if (subnets.length < 4) { - throw new IPFormatException("needs to be a single IP address"); - } - - Map>> first = map.get(subnets[0]); - - if (first == null) { - return false; - } - - Map> second = first.get(subnets[1]); - - if (second == null) { - return false; - } - - Set third = second.get(subnets[2]); - - if (third == null) { - return false; - } - - return third.contains(subnets[3]) || third.contains("*"); - } /** - * Convert to a Set. + * Convert to a Set. This set contains all IPs in the range * * @return this table's content as a Set */ public Set toSet() { HashSet set = new HashSet<>(); - for (Map.Entry>>> first : map.entrySet()) { - String firstString = first.getKey(); - Map>> secondMap = first.getValue(); - - for (Map.Entry>> second : secondMap.entrySet()) { - String secondString = second.getKey(); - Map> thirdMap = second.getValue(); - - for (Map.Entry> third : thirdMap.entrySet()) { - String thirdString = third.getKey(); - Set fourthSet = third.getValue(); - - if (fourthSet.contains("*")) { - set.add(firstString + "." + secondString + "." + thirdString); - } else { - for (String fourth : fourthSet) { - set.add(firstString + "." + secondString + "." + thirdString + "." + fourth); - } - } - - } + Iterator ipRangeIterator = ipRanges.iterator(); + while (ipRangeIterator.hasNext()) { + IPRange ipRange = ipRangeIterator.next(); + long ipLo = ipRange.getIpLo(); + long ipHi = ipRange.getIpHi(); + for (long ip = ipLo; ip <= ipHi; ip++) { + set.add(longToIp(ip).toString()); } } @@ -205,7 +189,7 @@ public class IPTable { * @return true if empty, false otherwise */ public boolean isEmpty() { - return map.isEmpty(); + return ipRanges.isEmpty(); } /** @@ -217,5 +201,19 @@ public class IPTable { } } - + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + Iterator ipRangeIterator = ipRanges.iterator(); + while (ipRangeIterator.hasNext()) { + IPRange ipRange = ipRangeIterator.next(); + stringBuilder.append(longToIp(ipRange.getIpLo())) + .append("-") + .append(longToIp(ipRange.getIpHi())); + if (ipRangeIterator.hasNext()) { + stringBuilder.append(", "); + } + } + return stringBuilder.toString(); + } }