欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 金融 > TensorFlow音频分类修复

TensorFlow音频分类修复

2024/10/25 2:18:33 来源:https://blog.csdn.net/tiantiantbtb/article/details/139754576  浏览:    关键词:TensorFlow音频分类修复

原先传wav格式,后来发现前端生成的wav格式不完整   后端改mp3  其实是mp3和wav都可以接收 

但问题是wav文件格式错误的话无法获取时长,这样也就是说只能传mp3,除非你对时长无所谓

修复TensorFlow放到生产后报错问题-CSDN博客

依赖

  <dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-api</artifactId><version>0.4.2</version></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-api</artifactId><version>0.4.2</version><classifier>linux-x86_64</classifier></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-api</artifactId><version>0.4.2</version><classifier>windows-x86_64</classifier></dependency><!-- https://mvnrepository.com/artifact/com.googlecode.soundlibs/jlayer --><dependency><groupId>com.googlecode.soundlibs</groupId><artifactId>jlayer</artifactId><version>1.0.1.4</version></dependency>

TensorFlow工具类

package com.ruoyi.webapp.tensorflow;import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;
import javax.sound.sampled.*;
import java.io.*;
import java.nio.file.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.types.TFloat32;
import com.google.protobuf.InvalidProtocolBufferException;import javazoom.jl.decoder.Bitstream;
import javazoom.jl.decoder.BitstreamException;
import javazoom.jl.decoder.Decoder;
import javazoom.jl.decoder.JavaLayerException;
import javazoom.jl.decoder.SampleBuffer;
import javazoom.jl.decoder.Header;@Component
public class YamnetUtils3 {private static final int SAMPLE_RATE = 16000;private static final int FRAME_SIZE_IN_MS = 96;private static final int HOP_SIZE_IN_MS = 48;//private static final String MODEL_PATH = "C:\\Users\\user\\Downloads\\archive";private static final String MODEL_PATH = "/usr/local/develop/archive"; // TensorFlow 模型路径private static Map<String, SignatureDef> signatureDefMap;static {try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();} catch (InvalidProtocolBufferException e) {e.printStackTrace();}}private static final SignatureDef modelSig = signatureDefMap.get("serving_default");private static final String inputTensorName = modelSig.getInputsMap().get("waveform").getName();private static final String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();private static Map<String, String> map = new ConcurrentHashMap<>();static {//String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";String csvFile = "/usr/local/develop/archive/assets/yamnet_class_map.csv";try {List<String> lines = Files.readAllLines(Paths.get(csvFile));for (String line : lines) {String[] values = line.split(",");map.put(values[0], values[2]);}} catch (IOException e) {e.printStackTrace();}}public String classifyAudio(MultipartFile file) throws IOException, UnsupportedAudioFileException {// Convert the MP3 file to a supported format (WAV)File wavFile = convertMp3ToWav(file);// Process the converted filereturn yamnetPare(wavFile);}private File convertMp3ToWav(MultipartFile file) throws IOException {File tempFile = File.createTempFile("temp", ".wav");File mp3File = File.createTempFile("temp", ".mp3");file.transferTo(mp3File);try (FileInputStream mp3Stream = new FileInputStream(mp3File);FileOutputStream wavStream = new FileOutputStream(tempFile)) {Bitstream bitstream = new Bitstream(mp3Stream);Decoder decoder = new Decoder();// Write WAV headerwriteWavHeader(wavStream, 0, 1, SAMPLE_RATE, 16);Header header;while ((header = bitstream.readFrame()) != null) {SampleBuffer output = (SampleBuffer) decoder.decodeFrame(header, bitstream);short[] samples = output.getBuffer();for (short sample : samples) {wavStream.write(shortToBytes(sample));}bitstream.closeFrame();}// Update WAV header with data sizeupdateWavHeader(tempFile);} catch (Exception e) {throw new IOException("Failed to convert MP3 to WAV", e);}return tempFile;}private void writeWavHeader(OutputStream out, long totalAudioLen, int channels, long sampleRate, int bitDepth) throws IOException {long totalDataLen = totalAudioLen + 36;long byteRate = sampleRate * channels * bitDepth / 8;byte[] header = new byte[44];header[0] = 'R';  // RIFF/WAVE headerheader[1] = 'I';header[2] = 'F';header[3] = 'F';header[4] = (byte) (totalDataLen & 0xff);header[5] = (byte) ((totalDataLen >> 8) & 0xff);header[6] = (byte) ((totalDataLen >> 16) & 0xff);header[7] = (byte) ((totalDataLen >> 24) & 0xff);header[8] = 'W';header[9] = 'A';header[10] = 'V';header[11] = 'E';header[12] = 'f';  // 'fmt ' chunkheader[13] = 'm';header[14] = 't';header[15] = ' ';header[16] = 16;  // 4 bytes: size of 'fmt ' chunkheader[17] = 0;header[18] = 0;header[19] = 0;header[20] = 1;  // format = 1header[21] = 0;header[22] = (byte) channels;header[23] = 0;header[24] = (byte) (sampleRate & 0xff);header[25] = (byte) ((sampleRate >> 8) & 0xff);header[26] = (byte) ((sampleRate >> 16) & 0xff);header[27] = (byte) ((sampleRate >> 24) & 0xff);header[28] = (byte) (byteRate & 0xff);header[29] = (byte) ((byteRate >> 8) & 0xff);header[30] = (byte) ((byteRate >> 16) & 0xff);header[31] = (byte) ((byteRate >> 24) & 0xff);header[32] = (byte) (2 * 16 / 8);  // block alignheader[33] = 0;header[34] = (byte) bitDepth;  // bits per sampleheader[35] = 0;header[36] = 'd';header[37] = 'a';header[38] = 't';header[39] = 'a';header[40] = (byte) (totalAudioLen & 0xff);header[41] = (byte) ((totalAudioLen >> 8) & 0xff);header[42] = (byte) ((totalAudioLen >> 16) & 0xff);header[43] = (byte) ((totalAudioLen >> 24) & 0xff);out.write(header, 0, 44);}private void updateWavHeader(File wavFile) throws IOException {RandomAccessFile wavRAF = new RandomAccessFile(wavFile, "rw");wavRAF.seek(4);wavRAF.write(intToBytes((int) (wavRAF.length() - 8)));wavRAF.seek(40);wavRAF.write(intToBytes((int) (wavRAF.length() - 44)));wavRAF.close();}private byte[] intToBytes(int value) {return new byte[]{(byte) (value & 0xFF),(byte) ((value >> 8) & 0xFF),(byte) ((value >> 16) & 0xFF),(byte) ((value >> 24) & 0xFF)};}private byte[] shortToBytes(short value) {return new byte[]{(byte) (value & 0xFF),(byte) ((value >> 8) & 0xFF)};}private FloatNdArray processAudio(File file) throws IOException, UnsupportedAudioFileException {try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(file)) {AudioFormat format = audioStream.getFormat();if (format.getSampleRate() != SAMPLE_RATE || format.getChannels() != 1) {System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");}int frameSize = (int) (SAMPLE_RATE * FRAME_SIZE_IN_MS / 1000);int hopSize = (int) (SAMPLE_RATE * HOP_SIZE_IN_MS / 1000);byte[] buffer = new byte[frameSize * format.getFrameSize()];short[] audioSamples = new short[frameSize];List<Float> floatList = new ArrayList<>();while (true) {int bytesRead = audioStream.read(buffer);if (bytesRead == -1) {break;}for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));}float[] floats = normalizeAudio(audioSamples);for (float aFloat : floats) {floatList.add(aFloat);}System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);}float[] floatArray = new float[floatList.size()];for (int i = 0; i < floatList.size(); i++) {floatArray[i] = floatList.get(i);}return StdArrays.ndCopyOf(floatArray);}}private float[] normalizeAudio(short[] frame) {float[] normalizedFrame = new float[frame.length];for (int i = 0; i < frame.length; i++) {normalizedFrame[i] = frame[i] / 32768f;}return normalizedFrame;}private String yamnetPare(File file) throws IOException, UnsupportedAudioFileException {FloatNdArray floatNdArray = processAudio(file);TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {try (Session session = savedModelBundle.session()) {List<Tensor> run = session.runner().feed(inputTensorName, tFloat32).fetch(outputTensorName).run();Tensor tensor = run.get(0);Shape shape = tensor.shape();System.out.println(shape + "--------------------------------------------------");String key = String.valueOf(shape.asArray()[0]);String value = map.get(key);return value;}}}
}

