ååããã®ïŒãã®æåŸã«æžãããéã¿W1ãW2ããã€ã¢ã¹b1ãb2ã®ã°ã©ãæç»çšããŒã¿ãäžæ¬ããŠæ¡åããã³ãŒãæ¹é ã¯ããã£ããã§ããã
#ã³ãŒã4-0
import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.functions import *
from common.gradient import numerical_gradient as n_g
x_e = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
t_e = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
weight_init_std=50.
W1 = weight_init_std * np. array([
[ 0.07395519, -0.13489392, -0.1178099 ],
[ 0.01890785, -0.02397794, 0.18300705]])
W2 = weight_init_std * np. array([
[-0.13469725, 0.1634472 ],
[ 0.13778756, -0.06120645],
[ 0.03805643, 0.24871219]])
b1 = np.zeros(3)
b2 = np.zeros(2)def predict(x):
  A1 = np.dot(x,W1) + b1
  Z1 = sigmoid(A1)
  A2 = np.dot(Z1,W2) + b2
  y = softmax(A2)
  return y
def loss(x, t):
  y = predict(x)
  return cross_entropy_error(y, t)
def acc(x, t):
  y = predict(x)
  y = np.argmax(y, axis=1)
  t = np.argmax(t, axis=1)
  accuracy = np.sum(y == t) / float(x.shape[0])
  return accuracy
import matplotlib.pyplot as plt
loss_W = lambda W: loss(x_e, t_e)
loss_list, acc_list = [ ], [ ]
data_list = [[[ ] for i in range(3)] for j in range(4)]
l_r , s_n = 5.0, 20Â
append() ã¡ãœããã§ã°ã©ãèŠçŽ ã远å ããŠãã空ã®å€æ¬¡å ãªã¹ããäœæããæ¹æ³ã¯ã次ã®ãµã€ããåèã«ãããŠããã ããŸãããããããšãããããŸããã
Pythonã®ãªã¹ãïŒé åïŒãä»»æã®å€ã»èŠçŽ æ°ã§åæå | note.nkmk.me
Â
#ã³ãŒã4-1
for i in range(s_n):
  W1 -= l_r*n_g(loss_W, W1)
  b1 -= l_r * n_g(loss_W, b1)
  W2 -= l_r * n_g(loss_W, W2)
  b2 -= l_r * n_g(loss_W, b2)
  loss_list.append(loss(x_e,t_e))
  acc_list.append(acc(x_e, t_e))
  for k in range(3):
    data_list[0][k].append(W1[0,k])
    data_list[1][k].append(b1[k])
    data_list[2][k].append(W2[k,0])
  for k in range(2):
    data_list[3][k].append(b2[k])
