import copy
from collections import defaultdict
from debug import myprint
from apm_helpers.messages import Messages as msg
from.simple_tree_solver import SimpleTreeSolver
from.solver_base import Solution
from.helpers import find_missed_branches,IsLoopChecker
class TwoPhaseSolver(SimpleTreeSolver):
 def __init__(self,*args,**kwargs):
  super().__init__(*args,**kwargs)
  self.optimistic_reuse=False
  if self.settings:
   self.optimistic_reuse=self.settings.get('optimistic_reuse',self.optimistic_reuse)
 def _solve(self,top_rows,params):
  platform=self.platform
  res=Solution()
  res.loop_nests=self.loop_nests
  res.host_estimations=platform.host.generate_estimations(top_rows)
  res.total_time_on_host=sum(res.host_estimations[row]['total_time']for row in top_rows)
  p1_params=copy.deepcopy(params)
  p1_params.update({'consider_region_combinations':False,})
  if not self.optimistic_reuse:
   estimations,p1_accel_estimations=self.find_most_profitable_heads(top_rows,res.host_estimations,res.total_time_on_host,platform,p1_params)
   p1_loopnests,p1_offload_heads,p1_non_offload_heads,p1_solution=estimations
  else:
   p1_params['estimate_data_transfer_tax']=False
   p1_params['estimate_read_data_transfer_tax']=False
   p1_params['estimate_write_data_transfer_tax']=False
   p1_params['estimate_max_speedup']=False
   estimations,_=self.find_most_profitable_heads(top_rows,res.host_estimations,res.total_time_on_host,platform,p1_params)
   p1_loopnests,p1_offload_heads,p1_non_offload_heads,p1_solution=estimations
   settings={}
   for x in p1_offload_heads:
    for k,v in p1_solution[x].estimation['params'].get_dict(x).items():
     if k not in settings:
      settings[k]={'per_region':{}}
     elif 'per_region' not in settings[k]:
      settings[k]['per_region']={}
     settings[k]['per_region'][x]=v
   settings.update({'consider_region_combinations':{'global':False},'model_children':{'global':False},'estimate_read_data_transfer_tax':{'global':params.get('estimate_data_transfer_tax',True)},'estimate_write_data_transfer_tax':{'global':params.get('estimate_data_transfer_tax',True)},'data_reuse_analysis':{'global':False},'estimate_max_speedup':{'global':params.get('estimate_max_speedup',True)},})
   p1_accel_estimations=[]
   for accel in platform.accelerators:
    p1_accel_estimations+=accel.generate_estimations(top_rows,settings=settings)
   row2est_idx=defaultdict(list)
   for idx,est in enumerate(p1_accel_estimations):
    if len(est['rows'])>1:
     continue
    gain=None
    if est['is_offload_candidate']and est['does_fit']:
     gain=self.objective_fn(est)-sum(self.objective_fn(res.host_estimations[x])for x in est['rows'])
    for row in est['rows']:
     row2est_idx[row].append((est,gain,idx))
   for row in p1_solution:
    if row in res.host_estimations:
     p1_solution[row].base_time=res.host_estimations[row]['total_base_time']
     p1_solution[row].fractional_time=res.host_estimations[row]['total_time']/res.total_time_on_host
    else:
     p1_solution[row].base_time=None
     p1_solution[row].fractional_time=None
    if row in row2est_idx:
     est,gain,est_idx=max([x for x in row2est_idx[row]if not x[0].get('relaxed')and x[1]is not None],key=lambda x:x[1],default=(None,None,None))
     p1_solution[row].estimation=est
     p1_solution[row].estimation_idx=est_idx
     p1_solution[row].gain=gain
   for row in p1_solution:
    est=p1_solution[row].estimation
    if est is not None and est['is_offload_candidate']and est['does_fit']:
     host_time=sum(p1_solution[row].base_time for x in est['rows'])
     speedup=host_time/est['time']
     p1_solution[row].speed_up=speedup
     p1_solution[row].max_speed_up=speedup
  p2_settings={}
  for x in p1_offload_heads:
   for k,v in p1_solution[x].estimation['params'].get_dict(x).items():
    if k not in p2_settings:
     p2_settings[k]={'per_region':{}}
    elif 'per_region' not in p2_settings[k]:
     p2_settings[k]['per_region']={}
    p2_settings[k]['per_region'][x]=v
  p2_settings.update({'consider_region_combinations':{'global':True},'model_children':{'global':False},'estimate_read_data_transfer_tax':{'global':True},'estimate_write_data_transfer_tax':{'global':True},'data_reuse_analysis':{'global':True},'estimate_max_speedup':{'global':False},})
  p2_accel_estimations=[]
  for accel in platform.accelerators:
   p2_accel_estimations+=accel.generate_estimations(p1_offload_heads,settings=p2_settings)
  row2est_idx=defaultdict(list)
  est_idx2gain={}
  for idx,est in enumerate(p2_accel_estimations):
   gain=None
   speedup=None
   per_row_speedup={}
   no_reuse_gain={}
   if est['is_offload_candidate']and est['does_fit']:
    host_time=sum(p1_solution[x].base_time for x in est['rows'])
    host_weight=sum(self.objective_fn(res.host_estimations[x])for x in est['rows'])
    accel_weight=self.objective_fn(est)
    no_combinations_gain=sum(self.objective_fn(p1_solution[x].estimation)-self.objective_fn(res.host_estimations[x])for x in est['rows']if p1_solution[x].speed_up>params['min_required_speed_up'])
    gain=accel_weight-host_weight
    speedup=host_time/est['time']
    relative_gain=gain-no_combinations_gain
    for row in est['rows']:
     row_host_time=p1_solution[row].base_time
     row_estimation=None
     for x in est.get('regions',[]):
      if x['measured_row']==row:
       row_estimation=x
       break
     if row_estimation:
      per_row_speedup[row]=row_host_time/row_estimation['time']
      no_reuse_gain[row]=p1_solution[row].gain
    est_idx2gain[idx]=relative_gain,gain,speedup,per_row_speedup,no_reuse_gain
   for row in est['rows']:
    row2est_idx[row].append((est,gain,per_row_speedup.get(row),idx))
  p2_solution={}
  for k,v in p1_solution.items():
   new_solution=v.copy()
   new_solution.is_profitable=False
   new_solution.is_potential_offload_head=False
   new_solution.is_offloaded=False
   p2_solution[k]=new_solution
  covered_rows=set()
  suggested_est_idx=[]
  p2_loopnests={k:(v[0],tuple())for k,v in p1_loopnests.items()}
  for est_idx,val in sorted(est_idx2gain.items(),key=lambda x:x[1][:3],reverse=True):
   relative_gain,gain,speedup,per_row_speedup,no_reuse_gain=val
   est=p2_accel_estimations[est_idx]
   if est['relaxed']or speedup<params['min_required_speed_up']:
    continue
   if any(x not in per_row_speedup for x in est['rows']):
    myprint(msg.WARNING_WRONG_OFFLOAD_COMBINATION_ESTIMATION.format(est['rows']),severity=3)
    continue
   if len(est['rows'])>1:
    minimal_required_gain=1e-6
    if relative_gain<minimal_required_gain:
     continue
    no_reuse_gain=sum(no_reuse_gain[row]for row in est['rows']if per_row_speedup[row]>=params['min_required_speed_up'])
    if gain-no_reuse_gain<minimal_required_gain:
     continue
   rows_set=set(est['rows'])
   if rows_set&covered_rows:
    continue
   covered_rows|=set(est['rows'])
   suggested_est_idx+=[est_idx]
   for row in est['rows']:
    loop_solution=p2_solution[row]
    loop_solution.is_profitable=True
    loop_solution.is_offloaded=True
    loop_solution.estimation=est
    loop_solution.estimation_idx=est_idx
    loop_solution.gain=gain
    loop_solution.speed_up=per_row_speedup.get(row)
    try:
     loopnest_row=p2_solution[row].loopnest[1]
     try:
      offload_heads=p2_loopnests[loopnest_row['key_column']][1]+(row,)
     except KeyError:
      offload_heads=(row,)
     p2_loopnests[loopnest_row['key_column']]=(loopnest_row,offload_heads)
    except KeyError:
     pass
  res.accel_estimations=p1_accel_estimations+[x for x in p2_accel_estimations if len(x['rows'])>1]
  res.offloads_by_loop_nest=p2_loopnests
  res.offload_heads=list(covered_rows)
  res.non_offload_heads=find_missed_branches(covered_rows,[x[0]for x in p2_loopnests.values()],row_fit=IsLoopChecker())
  for x in res.non_offload_heads:
   p2_solution[x].is_potential_offload_head=True
  res.regions=p2_solution
  self._calc_per_region_speed_up(res)
  return res
