#! /usr/bin/python

#######################################################################
# Copyright (C) 2008-2020 by Carnegie Mellon University.
#
# @OPENSOURCE_LICENSE_START@
# See license information in ../../LICENSE.txt
# @OPENSOURCE_LICENSE_END@
#
#######################################################################
# $SiLK: rwidsquery.in ef14e54179be 2020-04-14 21:57:45Z mthomas $
#######################################################################
# rwidsquery
#
# Invoke rwfilter to pull flows matching Snort signatures/alerts
#######################################################################

from __future__ import print_function

# netsa-python isn't released yet
# from netsa.utils import get_protocol_number
from optparse import OptionParser, OptionValueError

import os
import re
import datetime
import subprocess
import tempfile
import calendar
import subprocess
import sys


# BEGIN netsa-python code ----------------------------------------------------
def build_protocol_map():
    try:
        f = None
        proto_map = {}
        try:
            f = file('/etc/protocols', 'r')
            for l in f:
                l = l.strip()
                if l != '' and l[0] != '#':
                    try:
                        tokens = l.split()
                        proto_map[int(tokens[1])] = str.lower(tokens[0])
                    except:
                        pass
            return proto_map
        finally:
            if f:
                f.close()
    except:
        return { 1: 'icmp', 6: 'tcp', 17: 'udp' }


protocol_map = build_protocol_map()
protocol_map_reverse = dict([(v,k) for (k,v) in protocol_map.items()])


def get_protocol_name(number):
    return protocol_map[number]

def get_protocol_number(name):
    return protocol_map_reverse[str.lower(name)]
# END netsa-python code ------------------------------------------------------

if sys.version_info[0] >= 3:
    coding = {"encoding": "ascii"}
else:
    coding = {}

default_year = datetime.datetime.now().year
default_tolerance = 3600

def parse_options():

    filterargs = []
    file_types = [ 'full', 'fast', 'rule' ];
    op = OptionParser(usage="""%prog --intype=<intype> [options] <infile>""")
    op.set_defaults(
        output_file="stdout",
        year=default_year,
        tolerance=default_tolerance,
        type='auto')
    op.add_option("--version", dest="version", action="store_true",
        help="print version and exit")
    op.add_option("-o", "--output-file", dest="output_file",
        help="write flow records to this file (default: stdout)")
    op.add_option("-t", "--intype", dest="type",
        help="input file type (one of 'fast', 'full', or 'rule')")
    op.add_option("-s", "--start-date", dest="start",
        help="start date for flow selection")
    op.add_option("-e", "--end-date", dest="end",
        help="end date for flow selection")
    op.add_option("-y", "--year", dest="year",
        help="year to be used for alert timestamps")
    op.add_option("--tolerance", dest="tolerance",
        help="time tolerance in seconds between alert and flow timestamps")
    op.add_option("-c", "--config-file", dest="config_file",
        help="Snort configuration file location")
    op.add_option("-m", "--mask", dest="mask",
        help="list of rwfilter predicates to mask")
    op.add_option("--dry-run", dest="dry_run", action="store_true",
        help="display rwfilter command without running it")
    op.add_option("-v", "--verbose", dest="verbose", action="store_true",
        help="print rwfilter command before it's invoked")
    args = sys.argv[1:]

    try:
        args, filterargs = args[:args.index('--')], args[args.index('--')+1:]
    except ValueError:
        pass

    options, args = op.parse_args(args)

    if options.version:
        tool_version_exit()
    if options.type not in file_types:
        op.error("File type must be one of %s" %(file_types))
    if options.type == 'rule' and not (options.start and options.end):
        op.error("Start and end times for snort rule queries must be specified")
    if options.type in ['full', 'fast'] and (options.start or options.end):
        op.error("Start and end arguments not supported for alert file input")
    if options.dry_run:
        options.verbose = True
    if len(args) > 1:
        op.error("Too many command-line arguments provided")
    return options, args, filterargs

