82 lines
2.0 KiB
Python
Executable File
82 lines
2.0 KiB
Python
Executable File
import cv2
|
||
from PIL import Image
|
||
import matplotlib as mpl
|
||
|
||
from src.utils.load_emnist import *
|
||
|
||
|
||
def get_mapping(num, with_type='letters'):
|
||
"""
|
||
根据 mapping,由传入的 num 计算 UTF8 字符。详见 README.md
|
||
:param num:
|
||
:param with_type:
|
||
:return:
|
||
"""
|
||
if with_type == 'byclass':
|
||
if num <= 9:
|
||
return chr(num + 48) # 数字
|
||
elif num <= 35:
|
||
return chr(num + 55) # 大写字母
|
||
else:
|
||
return chr(num + 61) # 小写字母
|
||
elif with_type == 'letters':
|
||
return chr(num + 64) + " / " + chr(num + 96) # 大写/小写字母
|
||
elif with_type == 'digits':
|
||
return chr(num + 96)
|
||
else:
|
||
return num
|
||
|
||
|
||
def show_pred_result(path, result):
|
||
"""
|
||
可视化的方式展示图片和预测结果
|
||
:param path: 图片路径
|
||
:param result: 预测结果
|
||
:return: None
|
||
"""
|
||
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 中文字体支持
|
||
img = Image.open(path)
|
||
plt.figure(path) # 文件路径作窗口标题
|
||
plt.axis('off') # 关闭坐标系显示
|
||
plt.title("预测结果是:" + result)
|
||
plt.imshow(img)
|
||
plt.show()
|
||
|
||
|
||
def show_data_set(images_data, num=1000):
|
||
"""
|
||
转换数据集中的内容并在窗口中展示
|
||
:param images_data:
|
||
:param num: 展示几个
|
||
:return: None
|
||
"""
|
||
images = []
|
||
img = []
|
||
|
||
for i in range(images_data.shape[0]):
|
||
im = images_data[i]
|
||
im = im.reshape(28, 28)
|
||
if len(img) == 0:
|
||
img = im
|
||
else:
|
||
# 横向组合
|
||
img = np.hstack((img, im))
|
||
|
||
# 每行显示30个数字图片
|
||
if img.shape[1] / 28 == 45:
|
||
if len(images) == 0:
|
||
images = img.copy()
|
||
img = []
|
||
else:
|
||
# 纵向组合
|
||
images = np.vstack((images, img))
|
||
img = []
|
||
|
||
# 显示前 num 个
|
||
if i == num:
|
||
break
|
||
|
||
cv2.imshow('data_set_visualization', images)
|
||
cv2.waitKey()
|
||
cv2.destroyAllWindows()
|