diff --git a/pkg/assembler/assembler.go b/pkg/assembler/assembler.go index f44d16f..bf2b8d2 100644 --- a/pkg/assembler/assembler.go +++ b/pkg/assembler/assembler.go @@ -27,6 +27,7 @@ import ( "github.com/retroenv/retroasm/pkg/parser" "github.com/retroenv/retroasm/pkg/parser/ast" "github.com/retroenv/retroasm/pkg/scope" + "github.com/retroenv/retrogolib/set" ) var errNoCurrentSegment = errors.New("no current segment found") @@ -116,10 +117,11 @@ func (asm *Assembler[T]) Symbols() map[string]uint64 { // parseASTNodes processes the given AST nodes and converts them to internal types. func (asm *Assembler[T]) parseASTNodes(ctx context.Context, nodes []ast.Node) error { p := &parseAST[T]{ - cfg: asm.cfg, - fileReader: asm.fileReader, - currentScope: asm.fileScope, - segments: map[string]*segment{}, + cfg: asm.cfg, + fileReader: asm.fileReader, + includeActive: set.New[string](), + currentScope: asm.fileScope, + segments: map[string]*segment{}, } if len(asm.cfg.SegmentsOrdered) == 1 { segCfg := asm.cfg.SegmentsOrdered[0] @@ -151,7 +153,7 @@ func (asm *Assembler[T]) parseASTNodes(ctx context.Context, nodes []ast.Node) er return errNoCurrentSegment } - newNodes, err := parseASTNode(p, node) + newNodes, err := parseASTNode(ctx, p, node) if err != nil { return err } diff --git a/pkg/assembler/assembler_asm6_test.go b/pkg/assembler/assembler_asm6_test.go index 341a68b..e3671b1 100644 --- a/pkg/assembler/assembler_asm6_test.go +++ b/pkg/assembler/assembler_asm6_test.go @@ -3,6 +3,7 @@ package assembler import ( "bytes" "context" + "fmt" "strings" "testing" @@ -602,6 +603,54 @@ func TestAssemblerContextCancellation(t *testing.T) { assert.ErrorIs(t, err, context.Canceled, "expected context.Canceled error") } +var asm6SourceIncludeTestCode = ` +.segment "HEADER" +.include "defs.asm" +DB myval +` + +func TestAssemblerAsm6SourceInclude(t *testing.T) { + cfg := m6502.New() + assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig))) + + reader := strings.NewReader(asm6SourceIncludeTestCode) + var buf bytes.Buffer + asm := New(cfg, &buf) + + asm.fileReader = func(name string) ([]byte, error) { + assert.Equal(t, "defs.asm", name) + return []byte("myval = $42\n"), nil + } + + assert.NoError(t, asm.Process(t.Context(), reader)) + assert.Equal(t, []byte{0x42}, buf.Bytes()) +} + +func TestAssemblerAsm6SourceIncludeCycle(t *testing.T) { + cfg := m6502.New() + assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig))) + + reader := strings.NewReader(asm6SourceIncludeTestCode) + var buf bytes.Buffer + asm := New(cfg, &buf) + + asm.fileReader = func(name string) ([]byte, error) { + switch name { + case "defs.asm": + return []byte(".include \"more.asm\"\n"), nil + case "more.asm": + return []byte(".include \"defs.asm\"\n"), nil + default: + return nil, fmt.Errorf("unexpected include %q", name) + } + } + + err := asm.Process(t.Context(), reader) + assert.Error(t, err) + assert.Contains(t, err.Error(), "include cycle detected") + assert.Contains(t, err.Error(), "defs.asm -> more.asm -> defs.asm") +} + func runAsm6Test(t *testing.T, testConfig, testCode string) ([]byte, error) { t.Helper() diff --git a/pkg/assembler/parse_ast_nodes.go b/pkg/assembler/parse_ast_nodes.go index 3155e48..4bf48f8 100644 --- a/pkg/assembler/parse_ast_nodes.go +++ b/pkg/assembler/parse_ast_nodes.go @@ -1,6 +1,8 @@ package assembler import ( + "bytes" + "context" "errors" "fmt" "strings" @@ -9,14 +11,18 @@ import ( "github.com/retroenv/retroasm/pkg/expression" "github.com/retroenv/retroasm/pkg/lexer/token" "github.com/retroenv/retroasm/pkg/number" + "github.com/retroenv/retroasm/pkg/parser" "github.com/retroenv/retroasm/pkg/parser/ast" "github.com/retroenv/retroasm/pkg/scope" + "github.com/retroenv/retrogolib/set" ) type parseAST[T any] struct { cfg *config.Config[T] // a function that reads in a file, for testing includes, defaults to os.ReadFile - fileReader func(name string) ([]byte, error) + fileReader func(name string) ([]byte, error) + includeActive set.Set[string] + includeStack []string currentScope *scope.Scope // current scope, can be a function scope with file scope as parent currentSegment *segment // the current segment being parsed @@ -28,7 +34,7 @@ type parseAST[T any] struct { var errNilInstructionArgument = errors.New("instruction argument cannot be nil") //nolint:cyclop,funlen // type switch with one case per AST node type -func parseASTNode[T any](asm *parseAST[T], node ast.Node) ([]ast.Node, error) { +func parseASTNode[T any](ctx context.Context, asm *parseAST[T], node ast.Node) ([]ast.Node, error) { var ( err error nodes []ast.Node @@ -60,7 +66,7 @@ func parseASTNode[T any](asm *parseAST[T], node ast.Node) ([]ast.Node, error) { nodes, err = parseInstruction(n) case ast.Include: - nodes, err = parseInclude(asm, n) + nodes, err = parseInclude(ctx, asm, n) case ast.Macro: nodes, err = parseMacro(n) @@ -355,12 +361,17 @@ func nameWithModifiers(name string, modifiers []ast.Modifier) (string, error) { return fmt.Sprintf("%s%d", name, offset), nil // offset is negative, fmt includes '-' } -func parseInclude[T any](asm *parseAST[T], inc ast.Include) ([]ast.Node, error) { - if !inc.Binary { - return nil, errors.New("non binary includes are currently not supported") // TODO implement +func parseInclude[T any](ctx context.Context, asm *parseAST[T], inc ast.Include) ([]ast.Node, error) { + name := strings.Trim(inc.Name, "\"'") + + if inc.Binary { + return parseBinaryInclude(asm, name) } - name := strings.Trim(inc.Name, "\"'") + return parseSourceInclude(ctx, asm, name) +} + +func parseBinaryInclude[T any](asm *parseAST[T], name string) ([]ast.Node, error) { b, err := asm.fileReader(name) if err != nil { return nil, fmt.Errorf("reading file '%s': %w", name, err) @@ -372,6 +383,54 @@ func parseInclude[T any](asm *parseAST[T], inc ast.Include) ([]ast.Node, error) return []ast.Node{dat}, nil } +func parseSourceInclude[T any](ctx context.Context, asm *parseAST[T], name string) ([]ast.Node, error) { + if asm.includeActive.Contains(name) { + chain := append(append([]string{}, asm.includeStack...), name) + return nil, fmt.Errorf("include cycle detected: %s", strings.Join(chain, " -> ")) + } + asm.includeActive.Add(name) + asm.includeStack = append(asm.includeStack, name) + defer func() { + asm.includeActive.Remove(name) + asm.includeStack = asm.includeStack[:len(asm.includeStack)-1] + }() + + b, err := asm.fileReader(name) + if err != nil { + return nil, fmt.Errorf("reading file '%s': %w", name, err) + } + + pars := parser.New[T](asm.cfg.Arch, bytes.NewReader(b)) + if err := pars.Read(ctx); err != nil { + return nil, fmt.Errorf("parsing included file '%s': %w", name, err) + } + + nodes, err := pars.TokensToAstNodes() + if err != nil { + return nil, fmt.Errorf("converting tokens for included file '%s': %w", name, err) + } + + var result []ast.Node + for _, node := range nodes { + switch n := node.(type) { + case *ast.Comment: + continue + case ast.Segment: + if err := parseSegment(asm, n); err != nil { + return nil, fmt.Errorf("parsing segment in included file '%s': %w", name, err) + } + default: + newNodes, err := parseASTNode(ctx, asm, node) + if err != nil { + return nil, fmt.Errorf("processing node in included file '%s': %w", name, err) + } + result = append(result, newNodes...) + } + } + + return result, nil +} + func parseVariable(astVar ast.Variable) []ast.Node { v := &variable{v: astVar} return []ast.Node{v} diff --git a/pkg/assembler/parse_ast_nodes_test.go b/pkg/assembler/parse_ast_nodes_test.go index c1685fa..2557716 100644 --- a/pkg/assembler/parse_ast_nodes_test.go +++ b/pkg/assembler/parse_ast_nodes_test.go @@ -1,11 +1,15 @@ package assembler import ( + "strings" "testing" + "github.com/retroenv/retroasm/pkg/arch/m6502" "github.com/retroenv/retroasm/pkg/parser/ast" "github.com/retroenv/retroasm/pkg/scope" + cpu6502 "github.com/retroenv/retrogolib/arch/cpu/m6502" "github.com/retroenv/retrogolib/assert" + "github.com/retroenv/retrogolib/set" ) func TestModifierOffset(t *testing.T) { //nolint:funlen @@ -338,7 +342,7 @@ func TestParseScope(t *testing.T) { currentScope: fileScope, } - nodes, err := parseASTNode(p, ast.NewScope("inner")) + nodes, err := parseASTNode(t.Context(), p, ast.NewScope("inner")) assert.NoError(t, err) assert.Len(t, nodes, 2) assert.NotNil(t, p.currentScope) @@ -362,7 +366,7 @@ func TestParseUnnamedScope(t *testing.T) { currentScope: fileScope, } - nodes, err := parseASTNode(p, ast.NewScope("")) + nodes, err := parseASTNode(t.Context(), p, ast.NewScope("")) assert.NoError(t, err) assert.Len(t, nodes, 1) assert.Equal(t, fileScope, p.currentScope.Parent()) @@ -379,7 +383,7 @@ func TestParseScopeEnd(t *testing.T) { currentScope: childScope, } - nodes, err := parseASTNode(p, ast.NewScopeEnd()) + nodes, err := parseASTNode(t.Context(), p, ast.NewScopeEnd()) assert.NoError(t, err) assert.Len(t, nodes, 1) assert.Equal(t, fileScope, p.currentScope) @@ -394,10 +398,37 @@ func TestParseScopeEndWithoutParent(t *testing.T) { currentScope: scope.New(nil), } - _, err := parseASTNode(p, ast.NewScopeEnd()) + _, err := parseASTNode(t.Context(), p, ast.NewScopeEnd()) assert.Error(t, err) } +func TestParseSourceIncludeCycle(t *testing.T) { + cfg := m6502.New() + assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig))) + + p := &parseAST[*cpu6502.Instruction]{ + cfg: cfg, + fileReader: func(name string) ([]byte, error) { + switch name { + case "defs.asm": + return []byte(".include \"more.asm\"\n"), nil + case "more.asm": + return []byte(".include \"defs.asm\"\n"), nil + default: + return nil, nil + } + }, + includeActive: set.New[string](), + currentScope: scope.New(nil), + segments: map[string]*segment{}, + } + + _, err := parseASTNode(t.Context(), p, ast.NewInclude("defs.asm", false, 0, 0)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "include cycle detected") + assert.Contains(t, err.Error(), "defs.asm -> more.asm -> defs.asm") +} + type testTypedInstructionArgument struct { register string width int diff --git a/pkg/assembler/process_macros_step.go b/pkg/assembler/process_macros_step.go index 5fe1b28..4cede86 100644 --- a/pkg/assembler/process_macros_step.go +++ b/pkg/assembler/process_macros_step.go @@ -7,6 +7,7 @@ import ( "github.com/retroenv/retroasm/pkg/lexer/token" "github.com/retroenv/retroasm/pkg/parser" "github.com/retroenv/retroasm/pkg/parser/ast" + "github.com/retroenv/retrogolib/set" ) // processMacrosStep processes macro and rept nodes and replace them by their resolved nodes. @@ -80,7 +81,7 @@ func resolveMacroUsage[T any](ctx context.Context, asm *Assembler[T], id ast.Ide return macroTokensToAStNodes(ctx, asm, mac.tokens) } -func macroTokensToAStNodes[T any](_ context.Context, asm *Assembler[T], tokens []token.Token) ([]ast.Node, error) { +func macroTokensToAStNodes[T any](ctx context.Context, asm *Assembler[T], tokens []token.Token) ([]ast.Node, error) { // convert the adjusted tokens to AST nodes par := parser.NewWithTokens(asm.cfg.Arch, tokens) astNodes, err := par.TokensToAstNodes() @@ -89,16 +90,17 @@ func macroTokensToAStNodes[T any](_ context.Context, asm *Assembler[T], tokens [ } p := &parseAST[T]{ - cfg: asm.cfg, - fileReader: asm.fileReader, - currentScope: asm.fileScope, - segments: map[string]*segment{}, + cfg: asm.cfg, + fileReader: asm.fileReader, + includeActive: set.New[string](), + currentScope: asm.fileScope, + segments: map[string]*segment{}, } // process the AST nodes var result []ast.Node for _, node := range astNodes { - nodes, err := parseASTNode(p, node) + nodes, err := parseASTNode(ctx, p, node) if err != nil { return nil, err }