re_full = re.compile(r"""
    (?P<mon>\d+)/(?P<day>\d+)-(?P<hour>\d+):(?P<min>\d+):(?P<sec>\d+)
    (?:
    .*?
    (?P<sip>[\d.]+):(?P<sport>\d+)\s*->\s*(?P<dip>[\d.]+):(?P<dport>\d+)
    \s*
    (?P<proto>[\S]+)
    )?
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_fast = re.compile(r"""
    (?P<mon>\w+)\s*(?P<day>\d+)\s*(?P<hour>\d+):(?P<min>\d+):(?P<sec>\d+)
    (?:
    .*?
    {(?P<proto>[\S]+)}
    \s*
    (?P<sip>[\d.]+):(?P<sport>\d+)\s*->\s*(?P<dip>[\d.]+):(?P<dport>\d+)
    )?
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_rule = re.compile(r"""
    (?P<action>\w+)
    \s+
    (?P<proto>\w+)
    \s+
    (?P<sip>\S+)
    \s+
    (?P<sport>\S+)
    \s+
    [<>\-]+
    \s+
    (?P<dip>\S+)
    \s+
    (?P<dport>\S+)
    \s+
    \(
        (?P<options>.+)
    \)
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_rule_options = re.compile(r"""
    (?P<option>[\S^:]+?)
    \s*
    :
    \s*
    (?P<value>[^:]+)
    \s*
    ;
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

supported_rule_opts = {
    'ip_proto': 'protocol',
    'itype': 'icmp-type',
    'icode': 'icmp-code',
    'flags': 'tcp-flags'
}


def tool_version_exit():
    print("""%s: Part of SiLK 3.19.1
Copyright (C) 2001-2020 by Carnegie Mellon University
GNU General Public License (GPL) Rights pursuant to Version 2, June 1991
Government Purpose License Rights (GPLR) pursuant to DFARS 252.227.7013
Send bug reports, feature requests, and comments to netsa-help@cert.org""" % (appname))
    sys.exit(0)


def expand_ip_spec(prefix, val, tempfiles):
    if val == "any":
#        return [], []
        return []
    args = []
    if val.find('[') >= 0:
        addrs = [x for x in val.strip('[]').split(',') if x.find('!')]
        notaddrs = [x.strip('!') for x in
                    [x for x in val.strip('[]').split(',') if not x.find('!')]]

        if len(addrs) > 0:
            f = tempfile.NamedTemporaryFile(suffix='.set')
            args.append('--%sipset=%s' %(prefix, f.name))
            tempfiles.append(f)
            proc = subprocess.Popen('rwsetbuild stdin stdout',
                                       shell=True,
                                       stdin=subprocess.PIPE,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE,
                                   )
            out, err = proc.communicate('\n'.join(addrs))
            f.write(out)
            f.file.flush()

# rwfilter doesn't currently support both --sipset and --not-sipset
# so we can't do this.  Snort address specs with ! in them will be
# ignored until further notice.
#
#        if len(notaddrs) > 0:
#            f = tempfile.NamedTemporaryFile(suffix='.set')
#            args.append('--not-%sipset=%s' %(prefix, f.name))
#            tempfiles.append(f)
#            proc = subprocess.Popen('rwsetbuild stdin stdout',
#                                       shell=True,
#                                       stdin=subprocess.PIPE,
#                                       stdout=subprocess.PIPE,
#                                       stderr=subprocess.PIPE,
#                                   )
#            out, err = proc.communicate('\n'.join(notaddrs))
#            f.write(out)
#            f.file.flush()
    else:
        args.append("--%saddress=%s" %(prefix, val))

    return args#, tempfiles

def expand_port_spec(prefix, val):
    args = []

    if val == "any":
        return args

    portlist = ""
    m = re.match('(?P<neg>!)?(?P<p1>\d+)?(?P<range>:)?(?P<p2>\d+)?', val)
    if m:
        if not m.group('neg'):
            if m.group('p1') and m.group('p2'):          # 123:456
                portlist = "%d-%d" %(int(m.group('p1')), int(m.group('p2')))
            elif m.group('p1') and m.group('range'):     # 123:
                portlist = "%d-65535" %(int(m.group('p1')))
            elif m.group('range') and m.group('p2'):     # :456
                portlist = "0-%d" %(int(m.group('p2')))
            else:
                portlist = "%d" %(int(m.group('p1')))    # 123
        else:
            if m.group('p1') and m.group('p2'):          # !123:456
                portlist = "0-%d,%d-65535" %(int(m.group('p1'))-1,
                    int(m.group('p2'))+1)
            elif m.group('p1') and m.group('range'):     # !123:
                portlist = "0-%d" %(int(m.group('p1'))-1)
            elif m.group('range') and m.group('p2'):     # !:456
                portlist = "%d-65535" %(int(m.group('p2'))+1)
            else:                                        # !123
               portlist = "0-%d,%d-65535" %(int(m.group('p1'))-1,
                int(m.group('p1'))+1)

        args.append('--%sport=%s' %(prefix, portlist))

    return args


def process_rule(file, vars, tempfiles):
    """Convert a Snort rule into corresponding rwfilter arguments.

    file -- the input file to be processed
    vars -- dict of snort variables from snort.conf
    tempfiles -- list of temporary files created (caller cleans them up)
    """

    opts = []
    rule = ' '.join(str(x, **coding) for x in file.readlines())
    if vars:
        for k in vars.keys():
            rule = rule.replace("$%s" %(k), vars[k])

    matches = re_rule.search(rule)
    if matches:
        for prefix in [ 's', 'd']:
            opts += expand_ip_spec(prefix,
                matches.group(prefix + 'ip'), tempfiles)
            opts += expand_port_spec(prefix, matches.group(prefix + 'port'))

        oi = re_rule_options.finditer(matches.group('options'))
        for opt in oi:
            if opt.group('option') in supported_rule_opts:
                optname = supported_rule_opts[opt.group('option')]
                optval = opt.group('value')
                if optname.find('icmp') >= 0:
                    # Snort and rwfilter denote ranges differently
                    if optval.startswith('<'):
                        optval = optval.replace('<', '')
                        optval = "0-%s" %( int(optval)-1 )
                    elif optval.startswith('>'):
                        optval = optval.replace('>', '')
                        optval = "%s-255" %( int(optval)+1 )
                elif optname.find('flags') >= 0:
                    # Tweak flags syntax.
                    if optval.find(',') > 0:
                        optval = optval[:optval.index(',')]
                    optval = optval.replace('+', '')
                    optval = optval.replace('*', 'C') # 1 == CWR
                    optval = optval.replace('1', 'C') # 1 == CWR
                    optval = optval.replace('2', 'E') # 2 == ECE
                opts.append('--%s=%s' %(optname, optval))

    return opts

def process_alert(file, type, year, tolerance):
    """Convert a Snort alert into corresponding rwfilter arguments.

    file -- the input file to be processed
    type -- type of alerts in the file (either 'fast' or 'full')
    year -- the year to be assumed for dates, since Snort timestamps lack one

    """

    opts = []
    matches = None

    record = ' '.join(str(x, **coding) for x in file.readlines())
    if type == 'full':
        matches = re_full.search(record)
        month = matches.group('mon')
    elif type == 'fast':
        matches = re_fast.search(record)
        month = list(calendar.month_abbr).index(matches.group('mon'))
    else:
        return None

    if matches:
        dt = datetime.datetime(
            int(year),
            int(month),
            int(matches.group('day')),
            int(matches.group('hour')),
            int(matches.group('min')),
            int(matches.group('sec')))

        stime_min = dt - datetime.timedelta(seconds=tolerance)
        stime_max = dt + datetime.timedelta(seconds=tolerance)

        start_date = datetime.datetime(stime_min.year, stime_min.month,
            stime_min.day, stime_min.hour)
        end_date = datetime.datetime(stime_max.year, stime_max.month,
            stime_max.day, stime_max.hour)

        opts.append('--start-date=%s' %(start_date.strftime("%Y/%m/%d:%H")))
        opts.append('--end-date=%s' %(end_date.strftime("%Y/%m/%d:%H")))

        opts.append('--stime=%s-%s' %(stime_min.strftime("%Y/%m/%d:%H:%M:%S"),
            stime_max.strftime("%Y/%m/%d:%H:%M:%S")))

        if matches.group('sip'):
            opts.append("--saddress=%s" %(matches.group('sip')))

        if matches.group('sport') != "any":
            opts.append("--sport=%s" %(matches.group('sport')))

        if matches.group('dip'):
            opts.append("--daddress=%s" %(matches.group('dip')))

        if matches.group('dport') != "any":
            opts.append("--dport=%s" %(matches.group('dport')))

        if matches.group('proto'):
            opts.append("--protocol=%d" %(
                get_protocol_number(matches.group('proto'))))

        return opts

def get_snort_vars(file):

    lines = [ item.replace('var', '').strip()
        for item in file.readlines()
        if item.strip().startswith('var') ]
    vars=dict([ i.split(None, 1) for i in lines ])
    return vars

def main():
    global appname
    appname = os.path.basename(sys.argv[0])
    rwfilter = os.getenv("RWFILTER", 'rwfilter')

    cmdline = [ rwfilter ]
    tempfiles = []
    options, args, filterargs = parse_options()

    if len(args) == 1:
        if args[0] in ['-', 'stdin']:
            f = sys.stdin
        else:
            try:
                f = open(args[0],'rb')
            except IOError:
                sys.stderr.write("%s: Input file '%s' not found\n"
                                 %(appname, args[0]))
                sys.exit(1)
    else:
        f = sys.stdin
        if f.isatty():
            sys.stderr.write("%s: Standard input is connected to a terminal\n"
                             %(appname))
            sys.exit(1)

    if options.type in [ 'full', 'fast' ]:
        try:
            cmdline += process_alert(f, options.type, options.year,
                                     int(options.tolerance))
        except:
            sys.stderr.write("%s: Input does not match expected intype '%s'\n"
                             %(appname, options.type))
            sys.exit(1)

    elif options.type == 'rule':
        vars = None
        if options.config_file:
            try:
                conf_file = open(options.config_file, 'rb')
            except IOError:
                sys.stderr.write("%s: Could not load snort conf file '%s'\n"
                                 %(appname, options.config_file))
                sys.exit(1)
            vars = get_snort_vars(conf_file)

        cmdline.append('--start-date=%s' %(options.start))
        cmdline.append('--end-date=%s' %(options.end))

        cmdline.append('--stime=%s-%s' %(options.start, options.end))
        try:
            cmdline += (process_rule(f, vars, tempfiles))
        except:
            sys.stderr.write("%s: Input does not match expected intype '%s'\n"
                             %(appname, options.type))
            sys.exit(1)


    # Mask out rwfilter options the user doesn't want to filter on
    if options.mask:
        cmdline = [x for x in cmdline if not x.rsplit('=')[0].replace('--','')
            in options.mask.split(',')]

    # Add in extra rwfilter args from the command line
    cmdline += filterargs

    cmdline.append('--pass=%s' %(options.output_file))

    if options.verbose:
        print(' '.join(cmdline), file=sys.stderr)
    if not options.dry_run:
        subprocess.call(cmdline)

if __name__ == '__main__':
    main()


# Local Variables:
# mode:python
# indent-tabs-mode:nil
# End:
