-
Notifications
You must be signed in to change notification settings - Fork 12
Use pytricia for ip matching on non-windows machines #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1c6d6f3
e968e77
64d263e
d873fdf
9ddede9
09e98c6
c4dfc10
828d265
66eb7b2
eb5e75f
b1e95a7
fc2fb7a
4c89db1
848fbdf
0d09bbb
64679b8
3cecdb2
3844902
27812b6
944fe15
f3e248e
4cfae74
061ef75
9244f8b
dd2ad34
7e31023
52e2d88
1ef435a
7f965f8
7443ef7
c006e20
23df79a
dcbda4b
1b63127
13ab678
d032fdd
07d2ebf
e4799fa
7b28bfa
22fb8af
a6a0e94
7b1dbb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,49 +1,64 @@ | ||
| """ | ||
| Based on https://github.com/demskie/netparser | ||
| MIT License - Copyright (c) 2019 alex | ||
| """ | ||
|
|
||
| from .shared import parse_base_network, sort_networks, summarize_sorted_networks | ||
| from .sort import binary_search_for_insertion_index | ||
|
|
||
|
|
||
| class IPMatcher: | ||
| def __init__(self, networks=None): | ||
| self.sorted = [] | ||
| if networks is not None: | ||
| subnets = [] | ||
| for s in networks: | ||
| net = parse_base_network(s, False) | ||
| if net and net.is_valid(): | ||
| subnets.append(net) | ||
| sort_networks(subnets) | ||
| self.sorted = summarize_sorted_networks(subnets) | ||
|
|
||
| def has(self, network): | ||
| """ | ||
| Checks if the given IP address is in the list of networks. | ||
| """ | ||
| net = parse_base_network(network, False) | ||
| if not net or not net.is_valid(): | ||
| return False | ||
| idx = binary_search_for_insertion_index(net, self.sorted) | ||
| if idx < 0: | ||
| return False | ||
| if idx < len(self.sorted) and self.sorted[idx].contains(net): | ||
| return True | ||
| if idx - 1 >= 0 and self.sorted[idx - 1].contains(net): | ||
| return True | ||
| return False | ||
|
|
||
| def add(self, network): | ||
| net = parse_base_network(network, False) | ||
| if not net or not net.is_valid(): | ||
| return self | ||
| idx = binary_search_for_insertion_index(net, self.sorted) | ||
| if idx < len(self.sorted) and self.sorted[idx].compare(net) == 0: | ||
| import ipaddress | ||
|
|
||
| try: | ||
| import pytricia | ||
|
|
||
| PYTRICIA_AVAILABLE = True | ||
| except ImportError: | ||
| PYTRICIA_AVAILABLE = False | ||
| from aikido_zen.helpers.logging import logger | ||
|
|
||
| logger.warning( | ||
| "pytricia is not available. This happens on windows devices where pytricia is not supported yet." | ||
| "Using fallback, this may result in slower performance." | ||
| "You can try to install pytricia for better performance: pip install pytricia" | ||
| ) | ||
|
|
||
|
|
||
| def preparse(network: str) -> str: | ||
| # Remove the brackets around IPv6 addresses if they are there. | ||
| network = network.strip("[]") | ||
| try: | ||
| ip = ipaddress.IPv6Address(network) | ||
| if ip.ipv4_mapped: | ||
| return str(ip.ipv4_mapped) | ||
| except ValueError: | ||
bitterpanda63 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pass | ||
| return network | ||
|
|
||
|
|
||
| if PYTRICIA_AVAILABLE: | ||
|
|
||
| class IPMatcher: | ||
| def __init__(self, networks=None): | ||
| self.trie = pytricia.PyTricia(128) | ||
| if networks is not None: | ||
| for s in networks: | ||
| self._add(s) | ||
| # We freeze in constructor ensuring that after initialization the IPMatcher is always frozen. | ||
| self.trie.freeze() | ||
|
|
||
| def has(self, network): | ||
| try: | ||
| return self.trie.get(preparse(network)) is not None | ||
| except ValueError: | ||
| return False | ||
|
|
||
| def _add(self, network): | ||
| try: | ||
| self.trie[preparse(network)] = True | ||
| except ValueError: | ||
| pass | ||
| except SystemError: | ||
bitterpanda63 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # SystemError's have been known to occur in the PyTricia library (see issue #34 e.g.), | ||
| # best to play it safe and catch these errors. | ||
| pass | ||
| return self | ||
| self.sorted.insert(idx, net) | ||
| return self | ||
|
|
||
| def is_empty(self): | ||
| return len(self.sorted) == 0 | ||
| def is_empty(self): | ||
| return len(self.trie) == 0 | ||
|
|
||
| else: | ||
| # Fallback to pure Python implementation - this happens on windows machines since pytricia is not | ||
| # fully supported there. | ||
| from aikido_zen.helpers.ip_matcher_fallback import IPMatcher # noqa: F401 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| """ | ||
| Based on https://github.com/demskie/netparser | ||
| MIT License - Copyright (c) 2019 alex | ||
| """ | ||
|
|
||
| from .shared import parse_base_network, sort_networks, summarize_sorted_networks | ||
| from .sort import binary_search_for_insertion_index | ||
|
|
||
|
|
||
| class IPMatcher: | ||
| def __init__(self, networks=None): | ||
| self.sorted = [] | ||
| if networks is not None: | ||
| subnets = [] | ||
| for s in networks: | ||
| net = parse_base_network(s, False) | ||
| if net and net.is_valid(): | ||
| subnets.append(net) | ||
| sort_networks(subnets) | ||
| self.sorted = summarize_sorted_networks(subnets) | ||
|
|
||
| def has(self, network): | ||
| """ | ||
| Checks if the given IP address is in the list of networks. | ||
| """ | ||
| net = parse_base_network(network, False) | ||
| if not net or not net.is_valid(): | ||
| return False | ||
| idx = binary_search_for_insertion_index(net, self.sorted) | ||
| if idx < 0: | ||
| return False | ||
| if idx < len(self.sorted) and self.sorted[idx].contains(net): | ||
| return True | ||
| if idx - 1 >= 0 and self.sorted[idx - 1].contains(net): | ||
| return True | ||
| return False | ||
|
|
||
| def add(self, network): | ||
| net = parse_base_network(network, False) | ||
| if not net or not net.is_valid(): | ||
| return self | ||
| idx = binary_search_for_insertion_index(net, self.sorted) | ||
| if idx < len(self.sorted) and self.sorted[idx].compare(net) == 0: | ||
| return self | ||
| self.sorted.insert(idx, net) | ||
| return self | ||
|
|
||
| def is_empty(self): | ||
| return len(self.sorted) == 0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| import pytest | ||
| from . import IPMatcher | ||
|
|
||
|
|
||
| def test_single_ipv4s(): | ||
| input_list = [ | ||
| "192.168.0.0/32", | ||
| "192.168.0.3/32", | ||
| "192.168.0.24/32", | ||
| "192.168.0.52/32", | ||
| "192.168.0.123/32", | ||
| "192.168.0.124/32", | ||
| "192.168.0.125/32", | ||
| "192.168.0.170/32", | ||
| "192.168.0.171/32", | ||
| "192.168.0.222/32", | ||
| "192.168.0.234/32", | ||
| "192.168.0.255/32", | ||
| ] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("192.168.0.254") == False | ||
| assert matcher.has("192.168.0.1") == False | ||
| assert matcher.has("192.168.0.255") == True | ||
| assert matcher.has("192.168.0.24") == True | ||
|
|
||
|
|
||
| def test_with_ranges(): | ||
| input_list = [ | ||
| "192.168.0.0/24", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe good to test with multiple overlapping ranges too?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these test cases are from node iirc? (also a bit out of scope, since this was existing code copied over) |
||
| "192.168.0.3/32", | ||
| "192.168.0.24/32", | ||
| "192.168.0.52/32", | ||
| "192.168.0.123/32", | ||
| "192.168.0.124/32", | ||
| "192.168.0.125/32", | ||
| "192.168.0.170/32", | ||
| "192.168.0.171/32", | ||
| "192.168.0.222/32", | ||
| "192.168.0.234/32", | ||
| "192.168.0.255/32", | ||
| ] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("192.168.0.254") == True # Now included because of the /24 range | ||
| assert matcher.has("10.0.0.1") == False | ||
| assert matcher.has("192.168.0.234") == True | ||
|
|
||
|
|
||
| def test_with_invalid_ranges(): | ||
| input_list = [ | ||
| "192.168.0.0/24", | ||
| "192.168.0.3/32", | ||
| "192.168.0.24/32", | ||
| "192.168.0.52/32", | ||
| "foobar", | ||
| "0.a.0.0/32", | ||
| "123.123.123.123/1999", | ||
| "", | ||
| ",,,", | ||
| "192.168.0.124/32", | ||
| "192.168.0.125/32", | ||
| "192.168.0.170/32", | ||
| "192.168.0.171/32", | ||
| "192.168.0.222/32", | ||
| "192.168.0.234/32", | ||
| "192.168.0.255", | ||
| ] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("192.168.0.254") == True | ||
| assert matcher.has("foobar") == False | ||
| assert matcher.has("192.168.0.222") == True | ||
| assert matcher.has("192.168.0.1") == True | ||
| assert matcher.has("10.0.0.1") == False | ||
| assert matcher.has("192.168.0.255") == True | ||
| assert matcher.has("") == False | ||
| assert matcher.has("1") == False | ||
| assert matcher.has("192.168.0.1/32") == True | ||
|
|
||
|
|
||
| def test_with_empty_ranges(): | ||
| input_list = [] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("192.168.2.1") == False | ||
| assert matcher.has("foobar") == False | ||
|
|
||
|
|
||
| def test_with_ipv6_ranges(): | ||
| input_list = [ | ||
| "2002:db8::/32", | ||
| "2001:db8::1/128", | ||
| "2001:db8::2/128", | ||
| "2001:db8::3/128", | ||
| "2001:db8::4/128", | ||
| "2001:db8::5/128", | ||
| "2001:db8::6/128", | ||
| "2001:db8::7/128", | ||
| "2001:db8::8/128", | ||
| "2001:db8::9/128", | ||
| "2001:db8::a/128", | ||
| "2001:db8::b/128", | ||
| "2001:db8::c/128", | ||
| "2001:db8::d/128", | ||
| "2001:db8::e/128", | ||
| "[2001:db8::f]", | ||
| "2001:db9::abc", | ||
| ] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("2001:db8::1") == True | ||
| assert matcher.has("2001:db8::0") == False | ||
| assert matcher.has("2001:db8::f") == True | ||
| assert matcher.has("[2001:db8::f]") == True | ||
| assert matcher.has("2001:db8::10") == False | ||
| assert matcher.has("2002:db8::1") == True | ||
| assert matcher.has("2002:db8::2f:2") == True | ||
| assert matcher.has("2001:db9::abc") == True | ||
|
|
||
|
|
||
| def test_mix_ipv4_and_ipv6(): | ||
| input_list = ["2002:db8::/32", "10.0.0.0/8"] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("2001:db8::1") == False | ||
| assert matcher.has("2001:db8::0") == False | ||
| assert matcher.has("2002:db8::1") == True | ||
| assert matcher.has("10.0.0.1") == True | ||
| assert matcher.has("10.0.0.255") == True | ||
| assert matcher.has("192.168.1.1") == False | ||
|
|
||
|
|
||
| def test_add_ips_later(): | ||
| input_list = ["2002:db8::/32", "10.0.0.0/8"] | ||
| matcher = IPMatcher() | ||
| assert matcher.has("2001:db8::0") == False | ||
| assert matcher.has("2002:db8::1") == False | ||
| for ip in input_list: | ||
| matcher.add(ip) | ||
| assert matcher.has("2001:db8::1") == False | ||
| assert matcher.has("2001:db8::0") == False | ||
| assert matcher.has("2002:db8::1") == True | ||
| assert matcher.has("10.0.0.1") == True | ||
| assert matcher.has("10.0.0.255") == True | ||
| assert matcher.has("192.168.1.1") == False | ||
|
|
||
|
|
||
| def test_strange_ips(): | ||
| input_list = ["::ffff:0.0.0.0", "::ffff:0:0:0:0", "::ffff:127.0.0.1"] | ||
| matcher = IPMatcher(input_list) | ||
| assert matcher.has("::ffff:0.0.0.0") == True | ||
| assert matcher.has("::ffff:127.0.0.1") == True | ||
| assert matcher.has("::ffff:123") == False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this case is different for both?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, PyTricia is more accurate in this regard, this is also not a huge issue, and I think no one is running with windows in production. Given IPC also doesnt work |
||
| assert matcher.has("2001:db8::1") == False | ||
| assert matcher.has("[::ffff:0.0.0.0]") == True | ||
| assert matcher.has("::ffff:0:0:0:0") == True | ||
|
|
||
|
|
||
| def test_different_cidr_ranges(): | ||
| assert IPMatcher(["123.2.0.2/0"]).has("1.1.1.1") == True | ||
| assert IPMatcher(["123.2.0.2/1"]).has("1.1.1.1") == True | ||
| assert IPMatcher(["123.2.0.2/2"]).has("1.1.1.1") == False | ||
| assert IPMatcher(["123.2.0.2/3"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/4"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/5"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/6"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/7"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/8"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/9"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/10"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/11"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/12"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/13"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/14"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/15"]).has("123.3.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/16"]).has("123.3.0.1") == False | ||
| assert IPMatcher(["123.2.0.2/17"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/18"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/19"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/20"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/21"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/22"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/23"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/24"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/25"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/26"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/27"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/29"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/30"]).has("123.2.0.1") == True | ||
| assert IPMatcher(["123.2.0.2/31"]).has("123.2.0.1") == False | ||
| assert IPMatcher(["123.2.0.2/32"]).has("123.2.0.2") == True | ||
|
|
||
|
|
||
| def test_allow_all_ips(): | ||
| matcher = IPMatcher(["0.0.0.0/0", "::/0"]) | ||
| assert matcher.has("1.2.3.4") == True | ||
| assert matcher.has("::1") == True | ||
| assert matcher.has("::ffff:1234") == True | ||
| assert matcher.has("1.1.1.1") == True | ||
| assert matcher.has("2002:db8::1") == True | ||
| assert matcher.has("10.0.0.1") == True | ||
| assert matcher.has("10.0.0.255") == True | ||
| assert matcher.has("192.168.1.1") == True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on preparse merely restates the code. Either remove it or replace with why this normalization is needed (e.g., to canonicalize bracketed IPv6 and IPv4-mapped IPv6 inputs).
Details
✨ AI Reasoning
A newly added comment adjacent to the preparse function only restates what the code does (removing square brackets). It doesn't explain why this normalization is required (e.g., to handle bracketed IPv6 input or IPv4-mapped IPv6 addresses in callers), so it provides low informational value and increases maintenance burden.
🔧 How do I fix it?
Write comments that explain the purpose, reasoning, or business logic behind the code using words like 'because', 'so that', or 'in order to'.
Reply
@AikidoSec feedback: [FEEDBACK]to get better review comments in the future.Reply
@AikidoSec ignore: [REASON]to ignore this issue.More info