文件接口

 //记录文件@PostMapping("/reportUpload")//@RequireVip//睡眠监测时 文件上传接口public AjaxResult reportUpload(MultipartFile file) throws IOException {try {//获取文件名String originalFilename = file.getOriginalFilename();//获取时长秒//String wavDuration = getWavDuration(file);//String wavDuration = getWavDuration3(file);//String wavDuration = getWavDuration55(file);String filePath = RuoYiConfig.getUploadPath();///usr/local/nginx/html/upload/upload//上传并返回新文件路径String fileName = FileUploadUtils.upload(filePath, file);//获取音频类型//String s = yamnetUtils2.yamnetPare(file);String s = yamnetUtils3.classifyAudio(file);//获取时长秒//String wavDuration =getWavDuration(fileName);getMp3DurationString wavDuration =getMp3Duration(fileName);if(!StringUtils.isEmpty(s)){if (!s.equals("Speech") && !s.equals("Snoring") && !s.equals("Cough")) {s = "Other";}TReportFile tReportFile=new TReportFile();tReportFile.setUid(getLoginUser().getUserId());tReportFile.setFileType(s);tReportFile.setLengths(wavDuration);tReportFile.setFilePath(fileName);//tReportFile.setCreateTime(new Date());tReportFileService.insertTReportFile(tReportFile);}return success();}catch (Exception e){e.printStackTrace();return error();}}

 //获取wav文件时长private String getWavDuration(String relativeFilePath) {String basePath = "/usr/local/nginx/html/upload";if (relativeFilePath.startsWith("/profile")) {relativeFilePath = relativeFilePath.substring(8);}String fullPath = basePath + relativeFilePath;File file = new File(fullPath);if (!file.exists()) {return "File not found";}try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file)) {AudioFileFormat fileFormat = AudioSystem.getAudioFileFormat(file);long frameLength = audioInputStream.getFrameLength();float frameRate = fileFormat.getFormat().getFrameRate();float durationInSeconds = frameLength / frameRate;return Math.round(durationInSeconds) + "s";} catch (UnsupportedAudioFileException | IOException e) {e.printStackTrace();return "Error";}}//获取Mp3文件时长public static String getMp3Duration(String relativeFilePath) {String basePath = "/usr/local/nginx/html/upload";if (relativeFilePath.startsWith("/profile")) {relativeFilePath = relativeFilePath.substring(8);}String fullPath = basePath + relativeFilePath;File file = new File(fullPath);if (!file.exists()) {return "File not found";}try (FileInputStream fileInputStream = new FileInputStream(file)) {Bitstream bitstream = new Bitstream(fileInputStream);Header header;int totalFrames = 0;float totalDuration = 0;while ((header = bitstream.readFrame()) != null) {totalDuration += header.ms_per_frame();totalFrames++;bitstream.closeFrame();}float durationInSeconds = totalDuration / 1000.0f;return Math.round(durationInSeconds) + "s";} catch (IOException | BitstreamException e) {e.printStackTrace();return "Error";}}

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com