Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions pkg/assembler/assembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/retroenv/retroasm/pkg/parser"
"github.com/retroenv/retroasm/pkg/parser/ast"
"github.com/retroenv/retroasm/pkg/scope"
"github.com/retroenv/retrogolib/set"
)

var errNoCurrentSegment = errors.New("no current segment found")
Expand Down Expand Up @@ -116,10 +117,11 @@ func (asm *Assembler[T]) Symbols() map[string]uint64 {
// parseASTNodes processes the given AST nodes and converts them to internal types.
func (asm *Assembler[T]) parseASTNodes(ctx context.Context, nodes []ast.Node) error {
p := &parseAST[T]{
cfg: asm.cfg,
fileReader: asm.fileReader,
currentScope: asm.fileScope,
segments: map[string]*segment{},
cfg: asm.cfg,
fileReader: asm.fileReader,
includeActive: set.New[string](),
currentScope: asm.fileScope,
segments: map[string]*segment{},
}
if len(asm.cfg.SegmentsOrdered) == 1 {
segCfg := asm.cfg.SegmentsOrdered[0]
Expand Down Expand Up @@ -151,7 +153,7 @@ func (asm *Assembler[T]) parseASTNodes(ctx context.Context, nodes []ast.Node) er
return errNoCurrentSegment
}

newNodes, err := parseASTNode(p, node)
newNodes, err := parseASTNode(ctx, p, node)
if err != nil {
return err
}
Expand Down
49 changes: 49 additions & 0 deletions pkg/assembler/assembler_asm6_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package assembler
import (
"bytes"
"context"
"fmt"
"strings"
"testing"

Expand Down Expand Up @@ -602,6 +603,54 @@ func TestAssemblerContextCancellation(t *testing.T) {
assert.ErrorIs(t, err, context.Canceled, "expected context.Canceled error")
}

var asm6SourceIncludeTestCode = `
.segment "HEADER"
.include "defs.asm"
DB myval
`

func TestAssemblerAsm6SourceInclude(t *testing.T) {
cfg := m6502.New()
assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig)))

reader := strings.NewReader(asm6SourceIncludeTestCode)
var buf bytes.Buffer
asm := New(cfg, &buf)

asm.fileReader = func(name string) ([]byte, error) {
assert.Equal(t, "defs.asm", name)
return []byte("myval = $42\n"), nil
}

assert.NoError(t, asm.Process(t.Context(), reader))
assert.Equal(t, []byte{0x42}, buf.Bytes())
}

func TestAssemblerAsm6SourceIncludeCycle(t *testing.T) {
cfg := m6502.New()
assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig)))

reader := strings.NewReader(asm6SourceIncludeTestCode)
var buf bytes.Buffer
asm := New(cfg, &buf)

asm.fileReader = func(name string) ([]byte, error) {
switch name {
case "defs.asm":
return []byte(".include \"more.asm\"\n"), nil
case "more.asm":
return []byte(".include \"defs.asm\"\n"), nil
default:
return nil, fmt.Errorf("unexpected include %q", name)
}
}

err := asm.Process(t.Context(), reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "include cycle detected")
assert.Contains(t, err.Error(), "defs.asm -> more.asm -> defs.asm")
}

