add def for add,mul,draw value

This commit is contained in:
publicmatt 2024-02-22 07:15:45 -08:00
parent 045469f53e
commit 0ed01f7a84
7 changed files with 369 additions and 0 deletions

245
main.ipynb Normal file
View File

@ -0,0 +1,245 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bb9a1e74-8c5c-42bb-93fb-bfd3d6169dbd",
"metadata": {},
"outputs": [],
"source": [
"from micrograd.engine import Value\n",
"from micrograd.draw import draw_dot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "afbfde7d-1592-476d-9ea1-399434752e8a",
"metadata": {},
"outputs": [],
"source": [
"x = Value(1.0)\n",
"y = Value(4.0)\n",
"z = Value(-1)\n",
"f = (y+z) + (x*y) + z"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "35bdfb93-8f89-4dd8-8c50-29774dd04984",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"1215pt\" height=\"155pt\"\n",
" viewBox=\"0.00 0.00 1215.00 155.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 151)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-151 1211,-151 1211,4 -4,4\"/>\n",
"<!-- 140306720081952 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>140306720081952</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"337,-8.5 337,-44.5 543,-44.5 543,-8.5 337,-8.5\"/>\n",
"<text text-anchor=\"middle\" x=\"388\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 4.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"439,-8.5 439,-44.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"491\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720081616+ -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>140306720081616+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"606\" cy=\"-81.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"606\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140306720081952&#45;&gt;140306720081616+ -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>140306720081952&#45;&gt;140306720081616+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M514.99,-44.53C524.49,-47.3 534.02,-50.31 543,-53.5 553.54,-57.25 564.77,-62.05 574.74,-66.62\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"573.3,-69.82 583.84,-70.89 576.27,-63.48 573.3,-69.82\"/>\n",
"</g>\n",
"<!-- 140306720081952* -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>140306720081952*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"274\" cy=\"-26.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"274\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 140306720081952*&#45;&gt;140306720081952 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>140306720081952*&#45;&gt;140306720081952</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M301.35,-26.5C308.8,-26.5 317.4,-26.5 326.61,-26.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"326.8,-30 336.8,-26.5 326.8,-23 326.8,-30\"/>\n",
"</g>\n",
"<!-- 140306720082000 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>140306720082000</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-110.5 0,-146.5 211,-146.5 211,-110.5 0,-110.5\"/>\n",
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"107,-110.5 107,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"159\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720081520+ -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>140306720081520+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"938\" cy=\"-104.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"938\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140306720082000&#45;&gt;140306720081520+ -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>140306720082000&#45;&gt;140306720081520+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M211.33,-128.5C276.78,-128.5 362.8,-128.5 439,-128.5 439,-128.5 439,-128.5 607,-128.5 714.29,-128.5 840.7,-115.72 901.2,-108.81\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"901.84,-112.26 911.37,-107.64 901.04,-105.31 901.84,-112.26\"/>\n",
"</g>\n",
"<!-- 140306720082096+ -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>140306720082096+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"274\" cy=\"-81.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"274\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140306720082000&#45;&gt;140306720082096+ -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>140306720082000&#45;&gt;140306720082096+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M178.61,-110.47C189.49,-107.57 200.56,-104.53 211,-101.5 220.41,-98.77 230.56,-95.59 239.86,-92.59\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"241.13,-95.86 249.55,-89.43 238.96,-89.2 241.13,-95.86\"/>\n",
"</g>\n",
"<!-- 140306720081520 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>140306720081520</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1001,-86.5 1001,-122.5 1207,-122.5 1207,-86.5 1001,-86.5\"/>\n",
"<text text-anchor=\"middle\" x=\"1052\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1103,-86.5 1103,-122.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1155\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720081520+&#45;&gt;140306720081520 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>140306720081520+&#45;&gt;140306720081520</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M965.35,-104.5C972.8,-104.5 981.4,-104.5 990.61,-104.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"990.8,-108 1000.8,-104.5 990.8,-101 990.8,-108\"/>\n",
"</g>\n",
"<!-- 140306720082096 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>140306720082096</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"337,-63.5 337,-99.5 543,-99.5 543,-63.5 337,-63.5\"/>\n",
"<text text-anchor=\"middle\" x=\"388\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"439,-63.5 439,-99.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"491\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720082096&#45;&gt;140306720081616+ -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>140306720082096&#45;&gt;140306720081616+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M543.47,-81.5C552.37,-81.5 560.93,-81.5 568.69,-81.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"568.72,-85 578.72,-81.5 568.72,-78 568.72,-85\"/>\n",
"</g>\n",
"<!-- 140306720082096+&#45;&gt;140306720082096 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>140306720082096+&#45;&gt;140306720082096</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M301.35,-81.5C308.8,-81.5 317.4,-81.5 326.61,-81.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"326.8,-85 336.8,-81.5 326.8,-78 326.8,-85\"/>\n",
"</g>\n",
"<!-- 140306720081616 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>140306720081616</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"669,-63.5 669,-99.5 875,-99.5 875,-63.5 669,-63.5\"/>\n",
"<text text-anchor=\"middle\" x=\"720\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 7.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"771,-63.5 771,-99.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"823\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720081616&#45;&gt;140306720081520+ -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>140306720081616&#45;&gt;140306720081520+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M875.47,-95.87C884.65,-97.16 893.47,-98.39 901.42,-99.51\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"900.94,-102.98 911.33,-100.9 901.91,-96.04 900.94,-102.98\"/>\n",
"</g>\n",
"<!-- 140306720081616+&#45;&gt;140306720081616 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>140306720081616+&#45;&gt;140306720081616</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M633.35,-81.5C640.8,-81.5 649.4,-81.5 658.61,-81.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"658.8,-85 668.8,-81.5 658.8,-78 658.8,-85\"/>\n",
"</g>\n",
"<!-- 140306737924848 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>140306737924848</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-55.5 2.5,-91.5 208.5,-91.5 208.5,-55.5 2.5,-55.5\"/>\n",
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 4.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-55.5 104.5,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"156.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306737924848&#45;&gt;140306720081952* -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>140306737924848&#45;&gt;140306720081952*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M175.17,-55.47C187.14,-52.23 199.45,-48.82 211,-45.5 220.3,-42.83 230.34,-39.81 239.58,-36.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"240.69,-40.3 249.21,-34.01 238.62,-33.61 240.69,-40.3\"/>\n",
"</g>\n",
"<!-- 140306737924848&#45;&gt;140306720082096+ -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>140306737924848&#45;&gt;140306720082096+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M208.63,-78.41C218.68,-78.89 228.34,-79.35 236.99,-79.77\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"236.83,-83.27 246.99,-80.25 237.17,-76.27 236.83,-83.27\"/>\n",
"</g>\n",
"<!-- 140306720080368 -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>140306720080368</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-0.5 2.5,-36.5 208.5,-36.5 208.5,-0.5 2.5,-0.5\"/>\n",
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-0.5 104.5,-36.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"156.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140306720080368&#45;&gt;140306720081952* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>140306720080368&#45;&gt;140306720081952*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M208.63,-23.41C218.68,-23.89 228.34,-24.35 236.99,-24.77\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"236.83,-28.27 246.99,-25.25 237.17,-21.27 236.83,-28.27\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f9bb42e4e80>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8491fa95-c89e-4283-95fe-d685d50f7e07",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0
micrograd/__init__.py Normal file
View File

37
micrograd/draw.py Normal file
View File

@ -0,0 +1,37 @@
from graphviz import Digraph
def trace(root):
nodes, edges = set(), set()
def build(v):
if v not in nodes:
nodes.add(v)
for child in v._prev:
edges.add((child, v))
build(child)
build(root)
return nodes, edges
def draw_dot(root, format='svg', rankdir='LR'):
"""
format: png | svg | ...
rankdir: TB (top to bottom graph) | LR (left to right)
"""
assert rankdir in ['LR', 'TB']
nodes, edges = trace(root)
# , node_attr={'rankdir': 'TB'})
dot = Digraph(format=format, graph_attr={'rankdir': rankdir})
for n in nodes:
dot.node(name=str(id(n)), label="{ %s | data %.4f | grad %.4f }"
% (n.label, n.data, n.grad), shape='record')
if n._op:
dot.node(name=str(id(n)) + n._op, label=n._op)
dot.edge(str(id(n)) + n._op, str(id(n)))
for n1, n2 in edges:
dot.edge(str(id(n1)), str(id(n2)) + n2._op)
return dot

26
micrograd/engine.py Normal file
View File

@ -0,0 +1,26 @@
class Value:
def __init__(self, data, _children=(), _op='', label=''):
self.data = data
self.grad = 0.0
self._prev = set(_children)
self._op = _op
self.label = label
def __repr__(self):
return f"Value(data={self.data})"
def __add__(self, other):
out = Value(data=self.data + other.data,
_children=(self, other), _op='+')
return out
def __sub__(self, other):
out = Value(data=self.data - other.data,
_children=(self, other), _op='-')
return out
def __mul__(self, other):
out = Value(data=self.data * other.data,
_children=(self, other), _op='*')
return out

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
pytest
notebook
graphviz

0
test/__init__.py Normal file
View File

58
test/test_value.py Normal file
View File

@ -0,0 +1,58 @@
from micrograd.engine import Value
def test_value_init():
v = Value(1)
assert v.data == 1
def test_value_repr():
v = Value(2.0)
assert "Value(data=2.0)" == repr(v)
def test_value_add():
v1 = Value(2.0)
v2 = Value(4.0)
assert (v1 + v2).data == 6.0
assert "Value(data=6.0)" == repr(v1 + v2)
def test_value_sub():
v1 = Value(2.0)
v2 = Value(4.0)
assert (v1 - v2).data == -2.0
assert "Value(data=-2.0)" == repr(v1 - v2)
def test_value_mul():
v1 = Value(2.0)
v2 = Value(4.0)
v3 = Value(-1.0)
assert (v1 * v2).data == 8.0
assert (v1 * v3).data == -2.0
def test_value_mul_add():
v1 = Value(2.0)
v2 = Value(4.0)
v3 = Value(-1.0)
assert ((v1 * v3) + v2).data == 2.0
def test_children():
v1 = Value(2.0)
v2 = Value(4.0)
out = v1 * v2
assert len(out._prev) == 2
assert v1 in out._prev
assert v2 in out._prev
def test_operations():
v1 = Value(2.0)
v2 = Value(4.0)
mul = v1 * v2
add = v1 + v2
assert mul._op == '*'
assert add._op == '+'