diff --git a/main.ipynb b/main.ipynb new file mode 100644 index 0000000..420fc26 --- /dev/null +++ b/main.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bb9a1e74-8c5c-42bb-93fb-bfd3d6169dbd", + "metadata": {}, + "outputs": [], + "source": [ + "from micrograd.engine import Value\n", + "from micrograd.draw import draw_dot" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "afbfde7d-1592-476d-9ea1-399434752e8a", + "metadata": {}, + "outputs": [], + "source": [ + "x = Value(1.0)\n", + "y = Value(4.0)\n", + "z = Value(-1)\n", + "f = (y+z) + (x*y) + z" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "35bdfb93-8f89-4dd8-8c50-29774dd04984", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "140306720081952\n", + "\n", + "data 4.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720081616+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306720081952->140306720081616+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720081952*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306720081952*->140306720081952\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720082000\n", + "\n", + "data -1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720081520+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306720082000->140306720081520+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720082096+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306720082000->140306720082096+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720081520\n", + "\n", + "data 6.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720081520+->140306720081520\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720082096\n", + "\n", + "data 3.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720082096->140306720081616+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720082096+->140306720082096\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720081616\n", + "\n", + "data 7.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720081616->140306720081520+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720081616+->140306720081616\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306737924848\n", + "\n", + "data 4.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306737924848->140306720081952*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306737924848->140306720082096+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306720080368\n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306720080368->140306720081952*\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "draw_dot(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8491fa95-c89e-4283-95fe-d685d50f7e07", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/micrograd/__init__.py b/micrograd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/micrograd/draw.py b/micrograd/draw.py new file mode 100644 index 0000000..d555b4b --- /dev/null +++ b/micrograd/draw.py @@ -0,0 +1,37 @@ +from graphviz import Digraph + + +def trace(root): + nodes, edges = set(), set() + + def build(v): + if v not in nodes: + nodes.add(v) + for child in v._prev: + edges.add((child, v)) + build(child) + build(root) + return nodes, edges + + +def draw_dot(root, format='svg', rankdir='LR'): + """ + format: png | svg | ... + rankdir: TB (top to bottom graph) | LR (left to right) + """ + assert rankdir in ['LR', 'TB'] + nodes, edges = trace(root) + # , node_attr={'rankdir': 'TB'}) + dot = Digraph(format=format, graph_attr={'rankdir': rankdir}) + + for n in nodes: + dot.node(name=str(id(n)), label="{ %s | data %.4f | grad %.4f }" + % (n.label, n.data, n.grad), shape='record') + if n._op: + dot.node(name=str(id(n)) + n._op, label=n._op) + dot.edge(str(id(n)) + n._op, str(id(n))) + + for n1, n2 in edges: + dot.edge(str(id(n1)), str(id(n2)) + n2._op) + + return dot diff --git a/micrograd/engine.py b/micrograd/engine.py new file mode 100644 index 0000000..d700cb2 --- /dev/null +++ b/micrograd/engine.py @@ -0,0 +1,26 @@ +class Value: + + def __init__(self, data, _children=(), _op='', label=''): + self.data = data + self.grad = 0.0 + self._prev = set(_children) + self._op = _op + self.label = label + + def __repr__(self): + return f"Value(data={self.data})" + + def __add__(self, other): + out = Value(data=self.data + other.data, + _children=(self, other), _op='+') + return out + + def __sub__(self, other): + out = Value(data=self.data - other.data, + _children=(self, other), _op='-') + return out + + def __mul__(self, other): + out = Value(data=self.data * other.data, + _children=(self, other), _op='*') + return out diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fed7ac3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +pytest +notebook +graphviz diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_value.py b/test/test_value.py new file mode 100644 index 0000000..d0109c5 --- /dev/null +++ b/test/test_value.py @@ -0,0 +1,58 @@ +from micrograd.engine import Value + + +def test_value_init(): + v = Value(1) + assert v.data == 1 + + +def test_value_repr(): + v = Value(2.0) + assert "Value(data=2.0)" == repr(v) + + +def test_value_add(): + v1 = Value(2.0) + v2 = Value(4.0) + assert (v1 + v2).data == 6.0 + assert "Value(data=6.0)" == repr(v1 + v2) + + +def test_value_sub(): + v1 = Value(2.0) + v2 = Value(4.0) + assert (v1 - v2).data == -2.0 + assert "Value(data=-2.0)" == repr(v1 - v2) + + +def test_value_mul(): + v1 = Value(2.0) + v2 = Value(4.0) + v3 = Value(-1.0) + assert (v1 * v2).data == 8.0 + assert (v1 * v3).data == -2.0 + + +def test_value_mul_add(): + v1 = Value(2.0) + v2 = Value(4.0) + v3 = Value(-1.0) + assert ((v1 * v3) + v2).data == 2.0 + + +def test_children(): + v1 = Value(2.0) + v2 = Value(4.0) + out = v1 * v2 + assert len(out._prev) == 2 + assert v1 in out._prev + assert v2 in out._prev + + +def test_operations(): + v1 = Value(2.0) + v2 = Value(4.0) + mul = v1 * v2 + add = v1 + v2 + assert mul._op == '*' + assert add._op == '+'