/*  propagator.cpp
 *
 *  Copyright (C) 2010-2012 Andreas von Manteuffel
 *  Copyright (C) 2010-2012 Cedric Studerus
 *
 *  This file is part of the package Reduze 2.
 *  It is distributed under the GNU General Public License version 3
 *  (see the file GPL-3.0.txt or http://www.gnu.org/licenses/gpl-3.0.txt).
 */
#include "propagator.h"
#include "functions.h"
#include "ginacutils.h"
#include <exception>
#include "files.h"
#include "globalsymbols.h"
#include "integralfamily.h"
#include "kinematics.h"
#include "yamlutils.h"

GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(Propagator, GiNaC::basic,
		print_func<GiNaC::print_context>(&Propagator::do_print_dflt).
		print_func<Reduze::print_mma>(&Propagator::do_print_mma))
GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ScalarProduct, GiNaC::basic,
		print_func<GiNaC::print_context>(&ScalarProduct::do_print_dflt).
		print_func<Reduze::print_mma>(&ScalarProduct::do_print_mma))

using namespace GiNaC;
using namespace std;
using namespace Reduze;

Propagator::Propagator(const GiNaC::ex& momentum, const GiNaC::ex& squaredmass) :
		momentum1_(momentum), momentum2_(momentum), squaredmass_(squaredmass) {
}

Propagator::Propagator(const GiNaC::ex& k1, const GiNaC::ex& k2,
		const GiNaC::ex& sqm) :
		momentum1_(k1), momentum2_(k2), squaredmass_(sqm) {
}

static ex Prop_eval(const ex& e1, const ex& e2) {
	return Propagator(e1, e2);
}

static ex GProp_eval(const ex& e1, const ex& e2, const ex& e3) {
	return Propagator(e1, e2, e3);
}

static ex SP_eval(const ex& e1, const ex& e2) {
	return ScalarProduct(e1, e2);
}

REGISTER_FUNCTION(Prop, eval_func(Prop_eval))
REGISTER_FUNCTION(GProp, eval_func(GProp_eval))
REGISTER_FUNCTION(SP, eval_func(SP_eval))

Propagator::Propagator() {
}

Propagator::~Propagator() {
}

bool Propagator::has_squared_momentum() const {
	return momentum1_.is_equal(momentum2_);
}

GiNaC::ex Propagator::momentum() const {
	if (!has_squared_momentum())
		throw std::runtime_error("propagator has no uniquely defined momentum");
	return momentum1_;
}

GiNaC::ex Propagator::momentum1() const {
	return momentum1_;
}

GiNaC::ex Propagator::momentum2() const {
	return momentum2_;
}

bool has_overall_negative_sign(const ex& e) {
	if (is_a < mul > (e)) {
		for (size_t i = 0; i < e.nops(); ++i)
			if (e.op(i).info(GiNaC::info_flags::negative)) {
				return true;
			}
	}
	return false;
}

/// uniquely picks one of {e1,e2}, intended to be the simpler one if e2 == -e1
/** function is symmetric in e1 and e2,
 ** if is_a<basic>(e1) and e2 == -e1, e1 is returned (relevant for wildcards) */
ex pick_simpler_sign(const ex& e1, const ex& e2) {
	if (has_overall_negative_sign(e1) && !has_overall_negative_sign(e2)) {
		return e2;
	} else if (!has_overall_negative_sign(e1)
			&& has_overall_negative_sign(e2)) {
		return e1;
	} else if (is_a < mul > (e1) && is_a < mul > (e2)
			&& e1.nops() != e2.nops()) {
		return e1.nops() < e2.nops() ? e1 : e2;
	} else if (is_a < add > (e1) && is_a < add > (e2)
			&& e1.nops() != e2.nops()) {
		return e1.nops() < e2.nops() ? e1 : e2;
	} else if (is_a < add > (e1) && is_a < add > (e2) && e1.nops() == e2.nops() //
	&& e1.nops() > 0) {
		int num_neg1 = 0;
		int num_neg2 = 0;
		for (size_t i = 0; i < e1.nops(); ++i)
			if (has_overall_negative_sign(e1.op(i)))
				++num_neg1;
		for (size_t i = 0; i < e2.nops(); ++i)
			if (has_overall_negative_sign(e2.op(i)))
				++num_neg2;
		if (num_neg1 != num_neg2)
			return num_neg2 < num_neg1 ? e2 : e1;
		bool neg_first1 = has_overall_negative_sign(e1.op(0));
		bool neg_first2 = has_overall_negative_sign(e2.op(0));
		if (neg_first1 != neg_first2)
			return neg_first1 ? e2 : e1;
	}
	return (e2.compare(e1) < 0 ? e2 : e1);
}

