自分でタイタニックのチュートリアルを解く(3)
趣味でやってる合気道、そして友人との飲み会の合間に調べ調べ実装してみました。
import matplotlib.pyplot as plt import numpy as np from numpy.random import * #解析データづくり x_true = rand(100,1) * 100 - 0 y_true = rand(100,1) * 100 - 0 t_true = np.ones((100,1)) x_false = rand(100,1) * 100 - 100 y_false = rand(100,1) * 100 - 100 t_false = -np.ones((100,1)) data_true = np.c_[x_true,y_true,t_true] data_false = np.c_[x_false,y_false,t_false]
とりあえず、これで解析しやすそうな判別データを作成。
ちなみに描画してみるとこんな感じ
これで解析すべきデータはできた。
あとはこれを2分する直線を書いてみる。
%matplotlib inline import matplotlib.pyplot as plt import numpy as np import csv from numpy.random import * def f(x,y,w0,w1,w2): return w0 + w1*x + w2*y # Import Data with open('data.csv', 'r') as d: reader = csv.reader(d) data = list(reader) # 初期値 w0 = 100.0 w1 = -1.0 w2 = 1.0 x_true=[] y_true=[] x_false=[] y_false=[] # main for itr in range(0,10): for i in range(0,len(data)): x = float(data[i][0]) y = float(data[i][1]) t = float(data[i][2]) if f(x,y,w0,w1,w2)*t > 0: continue else: print("false") w0 = w0 + t w1 = w1 + t*x w2 = w2 + t*y print(w0,w1,w2) print(w1/w2,w0/w2) x = np.linspace(-100,100,201) y = -(w1/w2)*x - (w0/w2) plt.plot(x,y,"r-") for i in range (0,len(data)): if int(data[i][2]) > 0: x_true.append(float(data[i][0])) y_true.append(float(data[i][1])) else: x_false.append(float(data[i][0])) y_false.append(float(data[i][1])) plt.scatter(x_true,y_true,color='b',marker='x') plt.scatter(x_false,y_false,color='r',marker='x') plt.show() plt.close()
直線をあわせてプロットするとこんな感じ
恥を忍んでダメダメコードを公開してみた。