kabaohのブログ

仕事ができなすぎて辛いので、趣味に逃げるカバ野郎

自分でタイタニックのチュートリアルを解く(3)

趣味でやってる合気道、そして友人との飲み会の合間に調べ調べ実装してみました。

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import *

#解析データづくり
x_true = rand(100,1) * 100 - 0
y_true = rand(100,1) * 100 - 0
t_true = np.ones((100,1))


x_false = rand(100,1) * 100 - 100
y_false = rand(100,1) * 100 - 100
t_false = -np.ones((100,1))

data_true = np.c_[x_true,y_true,t_true]
data_false = np.c_[x_false,y_false,t_false]

とりあえず、これで解析しやすそうな判別データを作成。
ちなみに描画してみるとこんな感じ
f:id:kabaoh:20151130003044p:plain

これで解析すべきデータはできた。
あとはこれを2分する直線を書いてみる。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import csv 
from numpy.random import *
def f(x,y,w0,w1,w2):
    return w0 + w1*x + w2*y 

# Import Data
with open('data.csv', 'r') as d:
    reader = csv.reader(d)
    data = list(reader)
    
# 初期値
w0 = 100.0
w1 =  -1.0
w2 = 1.0

x_true=[]
y_true=[]

x_false=[]
y_false=[]


# main

for itr in range(0,10):
    for i in range(0,len(data)):
        x = float(data[i][0])
        y = float(data[i][1])
        t = float(data[i][2]) 
        
        if f(x,y,w0,w1,w2)*t > 0:
            continue
        else:
            print("false")
            w0 = w0 + t 
            w1 = w1 + t*x
            w2 = w2 + t*y            
            print(w0,w1,w2)
            print(w1/w2,w0/w2)

x = np.linspace(-100,100,201)  
y = -(w1/w2)*x - (w0/w2)  

plt.plot(x,y,"r-")

for i in range (0,len(data)):
    if int(data[i][2]) > 0:
        x_true.append(float(data[i][0]))
        y_true.append(float(data[i][1]))
     
    else:
        x_false.append(float(data[i][0]))
        y_false.append(float(data[i][1]))

plt.scatter(x_true,y_true,color='b',marker='x')
plt.scatter(x_false,y_false,color='r',marker='x')
plt.show()
plt.close()

直線をあわせてプロットするとこんな感じ
f:id:kabaoh:20151130003716p:plain


恥を忍んでダメダメコードを公開してみた。