Enhance text classification pipeline: add device support, improve error handling, and refine message processing logic
Browse files- public/workers/text-classification.js +47 -37
- src/components/ModelLoader.tsx +7 -1
- src/components/TextClassification.tsx +33 -32
- src/types.ts +7 -1
public/workers/text-classification.js
CHANGED
|
@@ -9,7 +9,7 @@ class MyTextClassificationPipeline {
|
|
| 9 |
this.instance = pipeline(
|
| 10 |
this.task,
|
| 11 |
model,
|
| 12 |
-
{ dtype, progress_callback },
|
| 13 |
)
|
| 14 |
return this.instance
|
| 15 |
}
|
|
@@ -17,49 +17,59 @@ class MyTextClassificationPipeline {
|
|
| 17 |
|
| 18 |
// Listen for messages from the main thread
|
| 19 |
self.addEventListener('message', async (event) => {
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
// Retrieve the pipeline. This will download the model if not already cached.
|
| 31 |
-
const classifier = await MyTextClassificationPipeline.getInstance(
|
| 32 |
-
model,
|
| 33 |
-
dtype,
|
| 34 |
-
(x) => {
|
| 35 |
-
self.postMessage({ status: 'loading', output: x })
|
| 36 |
}
|
| 37 |
-
)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
return
|
| 48 |
}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if (
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
}
|
|
|
|
| 62 |
}
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
})
|
|
|
|
| 9 |
this.instance = pipeline(
|
| 10 |
this.task,
|
| 11 |
model,
|
| 12 |
+
{ dtype, device: "webgpu", progress_callback },
|
| 13 |
)
|
| 14 |
return this.instance
|
| 15 |
}
|
|
|
|
| 17 |
|
| 18 |
// Listen for messages from the main thread
|
| 19 |
self.addEventListener('message', async (event) => {
|
| 20 |
+
try {
|
| 21 |
+
const { type, model, dtype, text } = event.data
|
| 22 |
|
| 23 |
+
if (!model) {
|
| 24 |
+
self.postMessage({
|
| 25 |
+
status: 'error',
|
| 26 |
+
output: 'No model provided'
|
| 27 |
+
})
|
| 28 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
|
|
|
| 30 |
|
| 31 |
+
// Retrieve the pipeline. This will download the model if not already cached.
|
| 32 |
+
const classifier = await MyTextClassificationPipeline.getInstance(
|
| 33 |
+
model,
|
| 34 |
+
dtype,
|
| 35 |
+
(x) => {
|
| 36 |
+
self.postMessage({ status: 'loading', output: x })
|
| 37 |
+
}
|
| 38 |
+
)
|
| 39 |
|
| 40 |
+
if (type === 'load') {
|
| 41 |
+
self.postMessage({
|
| 42 |
+
status: 'ready',
|
| 43 |
+
output: `Model ${model}, dtype ${dtype} loaded`
|
| 44 |
+
})
|
| 45 |
return
|
| 46 |
}
|
| 47 |
+
|
| 48 |
+
if (type === 'classify') {
|
| 49 |
+
if (!text) {
|
| 50 |
+
self.postMessage({ status: 'ready' }) // Nothing to process
|
| 51 |
+
return
|
| 52 |
+
}
|
| 53 |
+
const split = text.split('\n')
|
| 54 |
+
for (const line of split) {
|
| 55 |
+
if (line.trim()) {
|
| 56 |
+
const output = await classifier(line)
|
| 57 |
+
self.postMessage({
|
| 58 |
+
status: 'output',
|
| 59 |
+
output: {
|
| 60 |
+
sequence: line,
|
| 61 |
+
labels: [output[0].label],
|
| 62 |
+
scores: [output[0].score]
|
| 63 |
+
}
|
| 64 |
+
})
|
| 65 |
+
}
|
| 66 |
}
|
| 67 |
+
self.postMessage({ status: 'ready' })
|
| 68 |
}
|
| 69 |
+
} catch (error) {
|
| 70 |
+
self.postMessage({
|
| 71 |
+
status: 'error',
|
| 72 |
+
output: error.message || 'An error occurred during processing'
|
| 73 |
+
})
|
| 74 |
}
|
| 75 |
})
|
src/components/ModelLoader.tsx
CHANGED
|
@@ -21,6 +21,9 @@ const ModelLoader = () => {
|
|
| 21 |
setHasBeenLoaded
|
| 22 |
} = useModel()
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
useEffect(() => {
|
| 26 |
if (!modelInfo) return
|
|
@@ -43,6 +46,8 @@ const ModelLoader = () => {
|
|
| 43 |
setHasBeenLoaded(false)
|
| 44 |
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
|
| 45 |
|
|
|
|
|
|
|
| 46 |
useEffect(() => {
|
| 47 |
if (!modelInfo) return
|
| 48 |
|
|
@@ -61,6 +66,7 @@ const ModelLoader = () => {
|
|
| 61 |
const { status, output } = e.data
|
| 62 |
if (status === 'ready') {
|
| 63 |
setStatus('ready')
|
|
|
|
| 64 |
setHasBeenLoaded(true)
|
| 65 |
} else if (status === 'loading' && output && !hasBeenLoaded) {
|
| 66 |
setStatus('loading')
|
|
@@ -156,7 +162,7 @@ const ModelLoader = () => {
|
|
| 156 |
<div className="flex justify-center">
|
| 157 |
<button
|
| 158 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
| 159 |
-
disabled={hasBeenLoaded}
|
| 160 |
onClick={loadModel}
|
| 161 |
>
|
| 162 |
{status === 'loading' && !hasBeenLoaded ? (
|
|
|
|
| 21 |
setHasBeenLoaded
|
| 22 |
} = useModel()
|
| 23 |
|
| 24 |
+
useEffect(() => {
|
| 25 |
+
setHasBeenLoaded(false)
|
| 26 |
+
}, [selectedQuantization])
|
| 27 |
|
| 28 |
useEffect(() => {
|
| 29 |
if (!modelInfo) return
|
|
|
|
| 46 |
setHasBeenLoaded(false)
|
| 47 |
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
|
| 48 |
|
| 49 |
+
|
| 50 |
+
|
| 51 |
useEffect(() => {
|
| 52 |
if (!modelInfo) return
|
| 53 |
|
|
|
|
| 66 |
const { status, output } = e.data
|
| 67 |
if (status === 'ready') {
|
| 68 |
setStatus('ready')
|
| 69 |
+
if (e.data.output) console.log(e.data.output)
|
| 70 |
setHasBeenLoaded(true)
|
| 71 |
} else if (status === 'loading' && output && !hasBeenLoaded) {
|
| 72 |
setStatus('loading')
|
|
|
|
| 162 |
<div className="flex justify-center">
|
| 163 |
<button
|
| 164 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
| 165 |
+
disabled={hasBeenLoaded || status === 'loading'}
|
| 166 |
onClick={loadModel}
|
| 167 |
>
|
| 168 |
{status === 'loading' && !hasBeenLoaded ? (
|
src/components/TextClassification.tsx
CHANGED
|
@@ -24,8 +24,6 @@ function TextClassification() {
|
|
| 24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 25 |
const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
const classify = useCallback(() => {
|
| 30 |
if (!modelInfo || !activeWorker) {
|
| 31 |
console.error('Model info or worker is not available')
|
|
@@ -48,46 +46,49 @@ function TextClassification() {
|
|
| 48 |
|
| 49 |
return (
|
| 50 |
<div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
|
| 51 |
-
<h1 className="text-2xl font-bold mb-4">Text Classification</h1>
|
| 52 |
|
| 53 |
-
<div className="flex flex-col lg:flex-row gap-4 h-
|
| 54 |
{/* Input Section */}
|
| 55 |
-
<div className="flex flex-col w-full lg:w-1/2">
|
| 56 |
-
<label className="text-lg font-medium mb-2">Input Text:</label>
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
| 81 |
</div>
|
| 82 |
</div>
|
| 83 |
|
| 84 |
{/* Results Section */}
|
| 85 |
-
<div className="flex flex-col w-full lg:w-1/2">
|
| 86 |
-
<label className="text-lg font-medium mb-2">
|
| 87 |
Classification Results ({results.length}):
|
| 88 |
</label>
|
| 89 |
|
| 90 |
-
<div className="border border-gray-300 rounded p-3 flex-
|
| 91 |
{results.length === 0 ? (
|
| 92 |
<div className="text-gray-500 text-center py-8">
|
| 93 |
No results yet. Click "Classify Text" to analyze your input.
|
|
|
|
| 24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 25 |
const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
| 26 |
|
|
|
|
|
|
|
| 27 |
const classify = useCallback(() => {
|
| 28 |
if (!modelInfo || !activeWorker) {
|
| 29 |
console.error('Model info or worker is not available')
|
|
|
|
| 46 |
|
| 47 |
return (
|
| 48 |
<div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
|
| 49 |
+
<h1 className="text-2xl font-bold mb-4 flex-shrink-0">Text Classification</h1>
|
| 50 |
|
| 51 |
+
<div className="flex flex-col lg:flex-row gap-4 flex-1 min-h-0">
|
| 52 |
{/* Input Section */}
|
| 53 |
+
<div className="flex flex-col w-full lg:w-1/2 min-h-0">
|
| 54 |
+
<label className="text-lg font-medium mb-2 flex-shrink-0">Input Text:</label>
|
| 55 |
+
|
| 56 |
+
<div className="flex flex-col flex-1 min-h-0">
|
| 57 |
+
<textarea
|
| 58 |
+
className="border border-gray-300 rounded p-3 flex-1 resize-none min-h-[200px]"
|
| 59 |
+
value={text}
|
| 60 |
+
onChange={(e) => setText(e.target.value)}
|
| 61 |
+
placeholder="Enter text to classify (one per line)..."
|
| 62 |
+
/>
|
| 63 |
|
| 64 |
+
<div className="flex gap-2 mt-4 flex-shrink-0">
|
| 65 |
+
<button
|
| 66 |
+
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
| 67 |
+
disabled={busy}
|
| 68 |
+
onClick={classify}
|
| 69 |
+
>
|
| 70 |
+
{hasBeenLoaded ? !busy
|
| 71 |
+
? 'Classify Text'
|
| 72 |
+
: 'Processing...'
|
| 73 |
+
: 'Load model first'}
|
| 74 |
+
</button>
|
| 75 |
+
<button
|
| 76 |
+
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
| 77 |
+
onClick={handleClear}
|
| 78 |
+
>
|
| 79 |
+
Clear Results
|
| 80 |
+
</button>
|
| 81 |
+
</div>
|
| 82 |
</div>
|
| 83 |
</div>
|
| 84 |
|
| 85 |
{/* Results Section */}
|
| 86 |
+
<div className="flex flex-col w-full lg:w-1/2 min-h-0">
|
| 87 |
+
<label className="text-lg font-medium mb-2 flex-shrink-0">
|
| 88 |
Classification Results ({results.length}):
|
| 89 |
</label>
|
| 90 |
|
| 91 |
+
<div className="border border-gray-300 rounded p-3 flex-1 overflow-y-auto min-h-[200px]">
|
| 92 |
{results.length === 0 ? (
|
| 93 |
<div className="text-gray-500 text-center py-8">
|
| 94 |
No results yet. Click "Classify Text" to analyze your input.
|
src/types.ts
CHANGED
|
@@ -9,7 +9,13 @@ export interface ClassificationOutput {
|
|
| 9 |
scores: number[]
|
| 10 |
}
|
| 11 |
|
| 12 |
-
export type WorkerStatus =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
export interface WorkerMessage {
|
| 15 |
status: WorkerStatus
|
|
|
|
| 9 |
scores: number[]
|
| 10 |
}
|
| 11 |
|
| 12 |
+
export type WorkerStatus =
|
| 13 |
+
| 'initiate'
|
| 14 |
+
| 'ready'
|
| 15 |
+
| 'output'
|
| 16 |
+
| 'loading'
|
| 17 |
+
| 'error'
|
| 18 |
+
| 'disposed'
|
| 19 |
|
| 20 |
export interface WorkerMessage {
|
| 21 |
status: WorkerStatus
|