Refactor IPTable

This commit is contained in:
Ben Bosman
2022-07-06 13:52:49 +02:00
committed by Yura Bondarenko
parent 69345ff3fc
commit f9d6091c2c
2 changed files with 119 additions and 121 deletions

View File

@@ -139,7 +139,7 @@ public class ClientInfoServiceImpl implements ClientInfoService {
// If our IPTable is not empty, log the trusted proxies and return it
if (!ipTable.isEmpty()) {
log.info("Trusted proxies (configure via 'proxies.trusted.ipranges'): {}", ipTable.toSet().toString());
log.info("Trusted proxies (configure via 'proxies.trusted.ipranges'): {}", ipTable.toString());
return ipTable;
} else {
return null;

View File

@@ -7,9 +7,10 @@
*/
package org.dspace.statistics.util;
import java.util.HashMap;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.HashSet;
import java.util.Map;
import java.util.Iterator;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
@@ -25,8 +26,32 @@ public class IPTable {
private static final Logger log = LogManager.getLogger(IPTable.class);
/* A lookup tree for IP addresses and SubnetRanges */
private final Map<String, Map<String, Map<String, Set<String>>>> map
= new HashMap<>();
private final Set<IPRange> ipRanges = new HashSet<>();
/**
* Internal class representing an IP range
*/
static class IPRange {
/* Lowest address in the range */
private final long ipLo;
/* Highest address in the range */
private final long ipHi;
IPRange(long ipLo, long ipHi) {
this.ipLo = ipLo;
this.ipHi = ipHi;
}
public long getIpLo() {
return ipLo;
}
public long getIpHi() {
return ipHi;
}
}
/**
* Can be full v4 IP, subnet or range string.
@@ -45,155 +70,114 @@ public class IPTable {
*/
public void add(String ip) throws IPFormatException {
String[] start;
String start;
String[] end;
String end;
String[] range = ip.split("-");
if (range.length >= 2) {
if (range.length == 2) {
start = range[0].trim().split("/")[0].split("\\.");
end = range[1].trim().split("/")[0].split("\\.");
start = range[0].trim();
end = range[1].trim();
if (start.length != 4 || end.length != 4) {
throw new IPFormatException(ip + " - Ranges need to be full IPv4 Addresses");
try {
long ipLo = ipToLong(InetAddress.getByName(start));
long ipHi = ipToLong(InetAddress.getByName(end));
ipRanges.add(new IPRange(ipLo, ipHi));
return;
} catch (UnknownHostException e) {
throw new IPFormatException(ip + " - Range format should be similar to 1.2.3.0-1.2.3.255");
}
if (!(start[0].equals(end[0]) && start[1].equals(end[1]) && start[2].equals(end[2]))) {
throw new IPFormatException(ip + " - Ranges can only be across the last subnet x.y.z.0 - x.y.z.254");
}
} else {
} else if (ip.contains("/")) {
//need to ignore CIDR notation for the moment.
//ip = ip.split("\\/")[0];
String[] subnets = ip.split("\\.");
if (subnets.length < 3) {
throw new IPFormatException(ip + " - require at least three subnet places (255.255.255.0");
}
start = subnets;
end = subnets;
}
if (start.length >= 3) {
Map<String, Map<String, Set<String>>> first = map.get(start[0]);
if (first == null) {
first = new HashMap<>();
map.put(start[0], first);
}
Map<String, Set<String>> second = first.get(start[1]);
if (second == null) {
second = new HashMap<>();
first.put(start[1], second);
}
Set<String> third = second.get(start[2]);
if (third == null) {
third = new HashSet<>();
second.put(start[2], third);
}
//now populate fourth place (* or value 0-254);
if (start.length == 3) {
third.add("*");
}
if (third.contains("*")) {
return;
}
if (start.length >= 4) {
int s = Integer.valueOf(start[3]);
int e = Integer.valueOf(end[3]);
for (int i = s; i <= e; i++) {
third.add(String.valueOf(i));
String[] parts = ip.split("/");
try {
byte[] octets = InetAddress.getByName(parts[0]).getAddress();
long result = 0;
for (byte octet : octets) {
result <<= 8;
result |= octet & 0xff;
}
long mask = (long) Math.pow(2, 32 - Integer.parseInt(parts[1]));
long ipLo = (result / mask) * mask;
long ipHi = (( (result / mask) + 1) * mask) - 1;
ipRanges.add(new IPRange(ipLo, ipHi));
return;
} catch (Exception e) {
throw new IPFormatException(ip + " - Range format should be similar to 172.16.0.0/12");
}
} else {
try {
long ipLo = ipToLong(InetAddress.getByName(ip));
ipRanges.add(new IPRange(ipLo, ipLo));
return;
} catch (UnknownHostException e) {
throw new IPFormatException(ip + " - IP address format should be similar to 1.2.3.14");
}
}
}
public static long ipToLong(InetAddress ip) {
byte[] octets = ip.getAddress();
long result = 0;
for (byte octet : octets) {
result <<= 8;
result |= octet & 0xff;
}
return result;
}
public static String longToIp(long ip) {
long part = ip;
String[] parts = new String[4];
for (int i = 0; i < 4; i++) {
long octet = part & 0xff;
parts[3-i] = String.valueOf(octet);
part = part / 256;
}
return parts[0]+"."+parts[1]+"."+parts[2]+"."+parts[3];
}
/**
* Check whether a given address is contained in this netblock.
*
* @param ip the address to be tested
* @return true if {@code ip} is within this table's limits. Returns false
* if {@link ip} looks like an IPv6 address.
* if {@code ip} looks like an IPv6 address.
* @throws IPFormatException Exception Class to deal with IPFormat errors.
*/
public boolean contains(String ip) throws IPFormatException {
String[] subnets = ip.split("\\.");
// Does it look like IPv6?
if (subnets.length > 4 || ip.contains("::")) {
log.warn("Address {} assumed not to match. IPv6 is not implemented.", ip);
return false;
try {
long ipToTest = ipToLong(InetAddress.getByName(ip));
return ipRanges.stream()
.anyMatch(ipRange -> (ipToTest >= ipRange.getIpLo() && ipToTest <= ipRange.getIpHi()));
} catch (UnknownHostException e) {
throw new IPFormatException("ip not valid");
}
// Does it look like a subnet?
if (subnets.length < 4) {
throw new IPFormatException("needs to be a single IP address");
}
Map<String, Map<String, Set<String>>> first = map.get(subnets[0]);
if (first == null) {
return false;
}
Map<String, Set<String>> second = first.get(subnets[1]);
if (second == null) {
return false;
}
Set<String> third = second.get(subnets[2]);
if (third == null) {
return false;
}
return third.contains(subnets[3]) || third.contains("*");
}
/**
* Convert to a Set.
* Convert to a Set. This set contains all IPs in the range
*
* @return this table's content as a Set
*/
public Set<String> toSet() {
HashSet<String> set = new HashSet<>();
for (Map.Entry<String, Map<String, Map<String, Set<String>>>> first : map.entrySet()) {
String firstString = first.getKey();
Map<String, Map<String, Set<String>>> secondMap = first.getValue();
for (Map.Entry<String, Map<String, Set<String>>> second : secondMap.entrySet()) {
String secondString = second.getKey();
Map<String, Set<String>> thirdMap = second.getValue();
for (Map.Entry<String, Set<String>> third : thirdMap.entrySet()) {
String thirdString = third.getKey();
Set<String> fourthSet = third.getValue();
if (fourthSet.contains("*")) {
set.add(firstString + "." + secondString + "." + thirdString);
} else {
for (String fourth : fourthSet) {
set.add(firstString + "." + secondString + "." + thirdString + "." + fourth);
}
}
}
Iterator<IPRange> ipRangeIterator = ipRanges.iterator();
while (ipRangeIterator.hasNext()) {
IPRange ipRange = ipRangeIterator.next();
long ipLo = ipRange.getIpLo();
long ipHi = ipRange.getIpHi();
for (long ip = ipLo; ip <= ipHi; ip++) {
set.add(longToIp(ip).toString());
}
}
@@ -205,7 +189,7 @@ public class IPTable {
* @return true if empty, false otherwise
*/
public boolean isEmpty() {
return map.isEmpty();
return ipRanges.isEmpty();
}
/**
@@ -217,5 +201,19 @@ public class IPTable {
}
}
@Override
public String toString() {
StringBuilder stringBuilder = new StringBuilder();
Iterator<IPRange> ipRangeIterator = ipRanges.iterator();
while (ipRangeIterator.hasNext()) {
IPRange ipRange = ipRangeIterator.next();
stringBuilder.append(longToIp(ipRange.getIpLo()))
.append("-")
.append(longToIp(ipRange.getIpHi()));
if (ipRangeIterator.hasNext()) {
stringBuilder.append(", ");
}
}
return stringBuilder.toString();
}
}