1
2
3 import argparse
4 import atexit
5 import re
6 import collections
7 from string import Template
8 from copy import copy
9 from itertools import product
10 from functools import reduce
11
12 try:
13 import madgraph
14 except:
15 import internal.misc as misc
16 else:
17 import madgraph.various.misc as misc
18 import mmap
19 try:
20 from tqdm import tqdm
21 except ImportError:
22 tqdm = misc.tqdm
26 fp = open(file_path, 'r+')
27 buf = mmap.mmap(fp.fileno(),0)
28 lines = 0
29 while buf.readline():
30 lines += 1
31 return lines
32
34
36 self.graph = {}
37 self.all_wavs = []
38 self.external_wavs = []
39 self.internal_wavs = []
40
42 self.all_wavs.append(wav)
43 nature = wav.nature
44 if nature == 'external':
45 self.external_wavs.append(wav)
46 if nature == 'internal':
47 self.internal_wavs.append(wav)
48 for ext in ext_deps:
49 self.add_branch(wav, ext)
50
52 try:
53 self.graph[node_i].append(node_f)
54 except KeyError:
55 self.graph[node_i] = [node_f]
56
58 deps = [wav for wav in self.all_wavs
59 if wav.old_name == old_name and not wav.dead]
60 return deps
61
63 for wav in self.all_wavs:
64 if wav.old_name == old_name:
65 wav.dead = True
66
68 return {wav.old_name for wav in self.all_wavs}
69
71 '''Taken from https://www.python.org/doc/essays/graphs/'''
72
73 path = path + [start]
74 if start == end:
75 return path
76 if start not in self.graph:
77 return None
78 for node in self.graph[start]:
79 if node not in path:
80 newpath = self.find_path(node, end, path)
81 if newpath:
82 return newpath
83 return None
84
87
89 print_str = 'With new names:\n\t'
90 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ])
91 print_str += '\n\nWith old names:\n\t'
92 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ])
93 return print_str
94
98 '''Abstract class for wavefunctions and Amplitudes'''
99
100
101
102 ext_deps = None
103
104 - def __init__(self, arguments, old_name, nature):
105 self.args = arguments
106 self.old_name = old_name
107 self.nature = nature
108 self.name = None
109 self.dead = False
110 self.nb_used = 0
111 self.linkdag = []
112
116
119
120 @staticmethod
122 old_args = get_arguments(line)
123 old_name = old_args[-1].replace(' ','')
124 matches = graph.old_names() & set(old_args)
125 try:
126 matches.remove(old_name)
127 except KeyError:
128 pass
129 old_deps = old_args[0:len(matches)]
130
131
132 graph.kill_old(old_name)
133 return [graph.dependencies(dep) for dep in old_deps]
134
135 @classmethod
136 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
137 exts = graph.external_wavs
138 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) }
139 this_comb_good = False
140 for comb in External.good_wav_combs:
141 if cls.ext_deps.issubset(set(comb)):
142 this_comb_good = True
143 break
144
145 if diag_number and this_comb_good and cls.ext_deps:
146 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps])
147 this_hel = [helicity[i] for i in range(1, len(helicity)+1)]
148 hel_number = 1 + all_hel.index(tuple(this_hel))
149
150 if (hel_number,diag_number) in bad_hel_amp:
151 this_comb_good = False
152
153
154
155 return this_comb_good and cls.ext_deps
156
157 @staticmethod
159 old_args = get_arguments(line)
160 old_name = old_args[-1].replace(' ','')
161
162 this_args = copy(old_args)
163 wav_names = [w.name for w in wavs]
164 this_args[0:len(wavs)] = wav_names
165
166
167 return this_args
168
169 @staticmethod
172
173 @classmethod
174 - def get_obj(cls, line, wavs, graph, diag_num = None):
184
185
188
191
193 '''Class for storing external wavefunctions'''
194
195 good_hel = []
196 nhel_lines = ''
197 num_externals = 0
198
199 wavs_same_leg = {}
200 good_wav_combs = []
201
202 - def __init__(self, arguments, old_name):
203 super().__init__(arguments, old_name, 'external')
204 self.hel = int(self.args[2])
205 self.mg = int(arguments[0].split(',')[-1][:-1])
206 self.hel_ranges = []
207 self.raise_num()
208
209 @classmethod
212
213 @classmethod
215
216
217 old_args = get_arguments(line)
218 old_name = old_args[-1].replace(' ','')
219 graph.kill_old(old_name)
220
221 if 'NHEL' in old_args[2].upper():
222 nhel_index = re.search(r'\(.*?\)', old_args[2]).group()
223 ext_num = int(nhel_index[1:-1]) - 1
224 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True)
225 new_hels = [int_to_string(i) for i in new_hels]
226 else:
227
228 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1
229 new_hels = [' 0']
230
231 new_wavfuncs = []
232 for hel in new_hels:
233
234 this_args = copy(old_args)
235 this_args[2] = hel
236
237 this_wavfunc = External(this_args, old_name)
238 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1)
239
240 graph.store_wav(this_wavfunc)
241 new_wavfuncs.append(this_wavfunc)
242 if ext_num in cls.wavs_same_leg:
243 cls.wavs_same_leg[ext_num] += new_wavfuncs
244 else:
245 cls.wavs_same_leg[ext_num] = new_wavfuncs
246
247 return new_wavfuncs
248
249 @classmethod
251 num_combs = len(cls.good_hel)
252 gwc_old = [[] for x in range(num_combs)]
253 gwc=[]
254 for n, comb in enumerate(cls.good_hel):
255 sols = [[]]
256 for leg, wavs in cls.wavs_same_leg.items():
257 valid = []
258 for wav in wavs:
259 if comb[leg] == wav.hel:
260 valid.append(wav)
261 gwc_old[n].append(wav)
262 if len(valid) == 1:
263 for sol in sols:
264 sol.append(valid[0])
265 else:
266 tmp = []
267 for w in valid:
268 for sol in sols:
269 tmp2 = list(sol)
270 tmp2.append(w)
271 tmp.append(tmp2)
272 sols = tmp
273 gwc += sols
274
275 cls.good_wav_combs = gwc
276
277 @staticmethod
280
282 """ return the id of the particle under consideration """
283
284 try:
285 return self.id
286 except:
287 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0])
288 return self.id
289
293 '''Class for storing internal wavefunctions'''
294
295 max_wav_num = 0
296 num_internals = 0
297
298 @classmethod
301
302 @classmethod
304 deps = cls.get_deps(line, graph)
305 new_wavfuncs = [ cls.get_obj(line, wavs, graph)
306 for wavs in product(*deps)
307 if cls.good_helicity(wavs, graph) ]
308
309 return new_wavfuncs
310
311
312
313 @classmethod
316
317 @classmethod
323
324 - def __init__(self, arguments, old_name):
327
328
329 @staticmethod
332
334 '''Class for storing Amplitudes'''
335
336 max_amp_num = 0
337
338 - def __init__(self, arguments, old_name, diag_num):
339 self.diag_num = diag_num
340 super().__init__(arguments, old_name, 'amplitude')
341
342
343 @staticmethod
346
347 @classmethod
348 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
349 old_args = get_arguments(line)
350 old_name = old_args[-1].replace(' ','')
351
352 amp_index = re.search(r'\(.*?\)', old_name).group()
353 diag_num = int(amp_index[1:-1])
354
355 deps = cls.get_deps(line, graph)
356
357 new_amps = [cls.get_obj(line, wavs, graph, diag_num)
358 for wavs in product(*deps)
359 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)]
360
361 return new_amps
362
363 @classmethod
365 return Amplitude(new_args, old_name, diag_num)
366
367 @classmethod
369 wavs, graph = args
370 amp_num = -1
371 exts = graph.external_wavs
372 hel_amp = tuple([w.hel for w in sorted(cls.ext_deps, key=lambda x: x.mg)])
373 amp_num = External.map_hel[hel_amp] +1
374
375 if cls.max_amp_num < amp_num:
376 cls.max_amp_num = amp_num
377 return amp_num
378
380 '''Class for recycling helicity'''
381
382 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
383
384 External.good_hel = []
385 External.nhel_lines = ''
386 External.num_externals = 0
387 External.wavs_same_leg = {}
388 External.good_wav_combs = []
389
390 Internal.max_wav_num = 0
391 Internal.num_internals = 0
392
393 Amplitude.max_amp_num = 0
394 self.last_category = None
395 self.good_elements = good_elements
396 self.bad_amps = bad_amps
397 self.bad_amps_perhel = bad_amps_perhel
398
399
400 self.input_file = 'matrix_orig.f'
401 self.output_file = 'matrix_orig.f'
402 self.template_file = 'template_matrix.f'
403
404 self.template_dict = {}
405
406 self.template_dict['helicity_lines'] = '\n'
407 self.template_dict['helas_calls'] = []
408 self.template_dict['jamp_lines'] = '\n'
409 self.template_dict['amp2_lines'] = '\n'
410 self.template_dict['ncomb'] = '0'
411 self.template_dict['nwavefuncs'] = '0'
412
413 self.dag = DAG()
414
415 self.diag_num = 1
416 self.got_gwc = False
417
418 self.procedure_name = self.input_file.split('.')[0].upper()
419 self.procedure_kind = 'FUNCTION'
420
421 self.old_out_name = ''
422 self.loop_var = 'K'
423
424 self.all_hel = []
425 self.hel_filt = True
426
435
437 self.output_file = file
438
441
443
444 if not 'CALL' in line:
445 return None
446
447
448 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX']
449 if any( call in line for call in ext_calls ):
450 return 'external'
451
452
453
454
455 if not self.dag.external_wavs:
456 return None
457
458
459
460 matches = self.dag.old_names() & set(get_arguments(line))
461 try:
462 matches.remove(get_arguments(line)[-1])
463 except KeyError:
464 pass
465 try:
466 function = (line.split('(', 1)[0]).split()[-1]
467 except IndexError:
468 return None
469
470
471 if (function.split('_')[-1] != '0'):
472 return 'internal'
473 elif (function.split('_')[-1] == '0'):
474 return 'amplitude'
475 else:
476 print(f'Ahhhh what is going on here?\n{line}')
477 set_trace()
478
479 return None
480
481
482
484 old_pat = matchobj.group()
485 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var)
486
487
488 return new_pat
489
491 '''Add loop_var index to amp and output variable.
492 Also update name of output variable.'''
493
494 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line)
495 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line)
496 return new_line
497
499
500
501
502
503 return 'init_mode' in line.lower()
504
505
506
507
509 if f'{self.procedure_kind} {self.procedure_name}' in line:
510 if 'SUBROUTINE' == self.procedure_kind:
511 self.old_out_name = get_arguments(line)[-1]
512 if 'FUNCTION' == self.procedure_kind:
513 self.old_out_name = line.split('(')[0].split()[-1]
514
516
517 if 'diagram number' in line:
518 self.amp_calc_started = True
519
520 if ('AMP' not in get_arguments(line)[-1]
521 and self.amp_calc_started and list(line)[0] != 'C'):
522
523 if self.function_call(line) not in ['external',
524 'internal',
525 'amplitude']:
526 self.jamp_started = True
527 self.amp_calc_started = False
528 if self.jamp_started:
529 self.get_jamp_lines(line)
530 if self.in_amp2:
531 self.get_amp2_lines(line)
532 if self.find_amp2 and line.startswith(' ENDDO'):
533 self.in_amp2 = True
534 self.find_amp2 = False
535
537 if self.jamp_finished(line):
538 self.jamp_started = False
539 self.find_amp2 = True
540 elif not line.isspace():
541 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
542
544 if line.startswith(' DO I = 1, NCOLOR'):
545 self.in_amp2 = False
546 elif not line.isspace():
547 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
548
550 self.amp_calc_started = False
551 self.jamp_started = False
552 self.find_amp2 = False
553 self.in_amp2 = False
554 self.nhel_started = False
555
557
558
559
560
561 if nature not in ['external', 'internal', 'amplitude']:
562 raise Exception('wrong unfolding')
563
564 if nature == 'external':
565 new_objs = External.generate_wavfuncs(line, self.dag)
566 for obj in new_objs:
567 obj.line = apply_args(line, [obj.args])
568 else:
569 deps = Amplitude.get_deps(line, self.dag)
570 name2dep = dict([(d.name,d) for d in sum(deps,[])])
571
572
573 if nature == 'internal':
574 new_objs = Internal.generate_wavfuncs(line, self.dag)
575 for obj in new_objs:
576 obj.line = apply_args(line, [obj.args])
577 obj.linkdag = []
578 for name in obj.args:
579 if name in name2dep:
580 name2dep[name].nb_used +=1
581 obj.linkdag.append(name2dep[name])
582
583 if nature == 'amplitude':
584 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0]
585 if nb_diag not in self.bad_amps:
586 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel)
587 out_line = self.apply_amps(line, new_objs)
588 for i,obj in enumerate(new_objs):
589 if i == 0:
590 obj.line = out_line
591 obj.nb_used = 1
592 else:
593 obj.line = ''
594 obj.nb_used = 1
595 obj.linkdag = []
596 for name in obj.args:
597 if name in name2dep:
598 name2dep[name].nb_used +=1
599 obj.linkdag.append(name2dep[name])
600 else:
601 return ''
602
603
604 return new_objs
605
606
613
614 - def get_gwc(self, line, category):
615
616
617 if category not in ['external', 'internal', 'amplitude']:
618 return
619 if self.last_category != 'external':
620 self.last_category = category
621 return
622
623 External.get_gwc()
624 self.last_category = category
625
652
654 self.counter += 1
655 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb]
656 nexternal = len(hel_comb)
657 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
658
660
661 with open(self.input_file, 'r') as input_file:
662
663 self.prepare_bools()
664
665 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)):
666 if line_num == 0:
667 line_cache = line
668 continue
669
670 if '!SKIP' in line:
671 continue
672
673 char_5 = ''
674 try:
675 char_5 = line[5]
676 except IndexError:
677 pass
678 if char_5 == '$':
679 line_cache = undo_multiline(line_cache, line)
680 continue
681
682 line, line_cache = line_cache, line
683
684 self.get_old_name(line)
685 self.get_good_hel(line)
686 self.get_amp_stuff(line_num, line)
687 call_type = self.function_call(line)
688 self.get_gwc(line, call_type)
689
690
691
692 if call_type in ['external', 'internal', 'amplitude']:
693 self.template_dict['helas_calls'] += self.unfold_helicities(
694 line, call_type)
695
696 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num)
697
698 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1):
699 obj = self.template_dict['helas_calls'][i]
700 if obj.nb_used == 0:
701 obj.line = ''
702 for dep in obj.linkdag:
703 dep.nb_used -= 1
704
705
706
707 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}'
708 for obj in self.template_dict['helas_calls']
709 if obj.nb_used > 0 and obj.line])
710
712 out_file = open(self.output_file, 'w+')
713 with open(self.template_file, 'r') as file:
714 for line in file:
715 s = Template(line)
716 line = s.safe_substitute(self.template_dict)
717 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
718 out_file.write(line)
719 out_file.close()
720
722 out_file = open(self.output_file, 'w+')
723 self.template_dict['ncomb'] = '0'
724 self.template_dict['nwavefuncs'] = '0'
725 self.template_dict['helas_calls'] = ''
726 with open(self.template_file, 'r') as file:
727 for line in file:
728 s = Template(line)
729 line = s.safe_substitute(self.template_dict)
730 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
731 out_file.write(line)
732 out_file.close()
733
734
745
748
751 '''Find the substrings separated by commas between the first
752 closed set of parentheses in 'line'.
753 '''
754 bracket_depth = 0
755 element = 0
756 arguments = ['']
757 for char in line:
758 if char == '(':
759 bracket_depth += 1
760 if bracket_depth - 1 == 0:
761
762
763 continue
764 if char == ')':
765 bracket_depth -= 1
766 if bracket_depth == 0:
767
768 break
769 if char == ',' and bracket_depth == 1:
770 element += 1
771 arguments.append('')
772 continue
773 if bracket_depth > 0:
774 arguments[element] += char
775 return arguments
776
779 function = (old_line.split('(')[0]).split()[-1]
780 old_args = old_line.split(function)[-1]
781 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n')
782 for x in all_the_args]
783
784 return ''.join(new_lines)
785
787 if not new_amps:
788 return ''
789 fct = line.split('(',1)[0].split('_0')[0]
790 for i,amp in enumerate(new_amps):
791 if i == 0:
792 occur = []
793 for a in amp.args:
794 if "W(1," in a:
795 tmp = collections.defaultdict(int)
796 tmp[a] += 1
797 occur.append(tmp)
798 else:
799 for i in range(len(occur)):
800 a = amp.args[i]
801 occur[i][a] +=1
802
803
804 nb_wav = [len(o) for o in occur]
805 to_remove = nb_wav.index(max(nb_wav))
806
807 occur.pop(to_remove)
808
809 lines = []
810
811 wav_name = [o.keys() for o in occur]
812 for wfcts in product(*wav_name):
813
814 sub_amps = [amp for amp in new_amps
815 if all(w in amp.args for w in wfcts)]
816 if not sub_amps:
817 continue
818 if len(sub_amps) ==1:
819 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n',''))
820
821 continue
822
823
824 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1]))
825 windices = []
826 hel_calculated = []
827 iamp = 0
828 for i,amp in enumerate(sub_amps):
829 args = amp.args[:]
830
831 wcontract = args.pop(to_remove)
832 windex = wcontract.split(',')[1].split(')')[0]
833 windices.append(windex)
834 amp_result, args[-1] = args[-1], 'TMP(1)'
835
836 if i ==0:
837
838
839 spin = fct.split(None,1)[1][to_remove]
840 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args)))
841
842 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0]
843 hel_calculated.append(hel)
844
845
846
847
848 if spin in "VF":
849 lines.append(""" call CombineAmp(%(nb)i,
850 & (/%(hel_list)s/),
851 & (/%(w_list)s/),
852 & TMP, W, AMP(1,%(iamp)s))""" %
853 {'nb': len(sub_amps),
854 'hel_list': ','.join(hel_calculated),
855 'w_list': ','.join(windices),
856 'iamp': iamp
857 })
858 elif spin == "S":
859 lines.append(""" call CombineAmpS(%(nb)i,
860 &(/%(hel_list)s/),
861 & (/%(w_list)s/),
862 & TMP, W, AMP(1,%(iamp)s))""" %
863 {'nb': len(sub_amps),
864 'hel_list': ','.join(hel_calculated),
865 'w_list': ','.join(windices),
866 'iamp': iamp
867 })
868 else:
869 raise Exception("split amp are not supported for spin2 and 3/2")
870
871
872 return '\n'.join(lines)
873
875 name = wav.name
876 between_brackets = re.search(r'\(.*?\)', name).group()
877 num = int(between_brackets[1:-1].split(',')[-1])
878 return num
879
881 new_line = new_line[6:]
882 old_line = old_line.replace('\n','')
883 return f'{old_line}{new_line}'
884
886 char_limit = 72
887 num_splits = len(line)//char_limit
888 if num_splits != 0 and len(line) != 72 and '!' not in line:
889 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)]
890 indent = ''
891 for char in line[6:]:
892 if char == ' ':
893 indent += char
894 else:
895 break
896
897 line = f'\n ${indent}'.join(split_line)
898 return line
899
901 if i == 1:
902 return '+1'
903 if i == 0:
904 return ' 0'
905 if i == -1:
906 return '-1'
907 else:
908 print(f'How can {i} be a helicity?')
909 set_trace()
910 exit(1)
911
913 parser = argparse.ArgumentParser()
914 parser.add_argument('input_file', help='The file containing the '
915 'original matrix calculation')
916 parser.add_argument('hel_file', help='The file containing the '
917 'contributing helicities')
918 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering')
919 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting')
920
921 args = parser.parse_args()
922
923 with open(args.hel_file, 'r') as file:
924 good_elements = file.readline().split()
925
926 recycler = HelicityRecycler(good_elements)
927
928 recycler.hel_filt = args.hel_filt
929 recycler.amp_splt = args.amp_splt
930
931 recycler.set_input(args.input_file)
932 recycler.set_output('green_matrix.f')
933 recycler.set_template('template_matrix1.f')
934
935 recycler.generate_output_file()
936
937 if __name__ == '__main__':
938 main()
939