diff --git a/packages/components/nodes/textsplitters/CodeTextSplitter/CodeTextSplitter.ts b/packages/components/nodes/textsplitters/CodeTextSplitter/CodeTextSplitter.ts index 292486ac2bf..05b6cf4b9a2 100644 --- a/packages/components/nodes/textsplitters/CodeTextSplitter/CodeTextSplitter.ts +++ b/packages/components/nodes/textsplitters/CodeTextSplitter/CodeTextSplitter.ts @@ -3,9 +3,194 @@ import { getBaseClasses } from '../../../src/utils' import { RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitterParams, - SupportedTextSplitterLanguage + SupportedTextSplitterLanguage, + SupportedTextSplitterLanguages } from '@langchain/textsplitters' +const extraLanguageSeparators: Record = { + c: [ + '\nstruct ', + '\nunion ', + '\nenum ', + '\nvoid ', + '\nint ', + '\nfloat ', + '\ndouble ', + '\nif ', + '\nfor ', + '\nwhile ', + '\nswitch ', + '\ncase ', + '\n\n', + '\n', + ' ', + '' + ], + csharp: [ + '\nnamespace ', + '\ninterface ', + '\nenum ', + '\nstruct ', + '\ndelegate ', + '\nevent ', + '\nclass ', + '\nabstract ', + '\npublic ', + '\nprotected ', + '\nprivate ', + '\nstatic ', + '\nreturn ', + '\nif ', + '\ncontinue ', + '\nfor ', + '\nforeach ', + '\nwhile ', + '\nswitch ', + '\nbreak ', + '\ncase ', + '\nelse ', + '\ntry ', + '\nthrow ', + '\nfinally ', + '\ncatch ', + '\n\n', + '\n', + ' ', + '' + ], + cobol: [ + '\nIDENTIFICATION DIVISION.', + '\nENVIRONMENT DIVISION.', + '\nDATA DIVISION.', + '\nPROCEDURE DIVISION.', + '\nWORKING-STORAGE SECTION.', + '\nLINKAGE SECTION.', + '\nFILE SECTION.', + '\nINPUT-OUTPUT SECTION.', + '\nOPEN ', + '\nCLOSE ', + '\nREAD ', + '\nWRITE ', + '\nIF ', + '\nELSE ', + '\nMOVE ', + '\nPERFORM ', + '\nUNTIL ', + '\nVARYING ', + '\nACCEPT ', + '\nDISPLAY ', + '\nSTOP RUN.', + '\n', + ' ', + '' + ], + elixir: [ + '\ndef ', + '\ndefp ', + '\ndefmodule ', + '\ndefprotocol ', + '\ndefmacro ', + '\ndefmacrop ', + '\nif ', + '\nunless ', + '\ncase ', + '\ncond ', + '\nwith ', + '\nfor ', + '\ndo ', + '\n\n', + '\n', + ' ', + '' + ], + haskell: [ + '\nmain :: ', + '\nmain = ', + '\nlet ', + '\nin ', + '\ndo ', + '\nwhere ', + '\n:: ', + '\n= ', + '\ndata ', + '\nnewtype ', + '\ntype ', + '\nmodule ', + '\nimport ', + '\nqualified ', + '\nimport qualified ', + '\nclass ', + '\ninstance ', + '\ncase ', + '\n| ', + '\n= {', + '\n, ', + '\n\n', + '\n', + ' ', + '' + ], + kotlin: [ + '\nclass ', + '\npublic ', + '\nprotected ', + '\nprivate ', + '\ninternal ', + '\ncompanion ', + '\nfun ', + '\nval ', + '\nvar ', + '\nif ', + '\nfor ', + '\nwhile ', + '\nwhen ', + '\nelse ', + '\n\n', + '\n', + ' ', + '' + ], + lua: ['\nlocal ', '\nfunction ', '\nif ', '\nfor ', '\nwhile ', '\nrepeat ', '\n\n', '\n', ' ', ''], + powershell: [ + '\nfunction ', + '\nparam ', + '\nif ', + '\nforeach ', + '\nfor ', + '\nwhile ', + '\nswitch ', + '\nclass ', + '\ntry ', + '\ncatch ', + '\nfinally ', + '\n\n', + '\n', + ' ', + '' + ], + ts: [ + '\nenum ', + '\ninterface ', + '\nnamespace ', + '\ntype ', + '\nclass ', + '\nfunction ', + '\nconst ', + '\nlet ', + '\nvar ', + '\nif ', + '\nfor ', + '\nwhile ', + '\nswitch ', + '\ncase ', + '\ndefault ', + '\n\n', + '\n', + ' ', + '' + ] +} + class CodeTextSplitter_TextSplitters implements INode { label: string name: string @@ -31,14 +216,38 @@ class CodeTextSplitter_TextSplitters implements INode { name: 'language', type: 'options', options: [ + { + label: 'c', + name: 'c' + }, + { + label: 'cobol', + name: 'cobol' + }, { label: 'cpp', name: 'cpp' }, + { + label: 'csharp', + name: 'csharp' + }, + { + label: 'elixir', + name: 'elixir' + }, { label: 'go', name: 'go' }, + { + label: 'haskell', + name: 'haskell' + }, + { + label: 'html', + name: 'html' + }, { label: 'java', name: 'java' @@ -47,10 +256,30 @@ class CodeTextSplitter_TextSplitters implements INode { label: 'js', name: 'js' }, + { + label: 'kotlin', + name: 'kotlin' + }, + { + label: 'latex', + name: 'latex' + }, + { + label: 'lua', + name: 'lua' + }, + { + label: 'markdown', + name: 'markdown' + }, { label: 'php', name: 'php' }, + { + label: 'powershell', + name: 'powershell' + }, { label: 'proto', name: 'proto' @@ -76,24 +305,16 @@ class CodeTextSplitter_TextSplitters implements INode { name: 'scala' }, { - label: 'swift', - name: 'swift' - }, - { - label: 'markdown', - name: 'markdown' - }, - { - label: 'latex', - name: 'latex' + label: 'sol', + name: 'sol' }, { - label: 'html', - name: 'html' + label: 'swift', + name: 'swift' }, { - label: 'sol', - name: 'sol' + label: 'ts', + name: 'ts' } ] }, @@ -118,16 +339,23 @@ class CodeTextSplitter_TextSplitters implements INode { async init(nodeData: INodeData): Promise { const chunkSize = nodeData.inputs?.chunkSize as string const chunkOverlap = nodeData.inputs?.chunkOverlap as string - const language = nodeData.inputs?.language as SupportedTextSplitterLanguage + const language = nodeData.inputs?.language as string const obj = {} as RecursiveCharacterTextSplitterParams if (chunkSize) obj.chunkSize = parseInt(chunkSize, 10) if (chunkOverlap) obj.chunkOverlap = parseInt(chunkOverlap, 10) - const splitter = RecursiveCharacterTextSplitter.fromLanguage(language, obj) + if ((SupportedTextSplitterLanguages as readonly string[]).includes(language)) { + return RecursiveCharacterTextSplitter.fromLanguage(language as SupportedTextSplitterLanguage, obj) + } + + const separators = extraLanguageSeparators[language] + if (separators) { + return new RecursiveCharacterTextSplitter({ ...obj, separators }) + } - return splitter + return new RecursiveCharacterTextSplitter(obj) } } module.exports = { nodeClass: CodeTextSplitter_TextSplitters }