/*  crossing.cpp
 *
 *  Copyright (C) 2010-2012 Andreas von Manteuffel
 *  Copyright (C) 2010-2012 Cedric Studerus
 *
 *  This file is part of the package Reduze 2.
 *  It is distributed under the GNU General Public License version 3
 *  (see the file GPL-3.0.txt or http://www.gnu.org/licenses/gpl-3.0.txt).
 */

#include "crossing.h"
#include "files.h"
#include "functions.h"
#include "ginacutils.h"
#include "yamlutils.h"
#include "integralfamily.h"
#include "equation.h"
#include "kinematics.h"
#include "sector.h"

using namespace std;
using namespace GiNaC;

namespace Reduze {

// register type
namespace {
YAMLProxy<Crossing> dummy;
}

Crossing::Crossing(const Kinematics* kin) :
	kinematics_(kin), is_equivalent_to_identity_(true), has_symb2one_in_lhs_(
			false) {
}

// the permutation must transform the legs which are labeled from {1, ... n}
Crossing::Crossing(const Kinematics* kin, const Permutation& perm) :
	kinematics_(kin), permutation_(perm), is_equivalent_to_identity_(true),
			has_symb2one_in_lhs_(false) {

	const unsigned no_pat = subs_options::no_pattern;
	const lst& em = kinematics_->external_momenta();
	const lst& iem = kinematics_->independent_external_momenta();
	exmap em2alt, alt2em;
	lst alt;
	provide_alternative_symbols(em, alt, em2alt, alt2em);
	const exmap& rmc = kinematics_->rule_momentum_conservation();
	const exmap inv2sp = kinematics_->find_rules_invariants_to_sp();
	const exmap sp2inv = kinematics_->rules_sp_to_invariants();
	const lst& in_mom = kinematics_->incoming_momenta();
	const vector<vector<int> >& vec = permutation_.vector_rep();
	const int num_legs = em.nops();

	// verify all numbers *v of the permutation fulfill 1 <= *v <= num_legs
	for (vector<vector<int> >::const_iterator c = vec.begin(); c != vec.end(); ++c)
		for (vector<int>::const_iterator v = c->begin(); v != c->end(); ++v)
			if (*v < 1 || *v > num_legs)
				throw runtime_error("Invalid crossing: invalid leg number "
						+ to_string(*v) + " in " + to_string(permutation_));

	// determine name and the map to transform external momenta and kinematic invariants
	name_ = "";
	lst eqs;
	set<int> visited;
	for (vector<vector<int> >::const_iterator c = vec.begin(); c != vec.end(); ++c) {
		name_ += "x";
		for (vector<int>::const_iterator v = c->begin(); v != c->end();) {
			string vstr = to_string(*v);
			bool large_digit = *v > 10 ? true : false;
			visited.insert(*v);
			bool in1 = in_mom.has(em.op(*v - 1)), in2; // incoming ?
			ex mom1 = em.op(*v - 1).subs(rmc).expand(), mom2;
			++v;
			if (v == c->end()) {
				mom2 = em.op(*c->begin() - 1);
				in2 = in_mom.has(mom2); // incoming ?
			} else {
				mom2 = em.op(*v - 1);
				in2 = in_mom.has(mom2); // incoming ?
				if (large_digit)
					vstr += "p";
			}
			mom2 = mom2.subs(rmc, no_pat).expand();
			int sign = (in1 == in2 ? 1 : -1);
			ex tst = ScalarProduct(mom1, mom1) - ScalarProduct(mom2, mom2);
			tst = tst.subs(rmc, no_pat).expand().subs(sp2inv, no_pat).expand();
			if (!tst.is_zero())
				throw runtime_error(
						"Invalid crossing: external legs with different masses found: "
								+ to_string(permutation_));
			eqs.append((mom1.subs(em2alt, no_pat)) == sign * mom2);
			name_ += vstr;
		}
	}
	// append also the momenta which are not permuted
	for (int i = 1; i <= num_legs; ++i)
		if (visited.find(i) == visited.end()) {
			ex mom = em.op(i - 1).subs(rmc).expand();
			eqs.append((mom.subs(em2alt, no_pat)) == mom);
		}

	lst altiem = ex_to<lst> (ex(iem).subs(em2alt));
	eqs = ex_to<lst> (lsolve(eqs, altiem));
	rules_momenta_ = equations_to_substitutions(eqs, alt2em, true);
	ASSERT(permutation_.is_identity() || !rules_momenta_.empty());

	const ex inv1 = kinematics_->kinematic_invariants();
	const ex inv2 = //
					inv1.subs(inv2sp, no_pat).subs(rules_momenta_, no_pat).expand().subs(
							sp2inv, no_pat).expand();
	ASSERT(inv1.nops() == inv2.nops());
	for (unsigned i = 0; i < inv1.nops(); ++i) {
		ex i1 = inv1.op(i);
		ex i2 = inv2.op(i);
		ASSERT(is_a<symbol>(i1));
		if (i1.is_equal(i2))
			continue;
		rules_invariants_[i1] = i2;
	}

	// now verify derived rules
	// (might be wrong if numbers for invariants are used in user input)
	for (exmap::const_iterator r = sp2inv.begin(); r != sp2inv.end(); ++r) {
		ex lhsprime = r->first.subs(rules_momenta_, no_pat);
		ex rhsprime = r->second.subs(rules_invariants_, no_pat);
		ex diff = (lhsprime - rhsprime).expand().subs(sp2inv, no_pat);
		if (!diff.expand().is_zero())
			throw runtime_error(
					string("Failed to generate crossing:\n") +
					"Can't map crossing " + to_string(permutation_) +
					" to transformations of momenta and invariants.\n" +
					"This error is expected if numbers are used for kinematic"
					" invariants.\n"
					"Please try to disable generate_crossings in kinematics and"
					" avoid graph based jobs.");
	}

	// does the symbol to replace by one appear on lhs ?

	const symbol* symb2one = kinematics_->symbol_to_replace_by_one();
	if (symb2one)
		for (exmap::const_iterator r = rules_invariants_.begin(); r
				!= rules_invariants_.end(); ++r)
			if (r->first.is_equal(*symb2one)) {
				has_symb2one_in_lhs_ = true;
				break;
			}
	// is equivalent to identity crossing ?
	ex thisinvs = inv1.subs(rules_invariants_, no_pat);
	is_equivalent_to_identity_ = (thisinvs - inv1).expand().is_zero();
}

Crossing Crossing::inverse() const {
	return Crossing(kinematics_, permutation_.inverse());
}

Crossing Crossing::compose(const Crossing& c1, const Crossing& c2) {
	if (*c1.kinematics_ != *c2.kinematics_)
		throw runtime_error(
				"Cannot compose crossings from different kinematics.");
	return Crossing(c1.kinematics_, Permutation::compose(c1.permutation_,
			c2.permutation_));
}

std::string Crossing::name_for_crossed_family(const IntegralFamily* ic) const {
	if (!ic->is_crossed())
		return ic->name() + name();
	ASSERT(dynamic_cast<const CrossedIntegralFamily*>(ic) != 0);
	const CrossedIntegralFamily* derived =
			static_cast<const CrossedIntegralFamily*> (ic);
	string crossing = Crossing::compose(*this, derived->crossing()).name();
	return ic->source_integralfamily()->name() + crossing;
}

const IntegralFamily* Crossing::transform(const IntegralFamily* f) const {
	string name = name_for_crossed_family(f);
	const IntegralFamily* ic = Files::instance()->integralfamily(name);
	return ic;
}

Sector Crossing::transform(const Sector& s) const {
	string name = name_for_crossed_family(s.integralfamily());
	const IntegralFamily* ic = Files::instance()->integralfamily(name);
	return Sector(ic, s.id());
}

INT Crossing::transform(const INT& i) const {
	string name = name_for_crossed_family(i.integralfamily());
	const IntegralFamily* ic = Files::instance()->integralfamily(name);
	return INT(ic, i.v());
}

LinearCombination Crossing::transform(const LinearCombination& id,
		bool transform_to_minimal_equiv) const {
	LinearCombination newid;
	for (LinearCombination::const_iterator t = id.begin(); t != id.end(); ++t) {
		INT i = transform(t->first);
		if (transform_to_minimal_equiv)
			i = Crossing::transform_to_minimal_equivalent(i);
		newid.insert(i, t->second.subs(rules_invariants_));
	}
	return newid;
}

Identity Crossing::transform(const Identity& id,
		bool transform_to_minimal_equiv) const {
	Identity newid;
	Identity::const_iterator t;
	const exmap& s2one = kinematics_->get_rule_symbol_to_replace_by_one();

	if (has_symb2one_in_lhs_) { // lhs of a rule has symb to replace by one
		Identity reconstruct = id;
		reconstruct.reconstruct_symbol_replaced_by_one();
		for (t = reconstruct.begin(); t != reconstruct.end(); ++t) {
			INT i = transform(t->first);
			if (transform_to_minimal_equiv)
				i = Crossing::transform_to_minimal_equivalent(i);
			newid.insert(i, t->second.subs(rules_invariants_).subs(s2one));
		}
	} else {
		for (t = id.begin(); t != id.end(); ++t) {
			INT i = transform(t->first);
			if (transform_to_minimal_equiv)
				i = Crossing::transform_to_minimal_equivalent(i);
			newid.insert(i, t->second.subs(rules_invariants_).subs(s2one));
		}
	}
	return newid;
}

GiNaC::ex Crossing::transform(const GiNaC::ex& e) const {
	return e.subs(rules_invariants_).subs(rules_momenta_);
}

pair<INT, Crossing> Crossing::uncross(const INT& integral) {
	const IntegralFamily* ic = integral.integralfamily();
	const CrossedIntegralFamily* cic =
			dynamic_cast<const CrossedIntegralFamily*> (ic);
	if (cic == 0) {
		return make_pair(integral, Crossing(ic->kinematics()));
	} else {
		INT sourceintegral(cic->source_integralfamily(), integral.v());
		return make_pair(sourceintegral, cic->crossing());
	}
}

pair<const IntegralFamily*, Crossing> Crossing::uncross(const IntegralFamily* ic) {
	const CrossedIntegralFamily* cic =
			dynamic_cast<const CrossedIntegralFamily*> (ic);
	if (cic == 0) {
		return make_pair(ic, Crossing(ic->kinematics()));
	} else {
		return make_pair(cic->source_integralfamily(), cic->crossing());
	}
}

INT Crossing::transform_to_minimal_equivalent(const INT& i) {
	if (!i.integralfamily()->is_crossed())
		return i;
	// the family of the integral (is crossed)
	const CrossedIntegralFamily * ic =
			static_cast<const CrossedIntegralFamily*> (i.integralfamily());
	// the source integral family
	const IntegralFamily* source = ic->source_integralfamily();
	Crossing c =
			Files::instance()->crossings(ic->kinematics()->name())->equivalent_crossings().at(
					ic->crossing());
	string name = c.name_for_crossed_family(source);
	return INT(Files::instance()->integralfamily(name), i.v());
}

bool Crossing::is_equivalent(const Crossing& other) const {
	if (*kinematics_ != *other.kinematics_)
		return false;
	const lst invs = kinematics_->kinematic_invariants();
	ex thisinvs = invs.subs(rules_invariants_);
	ex otherinvs = invs.subs(other.rules_invariants_);
	return (thisinvs - otherinvs).expand().is_zero();
}

Sector Crossing::to_minimal_crossing(const Sector& s) {
	if (!s.integralfamily()->is_crossed())
		return s;
	const map<Crossing, Crossing>& crossmap = Files::instance()->crossings(
			s.integralfamily()->kinematics()->name())->equivalent_crossings();
	pair<INT, Crossing> unx = Crossing::uncross(INT(s));
	const Sector snox = unx.first.get_sector();
	map<Crossing, Crossing>::const_iterator min = crossmap.find(unx.second);
	VERIFY(min != crossmap.end());
	return min->second.transform(snox);
}

INT Crossing::to_minimal_crossing(const INT& i) {
	if (!i.integralfamily()->is_crossed())
		return i;
	Sector smin = to_minimal_crossing(i.get_sector());
	if (smin == i.get_sector())
		return i;
	return INT(smin.integralfamily(), i.v());
}

void Crossing::print(YAML::Emitter& ye) const {
	using namespace YAML;
	ye << BeginMap;
	ye << Key << "permutation" << Value << permutation_;
	if (kinematics_) {
		ostringstream ss;
		ss << "name: " << name_ << "\n";
		ss << "rules_momenta: " << rules_momenta_ << "\n";
		ss << "rules_invariants: " << rules_invariants_;
		ye << Comment(ss.str());
		//		ye << Key << "name" << Value << name_;
		//		ye << Key << "rules_momenta" << Value << Flow << rules_momenta_;
		//		ye << Key << "rules_invariants" << Value << Flow << rules_invariants_;
	}
	ye << EndMap;
}

void Crossing::read(const YAML::Node& node) {
	Permutation p;
	node["permutation"] >> p;
	VERIFY(kinematics_ != 0);
	*this = Crossing(kinematics_, p);
}

bool Crossing::operator<(const Crossing& other) const {
	if (*kinematics_ != *other.kinematics_)
		return *kinematics_ < *other.kinematics_;
	if (is_equivalent_to_identity_ != other.is_equivalent_to_identity_)
		return is_equivalent_to_identity_;
	return permutation() < other.permutation();
}
bool Crossing::operator==(const Crossing& other) const {
	if (*kinematics() != *other.kinematics())
		return false;
	return permutation_ == other.permutation_;
}
bool Crossing::operator!=(const Crossing& other) const {
	return !(*this == other);
}

//
//
//

OrderedCrossings::OrderedCrossings(const Kinematics* kin) :
		kinematics_(kin) {
	init();
}

bool OrderedCrossings::has(const Crossing& crossing) const {
	return equivalent_crossings_.find(crossing) != equivalent_crossings_.end();
}

void OrderedCrossings::print(YAML::Emitter& ye) const {
	using namespace YAML;
	ye << BeginSeq;
	list<Crossing>::const_iterator c = ordered_crossings_.begin();
	for (; c != ordered_crossings_.end(); ++c)
		ye << *c;
	ye << EndSeq;
}

void OrderedCrossings::read(const YAML::Node& n) {
	using namespace YAML;
	VERIFY(n.Type() == NodeType::Sequence);
	for (Iterator c = n.begin(); c != n.end(); ++c) {
		Crossing x(kinematics_);
		*c >> x;
		ordered_crossings_.push_back(x);
		LOGX("using crossing " << x.name() << " (invs: " << x.rules_invariants()
				<< ", mom: " << x.rules_momenta() << ")");
	}

	// verify no multiple defined crossings
	set<Crossing> tmp(ordered_crossings_.begin(), ordered_crossings_.end());
	if (tmp.size() != ordered_crossings_.size())
		throw runtime_error("multiple defined crossings");
	// verify no identity crossing present
	list<Crossing>::const_iterator c;
	for (c = ordered_crossings_.begin(); c != ordered_crossings_.end(); ++c)
		if (c->is_identity())
			throw runtime_error("identity crossing must not be contained"
					" in the list of crossings");
	// verify each crossing has an inverse
	set<Permutation> pset;
	for (c = ordered_crossings_.begin(); c != ordered_crossings_.end(); ++c)
		pset.insert(c->permutation());
	for (set<Permutation>::iterator p = pset.begin(); p != pset.end(); ++p)
		if (pset.find(p->inverse()) == pset.end())
			throw runtime_error(
					"missing crossing " + to_string(p->inverse())
							+ " which is the inverse of crossing "
							+ to_string(*p));
	// setup remaining members of the class
	init();
}

void OrderedCrossings::init() {
	const lst& invs = kinematics_->kinematic_invariants();
	list<Crossing> all = ordered_crossings_;
	all.push_front(Crossing(kinematics_));
	list<Crossing>::const_reverse_iterator cfrom;
	for (cfrom = all.rbegin(); cfrom != all.rend(); ++cfrom) {
		ex frominvs = invs.subs(cfrom->rules_invariants());
		list<Crossing>::const_iterator cto;
		for (cto = all.begin(); cto != --cfrom.base(); ++cto) {
			ex toinvs = invs.subs(cto->rules_invariants());
			ex diff = (frominvs - toinvs).expand();
			if (diff.is_zero())
				break;
		}
		crossings_by_name_.insert(make_pair(cfrom->name(), *cfrom));
		equivalent_crossings_.insert(make_pair(*cfrom, *cto));
		equivalent_crossing_classes_.insert(*cto);
		string fn = cfrom->name(), tn = cto->name();
		LOGX("Crossing equivalence: '" << fn << "' -> '" << tn << "'");
	}
}

void OrderedCrossings::construct_crossings() {
	const lst& em = kinematics_->external_momenta();
	const exmap& rmc = kinematics_->rule_momentum_conservation();
	const exmap& sp2inv = kinematics_->rules_sp_to_invariants();
	const lst& inem = kinematics_->incoming_momenta();

	// find all legs with the same mass

	set<int> alllegs;
	LOGX("External legs:");
	map<ex, set<int>, ex_is_less> legs_by_mass;
	for (unsigned i = 0; i < em.nops(); ++i) {
		int leg = i + 1;
		ASSERT(is_a<symbol>(em.op(i)));
		ex mom = em.op(i);
		string directed = (inem.has(mom) ? "incoming" : "outgoing");
		mom = mom.subs(rmc);
		ex mass = ScalarProduct(mom, mom).expand().subs(sp2inv).expand();
		LOGX("  external leg " << leg << " is " << directed << ", has momenta " //
				<< mom << " and mass " << mass);
		legs_by_mass[mass].insert(leg);
		alllegs.insert(leg);
	}
	LOGX("Rules to replace scalar products by kinematic invariants:");
	for (exmap::const_iterator r = sp2inv.begin(); r != sp2inv.end(); ++r)
		LOGX("  " << r->first << " -> " << r->second);

	// get all permutations of the legs with the same mass

	set<Permutation>::iterator p;
	set<Permutation> permutations;
	map<ex, set<int>, ex_is_less>::const_iterator m;
	for (m = legs_by_mass.begin(); m != legs_by_mass.end(); ++m) {
		vector<int> from(m->second.begin(), m->second.end());
		vector<int> next(from);
		do {
			permutations.insert(Permutation(from, next));
		} while (next_permutation(next.begin(), next.end()));
	}
	LOGX("Permutations of external legs (before completing):");
	for (p = permutations.begin(); p != permutations.end(); ++p)
		LOGX(*p);

	// discard invalid crossings
	LOGX("Discarding invalid crossings due to numeric kinematic invariants:");
	int num_discarded = 0;
	for (p = permutations.begin(); p != permutations.end();) {
		try {
			Crossing cr(kinematics_, *p);
			++p;
		} catch (exception& e) {
			LOGX("  discard crossing: " << *p);
			permutations.erase(p++);
			++num_discarded;
		}
	}
	LOGX("discarded " << num_discarded << " crossings");
	LOGX("Completing the permutations:");
	PermutationSet pset(permutations);
	pset.complete_permutations(); // complete the set of permutations
	pset.insert(Permutation()); // include identity
	permutations = pset.pset();
	LOGX("Permutations of external legs:");
	for (p = permutations.begin(); p != permutations.end(); ++p)
		LOGX(*p);

	// set up the corresponding crossing

	LOGX("Generating crossings:");
	for (p = permutations.begin(); p != permutations.end(); ++p) {
		try {
			Crossing cr(kinematics_, *p);
			if (!cr.is_identity()) {
				ordered_crossings_.push_back(cr);
				LOGX("  generated crossing " << *p);
			}
		} catch (exception& e) {
			ERROR("Failed to generate crossing\n" << e.what());
		}
	}
	ordered_crossings_.sort();
	size_t num_oc = ordered_crossings_.size();
	LOGX("Number of crossings (identity included): " << num_oc + 1);

	init(); // set the remaining members of the class
}

void OrderedCrossings::print_info() const {
    LOGX("Crossings for kinematics '" << kinematics_->name() << "':");
	const list<Crossing>& oc = ordered_crossings();
	for (list<Crossing>::const_iterator c = oc.begin(); c != oc.end(); ++c) {
		const map<Crossing, Crossing>& xmap = equivalent_crossings();
		VERIFY(xmap.find(*c) != xmap.end());
		YAML::Emitter ye;
		ye << *c;
		LOGX("  " << ye.c_str() << "\n" << "is equivalent to '" //
				<< xmap.find(*c)->second.name() << "'");
	}
	LOGX("");
}


}