func runAsm6Test(t *testing.T, testConfig, testCode string) ([]byte, error) {
t.Helper()

Expand Down
73 changes: 66 additions & 7 deletions pkg/assembler/parse_ast_nodes.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package assembler

import (
"bytes"
"context"
"errors"
"fmt"
"strings"
Expand All @@ -9,14 +11,18 @@ import (
"github.com/retroenv/retroasm/pkg/expression"
"github.com/retroenv/retroasm/pkg/lexer/token"
"github.com/retroenv/retroasm/pkg/number"
"github.com/retroenv/retroasm/pkg/parser"
"github.com/retroenv/retroasm/pkg/parser/ast"
"github.com/retroenv/retroasm/pkg/scope"
"github.com/retroenv/retrogolib/set"
)

type parseAST[T any] struct {
cfg *config.Config[T]
// a function that reads in a file, for testing includes, defaults to os.ReadFile
fileReader func(name string) ([]byte, error)
fileReader func(name string) ([]byte, error)
includeActive set.Set[string]
includeStack []string

currentScope *scope.Scope // current scope, can be a function scope with file scope as parent
currentSegment *segment // the current segment being parsed
Expand All @@ -28,7 +34,7 @@ type parseAST[T any] struct {
var errNilInstructionArgument = errors.New("instruction argument cannot be nil")

//nolint:cyclop,funlen // type switch with one case per AST node type
func parseASTNode[T any](asm *parseAST[T], node ast.Node) ([]ast.Node, error) {
func parseASTNode[T any](ctx context.Context, asm *parseAST[T], node ast.Node) ([]ast.Node, error) {
var (
err error
nodes []ast.Node
Expand Down Expand Up @@ -60,7 +66,7 @@ func parseASTNode[T any](asm *parseAST[T], node ast.Node) ([]ast.Node, error) {
nodes, err = parseInstruction(n)

case ast.Include:
nodes, err = parseInclude(asm, n)
nodes, err = parseInclude(ctx, asm, n)

case ast.Macro:
nodes, err = parseMacro(n)
Expand Down Expand Up @@ -355,12 +361,17 @@ func nameWithModifiers(name string, modifiers []ast.Modifier) (string, error) {
return fmt.Sprintf("%s%d", name, offset), nil // offset is negative, fmt includes '-'
}

func parseInclude[T any](asm *parseAST[T], inc ast.Include) ([]ast.Node, error) {
if !inc.Binary {
return nil, errors.New("non binary includes are currently not supported") // TODO implement
func parseInclude[T any](ctx context.Context, asm *parseAST[T], inc ast.Include) ([]ast.Node, error) {
name := strings.Trim(inc.Name, "\"'")

if inc.Binary {
return parseBinaryInclude(asm, name)
}

name := strings.Trim(inc.Name, "\"'")
return parseSourceInclude(ctx, asm, name)
}

func parseBinaryInclude[T any](asm *parseAST[T], name string) ([]ast.Node, error) {
b, err := asm.fileReader(name)
if err != nil {
return nil, fmt.Errorf("reading file '%s': %w", name, err)
Expand All @@ -372,6 +383,54 @@ func parseInclude[T any](asm *parseAST[T], inc ast.Include) ([]ast.Node, error)
return []ast.Node{dat}, nil
}

func parseSourceInclude[T any](ctx context.Context, asm *parseAST[T], name string) ([]ast.Node, error) {
if asm.includeActive.Contains(name) {
chain := append(append([]string{}, asm.includeStack...), name)
return nil, fmt.Errorf("include cycle detected: %s", strings.Join(chain, " -> "))
}
asm.includeActive.Add(name)
asm.includeStack = append(asm.includeStack, name)
defer func() {
asm.includeActive.Remove(name)
asm.includeStack = asm.includeStack[:len(asm.includeStack)-1]
}()

b, err := asm.fileReader(name)
if err != nil {
return nil, fmt.Errorf("reading file '%s': %w", name, err)
}

pars := parser.New[T](asm.cfg.Arch, bytes.NewReader(b))
if err := pars.Read(ctx); err != nil {
return nil, fmt.Errorf("parsing included file '%s': %w", name, err)
}

nodes, err := pars.TokensToAstNodes()
if err != nil {
return nil, fmt.Errorf("converting tokens for included file '%s': %w", name, err)
}

var result []ast.Node
for _, node := range nodes {
switch n := node.(type) {
case *ast.Comment:
continue
case ast.Segment:
if err := parseSegment(asm, n); err != nil {
return nil, fmt.Errorf("parsing segment in included file '%s': %w", name, err)
}
default:
newNodes, err := parseASTNode(ctx, asm, node)
if err != nil {
return nil, fmt.Errorf("processing node in included file '%s': %w", name, err)
}
result = append(result, newNodes...)
}
}

return result, nil
}

func parseVariable(astVar ast.Variable) []ast.Node {
v := &variable{v: astVar}
return []ast.Node{v}
Expand Down
39 changes: 35 additions & 4 deletions pkg/assembler/parse_ast_nodes_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package assembler

import (
"strings"
"testing"

"github.com/retroenv/retroasm/pkg/arch/m6502"
"github.com/retroenv/retroasm/pkg/parser/ast"
"github.com/retroenv/retroasm/pkg/scope"
cpu6502 "github.com/retroenv/retrogolib/arch/cpu/m6502"
"github.com/retroenv/retrogolib/assert"
"github.com/retroenv/retrogolib/set"
)

func TestModifierOffset(t *testing.T) { //nolint:funlen
Expand Down Expand Up @@ -338,7 +342,7 @@ func TestParseScope(t *testing.T) {
currentScope: fileScope,
}

nodes, err := parseASTNode(p, ast.NewScope("inner"))
nodes, err := parseASTNode(t.Context(), p, ast.NewScope("inner"))
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.NotNil(t, p.currentScope)
Expand All @@ -362,7 +366,7 @@ func TestParseUnnamedScope(t *testing.T) {
currentScope: fileScope,
}

nodes, err := parseASTNode(p, ast.NewScope(""))
nodes, err := parseASTNode(t.Context(), p, ast.NewScope(""))
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, fileScope, p.currentScope.Parent())
Expand All @@ -379,7 +383,7 @@ func TestParseScopeEnd(t *testing.T) {
currentScope: childScope,
}

nodes, err := parseASTNode(p, ast.NewScopeEnd())
nodes, err := parseASTNode(t.Context(), p, ast.NewScopeEnd())
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, fileScope, p.currentScope)
Expand All @@ -394,10 +398,37 @@ func TestParseScopeEndWithoutParent(t *testing.T) {
currentScope: scope.New(nil),
}

_, err := parseASTNode(p, ast.NewScopeEnd())
_, err := parseASTNode(t.Context(), p, ast.NewScopeEnd())
assert.Error(t, err)
}

func TestParseSourceIncludeCycle(t *testing.T) {
cfg := m6502.New()
assert.NoError(t, cfg.ReadCa65Config(strings.NewReader(unitTestConfig)))

p := &parseAST[*cpu6502.Instruction]{
cfg: cfg,
fileReader: func(name string) ([]byte, error) {
switch name {
case "defs.asm":
return []byte(".include \"more.asm\"\n"), nil
case "more.asm":
return []byte(".include \"defs.asm\"\n"), nil
default:
return nil, nil
}
},
includeActive: set.New[string](),
currentScope: scope.New(nil),
segments: map[string]*segment{},
}

_, err := parseASTNode(t.Context(), p, ast.NewInclude("defs.asm", false, 0, 0))
assert.Error(t, err)
assert.Contains(t, err.Error(), "include cycle detected")
assert.Contains(t, err.Error(), "defs.asm -> more.asm -> defs.asm")
}

type testTypedInstructionArgument struct {
register string
width int
Expand Down
14 changes: 8 additions & 6 deletions pkg/assembler/process_macros_step.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/retroenv/retroasm/pkg/lexer/token"
"github.com/retroenv/retroasm/pkg/parser"
"github.com/retroenv/retroasm/pkg/parser/ast"
"github.com/retroenv/retrogolib/set"
)

// processMacrosStep processes macro and rept nodes and replace them by their resolved nodes.
Expand Down Expand Up @@ -80,7 +81,7 @@ func resolveMacroUsage[T any](ctx context.Context, asm *Assembler[T], id ast.Ide
return macroTokensToAStNodes(ctx, asm, mac.tokens)
}

func macroTokensToAStNodes[T any](_ context.Context, asm *Assembler[T], tokens []token.Token) ([]ast.Node, error) {
func macroTokensToAStNodes[T any](ctx context.Context, asm *Assembler[T], tokens []token.Token) ([]ast.Node, error) {
// convert the adjusted tokens to AST nodes
par := parser.NewWithTokens(asm.cfg.Arch, tokens)
astNodes, err := par.TokensToAstNodes()
Expand All @@ -89,16 +90,17 @@ func macroTokensToAStNodes[T any](_ context.Context, asm *Assembler[T], tokens [
}

p := &parseAST[T]{
cfg: asm.cfg,
fileReader: asm.fileReader,
currentScope: asm.fileScope,
segments: map[string]*segment{},
cfg: asm.cfg,
fileReader: asm.fileReader,
includeActive: set.New[string](),
currentScope: asm.fileScope,
segments: map[string]*segment{},
}

// process the AST nodes
var result []ast.Node
for _, node := range astNodes {
nodes, err := parseASTNode(p, node)
nodes, err := parseASTNode(ctx, p, node)
if err != nil {
return nil, err
}
Expand Down
Loading