Â
Anacondaããã³ãã察話ã¢ãŒãã§äžæ²ã#ã³ãŒã4-0ããšã#ã³ãŒã4-1ãã«ç¶ããŠæ¬¡ã®ã#ã³ãŒã4-2ãã貌ãä»ããã°æ£è§£çãšæå€±é¢æ°ã®å€ã®ã°ã©ããâŠ
#ã³ãŒã4-2
x = np.arange(len(loss_list))
plt.plot(x, loss_list, label='loss')
plt.plot(x, acc_list, label='acc', linestyle='--')
plt.xlabel("iteration") #x軞ã©ãã«
plt.legend() #å¡äŸ
plt.show()Â
Â
ã#ã³ãŒã4-3ãã貌ãä»ããã°éã¿W1ã®1è¡ç®3èŠçŽ ã®3Dæãç·ã°ã©ããâŠ
#ã³ãŒã4-3
from mpl_toolkits.mplot3d import Axes3D #3Dã§ããããfig = plt.figure()
ax = Axes3D(fig)
ax.plot(data_list[0][0], data_list[0][1], data_list[0][2], "o-")ax.set_xlabel('W100') # 軞ã©ãã«
ax.set_ylabel('W101')
ax.set_zlabel('W102')plt.show() #衚瀺
Â
ã#ã³ãŒã4-4ãã貌ãä»ããã°éã¿W1ã®1è¡ç®3èŠçŽ ã®3é¢2Då±éå³é¢šã°ã©ããâŠ
#ã³ãŒã4-4
plt.subplot(2, 2, 2)
plt.plot(data_list[0][0], data_list[0][2], 'o-')
plt.xlabel("W100")
plt.ylabel("W102")plt.subplot(2, 2, 3)
plt.plot(data_list[0][1], data_list[0][2], 'o-')
plt.xlabel("W101")
plt.ylabel("W102")plt.subplot(2, 2, 4)
plt.plot(data_list[0][0], data_list[0][1], 'o-')
plt.xlabel("W100")
plt.ylabel("W101")
plt.show()
Â
ã#ã³ãŒã4-5ãã貌ãä»ããã°ãã€ã¢ã¹b1ã®3Dæãç·ã°ã©ããâŠ
#ã³ãŒã4-5
fig = plt.figure()
ax = Axes3D(fig)
ax.plot(data_list[1][0], data_list[1][1], data_list[1][2], "o-")
ax.set_xlabel('b10')
ax.set_ylabel('b11')
ax.set_zlabel('b12')
plt.show()
Â
ã#ã³ãŒã4-6ãã貌ãä»ããã°ãã€ã¢ã¹b1ã®3é¢2Då±éå³é¢šã°ã©ããâŠ
#ã³ãŒã4-6
plt.subplot(2, 2, 2)
plt.plot(data_list[1][0], data_list[1][2], 'o-')
plt.xlabel("b10")
plt.ylabel("b12")plt.subplot(2, 2, 3)
plt.plot(data_list[1][1], data_list[1][2], 'o-')
plt.xlabel("b11")
plt.ylabel("b12")plt.subplot(2, 2, 4)
plt.plot(data_list[1][0], data_list[1][1], 'o-')
plt.xlabel("b10")
plt.ylabel("b11")
plt.show()
Â
ã#ã³ãŒã4-7ãã貌ãä»ããã°éã¿W2ã®1åç®3èŠçŽ ã®3Dæãç·ã°ã©ããâŠ
#ã³ãŒã4-7
fig = plt.figure()
ax = Axes3D(fig)
ax.plot(data_list[2][0], data_list[2][1], data_list[2][2], "o-")
ax.set_xlabel('W200')
ax.set_ylabel('W210')
ax.set_zlabel('W220')
plt.show()Â
Â
ã#ã³ãŒã4-8ãã貌ãä»ããã°éã¿W2ã®1åç®3èŠçŽ ã®3é¢2Då±éå³é¢šã°ã©ããâŠ
#ã³ãŒã4-8
plt.subplot(2, 2, 2)
plt.plot(data_list[2][0], data_list[2][2], 'o-')
plt.xlabel("W210")
plt.ylabel("W220")plt.subplot(2, 2, 3)
plt.plot(data_list[2][1], data_list[2][2], 'o-')
plt.xlabel("W210")
plt.ylabel("W220")plt.subplot(2, 2, 4)
plt.plot(data_list[2][0], data_list[2][1], 'o-')
plt.xlabel("W200")
plt.ylabel("W210")
plt.show()
Â
ã#ã³ãŒã4-9ãã貌ãä»ããã°ãã€ã¢ã¹b2ã®2Dæãç·ã°ã©ãã衚瀺ãããã¯ãã§ããã
#ã³ãŒã4-9
plt.plot(data_list[3][0], data_list[3][1], 'o-')
plt.xlabel("b20")
plt.ylabel("b21")
plt.show()
Â
ããã«ã#ã³ãŒã4-10ãã貌ãä»ãããšãW1ãW2ã®å°æ°ç¹ä»¥äž6æ¡ã§äžžããç¶æ ã§W1ãb1ãW2ãb2ãã°ã©ãèŠçŽ ãæ ŒçŽãããªã¹ããåæåãããã
#ã³ãŒã4-10
W1 = weight_init_std * np. array([
[ 0.07395519, -0.13489392, -0.1178099 ],
[ 0.01890785, -0.02397794, 0.18300705]])
W2 = weight_init_std * np. array([
[-0.13469725, 0.1634472 ],
[ 0.13778756, -0.06120645],
[ 0.03805643, 0.24871219]])W1 = np.round(W1, decimals=6)
W2 = np.round(W2, decimals=6)b1 = np.zeros(3)
b2 = np.zeros(2)loss_list, acc_list = [ ], [ ]
data_list = [[[ ] for i in range(3)]\
for j in range(4)]Â
Â
ç¶ããŠã#ã³ãŒã4-0ã以å€ã®ã#ã³ãŒã4-1ãïœã#ã³ãŒã4-9ãã貌ãä»ãããšãW1ãšW2ã6æ¡ã§äžžããåæå€ã«ããåçš®ã°ã©ããæ®ããã
ã#ã³ãŒã4-10ãäžã®ãdecimals=ãã«ç¶ãæ°å2ãæã倿Žãããšãäžžãã®æ¡æ°ãå€ããããšãã§ããã
Â
ãã ãæ¯åæã£ãŠããéããæè€åº·æ¯ ããŒãããäœãDeep Learning âPythonã§åŠã¶ãã£ãŒãã©ãŒãã³ã°ã®çè«ãšå®è£ ã(O'REILLY) ã®ãµã³ãã«ã¹ã¯ãªãããããŠã³ããŒããããã£ã¬ã¯ããªã«ãäºåã«ç§»åããŠããå¿ èŠããããŸããÂ

ãŒãããäœãDeep Learning âPythonã§åŠã¶ãã£ãŒãã©ãŒãã³ã°ã®çè«ãšå®è£
- äœè :æè€ 康æ¯
- çºå£²æ¥: 2016/09/24
- ã¡ãã£ã¢: åè¡æ¬ïŒãœããã«ããŒïŒ
Â
ä»åã¯ããŸãã¯W1ãšW2ã®åæå€ã«å¯Ÿãã乿°Â weight_init_std ã50ãšå€§ããããŠã°ã©ãã®ã¹ã¯ãªãŒã³ã·ã§ãããæ®ã£ãŠã¿ããããªãã¡ãã¿ãã©ã€å¹æãããã¯ã«ãªã¹ãããçŸè±¡ãèµ·ããªãç¶æ ã®ã°ã©ãã§ããã
ãã¿ãã©ã€å¹æãããçŸè±¡ãèµ·ããŠããã°ã©ãã§ã¯ãããããªç¹åŸŽã芳枬ããããŸã ãŸãšããããŠããªãã
W1ã®2è¡ç®ãW2ã®2åç®ã¯ã©ããªããã«ã€ããŠããåŸæ¥èª¬æããäºå®ãå°ãã ãå èµ°ã£ãŠæžããŠããŸããšãäžéå端ãªå¯Ÿç§°æ§ãçŸããã®ã ã
Â
ã#ã³ãŒã4-2ãã«ããæ£è§£çãšæå€±é¢æ°ã®å€ã®ã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


ã芧ã®éãããŸã£ããåºå¥ã€ããªãã£ãïŒ
Windowsãã©ãã§ç»åãé æ¬¡è¡šç€ºããããšãããå®ç§ã«éãªã£ãŠããïŒ
Â
ã#ã³ãŒã4-3ãã«ããW1ã®3Dæãç·ã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-4ãã«ããW1ã®3é¢2Då±éå³é¢šã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-5ãã«ããb1ã®3Dæãç·ã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-6ãã«ããb1ã®3é¢2Då±éå³é¢šã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ãÂ


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-7ãã«ããW2ã®3Dæãç·ã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-8ãã«ããW2ã®3é¢2Då±éå³é¢šã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


Â
ã#ã³ãŒã4-9ãã«ãããã€ã¢ã¹b2ã®2Dæãç·ã°ã©ãã
å·ŠïŒäžžããªããå³ïŒäžžã6æ¡ã


Â
å·ŠïŒäžžã4æ¡ãå³ïŒäžžã2æ¡ã


ç¹°ãè¿ããã寞åãããã¬ã°ã©ããã§ããã®ã ïŒ
Â
ããã«ã°ã©ããæç»ããåŸã§W1ãb1ãW2ãb2ã®å€ããã³ããããšãããç§èŠã ãé©ãã¹ãç¹åŸŽã芳å¯ããããšèããïŒ
äžžããªãã
>>> W1
array([[ 3.10715279, -6.92996836, -8.0399468 ],
[-3.65469687, -2.36826359, 7.24733878]])
>>> b1
array([-2.79170761, 1.27328557, -4.31921463])
>>> W2
array([[-6.51184203, 7.94933953],
[ 6.5981932 , -2.7691377 ],
[ 1.59700312, 12.74142788]])
>>> b2
array([ 1.84055305, -1.84055305])
Â
äžžã6æ¡ã
>>> W1
array([[ 3.10715289, -6.92996828, -8.03994567],
[-3.65469717, -2.36826367, 7.24733962]])
>>> b1
array([-2.79170755, 1.27328572, -4.3192158 ])
>>> W2
array([[-6.51184149, 7.94933949],
[ 6.59819307, -2.76913807],
[ 1.59700329, 12.74142871]])
>>> b2
array([ 1.8405531, -1.8405531])
Â
äžžã4æ¡ã
>>> W1
array([[ 3.10710145, -6.92997272, -8.04005328],
[-3.6547279 , -2.36824909, 7.2472656 ]])
>>> b1
array([-2.79167077, 1.27327315, -4.31912361])
>>> W2
array([[-6.51186187, 7.94936187],
[ 6.59822748, -2.76912748],
[ 1.59701523, 12.74138477]])
>>> b2
array([ 1.84053306, -1.84053306])
Â
äžžã2æ¡ã
>>> W1
array([[ 3.09749693, -6.92642448, -8.03894874],
[-3.65772763, -2.36530438, 7.2475458 ]])
>>> b1
array([-2.78123196, 1.27206239, -4.32347713])
>>> W2
array([[-6.50319204, 7.94319204],
[ 6.60072934, -2.77072934],
[ 1.59685762, 12.74314238]])
>>> b2
array([ 1.83860683, -1.83860683])
é©ãã¹ãç¹åŸŽãšããã®ã¯ãå倿°ã§æå¹æ°åã®ç¯å²ãŸã§æ°åã®äžèŽãèŠãããããšã ïŒ
ããªãã¡äžžã2æ¡ãš4æ¡ã§ã¯å°æ°ç¹ä»¥äž2æ¡ãŸã§ãäžžã4æ¡ãš6æ¡ã§ã¯å°æ°ç¹ä»¥äž4æ¡ãŸã§ãäžžã6æ¡ãšäžžããªãã§ã¯å°æ°ç¹ä»¥äž6æ¡ãŸã§ãæ°åãäžèŽããŠããïŒ
éããããã ããªããããœã³ã³ããæ®ã£ãã¹ã¯ã·ã§ã¯åºå¥ã€ããªãããïŒ
Â
ã§ãæå¹æ°åã£ãŠãããããããšã ã£ãã£ãïŒ
ãããã¯ãäœãåœããåã®ããšãããšããåå¿ãè¿ã£ãŠããããç¥ããªãã
ãšããã weight_init_std ã®å€ãå°ãããããšãããã®ãããªç¹åŸŽãçŸããªããªãã®ã ã
å°æ°ç¹ä»¥äžã©ãããããã¡ã°ãäžã®æ¡ãäžèŽããªããªããçã ããå Žåã¯ã笊å·ããç°ãªãããšãããïŒ
ãããäœãã®å€å®åºæºãšããããšã¯ãã§ããªãã ãããïŒ
ã¹ãã³ãµãŒãªã³ã¯
Â