#!/usr/bin/env ruby
#
# Copyright (C) 2008-2009  Kouhei Sutou <kou@cozmixng.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

require 'pathname'
require 'time'
require 'find'
require 'optparse'
require 'net/smtp'
require 'digest/sha2'
require 'thread'

class MilterPerformanceTool
  def initialize
    @smtp_server = "localhost"
    @smtp_port = 25
    @helo_fqdn = "localhost.localdomain"
    @from = "from@example.com"
    @recipients = []
    @force_from = nil
    @force_recipients = nil
    @default_recipients = ["to@example.com"]
    @n_mails = 100
    @mails = []
    @period = nil
    @interval = nil
    @shuffle = false
    @mail_source_from_stdin = nil
    @mutex = Mutex.new
  end

  def parse_options(argv)
    opts = OptionParser.new do |opts|
      opts.separator ""
      opts.separator "Help options:"

      opts.on("-h", "--help", "Show this message") do
        puts opts
        exit(0)
      end

      opts.separator ""
      opts.separator "Application options:"

      opts.on("--smtp-server=SERVER",
              "Use SERVER as SMTP server",
              "(#{@smtp_server})") do |smtp_server|
        @smtp_server = smtp_server
      end

      opts.on("--smtp-port=PORT", Integer,
              "Use PORT as SMTP port",
              "(#{@smtp_port})") do |smtp_port|
        @smtp_port = smtp_port
      end

      opts.on("--helo-fqdn=FQDN",
              "Use FQDN for SMTP HELO command as default value.",
              "(#{@helo_fqdn})") do |helo_fqdn|
        @helo_fqdn = helo_fqdn
      end

      opts.on("--from=FROM",
              "Use FROM as envelope from address",
              "on SMTP MAIL command as default value.",
              "(#{@from})") do |from|
        @from = from
      end

      opts.on("--recipient=RECIPIENT",
              "Use RECIPIENT as envelope recipient address",
              "on SMTP RCPT command as default value.",
              "This option can be used n-times to set multi recipients.",
              "(#{@default_recipients.inspect})") do |recipient|
        @recipients << recipient
      end

      opts.on("--force-from=FROM",
              "Ensure using FROM as envelope from address on SMTP MAIL command.",
              "(#{@force_from})") do |from|
        @force_from = from
      end

      opts.on("--recipient=RECIPIENT",
              "Use RECIPIENT as envelope recipient address",
              "on SMTP RCPT command as default value.",
              "This option can be used n-times to set multi recipients.",
              "(#{@default_recipients.inspect})") do |recipient|
        @recipients << recipient
      end

      opts.on("--force-recipient=RECIPIENT",
              "Ensure using RECIPIENT as envelope recipient address",
              "on SMTP RCPT command as default value.",
              "This option can be used n-times to set multi recipients.",
              "(#{@force_recipients.inspect})") do |recipient|
        @force_recipients ||= []
        @force_recipients << recipient
      end

      opts.on("--n-mails=N", Integer,
              "Send a test mail N times",
              "This option is ignored when mail files are specified",
              "(#{@n_mails})") do |n_mails|
        @n_mails = n_mails
      end

      opts.on("--period=PERIOD",
              "Send mail files on average in PERIOD seconds/minutes/hours",
              "e.g.: 5s, 5m, 1.5h and so on. Default is seconds",
              "conflict option: --interval",
              "(#{@period})") do |period|
        if @interval
          raise OptionParser::InvalidOption, "can't use with --interval"
        end
        @period = parse_period(period)
      end

      opts.on("--interval=INTERVAL",
              "Send mail files at intervals of INTERVAL seconds/minutes/hours",
              "e.g.: 5s, 5m, 1.5h and so on. Default is seconds",
              "conflict option: --period",
              "(#{@interval})") do |interval|
        if @period
          raise OptionParser::InvalidOption, "can't use with --period"
        end
        @interval = parse_period(interval)
      end

      opts.on("--[no-]shuffle",
              "Shuffle target mails",
              "(#{@shuffle})") do |shuffle|
        @shuffle = shuffle
      end
    end
    @mails = opts.parse!(argv)
    @recipients = @default_recipients if @recipients.empty?
  end

  def run
    Thread.abort_on_exception = true
    @sources = []
    @rejected_sources = []
    @temporary_failure_sources = []
    @accumulated_elapsed_time = 0
    @max_elapsed_time = 0
    @min_elapsed_time = nil
    @n_temporary_failure_mails = 0
    @n_reject_mails = 0
    @n_error_mails = 0
    @n_processed_mails = 0
    mails = expand_mails(@mails)
    mails = [mails[0]] * @n_mails if mails.size <= 1
    mails = mails.sort_by {rand} if @shuffle

    start_time = Time.now
    begin
      if @interval
        threads = send_mails_in_interval(mails, @interval)
      elsif @period
        threads = send_mails_in_period(mails)
      else
        threads = send_mails_in_parallel(mails)
      end
      threads.each do |thread|
        thread.join
      end
    rescue Interrupt
    end
    @elapsed_time = Time.now - start_time
  end

  def report_sources
    unless @sources.empty?
      @sources.sort_by do |elapsed_time, source_file|
        -elapsed_time
      end[0, 10].each do |elapsed_time, source_file|
        puts "%#5.2f (sec): %s" % [elapsed_time, source_file]
      end
      puts
    end

    [
     ["Rejected mails:", @rejected_sources],
     ["Temporary failed mails:", @temporary_failure_sources],
    ].each do |label, sources|
      next if sources.empty?
      puts label
      sources.each  do |source_file|
        puts source_file
      end
      puts
    end
  end

  def report
    report_sources
    puts "Total:"
    puts "     N mails: %d" % @n_processed_mails
    puts "Elapsed time: %#5.2f (sec)" % @elapsed_time
    puts "     Average: %#5.2f (mails/sec)" % average(@elapsed_time)
    puts
    puts "Per mail:"
    puts "         Max: %#5.2f (sec)" % @max_elapsed_time
    puts "         Min: %#5.2f (sec)" % (@min_elapsed_time || 0)
    puts "     Average: %#5.2f (sec)" % average_per_mail(@accumulated_elapsed_time)
    puts
    puts "Status:"
    puts statistics_report("Temporary failure", @n_temporary_failure_mails)
    puts statistics_report("           Reject", @n_reject_mails)
    puts statistics_report("            Error", @n_error_mails)
  end

  private
  def expand_mails(mails)
    expanded_mails = []
    mails.each do |mail|
      Find.find(mail) do |file|
        expanded_mails << file if File.file?(file)
      end
    end
    expanded_mails.sort
  end

  def send_mails_in_interval(mails, interval)
    i = 0
    last = mails.size
    mails.collect do |mail|
      i += 1
      thread = Thread.start do
        send_mail(mail)
      end
      sleep(interval) if interval > 0 and i != last
      thread
    end
  end

  def send_mails_in_period(mails)
    send_mails_in_interval(mails, @period / mails.size)
  end

  def send_mails_in_parallel(mails)
    mails.collect do |mail|
      Thread.start do
        send_mail(mail)
      end
    end
  end

  def prepare_send_mail(mail_source_file)
    if mail_source_file
      if mail_source_file == "-" and File.exist?(mail_source_file)
        mail_source = read_mails_source_from_stdin
      else
        mail_source = File.read(mail_source_file)
      end
    else
      mail_source = generate_mail_source
    end
    helo_fqdn, from, recipients = parse_mail_source(mail_source)
    helo_fqdn ||= @helo_fqdn
    from ||= @from
    recipients ||= @recipients

    from = @force_from || from
    recipients = @force_recipients || recipients

    [helo_fqdn, from, recipients, mail_source]
  end

  def send_mail(mail_source_file=nil)
    helo_fqdn, from, recipients, source = prepare_send_mail(mail_source_file)

    temporary_failure = false
    reject = false
    error = false
    error_message = nil
    start_time = Time.now
    begin
      Net::SMTP.start(@smtp_server, @port, helo_fqdn) do |smtp|
        smtp.send_mail(source, from, *recipients)
      end
    rescue Net::SMTPServerBusy
      temporary_failure = true
    rescue Net::SMTPFatalError
      reject = true
    rescue Net::ProtoFatalError, Timeout::Error
      error = true
      error_object = $!
    end
    elapsed_time = Time.now - start_time
    puts "#{error_object.class}: #{error_object.message}" if error_object
    update_statistics(mail_source_file,
                      elapsed_time, temporary_failure, reject, error)
  end

  def update_statistics(source_file,
                        elapsed_time, temporary_failure, reject, error)
    @mutex.synchronize do
      if source_file
        @sources << [elapsed_time, source_file]
        @rejected_sources << source_file if reject
        @temporary_failure_sources << source_file if temporary_failure
      end
      @accumulated_elapsed_time += elapsed_time
      @n_processed_mails += 1
      @max_elapsed_time = [elapsed_time, @max_elapsed_time].max
      @min_elapsed_time = [elapsed_time, @min_elapsed_time || elapsed_time].min
      @n_temporary_failure_mails += 1 if temporary_failure
      @n_reject_mails += 1 if reject
      @n_error_mails += 1 if error
    end
  end

  def read_mails_source_from_stdin
    if @mail_source_from_stdin
      @mail_source_from_stdin.gsub(/^Message-Id:.*$/) do
        "Message-Id: <#{random_tag}@mail.example.com>"
      end
    else
      @mail_source_from_stdin = ARGF.read
    end
  end

  def parse_mail_source(source)
    header_part, body_part = source.split(/(?:\r?\n){2}/, 2)
    _, *names_and_values = header_part.split(/^([a-z][a-z\-]+):\s*/i)
    headers = {}
    received = []
    until names_and_values.empty?
      name = names_and_values.shift
      value = names_and_values.shift
      value = value.chomp.gsub(/(?:\r?\n)\s*/, ' ')
      received << value if name == "Received"
      headers[name] = value
    end
    helo_fqdn = extract_helo_fqdn_from_received(received[0])
    from = extract_mail_address(headers["From"])
    recipients = parse_recipient_header(headers["To"])
    recipients += parse_recipient_header(headers["Cc"])
    recipients.collect do |recipient|
      extract_mail_address(recipient)
    end
    recipients = nil if recipients.empty?
    [helo_fqdn, from, recipients]
  end

  def parse_recipient_header(header_value)
    return [] if header_value.nil?
    header_value.split(/\s*,\s*/)
  end

  def extract_helo_fqdn_from_received(received)
    return nil if received.nil?
    if /\Afrom ([a-z.]+)/i =~ received
      $1
    else
      nil
    end
  end

  def extract_mail_address(address)
    return nil if address.nil?
    if /<(.+?)>/ =~ address
      $1
    else
      address.gsub(/\(.*?\)/, '').strip
    end
  end

  def parse_period(period)
    if /\A(\d+(?:.\d+)?)(s|sec|seconds?|m|min|minutes?|h|hours?)?\z/i =~ period
      numeric = $1
      unit = $2
      numeric = Float(numeric)
      unit ||= "seconds"
      case unit.downcase[0]
      when ?s
        numeric
      when ?m
        numeric * 60
      when ?h
        numeric * 60 * 60
      else
        raise OptionParser::InvalidArgument, "invalid period unit"
      end
    else
      raise OptionParser::InvalidArgument, "invalid period format"
    end
  end

  def generate_mail_source
    now = Time.now.rfc2822
    <<-EOM
