Package madgraph :: Package madevent :: Module hel_recycle
[hide private]
[frames] | no frames]

Source Code for Module madgraph.madevent.hel_recycle

  1  #!/usr/bin/env python3 
  2   
  3  import argparse 
  4  import atexit 
  5  import os 
  6  import re 
  7  import collections 
  8  from string import Template 
  9  from copy import copy 
 10  from itertools import product 
 11  from functools import reduce  
 12   
 13  try: 
 14       import madgraph 
 15  except: 
 16       import internal.misc as misc 
 17  else: 
 18       import madgraph.various.misc as misc 
 19  import mmap 
 20  try: 
 21      from tqdm import tqdm 
 22  except ImportError: 
 23      tqdm = misc.tqdm 
24 25 26 -def get_num_lines(file_path):
27 fp = open(file_path, 'r+') 28 buf = mmap.mmap(fp.fileno(),0) 29 lines = 0 30 while buf.readline(): 31 lines += 1 32 return lines
33
34 -class DAG:
35
36 - def __init__(self):
37 self.graph = {} 38 self.all_wavs = [] 39 self.external_wavs = [] 40 self.internal_wavs = []
41
42 - def store_wav(self, wav, ext_deps=[]):
43 self.all_wavs.append(wav) 44 nature = wav.nature 45 if nature == 'external': 46 self.external_wavs.append(wav) 47 if nature == 'internal': 48 self.internal_wavs.append(wav) 49 for ext in ext_deps: 50 self.add_branch(wav, ext)
51
52 - def add_branch(self, node_i, node_f):
53 try: 54 self.graph[node_i].append(node_f) 55 except KeyError: 56 self.graph[node_i] = [node_f]
57
58 - def dependencies(self, old_name):
59 deps = [wav for wav in self.all_wavs 60 if wav.old_name == old_name and not wav.dead] 61 return deps
62
63 - def kill_old(self, old_name):
64 for wav in self.all_wavs: 65 if wav.old_name == old_name: 66 wav.dead = True
67
68 - def old_names(self):
69 return {wav.old_name for wav in self.all_wavs}
70
71 - def find_path(self, start, end, path=[]):
72 '''Taken from https://www.python.org/doc/essays/graphs/''' 73 74 path = path + [start] 75 if start == end: 76 return path 77 if start not in self.graph: 78 return None 79 for node in self.graph[start]: 80 if node not in path: 81 newpath = self.find_path(node, end, path) 82 if newpath: 83 return newpath 84 return None
85
86 - def __str__(self):
87 return self.__repr__()
88
89 - def __repr__(self):
90 print_str = 'With new names:\n\t' 91 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ]) 92 print_str += '\n\nWith old names:\n\t' 93 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ]) 94 return print_str
95
96 97 98 -class MathsObject:
99 '''Abstract class for wavefunctions and Amplitudes''' 100 101 # Store here which externals the last wav/amp depends on. 102 # This saves us having to call find_path multiple times. 103 ext_deps = None 104
105 - def __init__(self, arguments, old_name, nature):
106 self.args = arguments 107 self.old_name = old_name 108 self.nature = nature 109 self.name = None 110 self.dead = False 111 self.nb_used = 0 112 self.linkdag = []
113
114 - def set_name(self, *args):
115 self.args[-1] = self.format_name(*args) 116 self.name = self.args[-1]
117
118 - def format_name(self, *nums):
119 pass
120 121 @staticmethod
122 - def get_deps(line, graph):
123 old_args = get_arguments(line) 124 old_name = old_args[-1].replace(' ','') 125 matches = graph.old_names() & set([old.replace(' ','') for old in old_args]) 126 try: 127 matches.remove(old_name) 128 except KeyError: 129 pass 130 old_deps = old_args[0:len(matches)] 131 132 # If we're overwriting a wav clear it from graph 133 graph.kill_old(old_name) 134 return [graph.dependencies(dep) for dep in old_deps]
135 136 @classmethod
137 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
138 exts = graph.external_wavs 139 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) } 140 this_comb_good = False 141 for comb in External.good_wav_combs: 142 if cls.ext_deps.issubset(set(comb)): 143 this_comb_good = True 144 break 145 146 if diag_number and this_comb_good and cls.ext_deps: 147 148 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps]) 149 this_hel = [helicity[i] for i in range(1, len(helicity)+1)] 150 hel_number = 1 + all_hel.index(tuple(this_hel)) 151 152 if (hel_number,diag_number) in bad_hel_amp: 153 this_comb_good = False 154 155 156 157 return this_comb_good and cls.ext_deps
158 159 @staticmethod
160 - def get_new_args(line, wavs):
161 old_args = get_arguments(line) 162 old_name = old_args[-1].replace(' ','') 163 # Work out if wavs corresponds to an allowed helicity combination 164 this_args = copy(old_args) 165 wav_names = [w.name for w in wavs] 166 this_args[0:len(wavs)] = wav_names 167 # This isnt maximally efficient 168 # Could take the num from wavs that've been deleted in graph 169 return this_args
170 171 @staticmethod
172 - def get_number():
173 pass
174 175 @classmethod
176 - def get_obj(cls, line, wavs, graph, diag_num = None):
177 old_name = get_arguments(line)[-1].replace(' ','') 178 new_args = cls.get_new_args(line, wavs) 179 num = cls.get_number(wavs, graph) 180 181 this_obj = cls.call_constructor(new_args, old_name, diag_num) 182 this_obj.set_name(num, diag_num) 183 if this_obj.nature != 'amplitude': 184 graph.store_wav(this_obj, cls.ext_deps) 185 return this_obj
186 187
188 - def __str__(self):
189 return self.name
190
191 - def __repr__(self):
192 return self.name
193
194 -class External(MathsObject):
195 '''Class for storing external wavefunctions''' 196 197 good_hel = [] 198 nhel_lines = '' 199 num_externals = 0 200 # Could get this from dag but I'm worried about preserving order 201 wavs_same_leg = {} 202 good_wav_combs = [] 203
204 - def __init__(self, arguments, old_name):
205 super().__init__(arguments, old_name, 'external') 206 self.hel = int(self.args[2]) 207 self.mg = int(arguments[0].split(',')[-1][:-1]) 208 self.hel_ranges = [] 209 self.raise_num()
210 211 @classmethod
212 - def raise_num(cls):
213 cls.num_externals += 1
214 215 @classmethod
216 - def generate_wavfuncs(cls, line, graph):
217 # If graph is passed in Internal it should be done here to so 218 # we can set names 219 old_args = get_arguments(line) 220 old_name = old_args[-1].replace(' ','') 221 graph.kill_old(old_name) 222 223 if 'NHEL' in old_args[2].upper(): 224 nhel_index = re.search(r'\(.*?\)', old_args[2]).group() 225 ext_num = int(nhel_index[1:-1]) - 1 226 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True) 227 new_hels = [int_to_string(i) for i in new_hels] 228 else: 229 # Spinor must be a scalar so give it hel = 0 230 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1 231 new_hels = [' 0'] 232 233 new_wavfuncs = [] 234 for hel in new_hels: 235 236 this_args = copy(old_args) 237 this_args[2] = hel 238 239 this_wavfunc = External(this_args, old_name) 240 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1) 241 242 graph.store_wav(this_wavfunc) 243 new_wavfuncs.append(this_wavfunc) 244 if ext_num in cls.wavs_same_leg: 245 cls.wavs_same_leg[ext_num] += new_wavfuncs 246 else: 247 cls.wavs_same_leg[ext_num] = new_wavfuncs 248 249 return new_wavfuncs
250 251 @classmethod
252 - def get_gwc(cls):
253 num_combs = len(cls.good_hel) 254 gwc_old = [[] for x in range(num_combs)] 255 gwc=[] 256 for n, comb in enumerate(cls.good_hel): 257 sols = [[]] 258 for leg, wavs in cls.wavs_same_leg.items(): 259 valid = [] 260 for wav in wavs: 261 if comb[leg] == wav.hel: 262 valid.append(wav) 263 gwc_old[n].append(wav) 264 if len(valid) == 1: 265 for sol in sols: 266 sol.append(valid[0]) 267 else: 268 tmp = [] 269 for w in valid: 270 for sol in sols: 271 tmp2 = list(sol) 272 tmp2.append(w) 273 tmp.append(tmp2) 274 sols = tmp 275 gwc += sols 276 277 cls.good_wav_combs = gwc
278 279 @staticmethod
280 - def format_name(*nums):
281 return f'W(1,{nums[0]})'
282
283 - def get_id(self):
284 """ return the id of the particle under consideration """ 285 286 try: 287 return self.id 288 except: 289 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0]) 290 return self.id
291
292 293 294 -class Internal(MathsObject):
295 '''Class for storing internal wavefunctions''' 296 297 max_wav_num = 0 298 num_internals = 0 299 300 @classmethod
301 - def raise_num(cls):
302 cls.num_internals += 1
303 304 @classmethod
305 - def generate_wavfuncs(cls, line, graph):
306 deps = cls.get_deps(line, graph) 307 new_wavfuncs = [ cls.get_obj(line, wavs, graph) 308 for wavs in product(*deps) 309 if cls.good_helicity(wavs, graph) ] 310 311 return new_wavfuncs
312 313 314 # There must be a better way 315 @classmethod
316 - def call_constructor(cls, new_args, old_name, diag_num):
317 return Internal(new_args, old_name)
318 319 @classmethod
320 - def get_number(cls, *args):
321 num = External.num_externals + Internal.num_internals + 1 322 if cls.max_wav_num < num: 323 cls.max_wav_num = num 324 return num
325
326 - def __init__(self, arguments, old_name):
327 super().__init__(arguments, old_name, 'internal') 328 self.raise_num()
329 330 331 @staticmethod
332 - def format_name(*nums):
333 return f'W(1,{nums[0]})'
334
335 -class Amplitude(MathsObject):
336 '''Class for storing Amplitudes''' 337 338 max_amp_num = 0 339
340 - def __init__(self, arguments, old_name, diag_num):
341 self.diag_num = diag_num 342 super().__init__(arguments, old_name, 'amplitude')
343 344 345 @staticmethod
346 - def format_name(*nums):
347 return f'AMP({nums[0]},{nums[1]})'
348 349 @classmethod
350 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
351 old_args = get_arguments(line) 352 old_name = old_args[-1].replace(' ','') 353 354 amp_index = re.search(r'\(.*?\)', old_name).group() 355 diag_num = int(amp_index[1:-1]) 356 357 deps = cls.get_deps(line, graph) 358 359 new_amps = [cls.get_obj(line, wavs, graph, diag_num) 360 for wavs in product(*deps) 361 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)] 362 363 return new_amps
364 365 @classmethod
366 - def call_constructor(cls, new_args, old_name, diag_num):
367 return Amplitude(new_args, old_name, diag_num)
368 369 @classmethod
370 - def get_number(cls, *args):
371 wavs, graph = args 372 amp_num = -1 373 exts = graph.external_wavs 374 hel_amp = tuple([w.hel for w in sorted(cls.ext_deps, key=lambda x: x.mg)]) 375 amp_num = External.map_hel[hel_amp] +1 # Offset because Fortran counts from 1 376 377 if cls.max_amp_num < amp_num: 378 cls.max_amp_num = amp_num 379 return amp_num
380
381 -class HelicityRecycler():
382 '''Class for recycling helicity''' 383
384 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
385 386 External.good_hel = [] 387 External.nhel_lines = '' 388 External.num_externals = 0 389 External.wavs_same_leg = {} 390 External.good_wav_combs = [] 391 392 Internal.max_wav_num = 0 393 Internal.num_internals = 0 394 395 Amplitude.max_amp_num = 0 396 self.last_category = None 397 self.good_elements = good_elements 398 self.bad_amps = bad_amps 399 self.bad_amps_perhel = bad_amps_perhel 400 401 # Default file names 402 self.input_file = 'matrix_orig.f' 403 self.output_file = 'matrix_orig.f' 404 self.template_file = 'template_matrix.f' 405 406 self.template_dict = {} 407 #initialise everything as for zero matrix element 408 self.template_dict['helicity_lines'] = '\n' 409 self.template_dict['helas_calls'] = [] 410 self.template_dict['jamp_lines'] = '\n' 411 self.template_dict['amp2_lines'] = '\n' 412 self.template_dict['ncomb'] = '0' 413 self.template_dict['nwavefuncs'] = '0' 414 415 self.dag = DAG() 416 417 self.diag_num = 1 418 self.got_gwc = False 419 420 self.procedure_name = self.input_file.split('.')[0].upper() 421 self.procedure_kind = 'FUNCTION' 422 423 self.old_out_name = '' 424 self.loop_var = 'K' 425 426 self.all_hel = [] 427 self.hel_filt = True
428
429 - def set_input(self, file):
430 if 'born_matrix' in file: 431 print('HelicityRecycler is currently ' 432 f'unable to handle {file}') 433 exit(1) 434 self.procedure_name = file.split('.')[0].upper() 435 self.procedure_kind = 'FUNCTION' 436 self.input_file = file
437
438 - def set_output(self, file):
439 self.output_file = file
440
441 - def set_template(self, file):
442 self.template_file = file
443
444 - def function_call(self, line):
445 # Check a function is called at all 446 if not 'CALL' in line: 447 return None 448 449 # Now check for external spinor 450 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX'] 451 if any( call in line for call in ext_calls ): 452 return 'external' 453 454 # Now check for internal 455 # Wont find a internal when no externals have been found... 456 # ... I assume 457 if not self.dag.external_wavs: 458 return None 459 460 # Search for internals by looking for calls to the externals 461 # Maybe I should just get a list of all internals? 462 matches = self.dag.old_names() & set(get_arguments(line)) 463 try: 464 matches.remove(get_arguments(line)[-1]) 465 except KeyError: 466 pass 467 try: 468 function = (line.split('(', 1)[0]).split()[-1] 469 except IndexError: 470 return None 471 # What if [-1] is garbage? Then I'm relying on needs changing. 472 # Is that OK? 473 if (function.split('_')[-1] != '0'): 474 return 'internal' 475 elif (function.split('_')[-1] == '0'): 476 return 'amplitude' 477 else: 478 print(f'Ahhhh what is going on here?\n{line}') 479 set_trace() 480 481 return None
482 483 # string manipulation 484
485 - def add_amp_index(self, matchobj):
486 old_pat = matchobj.group() 487 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var) 488 489 #new_pat = f'{self.loop_var},{old_pat[:-1]}{old_pat[-1]}' 490 return new_pat
491
492 - def add_indices(self, line):
493 '''Add loop_var index to amp and output variable. 494 Also update name of output variable.''' 495 # Doesnt work if the AMP arguments contain brackets 496 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line) 497 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line) 498 return new_line
499
500 - def jamp_finished(self, line):
501 # indent_end = re.compile(fr'{self.jamp_indent}END\W') 502 # m = indent_end.match(line) 503 # if m: 504 # return True 505 return 'init_mode' in line.lower()
506 #if f'{self.old_out_name}=0.D0' in line.replace(' ', ''): 507 # return True 508 #return False 509
510 - def get_old_name(self, line):
511 if f'{self.procedure_kind} {self.procedure_name}' in line: 512 if 'SUBROUTINE' == self.procedure_kind: 513 self.old_out_name = get_arguments(line)[-1] 514 if 'FUNCTION' == self.procedure_kind: 515 self.old_out_name = line.split('(')[0].split()[-1]
516
517 - def get_amp_stuff(self, line_num, line):
518 519 if 'diagram number' in line: 520 self.amp_calc_started = True 521 # Check if the calculation of this diagram is finished 522 if ('AMP' not in get_arguments(line)[-1] 523 and self.amp_calc_started and list(line)[0] != 'C'): 524 # Check if the calculation of all diagrams is finished 525 if self.function_call(line) not in ['external', 526 'internal', 527 'amplitude']: 528 self.jamp_started = True 529 self.amp_calc_started = False 530 if self.jamp_started: 531 self.get_jamp_lines(line) 532 if self.in_amp2: 533 self.get_amp2_lines(line) 534 if self.find_amp2 and line.startswith(' ENDDO'): 535 self.in_amp2 = True 536 self.find_amp2 = False
537
538 - def get_jamp_lines(self, line):
539 if self.jamp_finished(line): 540 self.jamp_started = False 541 self.find_amp2 = True 542 elif not line.isspace(): 543 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
544
545 - def get_amp2_lines(self, line):
546 if line.startswith(' DO I = 1, NCOLOR'): 547 self.in_amp2 = False 548 elif not line.isspace(): 549 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
550
551 - def prepare_bools(self):
552 self.amp_calc_started = False 553 self.jamp_started = False 554 self.find_amp2 = False 555 self.in_amp2 = False 556 self.nhel_started = False
557
558 - def unfold_helicities(self, line, nature):
559 560 561 562 #print('deps',line, deps) 563 if nature not in ['external', 'internal', 'amplitude']: 564 raise Exception('wrong unfolding') 565 566 if nature == 'external': 567 new_objs = External.generate_wavfuncs(line, self.dag) 568 for obj in new_objs: 569 obj.line = apply_args(line, [obj.args]) 570 else: 571 deps = Amplitude.get_deps(line, self.dag) 572 name2dep = dict([(d.name,d) for d in sum(deps,[])]) 573 574 575 if nature == 'internal': 576 new_objs = Internal.generate_wavfuncs(line, self.dag) 577 for obj in new_objs: 578 obj.line = apply_args(line, [obj.args]) 579 obj.linkdag = [] 580 for name in obj.args: 581 if name in name2dep: 582 name2dep[name].nb_used +=1 583 obj.linkdag.append(name2dep[name]) 584 585 if nature == 'amplitude': 586 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0] 587 if nb_diag not in self.bad_amps: 588 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel) 589 out_line = self.apply_amps(line, new_objs) 590 for i,obj in enumerate(new_objs): 591 if i == 0: 592 obj.line = out_line 593 obj.nb_used = 1 594 else: 595 obj.line = '' 596 obj.nb_used = 1 597 obj.linkdag = [] 598 for name in obj.args: 599 if name in name2dep: 600 name2dep[name].nb_used +=1 601 obj.linkdag.append(name2dep[name]) 602 else: 603 return '' 604 605 606 return new_objs
607 #return f'{line}\n' if nature == 'external' else line 608
609 - def apply_amps(self, line, new_objs):
610 if self.amp_splt: 611 return split_amps(line, new_objs) 612 else: 613 614 return apply_args(line, [i.args for i in new_objs])
615
616 - def get_gwc(self, line, category):
617 618 #self.last_category = 619 if category not in ['external', 'internal', 'amplitude']: 620 return 621 if self.last_category != 'external': 622 self.last_category = category 623 return 624 625 External.get_gwc() 626 self.last_category = category
627
628 - def get_good_hel(self, line):
629 if 'DATA (NHEL' in line: 630 self.nhel_started = True 631 this_hel = [int(hel) for hel in line.split('/')[1].split(',')] 632 self.all_hel.append(tuple(this_hel)) 633 elif self.nhel_started: 634 self.nhel_started = False 635 636 if self.hel_filt: 637 External.good_hel = [ self.all_hel[int(i)-1] for i in self.good_elements ] 638 else: 639 External.good_hel = self.all_hel 640 641 External.map_hel=dict([(hel,i) for i,hel in enumerate(External.good_hel)]) 642 External.hel_ranges = [set() for hel in External.good_hel[0]] 643 for comb in External.good_hel: 644 for i, hel in enumerate(comb): 645 External.hel_ranges[i].add(hel) 646 647 self.counter = 0 648 nhel_array = [self.nhel_string(hel) 649 for hel in External.good_hel] 650 nhel_lines = '\n'.join(nhel_array) 651 self.template_dict['helicity_lines'] += nhel_lines 652 653 self.template_dict['ncomb'] = len(External.good_hel)
654
655 - def nhel_string(self, hel_comb):
656 self.counter += 1 657 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb] 658 nexternal = len(hel_comb) 659 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
660
661 - def read_orig(self):
662 663 with open(self.input_file, 'r') as input_file: 664 665 self.prepare_bools() 666 667 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)): 668 if line_num == 0: 669 line_cache = line 670 continue 671 672 if '!SKIP' in line: 673 continue 674 675 char_5 = '' 676 try: 677 char_5 = line[5] 678 except IndexError: 679 pass 680 if char_5 == '$': 681 line_cache = undo_multiline(line_cache, line) 682 continue 683 684 line, line_cache = line_cache, line 685 686 self.get_old_name(line) 687 self.get_good_hel(line) 688 self.get_amp_stuff(line_num, line) 689 call_type = self.function_call(line) 690 self.get_gwc(line, call_type) 691 692 693 if call_type in ['external', 'internal', 'amplitude']: 694 self.template_dict['helas_calls'] += self.unfold_helicities( 695 line, call_type) 696 697 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num) 698 # filter out uselless call 699 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1): 700 obj = self.template_dict['helas_calls'][i] 701 if obj.nb_used == 0: 702 obj.line = '' 703 for dep in obj.linkdag: 704 dep.nb_used -= 1 705 706 707 708 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}' 709 for obj in self.template_dict['helas_calls'] 710 if obj.nb_used > 0 and obj.line])
711
712 - def read_template(self):
713 out_file = open(self.output_file, 'w+') 714 with open(self.template_file, 'r') as file: 715 for line in file: 716 s = Template(line) 717 line = s.safe_substitute(self.template_dict) 718 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')]) 719 out_file.write(line) 720 out_file.close()
721
723 try: 724 os.remove(self.output_file) 725 except Exception: 726 pass 727 input_file = self.output_file.replace("_optim.f", "_orig.f") 728 os.symlink(input_file, self.output_file)
729 730
731 - def generate_output_file(self):
732 if not self.good_elements: 733 misc.sprint("No helicity", self.input_file) 734 self.write_zero_matrix_element() 735 return 736 737 atexit.register(self.clean_up) 738 self.read_orig() 739 self.read_template() 740 atexit.unregister(self.clean_up)
741
742 - def clean_up(self):
743 pass
744
745 746 -def get_arguments(line):
747 '''Find the substrings separated by commas between the first 748 closed set of parentheses in 'line'. 749 ''' 750 bracket_depth = 0 751 element = 0 752 arguments = [''] 753 for char in line: 754 if char == '(': 755 bracket_depth += 1 756 if bracket_depth - 1 == 0: 757 # This is the first '('. We don't want to add it to 758 # 'arguments' 759 continue 760 if char == ')': 761 bracket_depth -= 1 762 if bracket_depth == 0: 763 # We've reached the end 764 break 765 if char == ',' and bracket_depth == 1: 766 element += 1 767 arguments.append('') 768 continue 769 if bracket_depth > 0: 770 arguments[element] += char 771 return arguments
772
773 774 -def apply_args(old_line, all_the_args):
775 function = (old_line.split('(')[0]).split()[-1] 776 old_args = old_line.split(function)[-1] 777 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n') 778 for x in all_the_args] 779 780 return ''.join(new_lines)
781
782 -def split_amps(line, new_amps):
783 if not new_amps: 784 return '' 785 fct = line.split('(',1)[0].split('_0')[0] 786 for i,amp in enumerate(new_amps): 787 if i == 0: 788 occur = [] 789 for a in amp.args: 790 if "W(1," in a: 791 tmp = collections.defaultdict(int) 792 tmp[a] += 1 793 occur.append(tmp) 794 else: 795 for i in range(len(occur)): 796 a = amp.args[i] 797 occur[i][a] +=1 798 # Each element in occur is the wavs that appear in a column, with 799 # the number of occurences 800 nb_wav = [len(o) for o in occur] 801 to_remove = nb_wav.index(max(nb_wav)) 802 # Remove the one that occurs the most 803 occur.pop(to_remove) 804 805 lines = [] 806 # Get the wavs per column 807 wav_name = [o.keys() for o in occur] 808 for wfcts in product(*wav_name): 809 # Select the amplitudes produced by wfcts 810 sub_amps = [amp for amp in new_amps 811 if all(w in amp.args for w in wfcts)] 812 if not sub_amps: 813 continue 814 if len(sub_amps) ==1: 815 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n','')) 816 817 continue 818 819 # the next line is to make the code nicer 820 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1])) 821 windices = [] 822 hel_calculated = [] 823 iamp = 0 824 for i,amp in enumerate(sub_amps): 825 args = amp.args[:] 826 # Remove wav and get its index 827 wcontract = args.pop(to_remove) 828 windex = wcontract.split(',')[1].split(')')[0] 829 windices.append(windex) 830 amp_result, args[-1] = args[-1], 'TMP(1)' 831 832 if i ==0: 833 # Call the original fct with P1N_... 834 # Final arg is replaced with TMP(1) 835 spin = fct.split(None,1)[1][to_remove] 836 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args))) 837 838 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0] 839 hel_calculated.append(hel) 840 #lines.append(' %(result)s = TMP(3) * W(3,%(w)s) + TMP(4) * W(4,%(w)s)+' 841 # % {'result': amp_result, 'w': windex}) 842 #lines.append(' & TMP(5) * W(5,%(w)s)+TMP(6) * W(6,%(w)s)' 843 # % {'result': amp_result, 'w': windex}) 844 if spin in "VF": 845 lines.append(""" call CombineAmp(%(nb)i, 846 & (/%(hel_list)s/), 847 & (/%(w_list)s/), 848 & TMP, W, AMP(1,%(iamp)s))""" % 849 {'nb': len(sub_amps), 850 'hel_list': ','.join(hel_calculated), 851 'w_list': ','.join(windices), 852 'iamp': iamp 853 }) 854 elif spin == "S": 855 lines.append(""" call CombineAmpS(%(nb)i, 856 &(/%(hel_list)s/), 857 & (/%(w_list)s/), 858 & TMP, W, AMP(1,%(iamp)s))""" % 859 {'nb': len(sub_amps), 860 'hel_list': ','.join(hel_calculated), 861 'w_list': ','.join(windices), 862 'iamp': iamp 863 }) 864 else: 865 raise Exception("split amp are not supported for spin2 and 3/2") 866 867 #lines.append('') 868 return '\n'.join(lines)
869
870 -def get_num(wav):
871 name = wav.name 872 between_brackets = re.search(r'\(.*?\)', name).group() 873 num = int(between_brackets[1:-1].split(',')[-1]) 874 return num
875
876 -def undo_multiline(old_line, new_line):
877 new_line = new_line[6:] 878 old_line = old_line.replace('\n','') 879 return f'{old_line}{new_line}'
880
881 -def do_multiline(line):
882 char_limit = 72 883 num_splits = len(line)//char_limit 884 if num_splits != 0 and len(line) != 72 and '!' not in line: 885 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)] 886 indent = '' 887 for char in line[6:]: 888 if char == ' ': 889 indent += char 890 else: 891 break 892 893 line = f'\n ${indent}'.join(split_line) 894 return line
895
896 -def int_to_string(i):
897 if i == 1: 898 return '+1' 899 if i == 0: 900 return ' 0' 901 if i == -1: 902 return '-1' 903 else: 904 print(f'How can {i} be a helicity?') 905 set_trace() 906 exit(1)
907
908 -def main():
909 parser = argparse.ArgumentParser() 910 parser.add_argument('input_file', help='The file containing the ' 911 'original matrix calculation') 912 parser.add_argument('hel_file', help='The file containing the ' 913 'contributing helicities') 914 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering') 915 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting') 916 917 args = parser.parse_args() 918 919 with open(args.hel_file, 'r') as file: 920 good_elements = file.readline().split() 921 922 recycler = HelicityRecycler(good_elements) 923 924 recycler.hel_filt = args.hel_filt 925 recycler.amp_splt = args.amp_splt 926 927 recycler.set_input(args.input_file) 928 recycler.set_output('green_matrix.f') 929 recycler.set_template('template_matrix1.f') 930 931 recycler.generate_output_file()
932 933 if __name__ == '__main__': 934 main() 935