From d45fc586f1620d00a6e531690345852190777e72 Mon Sep 17 00:00:00 2001 From: Lukas Heinrich Date: Mon, 17 Mar 2025 11:07:02 +0100 Subject: [PATCH 1/4] first version of custom modifiers with sympy and jax --- .../contrib/extended_modifiers/__init__.py | 0 .../contrib/extended_modifiers/purefunc.py | 141 ++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 src/pyhf/contrib/extended_modifiers/__init__.py create mode 100644 src/pyhf/contrib/extended_modifiers/purefunc.py diff --git a/src/pyhf/contrib/extended_modifiers/__init__.py b/src/pyhf/contrib/extended_modifiers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/pyhf/contrib/extended_modifiers/purefunc.py b/src/pyhf/contrib/extended_modifiers/purefunc.py new file mode 100644 index 0000000000..4cf947fe87 --- /dev/null +++ b/src/pyhf/contrib/extended_modifiers/purefunc.py @@ -0,0 +1,141 @@ + +import sympy.parsing.sympy_parser as parser +import sympy +from pyhf.parameters import ParamViewer +import jax.numpy as jnp +import jax + +def create_modifiers(additional_parameters = None): + + class PureFunctionModifierBuilder: + is_shared = True + def __init__(self, pdfconfig): + self.config = pdfconfig + self.required_parsets = additional_parameters or {} + self.builder_data = {'local': {},'global': {'symbols': set()}} + + def collect(self, thismod, nom): + maskval = True if thismod else False + mask = [maskval] * len(nom) + return {'mask': mask} + + def append(self, key, channel, sample, thismod, defined_samp): + self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []}) + + nom = ( + defined_samp['data'] + if defined_samp + else [0.0] * self.config.channel_nbins[channel] + ) + moddata = self.collect(thismod, nom) + self.builder_data['local'][key][sample]['data']['mask'] += moddata['mask'] + + if thismod is not None: + formula = thismod['data']['formula'] + parsed = parser.parse_expr(formula) + free_symbols = parsed.free_symbols + for x in free_symbols: + self.builder_data['global'].setdefault('symbols',set()).add(x) + else: + parsed = None + self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed + + def finalize(self): + list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']] + self.builder_data['global']['symbol_names'] = list_of_symbols + for modname, modspec in self.builder_data['local'].items(): + for sample, samplespec in modspec.items(): + for channel, channelspec in samplespec['channels'].items(): + if channelspec['parsed'] is not None: + channelspec['jaxfunc'] = sympy.lambdify(list_of_symbols, channelspec['parsed'], 'jax') + else: + channelspec['jaxfunc'] = lambda *args: 1.0 + return self.builder_data + + class PureFunctionModifierApplicator: + op_code = 'multiplication' + name = 'purefunc' + + def __init__( + self, modifiers=None, pdfconfig=None, builder_data=None, batch_size=None + ): + self.builder_data = builder_data + self.batch_size = batch_size + self.pdfconfig = pdfconfig + self.inputs = [str(x) for x in builder_data['global']['symbols']] + + self.keys = [f'{mtype}/{m}' for m, mtype in modifiers] + self.modifiers = [m for m, _ in modifiers] + + parfield_shape = ( + (self.batch_size, pdfconfig.npars) + if self.batch_size + else (pdfconfig.npars,) + ) + + self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, self.inputs) + self.create_jax_eval() + + def create_jax_eval(self): + def eval_func(pars): + return jnp.array([ + [ + jnp.concatenate([ + self.builder_data['local'][m][s]['channels'][c]['jaxfunc'](*pars)*jnp.ones(self.pdfconfig.channel_nbins[c]) + for c in self.pdfconfig.channels + ]) + for s in self.pdfconfig.samples + ] + for m in self.keys + + ]) + self.jaxeval = eval_func + + def apply_nonbatched(self,pars): + return jnp.expand_dims(self.jaxeval(pars),2) + + def apply_batched(self,pars): + return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars) + + def apply(self, pars): + if not self.param_viewer.index_selection: + return + if self.batch_size is None: + par_selection = self.param_viewer.get(pars) + results_purefunc = self.apply_nonbatched(par_selection) + else: + par_selection = self.param_viewer.get(pars) + results_purefunc = self.apply_batched(par_selection) + return results_purefunc + + return PureFunctionModifierBuilder, PureFunctionModifierApplicator + + +from pyhf.modifiers import histfactory_set + +def enable(new_params = None): + modifier_set = {} + modifier_set.update(**histfactory_set) + + builder, applicator = create_modifiers(new_params) + + modifier_set.update(**{ + applicator.name: (builder, applicator)} + ) + return modifier_set + +def new_unconstrained_scalars(new_params): + param_spec = { + p['name']: + [{ + 'paramset_type': 'unconstrained', + 'n_parameters': 1, + 'is_shared': True, + 'inits': (p['init'],), + 'bounds': ((p['min'], p['max']),), + 'is_scalar': True, + 'fixed': False, + }] + for p in new_params + } + return param_spec \ No newline at end of file From 1eedbe4a7dd5c00371de86fc6da5ae7d7c7b3760 Mon Sep 17 00:00:00 2001 From: Lukas Heinrich Date: Mon, 17 Mar 2025 11:49:43 +0100 Subject: [PATCH 2/4] switch to on the fly creation of required parameters --- .../contrib/extended_modifiers/purefunc.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/pyhf/contrib/extended_modifiers/purefunc.py b/src/pyhf/contrib/extended_modifiers/purefunc.py index 4cf947fe87..7a7fb74689 100644 --- a/src/pyhf/contrib/extended_modifiers/purefunc.py +++ b/src/pyhf/contrib/extended_modifiers/purefunc.py @@ -5,13 +5,13 @@ import jax.numpy as jnp import jax -def create_modifiers(additional_parameters = None): +def create_modifiers(): class PureFunctionModifierBuilder: is_shared = True def __init__(self, pdfconfig): self.config = pdfconfig - self.required_parsets = additional_parameters or {} + self.required_parsets = {} self.builder_data = {'local': {},'global': {'symbols': set()}} def collect(self, thismod, nom): @@ -19,6 +19,23 @@ def collect(self, thismod, nom): mask = [maskval] * len(nom) return {'mask': mask} + def require_synbols_as_scalars(self, symbols): + param_spec = { + p: + [{ + 'paramset_type': 'unconstrained', + 'n_parameters': 1, + 'is_shared': True, + 'inits': (1.0,), + 'bounds': ((0,10),), + 'is_scalar': True, + 'fixed': False, + }] + for p in symbols + } + return param_spec + + def append(self, key, channel, sample, thismod, defined_samp): self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []}) @@ -42,6 +59,9 @@ def append(self, key, channel, sample, thismod, defined_samp): def finalize(self): list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']] + + self.required_parsets = self.require_synbols_as_scalars(list_of_symbols) + self.builder_data['global']['symbol_names'] = list_of_symbols for modname, modspec in self.builder_data['local'].items(): for sample, samplespec in modspec.items(): @@ -113,29 +133,13 @@ def apply(self, pars): from pyhf.modifiers import histfactory_set -def enable(new_params = None): +def enable(): modifier_set = {} modifier_set.update(**histfactory_set) - builder, applicator = create_modifiers(new_params) + builder, applicator = create_modifiers() modifier_set.update(**{ applicator.name: (builder, applicator)} ) - return modifier_set - -def new_unconstrained_scalars(new_params): - param_spec = { - p['name']: - [{ - 'paramset_type': 'unconstrained', - 'n_parameters': 1, - 'is_shared': True, - 'inits': (p['init'],), - 'bounds': ((p['min'], p['max']),), - 'is_scalar': True, - 'fixed': False, - }] - for p in new_params - } - return param_spec \ No newline at end of file + return modifier_set \ No newline at end of file From 83645d65ff7d8df04411aaa3995e627b5cc538c1 Mon Sep 17 00:00:00 2001 From: Lukas Heinrich Date: Mon, 17 Mar 2025 15:04:24 +0100 Subject: [PATCH 3/4] fix typo --- src/pyhf/contrib/extended_modifiers/purefunc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pyhf/contrib/extended_modifiers/purefunc.py b/src/pyhf/contrib/extended_modifiers/purefunc.py index 7a7fb74689..d693ebb3a9 100644 --- a/src/pyhf/contrib/extended_modifiers/purefunc.py +++ b/src/pyhf/contrib/extended_modifiers/purefunc.py @@ -13,13 +13,14 @@ def __init__(self, pdfconfig): self.config = pdfconfig self.required_parsets = {} self.builder_data = {'local': {},'global': {'symbols': set()}} + self.encountered_expressions = {} def collect(self, thismod, nom): maskval = True if thismod else False mask = [maskval] * len(nom) return {'mask': mask} - def require_synbols_as_scalars(self, symbols): + def require_symbols_as_scalars(self, symbols): param_spec = { p: [{ @@ -52,7 +53,8 @@ def append(self, key, channel, sample, thismod, defined_samp): parsed = parser.parse_expr(formula) free_symbols = parsed.free_symbols for x in free_symbols: - self.builder_data['global'].setdefault('symbols',set()).add(x) + if x not in self.encountered_expressions: + self.builder_data['global'].setdefault('symbols',set()).add(x) else: parsed = None self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed @@ -60,7 +62,7 @@ def append(self, key, channel, sample, thismod, defined_samp): def finalize(self): list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']] - self.required_parsets = self.require_synbols_as_scalars(list_of_symbols) + self.required_parsets = self.require_symbols_as_scalars(list_of_symbols) self.builder_data['global']['symbol_names'] = list_of_symbols for modname, modspec in self.builder_data['local'].items(): From a75ce13631073ecba8433ceae9dd218dcb498d93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Mar 2025 14:05:03 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../contrib/extended_modifiers/purefunc.py | 85 +++++++++++-------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/src/pyhf/contrib/extended_modifiers/purefunc.py b/src/pyhf/contrib/extended_modifiers/purefunc.py index d693ebb3a9..27f475b933 100644 --- a/src/pyhf/contrib/extended_modifiers/purefunc.py +++ b/src/pyhf/contrib/extended_modifiers/purefunc.py @@ -1,18 +1,19 @@ - import sympy.parsing.sympy_parser as parser import sympy from pyhf.parameters import ParamViewer import jax.numpy as jnp import jax + def create_modifiers(): class PureFunctionModifierBuilder: is_shared = True + def __init__(self, pdfconfig): self.config = pdfconfig self.required_parsets = {} - self.builder_data = {'local': {},'global': {'symbols': set()}} + self.builder_data = {'local': {}, 'global': {'symbols': set()}} self.encountered_expressions = {} def collect(self, thismod, nom): @@ -22,23 +23,25 @@ def collect(self, thismod, nom): def require_symbols_as_scalars(self, symbols): param_spec = { - p: - [{ - 'paramset_type': 'unconstrained', - 'n_parameters': 1, - 'is_shared': True, - 'inits': (1.0,), - 'bounds': ((0,10),), - 'is_scalar': True, - 'fixed': False, - }] + p: [ + { + 'paramset_type': 'unconstrained', + 'n_parameters': 1, + 'is_shared': True, + 'inits': (1.0,), + 'bounds': ((0, 10),), + 'is_scalar': True, + 'fixed': False, + } + ] for p in symbols } return param_spec - def append(self, key, channel, sample, thismod, defined_samp): - self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []}) + self.builder_data['local'].setdefault(key, {}).setdefault( + sample, {} + ).setdefault('data', {'mask': []}) nom = ( defined_samp['data'] @@ -54,10 +57,12 @@ def append(self, key, channel, sample, thismod, defined_samp): free_symbols = parsed.free_symbols for x in free_symbols: if x not in self.encountered_expressions: - self.builder_data['global'].setdefault('symbols',set()).add(x) + self.builder_data['global'].setdefault('symbols', set()).add(x) else: parsed = None - self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed + self.builder_data['local'].setdefault(key, {}).setdefault( + sample, {} + ).setdefault('channels', {}).setdefault(channel, {})['parsed'] = parsed def finalize(self): list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']] @@ -69,7 +74,9 @@ def finalize(self): for sample, samplespec in modspec.items(): for channel, channelspec in samplespec['channels'].items(): if channelspec['parsed'] is not None: - channelspec['jaxfunc'] = sympy.lambdify(list_of_symbols, channelspec['parsed'], 'jax') + channelspec['jaxfunc'] = sympy.lambdify( + list_of_symbols, channelspec['parsed'], 'jax' + ) else: channelspec['jaxfunc'] = lambda *args: 1.0 return self.builder_data @@ -95,28 +102,37 @@ def __init__( else (pdfconfig.npars,) ) - self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, self.inputs) + self.param_viewer = ParamViewer( + parfield_shape, pdfconfig.par_map, self.inputs + ) self.create_jax_eval() def create_jax_eval(self): def eval_func(pars): - return jnp.array([ + return jnp.array( [ - jnp.concatenate([ - self.builder_data['local'][m][s]['channels'][c]['jaxfunc'](*pars)*jnp.ones(self.pdfconfig.channel_nbins[c]) - for c in self.pdfconfig.channels - ]) - for s in self.pdfconfig.samples + [ + jnp.concatenate( + [ + self.builder_data['local'][m][s]['channels'][c][ + 'jaxfunc' + ](*pars) + * jnp.ones(self.pdfconfig.channel_nbins[c]) + for c in self.pdfconfig.channels + ] + ) + for s in self.pdfconfig.samples + ] + for m in self.keys ] - for m in self.keys + ) - ]) self.jaxeval = eval_func - - def apply_nonbatched(self,pars): - return jnp.expand_dims(self.jaxeval(pars),2) - def apply_batched(self,pars): + def apply_nonbatched(self, pars): + return jnp.expand_dims(self.jaxeval(pars), 2) + + def apply_batched(self, pars): return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars) def apply(self, pars): @@ -129,19 +145,18 @@ def apply(self, pars): par_selection = self.param_viewer.get(pars) results_purefunc = self.apply_batched(par_selection) return results_purefunc - + return PureFunctionModifierBuilder, PureFunctionModifierApplicator from pyhf.modifiers import histfactory_set + def enable(): modifier_set = {} modifier_set.update(**histfactory_set) builder, applicator = create_modifiers() - modifier_set.update(**{ - applicator.name: (builder, applicator)} - ) - return modifier_set \ No newline at end of file + modifier_set.update(**{applicator.name: (builder, applicator)}) + return modifier_set