Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,19 @@ func doDeallocate(ses *Session, execCtx *ExecCtx, st *tree.Deallocate) error {
return nil
}

func doReset(_ context.Context, _ *Session, _ *tree.Reset) error {
func doReset(ctx context.Context, ses *Session, st *tree.Reset) error {
if ses == nil || st == nil {
return nil
}
stmtName := string(st.Name)
if stmtName == "" {
return nil
}
preStmt, ok := ses.prepareStmts[stmtName]
if !ok || preStmt == nil {
return nil
}
preStmt.resetBinaryParamState()
return nil
}

Expand Down Expand Up @@ -3382,19 +3394,17 @@ func ExecRequest(ses *Session, execCtx *ExecCtx, req *Request) (resp *Response,
var prepareStmt *PrepareStmt
sql, prepareStmt, err = parseStmtExecute(execCtx.reqCtx, ses, req.GetData().([]byte))
if err != nil {
if prepareStmt != nil {
prepareStmt.clearBinaryParamState(ses.GetProc())
}
return NewGeneralErrorResponse(COM_STMT_EXECUTE, ses.GetTxnHandler().GetServerStatus(), err), nil
}
execCtx.prepareColDef = prepareStmt.ColDefData
err = doComQuery(ses, execCtx, &UserInput{sql: sql, stmtName: prepareStmt.Name, stmt: prepareStmt.PrepareStmt, preparePlan: prepareStmt.PreparePlan, isBinaryProtExecute: true})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, ses.GetTxnHandler().GetServerStatus(), err)
}
if prepareStmt.params != nil {
prepareStmt.params.GetNulls().Reset()
for k := range prepareStmt.getFromSendLongData {
delete(prepareStmt.getFromSendLongData, k)
}
}
prepareStmt.clearBinaryParamState(ses.GetProc())
return resp, nil

case COM_STMT_SEND_LONG_DATA:
Expand All @@ -3414,6 +3424,7 @@ func ExecRequest(ses *Session, execCtx *ExecCtx, req *Request) (resp *Response,
preStmt, err = ses.GetPrepareStmt(execCtx.reqCtx, stmtName)
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, ses.GetTxnHandler().GetServerStatus(), err)
return resp, nil
}
prefix := ""
if preStmt.IsCloudNonuser {
Expand All @@ -3435,7 +3446,8 @@ func ExecRequest(ses *Session, execCtx *ExecCtx, req *Request) (resp *Response,
var preStmt *PrepareStmt
preStmt, err = ses.GetPrepareStmt(execCtx.reqCtx, stmtName)
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, ses.GetTxnHandler().GetServerStatus(), err)
resp = NewGeneralErrorResponse(COM_STMT_RESET, ses.GetTxnHandler().GetServerStatus(), err)
return resp, nil
}
prefix := ""
if preStmt.IsCloudNonuser {
Expand Down Expand Up @@ -3487,7 +3499,7 @@ func parseStmtExecute(reqCtx context.Context, ses *Session, data []byte) (string
ses.Debug(reqCtx, "query trace", logutil.QueryField(sql))
err = ses.GetResponser().MysqlRrWr().ParseExecuteData(reqCtx, ses.GetProc(), preStmt, data, pos)
if err != nil {
return "", nil, err
return "", preStmt, err
}
return sql, preStmt, nil
}
Expand Down
148 changes: 148 additions & 0 deletions pkg/frontend/mysql_cmd_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package frontend

import (
"context"
"encoding/binary"
"fmt"
"io"
"sync/atomic"
Expand Down Expand Up @@ -857,6 +858,110 @@ func Test_HandleDeallocate(t *testing.T) {
})
}

func Test_doResetClearsPreparedBinaryState(t *testing.T) {
ctx := context.TODO()
ses := &Session{
prepareStmts: make(map[string]*PrepareStmt),
}
proc := testutil.NewProc(t)
params := vector.NewVec(types.T_text.ToType())
require.NoError(t, vector.AppendBytes(params, []byte("long-data"), false, proc.Mp()))
params.GetNulls().Set(0)

stmtName := "stmt1"
prepareStmt := &PrepareStmt{
Name: stmtName,
proc: proc,
params: params,
getFromSendLongData: map[int]struct{}{0: {}},
}
defer prepareStmt.Close()
ses.prepareStmts[stmtName] = prepareStmt

require.NoError(t, doReset(ctx, ses, tree.NewReset(tree.Identifier(stmtName))))
require.False(t, prepareStmt.params.GetNulls().Any())
require.Empty(t, prepareStmt.getFromSendLongData)
}

func Test_ExecRequestPrepareCommandMissingStmt(t *testing.T) {
ctx := context.TODO()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

for _, tc := range []struct {
name string
cmd CommandType
}{
{name: "close", cmd: COM_STMT_CLOSE},
{name: "reset", cmd: COM_STMT_RESET},
} {
t.Run(tc.name, func(t *testing.T) {
ses := newTestSession(t, ctrl)
ec := newTestExecCtx(ctx, ctrl)
stmtID := uint32(123)
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, stmtID)

resp, err := ExecRequest(ses, ec, &Request{cmd: tc.cmd, data: data})
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, ErrorResponse, resp.category)
require.Equal(t, int(tc.cmd), resp.cmd)
require.Error(t, resp.GetData().(error))
})
}
}

func Test_ExecRequestStmtExecuteErrorClearsPreparedBinaryState(t *testing.T) {
ctx := context.TODO()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ses := newTestSession(t, ctrl)
ec := newTestExecCtx(ctx, ctrl)
stmtID := uint32(321)
stmtName := getPrepareStmtName(stmtID)
st := tree.NewPrepareString(tree.Identifier(stmtName), "select ?, ?")
stmts, err := mysql.Parse(ctx, st.Sql, 1)
require.NoError(t, err)
compCtx := plan.NewEmptyCompilerContext()
preparePlan, err := buildPlan(ctx, nil, compCtx, st)
require.NoError(t, err)

proc := ses.GetProc()
params := vector.NewVec(types.T_text.ToType())
require.NoError(t, vector.AppendBytes(params, []byte("leftover"), false, proc.Mp()))
require.NoError(t, vector.AppendBytes(params, []byte("leftover"), false, proc.Mp()))
params.GetNulls().Set(1)

prepareStmt := &PrepareStmt{
Name: stmtName,
PreparePlan: preparePlan,
PrepareStmt: stmts[0],
proc: proc,
params: params,
getFromSendLongData: map[int]struct{}{0: {}},
}
defer prepareStmt.Close()
require.NoError(t, ses.SetPrepareStmt(ctx, stmtName, prepareStmt))

data := make([]byte, 0, 10)
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, stmtID)
data = append(data, buf...)
data = append(data, 0) // flag
data = append(data, 0, 0, 0, 0) // iteration-count
data = append(data, 0) // null bitmap
data = append(data, 0) // use existing ParamTypes, which are empty

resp, err := ExecRequest(ses, ec, &Request{cmd: COM_STMT_EXECUTE, data: data})
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, ErrorResponse, resp.category)
require.Nil(t, prepareStmt.params)
require.Empty(t, prepareStmt.getFromSendLongData)
}

func Test_CMD_FIELD_LIST(t *testing.T) {
ctx := defines.AttachAccountId(context.TODO(), catalog.System_Account)
convey.Convey("cmd field list", t, func() {
Expand Down Expand Up @@ -1658,6 +1763,49 @@ func Test_unsupportedCommand(t *testing.T) {
assert.Equal(t, "internal error: unsupported command. 0x0", respErr.Error())
}

func Test_ExecRequestStmtExecuteErrorClearsPreparedParamState(t *testing.T) {
ctx := context.TODO()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ses := newTestSession(t, ctrl)
execCtx := &ExecCtx{
ses: ses,
reqCtx: ctx,
}

st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?")
stmts, err := mysql.Parse(ctx, st.Sql, 1)
require.NoError(t, err)

compCtx := plan.NewEmptyCompilerContext()
preparePlan, err := buildPlan(ctx, nil, compCtx, st)
require.NoError(t, err)

prepareStmt := &PrepareStmt{
Name: preparePlan.GetDcl().GetPrepare().GetName(),
PreparePlan: preparePlan,
PrepareStmt: stmts[0],
getFromSendLongData: make(map[int]struct{}),
}
require.NoError(t, ses.SetPrepareStmt(ctx, prepareStmt.Name, prepareStmt))

payload := make([]byte, 4)
binary.LittleEndian.PutUint32(payload, 1)
payload = append(payload, 0) // flag
payload = append(payload, 0, 0, 0, 0) // iteration-count
payload = append(payload, 0) // null bitmap
payload = append(payload, 1) // new param bound flag
payload = append(payload, uint8(defines.MYSQL_TYPE_VAR_STRING), 0)
payload = append(payload, 5, 'a', 'b') // truncated lenenc string

resp, err := ExecRequest(ses, execCtx, &Request{cmd: COM_STMT_EXECUTE, data: payload})
require.NoError(t, err)
require.NotNil(t, resp)
require.Nil(t, prepareStmt.params)
require.Empty(t, prepareStmt.getFromSendLongData)
}

func Test_panic(t *testing.T) {
fault.EnableDomain(fault.DomainFrontend)
defer fault.DisableDomain(fault.DomainFrontend)
Expand Down
Loading
Loading