From 5784108865502505acf23a989a22247730280c27 Mon Sep 17 00:00:00 2001 From: rabbitstack Date: Sun, 1 Feb 2026 19:07:11 +0100 Subject: [PATCH] feat(functions): Implement COUNT function The count function counts the number of items in the slice or substrings in the string that is matching a wildcard pattern. --- pkg/filter/filter_test.go | 2 + pkg/filter/ql/function.go | 1 + pkg/filter/ql/functions/count.go | 92 +++++++++++++++++++++++++++ pkg/filter/ql/functions/count_test.go | 56 ++++++++++++++++ pkg/filter/ql/functions/types.go | 4 ++ pkg/filter/ql/parser_test.go | 2 +- 6 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 pkg/filter/ql/functions/count.go create mode 100644 pkg/filter/ql/functions/count_test.go diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index cc02ee531..b034342db 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -304,6 +304,7 @@ func TestProcFilter(t *testing.T) { {`ps.modules IN ('kernel32.dll')`, true}, {`evt.name = 'CreateProcess' and evt.pid != ps.ppid`, true}, {`ps.parent.name = 'svchost.exe'`, true}, + {`count(ps.modules, '*.dll') >= 2`, true}, {`ps.ancestor[0] = 'svchost.exe'`, true}, {`ps.ancestor[0] = 'csrss.exe'`, false}, @@ -311,6 +312,7 @@ func TestProcFilter(t *testing.T) { {`ps.ancestor[2] = 'csrss.exe'`, true}, {`ps.ancestor[3] = ''`, true}, {`ps.ancestor intersects ('csrss.exe', 'services.exe', 'svchost.exe')`, true}, + {`count(ps.ancestor, '*.exe') = 3`, true}, {`foreach(ps._ancestors, $proc, $proc.name in ('csrss.exe', 'services.exe', 'System'))`, true}, {`foreach(ps._ancestors, $proc, $proc.name in ('csrss.exe', 'services.exe', 'System') and ps.is_packaged, ps.is_packaged)`, true}, diff --git a/pkg/filter/ql/function.go b/pkg/filter/ql/function.go index 03a18fbb8..5415166a8 100644 --- a/pkg/filter/ql/function.go +++ b/pkg/filter/ql/function.go @@ -81,6 +81,7 @@ var funcs = map[string]FunctionDef{ functions.GetRegValueFn.String(): &functions.GetRegValue{}, functions.YaraFn.String(): &functions.Yara{}, functions.ForeachFn.String(): &Foreach{}, + functions.CountFn.String(): &functions.Count{}, } // FunctionDef is the interface that all function definitions have to satisfy. diff --git a/pkg/filter/ql/functions/count.go b/pkg/filter/ql/functions/count.go new file mode 100644 index 000000000..1ed051659 --- /dev/null +++ b/pkg/filter/ql/functions/count.go @@ -0,0 +1,92 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functions + +import ( + "strings" + + "github.com/rabbitstack/fibratus/pkg/util/wildcard" +) + +// Count counts the number of items in the slice or substrings +// in the string that is matching a wildcard pattern. +type Count struct{} + +func (f Count) Call(args []interface{}) (any, bool) { + if len(args) < 2 { + return false, false + } + + var count int + var caseInsensitive bool + + pattern := parseString(1, args) + + if len(args) > 2 { + caseInsensitive, _ = args[2].(bool) + } else { + caseInsensitive = true + } + + switch s := args[0].(type) { + case string: + substrings := strings.Fields(s) + for _, ss := range substrings { + switch caseInsensitive { + case true: + if wildcard.Match(strings.ToLower(pattern), strings.ToLower(ss)) { + count++ + } + case false: + if wildcard.Match(pattern, ss) { + count++ + } + } + } + case []string: + for _, i := range s { + switch caseInsensitive { + case true: + if wildcard.Match(strings.ToLower(pattern), strings.ToLower(i)) { + count++ + } + case false: + if wildcard.Match(pattern, i) { + count++ + } + } + } + } + + return count, true +} + +func (f Count) Desc() FunctionDesc { + desc := FunctionDesc{ + Name: CountFn, + Args: []FunctionArgDesc{ + {Keyword: "string|slice", Types: []ArgType{Field, BoundField, BoundSegment, BareBoundVariable, Func, String, Slice}, Required: true}, + {Keyword: "pattern", Types: []ArgType{String}, Required: true}, + {Keyword: "case_insensitive", Types: []ArgType{Bool}, Required: false}, + }, + } + return desc +} + +func (f Count) Name() Fn { return CountFn } diff --git a/pkg/filter/ql/functions/count_test.go b/pkg/filter/ql/functions/count_test.go new file mode 100644 index 000000000..19dbbc03f --- /dev/null +++ b/pkg/filter/ql/functions/count_test.go @@ -0,0 +1,56 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functions + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCount(t *testing.T) { + var tests = []struct { + args []any + expected int + }{ + { + []any{"hello world", "?orld"}, + 1, + }, + { + []any{"hello world", "saturn"}, + 0, + }, + { + []any{[]string{"C:\\Windows\\System32\\ntdll.dll", "C:\\Windows\\System32\\NTDLL.dll"}, "*ntdll.dll"}, + 2, + }, + { + []any{[]string{"C:\\Windows\\System32\\ntdll.dll", "C:\\Windows\\System32\\NTDLL.dll"}, "*ntdll.dll", false}, + 1, + }, + } + + for i, tt := range tests { + f := Count{} + res, _ := f.Call(tt.args) + assert.Equal(t, tt.expected, res, fmt.Sprintf("%d. result mismatch: exp=%v got=%v", i, tt.expected, res)) + } +} diff --git a/pkg/filter/ql/functions/types.go b/pkg/filter/ql/functions/types.go index 626b32352..459b12383 100644 --- a/pkg/filter/ql/functions/types.go +++ b/pkg/filter/ql/functions/types.go @@ -74,6 +74,8 @@ const ( YaraFn // ForeachFn represents the FOREACH function ForeachFn + // CountFn reprsents the COUNT function + CountFn ) // ArgType is the type alias for the argument value type. @@ -228,6 +230,8 @@ func (f Fn) String() string { return "YARA" case ForeachFn: return "FOREACH" + case CountFn: + return "COUNT" default: return "UNDEFINED" } diff --git a/pkg/filter/ql/parser_test.go b/pkg/filter/ql/parser_test.go index c8e2b1669..0055544cb 100644 --- a/pkg/filter/ql/parser_test.go +++ b/pkg/filter/ql/parser_test.go @@ -60,7 +60,7 @@ func TestParser(t *testing.T) { {expr: "ps.none = 'cmd.exe'", err: errors.New("ps.none = 'cmd.exe'\n╭^\n|\n|\n╰─────────────────── expected field, bound field, string, number, bool, ip, function")}, {expr: "ps.name = 'cmd.exe' AND ps.name IN ('exe') ps.name", err: errors.New("ps.name = 'cmd.exe' AND ps.name IN ('exe') ps.name\n╭──────────────────────────────────────────^\n|\n|\n╰─────────────────── expected operator, ')', ',', '|'")}, - {expr: "ip_cidr(net.dip) = '24'", err: errors.New("ip_cidr function is undefined. Did you mean one of BASE|CIDR_CONTAINS|CONCAT|DIR|ENTROPY|EXT|FOREACH|GET_REG_VALUE|GLOB|INDEXOF|IS_ABS|IS_MINIDUMP|LENGTH|LOWER|LTRIM|MD5|REGEX|REPLACE|RTRIM|SPLIT|SUBSTR|UNDEFINED|UPPER|VOLUME|YARA?")}, + {expr: "ip_cidr(net.dip) = '24'", err: errors.New("ip_cidr function is undefined. Did you mean one of BASE|CIDR_CONTAINS|CONCAT|COUNT|DIR|ENTROPY|EXT|FOREACH|GET_REG_VALUE|GLOB|INDEXOF|IS_ABS|IS_MINIDUMP|LENGTH|LOWER|LTRIM|MD5|REGEX|REPLACE|RTRIM|SPLIT|SUBSTR|UNDEFINED|UPPER|VOLUME|YARA?")}, {expr: "ps.name = 'cmd.exe' and not cidr_contains(net.sip, '172.14.0.0')"}, {expr: `ps.envs[ProgramFiles] = 'C:\\Program Files'`},