Package madgraph :: Package madevent :: Module hel_recycle
[hide private]
[frames] | no frames]

Source Code for Module madgraph.madevent.hel_recycle

  1  #!/usr/bin/env python3 
  2   
  3  import argparse 
  4  import atexit 
  5  import re 
  6  import collections 
  7  from string import Template 
  8  from copy import copy 
  9  from itertools import product 
 10  from functools import reduce  
 11   
 12  try: 
 13       import madgraph 
 14  except: 
 15       import internal.misc as misc 
 16  else: 
 17       import madgraph.various.misc as misc 
 18  import mmap 
 19  try: 
 20      from tqdm import tqdm 
 21  except ImportError: 
 22      tqdm = misc.tqdm 
23 24 25 -def get_num_lines(file_path):
26 fp = open(file_path, 'r+') 27 buf = mmap.mmap(fp.fileno(),0) 28 lines = 0 29 while buf.readline(): 30 lines += 1 31 return lines
32
33 -class DAG:
34
35 - def __init__(self):
36 self.graph = {} 37 self.all_wavs = [] 38 self.external_wavs = [] 39 self.internal_wavs = []
40
41 - def store_wav(self, wav, ext_deps=[]):
42 self.all_wavs.append(wav) 43 nature = wav.nature 44 if nature == 'external': 45 self.external_wavs.append(wav) 46 if nature == 'internal': 47 self.internal_wavs.append(wav) 48 for ext in ext_deps: 49 self.add_branch(wav, ext)
50
51 - def add_branch(self, node_i, node_f):
52 try: 53 self.graph[node_i].append(node_f) 54 except KeyError: 55 self.graph[node_i] = [node_f]
56
57 - def dependencies(self, old_name):
58 deps = [wav for wav in self.all_wavs 59 if wav.old_name == old_name and not wav.dead] 60 return deps
61
62 - def kill_old(self, old_name):
63 for wav in self.all_wavs: 64 if wav.old_name == old_name: 65 wav.dead = True
66
67 - def old_names(self):
68 return {wav.old_name for wav in self.all_wavs}
69
70 - def find_path(self, start, end, path=[]):
71 '''Taken from https://www.python.org/doc/essays/graphs/''' 72 73 path = path + [start] 74 if start == end: 75 return path 76 if start not in self.graph: 77 return None 78 for node in self.graph[start]: 79 if node not in path: 80 newpath = self.find_path(node, end, path) 81 if newpath: 82 return newpath 83 return None
84
85 - def __str__(self):
86 return self.__repr__()
87
88 - def __repr__(self):
89 print_str = 'With new names:\n\t' 90 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ]) 91 print_str += '\n\nWith old names:\n\t' 92 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ]) 93 return print_str
94
95 96 97 -class MathsObject:
98 '''Abstract class for wavefunctions and Amplitudes''' 99 100 # Store here which externals the last wav/amp depends on. 101 # This saves us having to call find_path multiple times. 102 ext_deps = None 103
104 - def __init__(self, arguments, old_name, nature):
105 self.args = arguments 106 self.old_name = old_name 107 self.nature = nature 108 self.name = None 109 self.dead = False 110 self.nb_used = 0 111 self.linkdag = []
112
113 - def set_name(self, *args):
114 self.args[-1] = self.format_name(*args) 115 self.name = self.args[-1]
116
117 - def format_name(self, *nums):
118 pass
119 120 @staticmethod
121 - def get_deps(line, graph):
122 old_args = get_arguments(line) 123 old_name = old_args[-1].replace(' ','') 124 matches = graph.old_names() & set(old_args) 125 try: 126 matches.remove(old_name) 127 except KeyError: 128 pass 129 old_deps = old_args[0:len(matches)] 130 131 # If we're overwriting a wav clear it from graph 132 graph.kill_old(old_name) 133 return [graph.dependencies(dep) for dep in old_deps]
134 135 @classmethod
136 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
137 exts = graph.external_wavs 138 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) } 139 this_comb_good = False 140 for comb in External.good_wav_combs: 141 if cls.ext_deps.issubset(set(comb)): 142 this_comb_good = True 143 break 144 145 if diag_number and this_comb_good and cls.ext_deps: 146 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps]) 147 this_hel = [helicity[i] for i in range(1, len(helicity)+1)] 148 hel_number = 1 + all_hel.index(tuple(this_hel)) 149 150 if (hel_number,diag_number) in bad_hel_amp: 151 this_comb_good = False 152 153 154 155 return this_comb_good and cls.ext_deps
156 157 @staticmethod
158 - def get_new_args(line, wavs):
159 old_args = get_arguments(line) 160 old_name = old_args[-1].replace(' ','') 161 # Work out if wavs corresponds to an allowed helicity combination 162 this_args = copy(old_args) 163 wav_names = [w.name for w in wavs] 164 this_args[0:len(wavs)] = wav_names 165 # This isnt maximally efficient 166 # Could take the num from wavs that've been deleted in graph 167 return this_args
168 169 @staticmethod
170 - def get_number():
171 pass
172 173 @classmethod
174 - def get_obj(cls, line, wavs, graph, diag_num = None):
175 old_name = get_arguments(line)[-1].replace(' ','') 176 new_args = cls.get_new_args(line, wavs) 177 num = cls.get_number(wavs, graph) 178 179 this_obj = cls.call_constructor(new_args, old_name, diag_num) 180 this_obj.set_name(num, diag_num) 181 if this_obj.nature != 'amplitude': 182 graph.store_wav(this_obj, cls.ext_deps) 183 return this_obj
184 185
186 - def __str__(self):
187 return self.name
188
189 - def __repr__(self):
190 return self.name
191
192 -class External(MathsObject):
193 '''Class for storing external wavefunctions''' 194 195 good_hel = [] 196 nhel_lines = '' 197 num_externals = 0 198 # Could get this from dag but I'm worried about preserving order 199 wavs_same_leg = {} 200 good_wav_combs = [] 201
202 - def __init__(self, arguments, old_name):
203 super().__init__(arguments, old_name, 'external') 204 self.hel = int(self.args[2]) 205 self.hel_ranges = [] 206 self.raise_num()
207 208 @classmethod
209 - def raise_num(cls):
210 cls.num_externals += 1
211 212 @classmethod
213 - def generate_wavfuncs(cls, line, graph):
214 # If graph is passed in Internal it should be done here to so 215 # we can set names 216 old_args = get_arguments(line) 217 old_name = old_args[-1].replace(' ','') 218 graph.kill_old(old_name) 219 220 if 'NHEL' in old_args[2].upper(): 221 nhel_index = re.search(r'\(.*?\)', old_args[2]).group() 222 ext_num = int(nhel_index[1:-1]) - 1 223 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True) 224 new_hels = [int_to_string(i) for i in new_hels] 225 else: 226 # Spinor must be a scalar so give it hel = 0 227 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1 228 new_hels = [' 0'] 229 230 new_wavfuncs = [] 231 for hel in new_hels: 232 233 this_args = copy(old_args) 234 this_args[2] = hel 235 236 this_wavfunc = External(this_args, old_name) 237 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1) 238 239 graph.store_wav(this_wavfunc) 240 new_wavfuncs.append(this_wavfunc) 241 if ext_num in cls.wavs_same_leg: 242 cls.wavs_same_leg[ext_num] += new_wavfuncs 243 else: 244 cls.wavs_same_leg[ext_num] = new_wavfuncs 245 246 return new_wavfuncs
247 248 @classmethod
249 - def get_gwc(cls):
250 num_combs = len(cls.good_hel) 251 gwc_old = [[] for x in range(num_combs)] 252 gwc=[] 253 for n, comb in enumerate(cls.good_hel): 254 sols = [[]] 255 for leg, wavs in cls.wavs_same_leg.items(): 256 valid = [] 257 #misc.sprint(wavs) 258 for wav in wavs: 259 if comb[leg] == wav.hel: 260 valid.append(wav) 261 gwc_old[n].append(wav) 262 if len(valid) == 1: 263 for sol in sols: 264 sol.append(valid[0]) 265 else: 266 tmp = [] 267 for w in valid: 268 for sol in sols: 269 tmp2 = list(sol) 270 tmp2.append(w) 271 tmp.append(tmp2) 272 sols = tmp 273 gwc += sols 274 275 cls.good_wav_combs = gwc
276 277 @staticmethod
278 - def format_name(*nums):
279 return f'W(1,{nums[0]})'
280
281 - def get_id(self):
282 """ return the id of the particle under consideration """ 283 284 try: 285 return self.id 286 except: 287 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0]) 288 return self.id
289
290 291 292 -class Internal(MathsObject):
293 '''Class for storing internal wavefunctions''' 294 295 max_wav_num = 0 296 num_internals = 0 297 298 @classmethod
299 - def raise_num(cls):
300 cls.num_internals += 1
301 302 @classmethod
303 - def generate_wavfuncs(cls, line, graph):
304 deps = cls.get_deps(line, graph) 305 new_wavfuncs = [ cls.get_obj(line, wavs, graph) 306 for wavs in product(*deps) 307 if cls.good_helicity(wavs, graph) ] 308 309 return new_wavfuncs
310 311 312 # There must be a better way 313 @classmethod
314 - def call_constructor(cls, new_args, old_name, diag_num):
315 return Internal(new_args, old_name)
316 317 @classmethod
318 - def get_number(cls, *args):
319 num = External.num_externals + Internal.num_internals + 1 320 if cls.max_wav_num < num: 321 cls.max_wav_num = num 322 return num
323
324 - def __init__(self, arguments, old_name):
325 super().__init__(arguments, old_name, 'internal') 326 self.raise_num()
327 328 329 @staticmethod
330 - def format_name(*nums):
331 return f'W(1,{nums[0]})'
332
333 -class Amplitude(MathsObject):
334 '''Class for storing Amplitudes''' 335 336 max_amp_num = 0 337
338 - def __init__(self, arguments, old_name, diag_num):
339 self.diag_num = diag_num 340 super().__init__(arguments, old_name, 'amplitude')
341 342 343 @staticmethod
344 - def format_name(*nums):
345 return f'AMP({nums[0]},{nums[1]})'
346 347 @classmethod
348 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
349 old_args = get_arguments(line) 350 old_name = old_args[-1].replace(' ','') 351 352 amp_index = re.search(r'\(.*?\)', old_name).group() 353 diag_num = int(amp_index[1:-1]) 354 355 deps = cls.get_deps(line, graph) 356 357 new_amps = [cls.get_obj(line, wavs, graph, diag_num) 358 for wavs in product(*deps) 359 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)] 360 361 return new_amps
362 363 @classmethod
364 - def call_constructor(cls, new_args, old_name, diag_num):
365 return Amplitude(new_args, old_name, diag_num)
366 367 @classmethod
368 - def get_number(cls, *args):
369 wavs, graph = args 370 amp_num = -1 371 exts = graph.external_wavs 372 for i, comb in enumerate(External.good_wav_combs): 373 if set(comb) == set(cls.ext_deps): 374 # Offset because Fortran counts from 1 375 amp_num = i + 1 376 if amp_num < 1: 377 print('Failed to find amp_num') 378 exit(1) 379 if cls.max_amp_num < amp_num: 380 cls.max_amp_num = amp_num 381 return amp_num
382
383 -class HelicityRecycler():
384 '''Class for recycling helicity''' 385
386 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
387 388 External.good_hel = [] 389 External.nhel_lines = '' 390 External.num_externals = 0 391 External.wavs_same_leg = {} 392 External.good_wav_combs = [] 393 394 Internal.max_wav_num = 0 395 Internal.num_internals = 0 396 397 Amplitude.max_amp_num = 0 398 self.last_category = None 399 self.good_elements = good_elements 400 self.bad_amps = bad_amps 401 self.bad_amps_perhel = bad_amps_perhel 402 403 # Default file names 404 self.input_file = 'matrix_orig.f' 405 self.output_file = 'matrix_orig.f' 406 self.template_file = 'template_matrix.f' 407 408 self.template_dict = {} 409 #initialise everything as for zero matrix element 410 self.template_dict['helicity_lines'] = '\n' 411 self.template_dict['helas_calls'] = [] 412 self.template_dict['jamp_lines'] = '\n' 413 self.template_dict['amp2_lines'] = '\n' 414 self.template_dict['ncomb'] = '0' 415 self.template_dict['nwavefuncs'] = '0' 416 417 self.dag = DAG() 418 419 self.diag_num = 1 420 self.got_gwc = False 421 422 self.procedure_name = self.input_file.split('.')[0].upper() 423 self.procedure_kind = 'FUNCTION' 424 425 self.old_out_name = '' 426 self.loop_var = 'K' 427 428 self.all_hel = [] 429 self.hel_filt = True
430
431 - def set_input(self, file):
432 if 'born_matrix' in file: 433 print('HelicityRecycler is currently ' 434 f'unable to handle {file}') 435 exit(1) 436 self.procedure_name = file.split('.')[0].upper() 437 self.procedure_kind = 'FUNCTION' 438 self.input_file = file
439
440 - def set_output(self, file):
441 self.output_file = file
442
443 - def set_template(self, file):
444 self.template_file = file
445
446 - def function_call(self, line):
447 # Check a function is called at all 448 if not 'CALL' in line: 449 return None 450 451 # Now check for external spinor 452 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX'] 453 if any( call in line for call in ext_calls ): 454 return 'external' 455 456 # Now check for internal 457 # Wont find a internal when no externals have been found... 458 # ... I assume 459 if not self.dag.external_wavs: 460 return None 461 462 # Search for internals by looking for calls to the externals 463 # Maybe I should just get a list of all internals? 464 matches = self.dag.old_names() & set(get_arguments(line)) 465 try: 466 matches.remove(get_arguments(line)[-1]) 467 except KeyError: 468 pass 469 try: 470 function = (line.split('(', 1)[0]).split()[-1] 471 except IndexError: 472 return None 473 # What if [-1] is garbage? Then I'm relying on needs changing. 474 # Is that OK? 475 if (function.split('_')[-1] != '0'): 476 return 'internal' 477 elif (function.split('_')[-1] == '0'): 478 return 'amplitude' 479 else: 480 print(f'Ahhhh what is going on here?\n{line}') 481 set_trace() 482 483 return None
484 485 # string manipulation 486
487 - def add_amp_index(self, matchobj):
488 old_pat = matchobj.group() 489 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var) 490 #misc.sprint(self.loop_var, old_pat) 491 492 #new_pat = f'{self.loop_var},{old_pat[:-1]}{old_pat[-1]}' 493 return new_pat
494
495 - def add_indices(self, line):
496 '''Add loop_var index to amp and output variable. 497 Also update name of output variable.''' 498 # Doesnt work if the AMP arguments contain brackets 499 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line) 500 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line) 501 return new_line
502
503 - def jamp_finished(self, line):
504 # indent_end = re.compile(fr'{self.jamp_indent}END\W') 505 # m = indent_end.match(line) 506 # if m: 507 # return True 508 return 'init_mode' in line.lower()
509 #if f'{self.old_out_name}=0.D0' in line.replace(' ', ''): 510 # return True 511 #return False 512
513 - def get_old_name(self, line):
514 if f'{self.procedure_kind} {self.procedure_name}' in line: 515 if 'SUBROUTINE' == self.procedure_kind: 516 self.old_out_name = get_arguments(line)[-1] 517 if 'FUNCTION' == self.procedure_kind: 518 self.old_out_name = line.split('(')[0].split()[-1]
519
520 - def get_amp_stuff(self, line_num, line):
521 522 if 'diagram number' in line: 523 self.amp_calc_started = True 524 # Check if the calculation of this diagram is finished 525 if ('AMP' not in get_arguments(line)[-1] 526 and self.amp_calc_started and list(line)[0] != 'C'): 527 # Check if the calculation of all diagrams is finished 528 if self.function_call(line) not in ['external', 529 'internal', 530 'amplitude']: 531 self.jamp_started = True 532 self.amp_calc_started = False 533 if self.jamp_started: 534 self.get_jamp_lines(line) 535 if self.in_amp2: 536 self.get_amp2_lines(line) 537 if self.find_amp2 and line.startswith(' ENDDO'): 538 self.in_amp2 = True 539 self.find_amp2 = False
540
541 - def get_jamp_lines(self, line):
542 if self.jamp_finished(line): 543 self.jamp_started = False 544 self.find_amp2 = True 545 elif not line.isspace(): 546 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
547
548 - def get_amp2_lines(self, line):
549 if line.startswith(' DO I = 1, NCOLOR'): 550 self.in_amp2 = False 551 elif not line.isspace(): 552 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
553
554 - def prepare_bools(self):
555 self.amp_calc_started = False 556 self.jamp_started = False 557 self.find_amp2 = False 558 self.in_amp2 = False 559 self.nhel_started = False
560
561 - def unfold_helicities(self, line, nature):
562 563 564 565 #print('deps',line, deps) 566 if nature not in ['external', 'internal', 'amplitude']: 567 raise Exception('wrong unfolding') 568 569 if nature == 'external': 570 new_objs = External.generate_wavfuncs(line, self.dag) 571 for obj in new_objs: 572 obj.line = apply_args(line, [obj.args]) 573 else: 574 deps = Amplitude.get_deps(line, self.dag) 575 name2dep = dict([(d.name,d) for d in sum(deps,[])]) 576 577 578 if nature == 'internal': 579 new_objs = Internal.generate_wavfuncs(line, self.dag) 580 for obj in new_objs: 581 obj.line = apply_args(line, [obj.args]) 582 obj.linkdag = [] 583 for name in obj.args: 584 if name in name2dep: 585 name2dep[name].nb_used +=1 586 obj.linkdag.append(name2dep[name]) 587 588 if nature == 'amplitude': 589 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0] 590 if nb_diag not in self.bad_amps: 591 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel) 592 out_line = self.apply_amps(line, new_objs) 593 for i,obj in enumerate(new_objs): 594 if i == 0: 595 obj.line = out_line 596 obj.nb_used = 1 597 else: 598 obj.line = '' 599 obj.nb_used = 1 600 obj.linkdag = [] 601 for name in obj.args: 602 if name in name2dep: 603 name2dep[name].nb_used +=1 604 obj.linkdag.append(name2dep[name]) 605 else: 606 return '' 607 608 609 return new_objs
610 #return f'{line}\n' if nature == 'external' else line 611
612 - def apply_amps(self, line, new_objs):
613 if self.amp_splt: 614 return split_amps(line, new_objs) 615 else: 616 617 return apply_args(line, [i.args for i in new_objs])
618
619 - def get_gwc(self, line, category):
620 621 #self.last_category = 622 if category not in ['external', 'internal', 'amplitude']: 623 return 624 if self.last_category != 'external': 625 self.last_category = category 626 return 627 628 External.get_gwc() 629 self.last_category = category
630
631 - def get_good_hel(self, line):
632 if 'DATA (NHEL' in line: 633 self.nhel_started = True 634 this_hel = [int(hel) for hel in line.split('/')[1].split(',')] 635 self.all_hel.append(tuple(this_hel)) 636 elif self.nhel_started: 637 self.nhel_started = False 638 639 if self.hel_filt: 640 External.good_hel = [ self.all_hel[int(i)-1] for i in self.good_elements ] 641 else: 642 External.good_hel = self.all_hel 643 644 External.hel_ranges = [set() for hel in External.good_hel[0]] 645 for comb in External.good_hel: 646 for i, hel in enumerate(comb): 647 External.hel_ranges[i].add(hel) 648 649 self.counter = 0 650 nhel_array = [self.nhel_string(hel) 651 for hel in External.good_hel] 652 nhel_lines = '\n'.join(nhel_array) 653 self.template_dict['helicity_lines'] += nhel_lines 654 655 self.template_dict['ncomb'] = len(External.good_hel)
656
657 - def nhel_string(self, hel_comb):
658 self.counter += 1 659 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb] 660 nexternal = len(hel_comb) 661 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
662
663 - def read_orig(self):
664 665 with open(self.input_file, 'r') as input_file: 666 667 self.prepare_bools() 668 669 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)): 670 if line_num == 0: 671 line_cache = line 672 continue 673 674 if '!SKIP' in line: 675 continue 676 677 char_5 = '' 678 try: 679 char_5 = line[5] 680 except IndexError: 681 pass 682 if char_5 == '$': 683 line_cache = undo_multiline(line_cache, line) 684 continue 685 686 line, line_cache = line_cache, line 687 688 self.get_old_name(line) 689 self.get_good_hel(line) 690 self.get_amp_stuff(line_num, line) 691 call_type = self.function_call(line) 692 self.get_gwc(line, call_type) 693 694 695 696 if call_type in ['external', 'internal', 'amplitude']: 697 self.template_dict['helas_calls'] += self.unfold_helicities( 698 line, call_type) 699 700 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num) 701 # filter out uselless call 702 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1): 703 obj = self.template_dict['helas_calls'][i] 704 if obj.nb_used == 0: 705 obj.line = '' 706 for dep in obj.linkdag: 707 dep.nb_used -= 1 708 709 710 711 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}' 712 for obj in self.template_dict['helas_calls'] 713 if obj.nb_used > 0 and obj.line])
714
715 - def read_template(self):
716 out_file = open(self.output_file, 'w+') 717 with open(self.template_file, 'r') as file: 718 for line in file: 719 s = Template(line) 720 line = s.safe_substitute(self.template_dict) 721 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')]) 722 out_file.write(line) 723 out_file.close()
724
726 out_file = open(self.output_file, 'w+') 727 self.template_dict['ncomb'] = '0' 728 self.template_dict['nwavefuncs'] = '0' 729 self.template_dict['helas_calls'] = '' 730 with open(self.template_file, 'r') as file: 731 for line in file: 732 s = Template(line) 733 line = s.safe_substitute(self.template_dict) 734 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')]) 735 out_file.write(line) 736 out_file.close()
737 738
739 - def generate_output_file(self):
740 if not self.good_elements: 741 misc.sprint("No helicity", self.input_file) 742 self.write_zero_matrix_element() 743 return 744 745 atexit.register(self.clean_up) 746 self.read_orig() 747 self.read_template() 748 atexit.unregister(self.clean_up)
749
750 - def clean_up(self):
751 pass
752
753 754 -def get_arguments(line):
755 '''Find the substrings separated by commas between the first 756 closed set of parentheses in 'line'. 757 ''' 758 bracket_depth = 0 759 element = 0 760 arguments = [''] 761 for char in line: 762 if char == '(': 763 bracket_depth += 1 764 if bracket_depth - 1 == 0: 765 # This is the first '('. We don't want to add it to 766 # 'arguments' 767 continue 768 if char == ')': 769 bracket_depth -= 1 770 if bracket_depth == 0: 771 # We've reached the end 772 break 773 if char == ',' and bracket_depth == 1: 774 element += 1 775 arguments.append('') 776 continue 777 if bracket_depth > 0: 778 arguments[element] += char 779 return arguments
780
781 782 -def apply_args(old_line, all_the_args):
783 function = (old_line.split('(')[0]).split()[-1] 784 old_args = old_line.split(function)[-1] 785 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n') 786 for x in all_the_args] 787 788 return ''.join(new_lines)
789
790 -def split_amps(line, new_amps):
791 fct = line.split('(',1)[0].split('_0')[0] 792 for i,amp in enumerate(new_amps): 793 if i == 0: 794 occur = [] 795 for a in amp.args: 796 if "W(1," in a: 797 tmp = collections.defaultdict(int) 798 tmp[a] += 1 799 occur.append(tmp) 800 else: 801 for i in range(len(occur)): 802 a = amp.args[i] 803 occur[i][a] +=1 804 # Each element in occur is the wavs that appear in a column, with 805 # the number of occurences 806 nb_wav = [len(o) for o in occur] 807 to_remove = nb_wav.index(max(nb_wav)) 808 # Remove the one that occurs the most 809 occur.pop(to_remove) 810 811 lines = [] 812 # Get the wavs per column 813 wav_name = [o.keys() for o in occur] 814 for wfcts in product(*wav_name): 815 # Select the amplitudes produced by wfcts 816 sub_amps = [amp for amp in new_amps 817 if all(w in amp.args for w in wfcts)] 818 if not sub_amps: 819 continue 820 if len(sub_amps) ==1: 821 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n','')) 822 823 continue 824 825 # the next line is to make the code nicer 826 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1])) 827 windices = [] 828 hel_calculated = [] 829 iamp = 0 830 for i,amp in enumerate(sub_amps): 831 args = amp.args[:] 832 # Remove wav and get its index 833 wcontract = args.pop(to_remove) 834 windex = wcontract.split(',')[1].split(')')[0] 835 windices.append(windex) 836 amp_result, args[-1] = args[-1], 'TMP(1)' 837 838 if i ==0: 839 # Call the original fct with P1N_... 840 # Final arg is replaced with TMP(1) 841 spin = fct.split(None,1)[1][to_remove] 842 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args))) 843 844 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0] 845 hel_calculated.append(hel) 846 #lines.append(' %(result)s = TMP(3) * W(3,%(w)s) + TMP(4) * W(4,%(w)s)+' 847 # % {'result': amp_result, 'w': windex}) 848 #lines.append(' & TMP(5) * W(5,%(w)s)+TMP(6) * W(6,%(w)s)' 849 # % {'result': amp_result, 'w': windex}) 850 if spin in "VF": 851 lines.append(""" call CombineAmp(%(nb)i, 852 & (/%(hel_list)s/), 853 & (/%(w_list)s/), 854 & TMP, W, AMP(1,%(iamp)s))""" % 855 {'nb': len(sub_amps), 856 'hel_list': ','.join(hel_calculated), 857 'w_list': ','.join(windices), 858 'iamp': iamp 859 }) 860 elif spin == "S": 861 lines.append(""" call CombineAmpS(%(nb)i, 862 &(/%(hel_list)s/), 863 & (/%(w_list)s/), 864 & TMP, W, AMP(1,%(iamp)s))""" % 865 {'nb': len(sub_amps), 866 'hel_list': ','.join(hel_calculated), 867 'w_list': ','.join(windices), 868 'iamp': iamp 869 }) 870 else: 871 raise Exception("split amp are not supported for spin2 and 3/2") 872 873 #lines.append('') 874 return '\n'.join(lines)
875
876 -def get_num(wav):
877 name = wav.name 878 between_brackets = re.search(r'\(.*?\)', name).group() 879 num = int(between_brackets[1:-1].split(',')[-1]) 880 return num
881
882 -def undo_multiline(old_line, new_line):
883 new_line = new_line[6:] 884 old_line = old_line.replace('\n','') 885 return f'{old_line}{new_line}'
886
887 -def do_multiline(line):
888 char_limit = 72 889 num_splits = len(line)//char_limit 890 if num_splits != 0 and len(line) != 72 and '!' not in line: 891 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)] 892 indent = '' 893 for char in line[6:]: 894 if char == ' ': 895 indent += char 896 else: 897 break 898 899 line = f'\n ${indent}'.join(split_line) 900 return line
901
902 -def int_to_string(i):
903 if i == 1: 904 return '+1' 905 if i == 0: 906 return ' 0' 907 if i == -1: 908 return '-1' 909 else: 910 print(f'How can {i} be a helicity?') 911 set_trace() 912 exit(1)
913
914 -def main():
915 parser = argparse.ArgumentParser() 916 parser.add_argument('input_file', help='The file containing the ' 917 'original matrix calculation') 918 parser.add_argument('hel_file', help='The file containing the ' 919 'contributing helicities') 920 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering') 921 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting') 922 923 args = parser.parse_args() 924 925 with open(args.hel_file, 'r') as file: 926 good_elements = file.readline().split() 927 928 recycler = HelicityRecycler(good_elements) 929 930 recycler.hel_filt = args.hel_filt 931 recycler.amp_splt = args.amp_splt 932 933 recycler.set_input(args.input_file) 934 recycler.set_output('green_matrix.f') 935 recycler.set_template('template_matrix1.f') 936 937 recycler.generate_output_file()
938 939 if __name__ == '__main__': 940 main() 941