fix grad accumulation, not overriding
This commit is contained in:
		
							parent
							
								
									6cf0186544
								
							
						
					
					
						commit
						4c9f7d8d7d
					
				|  | @ -34,8 +34,8 @@ class Value: | ||||||
|                     _children=(self, other), _op='+') |                     _children=(self, other), _op='+') | ||||||
| 
 | 
 | ||||||
|         def _backward(): |         def _backward(): | ||||||
|             self.grad = 1.0 * out.grad |             self.grad += 1.0 * out.grad | ||||||
|             other.grad = 1.0 * out.grad |             other.grad += 1.0 * out.grad | ||||||
|         out._backward = _backward |         out._backward = _backward | ||||||
|         return out |         return out | ||||||
| 
 | 
 | ||||||
|  | @ -49,8 +49,8 @@ class Value: | ||||||
|                     _children=(self, other), _op='*') |                     _children=(self, other), _op='*') | ||||||
| 
 | 
 | ||||||
|         def _backward(): |         def _backward(): | ||||||
|             self.grad = other.data * out.grad |             self.grad += other.data * out.grad | ||||||
|             other.grad = self.data * out.grad |             other.grad += self.data * out.grad | ||||||
|         out._backward = _backward |         out._backward = _backward | ||||||
|         return out |         return out | ||||||
| 
 | 
 | ||||||
|  | @ -61,6 +61,6 @@ class Value: | ||||||
|                     _op='tanh', label='tanh') |                     _op='tanh', label='tanh') | ||||||
| 
 | 
 | ||||||
|         def _backward(): |         def _backward(): | ||||||
|             self.grad = (1 - t**2) * out.grad |             self.grad += (1 - t**2) * out.grad | ||||||
|         out._backward = _backward |         out._backward = _backward | ||||||
|         return out |         return out | ||||||
|  |  | ||||||
|  | @ -61,7 +61,6 @@ def test_large_backprop(): | ||||||
|     assert pytest.approx(w2.grad, 0.001) == 0.0 |     assert pytest.approx(w2.grad, 0.001) == 0.0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="non-deterministic") |  | ||||||
| def test_auto_diff(): | def test_auto_diff(): | ||||||
|     # inputs |     # inputs | ||||||
|     x1 = Value(2.0, label='x1') |     x1 = Value(2.0, label='x1') | ||||||
|  | @ -89,7 +88,6 @@ def test_auto_diff(): | ||||||
| 
 | 
 | ||||||
|     y.backward() |     y.backward() | ||||||
|     assert pytest.approx(n.grad, 0.001) == 0.5 |     assert pytest.approx(n.grad, 0.001) == 0.5 | ||||||
|     assert h.grad == 0.0 |  | ||||||
| 
 | 
 | ||||||
|     assert pytest.approx(b.grad, 0.001) == 0.5 |     assert pytest.approx(b.grad, 0.001) == 0.5 | ||||||
|     assert pytest.approx(h.grad, 0.001) == 0.5 |     assert pytest.approx(h.grad, 0.001) == 0.5 | ||||||
|  | @ -102,3 +100,10 @@ def test_auto_diff(): | ||||||
| 
 | 
 | ||||||
|     assert pytest.approx(x2.grad, 0.001) == 0.5 |     assert pytest.approx(x2.grad, 0.001) == 0.5 | ||||||
|     assert pytest.approx(w2.grad, 0.001) == 0.0 |     assert pytest.approx(w2.grad, 0.001) == 0.0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_accumulation(): | ||||||
|  |     a = Value(2.0, label='x1') | ||||||
|  |     b = a + a | ||||||
|  |     b.backward() | ||||||
|  |     assert a.grad == 2.0 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue