From 107c3103446d63ac885bc0da1f22d2b3701a3332 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Tue, 16 Jan 2024 09:42:49 +0800 Subject: [PATCH 1/4] return selected indices --- chromax/functional.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chromax/functional.py b/chromax/functional.py index 7d8c820..5efcfea 100644 --- a/chromax/functional.py +++ b/chromax/functional.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Int from .typing import N_MARKERS, Haploid, Individual, Parents, Population @@ -142,7 +142,7 @@ def select( population: Population["n"], k: int, f_index: Callable[[Population["n"]], Float[Array, "n"]], -) -> Population["k"]: +) -> (Population["k"], Int[Array, "k"]): """Function to select individuals based on their score (index). :param population: input grouped population of shape (n, m, d) @@ -154,8 +154,8 @@ def select( (n, m, 2) and returns an array of n float number. :type f_index: Callable - :return: output population of (k, m, d) - :rtype: ndarray + :return: output population of (k, m, d), output indecies of (k,) + :rtype: ndarray, ndarray :Example: >>> from chromax import functional @@ -174,4 +174,4 @@ def select( """ indices = f_index(population) _, best_pop = jax.lax.top_k(indices, k) - return population[best_pop, :, :] + return population[best_pop, :, :], best_pop From 35b24eab4ab19acc9e6ed5b7124345952fd3d68d Mon Sep 17 00:00:00 2001 From: zhanglei Date: Wed, 17 Jan 2024 09:24:58 +0800 Subject: [PATCH 2/4] update typing to tuple --- chromax/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chromax/functional.py b/chromax/functional.py index 5efcfea..bfbdeb9 100644 --- a/chromax/functional.py +++ b/chromax/functional.py @@ -1,6 +1,6 @@ """Functional module.""" from functools import partial -from typing import Callable +from typing import Callable, Tuple import jax import jax.numpy as jnp @@ -142,7 +142,7 @@ def select( population: Population["n"], k: int, f_index: Callable[[Population["n"]], Float[Array, "n"]], -) -> (Population["k"], Int[Array, "k"]): +) -> Tuple[Population["k"], Int[Array, "k"]]: """Function to select individuals based on their score (index). :param population: input grouped population of shape (n, m, d) @@ -168,7 +168,7 @@ def select( >>> marker_effects = np.random.randn(n_chr * chr_len) >>> gebv_model = TraitModel(marker_effects[:, None]) >>> f_index = conventional_index(gebv_model) - >>> f2 = functional.select(f1, k=10, f_index=f_index) + >>> f2, selected_indices = functional.select(f1, k=10, f_index=f_index) >>> f2.shape (10, 1000, 2) """ From be5d52a12f6ecd75b0bb6c1cca31f4f83c03a547 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Thu, 18 Jan 2024 08:54:23 +0800 Subject: [PATCH 3/4] fix typo, fix tests --- chromax/functional.py | 2 +- examples/sample_usage.ipynb | 2 +- examples/time_wheat_bp.py | 8 ++++---- examples/wheat_bp.py | 12 ++++++------ tests/test_functional.py | 2 +- tests/test_simulator.py | 6 +++--- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/chromax/functional.py b/chromax/functional.py index bfbdeb9..c434612 100644 --- a/chromax/functional.py +++ b/chromax/functional.py @@ -154,7 +154,7 @@ def select( (n, m, 2) and returns an array of n float number. :type f_index: Callable - :return: output population of (k, m, d), output indecies of (k,) + :return: output population of (k, m, d), output indices of (k,) :rtype: ndarray, ndarray :Example: diff --git a/examples/sample_usage.ipynb b/examples/sample_usage.ipynb index c76333a..99bed27 100644 --- a/examples/sample_usage.ipynb +++ b/examples/sample_usage.ipynb @@ -53,7 +53,7 @@ } ], "source": [ - "selected_ind = simulator.select(population, k=10)\n", + "selected_ind, _ = simulator.select(population, k=10)\n", "simulator.GEBV(selected_ind).mean()" ] }, diff --git a/examples/time_wheat_bp.py b/examples/time_wheat_bp.py index 30c351a..49d3a2d 100644 --- a/examples/time_wheat_bp.py +++ b/examples/time_wheat_bp.py @@ -20,18 +20,18 @@ def wheat_schema( # dh_lines2 = simulator.double_haploid(f1[100*factor:], n_offspring=100) # dh_lines = jax.numpy.concatenate((dh_lines1, dh_lines2)) - headrows = simulator.select(dh_lines, 5, visual_selection(simulator, seed=7)) + headrows, _ = simulator.select(dh_lines, 5, visual_selection(simulator, seed=7)) headrows = headrows.reshape(1000 * factor, -1, 2) envs = simulator.create_environments(num_environments=16) - pyt = simulator.select( + pyt, _ = simulator.select( headrows, k=100 * factor, f_index=phenotype_index(simulator, envs[0]) ) - ayt = simulator.select( + ayt, _ = simulator.select( pyt, k=10 * factor, f_index=phenotype_index(simulator, envs[:4]) ) - released_variety = simulator.select( + released_variety, _ = simulator.select( ayt, k=1, f_index=phenotype_index(simulator, envs) ) diff --git a/examples/wheat_bp.py b/examples/wheat_bp.py index 4e11616..70d9042 100644 --- a/examples/wheat_bp.py +++ b/examples/wheat_bp.py @@ -16,23 +16,23 @@ def wheat_schema( ) -> Tuple[Population["50"], Individual]: f1, _ = simulator.random_crosses(germplasm, 100) dh_lines = simulator.double_haploid(f1, n_offspring=100) - headrows = simulator.select( + headrows, _ = simulator.select( dh_lines, k=5, f_index=visual_selection(simulator, seed=7) ).reshape(len(dh_lines) * 5, *dh_lines.shape[2:]) - hdrw_next_cycle = simulator.select( + hdrw_next_cycle, _ = simulator.select( dh_lines.reshape(dh_lines.shape[0] * dh_lines.shape[1], *dh_lines.shape[2:]), k=20, f_index=visual_selection(simulator, seed=7), ) envs = simulator.create_environments(num_environments=16) - pyt = simulator.select(headrows, k=50, f_index=phenotype_index(simulator, envs[0])) - pyt_next_cycle = simulator.select( + pyt, _ = simulator.select(headrows, k=50, f_index=phenotype_index(simulator, envs[0])) + pyt_next_cycle, _ = simulator.select( headrows, k=20, f_index=phenotype_index(simulator, envs[0]) ) - ayt = simulator.select(pyt, k=10, f_index=phenotype_index(simulator, envs[:4])) + ayt, _ = simulator.select(pyt, k=10, f_index=phenotype_index(simulator, envs[:4])) - released_variety = simulator.select( + released_variety, _ = simulator.select( ayt, k=1, f_index=phenotype_index(simulator, envs) ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 34275e9..4e65857 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -44,7 +44,7 @@ def test_select(): marker_effects = np.random.randn(n_markers) gebv_model = TraitModel(marker_effects[:, None]) f_index = conventional_index(gebv_model) - f2 = functional.select(f1, k=k, f_index=f_index) + f2, _ = functional.select(f1, k=k, f_index=f_index) assert f2.shape == (k, *f1.shape[1:]) f1_gebv = gebv_model(f1) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 310b3b4..db820fe 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -123,14 +123,14 @@ def test_select(): population = simulator.load_population(n_ind, ploidy=ploidy) pop_GEBV = simulator.GEBV(population) - selected_pop = simulator.select(population, k=10) + selected_pop, _ = simulator.select(population, k=10) selected_GEBV = simulator.GEBV(selected_pop) assert np.all(selected_GEBV.mean() > pop_GEBV.mean()) assert np.all(selected_GEBV.max() == pop_GEBV.max()) assert np.all(selected_GEBV.min() > pop_GEBV.min()) dh = simulator.double_haploid(population, n_offspring=100) - selected_dh = simulator.select(dh, k=5) + selected_dh, _ = simulator.select(dh, k=5) assert selected_dh.shape == (n_ind, 5, n_markers, ploidy) for i in range(n_ind): dh_GEBV = simulator.GEBV(dh[i]) @@ -190,7 +190,7 @@ def test_device(): GEBV = simulator.GEBV_model(population) assert GEBV.device_buffer.device() == device - selected_pop = simulator.select(population, k=10) + selected_pop, _ = simulator.select(population, k=10) assert selected_pop.device_buffer.device() == device diallel = simulator.diallel(selected_pop, n_offspring=10) From 703e8fe3f39f2922a8be5425ca0f7f9cbd8a48db Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 18 Jan 2024 12:11:46 +0100 Subject: [PATCH 4/4] improve tests + fix doc Simulator.select --- chromax/functional.py | 4 ++-- chromax/simulator.py | 12 +++++++----- examples/wheat_bp.py | 4 +++- tests/test_functional.py | 3 ++- tests/test_simulator.py | 15 ++++++++++++--- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/chromax/functional.py b/chromax/functional.py index c434612..45e2fa0 100644 --- a/chromax/functional.py +++ b/chromax/functional.py @@ -154,8 +154,8 @@ def select( (n, m, 2) and returns an array of n float number. :type f_index: Callable - :return: output population of (k, m, d), output indices of (k,) - :rtype: ndarray, ndarray + :return: output population of shape (k, m, d), output indices of shape (k,) + :rtype: tuple of two ndarrays :Example: >>> from chromax import functional diff --git a/chromax/simulator.py b/chromax/simulator.py index 08f5244..34a624b 100644 --- a/chromax/simulator.py +++ b/chromax/simulator.py @@ -409,7 +409,7 @@ def select( population: Population["_g n"], k: int, f_index: Optional[Callable[[Population["n"]], Float[Array, "n"]]] = None, - ) -> Population["_g k"]: + ) -> Tuple[Population["_g k"], Int[Array, "_g k"]]: """Function to select individuals based on their score (index). :param population: input population of shape (n, m, d), @@ -423,9 +423,9 @@ def select( i.e. the sum of the marker effects masked with the SNPs from the genetic_map. :type f_index: Callable - :return: output population of shape (k, m, d) or (g, k, m, d), - depending on the input population. - :rtype: ndarray + :return: output population of shape (k, m, d) or (g, k, m, d), depending on the input + population, and respective indices of shape (k,) or (g, k) + :rtype: tuple of two ndarrays :Example: >>> from chromax import Simulator, sample_data @@ -433,9 +433,11 @@ def select( >>> f1 = simulator.load_population(sample_data.genome) >>> len(f1), simulator.GEBV(f1).mean().values (371, [8.223844]) - >>> f2 = simulator.select(f1, k=20) + >>> f2, selected_indices = simulator.select(f1, k=20) >>> len(f2), simulator.GEBV(f2).mean().values (20, [14.595136]) + >>> selected_indices.shape + (20,) """ if f_index is None: f_index = conventional_index(self.GEBV_model) diff --git a/examples/wheat_bp.py b/examples/wheat_bp.py index 70d9042..eaa6e09 100644 --- a/examples/wheat_bp.py +++ b/examples/wheat_bp.py @@ -26,7 +26,9 @@ def wheat_schema( ) envs = simulator.create_environments(num_environments=16) - pyt, _ = simulator.select(headrows, k=50, f_index=phenotype_index(simulator, envs[0])) + pyt, _ = simulator.select( + headrows, k=50, f_index=phenotype_index(simulator, envs[0]) + ) pyt_next_cycle, _ = simulator.select( headrows, k=20, f_index=phenotype_index(simulator, envs[0]) ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4e65857..0ac9057 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -44,8 +44,9 @@ def test_select(): marker_effects = np.random.randn(n_markers) gebv_model = TraitModel(marker_effects[:, None]) f_index = conventional_index(gebv_model) - f2, _ = functional.select(f1, k=k, f_index=f_index) + f2, best_indices = functional.select(f1, k=k, f_index=f_index) assert f2.shape == (k, *f1.shape[1:]) + assert best_indices.shape == (k,) f1_gebv = gebv_model(f1) f2_gebv = gebv_model(f2) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index db820fe..3417c7a 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -123,21 +123,30 @@ def test_select(): population = simulator.load_population(n_ind, ploidy=ploidy) pop_GEBV = simulator.GEBV(population) - selected_pop, _ = simulator.select(population, k=10) + k = 10 + selected_pop, selected_indices = simulator.select(population, k=k) + assert selected_pop.shape == (k, n_markers, ploidy) + assert selected_indices.shape == (k,) selected_GEBV = simulator.GEBV(selected_pop) assert np.all(selected_GEBV.mean() > pop_GEBV.mean()) assert np.all(selected_GEBV.max() == pop_GEBV.max()) assert np.all(selected_GEBV.min() > pop_GEBV.min()) + GEBV_indices = pop_GEBV.iloc[selected_indices] + assert np.all(GEBV_indices.reset_index(drop=True) == selected_GEBV) + k = 5 dh = simulator.double_haploid(population, n_offspring=100) - selected_dh, _ = simulator.select(dh, k=5) - assert selected_dh.shape == (n_ind, 5, n_markers, ploidy) + selected_dh, selected_indices = simulator.select(dh, k=k) + assert selected_dh.shape == (n_ind, k, n_markers, ploidy) + assert selected_indices.shape == (n_ind, k) for i in range(n_ind): dh_GEBV = simulator.GEBV(dh[i]) selected_GEBV = simulator.GEBV(selected_dh[i]) assert np.all(selected_GEBV.mean() > dh_GEBV.mean()) assert np.all(selected_GEBV.max() == dh_GEBV.max()) assert np.all(selected_GEBV.min() > dh_GEBV.min()) + GEBV_indices = dh_GEBV.iloc[selected_indices[i]] + assert np.all(GEBV_indices.reset_index(drop=True) == selected_GEBV) def test_random_crosses():