My solutions to Harvard's online course CS50AI, An Introduction to Machine Learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
6.1 KiB

4 years ago
  1. import csv
  2. import itertools
  3. import sys
  4. PROBS = {
  5. # Unconditional probabilities for having gene
  6. "gene": {
  7. 2: 0.01,
  8. 1: 0.03,
  9. 0: 0.96
  10. },
  11. "trait": {
  12. # Probability of trait given two copies of gene
  13. 2: {
  14. True: 0.65,
  15. False: 0.35
  16. },
  17. # Probability of trait given one copy of gene
  18. 1: {
  19. True: 0.56,
  20. False: 0.44
  21. },
  22. # Probability of trait given no gene
  23. 0: {
  24. True: 0.01,
  25. False: 0.99
  26. }
  27. },
  28. # Mutation probability
  29. "mutation": 0.01
  30. }
  31. def main():
  32. # Check for proper usage
  33. if len(sys.argv) != 2:
  34. sys.exit("Usage: python heredity.py data.csv")
  35. people = load_data(sys.argv[1])
  36. # Keep track of gene and trait probabilities for each person
  37. probabilities = {
  38. person: {
  39. "gene": {
  40. 2: 0,
  41. 1: 0,
  42. 0: 0
  43. },
  44. "trait": {
  45. True: 0,
  46. False: 0
  47. }
  48. }
  49. for person in people
  50. }
  51. # Loop over all sets of people who might have the trait
  52. names = set(people)
  53. for have_trait in powerset(names):
  54. # Check if current set of people violates known information
  55. fails_evidence = any(
  56. (people[person]["trait"] is not None and
  57. people[person]["trait"] != (person in have_trait))
  58. for person in names
  59. )
  60. if fails_evidence:
  61. continue
  62. # Loop over all sets of people who might have the gene
  63. for one_gene in powerset(names):
  64. for two_genes in powerset(names - one_gene):
  65. # Update probabilities with new joint probability
  66. p = joint_probability(people, one_gene, two_genes, have_trait)
  67. update(probabilities, one_gene, two_genes, have_trait, p)
  68. # Ensure probabilities sum to 1
  69. normalize(probabilities)
  70. # Print results
  71. for person in people:
  72. print(f"{person}:")
  73. for field in probabilities[person]:
  74. print(f" {field.capitalize()}:")
  75. for value in probabilities[person][field]:
  76. p = probabilities[person][field][value]
  77. print(f" {value}: {p:.4f}")
  78. def load_data(filename):
  79. """
  80. Load gene and trait data from a file into a dictionary.
  81. File assumed to be a CSV containing fields name, mother, father, trait.
  82. mother, father must both be blank, or both be valid names in the CSV.
  83. trait should be 0 or 1 if trait is known, blank otherwise.
  84. """
  85. data = dict()
  86. with open(filename) as f:
  87. reader = csv.DictReader(f)
  88. for row in reader:
  89. name = row["name"]
  90. data[name] = {
  91. "name": name,
  92. "mother": row["mother"] or None,
  93. "father": row["father"] or None,
  94. "trait": (True if row["trait"] == "1" else
  95. False if row["trait"] == "0" else None)
  96. }
  97. return data
  98. def powerset(s):
  99. """
  100. Return a list of all possible subsets of set s.
  101. """
  102. s = list(s)
  103. return [
  104. set(s) for s in itertools.chain.from_iterable(
  105. itertools.combinations(s, r) for r in range(len(s) + 1)
  106. )
  107. ]
  108. def get_info(person, one_gene, two_genes, have_trait):
  109. trait = person in have_trait
  110. gene = 0
  111. if person in one_gene:
  112. gene = 1
  113. elif person in two_genes:
  114. gene = 2
  115. return gene, trait
  116. def joint_probability(people, one_gene, two_genes, have_trait):
  117. """
  118. Compute and return a joint probability.
  119. The probability returned should be the probability that
  120. * everyone in set `one_gene` has one copy of the gene, and
  121. * everyone in set `two_genes` has two copies of the gene, and
  122. * everyone not in `one_gene` or `two_gene` does not have the gene, and
  123. * everyone in set `have_trait` has the trait, and
  124. * everyone not in set` have_trait` does not have the trait.
  125. """
  126. def generate_prob(m_gene, f_gene, gene_combination):
  127. if m_gene == 1:
  128. m_prob = 0.5
  129. else:
  130. m_prob = 0.99 if m_gene/2 == gene_combination[0] else 0.01
  131. if f_gene == 1:
  132. f_prob = 0.5
  133. else:
  134. f_prob = 0.99 if f_gene/2 == gene_combination[1] else 0.01
  135. return m_prob * f_prob
  136. probabilities = []
  137. for person in people:
  138. gene, trait = get_info(person, one_gene, two_genes, have_trait)
  139. if people[person]["mother"] and people[person]["father"]:
  140. mother_gene, foo = get_info(people[person]["mother"], one_gene, two_genes, have_trait)
  141. father_gene, foo = get_info(people[person]["father"], one_gene, two_genes, have_trait)
  142. if gene == 1:
  143. gene_prob = generate_prob(mother_gene, father_gene, (0, 1)) + generate_prob(mother_gene, father_gene, (1, 0))
  144. else:
  145. gene_prob = generate_prob(mother_gene, father_gene, (gene/2, gene/2))
  146. else:
  147. gene_prob = PROBS["gene"][gene]
  148. probabilities.append(gene_prob * PROBS["trait"][gene][trait])
  149. joint_prob = 1
  150. for p in probabilities:
  151. joint_prob *= p
  152. return joint_prob
  153. def update(probabilities, one_gene, two_genes, have_trait, p):
  154. for person in probabilities:
  155. gene, trait = get_info(person, one_gene, two_genes, have_trait)
  156. probabilities[person]["gene"][gene] += p
  157. probabilities[person]["trait"][trait] += p
  158. def normalize(probabilities):
  159. for person in probabilities:
  160. psum = 0
  161. for gene in probabilities[person]["gene"]:
  162. psum += probabilities[person]["gene"][gene]
  163. gene_ratio = 1/psum
  164. for gene in probabilities[person]["gene"]:
  165. probabilities[person]["gene"][gene] *= gene_ratio
  166. psum = 0
  167. for trait in probabilities[person]["trait"]:
  168. psum += probabilities[person]["trait"][trait]
  169. trait_ratio = 1/psum
  170. for trait in probabilities[person]["trait"]:
  171. probabilities[person]["trait"][trait] *= trait_ratio
  172. if __name__ == "__main__":
  173. main()