Skip to content

Commit d325a49

Browse files
committed
use phoneme matching
1 parent 038ae63 commit d325a49

File tree

7 files changed

+1078
-256
lines changed

7 files changed

+1078
-256
lines changed

Sources/FluidAudio/ASR/AsrManager.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public final class AsrManager {
3333
}
3434
#endif
3535

36+
/// Custom vocabulary context for context biasing
37+
private var customVocabulary: CustomVocabularyContext?
38+
3639
// TODO:: the decoder state should be moved higher up in the API interface
3740
internal var microphoneDecoderState: TdtDecoderState
3841
internal var systemDecoderState: TdtDecoderState
@@ -92,6 +95,25 @@ public final class AsrManager {
9295
logger.info("AsrManager initialized successfully with provided models")
9396
}
9497

98+
/// Update custom vocabulary for context biasing without reinitializing ASR
99+
/// - Parameter vocabulary: New custom vocabulary context, or nil to disable context biasing
100+
public func setCustomVocabulary(_ vocabulary: CustomVocabularyContext?) {
101+
self.customVocabulary = vocabulary
102+
if let vocab = vocabulary {
103+
logger.info(
104+
"Custom vocabulary updated: \(vocab.terms.count) terms, "
105+
+ "thresholds: similarity=\(String(format: "%.2f", vocab.minSimilarity)), "
106+
+ "combined=\(String(format: "%.2f", vocab.minCombinedConfidence))")
107+
} else {
108+
logger.info("Custom vocabulary disabled")
109+
}
110+
}
111+
112+
/// Get current custom vocabulary
113+
public func getCustomVocabulary() -> CustomVocabularyContext? {
114+
return customVocabulary
115+
}
116+
95117
private func createFeatureProvider(
96118
features: [(name: String, array: MLMultiArray)]
97119
) throws

