#include "KWSManager.h" #include #include #include #include KWSManager::KWSManager(QObject* parent) : QObject(parent), initialized(false) { // 初始化配置结构体 memset(&kwsConfig, 0, sizeof(kwsConfig)); } KWSManager::~KWSManager() { cleanup(); } QString KWSManager::getDefaultModelPath() const { QString dataPath = QDir::homePath() + "/.config/QSmartAssistant/Data"; return dataPath + "/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; } QString KWSManager::getDefaultTokensPath() const { QString dataPath = QDir::homePath() + "/.config/QSmartAssistant/Data"; return dataPath + "/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; } QString KWSManager::getDefaultKeywordsPath() const { QString dataPath = QDir::homePath() + "/.config/QSmartAssistant/Data"; return dataPath + "/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/keywords.txt"; } bool KWSManager::initialize() { // 临时强制使用正确的路径,绕过设置系统 QString forcedModelPath = getDefaultModelPath(); QString forcedTokensPath = getDefaultTokensPath(); QString forcedKeywordsPath = getDefaultKeywordsPath(); qDebug() << "KWS初始化 - 强制使用正确的路径:"; qDebug() << "模型:" << forcedModelPath; qDebug() << "词汇表:" << forcedTokensPath; qDebug() << "关键词:" << forcedKeywordsPath; return initialize(forcedModelPath, forcedTokensPath, forcedKeywordsPath); } bool KWSManager::initialize(const QString& modelPath, const QString& tokensPath, const QString& keywordsPath) { qDebug() << "初始化KWS管理器"; qDebug() << "模型路径:" << modelPath; qDebug() << "词汇表路径:" << tokensPath; qDebug() << "关键词路径:" << keywordsPath; // 检查文件是否存在 if (!QFileInfo::exists(modelPath)) { qDebug() << "KWS模型文件不存在:" << modelPath; return false; } if (!QFileInfo::exists(tokensPath)) { qDebug() << "KWS词汇表文件不存在:" << tokensPath; return false; } if (!QFileInfo::exists(keywordsPath)) { qDebug() << "KWS关键词文件不存在:" << keywordsPath; return false; } // 清理之前的配置 cleanup(); // 保存路径 this->modelPath = modelPath; this->tokensPath = tokensPath; this->keywordsPath = keywordsPath; // 配置KWS参数 kwsConfig.feat_config.sample_rate = 16000; kwsConfig.feat_config.feature_dim = 80; // 设置模型路径(需要转换为C字符串) QByteArray modelPathBytes = modelPath.toUtf8(); QByteArray tokensPathBytes = tokensPath.toUtf8(); QByteArray keywordsPathBytes = keywordsPath.toUtf8(); qDebug() << "KWS模型路径:" << modelPath; // 构建decoder和joiner路径 QString basePath = QFileInfo(modelPath).absolutePath(); QString decoderPath = basePath + "/decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; QString joinerPath = basePath + "/joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; QByteArray decoderPathBytes = decoderPath.toUtf8(); QByteArray joinerPathBytes = joinerPath.toUtf8(); qDebug() << "Encoder路径:" << modelPath; qDebug() << "Decoder路径:" << decoderPath; qDebug() << "Joiner路径:" << joinerPath; // 检查所有必需文件是否存在 if (!QFileInfo::exists(decoderPath)) { qDebug() << "KWS Decoder文件不存在:" << decoderPath; return false; } if (!QFileInfo::exists(joinerPath)) { qDebug() << "KWS Joiner文件不存在:" << joinerPath; return false; } // 注意:这里需要确保字符串在KWS使用期间保持有效 // 尝试使用transducer配置 kwsConfig.model_config.transducer.encoder = strdup(modelPathBytes.constData()); kwsConfig.model_config.transducer.decoder = strdup(decoderPathBytes.constData()); kwsConfig.model_config.transducer.joiner = strdup(joinerPathBytes.constData()); kwsConfig.model_config.tokens = strdup(tokensPathBytes.constData()); kwsConfig.keywords_file = strdup(keywordsPathBytes.constData()); // 添加调试信息 qDebug() << "配置后的Encoder路径:" << kwsConfig.model_config.transducer.encoder; qDebug() << "配置后的Decoder路径:" << kwsConfig.model_config.transducer.decoder; qDebug() << "配置后的Joiner路径:" << kwsConfig.model_config.transducer.joiner; qDebug() << "配置后的词汇表路径:" << kwsConfig.model_config.tokens; qDebug() << "配置后的关键词路径:" << kwsConfig.keywords_file; // 从设置中读取KWS参数 QSettings settings; settings.beginGroup("KWS"); float threshold = settings.value("threshold", 0.25f).toFloat(); int maxActivePaths = settings.value("maxActivePaths", 8).toInt(); int numTrailingBlanks = settings.value("numTrailingBlanks", 2).toInt(); float keywordsScore = settings.value("keywordsScore", 1.5f).toFloat(); int numThreads = settings.value("numThreads", 2).toInt(); settings.endGroup(); // 应用参数(带范围验证) kwsConfig.max_active_paths = qBound(1, maxActivePaths, 16); kwsConfig.num_trailing_blanks = qBound(1, numTrailingBlanks, 5); kwsConfig.keywords_score = qBound(0.5f, keywordsScore, 3.0f); kwsConfig.keywords_threshold = qBound(0.01f, threshold, 1.0f); kwsConfig.model_config.num_threads = qBound(1, numThreads, 4); kwsConfig.model_config.provider = "cpu"; kwsConfig.model_config.model_type = ""; qDebug() << "KWS配置完成"; qDebug() << "采样率:" << kwsConfig.feat_config.sample_rate; qDebug() << "特征维度:" << kwsConfig.feat_config.feature_dim; qDebug() << "关键词阈值:" << kwsConfig.keywords_threshold; initialized = true; qDebug() << "KWS管理器初始化成功"; return true; } bool KWSManager::isInitialized() const { return initialized; } const SherpaOnnxKeywordSpotter* KWSManager::createKeywordSpotter() { if (!initialized) { qDebug() << "KWS管理器未初始化,无法创建关键词检测器"; return nullptr; } qDebug() << "创建KWS关键词检测器"; const SherpaOnnxKeywordSpotter* spotter = SherpaOnnxCreateKeywordSpotter(&kwsConfig); if (!spotter) { qDebug() << "创建KWS关键词检测器失败"; return nullptr; } qDebug() << "KWS关键词检测器创建成功"; return spotter; } void KWSManager::destroyKeywordSpotter(const SherpaOnnxKeywordSpotter* spotter) { if (spotter) { qDebug() << "销毁KWS关键词检测器"; SherpaOnnxDestroyKeywordSpotter(spotter); } } const SherpaOnnxOnlineStream* KWSManager::createKeywordStream(const SherpaOnnxKeywordSpotter* spotter) { if (!spotter) { qDebug() << "关键词检测器为空,无法创建流"; return nullptr; } qDebug() << "创建KWS关键词流"; const SherpaOnnxOnlineStream* stream = SherpaOnnxCreateKeywordStream(spotter); if (!stream) { qDebug() << "创建KWS关键词流失败"; return nullptr; } qDebug() << "KWS关键词流创建成功"; return stream; } void KWSManager::destroyKeywordStream(const SherpaOnnxOnlineStream* stream) { if (stream) { qDebug() << "销毁KWS关键词流"; SherpaOnnxDestroyOnlineStream(stream); } } void KWSManager::acceptWaveform(const SherpaOnnxOnlineStream* stream, const float* samples, int sampleCount) { if (!stream || !samples || sampleCount <= 0) { return; } // 接受音频波形数据 SherpaOnnxOnlineStreamAcceptWaveform(stream, 16000, samples, sampleCount); } bool KWSManager::isReady(const SherpaOnnxOnlineStream* stream, const SherpaOnnxKeywordSpotter* spotter) { if (!stream || !spotter) { return false; } return SherpaOnnxIsKeywordStreamReady(spotter, stream) != 0; } void KWSManager::decode(const SherpaOnnxOnlineStream* stream, const SherpaOnnxKeywordSpotter* spotter) { if (!stream || !spotter) { return; } SherpaOnnxDecodeKeywordStream(spotter, stream); } QString KWSManager::getResult(const SherpaOnnxOnlineStream* stream, const SherpaOnnxKeywordSpotter* spotter) { if (!stream || !spotter) { return QString(); } const SherpaOnnxKeywordResult* result = SherpaOnnxGetKeywordResult(spotter, stream); if (!result) { return QString(); } QString keyword = QString::fromUtf8(result->keyword); // 释放结果内存 SherpaOnnxDestroyKeywordResult(result); return keyword; } QString KWSManager::getPartialText(const SherpaOnnxOnlineStream* stream, const SherpaOnnxKeywordSpotter* spotter) { if (!stream || !spotter) { return QString(); } const SherpaOnnxKeywordResult* result = SherpaOnnxGetKeywordResult(spotter, stream); if (!result) { return QString(); } // 获取tokens字段,这包含了部分识别的文本 QString partialText = QString::fromUtf8(result->tokens ? result->tokens : ""); // 释放结果内存 SherpaOnnxDestroyKeywordResult(result); return partialText; } void KWSManager::reset(const SherpaOnnxOnlineStream* stream, const SherpaOnnxKeywordSpotter* spotter) { if (!stream || !spotter) { return; } SherpaOnnxResetKeywordStream(spotter, stream); } void KWSManager::cleanup() { if (kwsConfig.model_config.transducer.encoder) { free(const_cast(kwsConfig.model_config.transducer.encoder)); kwsConfig.model_config.transducer.encoder = nullptr; } if (kwsConfig.model_config.transducer.decoder) { free(const_cast(kwsConfig.model_config.transducer.decoder)); kwsConfig.model_config.transducer.decoder = nullptr; } if (kwsConfig.model_config.transducer.joiner) { free(const_cast(kwsConfig.model_config.transducer.joiner)); kwsConfig.model_config.transducer.joiner = nullptr; } if (kwsConfig.model_config.tokens) { free(const_cast(kwsConfig.model_config.tokens)); kwsConfig.model_config.tokens = nullptr; } if (kwsConfig.keywords_file) { free(const_cast(kwsConfig.keywords_file)); kwsConfig.keywords_file = nullptr; } initialized = false; qDebug() << "KWS管理器清理完成"; }