Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutation #13

Merged
merged 15 commits into from
Jan 30, 2024
Merged

Mutation #13

merged 15 commits into from
Jan 30, 2024

Conversation

oddoking
Copy link
Contributor

add mutation random key and mutate prob

Copy link
Contributor

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for doing this!

I left some comments. I have time now, so if you need any help just let me know!

Comment on lines 154 to 161
if mutation is None:
self.mutation = 0.0
elif mutation > 0 and mutation < 1:
self.mutation = mutation
else:
raise ValueError(
f"mutation must be between 0 and 1, but got {mutation}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify this, you can use as default 0 in init definition mutation: float = 0f

Then here, something like this:

if mutation < 0 or mutation > 0:
   raise ValueError(
                 f"mutation must be between 0 and 1, but got {mutation}"
             )
self.mutation = mutation

Comment on lines 177 to 178
print(f'shape is {self.random_key.shape} {self.random_key[0]}')
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

rec_sites = samples < recombination_vec
crossover_mask = jax.lax.associative_scan(jnp.logical_xor, rec_sites)

crossover_mask = crossover_mask.astype(jnp.int8)
haploid = jnp.take_along_axis(individual, crossover_mask[:, None], axis=-1)

mutation_samples = jax.random.uniform(mutate_random_key, shape=haploid.shape)
mutation_sites = mutation_samples > mutate_probability
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be mutation_samples < mutate_probability, isn't it?
If mutate_probability = 0, mutation_sites should always be False

Comment on lines 50 to 63
random_keys = jax.random.split(random_key, num=len(parents) * 2 * parents.shape[3])
random_keys = random_keys.reshape(len(parents), 2, parents.shape[3], 2)
offsprings = _cross(parents, recombination_vec, random_keys)
cross_random_key = jax.random.split(cross_random_key, num=len(parents) * 2 * parents.shape[3])
cross_random_key = cross_random_key.reshape(len(parents), 2, parents.shape[3], 2)

mutate_split_key = jax.random.split(mutate_split_key, num=len(parents) * 2 * parents.shape[3])
mutate_split_key = mutate_split_key.reshape(len(parents), 2, parents.shape[3], 2)

offsprings = _cross(parents, recombination_vec, cross_random_key, mutate_split_key, mutate_probability)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid adding the other keys as argument, but I would generate them internally here.
So you can do:

random_keys = jax.random.split(random_key, num=2 * len(parents) * 2 * parents.shape[3])
random_keys = random_keys.reshape(3, len(parents), 2, parents.shape[3], 2)
cross_random_key, mutate_random_key = random_keys

Also, use default value of 0 for mutate_probability, so previous code will continue working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random_keys = random_keys.reshape(3, len(parents), 2, parents.shape[3], 2)

you mean
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3], 2)?

Comment on lines 84 to 86
cross_random_key: jax.random.PRNGKeyArray,
mutate_random_key: jax.random.PRNGKeyArray,
mutate_probability: float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as cross

@oddoking
Copy link
Contributor Author

have a look new change

@younik younik merged commit 450ca24 into kora-labs:master Jan 30, 2024
5 checks passed
@younik
Copy link
Contributor

younik commented Jan 30, 2024

Merged, thank you! @oddoking

@younik younik mentioned this pull request Jan 30, 2024
@oddoking oddoking deleted the mutation branch March 9, 2024 01:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants