需要结合深度学习模型
1、pom依赖
注意结尾的webp-imageio 包,用于解决ImageIO.read读取部分图片返回为null的问题
<dependency><groupId>org.openpnp</groupId><artifactId>opencv</artifactId><version>4.7.0-0</version></dependency><dependency><groupId>com.microsoft.onnxruntime</groupId><artifactId>onnxruntime</artifactId><version>1.17.1</version></dependency><!-- 服务器端推理引擎 --><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>${djl.version}</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>${djl.version}</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>${djl.version}</version></dependency><!-- Pytorch --><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>${djl.version}</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>${djl.version}</version></dependency><!-- ONNX --><dependency><groupId>ai.djl.onnxruntime</groupId><artifactId>onnxruntime-engine</artifactId><version>${djl.version}</version></dependency><!-- 解决ImageIO.read 读取为null --><dependency><groupId>org.sejda.imageio</groupId><artifactId>webp-imageio</artifactId><version>0.1.6</version></dependency>
2、加载模型
注意提前设置环境变量,pytorch依赖环境dll文件,如果不存在,则默认下载
System.setProperty("ENGINE_CACHE_DIR", modelPath);
import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;public Criteria<Image, T> criteria() {Translator<Image, T> translator = getTranslator(arguments);try {JarFileUtils.copyFileFromJar("/onnx/models/" + modelName, PathConstants.ONNX, null, false, true);} catch (IOException e) {throw new RuntimeException(e);}
// String model_path = PathConstants.TEMP_DIR + PathConstants.ONNX + "/" + modelName;String modelPath = PathConstants.TEMP_DIR + File.separator+PathConstants.ONNX_NAME+ File.separator + modelName;log.info("路径修改前:{}",modelPath);modelPath= DjlHandlerUtil.getFixedModelPath(modelPath);log.info("路径修改后:{}",modelPath);Criteria<Image, T> criteria =Criteria.builder().setTypes(Image.class, getClassOfT()).optModelUrls(modelPath).optTranslator(translator).optDevice(Device.cpu()).optEngine(getEngine()) // Use PyTorch engine.optProgress(new ProgressBar()).build();return criteria;}protected Translator<Image, float[]> getTranslator(Map<String, Object> arguments) {BaseImageTranslator.BaseBuilder<?> builder=new BaseImageTranslator.BaseBuilder<BaseImageTranslator.BaseBuilder>() {@Overrideprotected BaseImageTranslator.BaseBuilder self() {return this;}};return new BaseImageTranslator<float[]>(builder) {@Overridepublic float[] processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {return ndList.get(0).toFloatArray();}};}
3、FaceFeatureTranslator
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;/*** @author gc.x* @date 2022-04*/
public final class FaceFeatureTranslator implements Translator<Image, float[]> {public FaceFeatureTranslator() {}@Overridepublic NDList processInput(TranslatorContext ctx, Image input) {NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);Pipeline pipeline = new Pipeline();pipeline// .add(new Resize(160)).add(new ToTensor()).add(new Normalize(new float[]{127.5f / 255.0f, 127.5f / 255.0f, 127.5f / 255.0f},new float[]{128.0f / 255.0f, 128.0f / 255.0f, 128.0f / 255.0f}));return pipeline.transform(new NDList(array));}@Overridepublic float[] processOutput(TranslatorContext ctx, NDList list) {NDList result = new NDList();long numOutputs = list.singletonOrThrow().getShape().get(0);for (int i = 0; i < numOutputs; i++) {result.add(list.singletonOrThrow().get(i));}float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);float[] feature = new float[embeddings.length];for (int i = 0; i < embeddings.length; i++) {feature[i] = embeddings[i][0];}return feature;}@Overridepublic Batchifier getBatchifier() {return Batchifier.STACK;}
}
4、BaseImageTranslator
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.*;
import ai.djl.util.Utils;import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Map;public abstract class BaseImageTranslator<T> implements Translator<Image, T> {private static final float[] MEAN = {0.485f, 0.456f, 0.406f};private static final float[] STD = {0.229f, 0.224f, 0.225f};private Image.Flag flag;private Pipeline pipeline;private Batchifier batchifier;/*** Constructs an ImageTranslator with the provided builder.** @param builder the data to build with*/public BaseImageTranslator(BaseBuilder<?> builder) {flag = builder.flag;pipeline = builder.pipeline;batchifier = builder.batchifier;}/** {@inheritDoc} */@Overridepublic Batchifier getBatchifier() {return batchifier;}/*** Processes the {@link Image} input and converts it to NDList.** @param ctx the toolkit that helps create the input NDArray* @param input the {@link Image} input* @return a {@link NDList}*/@Overridepublic NDList processInput(TranslatorContext ctx, Image input) {NDArray array = input.toNDArray(ctx.getNDManager(), flag);array = NDImageUtils.resize(array, 640, 640);array = array.transpose(2, 0, 1); // HWC -> CHW RGB -> BGR
// array = array.expandDims(0);array = array.div(255f);return new NDList(array);
// return pipeline.transform(new NDList(array));}protected static String getStringValue(Map<String, ?> arguments, String key, String def) {Object value = arguments.get(key);if (value == null) {return def;}return value.toString();}protected static int getIntValue(Map<String, ?> arguments, String key, int def) {Object value = arguments.get(key);if (value == null) {return def;}return (int) Double.parseDouble(value.toString());}protected static float getFloatValue(Map<String, ?> arguments, String key, float def) {Object value = arguments.get(key);if (value == null) {return def;}return (float) Double.parseDouble(value.toString());}protected static boolean getBooleanValue(Map<String, ?> arguments, String key, boolean def) {Object value = arguments.get(key);if (value == null) {return def;}return Boolean.parseBoolean(value.toString());}/*** A builder to extend for all classes extending the {@link BaseImageTranslator}.** @param <T> the concrete builder type*/@SuppressWarnings("rawtypes")public abstract static class BaseBuilder<T extends BaseBuilder> {protected int width = 224;protected int height = 224;protected Image.Flag flag = Image.Flag.COLOR;protected Pipeline pipeline;protected Batchifier batchifier = Batchifier.STACK;/*** Sets the optional {@link Image.Flag} (default is {@link* Image.Flag#COLOR}).** @param flag the color mode for the images* @return this builder*/public T optFlag(Image.Flag flag) {this.flag = flag;return self();}/*** Sets the {@link Pipeline} to use for pre-processing the image.** @param pipeline the pre-processing pipeline* @return this builder*/public T setPipeline(Pipeline pipeline) {this.pipeline = pipeline;return self();}/*** Adds the {@link Transform} to the {@link Pipeline} use for pre-processing the image.** @param transform the {@link Transform} to be added* @return this builder*/public T addTransform(Transform transform) {if (pipeline == null) {pipeline = new Pipeline();}pipeline.add(transform);return self();}/*** Sets the {@link Batchifier} for the {@link Translator}.** @param batchifier the {@link Batchifier} to be set* @return this builder*/public T optBatchifier(Batchifier batchifier) {this.batchifier = batchifier;return self();}protected abstract T self();protected void validate() {if (pipeline == null) {throw new IllegalArgumentException("pipeline is required.");}}protected void configPreProcess(Map<String, ?> arguments) {if (pipeline == null) {pipeline = new Pipeline();}width = getIntValue(arguments, "width", 224);height = getIntValue(arguments, "height", 224);if (arguments.containsKey("flag")) {flag = Image.Flag.valueOf(arguments.get("flag").toString());}if (getBooleanValue(arguments, "centerCrop", false)) {addTransform(new CenterCrop());}if (getBooleanValue(arguments, "resize", false)) {addTransform(new Resize(width, height));}if (getBooleanValue(arguments, "toTensor", true)) {addTransform(new ToTensor());}String normalize = getStringValue(arguments, "normalize", "false");if ("true".equals(normalize)) {addTransform(new Normalize(MEAN, STD));} else if (!"false".equals(normalize)) {String[] tokens = normalize.split("\\s*,\\s*");if (tokens.length != 6) {throw new IllegalArgumentException("Invalid normalize value: " + normalize);}float[] mean = {Float.parseFloat(tokens[0]),Float.parseFloat(tokens[1]),Float.parseFloat(tokens[2])};float[] std = {Float.parseFloat(tokens[3]),Float.parseFloat(tokens[4]),Float.parseFloat(tokens[5])};addTransform(new Normalize(mean, std));}String range = (String) arguments.get("range");if ("0,1".equals(range)) {addTransform(a -> a.div(255f));} else if ("-1,1".equals(range)) {addTransform(a -> a.div(128f).sub(1));}if (arguments.containsKey("batchifier")) {batchifier = Batchifier.fromString((String) arguments.get("batchifier"));}}protected void configPostProcess(Map<String, ?> arguments) {}}/** A Builder to construct a {@code ImageClassificationTranslator}. */@SuppressWarnings("rawtypes")public abstract static class ClassificationBuilder<T extends BaseBuilder>extends BaseBuilder<T> {protected SynsetLoader synsetLoader;/*** Sets the name of the synset file listing the potential classes for an image.** @param synsetArtifactName a file listing the potential classes for an image* @return the builder*/public T optSynsetArtifactName(String synsetArtifactName) {synsetLoader = new SynsetLoader(synsetArtifactName);return self();}/*** Sets the URL of the synset file.** @param synsetUrl the URL of the synset file* @return the builder*/public T optSynsetUrl(String synsetUrl) {try {this.synsetLoader = new SynsetLoader(new URL(synsetUrl));} catch (MalformedURLException e) {throw new IllegalArgumentException("Invalid synsetUrl: " + synsetUrl, e);}return self();}/*** Sets the potential classes for an image.** @param synset the potential classes for an image* @return the builder*/public T optSynset(List<String> synset) {synsetLoader = new SynsetLoader(synset);return self();}/** {@inheritDoc} */@Overrideprotected void validate() {super.validate();if (synsetLoader == null) {synsetLoader = new SynsetLoader("synset.txt");}}/** {@inheritDoc} */@Overrideprotected void configPostProcess(Map<String, ?> arguments) {String synset = (String) arguments.get("synset");if (synset != null) {optSynset(Arrays.asList(synset.split(",")));}String synsetUrl = (String) arguments.get("synsetUrl");if (synsetUrl != null) {optSynsetUrl(synsetUrl);}String synsetFileName = (String) arguments.get("synsetFileName");if (synsetFileName != null) {optSynsetArtifactName(synsetFileName);}}}protected static final class SynsetLoader {private String synsetFileName;private URL synsetUrl;private List<String> synset;public SynsetLoader(List<String> synset) {this.synset = synset;}public SynsetLoader(URL synsetUrl) {this.synsetUrl = synsetUrl;}public SynsetLoader(String synsetFileName) {this.synsetFileName = synsetFileName;}public List<String> load(Model model) throws IOException {if (synset != null) {return synset;} else if (synsetUrl != null) {try (InputStream is = synsetUrl.openStream()) {return Utils.readLines(is);}}return model.getArtifact(synsetFileName, Utils::readLines);}}
}
5、创建向量索引字段
需要在索引库创建的时候,一并创建对应字段。
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.mapping.Property;
import co.elastic.clients.elasticsearch._types.mapping.TypeMapping;
import co.elastic.clients.elasticsearch.indices.Alias;
import co.elastic.clients.elasticsearch.indices.CreateIndexRequest;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import co.elastic.clients.elasticsearch.indices.ExistsRequest;CreateIndexResponse response = null;try {TypeMapping.Builder tmBuilder = new TypeMapping.Builder();// 图片相似检索,采用点积运算 文本相似采用余线相似tmBuilder.properties('_img_vector', new Property.Builder().denseVector(builder -> builder.index(true).dims(1024).similarity("dot_product").indexOptions(opBuilder -> opBuilder.type("hnsw").m(12).efConstruction(100))).build());TypeMapping typeMapping = tmBuilder.build();CreateIndexRequest request = CreateIndexRequest.of(builder -> builder.index(indexName).aliases(indexName + "_alias", new Alias.Builder().isWriteIndex(true).build()).mappings(typeMapping));response = esClient.indices().create(request);log.info("acknowledged: {}", response.acknowledged());log.info("index: {}", response.index());log.info("shardsAcknowledged: {}", response.shardsAcknowledged());flag = response.acknowledged();} catch (IOException e) {e.printStackTrace();}
创建后生成的结构数据如下
6、添加到ES
float[] feature;
// 自定义属性字段数据,构建文档Map<String, Object> dataMap = req.getDataMap();// 自定义内置参数dataMap.put("_es_doc_type", "IMAGE");dataMap.put("_img_vector", feature);IndexRequest<Map> request = IndexRequest.of(i -> i.index(req.getIndexLib()).id(req.getDocId()).document(dataMap));IndexResponse response = esClient.index(request);boolean flag = response.result() == Result.Created;log.info("添加文档id={},结果={}", req.getDocId(), flag);
实际存储的数据结构如下图
7、pytorch环境依赖
cpu/linux-x86_64/native/lib/libc10.so.gz
cpu/linux-x86_64/native/lib/libtorch_cpu.so.gz
cpu/linux-x86_64/native/lib/libtorch.so.gz
cpu/linux-x86_64/native/lib/libgomp-52f2fd74.so.1.gz
cpu/osx-aarch64/native/lib/libtorch_cpu.dylib.gz
cpu/osx-aarch64/native/lib/libtorch.dylib.gz
cpu/osx-aarch64/native/lib/libc10.dylib.gz
cpu/osx-x86_64/native/lib/libtorch_cpu.dylib.gz
cpu/osx-x86_64/native/lib/libiomp5.dylib.gz
cpu/osx-x86_64/native/lib/libtorch.dylib.gz
cpu/osx-x86_64/native/lib/libc10.dylib.gz
cpu/win-x86_64/native/lib/torch.dll.gz
cpu/win-x86_64/native/lib/uv.dll.gz
cpu/win-x86_64/native/lib/torch_cpu.dll.gz
cpu/win-x86_64/native/lib/c10.dll.gz
cpu/win-x86_64/native/lib/fbgemm.dll.gz
cpu/win-x86_64/native/lib/libiomp5md.dll.gz
cpu/win-x86_64/native/lib/asmjit.dll.gz
cpu/win-x86_64/native/lib/libiompstubs5md.dll.gz
cpu-precxx11/linux-aarch64/native/lib/libc10.so.gz
cpu-precxx11/linux-aarch64/native/lib/libtorch_cpu.so.gz
cpu-precxx11/linux-aarch64/native/lib/libarm_compute-973e5a6b.so.gz
cpu-precxx11/linux-aarch64/native/lib/libopenblasp-r0-56e95da7.3.24.so.gz
cpu-precxx11/linux-aarch64/native/lib/libtorch.so.gz
cpu-precxx11/linux-aarch64/native/lib/libarm_compute_graph-6990f339.so.gz
cpu-precxx11/linux-aarch64/native/lib/libstdc%2B%2B.so.6.gz
cpu-precxx11/linux-aarch64/native/lib/libarm_compute_core-0793f69d.so.gz
cpu-precxx11/linux-aarch64/native/lib/libgfortran-b6d57c85.so.5.0.0.gz
cpu-precxx11/linux-aarch64/native/lib/libgomp-6e1a1d1b.so.1.0.0.gz
cpu-precxx11/linux-x86_64/native/lib/libgomp-a34b3233.so.1.gz
cpu-precxx11/linux-x86_64/native/lib/libc10.so.gz
cpu-precxx11/linux-x86_64/native/lib/libtorch_cpu.so.gz
cpu-precxx11/linux-x86_64/native/lib/libtorch.so.gz
cpu-precxx11/linux-x86_64/native/lib/libstdc%2B%2B.so.6.gz
cu121/linux-x86_64/native/lib/libc10_cuda.so.gz
cu121/linux-x86_64/native/lib/libcudnn.so.8.gz
cu121/linux-x86_64/native/lib/libnvfuser_codegen.so.gz
cu121/linux-x86_64/native/lib/libc10.so.gz
cu121/linux-x86_64/native/lib/libtorch_cpu.so.gz
cu121/linux-x86_64/native/lib/libcaffe2_nvrtc.so.gz
cu121/linux-x86_64/native/lib/libcudnn_adv_infer.so.8.gz
cu121/linux-x86_64/native/lib/libcudnn_cnn_train.so.8.gz
cu121/linux-x86_64/native/lib/libcudnn_ops_infer.so.8.gz
cu121/linux-x86_64/native/lib/libnvrtc-builtins-6c5639ce.so.12.1.gz
cu121/linux-x86_64/native/lib/libnvrtc-b51b459d.so.12.gz
cu121/linux-x86_64/native/lib/libtorch.so.gz
cu121/linux-x86_64/native/lib/libtorch_cuda_linalg.so.gz
cu121/linux-x86_64/native/lib/libcublas-37d11411.so.12.gz
cu121/linux-x86_64/native/lib/libtorch_cuda.so.gz
cu121/linux-x86_64/native/lib/libcudnn_adv_train.so.8.gz
cu121/linux-x86_64/native/lib/libcublasLt-f97bfc2c.so.12.gz
cu121/linux-x86_64/native/lib/libnvToolsExt-847d78f2.so.1.gz
cu121/linux-x86_64/native/lib/libcudnn_ops_train.so.8.gz
cu121/linux-x86_64/native/lib/libcudnn_cnn_infer.so.8.gz
cu121/linux-x86_64/native/lib/libgomp-52f2fd74.so.1.gz
cu121/linux-x86_64/native/lib/libcudart-9335f6a2.so.12.gz
cu121/win-x86_64/native/lib/zlibwapi.dll.gz
cu121/win-x86_64/native/lib/cudnn_ops_train64_8.dll.gz
cu121/win-x86_64/native/lib/torch.dll.gz
cu121/win-x86_64/native/lib/nvrtc-builtins64_121.dll.gz
cu121/win-x86_64/native/lib/cufftw64_11.dll.gz
cu121/win-x86_64/native/lib/cudnn_adv_infer64_8.dll.gz
cu121/win-x86_64/native/lib/nvrtc64_120_0.dll.gz
cu121/win-x86_64/native/lib/cusolverMg64_11.dll.gz
cu121/win-x86_64/native/lib/torch_cuda.dll.gz
cu121/win-x86_64/native/lib/cufft64_11.dll.gz
cu121/win-x86_64/native/lib/cublas64_12.dll.gz
cu121/win-x86_64/native/lib/cudnn64_8.dll.gz
cu121/win-x86_64/native/lib/uv.dll.gz
cu121/win-x86_64/native/lib/cudnn_cnn_train64_8.dll.gz
cu121/win-x86_64/native/lib/caffe2_nvrtc.dll.gz
cu121/win-x86_64/native/lib/torch_cpu.dll.gz
cu121/win-x86_64/native/lib/c10.dll.gz
cu121/win-x86_64/native/lib/cudnn_cnn_infer64_8.dll.gz
cu121/win-x86_64/native/lib/c10_cuda.dll.gz
cu121/win-x86_64/native/lib/cudart64_12.dll.gz
cu121/win-x86_64/native/lib/nvfuser_codegen.dll.gz
cu121/win-x86_64/native/lib/fbgemm.dll.gz
cu121/win-x86_64/native/lib/curand64_10.dll.gz
cu121/win-x86_64/native/lib/libiomp5md.dll.gz
cu121/win-x86_64/native/lib/cusolver64_11.dll.gz
cu121/win-x86_64/native/lib/cudnn_adv_train64_8.dll.gz
cu121/win-x86_64/native/lib/cublasLt64_12.dll.gz
cu121/win-x86_64/native/lib/nvToolsExt64_1.dll.gz
cu121/win-x86_64/native/lib/nvJitLink_120_0.dll.gz
cu121/win-x86_64/native/lib/cusparse64_12.dll.gz
cu121/win-x86_64/native/lib/asmjit.dll.gz
cu121/win-x86_64/native/lib/cudnn_ops_infer64_8.dll.gz
cu121/win-x86_64/native/lib/libiompstubs5md.dll.gz
cu121/win-x86_64/native/lib/cupti64_2023.1.1.dll.gz
cu121-precxx11/linux-x86_64/native/lib/libgomp-a34b3233.so.1.gz
cu121-precxx11/linux-x86_64/native/lib/libc10_cuda.so.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libnvfuser_codegen.so.gz
cu121-precxx11/linux-x86_64/native/lib/libc10.so.gz
cu121-precxx11/linux-x86_64/native/lib/libtorch_cpu.so.gz
cu121-precxx11/linux-x86_64/native/lib/libcaffe2_nvrtc.so.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_adv_infer.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_cnn_train.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_ops_infer.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libnvrtc-builtins-6c5639ce.so.12.1.gz
cu121-precxx11/linux-x86_64/native/lib/libnvrtc-b51b459d.so.12.gz
cu121-precxx11/linux-x86_64/native/lib/libtorch.so.gz
cu121-precxx11/linux-x86_64/native/lib/libtorch_cuda_linalg.so.gz
cu121-precxx11/linux-x86_64/native/lib/libcublas-37d11411.so.12.gz
cu121-precxx11/linux-x86_64/native/lib/libtorch_cuda.so.gz
cu121-precxx11/linux-x86_64/native/lib/libstdc%2B%2B.so.6.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_adv_train.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libcublasLt-f97bfc2c.so.12.gz
cu121-precxx11/linux-x86_64/native/lib/libnvToolsExt-847d78f2.so.1.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_ops_train.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libcudnn_cnn_infer.so.8.gz
cu121-precxx11/linux-x86_64/native/lib/libcudart-9335f6a2.so.12.gz