1
2
3 import argparse
4 import atexit
5 import re
6 import collections
7 from string import Template
8 from copy import copy
9 from itertools import product
10 from functools import reduce
11
12 try:
13 import madgraph
14 except:
15 import internal.misc as misc
16 else:
17 import madgraph.various.misc as misc
18 import mmap
19 try:
20 from tqdm import tqdm
21 except ImportError:
22 tqdm = misc.tqdm
26 fp = open(file_path, 'r+')
27 buf = mmap.mmap(fp.fileno(),0)
28 lines = 0
29 while buf.readline():
30 lines += 1
31 return lines
32
34
36 self.graph = {}
37 self.all_wavs = []
38 self.external_wavs = []
39 self.internal_wavs = []
40
42 self.all_wavs.append(wav)
43 nature = wav.nature
44 if nature == 'external':
45 self.external_wavs.append(wav)
46 if nature == 'internal':
47 self.internal_wavs.append(wav)
48 for ext in ext_deps:
49 self.add_branch(wav, ext)
50
52 try:
53 self.graph[node_i].append(node_f)
54 except KeyError:
55 self.graph[node_i] = [node_f]
56
58 deps = [wav for wav in self.all_wavs
59 if wav.old_name == old_name and not wav.dead]
60 return deps
61
63 for wav in self.all_wavs:
64 if wav.old_name == old_name:
65 wav.dead = True
66
68 return {wav.old_name for wav in self.all_wavs}
69
71 '''Taken from https://www.python.org/doc/essays/graphs/'''
72
73 path = path + [start]
74 if start == end:
75 return path
76 if start not in self.graph:
77 return None
78 for node in self.graph[start]:
79 if node not in path:
80 newpath = self.find_path(node, end, path)
81 if newpath:
82 return newpath
83 return None
84
87
89 print_str = 'With new names:\n\t'
90 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ])
91 print_str += '\n\nWith old names:\n\t'
92 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ])
93 return print_str
94
98 '''Abstract class for wavefunctions and Amplitudes'''
99
100
101
102 ext_deps = None
103
104 - def __init__(self, arguments, old_name, nature):
105 self.args = arguments
106 self.old_name = old_name
107 self.nature = nature
108 self.name = None
109 self.dead = False
110 self.nb_used = 0
111 self.linkdag = []
112
116
119
120 @staticmethod
122 old_args = get_arguments(line)
123 old_name = old_args[-1].replace(' ','')
124 matches = graph.old_names() & set(old_args)
125 try:
126 matches.remove(old_name)
127 except KeyError:
128 pass
129 old_deps = old_args[0:len(matches)]
130
131
132 graph.kill_old(old_name)
133 return [graph.dependencies(dep) for dep in old_deps]
134
135 @classmethod
136 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
137 exts = graph.external_wavs
138 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) }
139 this_comb_good = False
140 for comb in External.good_wav_combs:
141 if cls.ext_deps.issubset(set(comb)):
142 this_comb_good = True
143 break
144
145 if diag_number and this_comb_good and cls.ext_deps:
146 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps])
147 this_hel = [helicity[i] for i in range(1, len(helicity)+1)]
148 hel_number = 1 + all_hel.index(tuple(this_hel))
149
150 if (hel_number,diag_number) in bad_hel_amp:
151 this_comb_good = False
152
153
154
155 return this_comb_good and cls.ext_deps
156
157 @staticmethod
159 old_args = get_arguments(line)
160 old_name = old_args[-1].replace(' ','')
161
162 this_args = copy(old_args)
163 wav_names = [w.name for w in wavs]
164 this_args[0:len(wavs)] = wav_names
165
166
167 return this_args
168
169 @staticmethod
172
173 @classmethod
174 - def get_obj(cls, line, wavs, graph, diag_num = None):
184
185
188
191
193 '''Class for storing external wavefunctions'''
194
195 good_hel = []
196 nhel_lines = ''
197 num_externals = 0
198
199 wavs_same_leg = {}
200 good_wav_combs = []
201
202 - def __init__(self, arguments, old_name):
203 super().__init__(arguments, old_name, 'external')
204 self.hel = int(self.args[2])
205 self.hel_ranges = []
206 self.raise_num()
207
208 @classmethod
211
212 @classmethod
214
215
216 old_args = get_arguments(line)
217 old_name = old_args[-1].replace(' ','')
218 graph.kill_old(old_name)
219
220 if 'NHEL' in old_args[2].upper():
221 nhel_index = re.search(r'\(.*?\)', old_args[2]).group()
222 ext_num = int(nhel_index[1:-1]) - 1
223 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True)
224 new_hels = [int_to_string(i) for i in new_hels]
225 else:
226
227 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1
228 new_hels = [' 0']
229
230 new_wavfuncs = []
231 for hel in new_hels:
232
233 this_args = copy(old_args)
234 this_args[2] = hel
235
236 this_wavfunc = External(this_args, old_name)
237 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1)
238
239 graph.store_wav(this_wavfunc)
240 new_wavfuncs.append(this_wavfunc)
241 if ext_num in cls.wavs_same_leg:
242 cls.wavs_same_leg[ext_num] += new_wavfuncs
243 else:
244 cls.wavs_same_leg[ext_num] = new_wavfuncs
245
246 return new_wavfuncs
247
248 @classmethod
250 num_combs = len(cls.good_hel)
251 gwc_old = [[] for x in range(num_combs)]
252 gwc=[]
253 for n, comb in enumerate(cls.good_hel):
254 sols = [[]]
255 for leg, wavs in cls.wavs_same_leg.items():
256 valid = []
257
258 for wav in wavs:
259 if comb[leg] == wav.hel:
260 valid.append(wav)
261 gwc_old[n].append(wav)
262 if len(valid) == 1:
263 for sol in sols:
264 sol.append(valid[0])
265 else:
266 tmp = []
267 for w in valid:
268 for sol in sols:
269 tmp2 = list(sol)
270 tmp2.append(w)
271 tmp.append(tmp2)
272 sols = tmp
273 gwc += sols
274
275 cls.good_wav_combs = gwc
276
277 @staticmethod
280
282 """ return the id of the particle under consideration """
283
284 try:
285 return self.id
286 except:
287 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0])
288 return self.id
289
293 '''Class for storing internal wavefunctions'''
294
295 max_wav_num = 0
296 num_internals = 0
297
298 @classmethod
301
302 @classmethod
304 deps = cls.get_deps(line, graph)
305 new_wavfuncs = [ cls.get_obj(line, wavs, graph)
306 for wavs in product(*deps)
307 if cls.good_helicity(wavs, graph) ]
308
309 return new_wavfuncs
310
311
312
313 @classmethod
316
317 @classmethod
323
324 - def __init__(self, arguments, old_name):
327
328
329 @staticmethod
332
334 '''Class for storing Amplitudes'''
335
336 max_amp_num = 0
337
338 - def __init__(self, arguments, old_name, diag_num):
339 self.diag_num = diag_num
340 super().__init__(arguments, old_name, 'amplitude')
341
342
343 @staticmethod
346
347 @classmethod
348 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
349 old_args = get_arguments(line)
350 old_name = old_args[-1].replace(' ','')
351
352 amp_index = re.search(r'\(.*?\)', old_name).group()
353 diag_num = int(amp_index[1:-1])
354
355 deps = cls.get_deps(line, graph)
356
357 new_amps = [cls.get_obj(line, wavs, graph, diag_num)
358 for wavs in product(*deps)
359 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)]
360
361 return new_amps
362
363 @classmethod
365 return Amplitude(new_args, old_name, diag_num)
366
367 @classmethod
382
384 '''Class for recycling helicity'''
385
386 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
387
388 External.good_hel = []
389 External.nhel_lines = ''
390 External.num_externals = 0
391 External.wavs_same_leg = {}
392 External.good_wav_combs = []
393
394 Internal.max_wav_num = 0
395 Internal.num_internals = 0
396
397 Amplitude.max_amp_num = 0
398 self.last_category = None
399 self.good_elements = good_elements
400 self.bad_amps = bad_amps
401 self.bad_amps_perhel = bad_amps_perhel
402
403
404 self.input_file = 'matrix_orig.f'
405 self.output_file = 'matrix_orig.f'
406 self.template_file = 'template_matrix.f'
407
408 self.template_dict = {}
409
410 self.template_dict['helicity_lines'] = '\n'
411 self.template_dict['helas_calls'] = []
412 self.template_dict['jamp_lines'] = '\n'
413 self.template_dict['amp2_lines'] = '\n'
414 self.template_dict['ncomb'] = '0'
415 self.template_dict['nwavefuncs'] = '0'
416
417 self.dag = DAG()
418
419 self.diag_num = 1
420 self.got_gwc = False
421
422 self.procedure_name = self.input_file.split('.')[0].upper()
423 self.procedure_kind = 'FUNCTION'
424
425 self.old_out_name = ''
426 self.loop_var = 'K'
427
428 self.all_hel = []
429 self.hel_filt = True
430
439
441 self.output_file = file
442
445
447
448 if not 'CALL' in line:
449 return None
450
451
452 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX']
453 if any( call in line for call in ext_calls ):
454 return 'external'
455
456
457
458
459 if not self.dag.external_wavs:
460 return None
461
462
463
464 matches = self.dag.old_names() & set(get_arguments(line))
465 try:
466 matches.remove(get_arguments(line)[-1])
467 except KeyError:
468 pass
469 try:
470 function = (line.split('(', 1)[0]).split()[-1]
471 except IndexError:
472 return None
473
474
475 if (function.split('_')[-1] != '0'):
476 return 'internal'
477 elif (function.split('_')[-1] == '0'):
478 return 'amplitude'
479 else:
480 print(f'Ahhhh what is going on here?\n{line}')
481 set_trace()
482
483 return None
484
485
486
488 old_pat = matchobj.group()
489 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var)
490
491
492
493 return new_pat
494
496 '''Add loop_var index to amp and output variable.
497 Also update name of output variable.'''
498
499 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line)
500 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line)
501 return new_line
502
504
505
506
507
508 return 'init_mode' in line.lower()
509
510
511
512
514 if f'{self.procedure_kind} {self.procedure_name}' in line:
515 if 'SUBROUTINE' == self.procedure_kind:
516 self.old_out_name = get_arguments(line)[-1]
517 if 'FUNCTION' == self.procedure_kind:
518 self.old_out_name = line.split('(')[0].split()[-1]
519
521
522 if 'diagram number' in line:
523 self.amp_calc_started = True
524
525 if ('AMP' not in get_arguments(line)[-1]
526 and self.amp_calc_started and list(line)[0] != 'C'):
527
528 if self.function_call(line) not in ['external',
529 'internal',
530 'amplitude']:
531 self.jamp_started = True
532 self.amp_calc_started = False
533 if self.jamp_started:
534 self.get_jamp_lines(line)
535 if self.in_amp2:
536 self.get_amp2_lines(line)
537 if self.find_amp2 and line.startswith(' ENDDO'):
538 self.in_amp2 = True
539 self.find_amp2 = False
540
542 if self.jamp_finished(line):
543 self.jamp_started = False
544 self.find_amp2 = True
545 elif not line.isspace():
546 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
547
549 if line.startswith(' DO I = 1, NCOLOR'):
550 self.in_amp2 = False
551 elif not line.isspace():
552 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
553
555 self.amp_calc_started = False
556 self.jamp_started = False
557 self.find_amp2 = False
558 self.in_amp2 = False
559 self.nhel_started = False
560
562
563
564
565
566 if nature not in ['external', 'internal', 'amplitude']:
567 raise Exception('wrong unfolding')
568
569 if nature == 'external':
570 new_objs = External.generate_wavfuncs(line, self.dag)
571 for obj in new_objs:
572 obj.line = apply_args(line, [obj.args])
573 else:
574 deps = Amplitude.get_deps(line, self.dag)
575 name2dep = dict([(d.name,d) for d in sum(deps,[])])
576
577
578 if nature == 'internal':
579 new_objs = Internal.generate_wavfuncs(line, self.dag)
580 for obj in new_objs:
581 obj.line = apply_args(line, [obj.args])
582 obj.linkdag = []
583 for name in obj.args:
584 if name in name2dep:
585 name2dep[name].nb_used +=1
586 obj.linkdag.append(name2dep[name])
587
588 if nature == 'amplitude':
589 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0]
590 if nb_diag not in self.bad_amps:
591 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel)
592 out_line = self.apply_amps(line, new_objs)
593 for i,obj in enumerate(new_objs):
594 if i == 0:
595 obj.line = out_line
596 obj.nb_used = 1
597 else:
598 obj.line = ''
599 obj.nb_used = 1
600 obj.linkdag = []
601 for name in obj.args:
602 if name in name2dep:
603 name2dep[name].nb_used +=1
604 obj.linkdag.append(name2dep[name])
605 else:
606 return ''
607
608
609 return new_objs
610
611
618
619 - def get_gwc(self, line, category):
620
621
622 if category not in ['external', 'internal', 'amplitude']:
623 return
624 if self.last_category != 'external':
625 self.last_category = category
626 return
627
628 External.get_gwc()
629 self.last_category = category
630
656
658 self.counter += 1
659 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb]
660 nexternal = len(hel_comb)
661 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
662
664
665 with open(self.input_file, 'r') as input_file:
666
667 self.prepare_bools()
668
669 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)):
670 if line_num == 0:
671 line_cache = line
672 continue
673
674 if '!SKIP' in line:
675 continue
676
677 char_5 = ''
678 try:
679 char_5 = line[5]
680 except IndexError:
681 pass
682 if char_5 == '$':
683 line_cache = undo_multiline(line_cache, line)
684 continue
685
686 line, line_cache = line_cache, line
687
688 self.get_old_name(line)
689 self.get_good_hel(line)
690 self.get_amp_stuff(line_num, line)
691 call_type = self.function_call(line)
692 self.get_gwc(line, call_type)
693
694
695
696 if call_type in ['external', 'internal', 'amplitude']:
697 self.template_dict['helas_calls'] += self.unfold_helicities(
698 line, call_type)
699
700 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num)
701
702 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1):
703 obj = self.template_dict['helas_calls'][i]
704 if obj.nb_used == 0:
705 obj.line = ''
706 for dep in obj.linkdag:
707 dep.nb_used -= 1
708
709
710
711 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}'
712 for obj in self.template_dict['helas_calls']
713 if obj.nb_used > 0 and obj.line])
714
716 out_file = open(self.output_file, 'w+')
717 with open(self.template_file, 'r') as file:
718 for line in file:
719 s = Template(line)
720 line = s.safe_substitute(self.template_dict)
721 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
722 out_file.write(line)
723 out_file.close()
724
726 out_file = open(self.output_file, 'w+')
727 self.template_dict['ncomb'] = '0'
728 self.template_dict['nwavefuncs'] = '0'
729 self.template_dict['helas_calls'] = ''
730 with open(self.template_file, 'r') as file:
731 for line in file:
732 s = Template(line)
733 line = s.safe_substitute(self.template_dict)
734 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
735 out_file.write(line)
736 out_file.close()
737
738
749
752
755 '''Find the substrings separated by commas between the first
756 closed set of parentheses in 'line'.
757 '''
758 bracket_depth = 0
759 element = 0
760 arguments = ['']
761 for char in line:
762 if char == '(':
763 bracket_depth += 1
764 if bracket_depth - 1 == 0:
765
766
767 continue
768 if char == ')':
769 bracket_depth -= 1
770 if bracket_depth == 0:
771
772 break
773 if char == ',' and bracket_depth == 1:
774 element += 1
775 arguments.append('')
776 continue
777 if bracket_depth > 0:
778 arguments[element] += char
779 return arguments
780
783 function = (old_line.split('(')[0]).split()[-1]
784 old_args = old_line.split(function)[-1]
785 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n')
786 for x in all_the_args]
787
788 return ''.join(new_lines)
789
791 fct = line.split('(',1)[0].split('_0')[0]
792 for i,amp in enumerate(new_amps):
793 if i == 0:
794 occur = []
795 for a in amp.args:
796 if "W(1," in a:
797 tmp = collections.defaultdict(int)
798 tmp[a] += 1
799 occur.append(tmp)
800 else:
801 for i in range(len(occur)):
802 a = amp.args[i]
803 occur[i][a] +=1
804
805
806 nb_wav = [len(o) for o in occur]
807 to_remove = nb_wav.index(max(nb_wav))
808
809 occur.pop(to_remove)
810
811 lines = []
812
813 wav_name = [o.keys() for o in occur]
814 for wfcts in product(*wav_name):
815
816 sub_amps = [amp for amp in new_amps
817 if all(w in amp.args for w in wfcts)]
818 if not sub_amps:
819 continue
820 if len(sub_amps) ==1:
821 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n',''))
822
823 continue
824
825
826 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1]))
827 windices = []
828 hel_calculated = []
829 iamp = 0
830 for i,amp in enumerate(sub_amps):
831 args = amp.args[:]
832
833 wcontract = args.pop(to_remove)
834 windex = wcontract.split(',')[1].split(')')[0]
835 windices.append(windex)
836 amp_result, args[-1] = args[-1], 'TMP(1)'
837
838 if i ==0:
839
840
841 spin = fct.split(None,1)[1][to_remove]
842 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args)))
843
844 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0]
845 hel_calculated.append(hel)
846
847
848
849
850 if spin in "VF":
851 lines.append(""" call CombineAmp(%(nb)i,
852 & (/%(hel_list)s/),
853 & (/%(w_list)s/),
854 & TMP, W, AMP(1,%(iamp)s))""" %
855 {'nb': len(sub_amps),
856 'hel_list': ','.join(hel_calculated),
857 'w_list': ','.join(windices),
858 'iamp': iamp
859 })
860 elif spin == "S":
861 lines.append(""" call CombineAmpS(%(nb)i,
862 &(/%(hel_list)s/),
863 & (/%(w_list)s/),
864 & TMP, W, AMP(1,%(iamp)s))""" %
865 {'nb': len(sub_amps),
866 'hel_list': ','.join(hel_calculated),
867 'w_list': ','.join(windices),
868 'iamp': iamp
869 })
870 else:
871 raise Exception("split amp are not supported for spin2 and 3/2")
872
873
874 return '\n'.join(lines)
875
877 name = wav.name
878 between_brackets = re.search(r'\(.*?\)', name).group()
879 num = int(between_brackets[1:-1].split(',')[-1])
880 return num
881
883 new_line = new_line[6:]
884 old_line = old_line.replace('\n','')
885 return f'{old_line}{new_line}'
886
888 char_limit = 72
889 num_splits = len(line)//char_limit
890 if num_splits != 0 and len(line) != 72 and '!' not in line:
891 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)]
892 indent = ''
893 for char in line[6:]:
894 if char == ' ':
895 indent += char
896 else:
897 break
898
899 line = f'\n ${indent}'.join(split_line)
900 return line
901
903 if i == 1:
904 return '+1'
905 if i == 0:
906 return ' 0'
907 if i == -1:
908 return '-1'
909 else:
910 print(f'How can {i} be a helicity?')
911 set_trace()
912 exit(1)
913
915 parser = argparse.ArgumentParser()
916 parser.add_argument('input_file', help='The file containing the '
917 'original matrix calculation')
918 parser.add_argument('hel_file', help='The file containing the '
919 'contributing helicities')
920 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering')
921 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting')
922
923 args = parser.parse_args()
924
925 with open(args.hel_file, 'r') as file:
926 good_elements = file.readline().split()
927
928 recycler = HelicityRecycler(good_elements)
929
930 recycler.hel_filt = args.hel_filt
931 recycler.amp_splt = args.amp_splt
932
933 recycler.set_input(args.input_file)
934 recycler.set_output('green_matrix.f')
935 recycler.set_template('template_matrix1.f')
936
937 recycler.generate_output_file()
938
939 if __name__ == '__main__':
940 main()
941