#ifdef NEW_GINAC
ex Propagator::eval() const {
    // ex subexpressions are automatically evaluated in GiNaC >= 1.7
	ex k1 = momentum1_;
	ex k1m = -momentum1_;
	ex k2 = momentum2_;
	ex k2m = -momentum2_;
	ex msq = squaredmass_;
#else
    ex Propagator::eval(int level) const {
    // eval subexpressions first
    ex k1 = momentum1_.eval();
    ex k1m = (level != 1 ? (-momentum1_).eval(level - 1) : -momentum1_);
    ex k2 = (level != 1 ? momentum2_.eval(level - 1) : momentum2_);
    ex k2m = (level != 1 ? (-momentum2_).eval(level - 1) : -momentum2_);
    ex msq = (level != 1 ? squaredmass_.eval(level - 1) : squaredmass_);
#endif
	// note: be careful with swapping signs such that wildcards still work
	ex k1min = pick_simpler_sign(k1, k1m);
	ex k2min = pick_simpler_sign(k2, k2m);
	ex kmin = (k1min.compare(k2min) <= 0 ? k1min : k2min);
	if (kmin.is_equal(momentum1_))
		return this->hold();
	else if (kmin.is_equal(k1))
		return Propagator(k1, k2, msq).hold();
	else if (kmin.is_equal(k1m))
		return Propagator(k1m, k2m, msq).hold(); // swap signs of both momenta
	else if (kmin.is_equal(k2))
		return Propagator(k2, k1, msq).hold();
	else
		// (kmin.is_equal(k2m))
		return Propagator(k2m, k1m, msq).hold(); // swap signs of both momenta
}

int Propagator::compare_same_type(const GiNaC::basic& other) const {
	// GiNaC calls of this method guarantee the following cast is safe
	const Propagator& o = static_cast<const Propagator&>(other); // safe
	if (!momentum1_.is_equal(o.momentum1_))
		return momentum1_.compare(o.momentum1_);
	if (!momentum2_.is_equal(o.momentum2_))
		return momentum2_.compare(o.momentum2_);
	if (!squaredmass_.is_equal(o.squaredmass_))
		return squaredmass_.compare(o.squaredmass_);
	return 0;
}

GiNaC::ex Propagator::derivative_contracted(const GiNaC::symbol& q,
		const GiNaC::symbol& k) const {
	const ex dsp_dk = (momentum1_ * momentum2_).diff(k);
	return -pow(*this, 2) * ScalarProduct(q, dsp_dk); // q * d/dk (*this)
}

/// decompose mom = mom_loop + mom_ext such that mom_loop/mom_ext is free of external/loop momenta
void decompose(const GiNaC::ex& mom, GiNaC::ex& mom_loop, GiNaC::ex& mom_ext,
		const Reduze::IntegralFamily* fam) {
	mom_loop = mom_ext = 0;
	if (is_a < add > (mom)) {
		for (GiNaC::const_iterator i = mom.begin(); i != mom.end(); ++i) {
			ex ml, me;
			decompose(*i, ml, me, fam);
			mom_loop += ml;
			mom_ext += me;
		}
	} else {
		if (Reduze::freeof(mom, fam->loop_momenta())
				&& !Reduze::freeof(mom, fam->kinematics()->external_momenta()))
			mom_ext = mom;
		else if (!Reduze::freeof(mom, fam->loop_momenta())
				&& Reduze::freeof(mom, fam->kinematics()->external_momenta()))
			mom_loop = mom;
		else
			throw std::runtime_error(
					std::string("Failed to decompose momentum ")
							+ Reduze::to_string(mom));
	}
}

GiNaC::ex Propagator::mass_expand(const Reduze::IntegralFamily* fam,
		const GiNaC::ex& M) const {
	if (!has_squared_momentum())
		throw std::runtime_error("error: no unique momentum");
	ex mom = momentum1_.expand();
	ex k(0), q(0);
	decompose(mom, k, q, fam);
	ex m2 = squaredmass_;
	ex M2 = GiNaC::power(M, 2);
	// (-m^2 + (k + q)^2)^(-1) -> (k^2 - M^2)^(-1) + (m^2 - M^2 - 2*k*q - q^2)/((k^2 - M^2)*(-m^2 + (k + q)^2))
	ex res1 = Propagator(k, M2);
	ex res2 = (m2 - M2 - 2 * ScalarProduct(k, q) - ScalarProduct(q, q))
			* Propagator(k, M) * Propagator(k + q, m2);
	return res1 + res2;
}

GiNaC::ex Propagator::inverse_ex() const {
	return ScalarProduct(momentum1_, momentum2_) - squaredmass_;
}

size_t Propagator::nops() const {
	return 3;
}

GiNaC::ex Propagator::op(size_t i) const {
	switch (i) {
	case 0:
		return momentum1_;
	case 1:
		return momentum2_;
	case 2:
		return squaredmass_;
	default:
		throw std::range_error("Invalid operand");
	}
}

GiNaC::ex& Propagator::let_op(size_t i) {
	switch (i) {
	case 0:
		return momentum1_;
	case 1:
		return momentum2_;
	case 2:
		return squaredmass_;
	default:
		throw std::range_error("Invalid operand");
	}
}

void Propagator::do_print_dflt(const GiNaC::print_context & c,
		unsigned level) const {
	if (has_squared_momentum())
		c.s << "Prop" << '(' << momentum1_ << ',' << squaredmass_ << ')';
	else
		c.s << "GProp" << '(' << momentum1_ << ',' << momentum2_ << ','
				<< squaredmass_ << ')';
}

void Propagator::do_print_mma(const Reduze::print_mma & c,
		unsigned level) const {
	if (has_squared_momentum())
		c.s << "Prop" << '[' << momentum1_ << ',' << squaredmass_ << ']';
	else
		c.s << "GProp" << '[' << momentum1_ << ',' << momentum2_ << ','
				<< squaredmass_ << ']';
}

// YAML I/O for Propagator

Propagator Propagator::read(const Reduze::Kinematics*kin,
		const GiNaC::lst& loopmoms, const YAML::Node& n) {
	lst allmoms = add_lst(kin->external_momenta(), loopmoms);
	const GiNaC::exmap& mom_conserv = kin->rule_momentum_conservation();

	const YAML::Node* np = 0;
	std::string type;
	if (n.Type() == YAML::NodeType::Sequence && n.size() == 2) {
		type = "standard";
		np = &n;
	} else if (n.Type() == YAML::NodeType::Map && n.size() == 1) {
		n.begin().first() >> type;
		np = &n.begin().second();
	} else {
		throw runtime_error("unknown entry in propagators " + position_info(n));
	}

	string str;
	ex k1, k2, msq;
	if (type == "standard") { // standard propagator with k^2
		(*np)[0] >> str;
		k1 = k2 = ex(str, allmoms).subs(mom_conserv);
		(*np)[1] >> str;
		msq = ex(str, kin->kinematic_invariants());
		return Propagator(k1, k2, msq);
	} else if (type == "bilinear") { // generalised propagator with k1.k2
		(*np)[0][0] >> str;
		k1 = ex(str, allmoms).subs(mom_conserv);
		(*np)[0][1] >> str;
		k2 = ex(str, allmoms).subs(mom_conserv);
		(*np)[1] >> str;
		msq = ex(str, kin->kinematic_invariants());
		return Propagator(k1, k2, msq);
	} else {
		throw runtime_error(
				"unknown propagator type '" + type + "' " + position_info(n));
	}
}

void Propagator::print(YAML::Emitter& os) const {
	using namespace YAML;
	// quick hack for now: we have only two types of propagators
	if (momentum1_.is_equal(momentum2_)) {
		os << Flow << BeginMap << Key << "standard" << Value;
		os << BeginSeq << momentum1_ << squaredmass_ << EndSeq;
		os << EndMap;
	} else if (!momentum1_.is_equal(momentum2_)) {
		os << Flow << BeginMap << Key << "bilinear" << Value;
		os << BeginSeq << BeginSeq << momentum1_ << momentum2_ << EndSeq;
		os << squaredmass_ << EndSeq;
		os << EndMap;
	} else {
		ABORT("unknown propagator type");
	}
}

// ScalarProduct

ScalarProduct::ScalarProduct() {
}

ScalarProduct::ScalarProduct(const GiNaC::ex& left, const GiNaC::ex& right) :
		left_(left), right_(right) {
}

/// returns expr divided by numerical factors, prefactor is multiplied by them
/** considers numerical factors only if expr is a GiNaC::mul and numerical
 ** factors appear as direct operands of this mul **/
static ex without_numeric_factor(const ex& expr, ex& prefactor) {
	if (!is_a < mul > (expr))
		return expr;
	ex num_symbs = Reduze::Files::instance()->globalsymbols()->all();
	ex result = 1;
	for (size_t i = 0; i < expr.nops(); ++i)
		if (is_a < numeric > (expr.op(i)))
			prefactor *= expr.op(i);
		else if (is_a < GiNaC::power > (expr.op(i))
				&& is_a < symbol > (expr.op(i).op(0))
				&& num_symbs.has(expr.op(i).op(0))
				&& is_a < numeric > (expr.op(i).op(1)))
			prefactor *= expr.op(i);
		else if (is_a < symbol > (expr.op(i)) && num_symbs.has(expr.op(i)))
			prefactor *= expr.op(i);
		else
			result *= expr.op(i);
	return result;
}

#ifdef NEW_GINAC
GiNaC::ex ScalarProduct::eval() const {
    // eval subexpressions first
    ex l = left_;
    ex r = right_;
#else
GiNaC::ex ScalarProduct::eval(int level) const {
    // eval subexpressions first
    ex l = (level != 1 ? left_.eval(level - 1) : left_);
    ex r = (level != 1 ? right_.eval(level - 1) : right_);
#endif
	// bilinearity: (l,number*x) -> number*(l,x),  (number*x,r) -> number*(x,r)
	ex prefac = 1;
	l = without_numeric_factor(l, prefac);
	r = without_numeric_factor(r, prefac);
	// commutativity: (l,r) -> (r,l) if not in canonical order
	if (l.compare(r) > 0)
		swap(l, r);
	if (l.is_zero() || r.is_zero())
		return 0;
	else if (l.is_equal(left_) && r.is_equal(right_))
		return this->hold();
	else
		return prefac * ScalarProduct(l, r).hold();
}

GiNaC::ex ScalarProduct::expand(unsigned options) const {
	ex result = 0;
	ex l = left_.expand(options);
	ex r = right_.expand(options);
#ifdef NEW_GINAC
    l = is_a < add > (l) ? l : lst({l});
    r = is_a < add > (r) ? r : lst({r});
#else
    l = is_a < add > (l) ? l : lst(l);
    r = is_a < add > (r) ? r : lst(r);
#endif
	for (size_t i = 0; i < l.nops(); ++i)
		for (size_t j = 0; j < r.nops(); ++j)
			result += ScalarProduct(l.op(i), r.op(j)).eval();
	return result;
}

int ScalarProduct::compare_same_type(const GiNaC::basic & other) const {
	const ScalarProduct& o = static_cast<const ScalarProduct&>(other); // safe
	if (!left_.is_equal(o.left_))
		return left_.compare(o.left_);
	else
		return right_.compare(o.right_);
}

size_t ScalarProduct::nops() const {
	return 2;
}

GiNaC::ex ScalarProduct::op(size_t i) const {
	switch (i) {
	case 0:
		return left_;
	case 1:
		return right_;
	default:
		throw std::range_error("Invalid operand");
	}
}

GiNaC::ex& ScalarProduct::let_op(size_t i) {
	switch (i) {
	case 0:
		return left_;
	case 1:
		return right_;
	default:
		throw std::range_error("Invalid operand");
	}
}

void ScalarProduct::do_print_dflt(const GiNaC::print_context & c,
		unsigned level) const {
	c.s << "SP" << '(' << left_ << ',' << right_ << ')';
}

void ScalarProduct::do_print_mma(const Reduze::print_mma & c,
		unsigned level) const {
	c.s << "SP" << '[' << left_ << ',' << right_ << ']';
}
