add def for add,mul,draw value
This commit is contained in:
parent
045469f53e
commit
0ed01f7a84
|
@ -0,0 +1,245 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "bb9a1e74-8c5c-42bb-93fb-bfd3d6169dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from micrograd.engine import Value\n",
|
||||
"from micrograd.draw import draw_dot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "afbfde7d-1592-476d-9ea1-399434752e8a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = Value(1.0)\n",
|
||||
"y = Value(4.0)\n",
|
||||
"z = Value(-1)\n",
|
||||
"f = (y+z) + (x*y) + z"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "35bdfb93-8f89-4dd8-8c50-29774dd04984",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||
"<!-- Generated by graphviz version 2.43.0 (0)\n",
|
||||
" -->\n",
|
||||
"<!-- Title: %3 Pages: 1 -->\n",
|
||||
"<svg width=\"1215pt\" height=\"155pt\"\n",
|
||||
" viewBox=\"0.00 0.00 1215.00 155.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 151)\">\n",
|
||||
"<title>%3</title>\n",
|
||||
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-151 1211,-151 1211,4 -4,4\"/>\n",
|
||||
"<!-- 140306720081952 -->\n",
|
||||
"<g id=\"node1\" class=\"node\">\n",
|
||||
"<title>140306720081952</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"337,-8.5 337,-44.5 543,-44.5 543,-8.5 337,-8.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"388\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 4.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"439,-8.5 439,-44.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"491\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081616+ -->\n",
|
||||
"<g id=\"node9\" class=\"node\">\n",
|
||||
"<title>140306720081616+</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"606\" cy=\"-81.5\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"606\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081952->140306720081616+ -->\n",
|
||||
"<g id=\"edge11\" class=\"edge\">\n",
|
||||
"<title>140306720081952->140306720081616+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M514.99,-44.53C524.49,-47.3 534.02,-50.31 543,-53.5 553.54,-57.25 564.77,-62.05 574.74,-66.62\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"573.3,-69.82 583.84,-70.89 576.27,-63.48 573.3,-69.82\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081952* -->\n",
|
||||
"<g id=\"node2\" class=\"node\">\n",
|
||||
"<title>140306720081952*</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"274\" cy=\"-26.5\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"274\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081952*->140306720081952 -->\n",
|
||||
"<g id=\"edge1\" class=\"edge\">\n",
|
||||
"<title>140306720081952*->140306720081952</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M301.35,-26.5C308.8,-26.5 317.4,-26.5 326.61,-26.5\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"326.8,-30 336.8,-26.5 326.8,-23 326.8,-30\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082000 -->\n",
|
||||
"<g id=\"node3\" class=\"node\">\n",
|
||||
"<title>140306720082000</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-110.5 0,-146.5 211,-146.5 211,-110.5 0,-110.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -1.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"107,-110.5 107,-146.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"159\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081520+ -->\n",
|
||||
"<g id=\"node5\" class=\"node\">\n",
|
||||
"<title>140306720081520+</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"938\" cy=\"-104.5\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"938\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082000->140306720081520+ -->\n",
|
||||
"<g id=\"edge6\" class=\"edge\">\n",
|
||||
"<title>140306720082000->140306720081520+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M211.33,-128.5C276.78,-128.5 362.8,-128.5 439,-128.5 439,-128.5 439,-128.5 607,-128.5 714.29,-128.5 840.7,-115.72 901.2,-108.81\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"901.84,-112.26 911.37,-107.64 901.04,-105.31 901.84,-112.26\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082096+ -->\n",
|
||||
"<g id=\"node7\" class=\"node\">\n",
|
||||
"<title>140306720082096+</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"274\" cy=\"-81.5\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"274\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082000->140306720082096+ -->\n",
|
||||
"<g id=\"edge9\" class=\"edge\">\n",
|
||||
"<title>140306720082000->140306720082096+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M178.61,-110.47C189.49,-107.57 200.56,-104.53 211,-101.5 220.41,-98.77 230.56,-95.59 239.86,-92.59\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"241.13,-95.86 249.55,-89.43 238.96,-89.2 241.13,-95.86\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081520 -->\n",
|
||||
"<g id=\"node4\" class=\"node\">\n",
|
||||
"<title>140306720081520</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"1001,-86.5 1001,-122.5 1207,-122.5 1207,-86.5 1001,-86.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"1052\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"1103,-86.5 1103,-122.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"1155\" y=\"-100.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081520+->140306720081520 -->\n",
|
||||
"<g id=\"edge2\" class=\"edge\">\n",
|
||||
"<title>140306720081520+->140306720081520</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M965.35,-104.5C972.8,-104.5 981.4,-104.5 990.61,-104.5\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"990.8,-108 1000.8,-104.5 990.8,-101 990.8,-108\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082096 -->\n",
|
||||
"<g id=\"node6\" class=\"node\">\n",
|
||||
"<title>140306720082096</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"337,-63.5 337,-99.5 543,-99.5 543,-63.5 337,-63.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"388\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"439,-63.5 439,-99.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"491\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082096->140306720081616+ -->\n",
|
||||
"<g id=\"edge5\" class=\"edge\">\n",
|
||||
"<title>140306720082096->140306720081616+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M543.47,-81.5C552.37,-81.5 560.93,-81.5 568.69,-81.5\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"568.72,-85 578.72,-81.5 568.72,-78 568.72,-85\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720082096+->140306720082096 -->\n",
|
||||
"<g id=\"edge3\" class=\"edge\">\n",
|
||||
"<title>140306720082096+->140306720082096</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M301.35,-81.5C308.8,-81.5 317.4,-81.5 326.61,-81.5\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"326.8,-85 336.8,-81.5 326.8,-78 326.8,-85\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081616 -->\n",
|
||||
"<g id=\"node8\" class=\"node\">\n",
|
||||
"<title>140306720081616</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"669,-63.5 669,-99.5 875,-99.5 875,-63.5 669,-63.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"720\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 7.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"771,-63.5 771,-99.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"823\" y=\"-77.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081616->140306720081520+ -->\n",
|
||||
"<g id=\"edge8\" class=\"edge\">\n",
|
||||
"<title>140306720081616->140306720081520+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M875.47,-95.87C884.65,-97.16 893.47,-98.39 901.42,-99.51\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"900.94,-102.98 911.33,-100.9 901.91,-96.04 900.94,-102.98\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720081616+->140306720081616 -->\n",
|
||||
"<g id=\"edge4\" class=\"edge\">\n",
|
||||
"<title>140306720081616+->140306720081616</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M633.35,-81.5C640.8,-81.5 649.4,-81.5 658.61,-81.5\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"658.8,-85 668.8,-81.5 658.8,-78 658.8,-85\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306737924848 -->\n",
|
||||
"<g id=\"node10\" class=\"node\">\n",
|
||||
"<title>140306737924848</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-55.5 2.5,-91.5 208.5,-91.5 208.5,-55.5 2.5,-55.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 4.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-55.5 104.5,-91.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"156.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306737924848->140306720081952* -->\n",
|
||||
"<g id=\"edge10\" class=\"edge\">\n",
|
||||
"<title>140306737924848->140306720081952*</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M175.17,-55.47C187.14,-52.23 199.45,-48.82 211,-45.5 220.3,-42.83 230.34,-39.81 239.58,-36.98\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"240.69,-40.3 249.21,-34.01 238.62,-33.61 240.69,-40.3\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306737924848->140306720082096+ -->\n",
|
||||
"<g id=\"edge12\" class=\"edge\">\n",
|
||||
"<title>140306737924848->140306720082096+</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M208.63,-78.41C218.68,-78.89 228.34,-79.35 236.99,-79.77\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"236.83,-83.27 246.99,-80.25 237.17,-76.27 236.83,-83.27\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720080368 -->\n",
|
||||
"<g id=\"node11\" class=\"node\">\n",
|
||||
"<title>140306720080368</title>\n",
|
||||
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-0.5 2.5,-36.5 208.5,-36.5 208.5,-0.5 2.5,-0.5\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"53.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n",
|
||||
"<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-0.5 104.5,-36.5 \"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"156.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 140306720080368->140306720081952* -->\n",
|
||||
"<g id=\"edge7\" class=\"edge\">\n",
|
||||
"<title>140306720080368->140306720081952*</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M208.63,-23.41C218.68,-23.89 228.34,-24.35 236.99,-24.77\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"236.83,-28.27 246.99,-25.25 237.17,-21.27 236.83,-28.27\"/>\n",
|
||||
"</g>\n",
|
||||
"</g>\n",
|
||||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.graphs.Digraph at 0x7f9bb42e4e80>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"draw_dot(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8491fa95-c89e-4283-95fe-d685d50f7e07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
from graphviz import Digraph
|
||||
|
||||
|
||||
def trace(root):
|
||||
nodes, edges = set(), set()
|
||||
|
||||
def build(v):
|
||||
if v not in nodes:
|
||||
nodes.add(v)
|
||||
for child in v._prev:
|
||||
edges.add((child, v))
|
||||
build(child)
|
||||
build(root)
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def draw_dot(root, format='svg', rankdir='LR'):
|
||||
"""
|
||||
format: png | svg | ...
|
||||
rankdir: TB (top to bottom graph) | LR (left to right)
|
||||
"""
|
||||
assert rankdir in ['LR', 'TB']
|
||||
nodes, edges = trace(root)
|
||||
# , node_attr={'rankdir': 'TB'})
|
||||
dot = Digraph(format=format, graph_attr={'rankdir': rankdir})
|
||||
|
||||
for n in nodes:
|
||||
dot.node(name=str(id(n)), label="{ %s | data %.4f | grad %.4f }"
|
||||
% (n.label, n.data, n.grad), shape='record')
|
||||
if n._op:
|
||||
dot.node(name=str(id(n)) + n._op, label=n._op)
|
||||
dot.edge(str(id(n)) + n._op, str(id(n)))
|
||||
|
||||
for n1, n2 in edges:
|
||||
dot.edge(str(id(n1)), str(id(n2)) + n2._op)
|
||||
|
||||
return dot
|
|
@ -0,0 +1,26 @@
|
|||
class Value:
|
||||
|
||||
def __init__(self, data, _children=(), _op='', label=''):
|
||||
self.data = data
|
||||
self.grad = 0.0
|
||||
self._prev = set(_children)
|
||||
self._op = _op
|
||||
self.label = label
|
||||
|
||||
def __repr__(self):
|
||||
return f"Value(data={self.data})"
|
||||
|
||||
def __add__(self, other):
|
||||
out = Value(data=self.data + other.data,
|
||||
_children=(self, other), _op='+')
|
||||
return out
|
||||
|
||||
def __sub__(self, other):
|
||||
out = Value(data=self.data - other.data,
|
||||
_children=(self, other), _op='-')
|
||||
return out
|
||||
|
||||
def __mul__(self, other):
|
||||
out = Value(data=self.data * other.data,
|
||||
_children=(self, other), _op='*')
|
||||
return out
|
|
@ -0,0 +1,3 @@
|
|||
pytest
|
||||
notebook
|
||||
graphviz
|
|
@ -0,0 +1,58 @@
|
|||
from micrograd.engine import Value
|
||||
|
||||
|
||||
def test_value_init():
|
||||
v = Value(1)
|
||||
assert v.data == 1
|
||||
|
||||
|
||||
def test_value_repr():
|
||||
v = Value(2.0)
|
||||
assert "Value(data=2.0)" == repr(v)
|
||||
|
||||
|
||||
def test_value_add():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
assert (v1 + v2).data == 6.0
|
||||
assert "Value(data=6.0)" == repr(v1 + v2)
|
||||
|
||||
|
||||
def test_value_sub():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
assert (v1 - v2).data == -2.0
|
||||
assert "Value(data=-2.0)" == repr(v1 - v2)
|
||||
|
||||
|
||||
def test_value_mul():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
v3 = Value(-1.0)
|
||||
assert (v1 * v2).data == 8.0
|
||||
assert (v1 * v3).data == -2.0
|
||||
|
||||
|
||||
def test_value_mul_add():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
v3 = Value(-1.0)
|
||||
assert ((v1 * v3) + v2).data == 2.0
|
||||
|
||||
|
||||
def test_children():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
out = v1 * v2
|
||||
assert len(out._prev) == 2
|
||||
assert v1 in out._prev
|
||||
assert v2 in out._prev
|
||||
|
||||
|
||||
def test_operations():
|
||||
v1 = Value(2.0)
|
||||
v2 = Value(4.0)
|
||||
mul = v1 * v2
|
||||
add = v1 + v2
|
||||
assert mul._op == '*'
|
||||
assert add._op == '+'
|
Loading…
Reference in New Issue