Commit 7489e16a authored by Zhouxingyu's avatar Zhouxingyu

二次封装

parent 8454a857
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import gc
model_urls = {
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
}
......@@ -67,7 +67,6 @@ cfg = {
}
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
......@@ -76,7 +75,8 @@ def vgg16_bn(pretrained=False, **kwargs):
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
a = make_layers(cfg['D'], batch_norm=True)
model = VGG(a, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
\ No newline at end of file
# -*- coding:utf-8 -*-
import re
def pd(str):
res = re.findall(r'[\u4E00-\u9FA5]', str)
new_res = ''.join(res)
#print(len(new_res))
if len(new_res) == 0:
res1 = re.findall(r'[a-zA-Z0-9]', str)
new_res = ''.join(res1)
print(new_res)
return new_res
if __name__ == '__main__':
a = '3M'
pd(a)
b = '惠HP普'
pd(b)
c = 'C&C'
pd(c)
\ No newline at end of file
......@@ -6,6 +6,9 @@ from torchvision import transforms as T
from torchvision.transforms import ToTensor, ToPILImage
from VGG16 import vgg16_bn
import torch as t
import gc
import sys
import time
class img_similarity():
'''
......@@ -24,21 +27,13 @@ class img_similarity():
self.normalize
])
self.cos = t.nn.CosineSimilarity(dim=1, eps=1e-6)
self.model = vgg16_bn(pretrained=True).cuda()
self.Features = []
self.similarity_list = []
self.remove_list = []
self.pic_name_list = []
def pic2vec(self, img_path):
data = Image.open(img_path)
data = self.transforms(data)
data = data.unsqueeze(0)
model = vgg16_bn(pretrained=True)
Feature = (model(data)-0.5)*2
#Feature = Feature.detach().numpy().tolist()
return Feature
def data_load(self, path):
'''
要加载数据调用此函数,输入输入的地址。
......@@ -52,15 +47,23 @@ class img_similarity():
要存储数据调用此函数,应该在load_image后使用,输入输出地址。
'''
path = os.path.abspath(path) #相对路径转绝对路径
print(path)
t.save(self.Features, path)
t.save(self.pic_name_list, path.split('\\')[-1].split('.')[-2]+'_namelyst.pth')
print('保存完毕!')
t.save(self.pic_name_list, '.'.join(path.split('.')[:-1])+'_namelyst.pth')
print('图片信息保存完毕!')
def remove_save(self, path):
t.save(self.remove_list, path)
print('remove列表保存完毕!')
def remove_load(self, path):
self.remove_list = t.load(path)
def vec_remove(self, vec):
m = 0
vec = vec.detach().numpy().tolist()
vec = vec.cpu().detach().numpy().tolist()
self.remove_list.sort()
for num in self.remove_list:
del vec[num-m]
......@@ -75,21 +78,33 @@ class img_similarity():
读取图片并且转化为特征向量。需要输入图片文件夹目录地址。
'''
imgs = [os.path.join(root_path, img) for img in os.listdir(root_path)]
index = Index()
print(f'共计有{len(imgs)}张图片需要处理。')
for i in range(len(imgs)):
img_path = imgs[i]
#img_similarity.pic2vec(self, img_path)
name = img_path.split('.')[-2].split('\\')[-1]
self.Features.append(img_similarity.pic2vec(self, img_path))
Feature = (self.model(self.transforms(Image.open(img_path).convert("RGB")).cuda().unsqueeze(0))-0.5)*2
Feature = Feature.cpu()
Feature = Feature.detach().numpy()
Feature = t.from_numpy(Feature)
#a = img_similarity.pic2vec(self, img_path)
self.Features.append(Feature)
Feature = None
gc.collect()
self.pic_name_list.append(name)
print(f'图片 {name} 加载完毕!')
print('图片加载完成!如有需要请用data_save方法进行数据报错,以方便下次快速读取。')
#print(f'图片 {name} 加载完毕!')
print(index(i, len(imgs)-1), end='%')
time.sleep(0.01)
print('\n图片加载完成!如有需要请用data_save方法进行数据保存,以方便下次快速读取。')
def similarity(self, pic_path):
'''
计算相似度,输入待计算图片,将图片与文件夹内图片进行比对,并且输出和每一张图片的相似度列表。
'''
vec = img_similarity.pic2vec(self, pic_path)
self.pic_path = pic_path
vec = (self.model(self.transforms(Image.open(pic_path).convert("RGB")).cuda().unsqueeze(0))-0.5)*2
#print('正在比对相似度。。。。')
for i in range(len(self.Features)):
'''
......@@ -130,7 +145,7 @@ class img_similarity():
return self.pic_name_list[self.similarity_list.index(max(self.similarity_list))]
def max_batch(self, n=1):
def max_batch(self, n=2):
'''
提取相似度最大的n个图片名称。n默认为1,可以自己设定。
'''
......@@ -144,9 +159,58 @@ class img_similarity():
max_sim_name = []
for i in range(len(front)):
max_sim_name.append(self.pic_name_list[front[i][0]])
return max_sim_name
for value in max_sim_name:
if value != self.pic_path:
return value
class Index(object):
def __init__(self, number=50, decimal=2):
"""
:param decimal: 你保留的保留小数位
:param number: # 号的 个数
"""
self.decimal = decimal
self.number = number
self.a = 100/number # 在百分比 为几时增加一个 # 号
def __call__(self, now, total):
# 1. 获取当前的百分比数
percentage = self.percentage_number(now, total)
# 2. 根据 现在百分比计算
well_num = int(percentage / self.a)
# print("well_num: ", well_num, percentage)
# 3. 打印字符进度条
progress_bar_num = self.progress_bar(well_num)
# 4. 完成的进度条
result = "\r%s %s" % (progress_bar_num, percentage)
return result
def percentage_number(self, now, total):
"""
计算百分比
:param now: 现在的数
:param total: 总数
:return: 百分
"""
return round(now / total * 100, self.decimal)
def progress_bar(self, num):
"""
显示进度条位置
:param num: 拼接的 “#” 号的
:return: 返回的结果当前的进度条
"""
# 1. "#" 号个数
well_num = "#" * num
# 2. 空格的个数
space_num = " " * (self.number - num)
return '[%s%s]' % (well_num, space_num)
'''
s = img_similarity()
#s.load_img('picture')
......@@ -159,6 +223,20 @@ s.P_feature_removal()
print(s.pic_name_list)
print(s.similarity('9387011a8ed714f6.jpg'))
print(s.max_batch())
'''
s = img_similarity()
s.load_img('oil')
s.data_save('oil_data.pth')
s.P_feature_removal()
s.remove_save('oil_remove.pth')
s = img_similarity()
s.data_load('oil_data.pth')
s.remove_load('oil_remove.pth')
print(len(s.remove_list))
print(s.pic_name_list)
print(s.similarity('京东_刀麦_3.jpg'))
print(s.max_batch())
'''
from tool import sim_tools
from PIL import Image
from matplotlib import pyplot as plt
p = sim_tools('食用油图片', 'data')
p.pth_load()
name = p.img_match('0da610f7fc57e3e1.jpg')
img = Image.open(f'imgs/食用油图片/{name}.jpg')
plt.imshow(img)
plt.axis('off')
plt.show()
\ No newline at end of file
# -*-coding:utf-8-*-
from function import img_similarity
import os
class sim_tools(img_similarity):
def __init__(self, document_name, data_path=None): #document_name为图片存放文件夹,data_path为提取后的数据存储位置。
super(sim_tools, self).__init__()
self.document_name = document_name
self.name = self.document_name.split('\\')[-1]
#加入名称split
if data_path != None:
self.data_path = data_path + '/'
else:
self.data_path = data_path
def pth_save(self):
self.load_img(f'{self.document_name}')
self.P_feature_removal()
self.data_save(f'{self.data_path}{self.name}.pth')
self.remove_save(f'{self.data_path}{self.name}_remove.pth')
def pth_load(self):
self.data_load(f'{self.data_path}{self.name}.pth')
self.remove_load(f'{self.data_path}{self.name}_remove.pth')
def img_match(self, pic_path):
self.similarity(pic_path)
return self.max_batch()
def main(img_dir, data_dir):
for root,dirs,files in os.walk(img_dir):
for dir in dirs:
print(f'开始提取{dir}文件夹的图片特征。。。')
p = sim_tools(os.path.join(root,dir), data_dir)
p.pth_save()
print('全部图片特征提取完毕!')
if __name__ == '__main__':
img_dir = input('请输入存放图片的文件夹:')
data_dir = input('请输入用来存放特征提取信息的文件夹:')
main(img_dir, data_dir)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment