selection.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package gago
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand"
  6. "sort"
  7. )
  8. // Selector chooses a subset of size n from a group of individuals. The group of
  9. // individuals a Selector is applied to is expected to be sorted.
  10. type Selector interface {
  11. Apply(n int, indis Individuals, rng *rand.Rand) (selected Individuals, indexes []int, err error)
  12. Validate() error
  13. }
  14. // SelElitism selection returns the n best individuals of a group.
  15. type SelElitism struct{}
  16. // Apply SelElitism.
  17. func (sel SelElitism) Apply(n int, indis Individuals, rng *rand.Rand) (Individuals, []int, error) {
  18. indis.SortByFitness()
  19. return indis[:n].Clone(rng), newInts(n), nil
  20. }
  21. // Validate SelElitism fields.
  22. func (sel SelElitism) Validate() error {
  23. return nil
  24. }
  25. // SelTournament samples individuals through tournament selection. The
  26. // tournament is composed of randomly chosen individuals. The winner of the
  27. // tournament is the chosen individual with the lowest fitness. The obtained
  28. // individuals are all distinct.
  29. type SelTournament struct {
  30. NContestants int
  31. }
  32. // Apply SelTournament.
  33. func (sel SelTournament) Apply(n int, indis Individuals, rng *rand.Rand) (Individuals, []int, error) {
  34. // Check that the number of individuals is large enough
  35. if len(indis)-n < sel.NContestants-1 {
  36. return nil, nil, fmt.Errorf("Not enough individuals to select %d "+
  37. "with NContestants = %d, have %d individuals and need at least %d",
  38. n, sel.NContestants, len(indis), sel.NContestants+n-1)
  39. }
  40. var (
  41. winners = make(Individuals, n)
  42. indexes = make([]int, n)
  43. notSelectedIdxs = newInts(len(indis))
  44. )
  45. for i := range winners {
  46. // Sample contestants
  47. var (
  48. contestants, idxs, _ = sampleInts(notSelectedIdxs, sel.NContestants, rng)
  49. winnerIdx int
  50. )
  51. // Find the best contestant
  52. winners[i] = indis[contestants[0]]
  53. winners[i].Evaluate()
  54. for j, idx := range contestants[1:] {
  55. if indis[idx].GetFitness() < winners[i].Fitness {
  56. winners[i] = indis[idx]
  57. indexes[i] = idx
  58. winnerIdx = idxs[j]
  59. }
  60. }
  61. // Ban the winner from re-participating
  62. notSelectedIdxs = append(notSelectedIdxs[:winnerIdx], notSelectedIdxs[winnerIdx+1:]...)
  63. }
  64. return winners.Clone(rng), indexes, nil
  65. }
  66. // Validate SelTournament fields.
  67. func (sel SelTournament) Validate() error {
  68. if sel.NContestants < 1 {
  69. return errors.New("NContestants should be higher than 0")
  70. }
  71. return nil
  72. }
  73. // SelRoulette samples individuals through roulette wheel selection (also known
  74. // as fitness proportionate selection).
  75. type SelRoulette struct{}
  76. func buildWheel(fitnesses []float64) []float64 {
  77. var (
  78. n = len(fitnesses)
  79. wheel = make([]float64, n)
  80. )
  81. for i, v := range fitnesses {
  82. wheel[i] = fitnesses[n-1] - v + 1
  83. }
  84. return cumsum(divide(wheel, sumFloat64s(wheel)))
  85. }
  86. // Apply SelRoulette.
  87. func (sel SelRoulette) Apply(n int, indis Individuals, rng *rand.Rand) (Individuals, []int, error) {
  88. var (
  89. selected = make(Individuals, n)
  90. indexes = make([]int, n)
  91. wheel = buildWheel(indis.getFitnesses())
  92. )
  93. for i := range selected {
  94. var (
  95. index = sort.SearchFloat64s(wheel, rand.Float64())
  96. winner = indis[index]
  97. )
  98. indexes[i] = index
  99. selected[i] = winner
  100. }
  101. return selected.Clone(rng), indexes, nil
  102. }
  103. // Validate SelRoulette fields.
  104. func (sel SelRoulette) Validate() error {
  105. return nil
  106. }