Sources/FluidAudio/ASR/ContextBiasing/CtcKeywordSpotter.swift

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ public struct CtcKeywordSpotter {
141141
}
142142
}
143143
print(" frame[\(frameIdx)]: \(tokenLogProbs.joined(separator: ", "))", to: &standardError)
144+
145+
// Show top 10 most likely tokens at this frame
146+
let topK = 10
147+
let sortedIndices = frame.enumerated()
148+
.sorted { $0.element > $1.element }
149+
.prefix(topK)
150+
print(" top-\(topK) tokens: ", terminator: "", to: &standardError)
151+
let topTokens = sortedIndices.map { "id\($0.offset)=\(String(format: "%.4f", $0.element))" }
152+
print(topTokens.joined(separator: ", "), to: &standardError)
144153
}
145154
}
146155
}
@@ -451,6 +460,9 @@ public struct CtcKeywordSpotter {
451460

452461
/// Dynamic programming keyword alignment, ported from
453462
/// `NeMo/scripts/asr_context_biasing/ctc_word_spotter.py:ctc_word_spot`.
463+
// Wildcard token ID: -1 represents "*" that matches anything at zero cost
464+
private static let WILDCARD_TOKEN_ID = -1
465+
454466
func ctcWordSpot(
455467
logProbs: [[Float]],
456468
keywordTokens: [Int]
@@ -483,6 +495,18 @@ public struct CtcKeywordSpotter {
483495
for n in 1...N {
484496
let tokenId = keywordTokens[n - 1]
485497

498+
// Wildcard token: matches any symbol (including blank) at zero cost
499+
if tokenId == Self.WILDCARD_TOKEN_ID {
500+
// Wildcard can skip this frame at zero cost
501+
let wildcardSkip = dp[t - 1][n - 1] // Move to next token
502+
let wildcardStay = dp[t - 1][n] // Stay on wildcard
503+
504+
let wildcardScore = max(wildcardSkip, wildcardStay)
505+
dp[t][n] = wildcardScore
506+
backtrackTime[t][n] = wildcardScore == wildcardSkip ? t - 1 : backtrackTime[t - 1][n]
507+
continue
508+
}
509+
486510
if tokenId < 0 || tokenId >= frame.count {
487511
continue
488512
}
@@ -522,7 +546,10 @@ public struct CtcKeywordSpotter {
522546
}
523547

524548
let bestStart = backtrackTime[bestEnd][N]
525-
let normalizedScore = bestScore / Float(N)
549+
550+
// Normalize score only by non-wildcard tokens
551+
let nonWildcardCount = keywordTokens.filter { $0 != Self.WILDCARD_TOKEN_ID }.count
552+
let normalizedScore = nonWildcardCount > 0 ? bestScore / Float(nonWildcardCount) : bestScore
526553

527554
return (normalizedScore, bestStart, bestEnd)
528555
}

Sources/FluidAudio/ASR/ContextBiasing/CustomVocabularyContext.swift

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import OSLog
23

34
/// A single custom vocabulary entry.
45
public struct CustomVocabularyTerm: Codable, Sendable {
@@ -21,6 +22,11 @@ public struct CustomVocabularyConfig: Codable, Sendable {
2122
public let depthScaling: Float?
2223
public let scorePerPhrase: Float?
2324
public let terms: [CustomVocabularyTerm]
25+
26+
// CTC keyword boosting confidence thresholds
27+
public let minCtcScore: Float?
28+
public let minSimilarity: Float?
29+
public let minCombinedConfidence: Float?
2430
}
2531

2632
/// Runtime context used by the decoder biasing system.
@@ -31,35 +37,105 @@ public struct CustomVocabularyContext: Sendable {
3137
public let depthScaling: Float
3238
public let scorePerPhrase: Float
3339

40+
// CTC keyword boosting confidence thresholds
41+
public let minCtcScore: Float
42+
public let minSimilarity: Float
43+
public let minCombinedConfidence: Float
44+
3445
public init(
3546
terms: [CustomVocabularyTerm],
3647
alpha: Float = 0.5,
3748
contextScore: Float = 1.2,
3849
depthScaling: Float = 2.0,
39-
scorePerPhrase: Float = 0.0
50+
scorePerPhrase: Float = 0.0,
51+
minCtcScore: Float = -10.0,
52+
minSimilarity: Float = 0.50,
53+
minCombinedConfidence: Float = 0.54
4054
) {
4155
self.terms = terms
4256
self.alpha = alpha
4357
self.contextScore = contextScore
4458
self.depthScaling = depthScaling
4559
self.scorePerPhrase = scorePerPhrase
60+
self.minCtcScore = minCtcScore
61+
self.minSimilarity = minSimilarity
62+
self.minCombinedConfidence = minCombinedConfidence
4663
}
4764

4865
/// Load a custom vocabulary JSON file produced by the analysis tooling.
4966
public static func load(from url: URL) throws -> CustomVocabularyContext {
67+
let logger = Logger(subsystem: "com.fluidaudio", category: "CustomVocabulary")
5068
let data = try Data(contentsOf: url)
5169
let config = try JSONDecoder().decode(CustomVocabularyConfig.self, from: data)
5270

5371
let alpha = config.alpha ?? 0.5
5472
let contextScore = config.contextScore ?? 1.2
5573
let depthScaling = config.depthScaling ?? 2.0
5674
let scorePerPhrase = config.scorePerPhrase ?? 0.0
75+
let minCtcScore = config.minCtcScore ?? -10.0
76+
let minSimilarity = config.minSimilarity ?? 0.50
77+
let minCombinedConfidence = config.minCombinedConfidence ?? 0.54
78+
79+
// Validate and normalize vocabulary terms
80+
var validatedTerms: [CustomVocabularyTerm] = []
81+
for term in config.terms {
82+
let (sanitized, warnings) = sanitizeVocabularyTerm(term.text)
83+
84+
if !warnings.isEmpty {
85+
logger.warning("Term '\(term.text)': \(warnings.joined(separator: ", "))")
86+
}
87+
88+
// Skip empty terms after sanitization
89+
guard !sanitized.isEmpty else {
90+
logger.warning("Term '\(term.text)' is empty after sanitization, skipping")
91+
continue
92+
}
93+
94+
validatedTerms.append(term)
95+
}
96+
5797
return CustomVocabularyContext(
58-
terms: config.terms,
98+
terms: validatedTerms,
5999
alpha: alpha,
60100
contextScore: contextScore,
61101
depthScaling: depthScaling,
62-
scorePerPhrase: scorePerPhrase
102+
scorePerPhrase: scorePerPhrase,
103+
minCtcScore: minCtcScore,
104+
minSimilarity: minSimilarity,
105+
minCombinedConfidence: minCombinedConfidence
63106
)
64107
}
108+
109+
/// Sanitize a vocabulary term and return warnings about potential issues.
110+
private static func sanitizeVocabularyTerm(_ text: String) -> (sanitized: String, warnings: [String]) {
111+
var warnings: [String] = []
112+
var result = text
113+
114+
// 1. Check for control characters
115+
if result.rangeOfCharacter(from: .controlCharacters) != nil {
116+
warnings.append("contains control characters")
117+
result = result.filter { !$0.isNewline && !$0.isWhitespace || $0 == " " }
118+
}
119+
120+
// 2. Check for diacritics (informational, not blocking)
121+
if result.folding(options: .diacriticInsensitive, locale: nil) != result {
122+
warnings.append("contains diacritics - consider adding ASCII alias")
123+
}
124+
125+
// 3. Check for numbers (informational)
126+
if result.rangeOfCharacter(from: .decimalDigits) != nil {
127+
warnings.append("contains numbers")
128+
}
129+
130+
// 4. Check for unusual characters (not letters, spaces, hyphens, apostrophes)
131+
let allowedChars = CharacterSet.letters
132+
.union(.whitespaces)
133+
.union(CharacterSet(charactersIn: "-'"))
134+
135+
if result.rangeOfCharacter(from: allowedChars.inverted) != nil {
136+
warnings.append("contains unusual characters")
137+
}
138+
139+
return (result, warnings)
140+
}
65141
}

0 commit comments

Comments
 (0)