Return-Path: <#{@from}>
Received: from #{@helo_fqdn} (#{@helo_fqdn} [192.168.1.1])
	by mail.example.com with ESMTP id #{generate_id};
	#{now}
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Type: text/plain; charset=US-ASCII
X-Mailer: milter-performance-check
Message-Id: <#{random_tag}@mail.example.com>
Subject: test mail
From: #{@from}
To: #{@recipients.join(', ')}
Date: #{now}

Hello,

This is a test mail.
EOM
  end

  def generate_id
    characters = ("0".."9").to_a + ("a".."z").to_a + ("A".."Z").to_a
    length = 10
    id = ""
    length.times do
      id << characters[rand(characters.size)]
    end
    id
  end

  def random_tag
    Digest::SHA2.hexdigest("#{Time.now.to_i.to_s}.#{rand(Time.now.to_i)}")
  end

  def average(elapsed_time)
    if elapsed_time.zero?
      @n_processed_mails
    else
      @n_processed_mails / elapsed_time.to_f
    end
  end

  def average_per_mail(elapsed_time)
    if @n_processed_mails.zero?
      elapsed_time.to_f
    else
      elapsed_time.to_f / @n_processed_mails
    end
  end

  def statistics_report(label, n_mails)
    " #{label}: %3d (%5.2f%%)" % [n_mails,
                                  n_mails.to_f / @n_processed_mails * 100]
  end
end

if __FILE__ == $0
  performance_tool = MilterPerformanceTool.new
  performance_tool.parse_options(ARGV)
  performance_tool.run
  performance_tool.report
end

# vi:ts=2:nowrap:ai:expandtab:sw=2
