import type { InvestmentFactorForecast, Portfolio } from 'venn-api';
import { isNil, partition, sortBy } from 'lodash';
import { calculateComparison, normalizePortfolio } from 'venn-utils';

export interface TradesStatistics {
  tradeCount: number;
  capitalMoved: number;
  investments: InvestmentTradeStatistic[];
}

export interface InvestmentTradeStatistic {
  name: string;
  id: string;
  forecastReturn?: number;
  tradeValue: number;
  originalAllocation: number | undefined;
}

export interface PortfolioNodeTradeStatistic {
  id: number;
  name: string;
  level: number;
  previousItemLevel: number;
  originalValue: number | undefined;
  tradeValue: number;
  isNew: boolean;
  isDeleted: boolean;
  investmentTrade?: InvestmentTradeStatistic;
}

export const getFullPortfolioTradeStatistics = (
  solution: Portfolio,
  base: Portfolio,
  investmentForecasts: { [key: string]: InvestmentFactorForecast },
): PortfolioNodeTradeStatistic[] => {
  const normalizedBase = normalizePortfolio(base);
  const normalizedSolution = normalizePortfolio(solution);
  const [allCompareNodes, allGhostChildren] = calculateComparison(normalizedSolution, normalizedBase);
  const result: PortfolioNodeTradeStatistic[] = [];

  const getPortfolioTradesForSubtree = (
    solutionNode: Portfolio | undefined,
    baseNode: Portfolio | undefined,
    level: number,
  ) => {
    if (isNil(solutionNode) && isNil(baseNode)) {
      return;
    }
    const fund = solutionNode?.fund ?? baseNode?.fund;
    if (level !== 0) {
      result.push({
        id: (solutionNode?.id ?? baseNode?.id)!,
        name: (solutionNode?.name ?? baseNode?.name)!,
        level,
        previousItemLevel: result.length === 0 ? 0 : result[result.length - 1]!.level,
        originalValue: baseNode?.draft && baseNode?.allocation === 0 ? undefined : (baseNode?.allocation ?? 0),
        tradeValue: (solutionNode?.allocation ?? 0) - (baseNode?.allocation ?? 0),
        isDeleted: solutionNode?.allocation === 0,
        isNew: !isNil(baseNode) && baseNode.draft,
        investmentTrade: isNil(fund)
          ? undefined
          : {
              originalAllocation: baseNode?.allocation ?? 0,
              tradeValue: (solutionNode?.allocation ?? 0) - (baseNode?.allocation ?? 0),
              id: fund.id,
              name: fund.name,
              forecastReturn: investmentForecasts[fund.id]?.annualizedTotalReturn,
            },
      });
    }
    if (isNil(solutionNode) || !isNil(solutionNode.fund)) {
      return;
    }
    solutionNode.children.forEach((child) =>
      getPortfolioTradesForSubtree(child, allCompareNodes.get(child.id), level + 1),
    );
    allGhostChildren
      .get(solutionNode.id)
      ?.forEach((ghostChild) => getPortfolioTradesForSubtree(undefined, ghostChild, level + 1));
  };

  getPortfolioTradesForSubtree(solution, base, 0);

  return result;
};

export const getTradeStatistics = (original: Portfolio, optimized: Portfolio): TradesStatistics => {
  const trade = getTradeStatsForSubtree(original, optimized);
  const additionalFundsRequired = (optimized.allocation ?? 0) - (original.allocation ?? 0);
  if (additionalFundsRequired > 0) {
    return {
      ...trade,
      capitalMoved: trade.capitalMoved + additionalFundsRequired,
    };
  }
  return trade;
};

