--[[
# Copyright 2001-2014 Cisco Systems, Inc. and/or its affiliates. All rights
# reserved.
#
# This file contains proprietary Detector Content created by Cisco Systems,
# Inc. or its affiliates ("Cisco") and is distributed under the GNU General
# Public License, v2 (the "GPL").  This file may also include Detector Content
# contributed by third parties. Third party contributors are identified in the
# "authors" file.  The Detector Content created by Cisco is owned by, and
# remains the property of, Cisco.  Detector Content from third party
# contributors is owned by, and remains the property of, such third parties and
# is distributed under the GPL.  The term "Detector Content" means specifically
# formulated patterns and logic to identify applications based on network
# traffic characteristics, comprised of instructions in source code or object
# code form (including the structure, sequence, organization, and syntax
# thereof), and all documentation related thereto that have been officially
# approved by Cisco.  Modifications are considered part of the Detector
# Content.
--]]
--[[
detection_name: NetBIOS-ssn (SMB)
version: 9
description: Netbios session service, also known as SMB.
bundle_description: $VAR1 = {
          'SMBv1' => 'Server Message Block version 1, a set of early SMB dialects including SMB, SMB1, and CIFS.',
          'SMBv3-unencrypted' => 'Server Message Block version 3, more recent SMB dialects including SMB 3.0, SMB 3.0.1, and SMB 3.1.1.',
          'SMBv3-encrypted' => 'Server Message Block version 3, encrypted traffic.',
          'SMBv2' => 'Server Message Block version 2. This set of SMB dialects includes SMB 2.0 and SMB 2.1.',
          'NetBIOS-ssn (SMB)' => 'Netbios session service, also known as SMB.'
        };

--]]

require "DetectorCommon"

local DC = DetectorCommon
local FT = flowTrackerModule

gServiceId = 17
gServiceName = 'NetBIOS-ssn (SMB)'
gDetector = nil

DetectorPackageInfo = {
    name =  "NetBIOS-ssn (SMB)",
    proto =  DC.ipproto.tcp,
    server = {
        init =  'DetectorInit',
        validate =  'DetectorValidator',
        fini = 'DetectorFini',
    }
}

gSfAppIdSMB = 755
gSfAppIdSMBv1 = 4645
gSfAppIdSMBv2 = 4646
gSfAppIdSMBv3_encrypted = 4647
gSfAppIdSMBv3_unencrypted = 4665

gPatterns = {
    smbanner1 = { "\255SMB", 4, gSfAppIdSMB},
    smbanner2 = { "\254SMB", 4, gSfAppIdSMB},
    smbanner3 = { "\253SMB", 4, gSfAppIdSMB},
    nbss_response = { "\130", 0, gSfAppIdSMB},
}

gFastPatterns = {
    {DC.ipproto.tcp, gPatterns.smbanner1},
    {DC.ipproto.tcp, gPatterns.smbanner2},
    {DC.ipproto.tcp, gPatterns.smbanner3},
}

gPorts = {
    {DC.ipproto.tcp, 139},
    {DC.ipproto.tcp, 445},
}

gAppRegistry = {
	{gSfAppIdSMB, 1},
    {gSfAppIdSMBv1, 1},
    {gSfAppIdSMBv2, 1},
    {gSfAppIdSMBv3_encrypted, 1},
    {gSfAppIdSMBv3_unencrypted, 1},
}

function serviceInProcess(context)

    local flowFlag = context.detectorFlow:getFlowFlag(DC.flowFlags.serviceDetected)

    if ((not flowFlag) or (flowFlag == 0)) then
        gDetector:inProcessService()
    end

    DC.printf('%s: Inprocess, packetCount: %d\n', gServiceName, context.packetCount);
    return DC.serviceStatus.inProcess
end

function serviceSuccess(context)
    local flowFlag = context.detectorFlow:getFlowFlag(DC.flowFlags.serviceDetected)

    if context.payload_id and context.add_payload then
        DC.printf("%s: adding payload %d\n", gServiceName, context.payload_id)
        gDetector:service_analyzePayload(context.payload_id)
    end

    if ((not flowFlag) or (flowFlag == 0)) then
        gDetector:addService(gServiceId, "", "", gSfAppIdSMB)
    end

    DC.printf('%s: Detected, packetCount: %d\n', gServiceName, context.packetCount);
    return DC.serviceStatus.success
end

function serviceFail(context)
    local flowFlag = context.detectorFlow:getFlowFlag(DC.flowFlags.serviceDetected)

    if ((not flowFlag) or (flowFlag == 0)) then
        gDetector:failService()
    end

    context.detectorFlow:clearFlowFlag(DC.flowFlags.continue)
    DC.printf('%s: Failed, packetCount: %d\n', gServiceName, context.packetCount);
    return DC.serviceStatus.nomatch
end

