import numpy
import json
CATEGORY = 10
SAMPLES = 1000
label_list = [i for i in range(CATEGORY)]
pred_list = numpy.random.randint(0, CATEGORY-1, size=SAMPLES)
y_batch_list = numpy.random.randint(0, CATEGORY-1, size=SAMPLES)
print(pred_list, y_batch_list)
class confusion_matrix:
def __init__(self, pred_list, y_batch_list, label_list):
if len(pred_list) != len(y_batch_list):
raise Exception('Prediction length is different from Label list!')
self.pred_list = pred_list
self.y_batch_list = y_batch_list
self.matrix_size = len(label_list)
# this matrix are 2 dimensions(y_batch, pred)
self.confusion_matrix = [[ x*0 for x in range(self.matrix_size)] for y in range(self.matrix_size)]
self.precision_list = [x*0 for x in range(self.matrix_size)]
self.recall_list = [x*0 for x in range(self.matrix_size)]
def calculate_confusion_matrix(self):
for i in range(len(self.pred_list)):
# dimension => [y_batch, pred]
self.confusion_matrix[self.y_batch_list[i]][self.pred_list[i]] += 1
def calculate_recall_precision_list(self):
# calculate recall
for i in range(self.matrix_size):
tmp_value = 0
for j in range(self.matrix_size):
tmp_value += self.confusion_matrix[i][j]
if tmp_value is not 0:
self.recall_list[i] = float(self.confusion_matrix[i][i]) / tmp_value
# calculate precision
for j in range(self.matrix_size):
tmp_value = 0
for i in range(self.matrix_size):
tmp_value += self.confusion_matrix[i][j]
if tmp_value is not 0:
self.precision_list[j] = float(self.confusion_matrix[j][j]) / tmp_value
def gen_json_data(self):
data = {'confusion_matrix': self.confusion_matrix,
'precision_list': self.precision_list,
'recall_list': self.recall_list
}
return data
ret = confusion_matrix(pred_list.tolist(), y_batch_list.tolist(), label_list)
ret.calculate_confusion_matrix()
ret.calculate_recall_precision_list()
Result:
print(ret.gen_json_data())
{'precision_list': [0.0625, 0.14912280701754385, 0.02654867256637168, 0.1452991452991453, 0.07377049180327869, 0.10526315789473684, 0.11320754716981132, 0.13, 0.13725490196078433, 0],
'confusion_matrix': [[7, 14, 10, 15, 17, 19, 17, 18, 14, 0], [10, 17, 14, 9, 5, 11, 9, 12, 12, 0], [11, 11, 3, 19, 16, 13, 4, 11, 7, 0], [13, 18, 16, 17, 13, 12, 11, 11, 12, 0], [15, 12, 15, 14, 9, 13, 17, 9, 11, 0], [19, 8, 11, 11, 17, 12, 13, 10, 8, 0], [9, 9, 10, 11, 14, 11, 12, 7, 15, 0], [20, 14, 13, 10, 18, 10, 11, 13, 9, 0], [8, 11, 21, 11, 13, 13, 12, 9, 14, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
'recall_list': [0.05343511450381679, 0.1717171717171717, 0.031578947368421054, 0.13821138211382114, 0.0782608695652174, 0.11009174311926606, 0.12244897959183673, 0.11016949152542373, 0.125, 0]
}
1 comment:
best post.
Learn Data Science Online Course
Data Science Online Training In India
Post a Comment