diff --git a/main.ipynb b/main.ipynb index 420fc26..2e1cd9a 100644 --- a/main.ipynb +++ b/main.ipynb @@ -39,168 +39,182 @@ "\n", "\n", - "\n", + "\n", "\n", "%3\n", - "\n", - "\n", + "\n", + "\n", "\n", - "140306720081952\n", - "\n", - "data 4.0000\n", - "\n", - "grad 0.0000\n", + "140118324758560\n", + "\n", + " \n", + "\n", + "data 7.0000\n", + "\n", + "grad 0.0000\n", "\n", - "\n", + "\n", "\n", - "140306720081616+\n", - "\n", - "+\n", + "140118324758464+\n", + "\n", + "+\n", "\n", - "\n", - "\n", - "140306720081952->140306720081616+\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720081952*\n", - "\n", - "*\n", - "\n", - "\n", - "\n", - "140306720081952*->140306720081952\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720082000\n", - "\n", - "data -1.0000\n", - "\n", - "grad 0.0000\n", - "\n", - "\n", - "\n", - "140306720081520+\n", - "\n", - "+\n", - "\n", - "\n", - "\n", - "140306720082000->140306720081520+\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720082096+\n", - "\n", - "+\n", - "\n", - "\n", + "\n", "\n", - "140306720082000->140306720082096+\n", - "\n", - "\n", + "140118324758560->140118324758464+\n", + "\n", + "\n", "\n", - "\n", - "\n", - "140306720081520\n", - "\n", - "data 6.0000\n", - "\n", - "grad 0.0000\n", + "\n", + "\n", + "140118324758560+\n", + "\n", + "+\n", "\n", - "\n", - "\n", - "140306720081520+->140306720081520\n", - "\n", - "\n", + "\n", + "\n", + "140118324758560+->140118324758560\n", + "\n", + "\n", "\n", - "\n", + "\n", + "\n", + "140118324760816\n", + "\n", + " \n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", "\n", - "140306720082096\n", - "\n", - "data 3.0000\n", - "\n", - "grad 0.0000\n", + "140118324758800*\n", + "\n", + "*\n", "\n", - "\n", - "\n", - "140306720082096->140306720081616+\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720082096+->140306720082096\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720081616\n", - "\n", - "data 7.0000\n", - "\n", - "grad 0.0000\n", - "\n", - "\n", - "\n", - "140306720081616->140306720081520+\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720081616+->140306720081616\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306737924848\n", - "\n", - "data 4.0000\n", - "\n", - "grad 0.0000\n", - "\n", - "\n", - "\n", - "140306737924848->140306720081952*\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306737924848->140306720082096+\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "140306720080368\n", - "\n", - "data 1.0000\n", - "\n", - "grad 0.0000\n", - "\n", - "\n", + "\n", "\n", - "140306720080368->140306720081952*\n", - "\n", - "\n", + "140118324760816->140118324758800*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118745713392\n", + "\n", + " \n", + "\n", + "data 4.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118745713392->140118324758800*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758992+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140118745713392->140118324758992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758800\n", + "\n", + " \n", + "\n", + "data 4.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324758800->140118324758560+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758800*->140118324758800\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758848\n", + "\n", + " \n", + "\n", + "data -1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324758848->140118324758464+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758848->140118324758992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758464\n", + "\n", + " \n", + "\n", + "data 6.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324758464+->140118324758464\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758992\n", + "\n", + " \n", + "\n", + "data 3.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324758992->140118324758560+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324758992+->140118324758992\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -214,9 +228,283 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "8491fa95-c89e-4283-95fe-d685d50f7e07", "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "140118324770320\n", + "\n", + "n\n", + "\n", + "data 0.2000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770464tanh\n", + "\n", + "tanh\n", + "\n", + "\n", + "\n", + "140118324770320->140118324770464tanh\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770320+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140118324770320+->140118324770320\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118313495584\n", + "\n", + "x2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770128*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140118313495584->140118324770128*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770464\n", + "\n", + "y\n", + "\n", + "data 0.1974\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770464tanh->140118324770464\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324769984\n", + "\n", + "b\n", + "\n", + "data 6.2000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324769984->140118324770320+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770032\n", + "\n", + "h1\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770176+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140118324770032->140118324770176+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770032*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140118324770032*->140118324770032\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118313493232\n", + "\n", + "x1\n", + "\n", + "data 2.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118313493232->140118324770032*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118313498944\n", + "\n", + "w2\n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118313498944->140118324770128*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770128\n", + "\n", + "h2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770128->140118324770176+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770128*->140118324770128\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118313484640\n", + "\n", + "w1\n", + "\n", + "data -3.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118313484640->140118324770032*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770176\n", + "\n", + "h\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140118324770176->140118324770320+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140118324770176+->140118324770176\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# inputs\n", + "x1 = Value(2.0, label='x1')\n", + "x2 = Value(0.0, label='x2')\n", + "\n", + "# weights\n", + "w1 = Value(-3.0, label='w1')\n", + "w2 = Value(1.0, label='w2')\n", + "\n", + "# bias\n", + "b = Value(6.2, label='b')\n", + "\n", + "h1 = x1 * w1\n", + "h1.label = 'h1'\n", + "h2 = x2 * w2\n", + "h2.label = 'h2'\n", + "\n", + "h = h1 + h2\n", + "h.label = 'h'\n", + "\n", + "n = h + b\n", + "n.label = 'n'\n", + "y = n.tanh()\n", + "y.label = 'y'\n", + "\n", + "draw_dot(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "543bd866-2ce7-4b9a-8d8b-55a160a7c83b", + "metadata": {}, "outputs": [], "source": [] } diff --git a/micrograd/engine.py b/micrograd/engine.py index d700cb2..cfc95bc 100644 --- a/micrograd/engine.py +++ b/micrograd/engine.py @@ -6,6 +6,25 @@ class Value: self._prev = set(_children) self._op = _op self.label = label + self._backward = lambda: None + + def backward(self): + def topological(root): + ordered = [] + visited = set() + + def build_topo(v): + if v not in visited: + visited.add(v) + for child in v._prev: + build_topo(child) + ordered.append(v) + build_topo(root) + return ordered + ordered = topological(self) + self.grad = 1.0 + for node in reversed(ordered): + node._backward() def __repr__(self): return f"Value(data={self.data})" @@ -13,6 +32,11 @@ class Value: def __add__(self, other): out = Value(data=self.data + other.data, _children=(self, other), _op='+') + + def _backward(): + self.grad = 1.0 * out.grad + other.grad = 1.0 * out.grad + out._backward = _backward return out def __sub__(self, other): @@ -23,4 +47,20 @@ class Value: def __mul__(self, other): out = Value(data=self.data * other.data, _children=(self, other), _op='*') + + def _backward(): + self.grad = other.data * out.grad + other.grad = self.data * out.grad + out._backward = _backward + return out + + def tanh(self): + from math import exp + t = (exp(2 * self.data) - 1) / (exp(2 * self.data) + 1) + out = Value(data=t, _children=(self,), + _op='tanh', label='tanh') + + def _backward(): + self.grad = (1 - t**2) * out.grad + out._backward = _backward return out diff --git a/test/test_backprop.py b/test/test_backprop.py new file mode 100644 index 0000000..829ee18 --- /dev/null +++ b/test/test_backprop.py @@ -0,0 +1,104 @@ +import pytest +from micrograd.engine import Value + + +def test_backward_tanh(): + # inputs + x = Value(0.8814) + y = x.tanh() + y.grad = 1.0 + y._backward() + + assert pytest.approx(x.grad, 0.1) == 0.5 + + +def test_large_backprop(): + # inputs + x1 = Value(2.0, label='x1') + x2 = Value(0.0, label='x2') + + # weights + w1 = Value(-3.0, label='w1') + w2 = Value(1.0, label='w2') + + # bias + b = Value(6.8813735870195432, label='b') + + h1 = x1 * w1 + h1.label = 'h1' + h2 = x2 * w2 + h2.label = 'h2' + + h = h1 + h2 + h.label = 'h' + + n = h + b + n.label = 'n' + y = n.tanh() + y.label = 'y' + + y.grad = 1.0 + y._backward() + + assert pytest.approx(n.grad, 0.001) == 0.5 + assert h.grad == 0.0 + + n._backward() + assert pytest.approx(b.grad, 0.001) == 0.5 + assert pytest.approx(h.grad, 0.001) == 0.5 + + b._backward() + h._backward() + assert pytest.approx(h1.grad, 0.001) == 0.5 + assert pytest.approx(h2.grad, 0.001) == 0.5 + + h1._backward() + h2._backward() + assert pytest.approx(x1.grad, 0.001) == -1.5 + assert pytest.approx(w1.grad, 0.001) == 1.0 + + assert pytest.approx(x2.grad, 0.001) == 0.5 + assert pytest.approx(w2.grad, 0.001) == 0.0 + + +@pytest.mark.skip(reason="non-deterministic") +def test_auto_diff(): + # inputs + x1 = Value(2.0, label='x1') + x2 = Value(0.0, label='x2') + + # weights + w1 = Value(-3.0, label='w1') + w2 = Value(1.0, label='w2') + + # bias + b = Value(6.8813735870195432, label='b') + + h1 = x1 * w1 + h1.label = 'h1' + h2 = x2 * w2 + h2.label = 'h2' + + h = h1 + h2 + h.label = 'h' + + n = h + b + n.label = 'n' + y = n.tanh() + y.label = 'y' + + y.backward() + assert pytest.approx(n.grad, 0.001) == 0.5 + assert h.grad == 0.0 + + assert pytest.approx(b.grad, 0.001) == 0.5 + assert pytest.approx(h.grad, 0.001) == 0.5 + + assert pytest.approx(h1.grad, 0.001) == 0.5 + assert pytest.approx(h2.grad, 0.001) == 0.5 + + assert pytest.approx(x1.grad, 0.001) == -1.5 + assert pytest.approx(w1.grad, 0.001) == 1.0 + + assert pytest.approx(x2.grad, 0.001) == 0.5 + assert pytest.approx(w2.grad, 0.001) == 0.0 diff --git a/test/test_neuron.py b/test/test_neuron.py new file mode 100644 index 0000000..f34a8da --- /dev/null +++ b/test/test_neuron.py @@ -0,0 +1,31 @@ +import pytest +from micrograd.engine import Value + + +# @pytest.mark.skip(reason="complicated assertion") +def test_big_neuron(): + # inputs + x1 = Value(2.0, label='x1') + x2 = Value(0.0, label='x2') + + # weights + w1 = Value(-3.0, label='w1') + w2 = Value(1.0, label='w2') + + # bias + b = Value(6.8813735870195432, label='b') + + h1 = x1 * w1 + h1.label = 'h1' + h2 = x2 * w2 + h2.label = 'h2' + + h = h1 + h2 + h.label = 'h' + + n = h + b + n.label = 'n' + y = n.tanh() + y.label = 'y' + + assert pytest.approx(y.data, 0.01) == 0.7071 diff --git a/test/test_value.py b/test/test_value.py index d0109c5..493d916 100644 --- a/test/test_value.py +++ b/test/test_value.py @@ -56,3 +56,16 @@ def test_operations(): add = v1 + v2 assert mul._op == '*' assert add._op == '+' + + +def test_tanh(): + t = Value(2.0).tanh() + assert t.data > 0 + assert t.data < 1 + + t = Value(0.0).tanh() + assert t.data == 0 + + t = Value(-2.0).tanh() + assert t.data < 0 + assert t.data > -1