improve text-classification components
Browse files- public/workers/text-classification.js +4 -4
- src/components/ModelCode.tsx +3 -1
- src/components/PipelineLayout.tsx +4 -0
- src/components/PipelineSelector.tsx +3 -3
- src/components/Sidebar.tsx +2 -0
- src/components/pipelines/TextClassification.tsx +51 -14
- src/components/pipelines/TextClassificationConfig.tsx +51 -0
- src/contexts/TextClassificationContext.tsx +47 -0
- src/types.ts +3 -0
public/workers/text-classification.js
CHANGED
|
@@ -41,7 +41,7 @@ class MyTextClassificationPipeline {
|
|
| 41 |
// Listen for messages from the main thread
|
| 42 |
self.addEventListener('message', async (event) => {
|
| 43 |
try {
|
| 44 |
-
const { type, model, dtype, text } = event.data
|
| 45 |
|
| 46 |
if (!model) {
|
| 47 |
self.postMessage({
|
|
@@ -76,13 +76,13 @@ self.addEventListener('message', async (event) => {
|
|
| 76 |
const split = text.split('\n')
|
| 77 |
for (const line of split) {
|
| 78 |
if (line.trim()) {
|
| 79 |
-
const output = await classifier(line)
|
| 80 |
self.postMessage({
|
| 81 |
status: 'output',
|
| 82 |
output: {
|
| 83 |
sequence: line,
|
| 84 |
-
labels:
|
| 85 |
-
scores:
|
| 86 |
}
|
| 87 |
})
|
| 88 |
}
|
|
|
|
| 41 |
// Listen for messages from the main thread
|
| 42 |
self.addEventListener('message', async (event) => {
|
| 43 |
try {
|
| 44 |
+
const { type, model, dtype, text, config } = event.data
|
| 45 |
|
| 46 |
if (!model) {
|
| 47 |
self.postMessage({
|
|
|
|
| 76 |
const split = text.split('\n')
|
| 77 |
for (const line of split) {
|
| 78 |
if (line.trim()) {
|
| 79 |
+
const output = await classifier(line, config)
|
| 80 |
self.postMessage({
|
| 81 |
status: 'output',
|
| 82 |
output: {
|
| 83 |
sequence: line,
|
| 84 |
+
labels: output.map((item) => item.label),
|
| 85 |
+
scores: output.map((item) => item.score)
|
| 86 |
}
|
| 87 |
})
|
| 88 |
}
|
src/components/ModelCode.tsx
CHANGED
|
@@ -38,7 +38,9 @@ const ModelCode = ({ isCodeModalOpen, setIsCodeModalOpen }: ModelCodeProps) => {
|
|
| 38 |
case 'text-classification':
|
| 39 |
classType = 'classifier'
|
| 40 |
exampleData = 'I love this product!'
|
| 41 |
-
config = {
|
|
|
|
|
|
|
| 42 |
break
|
| 43 |
case 'text-generation':
|
| 44 |
classType = 'generator'
|
|
|
|
| 38 |
case 'text-classification':
|
| 39 |
classType = 'classifier'
|
| 40 |
exampleData = 'I love this product!'
|
| 41 |
+
config = {
|
| 42 |
+
top_k: 1
|
| 43 |
+
}
|
| 44 |
break
|
| 45 |
case 'text-generation':
|
| 46 |
classType = 'generator'
|
src/components/PipelineLayout.tsx
CHANGED
|
@@ -3,6 +3,7 @@ import { TextGenerationProvider } from '../contexts/TextGenerationContext'
|
|
| 3 |
import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
|
| 4 |
import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
|
| 5 |
import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
|
|
|
|
| 6 |
|
| 7 |
export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
| 8 |
const { pipeline } = useModel()
|
|
@@ -26,6 +27,9 @@ export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
|
| 26 |
<ImageClassificationProvider>{children}</ImageClassificationProvider>
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
default:
|
| 30 |
return <>{children}</>
|
| 31 |
}
|
|
|
|
| 3 |
import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
|
| 4 |
import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
|
| 5 |
import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
|
| 6 |
+
import { TextClassificationProvider } from '../contexts/TextClassificationContext'
|
| 7 |
|
| 8 |
export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
| 9 |
const { pipeline } = useModel()
|
|
|
|
| 27 |
<ImageClassificationProvider>{children}</ImageClassificationProvider>
|
| 28 |
)
|
| 29 |
|
| 30 |
+
case 'text-classification':
|
| 31 |
+
return <TextClassificationProvider>{children}</TextClassificationProvider>
|
| 32 |
+
|
| 33 |
default:
|
| 34 |
return <>{children}</>
|
| 35 |
}
|
src/components/PipelineSelector.tsx
CHANGED
|
@@ -12,9 +12,9 @@ export const supportedPipelines = [
|
|
| 12 |
'image-classification',
|
| 13 |
'text-generation',
|
| 14 |
'zero-shot-classification',
|
| 15 |
-
'text-classification'
|
| 16 |
-
'summarization',
|
| 17 |
-
'translation'
|
| 18 |
]
|
| 19 |
|
| 20 |
interface PipelineSelectorProps {
|
|
|
|
| 12 |
'image-classification',
|
| 13 |
'text-generation',
|
| 14 |
'zero-shot-classification',
|
| 15 |
+
'text-classification'
|
| 16 |
+
// 'summarization',
|
| 17 |
+
// 'translation'
|
| 18 |
]
|
| 19 |
|
| 20 |
interface PipelineSelectorProps {
|
src/components/Sidebar.tsx
CHANGED
|
@@ -7,6 +7,7 @@ import TextGenerationConfig from './pipelines/TextGenerationConfig'
|
|
| 7 |
import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
|
| 8 |
import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
|
| 9 |
import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
|
|
|
|
| 10 |
import { Button } from '@/components/ui/button'
|
| 11 |
|
| 12 |
interface SidebarProps {
|
|
@@ -102,6 +103,7 @@ const Sidebar = ({
|
|
| 102 |
{pipeline === 'image-classification' && (
|
| 103 |
<ImageClassificationConfig />
|
| 104 |
)}
|
|
|
|
| 105 |
</div>
|
| 106 |
</div>
|
| 107 |
</div>
|
|
|
|
| 7 |
import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
|
| 8 |
import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
|
| 9 |
import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
|
| 10 |
+
import TextClassificationConfig from './pipelines/TextClassificationConfig'
|
| 11 |
import { Button } from '@/components/ui/button'
|
| 12 |
|
| 13 |
interface SidebarProps {
|
|
|
|
| 103 |
{pipeline === 'image-classification' && (
|
| 104 |
<ImageClassificationConfig />
|
| 105 |
)}
|
| 106 |
+
{pipeline === 'text-classification' && <TextClassificationConfig />}
|
| 107 |
</div>
|
| 108 |
</div>
|
| 109 |
</div>
|
src/components/pipelines/TextClassification.tsx
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
import { useState, useCallback, useEffect } from 'react'
|
| 2 |
-
import {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import { useModel } from '../../contexts/ModelContext'
|
|
|
|
| 4 |
|
| 5 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 6 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
@@ -18,7 +23,7 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
| 18 |
function TextClassification() {
|
| 19 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 20 |
const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
|
| 21 |
-
const [results, setResults] = useState<
|
| 22 |
const {
|
| 23 |
activeWorker,
|
| 24 |
status,
|
|
@@ -27,6 +32,7 @@ function TextClassification() {
|
|
| 27 |
hasBeenLoaded,
|
| 28 |
selectedQuantization
|
| 29 |
} = useModel()
|
|
|
|
| 30 |
|
| 31 |
useEffect(() => {
|
| 32 |
if (modelInfo?.widgetData) {
|
|
@@ -51,10 +57,11 @@ function TextClassification() {
|
|
| 51 |
type: 'classify',
|
| 52 |
text,
|
| 53 |
model: modelInfo.id,
|
| 54 |
-
dtype: selectedQuantization ?? 'fp32'
|
|
|
|
| 55 |
}
|
| 56 |
activeWorker.postMessage(message)
|
| 57 |
-
}, [text, modelInfo, activeWorker, selectedQuantization, setResults])
|
| 58 |
|
| 59 |
// Handle worker messages
|
| 60 |
useEffect(() => {
|
|
@@ -65,7 +72,7 @@ function TextClassification() {
|
|
| 65 |
if (status === 'output') {
|
| 66 |
setStatus('output')
|
| 67 |
const result = e.data.output!
|
| 68 |
-
setResults((prev:
|
| 69 |
}
|
| 70 |
}
|
| 71 |
|
|
@@ -135,17 +142,47 @@ function TextClassification() {
|
|
| 135 |
<div className="space-y-3">
|
| 136 |
{results.map((result, index) => (
|
| 137 |
<div key={index} className="p-3 rounded-sm border-2">
|
| 138 |
-
<div className="
|
| 139 |
-
<span className="font-semibold text-sm">
|
| 140 |
-
{result.labels[0]}
|
| 141 |
-
</span>
|
| 142 |
-
<span className="text-sm font-mono">
|
| 143 |
-
{(result.scores[0] * 100).toFixed(1)}%
|
| 144 |
-
</span>
|
| 145 |
-
</div>
|
| 146 |
-
<div className="text-sm text-gray-700">
|
| 147 |
{result.sequence}
|
| 148 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
</div>
|
| 150 |
))}
|
| 151 |
</div>
|
|
|
|
| 1 |
import { useState, useCallback, useEffect } from 'react'
|
| 2 |
+
import {
|
| 3 |
+
ClassificationOutput,
|
| 4 |
+
TextClassificationWorkerInput,
|
| 5 |
+
WorkerMessage
|
| 6 |
+
} from '../../types'
|
| 7 |
import { useModel } from '../../contexts/ModelContext'
|
| 8 |
+
import { useTextClassification } from '../../contexts/TextClassificationContext'
|
| 9 |
|
| 10 |
const PLACEHOLDER_TEXTS: string[] = [
|
| 11 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
|
|
| 23 |
function TextClassification() {
|
| 24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
| 25 |
const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
|
| 26 |
+
const [results, setResults] = useState<ClassificationOutput[]>([])
|
| 27 |
const {
|
| 28 |
activeWorker,
|
| 29 |
status,
|
|
|
|
| 32 |
hasBeenLoaded,
|
| 33 |
selectedQuantization
|
| 34 |
} = useModel()
|
| 35 |
+
const { config } = useTextClassification()
|
| 36 |
|
| 37 |
useEffect(() => {
|
| 38 |
if (modelInfo?.widgetData) {
|
|
|
|
| 57 |
type: 'classify',
|
| 58 |
text,
|
| 59 |
model: modelInfo.id,
|
| 60 |
+
dtype: selectedQuantization ?? 'fp32',
|
| 61 |
+
config
|
| 62 |
}
|
| 63 |
activeWorker.postMessage(message)
|
| 64 |
+
}, [text, modelInfo, activeWorker, selectedQuantization, config, setResults])
|
| 65 |
|
| 66 |
// Handle worker messages
|
| 67 |
useEffect(() => {
|
|
|
|
| 72 |
if (status === 'output') {
|
| 73 |
setStatus('output')
|
| 74 |
const result = e.data.output!
|
| 75 |
+
setResults((prev: ClassificationOutput[]) => [...prev, result])
|
| 76 |
}
|
| 77 |
}
|
| 78 |
|
|
|
|
| 142 |
<div className="space-y-3">
|
| 143 |
{results.map((result, index) => (
|
| 144 |
<div key={index} className="p-3 rounded-sm border-2">
|
| 145 |
+
<div className="text-sm text-gray-700 mb-3">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
{result.sequence}
|
| 147 |
</div>
|
| 148 |
+
<div className="space-y-2">
|
| 149 |
+
{result.labels.map(
|
| 150 |
+
(label: string, labelIndex: number) => {
|
| 151 |
+
const score = result.scores[labelIndex]
|
| 152 |
+
const isTopPrediction = labelIndex === 0
|
| 153 |
+
|
| 154 |
+
return (
|
| 155 |
+
<div
|
| 156 |
+
key={labelIndex}
|
| 157 |
+
className={`flex justify-between items-center p-2 rounded ${
|
| 158 |
+
isTopPrediction
|
| 159 |
+
? 'bg-blue-50 border-l-4 border-blue-500'
|
| 160 |
+
: 'bg-gray-50'
|
| 161 |
+
}`}
|
| 162 |
+
>
|
| 163 |
+
<span
|
| 164 |
+
className={`font-medium text-sm ${
|
| 165 |
+
isTopPrediction
|
| 166 |
+
? 'text-blue-700'
|
| 167 |
+
: 'text-gray-700'
|
| 168 |
+
}`}
|
| 169 |
+
>
|
| 170 |
+
{label}
|
| 171 |
+
</span>
|
| 172 |
+
<span
|
| 173 |
+
className={`text-sm font-mono ${
|
| 174 |
+
isTopPrediction
|
| 175 |
+
? 'text-blue-600'
|
| 176 |
+
: 'text-gray-600'
|
| 177 |
+
}`}
|
| 178 |
+
>
|
| 179 |
+
{(score * 100).toFixed(1)}%
|
| 180 |
+
</span>
|
| 181 |
+
</div>
|
| 182 |
+
)
|
| 183 |
+
}
|
| 184 |
+
)}
|
| 185 |
+
</div>
|
| 186 |
</div>
|
| 187 |
))}
|
| 188 |
</div>
|
src/components/pipelines/TextClassificationConfig.tsx
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react'
|
| 2 |
+
import { useTextClassification } from '../../contexts/TextClassificationContext'
|
| 3 |
+
import { Slider } from '../ui/slider'
|
| 4 |
+
|
| 5 |
+
const TextClassificationConfig = () => {
|
| 6 |
+
const { config, setConfig } = useTextClassification()
|
| 7 |
+
|
| 8 |
+
return (
|
| 9 |
+
<div className="space-y-4">
|
| 10 |
+
<h3 className="text-lg font-semibold text-foreground">
|
| 11 |
+
Text Classification Settings
|
| 12 |
+
</h3>
|
| 13 |
+
|
| 14 |
+
<div className="space-y-3">
|
| 15 |
+
<div>
|
| 16 |
+
<label className="block text-sm font-medium text-foreground/80 mb-1">
|
| 17 |
+
Top K Predictions: {config.top_k}
|
| 18 |
+
</label>
|
| 19 |
+
<Slider
|
| 20 |
+
defaultValue={[config.top_k]}
|
| 21 |
+
min={1}
|
| 22 |
+
max={10}
|
| 23 |
+
step={1}
|
| 24 |
+
onValueChange={(value) => setConfig({ top_k: value[0] })}
|
| 25 |
+
className="w-full rounded-lg"
|
| 26 |
+
/>
|
| 27 |
+
<div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
|
| 28 |
+
<span>1</span>
|
| 29 |
+
<span>4</span>
|
| 30 |
+
<span>7</span>
|
| 31 |
+
<span>10</span>
|
| 32 |
+
</div>
|
| 33 |
+
<p className="text-xs text-muted-foreground mt-1">
|
| 34 |
+
Number of top predictions to return for each text
|
| 35 |
+
</p>
|
| 36 |
+
</div>
|
| 37 |
+
|
| 38 |
+
<div className="p-3 bg-chart-4/10 border border-chart-4/20 rounded-lg">
|
| 39 |
+
<h4 className="text-sm font-medium text-chart-4 mb-2">💡 Tips</h4>
|
| 40 |
+
<div className="text-xs text-chart-4 space-y-1">
|
| 41 |
+
<p>• Use Top K = 1-3 for most cases</p>
|
| 42 |
+
<p>• Higher values show more detailed rankings</p>
|
| 43 |
+
<p>• Try quantized models for faster processing</p>
|
| 44 |
+
</div>
|
| 45 |
+
</div>
|
| 46 |
+
</div>
|
| 47 |
+
</div>
|
| 48 |
+
)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
export default TextClassificationConfig
|
src/contexts/TextClassificationContext.tsx
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { createContext, useContext, useState } from 'react'
|
| 2 |
+
|
| 3 |
+
interface TextClassificationConfig {
|
| 4 |
+
top_k: number
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
interface TextClassificationContextType {
|
| 8 |
+
config: TextClassificationConfig
|
| 9 |
+
setConfig: React.Dispatch<React.SetStateAction<TextClassificationConfig>>
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
const TextClassificationContext = createContext<
|
| 13 |
+
TextClassificationContextType | undefined
|
| 14 |
+
>(undefined)
|
| 15 |
+
|
| 16 |
+
export function useTextClassification() {
|
| 17 |
+
const context = useContext(TextClassificationContext)
|
| 18 |
+
if (context === undefined) {
|
| 19 |
+
throw new Error(
|
| 20 |
+
'useTextClassification must be used within a TextClassificationProvider'
|
| 21 |
+
)
|
| 22 |
+
}
|
| 23 |
+
return context
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
interface TextClassificationProviderProps {
|
| 27 |
+
children: React.ReactNode
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
export function TextClassificationProvider({
|
| 31 |
+
children
|
| 32 |
+
}: TextClassificationProviderProps) {
|
| 33 |
+
const [config, setConfig] = useState<TextClassificationConfig>({
|
| 34 |
+
top_k: 1
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
const value: TextClassificationContextType = {
|
| 38 |
+
config,
|
| 39 |
+
setConfig
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return (
|
| 43 |
+
<TextClassificationContext.Provider value={value}>
|
| 44 |
+
{children}
|
| 45 |
+
</TextClassificationContext.Provider>
|
| 46 |
+
)
|
| 47 |
+
}
|
src/types.ts
CHANGED
|
@@ -48,6 +48,9 @@ export interface TextClassificationWorkerInput {
|
|
| 48 |
text: string
|
| 49 |
model: string
|
| 50 |
dtype: QuantizationType
|
|
|
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
export interface TextGenerationWorkerInput {
|
|
|
|
| 48 |
text: string
|
| 49 |
model: string
|
| 50 |
dtype: QuantizationType
|
| 51 |
+
config?: {
|
| 52 |
+
top_k?: number
|
| 53 |
+
}
|
| 54 |
}
|
| 55 |
|
| 56 |
export interface TextGenerationWorkerInput {
|