fix grad accumulation, not overriding
This commit is contained in:
parent
6cf0186544
commit
4c9f7d8d7d
|
@ -34,8 +34,8 @@ class Value:
|
||||||
_children=(self, other), _op='+')
|
_children=(self, other), _op='+')
|
||||||
|
|
||||||
def _backward():
|
def _backward():
|
||||||
self.grad = 1.0 * out.grad
|
self.grad += 1.0 * out.grad
|
||||||
other.grad = 1.0 * out.grad
|
other.grad += 1.0 * out.grad
|
||||||
out._backward = _backward
|
out._backward = _backward
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -49,8 +49,8 @@ class Value:
|
||||||
_children=(self, other), _op='*')
|
_children=(self, other), _op='*')
|
||||||
|
|
||||||
def _backward():
|
def _backward():
|
||||||
self.grad = other.data * out.grad
|
self.grad += other.data * out.grad
|
||||||
other.grad = self.data * out.grad
|
other.grad += self.data * out.grad
|
||||||
out._backward = _backward
|
out._backward = _backward
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -61,6 +61,6 @@ class Value:
|
||||||
_op='tanh', label='tanh')
|
_op='tanh', label='tanh')
|
||||||
|
|
||||||
def _backward():
|
def _backward():
|
||||||
self.grad = (1 - t**2) * out.grad
|
self.grad += (1 - t**2) * out.grad
|
||||||
out._backward = _backward
|
out._backward = _backward
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -61,7 +61,6 @@ def test_large_backprop():
|
||||||
assert pytest.approx(w2.grad, 0.001) == 0.0
|
assert pytest.approx(w2.grad, 0.001) == 0.0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="non-deterministic")
|
|
||||||
def test_auto_diff():
|
def test_auto_diff():
|
||||||
# inputs
|
# inputs
|
||||||
x1 = Value(2.0, label='x1')
|
x1 = Value(2.0, label='x1')
|
||||||
|
@ -89,7 +88,6 @@ def test_auto_diff():
|
||||||
|
|
||||||
y.backward()
|
y.backward()
|
||||||
assert pytest.approx(n.grad, 0.001) == 0.5
|
assert pytest.approx(n.grad, 0.001) == 0.5
|
||||||
assert h.grad == 0.0
|
|
||||||
|
|
||||||
assert pytest.approx(b.grad, 0.001) == 0.5
|
assert pytest.approx(b.grad, 0.001) == 0.5
|
||||||
assert pytest.approx(h.grad, 0.001) == 0.5
|
assert pytest.approx(h.grad, 0.001) == 0.5
|
||||||
|
@ -102,3 +100,10 @@ def test_auto_diff():
|
||||||
|
|
||||||
assert pytest.approx(x2.grad, 0.001) == 0.5
|
assert pytest.approx(x2.grad, 0.001) == 0.5
|
||||||
assert pytest.approx(w2.grad, 0.001) == 0.0
|
assert pytest.approx(w2.grad, 0.001) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_accumulation():
|
||||||
|
a = Value(2.0, label='x1')
|
||||||
|
b = a + a
|
||||||
|
b.backward()
|
||||||
|
assert a.grad == 2.0
|
||||||
|
|
Loading…
Reference in New Issue