function registerPortsPatterns()

    for i,v in ipairs(gPorts) do
        gDetector:addPort(v[1], v[2])
    end

    for i,v in ipairs(gFastPatterns) do
        if ( gDetector:registerPattern(v[1], v[2][1], #v[2][1], v[2][2], v[2][3]) ~= 0) then
            DC.printf ('%s: register pattern failed for %s\n', gServiceName,v[2][1])
        else
            DC.printf ('%s: register pattern successful for %s\n', gServiceName,v[2][1])
        end
    end

	for i,v in ipairs(gAppRegistry) do
		pcall(function () gDetector:registerAppId(v[1],v[2]) end)
	end

end

function check_smbv3_dialect(size)
    if size >= 18 then
        -- we need the smb2 header length to find our dialect field
        match, smb2_hdr_len_raw = gDetector:getPcreGroups("(..)", 8)
        smb2_hdr_len = DC.reverseBinaryStringToNumber(smb2_hdr_len_raw, 2)
        -- we are only interested in Negotiate Protocol Response packets - the cmd is "0"
        match, smb2_cmd_raw = gDetector:getPcreGroups("(..)", 16)
        smb2_cmd = DC.reverseBinaryStringToNumber(smb2_cmd_raw, 2)
        DC.printf("%s: check_smbv3_dialect: smb2 header size %d, total size %d, smb2_cmd %d\n",
            gServiceName, smb2_hdr_len, size, smb2_cmd) 
        dialect_index = 4 + smb2_hdr_len + 5
        DC.printf("%s: dialect index is %d\n", gServiceName, dialect_index)
        if smb2_cmd == 0 and size > dialect_index then
            match, dialect_raw = gDetector:getPcreGroups("(.)", dialect_index)
            dialect = DC.binaryStringToNumber(dialect_raw, 1)
            DC.printf("%s: dialect is %d\n", gServiceName, dialect)
            if dialect >= 3 then
                return 1
            end
        end
    end
    return nil
end

function DetectorInit(detectorInstance)
    gDetector = detectorInstance
    DC.printf ('%s:DetectorInit()\n', gServiceName)
    gDetector:init(gServiceName, 'DetectorValidator', 'DetectorFini')
    registerPortsPatterns()
    return gDetector
end

function DetectorValidator()
    local context = {}
    context.detectorFlow = gDetector:getFlow()
    context.packetDataLen = gDetector:getPacketSize()
    context.packetDir = gDetector:getPacketDir()
    context.flowKey = context.detectorFlow:getFlowKey()
    context.packetCount = gDetector:getPktCount()
    local size = context.packetDataLen
    local dir = context.packetDir
    local flowKey = context.flowKey

    DC.printf ('%s:DetectorValidator(): packetCount %d, dir %d, size %d\n',
        gServiceName, context.packetCount, dir, size)

    if size == 0 or dir == 0 then
        return serviceInProcess(context)
    end

    local rft = FT.getFlowTracker(flowKey)
    if not rft then
        rft = FT.addFlowTracker(flowKey, {nbss_count = 0})
    end

    if dir == 1 and size >= 4 then
        matched, len_raw = gDetector:getPcreGroups('.(...)',0)
        if (matched) then
            len = DC.binaryStringToNumber(len_raw, 3)
            DC.printf("len %d, size-4 %d\n", len, size-4)
            if (len == size-4) then
                if DC.checkPattern(gDetector, gPatterns.nbss_response) then                         
                    DC.printf("NBSS packet\n")
                    return serviceInProcess(context)
                -- check the header to determine SMB version
                elseif size >= 10 and DC.checkPattern(gDetector, gPatterns.smbanner1) then
                    DC.printf("detected SMBv1\n")
                    context.payload_id = gSfAppIdSMBv1
                elseif size >= 10 and DC.checkPattern(gDetector, gPatterns.smbanner2) then
                    DC.printf("detected SMBv2\n")
                    if rft.smbv3_dialect or check_smbv3_dialect(size) then
                        DC.printf("dialect SMBv3\n")
                        rft.smbv3_dialect = 1
                        context.payload_id = gSfAppIdSMBv3_unencrypted
                        context.add_payload = 1
                    else
                        DC.printf("dialect IS NOT SMBv3\n")
                        context.payload_id = gSfAppIdSMBv2
                    end
                elseif size >= 10 and DC.checkPattern(gDetector, gPatterns.smbanner3) then
                    DC.printf("detected SMBv3 banner\n")
                    rft.smbv3_dialect = 1
                    context.payload_id = gSfAppIdSMBv3_encrypted
                    context.add_payload = 1
                else
                    -- fail if we don't see one of those headers
                    return serviceFail(context)
                end

                rft.nbss_count = rft.nbss_count + 1
                DC.printf("rft.nbss_count %d\n", rft.nbss_count)
                if rft.nbss_count >= 5 then
                    context.detectorFlow:clearFlowFlag(DC.flowFlags.continue)
                    context.add_payload = 1
                    return serviceSuccess(context)
                else
                    context.detectorFlow:setFlowFlag(DC.flowFlags.continue)
                    return serviceSuccess(context)
                end

            end
        end
    end

    return serviceFail(context)
end

function DetectorFini()
end

