"""
gen.py — Random molecule generation and adaptive genetic algorithm for LACAN.
This module has two responsibilities:
1. **Random generation** — assemble drug-like molecules from scratch by
recursively sampling fragments from the corpus and joining them through
their dummy-atom attachment points.
2. **Optimisation** — run an adaptive GA (:func:`generate_optimized_molecules`)
that iteratively improves molecules toward a user-defined scoring function
while maintaining population diversity.
Random generation
-----------------
:func:`generate_random_molecule` builds a molecule by starting from a randomly
chosen corpus fragment and recursively filling its open attachment points with
rings, linkers, or substituents sampled proportionally to their corpus
frequency. The recursion terminates when all dummies are consumed.
:func:`generate_filtered_molecule` wraps the above with a rejection loop that
discards molecules outside the atom-count window or below the LACAN score
threshold.
:func:`generate_filtered_molecules` parallelises the above using
``multiprocessing.Pool``, giving each worker a unique seed offset so the
molecules are independent.
Corpus biasing
--------------
:func:`get_corpus_from_mols` / :func:`bias_corpus` build a merged corpus that
over-represents fragments from a reference set (e.g. known actives). The
``ratio`` parameter controls the boost: at ratio=2 the reference fragments'
effective frequency is doubled relative to the background ChEMBL corpus.
The resulting corpus can be passed as ``fragcorpus=`` to any generation
function to steer sampling toward the desired chemotype.
Adaptive GA
-----------
See :func:`generate_optimized_molecules` for the full description.
GA design
---------
* A :class:`HallOfFame` collects the all-time best diverse molecules;
the final result is simply the top-N from it.
* **scoring_budget** — a hard cap on molecules sent to the scoring function
per generation. Candidates are pre-screened with cheap property filters
and MaxMin diversity sampling before scoring, keeping the budget tight even
when raw operation output is large.
* **Per-operation Thompson Sampling bandit** — seven arms (ring replacement,
linker replacement, substituent replacement, scaffold decoration, crossover,
random injection, and atom-level mutation) each maintain a Beta posterior
over their hit rate. Budget is allocated proportionally to sampled weights,
so productive arms get more budget while all arms remain explored.
* **Smooth explore_fraction** — a float in [0,1] controls the mix of
exploration arms (ring/linker/sub-replace, decoration, crossover, random)
vs exploitation arms (mutation). On plateau it shifts toward mutation
(highest empirical hit rate); on diversity collapse it shifts toward
exploration. It decays back toward a user-set baseline otherwise.
* **Presets** — ``preset="ml"`` / ``"medium"`` / ``"docking"`` / ``"guacamol"``
set a coherent bundle of defaults appropriate for fast, medium, slow, and
unlimited-budget scoring functions respectively. Individual kwargs override
preset values.
"""
from rdkit import Chem
from rdkit.Chem import Descriptors
from lacan import lacan
from lacan import mutate, breed, decompose
from lacan import replace as replace_module
import random
import math
import csv
import os
from rdkit import DataStructs
from rdkit.Chem import rdChemReactions
from rdkit.Chem import rdFingerprintGenerator
import multiprocessing
MFPGEN = rdFingerprintGenerator.GetMorganGenerator(2, fpSize=4096)
rxn1 = rdChemReactions.ReactionFromSmarts("[*:0]-[#0:1].[*:2]-[#0:3]>>[*:0]-[*:2].[*:1]-[*:3]")
rxn2 = rdChemReactions.ReactionFromSmarts("[*:0]=[#0:1].[*:2]=[#0:3]>>[*:0]=[*:2].[*:1]=[*:3]")
_corpus_cache = {} # keyed by min_count so different thresholds coexist
# ── Presets ───────────────────────────────────────────────────────────────────
# Each preset sets a coherent bundle of defaults. Individual kwargs passed to
# generate_optimized_molecules() always override preset values.
_PRESETS = {
"ml": dict(
# Fast QSAR/ML model: 0.1–10 ms/mol. Score many molecules freely.
# Large startN is cheap and gives a much better initial population.
startN=1000,
popsize=40,
scoring_budget=200,
generations=15,
plateau_patience=3,
base_explore_fraction=0.4,
hof_size=100,
hof_sim_threshold=0.55,
pool_diversity_k=2, # keep top 2*popsize, then diversity-select
n_replacements=10,
conservative=True,
),
"medium": dict(
# ADMET models, shape, ML ensemble: 0.1–2 s/mol.
startN=40,
popsize=25,
scoring_budget=60,
generations=12,
plateau_patience=3,
base_explore_fraction=0.45,
hof_size=100,
hof_sim_threshold=0.65,
pool_diversity_k=2,
n_replacements=8,
conservative=True,
),
"docking": dict(
# Physics-based docking: 5–30 s/mol. Tight budget is critical.
startN=20,
popsize=15,
scoring_budget=20,
generations=10,
plateau_patience=2,
base_explore_fraction=0.5,
hof_size=50,
hof_sim_threshold=0.60,
pool_diversity_k=2,
n_replacements=5,
conservative=True,
),
"guacamol": dict(
# Unlimited scoring budget (GuacaMol benchmark).
# Large population + many generations; mutation-dominant.
# Pool is culled immediately after scoring each generation to prevent
# unbounded growth (was causing pool=22k which slows _diverse_top_k).
startN=1000,
popsize=200,
scoring_budget=30000,
generations=25,
plateau_patience=5,
base_explore_fraction=0.15,
hof_size=1000,
hof_sim_threshold=0.55,
pool_diversity_k=3,
n_replacements=10, # was 15 — reduces frag time ~30%
conservative=True,
maxmin_cap_factor=1,
),}
[docs]
def load_corpus(min_count=200):
"""Load the ChEMBL fragment corpus from ``data/rls.csv``, with caching.
The CSV is read only on the first call for a given ``min_count`` value;
subsequent calls return the cached result immediately.
Parameters
----------
min_count : int
Minimum occurrence count for a fragment to be included (default 200).
Returns
-------
list of [smiles, count, degree, ftype, bonds]
"""
if min_count in _corpus_cache:
return _corpus_cache[min_count]
_this_dir = os.path.dirname(__file__)
_path = os.path.join(_this_dir, "data/rls.csv")
_entries = []
with open(_path, newline="") as _f:
_reader = csv.reader(_f, delimiter=",", quotechar='"')
for _e in _reader:
if _e[0] != "smiles" and int(_e[1]) > min_count:
_entries.append([_e[0], int(_e[1]), int(_e[2]), _e[3], _e[4]])
_corpus_cache[min_count] = _entries
return _entries
def _fragcount(corpus):
"""Return total occurrence counts per fragment type for *corpus*."""
fc = {"Sub": 0, "Linker": 0, "Ring": 0}
for e in corpus:
fc[e[3]] += e[1]
return fc
[docs]
def generate_random_molecule(fragments=[], fragcorpus=None, needed_sample=[], bondstring="", mol=None):
"""Recursively assemble a random molecule from corpus fragments.
Parameters
----------
fragments : list — accumulates SMILES of all fragments used so far
fragcorpus : list of corpus entries (default: ChEMBL rls.csv)
needed_sample : list of lists — allowed fragment types for each open point
bondstring : str — bond types for the current open attachment points
mol : RDKit Mol — molecule built so far
Returns
-------
(mol, fragments)
"""
if fragcorpus is None:
fragcorpus = load_corpus()
FRAGCOUNT = _fragcount(fragcorpus)
if len(fragments) > 0:
new_needed_sample = []
new_bondstring = ""
es = []
for i, cns in enumerate(needed_sample):
fragtype = random.choices(cns, [FRAGCOUNT[ft] for ft in cns])[0]
fe = [e for e in fragcorpus if e[3] == fragtype and bondstring[i] == e[4][0]]
es.append(random.choices([k for k in fe], [k[1]**0.666 for k in fe])[0])
for i, e in enumerate(es):
fragments.append(e[0])
if bondstring[i] == "-":
mol = rxn1.RunReactants((mol, Chem.MolFromSmiles(e[0])))[0][0]
else:
mol = rxn2.RunReactants((mol, Chem.MolFromSmiles(e[0])))[0][0]
Chem.SanitizeMol(mol)
if e[3] == "Ring" and e[2] > 1:
new_needed_sample += (e[2] - 1) * [["Sub", "Linker"]]
new_bondstring += e[4][1:]
if e[3] == "Linker":
new_needed_sample += (e[2] - 1) * [["Ring"]]
new_bondstring += e[4][1:]
if len(new_needed_sample) > 0:
mol, fragments = generate_random_molecule(fragments, fragcorpus, new_needed_sample, new_bondstring, mol)
else:
e = random.choices([k for k in fragcorpus], [k[1]**0.666 for k in fragcorpus])[0]
mol = Chem.MolFromSmiles(e[0])
if e[3] == "Linker":
mol, fragments = generate_random_molecule([e[0]], fragcorpus, e[2] * [["Ring"]], e[4], mol)
elif e[3] == "Sub":
mol, fragments = generate_random_molecule([e[0]], fragcorpus, [["Ring"]], e[4], mol)
else:
mol, fragments = generate_random_molecule([e[0]], fragcorpus, e[2] * [["Sub", "Linker", "Ring"]], e[4], mol)
return mol, fragments
[docs]
def generate_filtered_molecule(profile, fragcorpus=None, threshold=0.5, seed=None, min_atoms=14, max_atoms=35):
"""Generate one random molecule that passes the LACAN score and size filters.
Parameters
----------
profile : LACAN profile dict
fragcorpus: fragment corpus (default: ChEMBL rls.csv)
threshold : minimum LACAN score to accept (default 0.5)
seed : int or None
min_atoms : minimum heavy-atom count, exclusive (default 14)
max_atoms : maximum heavy-atom count, exclusive (default 35)
Returns
-------
RDKit Mol
"""
if fragcorpus is None:
fragcorpus = load_corpus()
score = 0
if seed:
random.seed(seed)
while score < threshold:
mol, _ = generate_random_molecule(fragcorpus=fragcorpus)
score, _ = lacan.score_mol(mol, profile)
if not (min_atoms < len(mol.GetAtoms()) < max_atoms):
score = 0
return mol
[docs]
def generate_filtered_molecules(profile, fragcorpus=None, threshold=0.5, seed=None,
min_atoms=14, max_atoms=35, n_jobs=1, n_molecules=10):
"""Generate multiple filtered molecules, optionally in parallel.
Parameters
----------
profile : LACAN profile dict
fragcorpus : fragment corpus (default: ChEMBL rls.csv)
threshold : minimum LACAN score (default 0.5)
seed : int or None — base seed
min_atoms : minimum heavy-atom count, exclusive (default 14)
max_atoms : maximum heavy-atom count, exclusive (default 35)
n_jobs : int — parallel workers; 1 = sequential, -1 = all CPU cores
n_molecules : int — number of molecules to generate (default 10)
Returns
-------
list of RDKit Mol
"""
if fragcorpus is None:
fragcorpus = load_corpus()
if n_jobs == 1:
mols = [generate_filtered_molecule(profile, fragcorpus, threshold, seed, min_atoms, max_atoms)
for i in range(n_molecules)]
else:
if n_jobs == -1:
n_jobs = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=n_jobs)
mols = pool.starmap(generate_filtered_molecule,
[(profile, fragcorpus, threshold, seed + 823848 * i if seed else None, min_atoms, max_atoms)
for i in range(n_molecules)])
pool.close()
pool.join()
return mols
[docs]
def get_corpus_from_mols(mols, fragcorpus=None, ratio=1):
"""Merge a custom fragment corpus from reference molecules into the background corpus.
Parameters
----------
mols : iterable of RDKit Mol objects
fragcorpus : background corpus (default: ChEMBL rls.csv)
ratio : float — frequency multiplier for reference fragments (default 1)
Returns
-------
list of corpus entries
"""
if fragcorpus is None:
fragcorpus = load_corpus()
custom_entries = decompose.get_corpus(mols)
custom_fragcount = sum([e[1] for e in custom_entries])
fragcount = sum([e[1] for e in custom_entries])
mult = fragcount / custom_fragcount * ratio
fragcounts = {e[0]: e[1] for e in fragcorpus}
custom_fragcounts = {e[0]: e[1] for e in custom_entries}
merged_entries = []
for e in custom_entries:
if e[0] in fragcounts:
merged_entries.append([e[0]] + [int(e[1] * mult) + fragcounts[e[0]]] + e[2:])
else:
merged_entries.append([e[0]] + [int(e[1] * mult)] + e[2:])
for e in fragcorpus:
if e[0] not in custom_fragcounts:
merged_entries.append(e)
return merged_entries
[docs]
def bias_corpus(mols, fragcorpus=None, ratio=2.0):
"""Build a fragment corpus biased toward the chemistry of the provided molecules.
Parameters
----------
mols : iterable of RDKit Mol objects — reference molecules
fragcorpus : background corpus (default: ChEMBL rls.csv)
ratio : float — frequency multiplier (default 2.0)
Returns
-------
list of corpus entries
"""
if fragcorpus is None:
fragcorpus = load_corpus()
return get_corpus_from_mols(mols, fragcorpus=fragcorpus, ratio=ratio)
# ── GA helpers ────────────────────────────────────────────────────────────────
def _mean_diversity(smis):
"""Mean pairwise Tanimoto distance for a list of SMILES. Range [0,1]."""
if len(smis) < 2:
return 1.0
fps = [MFPGEN.GetFingerprint(Chem.MolFromSmiles(s)) for s in smis]
total, n = 0.0, 0
for i in range(len(fps)):
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[i+1:])
total += sum(1 - s for s in sims)
n += len(sims)
return total / n if n > 0 else 1.0
def _sign(higher_is_better):
"""Return -1 if higher is better (GA minimises internally)."""
return -1 if higher_is_better else 1
def _safe_score(scoring_function, mols):
"""Call scoring_function safely, returning 0.0 for None mols or exceptions."""
scores = []
valid_mols = []
valid_indices = []
for i, mol in enumerate(mols):
if mol is not None:
valid_mols.append(mol)
valid_indices.append(i)
if not valid_mols:
return [0.0] * len(mols)
try:
raw = scoring_function(valid_mols)
except Exception:
raw = [0.0] * len(valid_mols)
result = [0.0] * len(mols)
for idx, score in zip(valid_indices, raw):
try:
result[idx] = float(score)
except Exception:
result[idx] = 0.0
return result
def _basic_property_filter(mol, min_atoms=8, max_atoms=60, max_mw=700.0,
max_rotbonds=15, max_hbd=7):
"""Fast property-based pre-filter before expensive scoring.
Returns True if the molecule passes all filters.
"""
if mol is None:
return False
try:
n = mol.GetNumAtoms()
if not (min_atoms <= n <= max_atoms):
return False
mw = Descriptors.MolWt(mol)
if mw > max_mw:
return False
rb = Descriptors.NumRotatableBonds(mol)
if rb > max_rotbonds:
return False
hbd = Descriptors.NumHDonors(mol)
if hbd > max_hbd:
return False
return True
except Exception:
return False
def _maxmin_sample(mols, k):
"""MaxMin diversity sampling: pick k molecules maximising pairwise diversity.
Precomputes the full n×n Tanimoto distance matrix once, then runs the
greedy MaxMin selection using numpy min/argmax over columns — no per-
iteration BulkTanimotoSimilarity calls inside the loop.
Time complexity: O(n²) for matrix build + O(n×k) for selection,
but the selection loop is pure numpy with no Python overhead per candidate.
Parameters
----------
mols : list of RDKit Mol
k : int — number to select (if k >= len(mols), returns all)
Returns
-------
list of RDKit Mol
"""
import numpy as np
n = len(mols)
if k >= n:
return list(mols)
fps = [MFPGEN.GetFingerprint(m) for m in mols]
# Build full n×n distance matrix in one pass.
# BulkTanimotoSimilarity(fps[i], fps) returns row i of the similarity
# matrix; we compute all rows and convert to distances in one go.
sim_matrix = np.zeros((n, n), dtype=np.float32)
for i, fp in enumerate(fps):
row = DataStructs.BulkTanimotoSimilarity(fp, fps)
sim_matrix[i] = row
dist_matrix = 1.0 - sim_matrix # distance = 1 - similarity
# Seed: pick the molecule with the highest mean distance to all others
# (the most peripheral point).
mean_dist = dist_matrix.sum(axis=1) / (n - 1)
selected = [int(np.argmax(mean_dist))]
# min_dist_to_selected[i] = distance from mol i to its nearest selected mol
min_dist_to_selected = dist_matrix[selected[0]].copy()
selected_mask = np.zeros(n, dtype=bool)
selected_mask[selected[0]] = True
while len(selected) < k:
# Zero out already-selected so they can't be picked again
min_dist_to_selected[selected_mask] = -1.0
best_idx = int(np.argmax(min_dist_to_selected))
if min_dist_to_selected[best_idx] < 0:
break
selected.append(best_idx)
selected_mask[best_idx] = True
# Update: min distance to selected set = elementwise min of current
# min-distances and distances to the newly added molecule
np.minimum(min_dist_to_selected, dist_matrix[best_idx],
out=min_dist_to_selected)
return [mols[i] for i in selected]
def _prescreened_candidates(raw_mols, budget, seen_smis,
min_atoms=8, max_atoms=60,
max_mw=700.0, max_rotbonds=15, max_hbd=7,
maxmin_cap_factor=10):
"""Filter and diversity-sample a raw candidate list down to scoring budget.
Pipeline:
1. Drop None and already-seen SMILES
2. Apply cheap property filters
3. Random subsample to ``budget * maxmin_cap_factor`` before MaxMin —
caps the O(n²) MaxMin cost regardless of how large the raw pool is
4. MaxMin diversity sampling to reach ``budget``
Parameters
----------
raw_mols : list of RDKit Mol (may contain None)
budget : int — maximum number to return
seen_smis : set of SMILES already scored (mutated in place)
maxmin_cap_factor : int — MaxMin input is capped at
``budget * maxmin_cap_factor`` (default 10).
E.g. budget=200 → MaxMin sees at most 2000 candidates.
Raise if you want more diversity coverage at the cost
of speed; lower to speed up further.
Returns
-------
list of (smiles, mol) tuples, length <= budget
"""
# Step 1: deduplicate and property filter
candidates = []
for m in raw_mols:
if not _basic_property_filter(m, min_atoms=min_atoms, max_atoms=max_atoms,
max_mw=max_mw, max_rotbonds=max_rotbonds,
max_hbd=max_hbd):
continue
try:
smi = Chem.MolToSmiles(m)
except Exception:
continue
if smi in seen_smis:
continue
candidates.append((smi, m))
if not candidates:
return []
# Step 2: random subsample before MaxMin to bound O(n²) FP cost.
# Shuffle first so the random cap is unbiased, then MaxMin picks the
# most diverse subset from this pre-sampled pool.
maxmin_input_cap = budget * maxmin_cap_factor
if len(candidates) > maxmin_input_cap:
candidates = random.sample(candidates, maxmin_input_cap)
# Step 3: MaxMin diversity sampling down to budget
if len(candidates) > budget:
sampled_mols = _maxmin_sample([m for _, m in candidates], budget)
sampled_smis = {Chem.MolToSmiles(m) for m in sampled_mols}
candidates = [(smi, m) for smi, m in candidates if smi in sampled_smis]
# Register all returned SMILES as seen
for smi, _ in candidates:
seen_smis.add(smi)
return candidates[:budget]
def _fragment_moves_by_type_worker(args):
"""Top-level picklable worker for parallel fragment moves across parents."""
mol_smi, profile, fragcorpus, n_replacements, conservative, min_atoms = args
mol = Chem.MolFromSmiles(mol_smi)
if mol is None:
return {"ring": [], "linker": [], "sub": [], "decorate": []}
return _fragment_moves_by_type(mol, profile, fragcorpus,
n_replacements=n_replacements,
conservative=conservative,
min_atoms=min_atoms)
def _fragment_moves_all_parents(pool, profile, fragcorpus, n_replacements=5,
conservative=True, n_jobs=1):
"""Apply fragment moves to all pool parents, in parallel if n_jobs != 1.
Returns
-------
list of (parent_internal_score, by_type_dict) — one per parent
"""
if n_jobs == 1 or len(pool) <= 1:
out = []
for smi, sc in pool:
mol = Chem.MolFromSmiles(smi)
if mol is None:
out.append((sc, {"ring": [], "linker": [], "sub": [], "decorate": []}))
else:
out.append((sc, _fragment_moves_by_type(mol, profile, fragcorpus,
n_replacements=n_replacements,
conservative=conservative)))
return out
actual_jobs = multiprocessing.cpu_count() if n_jobs == -1 else n_jobs
worker_args = [(smi, profile, fragcorpus, n_replacements, conservative, 5)
for smi, _ in pool]
with multiprocessing.Pool(processes=actual_jobs) as mp_pool:
by_type_list = mp_pool.map(_fragment_moves_by_type_worker, worker_args)
return [(sc, bt) for (_, sc), bt in zip(pool, by_type_list)]
def _fragment_moves_by_type(mol, profile, fragcorpus, n_replacements=5,
conservative=True, min_atoms=5, protect_smarts=None):
"""Apply each fragment operation separately and return per-operation outputs.
Returns
-------
dict mapping operation name -> list of RDKit Mol
Keys: "ring", "linker", "sub", "decorate"
"""
results = {"ring": [], "linker": [], "sub": [], "decorate": []}
op_map = {
"ring": replace_module.replace_ring,
"linker": replace_module.replace_linker,
"sub": replace_module.replace_substituent,
}
for name, op in op_map.items():
try:
out = op(mol, profile, score_threshold=0.0,
fragcorpus=fragcorpus, n_replacements=n_replacements,
conservative=conservative, protect_smarts=protect_smarts)
results[name] = [m for m in out if m is not None and m.GetNumAtoms() >= min_atoms]
except Exception:
pass
try:
out = replace_module.decorate_scaffold(
mol, profile, score_threshold=0.0,
fragcorpus=fragcorpus, n_replacements=n_replacements,
mode="Hydrogen", protect_smarts=protect_smarts)
results["decorate"] = [m for m in out if m is not None and m.GetNumAtoms() >= min_atoms]
except Exception:
pass
return results
def _fragment_moves(mol, profile, fragcorpus, n_replacements=5, conservative=True, min_atoms=5, protect_smarts=None):
"""Apply all fragment-level operations to a single molecule (flat list output).
Kept for backwards compatibility with optimize_from_mol.
"""
by_type = _fragment_moves_by_type(mol, profile, fragcorpus,
n_replacements=n_replacements,
conservative=conservative,
min_atoms=min_atoms,
protect_smarts=protect_smarts)
new_mols = []
for mlist in by_type.values():
new_mols += mlist
seen, out = set(), []
for m in new_mols:
smi = Chem.MolToSmiles(m)
if smi not in seen:
seen.add(smi)
out.append(m)
return out
# ── Thompson Sampling Bandit ──────────────────────────────────────────────────
[docs]
class ThompsonBandit:
"""Beta-Bernoulli Thompson Sampling bandit over GA operations.
Models each arm as a Bernoulli process: did this arm produce at least one
improvement this generation? Maintains Beta(α, β) posteriors over each
arm's true hit rate. Allocation weights are drawn from these posteriors
each generation, so uncertain arms are still explored but arms with
consistently higher hit rates get proportionally more budget over time.
Thompson Sampling is well-suited here because most moves are non-productive
(hit rate ~1–5%), so the Bernoulli posterior handles sparse signals natively.
Budget allocation is proportional (soft) rather than winner-take-all, so all
arms remain explored. Statistics persist across runs (saved to JSON), letting
the bandit accumulate meaningful signal over many optimisation campaigns.
Prior α₀ / β₀
--------------
The prior encodes domain knowledge about each arm's expected hit rate,
based on empirical observations from fragment-based drug design:
- sub, mutate : slightly higher prior hit rate (α=2, β=18 → ~10%)
- ring, linker : moderate (α=1, β=14 → ~7%)
- decorate : moderate (α=1, β=14 → ~7%)
- crossover : lower prior (α=1, β=24 → ~4%)
- random : lowest (α=1, β=49 → ~2%) — random injection rarely
beats an optimised population directly
These are weak priors (total pseudo-counts 15–25) so they are quickly
overridden by real data once enough generations accumulate.
Parameters
----------
arm_names : list of str
stats_file : str or None — path to JSON file for persistent statistics.
If provided, α/β counts are loaded on init and saved after
each update. Pass None to disable persistence.
prior_alpha : dict mapping arm name -> α₀ (default: built-in domain priors)
prior_beta : dict mapping arm name -> β₀ (default: built-in domain priors)
"""
# Domain-knowledge priors: (alpha0, beta0) per arm
_DEFAULT_PRIORS = {
"ring": (1, 14), # ~7% hit rate prior
"linker": (1, 14), # ~7%
"sub": (2, 18), # ~10%
"decorate": (1, 14), # ~7%
"crossover": (1, 24), # ~4%
"random": (1, 49), # ~2%
"mutate": (2, 18), # ~10%
}
def __init__(self, arm_names, stats_file=None, prior_alpha=None, prior_beta=None):
self.arms = arm_names
self.stats_file = stats_file
# Build priors — per-arm overrides, falling back to defaults
pa = prior_alpha or {}
pb = prior_beta or {}
self._alpha0 = {a: pa.get(a, self._DEFAULT_PRIORS.get(a, (1, 9))[0])
for a in arm_names}
self._beta0 = {a: pb.get(a, self._DEFAULT_PRIORS.get(a, (1, 9))[1])
for a in arm_names}
# Posterior counts (accumulated across runs if stats_file is used)
self.alpha = {a: self._alpha0[a] for a in arm_names} # successes + prior
self.beta = {a: self._beta0[a] for a in arm_names} # failures + prior
self.n_pulls = {a: 0 for a in arm_names}
# Load persisted stats if available
if stats_file is not None:
self._load()
# ── persistence ──────────────────────────────────────────────────────────
def _load(self):
"""Load α/β counts from stats_file if it exists."""
import json
if not os.path.exists(self.stats_file):
return
try:
with open(self.stats_file) as f:
data = json.load(f)
for a in self.arms:
if a in data.get("alpha", {}):
self.alpha[a] = data["alpha"][a]
if a in data.get("beta", {}):
self.beta[a] = data["beta"][a]
if a in data.get("n_pulls", {}):
self.n_pulls[a] = data["n_pulls"][a]
except Exception as e:
pass # silently ignore corrupt files
[docs]
def save(self):
"""Persist current α/β counts to stats_file."""
import json
if self.stats_file is None:
return
try:
os.makedirs(os.path.dirname(os.path.abspath(self.stats_file)), exist_ok=True)
with open(self.stats_file, "w") as f:
json.dump({"alpha": self.alpha, "beta": self.beta,
"n_pulls": self.n_pulls}, f, indent=2)
except Exception:
pass
[docs]
def reset_stats(self):
"""Reset posteriors to priors and delete stats_file if present."""
self.alpha = {a: self._alpha0[a] for a in self.arms}
self.beta = {a: self._beta0[a] for a in self.arms}
self.n_pulls = {a: 0 for a in self.arms}
if self.stats_file and os.path.exists(self.stats_file):
try:
os.remove(self.stats_file)
except Exception:
pass
# ── bandit interface ──────────────────────────────────────────────────────
[docs]
def sample_weights(self):
"""Draw one sample from each arm's Beta posterior.
Returns a dict arm -> sampled weight. Use these as proportional
allocation weights for the scoring budget.
"""
return {a: random.betavariate(self.alpha[a], self.beta[a])
for a in self.arms}
[docs]
def update(self, arm, success):
"""Update posterior for *arm*.
Parameters
----------
arm : str
success : bool or int — True/1 if the arm produced at least one
improvement this generation, False/0 otherwise.
"""
self.n_pulls[arm] += 1
if success:
self.alpha[arm] += 1
else:
self.beta[arm] += 1
@property
def hit_rate(self):
"""Posterior mean hit rate per arm: α / (α + β)."""
return {a: self.alpha[a] / (self.alpha[a] + self.beta[a])
for a in self.arms}
[docs]
def summary(self):
hr = self.hit_rate
rows = []
for a in self.arms:
rows.append(
f" {a:12s} pulls={self.n_pulls[a]:4d}"
f" alpha={self.alpha[a]:6.1f} beta={self.beta[a]:6.1f}"
f" hit_rate={hr[a]:.3f}"
)
return "\n".join(rows)
# ── HallOfFame ────────────────────────────────────────────────────────────────
[docs]
class HallOfFame:
"""Diversity-gated collection of the best molecules seen across all generations.
A new molecule is accepted only if its Tanimoto similarity to every current
member is below ``sim_threshold``. When full, a new molecule displaces the
worst member if it scores better.
Parameters
----------
max_size : int — maximum number of entries (default 100)
sim_threshold : float — max allowed Tanimoto similarity to existing members
(default 0.70)
higher_is_better : bool (default True)
"""
# Number of nearest neighbours checked for the diversity gate.
# Checking all members becomes expensive when HoF is large and adds little
# value — a mediocre early entry shouldn't block a great later one forever
# just because it happens to be similar. We check the K most similar
# existing members instead.
_DIVERSITY_GATE_K = 20
def __init__(self, max_size=100, sim_threshold=0.55, higher_is_better=True):
self.max_size = max_size
self.sim_threshold = sim_threshold
self.higher_is_better = higher_is_better
self._entries = [] # list of (smiles, score, fp)
def _is_better(self, a, b):
return a > b if self.higher_is_better else a < b
[docs]
def offer(self, smi, score, mol):
"""Try to add (smi, score) to the HallOfFame.
Diversity gate: the candidate is accepted only if its maximum Tanimoto
similarity to the ``_DIVERSITY_GATE_K`` most similar existing members
is below ``sim_threshold``. Using the top-K rather than all members
prevents mediocre early entries from permanently blocking better later
ones that happen to be structurally related.
Returns True if accepted, False if rejected.
"""
try:
fp = MFPGEN.GetFingerprint(mol)
except Exception:
return False
if self._entries:
existing_fps = [e[2] for e in self._entries]
sims = DataStructs.BulkTanimotoSimilarity(fp, existing_fps)
# Check only against the K most similar members
top_k_sims = sorted(sims, reverse=True)[:self._DIVERSITY_GATE_K]
if top_k_sims and top_k_sims[0] >= self.sim_threshold:
# Similar neighbour found — only block if that neighbour is at
# least as good as the candidate (don't block an improvement)
most_similar_idx = int(max(range(len(sims)), key=lambda i: sims[i]))
neighbour_score = self._entries[most_similar_idx][1]
if not self._is_better(score, neighbour_score):
return False
# Candidate is better than its nearest neighbour → replace it
self._entries[most_similar_idx] = (smi, score, fp)
return True
if len(self._entries) < self.max_size:
self._entries.append((smi, score, fp))
return True
# HoF full and no similar neighbour to replace: displace the worst
# member if the candidate is better than it.
worst_idx = min(range(len(self._entries)),
key=lambda i: self._entries[i][1] if self.higher_is_better
else -self._entries[i][1])
if self._is_better(score, self._entries[worst_idx][1]):
self._entries[worst_idx] = (smi, score, fp)
return True
return False
def _worth_offering(self, score):
"""Quick check: is this score worth running the full offer() logic?
Returns True if the HoF isn't full, or if the score beats the current
worst member. Avoids fingerprint computation for clearly hopeless cases.
"""
if len(self._entries) < self.max_size:
return True
worst_score = min(e[1] for e in self._entries) if self.higher_is_better \
else max(e[1] for e in self._entries)
return self._is_better(score, worst_score)
[docs]
def top(self, n=None):
"""Return top-n (smiles, score) pairs sorted best-first."""
sorted_entries = sorted(self._entries, key=lambda e: e[1],
reverse=self.higher_is_better)
result = [(smi, sc) for smi, sc, _ in sorted_entries]
return result if n is None else result[:n]
@property
def best_score(self):
if not self._entries:
return None
return max(e[1] for e in self._entries) if self.higher_is_better \
else min(e[1] for e in self._entries)
def __len__(self):
return len(self._entries)
# ── Pool diversity-aware selection ───────────────────────────────────────────
def _diverse_top_k(pool_entries, k, diversity_weight=0.3):
"""Select k molecules from pool balancing score and diversity.
Algorithm: iteratively pick the candidate with the best combined
score_rank + diversity_rank (where diversity = min distance to already
selected set), starting from the top scorer.
Parameters
----------
pool_entries : list of (smiles, internal_score) [lower internal = better]
k : int — how many to keep
diversity_weight: float in [0,1] — weight of diversity vs score (default 0.3)
Returns
-------
list of (smiles, internal_score)
"""
if len(pool_entries) <= k:
return list(pool_entries)
# Pre-compute fingerprints
fps = []
valid = []
for entry in pool_entries:
try:
fp = MFPGEN.GetFingerprint(Chem.MolFromSmiles(entry[0]))
fps.append(fp)
valid.append(entry)
except Exception:
pass
n = len(valid)
if n <= k:
return valid
# Rank by score (ascending internal score = better)
score_rank = [0] * n
for rank, idx in enumerate(sorted(range(n), key=lambda i: valid[i][1])):
score_rank[idx] = rank
selected_indices = []
selected_fps = []
min_dist_to_selected = [1.0] * n # initialise: no selected yet → max distance
# Always start with the best scorer
best_idx = min(range(n), key=lambda i: valid[i][1])
selected_indices.append(best_idx)
selected_fps.append(fps[best_idx])
sims = DataStructs.BulkTanimotoSimilarity(fps[best_idx], fps)
min_dist_to_selected = [min(1 - s, d) for s, d in zip(sims, min_dist_to_selected)]
while len(selected_indices) < k:
# Rank by diversity (higher min-distance = more diverse = better)
remaining = [i for i in range(n) if i not in selected_indices]
div_scores = [min_dist_to_selected[i] for i in remaining]
max_div = max(div_scores) if max(div_scores) > 0 else 1.0
div_rank = {i: div_scores[j] / max_div for j, i in enumerate(remaining)}
# Combined score: lower score_rank_normalised + higher div is better
max_srank = n - 1
best = max(remaining,
key=lambda i: (1 - diversity_weight) * (1 - score_rank[i] / max_srank)
+ diversity_weight * div_rank[i])
selected_indices.append(best)
selected_fps.append(fps[best])
sims = DataStructs.BulkTanimotoSimilarity(fps[best], fps)
min_dist_to_selected = [min(1 - s, d) for s, d in zip(sims, min_dist_to_selected)]
return [valid[i] for i in selected_indices]
[docs]
def generate_optimized_molecules(
scoring_function,
profile,
seed=123,
startN=None,
generations=None,
popsize=None,
scoring_budget=None,
higher_is_better=True,
diversity_threshold=0.35,
plateau_patience=None,
base_explore_fraction=None,
explore_ratio=None, # legacy alias for base_explore_fraction
hof_size=None,
hof_sim_threshold=None,
sim_threshold=None, # legacy alias for hof_sim_threshold
pool_diversity_k=2,
n_replacements=None,
conservative=None,
fragcorpus=None,
n_jobs=-1,
quiet=False,
callback=None,
seed_mols=None,
preset="ml",
bandit_stats_file=None,
maxmin_cap_factor=None, # cap for MaxMin prescreening; 1 = skip MaxMin (fast scorers)
# Legacy / ignored parameters kept for API compatibility
win_threshold=None, # silently ignored — use HallOfFame instead
**kwargs):
"""
Adaptive GA with per-operation Thompson Sampling bandit, scoring budget, and HallOfFame.
Quick start
-----------
For most use cases, set ``preset=`` and override individual parameters as
needed::
# Fast ML model
winners = gen.generate_optimized_molecules(rf_score, profile, preset="ml")
# Slow docking function — tight budget
winners = gen.generate_optimized_molecules(dock, profile, preset="docking",
scoring_budget=15)
# GuacaMol benchmark — unlimited scoring budget
winners = gen.generate_optimized_molecules(guac_score, profile, preset="guacamol")
Presets
-------
``preset="ml"`` — fast QSAR/ML models (0.1–10 ms/mol)
``preset="medium"`` — ADMET, shape, ML ensembles (0.1–2 s/mol)
``preset="docking"`` — physics-based docking (5–30 s/mol)
``preset="guacamol"`` — unlimited scoring budget, large population,
mutation-dominant allocation
All preset values can be overridden by passing explicit kwargs.
Key parameters
--------------
scoring_budget : int
Hard cap on molecules sent to the scoring function per generation.
Candidates are pre-screened with cheap property filters and MaxMin
diversity sampling before scoring. This is the most important knob
for slow scoring functions.
base_explore_fraction : float in [0,1]
Baseline fraction of the scoring budget allocated to exploration arms
(ring/linker/sub replacement, decoration, crossover, random injection)
vs exploitation arms (atom-level mutation). On scoring plateau this
shifts toward mutation (highest empirical hit rate); on diversity
collapse it shifts toward exploration.
hof_size : int
Maximum size of the HallOfFame. The HoF collects the best diverse
molecules found across all generations; the final result is the top-N
from it sorted best-first.
hof_sim_threshold : float
Maximum Tanimoto similarity allowed between two HallOfFame members
(diversity gate). Lower = more diverse but smaller effective HoF.
diversity_threshold : float
If mean pairwise Tanimoto distance of the pool falls below this,
explore_fraction is boosted.
plateau_patience : int
After this many generations without improvement, explore_fraction is
shifted toward mutation to exploit the arm with the highest hit rate.
Returns
-------
list of (smiles, score) tuples
All HallOfFame members sorted best-first.
"""
# ── resolve preset ────────────────────────────────────────────────────────
p = _PRESETS.get(preset, _PRESETS["ml"]).copy()
def _resolve(val, key):
"""Return val if not None, else preset default."""
return val if val is not None else p[key]
startN = _resolve(startN, "startN")
popsize = _resolve(popsize, "popsize")
scoring_budget = _resolve(scoring_budget, "scoring_budget")
generations = _resolve(generations, "generations")
plateau_patience = _resolve(plateau_patience, "plateau_patience")
n_replacements = _resolve(n_replacements, "n_replacements")
conservative = _resolve(conservative, "conservative")
hof_size = _resolve(hof_size, "hof_size")
pool_diversity_k = _resolve(pool_diversity_k, "pool_diversity_k")
# maxmin_cap_factor: preset may define it; kwarg overrides; fallback=10 (original default)
_mmcf_default = p.get("maxmin_cap_factor", 10)
_maxmin_cap_factor = maxmin_cap_factor if maxmin_cap_factor is not None else _mmcf_default
# Handle legacy aliases
_base_ef = base_explore_fraction if base_explore_fraction is not None \
else (explore_ratio if explore_ratio is not None
else p["base_explore_fraction"])
_hof_sim = hof_sim_threshold if hof_sim_threshold is not None \
else (sim_threshold if sim_threshold is not None
else p["hof_sim_threshold"])
if fragcorpus is None:
fragcorpus = load_corpus()
random.seed(seed)
sign = _sign(higher_is_better)
if not quiet:
print(f"[LACAN GA] preset={preset!r} startN={startN} popsize={popsize}"
f" scoring_budget={scoring_budget}/gen generations={generations}"
f" base_explore_fraction={_base_ef:.2f} maxmin_cap_factor={_maxmin_cap_factor}")
if win_threshold is not None:
print(f"[LACAN GA] Note: win_threshold={win_threshold} is ignored in this version. "
f"Results are returned as all HallOfFame entries (top {hof_size} diverse molecules).")
# ── data structures ───────────────────────────────────────────────────────
hof = HallOfFame(max_size=hof_size, sim_threshold=_hof_sim,
higher_is_better=higher_is_better)
pool = [] # list of (smiles, internal_score)
seen_smis = set() # all SMILES ever scored
# Bandit over 7 arms — Thompson Sampling with optional persistent stats
EXPLORE_ARMS = ["ring", "linker", "sub", "decorate", "crossover", "random"]
EXPLOIT_ARMS = ["mutate"]
bandit = ThompsonBandit(EXPLORE_ARMS + EXPLOIT_ARMS,
stats_file=bandit_stats_file)
# ── initial population ────────────────────────────────────────────────────
if seed_mols is not None:
if not quiet:
print(f" Seeding from {len(seed_mols)} provided molecules...")
start_mols = [m for m in seed_mols if m is not None]
for m in start_mols:
try:
seen_smis.add(Chem.MolToSmiles(m))
except Exception:
pass
else:
if not quiet:
print(f" Generating {startN} initial molecules...")
start_mols = generate_filtered_molecules(profile, fragcorpus=fragcorpus,
min_atoms=15, n_molecules=startN,
n_jobs=n_jobs, seed=seed)
for m in start_mols:
try:
seen_smis.add(Chem.MolToSmiles(m))
except Exception:
pass
raw_scores = _safe_score(scoring_function, start_mols)
for mol, rs in zip(start_mols, raw_scores):
try:
smi = Chem.MolToSmiles(mol)
except Exception:
continue
if hof._worth_offering(rs):
hof.offer(smi, rs, mol)
pool.append((smi, rs * sign))
# For the initial pool, keep a larger slice proportional to startN so that
# a large initial sweep (e.g. startN=1000) benefits the first few GA
# generations rather than being immediately discarded. We use diversity-
# aware selection so the larger pool is also well-spread, then decay to
# popsize at the start of gen 1 via the normal per-generation cull.
initial_pool_size = min(len(pool), max(popsize * pool_diversity_k,
startN // 10))
pool.sort(key=lambda x: x[1])
pool = pool[:initial_pool_size * pool_diversity_k]
pool = _diverse_top_k(pool, initial_pool_size, diversity_weight=0.3)
plateau_count = 0
explore_fraction = _base_ef
# Initialise best_internal from the culled pool so the plateau tracker
# has a consistent reference point from the start.
best_internal = pool[0][1] if pool else 0.0
if not quiet:
print(f" Initial pool: {len(pool)} molecules | HoF: {len(hof)}"
f" | best={hof.best_score}")
# ── generational loop ─────────────────────────────────────────────────────
import time as _time
_gen_t0 = _time.perf_counter()
_viable_for_crossover = None # cache: recomputed only when pool changes significantly
for gen_idx in range(generations):
_t_gen_start = _time.perf_counter()
if seed:
seed = (seed * 6543 + gen_idx) % (2**31)
# ── phase: pool cull + diversity-select ───────────────────────────────
_t0 = _time.perf_counter()
pool.sort(key=lambda x: x[1])
pool = pool[:popsize * pool_diversity_k]
pool = _diverse_top_k(pool, popsize, diversity_weight=0.3)
_t_pool_cull = _time.perf_counter() - _t0
culled_best = pool[0][1] if pool else best_internal
if culled_best < best_internal - 1e-6:
best_internal = culled_best
plateau_count = 0
else:
plateau_count += 1
# ── phase: mean diversity ─────────────────────────────────────────────
_t0 = _time.perf_counter()
diversity = _mean_diversity([smi for smi, _ in pool])
_t_diversity = _time.perf_counter() - _t0
force_explore = plateau_count >= plateau_patience
low_diversity = diversity < diversity_threshold
if low_diversity and not force_explore:
explore_fraction = min(explore_fraction + 0.10, 0.60)
elif force_explore:
explore_fraction = max(explore_fraction - 0.05, 0.05)
else:
explore_fraction = max(explore_fraction - 0.03, _base_ef)
n_exploit_budget = max(1, int(scoring_budget * (1 - explore_fraction)))
n_explore_budget = max(1, scoring_budget - n_exploit_budget)
parent_mols = [Chem.MolFromSmiles(smi) for smi, _ in pool]
parent_score_map = {smi: sc for smi, sc in pool}
arm_candidates = {a: [] for a in EXPLORE_ARMS + EXPLOIT_ARMS}
# ── phase: fragment moves (ring/linker/sub/decorate) ──────────────────
_t0 = _time.perf_counter()
parent_results = _fragment_moves_all_parents(
pool, profile, fragcorpus,
n_replacements=n_replacements,
conservative=conservative,
n_jobs=n_jobs,
)
for (parent_sc, by_type) in parent_results:
for arm in ["ring", "linker", "sub", "decorate"]:
for m in by_type.get(arm, []):
arm_candidates[arm].append((m, parent_sc))
_t_frag = _time.perf_counter() - _t0
# ── Budget allocation via Thompson Sampling ───────────────────────────
ts_weights = bandit.sample_weights()
def _allocate_budget(arms, total_budget):
weights = {a: ts_weights[a] for a in arms}
total_w = sum(weights.values())
per_arm = {}
remaining = total_budget
arm_list = list(arms)
for i, a in enumerate(arm_list):
if i == len(arm_list) - 1:
per_arm[a] = max(1, remaining)
else:
if total_w > 0:
alloc = max(1, int(total_budget * weights[a] / total_w))
else:
alloc = max(1, total_budget // len(arm_list))
per_arm[a] = alloc
remaining = max(0, remaining - alloc)
return per_arm
explore_alloc = _allocate_budget(EXPLORE_ARMS, n_explore_budget)
exploit_alloc = _allocate_budget(EXPLOIT_ARMS, n_exploit_budget)
arm_alloc = {**explore_alloc, **exploit_alloc}
# ── phase: crossover ──────────────────────────────────────────────────
# We call breed.breed directly (not cross_breed_mols) to control
# hacrange. Viability check (fragment_molecule) is cached across
# generations — recomputed only when the pool population changes
# (i.e. when plateau_count resets, meaning new best molecules entered).
_t0 = _time.perf_counter()
_crossover_err = None
pool_mean_sc = sum(sc for _, sc in pool) / len(pool) if pool else 0.0
if len(parent_mols) >= 2:
try:
# Recompute viable parents only on improvement (plateau reset)
# or if cache is empty. Saves ~2-4s of fragment_molecule calls.
if _viable_for_crossover is None or plateau_count == 0:
_viable_for_crossover = []
for m in parent_mols:
if m is None:
continue
try:
s, c = breed.fragment_molecule(m, 3)
if not c:
s, c = breed.fragment_molecule(m, 2)
if s and c:
_viable_for_crossover.append(m)
except Exception:
pass
if len(_viable_for_crossover) >= 2:
_nmols_per_pair = max(1, explore_alloc["crossover"] // max(1, len(_viable_for_crossover)))
_actual_jobs = multiprocessing.cpu_count() if n_jobs == -1 else max(1, n_jobs)
_breed_inputs = [
(m, random.choice(_viable_for_crossover), profile,
_nmols_per_pair, 3, (0.6, 1.5), (0.25, 0.75), False)
for m in _viable_for_crossover
]
with multiprocessing.Pool(processes=_actual_jobs) as _xpool:
_xresults = _xpool.starmap(breed.breed, _breed_inputs)
for _xmols in _xresults:
for m in _xmols:
arm_candidates["crossover"].append((m, pool_mean_sc))
else:
_crossover_err = f"too few fragmentable parents ({len(_viable_for_crossover) if _viable_for_crossover else 0}/{len(parent_mols)})"
except Exception as _e:
_crossover_err = repr(_e)
_t_crossover = _time.perf_counter() - _t0
# ── phase: random injection ───────────────────────────────────────────
_t0 = _time.perf_counter()
n_rand = max(2, explore_alloc["random"])
try:
randoms = generate_filtered_molecules(
profile, fragcorpus=fragcorpus, min_atoms=14,
n_molecules=n_rand, n_jobs=n_jobs, seed=seed)
for m in randoms:
arm_candidates["random"].append((m, pool_mean_sc))
except Exception:
pass
_t_random = _time.perf_counter() - _t0
# ── phase: mutation ───────────────────────────────────────────────────
_t0 = _time.perf_counter()
try:
mutants = mutate.apply_mutations_mols(
parent_mols, profile, score_threshold=0.0, n_jobs=n_jobs)
# Cap raw mutants: scoring budget for mutate arm * small multiplier.
# apply_mutations_mols on 200 parents produces ~18k candidates but
# we only need ~425. A cap of 5x leaves enough diversity without
# running 18k through the property filter loop.
_mutant_cap = exploit_alloc["mutate"] * 5
if len(mutants) > _mutant_cap:
random.shuffle(mutants)
mutants = mutants[:_mutant_cap]
for m in mutants:
arm_candidates["mutate"].append((m, pool_mean_sc))
except Exception as _e:
if not quiet:
print(f" MUTATE_ERR: {repr(_e)}")
_t_mutate = _time.perf_counter() - _t0
# ── phase: prescreening ───────────────────────────────────────────────
_t0 = _time.perf_counter()
prescreened = {}
_prescreen_counts = {} # raw -> screened per arm, for diagnostics
for arm, alloc in arm_alloc.items():
raw_pairs = arm_candidates[arm]
if not raw_pairs:
prescreened[arm] = []
bandit.update(arm, False)
_prescreen_counts[arm] = (0, 0)
continue
raw_mols = [m for m, _ in raw_pairs]
raw_smi_to_parent_sc = {}
for m, psc in raw_pairs:
try:
raw_smi_to_parent_sc[Chem.MolToSmiles(m)] = psc
except Exception:
pass
screened = _prescreened_candidates(raw_mols, alloc, seen_smis,
maxmin_cap_factor=_maxmin_cap_factor)
prescreened[arm] = [(smi, mol, raw_smi_to_parent_sc.get(smi, pool_mean_sc))
for smi, mol in screened]
_prescreen_counts[arm] = (len(raw_pairs), len(screened))
_t_prescreen = _time.perf_counter() - _t0
# ── phase: scoring ────────────────────────────────────────────────────
all_candidates = []
for arm in EXPLORE_ARMS + EXPLOIT_ARMS:
for item in prescreened[arm]:
all_candidates.append((arm, item[0], item[1], item[2]))
if not all_candidates:
plateau_count += 1
if not quiet:
print(f" Gen {gen_idx+1}/{generations} [no candidates — skipping]")
continue
_t0 = _time.perf_counter()
score_mols = [mol for _, _, mol, _ in all_candidates]
raw_scores = _safe_score(scoring_function, score_mols)
_t_score = _time.perf_counter() - _t0
arm_yields = {a: [] for a in EXPLORE_ARMS + EXPLOIT_ARMS}
for (arm, smi, mol, parent_sc), rs in zip(all_candidates, raw_scores):
internal = rs * sign
pool.append((smi, internal))
if hof._worth_offering(rs):
hof.offer(smi, rs, mol)
arm_yields[arm].append(max(0.0, parent_sc - internal))
# ── immediate pool cull ───────────────────────────────────────────────
# Cull right after scoring so the pool never grows to 20k+ entries.
# Without this, _diverse_top_k at the next gen start gets a huge list.
_pool_hard_cap = popsize * pool_diversity_k * 4 # generous but bounded
if len(pool) > _pool_hard_cap:
pool.sort(key=lambda x: x[1])
pool = pool[:_pool_hard_cap]
# ── phase: HoF offers (counted above, separate timing) ────────────────
for arm, yields in arm_yields.items():
success = any(y > 1e-6 for y in yields) if yields else False
bandit.update(arm, success)
bandit.save()
_t_gen_total = _time.perf_counter() - _t_gen_start
# ── per-generation report ─────────────────────────────────────────────
if not quiet:
best_raw = pool[0][1] / sign if pool else float("nan")
# Candidate counts per arm: "mutate:2134->423"
arm_counts = " ".join(
f"{a}:{_prescreen_counts.get(a,(0,0))[0]}->{_prescreen_counts.get(a,(0,0))[1]}"
for a in EXPLORE_ARMS + EXPLOIT_ARMS
)
print(
f" Gen {gen_idx+1}/{generations}"
f" pool={len(pool)}"
f" explore_f={explore_fraction:.2f}"
f" diversity={diversity:.3f}"
f" plateau={plateau_count}"
f" best_pool={best_raw:.4f}"
f" HoF={len(hof)} best_hof={hof.best_score}"
f"\n TIMING(s): total={_t_gen_total:.1f}"
f" pool_cull={_t_pool_cull:.2f}"
f" diversity={_t_diversity:.2f}"
f" frag={_t_frag:.2f}"
f" crossover={_t_crossover:.2f}"
f" random={_t_random:.2f}"
f" mutate={_t_mutate:.2f}"
f" prescreen={_t_prescreen:.2f}"
f" score={_t_score:.2f}"
f" scored={len(all_candidates)}"
f"\n CANDIDATES: {arm_counts}"
+ (f"\n CROSSOVER_ERR: {_crossover_err}" if _crossover_err else "")
)
if callback is not None:
best_pool_raw = (min(s for _, s in pool) / sign) if pool else None
callback({
"generation": gen_idx + 1,
"explore_fraction": explore_fraction,
"pool_size": len(pool),
"hof_size": len(hof),
"diversity": diversity,
"plateau": plateau_count,
"best_pool": best_pool_raw,
"best_hof": hof.best_score,
"bandit": bandit.hit_rate,
"timing": {
"total": _t_gen_total,
"pool_cull": _t_pool_cull,
"diversity": _t_diversity,
"frag": _t_frag,
"crossover": _t_crossover,
"random": _t_random,
"mutate": _t_mutate,
"prescreen": _t_prescreen,
"score": _t_score,
"n_scored": len(all_candidates),
},
})
if not quiet:
print(f"\n[LACAN GA] Done. HallOfFame: {len(hof)} molecules.")
print("[LACAN GA] Per-arm bandit summary (posterior hit rates):")
print(bandit.summary())
if bandit_stats_file:
print(f"[LACAN GA] Bandit stats saved to {bandit_stats_file}")
return hof.top()
[docs]
class GAReporter:
"""Collects per-generation statistics from :func:`generate_optimized_molecules` and plots them.
Pass an instance as the ``callback=`` argument to the GA. After the run,
call :meth:`plot` to visualise the results, or compare multiple runs by
passing a list of reporters to :func:`compare`.
Parameters
----------
label : str
Name for this run, shown in plot legends (default ``"run"``).
Example
-------
::
reporter = GAReporter(label="docking_preset")
winners = gen.generate_optimized_molecules(
my_score_fn, profile,
preset="docking",
callback=reporter,
)
reporter.plot()
"""
def __init__(self, label="run"):
self.label = label
self.history = []
def __call__(self, stats):
self.history.append(stats)
[docs]
def plot(self, ax=None, show=True):
"""Plot score, diversity, explore fraction, and per-arm hit rates over time."""
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
if not self.history:
print("No data recorded yet.")
return None
gens = [h["generation"] for h in self.history]
best_pool = [h["best_pool"] for h in self.history]
best_hof = [h["best_hof"] for h in self.history]
diversity = [h["diversity"] for h in self.history]
explore_f = [h["explore_fraction"] for h in self.history]
hof_sizes = [h["hof_size"] for h in self.history]
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
fig.suptitle(f"GA run: {self.label}", fontsize=13, fontweight="bold")
# Panel 1: scores
ax1 = axes[0]
ax1.plot(gens, best_pool, "o-", color="steelblue", label="best in pool", lw=1.5)
ax1.plot(gens, best_hof, "s--", color="darkorange", label="best HoF", lw=1.5)
ax1.set_ylabel("Score")
ax1.legend(loc="lower right", fontsize=9)
ax1.set_title("Score per generation")
# Panel 2: diversity + explore fraction + HoF size
ax2 = axes[1]
ax2.plot(gens, diversity, "^-", color="mediumpurple", label="diversity", lw=1.5)
ax2.plot(gens, explore_f, "v--", color="teal", label="explore_fraction", lw=1.5)
ax2_r = ax2.twinx()
ax2_r.bar(gens, hof_sizes, color="salmon", alpha=0.4, label="HoF size")
ax2_r.set_ylabel("HoF size", color="salmon")
ax2_r.tick_params(axis="y", labelcolor="salmon")
ax2.set_ylabel("Fraction / Distance")
ax2.set_ylim(0, 1)
ax2.legend(loc="upper left", fontsize=9)
ax2_r.legend(loc="upper right", fontsize=9)
ax2.set_title("Diversity, explore fraction, and HallOfFame size")
# Panel 3: bandit arm hit rates over time
ax3 = axes[2]
colors3 = plt.rcParams["axes.prop_cycle"].by_key()["color"]
arm_names = list(self.history[0]["bandit"].keys()) if self.history else []
for i, arm in enumerate(arm_names):
rates = [h["bandit"].get(arm, 0) for h in self.history]
ax3.plot(gens, rates, "o-", label=arm,
color=colors3[i % len(colors3)], lw=1.5, alpha=0.85)
ax3.set_ylabel("Posterior hit rate")
ax3.set_xlabel("Generation")
ax3.set_ylim(0, None)
ax3.set_title("Bandit arm posterior hit rates (Thompson Sampling)")
ax3.legend(fontsize=8, loc="upper right", ncol=4)
plt.tight_layout()
if show:
plt.show()
return fig
[docs]
@staticmethod
def compare(reporters, metric="best_pool", show=True):
"""Overlay score curves from multiple runs."""
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(9, 4))
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
for i, rep in enumerate(reporters):
if not rep.history:
continue
gens = [h["generation"] for h in rep.history]
vals = [h[metric] for h in rep.history]
if metric == "best_hof":
vals = [v if v is not None else float("nan") for v in vals]
ax.plot(gens, vals, "o-", color=colors[i % len(colors)],
label=rep.label, lw=1.8)
ylabel = {
"best_pool": "Score (best in pool)",
"best_hof": "Score (best HoF)",
"diversity": "Mean Tanimoto distance",
}.get(metric, metric)
ax.set_ylabel(ylabel)
ax.set_xlabel("Generation")
ax.set_title(f"GA comparison — {metric}")
ax.legend(fontsize=9)
plt.tight_layout()
if show:
plt.show()
return fig
[docs]
def summary(self):
"""Print a compact summary table of the run."""
if not self.history:
print("No data.")
return
print(f"\n{'Gen':>4} {'ExpF':>5} {'Pool':>5} {'HoF':>5} "
f"{'Div':>6} {'BestPool':>9} {'BestHoF':>9}")
print("-" * 58)
for h in self.history:
bh = f"{h['best_hof']:.4f}" if h["best_hof"] is not None else " — "
bp = f"{h['best_pool']:.4f}" if h["best_pool"] is not None else " — "
print(f"{h['generation']:>4} {h['explore_fraction']:>5.2f}"
f" {h['pool_size']:>5} {h['hof_size']:>5} "
f"{h['diversity']:>6.3f} {bp:>9} {bh:>9}")
print()
[docs]
def optimize_from_mol(
mol,
scoring_function,
profile,
generations=20,
beam_width=10,
n_replacements=15,
win_threshold=0.8,
higher_is_better=True,
conservative=True,
fragcorpus=None,
plateau_patience=5,
quiet=False,
callback=None,
protect_smarts=None):
"""Optimise a single seed molecule using fragment-level operations.
Unlike :func:`generate_optimized_molecules`, which explores chemical space
from random starting structures, this function starts from *one specific
molecule* and iteratively improves it by applying ring, linker, substituent,
and decoration operations — then retaining the best results as parents for
the next round (beam search).
Parameters
----------
mol : RDKit Mol — the seed / lead molecule
scoring_function : callable — ``fn(list[Mol]) -> list[float]``
profile : LACAN profile dict
generations : int — maximum number of rounds (default 20)
beam_width : int — top molecules kept as parents each round (default 10)
n_replacements : int — fragment operation attempts per parent (default 15)
win_threshold : float — score at which a molecule is reported as winner
(default 0.8). If no winners found, best beam returned.
higher_is_better : bool (default True)
conservative : bool — prefer similar fragment replacements (default True)
fragcorpus : fragment corpus (default: ChEMBL rls.csv)
plateau_patience : int — stop early after this many generations without
improvement (default 5)
quiet : bool (default False)
callback : callable or None — called each generation with stats dict
protect_smarts : str or None — SMARTS pattern; atoms matching it are skipped
by all fragment operations every generation. The protected
set is re-derived from each parent mol on every call via
:func:`~lacan.protect.get_protected_atoms`, so it survives
the SMILES round-trip in the beam transparently. ``None``
disables atom exclusion entirely (default).
Returns
-------
list of (smiles, score) tuples, sorted best-first
"""
if fragcorpus is None:
fragcorpus = load_corpus()
sign = _sign(higher_is_better)
_win = win_threshold * sign
seed_smi = Chem.MolToSmiles(mol)
seed_score_raw = _safe_score(scoring_function, [mol])[0]
beam = [(seed_smi, seed_score_raw * sign)]
winners = []
all_seen_smis = {seed_smi}
if seed_score_raw * sign < _win:
winners.append((seed_smi, seed_score_raw))
best_internal = beam[0][1]
plateau_count = 0
if not quiet:
print(f"optimize_from_mol: seed score = {seed_score_raw:.4f}")
for gen_idx in range(generations):
parent_mols = [Chem.MolFromSmiles(smi) for smi, _ in beam]
candidates = []
for parent in parent_mols:
candidates += _fragment_moves(parent, profile, fragcorpus,
n_replacements=n_replacements,
conservative=conservative,
protect_smarts=protect_smarts)
novel = []
for m in candidates:
try:
smi = Chem.MolToSmiles(m)
except Exception:
continue
if smi not in all_seen_smis:
all_seen_smis.add(smi)
novel.append((smi, m))
if not novel:
plateau_count += 1
if not quiet:
print(f" Gen {gen_idx+1}: no novel candidates — plateau {plateau_count}/{plateau_patience}")
if plateau_count >= plateau_patience:
break
continue
novel_smis, novel_mols = zip(*novel)
raw_scores = _safe_score(scoring_function, list(novel_mols))
new_beam_entries = list(beam)
gen_best = best_internal
for smi, rs in zip(novel_smis, raw_scores):
internal = rs * sign
if internal < _win:
winners.append((smi, rs))
new_beam_entries.append((smi, internal))
if internal < gen_best:
gen_best = internal
new_beam_entries = sorted(new_beam_entries, key=lambda x: x[1])[:beam_width]
beam = new_beam_entries
if gen_best < best_internal - 1e-6:
best_internal = gen_best
plateau_count = 0
else:
plateau_count += 1
best_raw = best_internal / sign
if not quiet:
print(f" Gen {gen_idx+1}/{generations} | candidates={len(novel)} "
f"| winners={len(winners)} | best={best_raw:.4f} "
f"| plateau={plateau_count}/{plateau_patience}")
if callback is not None:
callback({
"generation": gen_idx + 1,
"n_candidates": len(novel),
"n_winners": len(winners),
"best_score": best_raw,
"plateau": plateau_count,
})
if plateau_count >= plateau_patience:
if not quiet:
print(" Plateau reached — stopping early.")
break
if not winners:
if not quiet:
print(" No winners above threshold — returning best beam members.")
winners = [(smi, sc / sign) for smi, sc in beam]
seen, deduped = set(), []
for smi, sc in sorted(winners, key=lambda x: x[1], reverse=higher_is_better):
if smi not in seen:
seen.add(smi)
deduped.append((smi, sc))
return deduped
[docs]
def next_population(population):
"""Placeholder for a future incremental population-advance API."""
raise NotImplementedError("next_population is not yet implemented")