const getTradeStatsForSubtree = (node: Portfolio | null, optimizedNode: Portfolio | null): TradesStatistics => {
  if (isNil(node) && isNil(optimizedNode)) {
    return {
      tradeCount: 0,
      capitalMoved: 0,
      investments: [],
    };
  }
  if (isNil(node)) {
    return getTradeStatsVsNull(optimizedNode!, false);
  }
  if (isNil(optimizedNode)) {
    return getTradeStatsVsNull(node, true);
  }

  if (isNil(node.fund) !== isNil(optimizedNode.fund)) {
    // One node is a strategy, one is a fund: the trades number is a sum of selling one and buying another
    const nodeStats = getTradeStatsVsNull(node, true);
    const optimizedStats = getTradeStatsVsNull(optimizedNode, false);
    return {
      tradeCount: nodeStats.tradeCount + optimizedStats.tradeCount,
      capitalMoved: nodeStats.capitalMoved + optimizedStats.capitalMoved,
      investments: [...nodeStats.investments, ...optimizedStats.investments],
    };
  }

  // If both nodes are funds, calculate the trade statistic
  if (!isNil(node.fund) && !isNil(optimizedNode.fund)) {
    const tradeValue =
      node.fund.id === optimizedNode.fund.id
        ? Math.max(0, (node.allocation ?? 0) - (optimizedNode.allocation ?? 0))
        : (node.allocation ?? 0);
    const investments = (
      node.fund.id === optimizedNode.fund.id
        ? [
            {
              name: node.fund.name,
              id: node.fund.id,
              tradeValue: (optimizedNode.allocation ?? 0) - (node.allocation ?? 0),
              originalAllocation: node.allocation ?? 0,
            },
          ]
        : [
            {
              name: node.fund.name,
              id: node.fund.id,
              tradeValue: -(node.allocation ?? 0),
              originalAllocation: node.allocation ?? 0,
            },
            {
              name: optimizedNode.fund.name,
              id: optimizedNode.fund.id,
              tradeValue: optimizedNode.allocation ?? 0,
              originalAllocation: undefined,
            },
          ]
    ).filter((item) => item.tradeValue !== 0);
    return {
      tradeCount: investments.length,
      capitalMoved: tradeValue,
      investments,
    };
  }

  const [nodeFunds, nodeStrategies] = partition(node.children, (item) => item.fund);
  const [optimizedFunds, optimizedStrategies] = partition(optimizedNode.children, (item) => item.fund);

  const cumulativeStats: TradesStatistics = {
    tradeCount: 0,
    capitalMoved: 0,
    investments: [],
  };

  // Handle children that are funds
  const sortedNodeFunds = sortBy(nodeFunds, ['fund.id', 'allocation']);
  const sortedOptimizedFunds = sortBy(optimizedFunds, ['fund.id', 'allocation']);

  let nodeIdx = 0;
  let optiIdx = 0;

  while (nodeIdx < sortedNodeFunds.length || optiIdx < sortedOptimizedFunds.length) {
    let tradeValue = 0;
    let tradeInvestment: InvestmentTradeStatistic | undefined;

    if (nodeIdx >= sortedNodeFunds.length) {
      // If optimized solution has extra funds that don't exist in the primary node, count them in
      tradeValue = 0;
      tradeInvestment = {
        name: sortedOptimizedFunds[optiIdx]!.name,
        id: sortedOptimizedFunds[optiIdx]!.fund!.id,
        tradeValue: sortedOptimizedFunds[optiIdx]!.allocation ?? 0,
        originalAllocation: undefined,
      };
      optiIdx++;
    } else if (optiIdx >= sortedOptimizedFunds.length) {
      // If primary node has extra funds that don't exist in the optimized solution, count their allocation
      tradeValue = sortedNodeFunds[nodeIdx]!.allocation ?? 0;
      tradeInvestment = {
        name: sortedNodeFunds[nodeIdx]!.name,
        id: sortedNodeFunds[nodeIdx]!.fund!.id,
        tradeValue: -tradeValue,
        originalAllocation: sortedNodeFunds[nodeIdx]!.allocation ?? 0,
      };
      nodeIdx++;
    } else {
      // Try to match funds between nodes
      const nodeFund = sortedNodeFunds[nodeIdx]!;
      const optiFund = sortedOptimizedFunds[optiIdx]!;
      if (nodeFund.fund!.id === optiFund.fund!.id) {
        // In both nodes we're currently considering the same fund
        tradeValue = Math.max(0, (nodeFund.allocation ?? 0) - (optiFund.allocation ?? 0));
        tradeInvestment = {
          name: nodeFund.name,
          id: nodeFund.fund!.id,
          tradeValue: (optiFund.allocation ?? 0) - (nodeFund.allocation ?? 0),
          originalAllocation: nodeFund.allocation ?? 0,
        };
        nodeIdx++;
        optiIdx++;
      } else if (nodeFund.fund!.id < optiFund.fund!.id) {
        // In optimized node, we're already looking at a fund with lexicographically greater fund id.
        // Count in the primary node's trade and move further in the primary node's child funds.
        tradeValue = nodeFund.allocation ?? 0;
        tradeInvestment = {
          name: nodeFund.name,
          id: nodeFund.fund!.id,
          tradeValue: -tradeValue,
          originalAllocation: nodeFund.allocation ?? 0,
        };
        nodeIdx++;
      } else {
        // In primary node, we're already looking at a fund with lexicographically greater fund id.
        // Count in the optimized node's trade and move further in the optimized node's child funds.
        tradeValue = 0;
        tradeInvestment = {
          name: optiFund.name,
          id: optiFund.fund!.id,
          tradeValue: optiFund.allocation ?? 0,
          originalAllocation: undefined,
        };
        optiIdx++;
      }
    }

    if (tradeInvestment.tradeValue !== 0) {
      cumulativeStats.tradeCount += 1;
      cumulativeStats.capitalMoved += tradeValue;
      cumulativeStats.investments.push(tradeInvestment);
    }
  }

  // Handle children that are strategies
  const sortedNodeStrategies = sortBy(nodeStrategies, ['id']);
  const sortedOptimizedStrategies = sortBy(optimizedStrategies, ['id']);

  nodeIdx = 0;
  optiIdx = 0;

  while (nodeIdx < sortedNodeStrategies.length || optiIdx < sortedOptimizedStrategies.length) {
    let trade: TradesStatistics = {
      tradeCount: 0,
      capitalMoved: 0,
      investments: [],
    };

    if (nodeIdx >= sortedNodeStrategies.length) {
      // If optimized solution has extra strategies that don't exist in the primary node, count their allocation
      trade = getTradeStatsVsNull(sortedOptimizedStrategies[optiIdx]!, false);
      optiIdx++;
    } else if (optiIdx >= sortedOptimizedStrategies.length) {
      // If primary node has extra strategies that don't exist in the optimized solution, count their allocation
      trade = getTradeStatsVsNull(sortedNodeStrategies[nodeIdx]!, true);
      nodeIdx++;
    } else {
      // Try to match strategies between nodes
      const nodeStrategy = sortedNodeStrategies[nodeIdx]!;
      const optiStrategy = sortedOptimizedStrategies[optiIdx]!;
      if (nodeStrategy.id === optiStrategy.id) {
        // In both nodes we're currently considering the same strategy
        trade = getTradeStatsForSubtree(nodeStrategy, optiStrategy);
        nodeIdx++;
        optiIdx++;
      } else if (nodeStrategy.id < optiStrategy.id) {
        // In optimized node, we're already looking at a strategy with lexicographically greater strategy id.
        // Count in the primary node's trades and move further in the primary node's child strategies.
        trade = getTradeStatsVsNull(nodeStrategy, true);
        nodeIdx++;
      } else {
        // In primary node, we're already looking at a strategy with lexicographically greater strategy id.
        // Count in the optimized node's trades and move further in the optimized node's child strategies.
        trade = getTradeStatsVsNull(optiStrategy, false);
        optiIdx++;
      }
    }

    cumulativeStats.tradeCount += trade.tradeCount;
    cumulativeStats.capitalMoved += trade.capitalMoved;
    cumulativeStats.investments = [...cumulativeStats.investments, ...trade.investments];
  }

  return cumulativeStats;
};

const getTradeStatsVsNull = (node: Portfolio, isBasePortfolio: boolean): TradesStatistics => {
  if (!isNil(node.fund)) {
    const tradeValue = node.allocation ?? 0;
    return {
      tradeCount: tradeValue > 0 ? 1 : 0,
      capitalMoved: isBasePortfolio ? tradeValue : 0,
      investments:
        tradeValue > 0
          ? [
              {
                name: node.name,
                id: node.fund.id,
                tradeValue: isBasePortfolio ? -tradeValue : tradeValue,
                originalAllocation: isBasePortfolio ? (node.allocation ?? 0) : undefined,
              },
            ]
          : [],
    };
  }
  const cumulativeStats: TradesStatistics = {
    tradeCount: 0,
    capitalMoved: 0,
    investments: [],
  };
  for (const child of node.children) {
    const childStats = getTradeStatsVsNull(child, isBasePortfolio);
    if (childStats.tradeCount > 0) {
      cumulativeStats.tradeCount += childStats.tradeCount;
      cumulativeStats.capitalMoved += childStats.capitalMoved;
      cumulativeStats.investments = [...cumulativeStats.investments, ...childStats.investments];
    }
  }
  return cumulativeStats;
};

interface BaseAndOptimizedTrade {
  name: string;
  level: number;
  isInvestment: boolean;
  isNewOpportunity?: boolean;
  baseAllocation: number | undefined;
  optimizedAllocation: number;
  trade: number;
}

export const getTradesBetweenBaseAndOptimized = (
  base: Portfolio,
  node: Portfolio,
  level = 0,
): BaseAndOptimizedTrade[] => {
  let results: BaseAndOptimizedTrade[] = [
    {
      name: base.name,
      level,
      isInvestment: !isNil(base.fund),
      isNewOpportunity: !isNil(base.fund) && base.draft,
      baseAllocation: base.allocation,
      optimizedAllocation: node.allocation ?? 0,
      trade: (node.allocation ?? 0) - (base.allocation ?? 0),
    },
  ];
  if (base.fund) {
    return results;
  }
  let baseIdx = 0;
  let nodeIdx = 0;
  while (baseIdx < base.children.length || nodeIdx < node.children.length) {
    if (
      baseIdx === base.children.length ||
      (baseIdx < base.children.length &&
        nodeIdx < node.children.length &&
        base.children[baseIdx]!.id !== node.children[nodeIdx]!.id)
    ) {
      results.push({
        name: node.children[nodeIdx]!.name,
        level: level + 1,
        isInvestment: true, // Optimized portfolios can only differ from base by having additional funds
        isNewOpportunity: true,
        baseAllocation: undefined,
        optimizedAllocation: node.children[nodeIdx]!.allocation ?? 0,
        trade: node.children[nodeIdx]!.allocation ?? 0,
      });
      nodeIdx++;
    } else if (nodeIdx === node.children.length) {
      // This shouldn't happen; no nodes should be missing in optimized portfolio compared to base
      baseIdx++;
    } else if (base.children[baseIdx]!.id === node.children[nodeIdx]!.id) {
      results = [
        ...results,
        ...getTradesBetweenBaseAndOptimized(base.children[baseIdx]!, node.children[nodeIdx]!, level + 1),
      ];
      nodeIdx++;
      baseIdx++;
    } else {
      // This shouldn't happen; optimized portfolios should have the same structure as base portfolios
      nodeIdx++;
      baseIdx++;
    }
  }
  return results;
};
