"""
Portfolio Service - Manages portfolio state and positions
"""
from typing import Dict, List, Optional
from decimal import Decimal
from datetime import datetime
from sqlalchemy.orm import Session
from app.models import Position, PortfolioSnapshot, Trade
from app.config import settings
import logging

logger = logging.getLogger(__name__)


class PortfolioService:
    """
    Manages portfolio state including cash, positions, and PnL
    """
    
    def __init__(self):
        self.cash_usdt = Decimal(settings.initial_balance_usdt)
        self.initial_balance = Decimal(settings.initial_balance_usdt)
        self.positions: Dict[str, Position] = {}
        
    def get_state(self, db: Session, current_prices: Dict[str, float]) -> Dict:
        """
        Get current portfolio state
        
        Args:
            db: Database session
            current_prices: Dict of symbol -> current price
        
        Returns:
            Portfolio state dict
        """
        # Get active positions
        positions = db.query(Position).filter(Position.quantity > 0).all()
        
        # Calculate positions value and unrealized PnL
        positions_value = Decimal(0)
        pnl_unrealized = Decimal(0)
        
        position_details = []
        
        for pos in positions:
            current_price = Decimal(str(current_prices.get(pos.symbol, 0)))
            if current_price == 0:
                continue
            
            # Update current price
            pos.current_price = current_price
            
            # Calculate value and unrealized PnL
            position_value = pos.quantity * current_price
            positions_value += position_value
            
            if pos.side == 'LONG':
                unrealized = (current_price - pos.entry_price) * pos.quantity
            else:  # SHORT
                unrealized = (pos.entry_price - current_price) * pos.quantity
            
            pos.pnl_unrealized = unrealized
            pnl_unrealized += unrealized
            
            position_details.append({
                'symbol': pos.symbol,
                'side': pos.side,
                'quantity': float(pos.quantity),
                'entry_price': float(pos.entry_price),
                'current_price': float(current_price),
                'value_usdt': float(position_value),
                'pnl_unrealized': float(unrealized),
                'pnl_pct': float((current_price / pos.entry_price - 1) * 100) if pos.side == 'LONG' else float((1 - current_price / pos.entry_price) * 100),
            })
        
        db.commit()
        
        # Get latest snapshot for total realized PnL
        latest_snapshot = db.query(PortfolioSnapshot).order_by(
            PortfolioSnapshot.timestamp.desc()
        ).first()
        
        pnl_realized = Decimal(0) if not latest_snapshot else latest_snapshot.pnl_realized_total
        
        # Get current cash from latest snapshot or use initial
        if latest_snapshot:
            self.cash_usdt = latest_snapshot.cash_usdt
        
        # Calculate total equity
        equity = self.cash_usdt + positions_value
        
        # Calculate total PnL
        pnl_total = pnl_realized + pnl_unrealized
        pnl_pct = float((equity / self.initial_balance - 1) * 100)
        
        # Count total trades
        total_trades = db.query(Trade).count()
        
        return {
            'equity_usdt': float(equity),
            'cash_usdt': float(self.cash_usdt),
            'positions_value_usdt': float(positions_value),
            'pnl_realized': float(pnl_realized),
            'pnl_unrealized': float(pnl_unrealized),
            'pnl_total': float(pnl_total),
            'pnl_pct': pnl_pct,
            'num_positions': len(position_details),
            'num_trades': total_trades,
            'positions': position_details,
            'max_position_size_usdt': float(equity * Decimal(str(settings.max_position_size_pct))),
        }
    
    def can_open_position(
        self,
        db: Session,
        symbol: str,
        size_usdt: float
    ) -> tuple[bool, str]:
        """
        Check if we can open a new position
        
        Returns:
            (can_open, reason)
        """
        # Check if position already exists
        existing = db.query(Position).filter(
            Position.symbol == symbol,
            Position.quantity > 0
        ).first()
        
        if existing:
            return False, f"Position already open for {symbol}"
        
        # Check max open positions
        open_positions = db.query(Position).filter(
            Position.quantity > 0
        ).count()
        
        if open_positions >= settings.max_open_positions:
            return False, f"Max open positions reached ({settings.max_open_positions})"
        
        # Check sufficient cash
        if Decimal(str(size_usdt)) > self.cash_usdt:
            return False, f"Insufficient cash: need {size_usdt}, have {float(self.cash_usdt)}"
        
        return True, "OK"
    
    def save_snapshot(self, db: Session, state: Dict):
        """Save portfolio snapshot to database"""
        snapshot = PortfolioSnapshot(
            equity_usdt=Decimal(str(state['equity_usdt'])),
            cash_usdt=Decimal(str(state['cash_usdt'])),
            positions_value_usdt=Decimal(str(state['positions_value_usdt'])),
            pnl_realized_total=Decimal(str(state['pnl_realized'])),
            pnl_unrealized_total=Decimal(str(state['pnl_unrealized'])),
            pnl_total=Decimal(str(state['pnl_total'])),
            pnl_pct=Decimal(str(state['pnl_pct'])),
            num_positions=state['num_positions'],
            num_trades=state['num_trades'],
        )
        
        db.add(snapshot)
        db.commit()
        
        logger.info(
            f"Portfolio snapshot: Equity=${state['equity_usdt']:.2f}, "
            f"PnL={state['pnl_pct']:.2f}%"
        )


# Global instance
portfolio_service = PortfolioService()
