ga.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package gago
  2. import (
  3. "errors"
  4. "log"
  5. "math/rand"
  6. "sync"
  7. "time"
  8. "golang.org/x/sync/errgroup"
  9. )
  10. // A GA contains population which themselves contain individuals.
  11. type GA struct {
  12. // Fields that are provided by the user
  13. GenomeFactory GenomeFactory `json:"-"`
  14. NPops int `json:"-"` // Number of Populations
  15. PopSize int `json:"-"` // Number of Individuls per Population
  16. Model Model `json:"-"`
  17. Migrator Migrator `json:"-"`
  18. MigFrequency int `json:"-"` // Frequency at which migrations occur
  19. Speciator Speciator `json:"-"`
  20. Logger *log.Logger `json:"-"`
  21. Callback func(ga *GA) `json:"-"`
  22. // Fields that are generated at runtime
  23. Populations Populations `json:"pops"`
  24. Best Individual `json:"best"` // Overall best individual
  25. Age time.Duration `json:"duration"`
  26. Generations int `json:"generations"`
  27. rng *rand.Rand
  28. }
  29. // Validate the parameters of a GA to ensure it will run correctly; some
  30. // settings or combination of settings may be incoherent during runtime.
  31. func (ga GA) Validate() error {
  32. // Check the GenomeFactory presence
  33. if ga.GenomeFactory == nil {
  34. return errors.New("GenomeFactory cannot be nil")
  35. }
  36. // Check the number of populations is higher than 0
  37. if ga.NPops < 1 {
  38. return errors.New("NPops should be higher than 0")
  39. }
  40. // Check the number of individuals per population is higher than 0
  41. if ga.PopSize < 1 {
  42. return errors.New("PopSize should be higher than 0")
  43. }
  44. // Check the model presence
  45. if ga.Model == nil {
  46. return errors.New("Model cannot be nil")
  47. }
  48. // Check the model is valid
  49. var modelErr = ga.Model.Validate()
  50. if modelErr != nil {
  51. return modelErr
  52. }
  53. // Check the migration frequency if a Migrator has been provided
  54. if ga.Migrator != nil && ga.MigFrequency < 1 {
  55. return errors.New("MigFrequency should be strictly higher than 0")
  56. }
  57. // Check the speciator is valid if it has been provided
  58. if ga.Speciator != nil {
  59. if specErr := ga.Speciator.Validate(); specErr != nil {
  60. return specErr
  61. }
  62. }
  63. // No error
  64. return nil
  65. }
  66. // Find the best individual in each population and then compare the best overall
  67. // individual to the current best individual. This method supposes that the
  68. // populations have been preemptively ascendingly sorted by fitness so that
  69. // checking the first individual of each population is sufficient.
  70. func (ga *GA) findBest() {
  71. for _, pop := range ga.Populations {
  72. var best = pop.Individuals[0]
  73. if best.Fitness < ga.Best.Fitness {
  74. ga.Best = best.Clone(pop.rng)
  75. }
  76. }
  77. }
  78. // Initialize each population in the GA and assign an initial fitness to each
  79. // individual in each population. Running Initialize after running Enhance will
  80. // reset the GA entirely.
  81. func (ga *GA) Initialize() {
  82. ga.Populations = make([]Population, ga.NPops)
  83. ga.rng = newRandomNumberGenerator()
  84. var wg sync.WaitGroup
  85. for i := range ga.Populations {
  86. wg.Add(1)
  87. go func(j int) {
  88. defer wg.Done()
  89. // Generate a population
  90. ga.Populations[j] = newPopulation(
  91. ga.PopSize,
  92. ga.GenomeFactory,
  93. randString(3, ga.rng),
  94. )
  95. // Evaluate its individuals
  96. ga.Populations[j].Individuals.Evaluate()
  97. // Sort its individuals
  98. ga.Populations[j].Individuals.SortByFitness()
  99. // Log current statistics if a logger has been provided
  100. if ga.Logger != nil {
  101. ga.Populations[j].Log(ga.Logger)
  102. }
  103. }(i)
  104. }
  105. wg.Wait()
  106. // The initial best individual is initialized randomly
  107. var rng = newRandomNumberGenerator()
  108. ga.Best = NewIndividual(ga.GenomeFactory(rng), rng)
  109. ga.findBest()
  110. // Execute the callback if it has been set
  111. if ga.Callback != nil {
  112. ga.Callback(ga)
  113. }
  114. }
  115. // Enhance each population in the GA. The population level operations are done
  116. // in parallel with a wait group. After all the population operations have been
  117. // run, the GA level operations are run.
  118. func (ga *GA) Enhance() error {
  119. var start = time.Now()
  120. ga.Generations++
  121. // Migrate the individuals between the populations if there are at least 2
  122. // Populations and that there is a migrator and that the migration frequency
  123. // divides the generation count
  124. if len(ga.Populations) > 1 && ga.Migrator != nil && ga.Generations%ga.MigFrequency == 0 {
  125. ga.Migrator.Apply(ga.Populations, ga.rng)
  126. }
  127. var g errgroup.Group
  128. for i := range ga.Populations {
  129. i := i // https://golang.org/doc/faq#closures_and_goroutines
  130. g.Go(func() error {
  131. var err error
  132. // Apply speciation if a positive number of species has been specified
  133. if ga.Speciator != nil {
  134. err = ga.Populations[i].speciateEvolveMerge(ga.Speciator, ga.Model)
  135. if err != nil {
  136. return err
  137. }
  138. } else {
  139. // Else apply the evolution model to the entire population
  140. err = ga.Model.Apply(&ga.Populations[i])
  141. if err != nil {
  142. return err
  143. }
  144. }
  145. // Evaluate and sort
  146. ga.Populations[i].Individuals.Evaluate()
  147. ga.Populations[i].Individuals.SortByFitness()
  148. ga.Populations[i].Age += time.Since(start)
  149. ga.Populations[i].Generations++
  150. // Log current statistics if a logger has been provided
  151. if ga.Logger != nil {
  152. ga.Populations[i].Log(ga.Logger)
  153. }
  154. return err
  155. })
  156. }
  157. if err := g.Wait(); err != nil {
  158. return err
  159. }
  160. // Check if there is an individual that is better than the current one
  161. ga.findBest()
  162. ga.Age += time.Since(start)
  163. // Execute the callback if it has been set
  164. if ga.Callback != nil {
  165. ga.Callback(ga)
  166. }
  167. // No error
  168. return nil
  169. }
  170. func (pop *Population) speciateEvolveMerge(spec Speciator, model Model) error {
  171. var (
  172. species, err = spec.Apply(pop.Individuals, pop.rng)
  173. pops = make([]Population, len(species))
  174. )
  175. if err != nil {
  176. return err
  177. }
  178. // Create a subpopulation from each specie so that the evolution Model can
  179. // be applied to it.
  180. for i, specie := range species {
  181. pops[i] = Population{
  182. Individuals: specie,
  183. Age: pop.Age,
  184. Generations: pop.Generations,
  185. ID: randString(len(pop.ID), pop.rng),
  186. rng: pop.rng,
  187. }
  188. err = model.Apply(&pops[i])
  189. if err != nil {
  190. return err
  191. }
  192. }
  193. // Merge each species back into the original population
  194. var i int
  195. for _, subpop := range pops {
  196. copy(pop.Individuals[i:i+len(subpop.Individuals)], subpop.Individuals)
  197. i += len(subpop.Individuals)
  198. }
  199. return nil
  200. }