Commit 8454a857 authored by Zhouxingyu's avatar Zhouxingyu

直接使用的图像相似度匹配工具包

parents
Pipeline #29 canceled with stages
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
model_urls = {
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
}
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(0),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(0),
nn.Linear(4096, num_classes),
nn.ReLU(True),
nn.Sigmoid(),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
}
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
\ No newline at end of file
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
from torchvision.transforms import ToTensor, ToPILImage
from VGG16 import vgg16_bn
import torch as t
class img_similarity():
'''
1. 使用 load_img(path) 或者 data_load(path) 读取图片特征tensor。
2. 使用 P_feature_removal(float) ,获得公共特征,并暂时忽略这些特征。
3. 使用 similarity(path) 获得待计算图像和保存图像的相似度。
4. 使用 max_batch(int) 找到相似度前n个最大的图片名称。
5. 你可以使用 data_save(path) 保存提取的图像特征数据,方便下次读取。
'''
def __init__(self):
self.normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
self.transforms = T.Compose([
T.Resize(224),
T.ToTensor(),
self.normalize
])
self.cos = t.nn.CosineSimilarity(dim=1, eps=1e-6)
self.Features = []
self.similarity_list = []
self.remove_list = []
self.pic_name_list = []
def pic2vec(self, img_path):
data = Image.open(img_path)
data = self.transforms(data)
data = data.unsqueeze(0)
model = vgg16_bn(pretrained=True)
Feature = (model(data)-0.5)*2
#Feature = Feature.detach().numpy().tolist()
return Feature
def data_load(self, path):
'''
要加载数据调用此函数,输入输入的地址。
'''
self.Features = t.load(path)
self.pic_name_list = t.load(path.split('\\')[-1].split('.')[-2]+'_namelyst.pth')
def data_save(self, path):
'''
要存储数据调用此函数,应该在load_image后使用,输入输出地址。
'''
path = os.path.abspath(path) #相对路径转绝对路径
print(path)
t.save(self.Features, path)
t.save(self.pic_name_list, path.split('\\')[-1].split('.')[-2]+'_namelyst.pth')
print('保存完毕!')
def vec_remove(self, vec):
m = 0
vec = vec.detach().numpy().tolist()
self.remove_list.sort()
for num in self.remove_list:
del vec[num-m]
m+=1
vec = t.tensor(vec)
vec = vec.unsqueeze(0)
return vec
def load_img(self, root_path):
'''
读取图片并且转化为特征向量。需要输入图片文件夹目录地址。
'''
imgs = [os.path.join(root_path, img) for img in os.listdir(root_path)]
for i in range(len(imgs)):
img_path = imgs[i]
#img_similarity.pic2vec(self, img_path)
name = img_path.split('.')[-2].split('\\')[-1]
self.Features.append(img_similarity.pic2vec(self, img_path))
self.pic_name_list.append(name)
print(f'图片 {name} 加载完毕!')
print('图片加载完成!如有需要请用data_save方法进行数据报错,以方便下次快速读取。')
def similarity(self, pic_path):
'''
计算相似度,输入待计算图片,将图片与文件夹内图片进行比对,并且输出和每一张图片的相似度列表。
'''
vec = img_similarity.pic2vec(self, pic_path)
#print('正在比对相似度。。。。')
for i in range(len(self.Features)):
'''
x = self.vec_remove(vec[0])
y = self.vec_remove(self.Features[i][0])
print(x.shape, y.shape)
a = self.cos(x, y)
a = a.detach().numpy().tolist()[0]
print(a)
'''
self.similarity_list.append(self.cos(self.vec_remove(vec[0]), self.vec_remove(self.Features[i][0])).detach().numpy().tolist()[0])
return self.similarity_list
def P_feature_removal(self, std=0.2):
'''
运行后将获得均方差小于std的向量。std可以修改。
'''
new_Features = []
for i in range(len(self.Features)):
new_Features.append(self.Features[i].detach().numpy().tolist())
for m in range(1000):
sub_list = []
for i in range(len(new_Features)):
sub_list.append(new_Features[i][0][m])
img_std = np.std(np.array(sub_list),ddof=1)
if img_std < std:
self.remove_list.append(m)
def max(self):
'''
m = 0
for i in range(1,len(self.similarity_list)):
if self.similarity_list[i] > self.similarity_list[m]:
m = i
'''
return self.pic_name_list[self.similarity_list.index(max(self.similarity_list))]
def max_batch(self, n=1):
'''
提取相似度最大的n个图片名称。n默认为1,可以自己设定。
'''
from collections import defaultdict
d = defaultdict(list)
for index, item in enumerate(self.similarity_list):
d[item].append(index)
front = [ d[i] for i in sorted(d)]
front = front[-n:]
max_sim_name = []
for i in range(len(front)):
max_sim_name.append(self.pic_name_list[front[i][0]])
return max_sim_name
'''
s = img_similarity()
#s.load_img('picture')
#s.data_save('data.pth')
#s.data_load('data.pth')
#print(s.pic_name_list)
s.data_load('data.pth')
s.P_feature_removal()
print(s.pic_name_list)
print(s.similarity('9387011a8ed714f6.jpg'))
print(s.max_batch())
'''
# 图片相似度
### 介绍
本程序使用VGG16进行特征提取,利用均方差阈值进行特征筛选,使用向量余弦相似度进行相似度匹配。
### 目录
1. [环境搭建](#环境搭建)
2. [如何使用](#如何使用)
3. [模型下载](#模型下载)
### 环境搭建
1. 按照requirements.txt搭建环境。
2. 使用python3.6以上版本。
### 如何使用
1. 使用 load_img(path) 或者 data_load(path) 读取图片特征tensor。
2. 使用 P_feature_removal(float) ,获得公共特征,并暂时忽略这些特征。
3. 使用 similarity(path) 获得待计算图像和保存图像的相似度。
4. 使用 max_batch(int) 找到相似度前n个最大的图片名称。
5. 你可以使用 data_save(path) 保存提取的图像特征数据,方便下次读取。
### 模型下载
1. 首次使用将自动下载VGG16的模型。
torch==1.2.0+cu92
torchvision==0.4.0+cu92
numpy==1.17.0
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment