My solutions to Harvard's online course CS50AI, An Introduction to Machine Learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
7.7 KiB

4 years ago
  1. import itertools
  2. class Sentence():
  3. def evaluate(self, model):
  4. """Evaluates the logical sentence."""
  5. raise Exception("nothing to evaluate")
  6. def formula(self):
  7. """Returns string formula representing logical sentence."""
  8. return ""
  9. def symbols(self):
  10. """Returns a set of all symbols in the logical sentence."""
  11. return set()
  12. @classmethod
  13. def validate(cls, sentence):
  14. if not isinstance(sentence, Sentence):
  15. raise TypeError("must be a logical sentence")
  16. @classmethod
  17. def parenthesize(cls, s):
  18. """Parenthesizes an expression if not already parenthesized."""
  19. def balanced(s):
  20. """Checks if a string has balanced parentheses."""
  21. count = 0
  22. for c in s:
  23. if c == "(":
  24. count += 1
  25. elif c == ")":
  26. if count <= 0:
  27. return False
  28. count -= 1
  29. return count == 0
  30. if not len(s) or s.isalpha() or (
  31. s[0] == "(" and s[-1] == ")" and balanced(s[1:-1])
  32. ):
  33. return s
  34. else:
  35. return f"({s})"
  36. class Symbol(Sentence):
  37. def __init__(self, name):
  38. self.name = name
  39. def __eq__(self, other):
  40. return isinstance(other, Symbol) and self.name == other.name
  41. def __hash__(self):
  42. return hash(("symbol", self.name))
  43. def __repr__(self):
  44. return self.name
  45. def evaluate(self, model):
  46. try:
  47. return bool(model[self.name])
  48. except KeyError:
  49. raise Exception(f"variable {self.name} not in model")
  50. def formula(self):
  51. return self.name
  52. def symbols(self):
  53. return {self.name}
  54. class Not(Sentence):
  55. def __init__(self, operand):
  56. Sentence.validate(operand)
  57. self.operand = operand
  58. def __eq__(self, other):
  59. return isinstance(other, Not) and self.operand == other.operand
  60. def __hash__(self):
  61. return hash(("not", hash(self.operand)))
  62. def __repr__(self):
  63. return f"Not({self.operand})"
  64. def evaluate(self, model):
  65. return not self.operand.evaluate(model)
  66. def formula(self):
  67. return "¬" + Sentence.parenthesize(self.operand.formula())
  68. def symbols(self):
  69. return self.operand.symbols()
  70. class And(Sentence):
  71. def __init__(self, *conjuncts):
  72. for conjunct in conjuncts:
  73. Sentence.validate(conjunct)
  74. self.conjuncts = list(conjuncts)
  75. def __eq__(self, other):
  76. return isinstance(other, And) and self.conjuncts == other.conjuncts
  77. def __hash__(self):
  78. return hash(
  79. ("and", tuple(hash(conjunct) for conjunct in self.conjuncts))
  80. )
  81. def __repr__(self):
  82. conjunctions = ", ".join(
  83. [str(conjunct) for conjunct in self.conjuncts]
  84. )
  85. return f"And({conjunctions})"
  86. def add(self, conjunct):
  87. Sentence.validate(conjunct)
  88. self.conjuncts.append(conjunct)
  89. def evaluate(self, model):
  90. return all(conjunct.evaluate(model) for conjunct in self.conjuncts)
  91. def formula(self):
  92. if len(self.conjuncts) == 1:
  93. return self.conjuncts[0].formula()
  94. return "".join([Sentence.parenthesize(conjunct.formula())
  95. for conjunct in self.conjuncts])
  96. def symbols(self):
  97. return set.union(*[conjunct.symbols() for conjunct in self.conjuncts])
  98. class Or(Sentence):
  99. def __init__(self, *disjuncts):
  100. for disjunct in disjuncts:
  101. Sentence.validate(disjunct)
  102. self.disjuncts = list(disjuncts)
  103. def __eq__(self, other):
  104. return isinstance(other, Or) and self.disjuncts == other.disjuncts
  105. def __hash__(self):
  106. return hash(
  107. ("or", tuple(hash(disjunct) for disjunct in self.disjuncts))
  108. )
  109. def __repr__(self):
  110. disjuncts = ", ".join([str(disjunct) for disjunct in self.disjuncts])
  111. return f"Or({disjuncts})"
  112. def evaluate(self, model):
  113. return any(disjunct.evaluate(model) for disjunct in self.disjuncts)
  114. def formula(self):
  115. if len(self.disjuncts) == 1:
  116. return self.disjuncts[0].formula()
  117. return "".join([Sentence.parenthesize(disjunct.formula())
  118. for disjunct in self.disjuncts])
  119. def symbols(self):
  120. return set.union(*[disjunct.symbols() for disjunct in self.disjuncts])
  121. class Implication(Sentence):
  122. def __init__(self, antecedent, consequent):
  123. Sentence.validate(antecedent)
  124. Sentence.validate(consequent)
  125. self.antecedent = antecedent
  126. self.consequent = consequent
  127. def __eq__(self, other):
  128. return (isinstance(other, Implication)
  129. and self.antecedent == other.antecedent
  130. and self.consequent == other.consequent)
  131. def __hash__(self):
  132. return hash(("implies", hash(self.antecedent), hash(self.consequent)))
  133. def __repr__(self):
  134. return f"Implication({self.antecedent}, {self.consequent})"
  135. def evaluate(self, model):
  136. return ((not self.antecedent.evaluate(model))
  137. or self.consequent.evaluate(model))
  138. def formula(self):
  139. antecedent = Sentence.parenthesize(self.antecedent.formula())
  140. consequent = Sentence.parenthesize(self.consequent.formula())
  141. return f"{antecedent} => {consequent}"
  142. def symbols(self):
  143. return set.union(self.antecedent.symbols(), self.consequent.symbols())
  144. class Biconditional(Sentence):
  145. def __init__(self, left, right):
  146. Sentence.validate(left)
  147. Sentence.validate(right)
  148. self.left = left
  149. self.right = right
  150. def __eq__(self, other):
  151. return (isinstance(other, Biconditional)
  152. and self.left == other.left
  153. and self.right == other.right)
  154. def __hash__(self):
  155. return hash(("biconditional", hash(self.left), hash(self.right)))
  156. def __repr__(self):
  157. return f"Biconditional({self.left}, {self.right})"
  158. def evaluate(self, model):
  159. return ((self.left.evaluate(model)
  160. and self.right.evaluate(model))
  161. or (not self.left.evaluate(model)
  162. and not self.right.evaluate(model)))
  163. def formula(self):
  164. left = Sentence.parenthesize(str(self.left))
  165. right = Sentence.parenthesize(str(self.right))
  166. return f"{left} <=> {right}"
  167. def symbols(self):
  168. return set.union(self.left.symbols(), self.right.symbols())
  169. def model_check(knowledge, query):
  170. """Checks if knowledge base entails query."""
  171. def check_all(knowledge, query, symbols, model):
  172. """Checks if knowledge base entails query, given a particular model."""
  173. # If model has an assignment for each symbol
  174. if not symbols:
  175. # If knowledge base is true in model, then query must also be true
  176. if knowledge.evaluate(model):
  177. return query.evaluate(model)
  178. return True
  179. else:
  180. # Choose one of the remaining unused symbols
  181. remaining = symbols.copy()
  182. p = remaining.pop()
  183. # Create a model where the symbol is true
  184. model_true = model.copy()
  185. model_true[p] = True
  186. # Create a model where the symbol is false
  187. model_false = model.copy()
  188. model_false[p] = False
  189. # Ensure entailment holds in both models
  190. return (check_all(knowledge, query, remaining, model_true) and
  191. check_all(knowledge, query, remaining, model_false))
  192. # Get all symbols in both knowledge and query
  193. symbols = set.union(knowledge.symbols(), query.symbols())
  194. # Check that knowledge entails query
  195. return check_all(knowledge, query, symbols, dict())