117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
#!/usr/bin/python
|
|
|
|
import math
|
|
|
|
class Table(object):
|
|
|
|
def __init__(self, table_entry=256, table_range=8):
|
|
self.table_entry = table_entry
|
|
self.table_range = table_range
|
|
pass
|
|
|
|
def sigmoid(self, x):
|
|
return 1 / (1 + math.exp(-1*x))
|
|
|
|
def tanh(self, x):
|
|
return (math.exp(2*x)-1) / (math.exp(2*x)+1)
|
|
|
|
def fp2q7(self, x):
|
|
x_int = math.floor(x*(2**7)+0.5)
|
|
if x_int >= 128 :
|
|
x_int = 127
|
|
if x_int < -128 :
|
|
x_int = -128
|
|
if x_int >= 0 :
|
|
return x_int
|
|
else :
|
|
return 0x100 + x_int
|
|
|
|
def fp2q15(self, x):
|
|
x_int = math.floor(x*(2**15)+0.5)
|
|
if x_int >= 2**15 :
|
|
x_int = 2**15-1
|
|
if x_int < -1*2**15 :
|
|
x_int = -1*2**15
|
|
if x_int >= 0 :
|
|
return x_int
|
|
else :
|
|
return 0x10000 + x_int
|
|
|
|
def table_gen(self):
|
|
outfile = open("NNCommonTable.c", "wb")
|
|
|
|
outfile.write("/*\n * Common tables for NN\n *\n *\n *\n *\n */\n\n#include \"arm_math.h\"\n#include \"NNCommonTable.h\"\n\n/*\n * Table for sigmoid\n */\n")
|
|
|
|
for function_type in ["sigmoid", "tanh"]:
|
|
for data_type in [7, 15]:
|
|
out_type = "q"+str(data_type)+"_t"
|
|
act_func = getattr(self, function_type)
|
|
quan_func = getattr(self, 'fp2q'+str(data_type))
|
|
|
|
# unified table
|
|
outfile.write('const %s %sTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry) )
|
|
for i in range(self.table_entry):
|
|
# convert into actual value
|
|
if i < self.table_entry/2:
|
|
value_q7 = self.table_range * (i)
|
|
else:
|
|
value_q7 = self.table_range * (i - self.table_entry)
|
|
|
|
if data_type == 7:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
|
|
else:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
|
|
if i % 8 == 7:
|
|
outfile.write("\n")
|
|
outfile.write("};\n\n")
|
|
|
|
for data_type in [15]:
|
|
out_type = "q"+str(data_type)+"_t"
|
|
act_func = getattr(self, function_type)
|
|
quan_func = getattr(self, 'fp2q'+str(data_type))
|
|
|
|
# H-L tables
|
|
outfile.write('const %s %sLTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry/2))
|
|
for i in range(self.table_entry/2):
|
|
# convert into actual value, max value is 16*self.table_entry/4 / 4
|
|
# which is equivalent to self.table_entry / self.table_entry/2 = 2, i.e., 1/4 of 8
|
|
if i < self.table_entry/4:
|
|
value_q7 = self.table_range * i / 4
|
|
else:
|
|
value_q7 = self.table_range * (i - self.table_entry/2) / 4
|
|
if data_type == 7:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
|
|
else:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
|
|
if i % 8 == 7:
|
|
outfile.write("\n")
|
|
outfile.write("};\n\n")
|
|
|
|
outfile.write('const %s %sHTable_q%d[%d] = {\n' % (out_type, function_type, data_type, 3*self.table_entry/4))
|
|
for i in range(3 * self.table_entry/4):
|
|
# convert into actual value, tageting range (2, 8)
|
|
if i < 3*self.table_entry/8 :
|
|
value_q7 = self.table_range * ( i + self.table_entry/8 )
|
|
else:
|
|
value_q7 = self.table_range * ( i + self.table_entry/8 - self.table_entry)
|
|
if data_type == 7:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
|
|
else:
|
|
#outfile.write('%f, ' % (act_func(float(value_q7)/256)))
|
|
outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
|
|
if i % 8 == 7:
|
|
outfile.write("\n")
|
|
outfile.write("};\n\n")
|
|
|
|
outfile.close()
|
|
|
|
|
|
mytable = Table(table_entry=256, table_range=16)
|
|
|
|
mytable.table_gen()
|