1
2
3 import argparse
4 import atexit
5 import os
6 import re
7 import collections
8 from string import Template
9 from copy import copy
10 from itertools import product
11 from functools import reduce
12
13 try:
14 import madgraph
15 except:
16 import internal.misc as misc
17 else:
18 import madgraph.various.misc as misc
19 import mmap
20 try:
21 from tqdm import tqdm
22 except ImportError:
23 tqdm = misc.tqdm
27 fp = open(file_path, 'r+')
28 buf = mmap.mmap(fp.fileno(),0)
29 lines = 0
30 while buf.readline():
31 lines += 1
32 return lines
33
35
37 self.graph = {}
38 self.all_wavs = []
39 self.external_wavs = []
40 self.internal_wavs = []
41
43 self.all_wavs.append(wav)
44 nature = wav.nature
45 if nature == 'external':
46 self.external_wavs.append(wav)
47 if nature == 'internal':
48 self.internal_wavs.append(wav)
49 for ext in ext_deps:
50 self.add_branch(wav, ext)
51
53 try:
54 self.graph[node_i].append(node_f)
55 except KeyError:
56 self.graph[node_i] = [node_f]
57
59 deps = [wav for wav in self.all_wavs
60 if wav.old_name == old_name and not wav.dead]
61 return deps
62
64 for wav in self.all_wavs:
65 if wav.old_name == old_name:
66 wav.dead = True
67
69 return {wav.old_name for wav in self.all_wavs}
70
72 '''Taken from https://www.python.org/doc/essays/graphs/'''
73
74 path = path + [start]
75 if start == end:
76 return path
77 if start not in self.graph:
78 return None
79 for node in self.graph[start]:
80 if node not in path:
81 newpath = self.find_path(node, end, path)
82 if newpath:
83 return newpath
84 return None
85
88
90 print_str = 'With new names:\n\t'
91 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ])
92 print_str += '\n\nWith old names:\n\t'
93 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ])
94 return print_str
95
99 '''Abstract class for wavefunctions and Amplitudes'''
100
101
102
103 ext_deps = None
104
105 - def __init__(self, arguments, old_name, nature):
106 self.args = arguments
107 self.old_name = old_name
108 self.nature = nature
109 self.name = None
110 self.dead = False
111 self.nb_used = 0
112 self.linkdag = []
113
117
120
121 @staticmethod
123 old_args = get_arguments(line)
124 old_name = old_args[-1].replace(' ','')
125 matches = graph.old_names() & set([old.replace(' ','') for old in old_args])
126 try:
127 matches.remove(old_name)
128 except KeyError:
129 pass
130 old_deps = old_args[0:len(matches)]
131
132
133 graph.kill_old(old_name)
134 return [graph.dependencies(dep) for dep in old_deps]
135
136 @classmethod
137 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
138 exts = graph.external_wavs
139 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) }
140 this_comb_good = False
141 for comb in External.good_wav_combs:
142 if cls.ext_deps.issubset(set(comb)):
143 this_comb_good = True
144 break
145
146 if diag_number and this_comb_good and cls.ext_deps:
147
148 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps])
149 this_hel = [helicity[i] for i in range(1, len(helicity)+1)]
150 hel_number = 1 + all_hel.index(tuple(this_hel))
151
152 if (hel_number,diag_number) in bad_hel_amp:
153 this_comb_good = False
154
155
156
157 return this_comb_good and cls.ext_deps
158
159 @staticmethod
161 old_args = get_arguments(line)
162 old_name = old_args[-1].replace(' ','')
163
164 this_args = copy(old_args)
165 wav_names = [w.name for w in wavs]
166 this_args[0:len(wavs)] = wav_names
167
168
169 return this_args
170
171 @staticmethod
174
175 @classmethod
176 - def get_obj(cls, line, wavs, graph, diag_num = None):
186
187
190
193
195 '''Class for storing external wavefunctions'''
196
197 good_hel = []
198 nhel_lines = ''
199 num_externals = 0
200
201 wavs_same_leg = {}
202 good_wav_combs = []
203
204 - def __init__(self, arguments, old_name):
205 super().__init__(arguments, old_name, 'external')
206 self.hel = int(self.args[2])
207 self.mg = int(arguments[0].split(',')[-1][:-1])
208 self.hel_ranges = []
209 self.raise_num()
210
211 @classmethod
214
215 @classmethod
217
218
219 old_args = get_arguments(line)
220 old_name = old_args[-1].replace(' ','')
221 graph.kill_old(old_name)
222
223 if 'NHEL' in old_args[2].upper():
224 nhel_index = re.search(r'\(.*?\)', old_args[2]).group()
225 ext_num = int(nhel_index[1:-1]) - 1
226 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True)
227 new_hels = [int_to_string(i) for i in new_hels]
228 else:
229
230 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1
231 new_hels = [' 0']
232
233 new_wavfuncs = []
234 for hel in new_hels:
235
236 this_args = copy(old_args)
237 this_args[2] = hel
238
239 this_wavfunc = External(this_args, old_name)
240 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1)
241
242 graph.store_wav(this_wavfunc)
243 new_wavfuncs.append(this_wavfunc)
244 if ext_num in cls.wavs_same_leg:
245 cls.wavs_same_leg[ext_num] += new_wavfuncs
246 else:
247 cls.wavs_same_leg[ext_num] = new_wavfuncs
248
249 return new_wavfuncs
250
251 @classmethod
253 num_combs = len(cls.good_hel)
254 gwc_old = [[] for x in range(num_combs)]
255 gwc=[]
256 for n, comb in enumerate(cls.good_hel):
257 sols = [[]]
258 for leg, wavs in cls.wavs_same_leg.items():
259 valid = []
260 for wav in wavs:
261 if comb[leg] == wav.hel:
262 valid.append(wav)
263 gwc_old[n].append(wav)
264 if len(valid) == 1:
265 for sol in sols:
266 sol.append(valid[0])
267 else:
268 tmp = []
269 for w in valid:
270 for sol in sols:
271 tmp2 = list(sol)
272 tmp2.append(w)
273 tmp.append(tmp2)
274 sols = tmp
275 gwc += sols
276
277 cls.good_wav_combs = gwc
278
279 @staticmethod
282
284 """ return the id of the particle under consideration """
285
286 try:
287 return self.id
288 except:
289 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0])
290 return self.id
291
295 '''Class for storing internal wavefunctions'''
296
297 max_wav_num = 0
298 num_internals = 0
299
300 @classmethod
303
304 @classmethod
306 deps = cls.get_deps(line, graph)
307 new_wavfuncs = [ cls.get_obj(line, wavs, graph)
308 for wavs in product(*deps)
309 if cls.good_helicity(wavs, graph) ]
310
311 return new_wavfuncs
312
313
314
315 @classmethod
318
319 @classmethod
325
326 - def __init__(self, arguments, old_name):
329
330
331 @staticmethod
334
336 '''Class for storing Amplitudes'''
337
338 max_amp_num = 0
339
340 - def __init__(self, arguments, old_name, diag_num):
341 self.diag_num = diag_num
342 super().__init__(arguments, old_name, 'amplitude')
343
344
345 @staticmethod
348
349 @classmethod
350 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
351 old_args = get_arguments(line)
352 old_name = old_args[-1].replace(' ','')
353
354 amp_index = re.search(r'\(.*?\)', old_name).group()
355 diag_num = int(amp_index[1:-1])
356
357 deps = cls.get_deps(line, graph)
358
359 new_amps = [cls.get_obj(line, wavs, graph, diag_num)
360 for wavs in product(*deps)
361 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)]
362
363 return new_amps
364
365 @classmethod
367 return Amplitude(new_args, old_name, diag_num)
368
369 @classmethod
371 wavs, graph = args
372 amp_num = -1
373 exts = graph.external_wavs
374 hel_amp = tuple([w.hel for w in sorted(cls.ext_deps, key=lambda x: x.mg)])
375 amp_num = External.map_hel[hel_amp] +1
376
377 if cls.max_amp_num < amp_num:
378 cls.max_amp_num = amp_num
379 return amp_num
380
382 '''Class for recycling helicity'''
383
384 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
385
386 External.good_hel = []
387 External.nhel_lines = ''
388 External.num_externals = 0
389 External.wavs_same_leg = {}
390 External.good_wav_combs = []
391
392 Internal.max_wav_num = 0
393 Internal.num_internals = 0
394
395 Amplitude.max_amp_num = 0
396 self.last_category = None
397 self.good_elements = good_elements
398 self.bad_amps = bad_amps
399 self.bad_amps_perhel = bad_amps_perhel
400
401
402 self.input_file = 'matrix_orig.f'
403 self.output_file = 'matrix_orig.f'
404 self.template_file = 'template_matrix.f'
405
406 self.template_dict = {}
407
408 self.template_dict['helicity_lines'] = '\n'
409 self.template_dict['helas_calls'] = []
410 self.template_dict['jamp_lines'] = '\n'
411 self.template_dict['amp2_lines'] = '\n'
412 self.template_dict['ncomb'] = '0'
413 self.template_dict['nwavefuncs'] = '0'
414
415 self.dag = DAG()
416
417 self.diag_num = 1
418 self.got_gwc = False
419
420 self.procedure_name = self.input_file.split('.')[0].upper()
421 self.procedure_kind = 'FUNCTION'
422
423 self.old_out_name = ''
424 self.loop_var = 'K'
425
426 self.all_hel = []
427 self.hel_filt = True
428
437
439 self.output_file = file
440
443
445
446 if not 'CALL' in line:
447 return None
448
449
450 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX']
451 if any( call in line for call in ext_calls ):
452 return 'external'
453
454
455
456
457 if not self.dag.external_wavs:
458 return None
459
460
461
462 matches = self.dag.old_names() & set(get_arguments(line))
463 try:
464 matches.remove(get_arguments(line)[-1])
465 except KeyError:
466 pass
467 try:
468 function = (line.split('(', 1)[0]).split()[-1]
469 except IndexError:
470 return None
471
472
473 if (function.split('_')[-1] != '0'):
474 return 'internal'
475 elif (function.split('_')[-1] == '0'):
476 return 'amplitude'
477 else:
478 print(f'Ahhhh what is going on here?\n{line}')
479 set_trace()
480
481 return None
482
483
484
486 old_pat = matchobj.group()
487 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var)
488
489
490 return new_pat
491
493 '''Add loop_var index to amp and output variable.
494 Also update name of output variable.'''
495
496 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line)
497 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line)
498 return new_line
499
501
502
503
504
505 return 'init_mode' in line.lower()
506
507
508
509
511 if f'{self.procedure_kind} {self.procedure_name}' in line:
512 if 'SUBROUTINE' == self.procedure_kind:
513 self.old_out_name = get_arguments(line)[-1]
514 if 'FUNCTION' == self.procedure_kind:
515 self.old_out_name = line.split('(')[0].split()[-1]
516
518
519 if 'diagram number' in line:
520 self.amp_calc_started = True
521
522 if ('AMP' not in get_arguments(line)[-1]
523 and self.amp_calc_started and list(line)[0] != 'C'):
524
525 if self.function_call(line) not in ['external',
526 'internal',
527 'amplitude']:
528 self.jamp_started = True
529 self.amp_calc_started = False
530 if self.jamp_started:
531 self.get_jamp_lines(line)
532 if self.in_amp2:
533 self.get_amp2_lines(line)
534 if self.find_amp2 and line.startswith(' ENDDO'):
535 self.in_amp2 = True
536 self.find_amp2 = False
537
539 if self.jamp_finished(line):
540 self.jamp_started = False
541 self.find_amp2 = True
542 elif not line.isspace():
543 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
544
546 if line.startswith(' DO I = 1, NCOLOR'):
547 self.in_amp2 = False
548 elif not line.isspace():
549 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
550
552 self.amp_calc_started = False
553 self.jamp_started = False
554 self.find_amp2 = False
555 self.in_amp2 = False
556 self.nhel_started = False
557
559
560
561
562
563 if nature not in ['external', 'internal', 'amplitude']:
564 raise Exception('wrong unfolding')
565
566 if nature == 'external':
567 new_objs = External.generate_wavfuncs(line, self.dag)
568 for obj in new_objs:
569 obj.line = apply_args(line, [obj.args])
570 else:
571 deps = Amplitude.get_deps(line, self.dag)
572 name2dep = dict([(d.name,d) for d in sum(deps,[])])
573
574
575 if nature == 'internal':
576 new_objs = Internal.generate_wavfuncs(line, self.dag)
577 for obj in new_objs:
578 obj.line = apply_args(line, [obj.args])
579 obj.linkdag = []
580 for name in obj.args:
581 if name in name2dep:
582 name2dep[name].nb_used +=1
583 obj.linkdag.append(name2dep[name])
584
585 if nature == 'amplitude':
586 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0]
587 if nb_diag not in self.bad_amps:
588 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel)
589 out_line = self.apply_amps(line, new_objs)
590 for i,obj in enumerate(new_objs):
591 if i == 0:
592 obj.line = out_line
593 obj.nb_used = 1
594 else:
595 obj.line = ''
596 obj.nb_used = 1
597 obj.linkdag = []
598 for name in obj.args:
599 if name in name2dep:
600 name2dep[name].nb_used +=1
601 obj.linkdag.append(name2dep[name])
602 else:
603 return ''
604
605
606 return new_objs
607
608
615
616 - def get_gwc(self, line, category):
617
618
619 if category not in ['external', 'internal', 'amplitude']:
620 return
621 if self.last_category != 'external':
622 self.last_category = category
623 return
624
625 External.get_gwc()
626 self.last_category = category
627
654
656 self.counter += 1
657 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb]
658 nexternal = len(hel_comb)
659 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
660
662
663 with open(self.input_file, 'r') as input_file:
664
665 self.prepare_bools()
666
667 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)):
668 if line_num == 0:
669 line_cache = line
670 continue
671
672 if '!SKIP' in line:
673 continue
674
675 char_5 = ''
676 try:
677 char_5 = line[5]
678 except IndexError:
679 pass
680 if char_5 == '$':
681 line_cache = undo_multiline(line_cache, line)
682 continue
683
684 line, line_cache = line_cache, line
685
686 self.get_old_name(line)
687 self.get_good_hel(line)
688 self.get_amp_stuff(line_num, line)
689 call_type = self.function_call(line)
690 self.get_gwc(line, call_type)
691
692
693 if call_type in ['external', 'internal', 'amplitude']:
694 self.template_dict['helas_calls'] += self.unfold_helicities(
695 line, call_type)
696
697 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num)
698
699 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1):
700 obj = self.template_dict['helas_calls'][i]
701 if obj.nb_used == 0:
702 obj.line = ''
703 for dep in obj.linkdag:
704 dep.nb_used -= 1
705
706
707
708 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}'
709 for obj in self.template_dict['helas_calls']
710 if obj.nb_used > 0 and obj.line])
711
713 out_file = open(self.output_file, 'w+')
714 with open(self.template_file, 'r') as file:
715 for line in file:
716 s = Template(line)
717 line = s.safe_substitute(self.template_dict)
718 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
719 out_file.write(line)
720 out_file.close()
721
723 try:
724 os.remove(self.output_file)
725 except Exception:
726 pass
727 input_file = self.output_file.replace("_optim.f", "_orig.f")
728 os.symlink(input_file, self.output_file)
729
730
741
744
747 '''Find the substrings separated by commas between the first
748 closed set of parentheses in 'line'.
749 '''
750 bracket_depth = 0
751 element = 0
752 arguments = ['']
753 for char in line:
754 if char == '(':
755 bracket_depth += 1
756 if bracket_depth - 1 == 0:
757
758
759 continue
760 if char == ')':
761 bracket_depth -= 1
762 if bracket_depth == 0:
763
764 break
765 if char == ',' and bracket_depth == 1:
766 element += 1
767 arguments.append('')
768 continue
769 if bracket_depth > 0:
770 arguments[element] += char
771 return arguments
772
775 function = (old_line.split('(')[0]).split()[-1]
776 old_args = old_line.split(function)[-1]
777 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n')
778 for x in all_the_args]
779
780 return ''.join(new_lines)
781
783 if not new_amps:
784 return ''
785 fct = line.split('(',1)[0].split('_0')[0]
786 for i,amp in enumerate(new_amps):
787 if i == 0:
788 occur = []
789 for a in amp.args:
790 if "W(1," in a:
791 tmp = collections.defaultdict(int)
792 tmp[a] += 1
793 occur.append(tmp)
794 else:
795 for i in range(len(occur)):
796 a = amp.args[i]
797 occur[i][a] +=1
798
799
800 nb_wav = [len(o) for o in occur]
801 to_remove = nb_wav.index(max(nb_wav))
802
803 occur.pop(to_remove)
804
805 lines = []
806
807 wav_name = [o.keys() for o in occur]
808 for wfcts in product(*wav_name):
809
810 sub_amps = [amp for amp in new_amps
811 if all(w in amp.args for w in wfcts)]
812 if not sub_amps:
813 continue
814 if len(sub_amps) ==1:
815 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n',''))
816
817 continue
818
819
820 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1]))
821 windices = []
822 hel_calculated = []
823 iamp = 0
824 for i,amp in enumerate(sub_amps):
825 args = amp.args[:]
826
827 wcontract = args.pop(to_remove)
828 windex = wcontract.split(',')[1].split(')')[0]
829 windices.append(windex)
830 amp_result, args[-1] = args[-1], 'TMP(1)'
831
832 if i ==0:
833
834
835 spin = fct.split(None,1)[1][to_remove]
836 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args)))
837
838 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0]
839 hel_calculated.append(hel)
840
841
842
843
844 if spin in "VF":
845 lines.append(""" call CombineAmp(%(nb)i,
846 & (/%(hel_list)s/),
847 & (/%(w_list)s/),
848 & TMP, W, AMP(1,%(iamp)s))""" %
849 {'nb': len(sub_amps),
850 'hel_list': ','.join(hel_calculated),
851 'w_list': ','.join(windices),
852 'iamp': iamp
853 })
854 elif spin == "S":
855 lines.append(""" call CombineAmpS(%(nb)i,
856 &(/%(hel_list)s/),
857 & (/%(w_list)s/),
858 & TMP, W, AMP(1,%(iamp)s))""" %
859 {'nb': len(sub_amps),
860 'hel_list': ','.join(hel_calculated),
861 'w_list': ','.join(windices),
862 'iamp': iamp
863 })
864 else:
865 raise Exception("split amp are not supported for spin2 and 3/2")
866
867
868 return '\n'.join(lines)
869
871 name = wav.name
872 between_brackets = re.search(r'\(.*?\)', name).group()
873 num = int(between_brackets[1:-1].split(',')[-1])
874 return num
875
877 new_line = new_line[6:]
878 old_line = old_line.replace('\n','')
879 return f'{old_line}{new_line}'
880
882 char_limit = 72
883 num_splits = len(line)//char_limit
884 if num_splits != 0 and len(line) != 72 and '!' not in line:
885 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)]
886 indent = ''
887 for char in line[6:]:
888 if char == ' ':
889 indent += char
890 else:
891 break
892
893 line = f'\n ${indent}'.join(split_line)
894 return line
895
897 if i == 1:
898 return '+1'
899 if i == 0:
900 return ' 0'
901 if i == -1:
902 return '-1'
903 else:
904 print(f'How can {i} be a helicity?')
905 set_trace()
906 exit(1)
907
909 parser = argparse.ArgumentParser()
910 parser.add_argument('input_file', help='The file containing the '
911 'original matrix calculation')
912 parser.add_argument('hel_file', help='The file containing the '
913 'contributing helicities')
914 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering')
915 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting')
916
917 args = parser.parse_args()
918
919 with open(args.hel_file, 'r') as file:
920 good_elements = file.readline().split()
921
922 recycler = HelicityRecycler(good_elements)
923
924 recycler.hel_filt = args.hel_filt
925 recycler.amp_splt = args.amp_splt
926
927 recycler.set_input(args.input_file)
928 recycler.set_output('green_matrix.f')
929 recycler.set_template('template_matrix1.f')
930
931 recycler.generate_output_file()
932
933 if __name__ == '__main__':
934 main()
935