vet/pkg/code/languages/python.go
abhisek e6f6288701
feat: Code analysis framework infra
feat: Building code graph

Refactor to support import processing

Handle relative import name fixup

Add docs for code analysis framework

Update docs to include additional examples

feat: Function call graph

Update code graph to link function decl and calls

Include call node in function calls

feat: Flatten vulnerabilities in CSV reporter

refactor: Maintain separation of concerns for code analysis framework

refactor: Separate storage entities in its own package

feat: Add callback support in code graph builder

docs: Fix code analysis framework docs
Signed-off-by: abhisek <abhisek.datta@gmail.com>
2024-07-11 15:09:11 +05:30

234 lines
6.4 KiB
Go

package languages
import (
"fmt"
"path/filepath"
"strings"
"github.com/safedep/vet/pkg/code"
"github.com/safedep/vet/pkg/code/nodes"
"github.com/safedep/vet/pkg/common/logger"
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/python"
)
type pythonSourceLanguage struct {
commonSourceLanguage
}
func NewPythonSourceLanguage() (code.SourceLanguage, error) {
parser := sitter.NewParser()
parser.SetLanguage(python.GetLanguage())
return &pythonSourceLanguage{
commonSourceLanguage: commonSourceLanguage{
parser: parser,
lang: python.GetLanguage(),
},
}, nil
}
func (l *pythonSourceLanguage) GetMeta() code.SourceLanguageMeta {
return code.SourceLanguageMeta{
SourceFileGlobs: []string{"*.py"},
}
}
func (l *pythonSourceLanguage) GetImportNodes(cst *nodes.CST) ([]nodes.CSTImportNode, error) {
// Tree Sitter query to get import nodes in Python
// The order of capture names are important because the extraction
// uses capture index
query := `
(import_statement
name: ((dotted_name) @module_name))
(import_from_statement
module_name: (dotted_name) @module_name
name: (dotted_name
(identifier) @submodule_name @submodule_alias))
(import_from_statement
module_name: (relative_import) @module_name
name: (dotted_name
(identifier) @submodule_name @submodule_alias))
(import_statement
name: (aliased_import
name: ((dotted_name) @module_name @submodule_name)
alias: (identifier) @submodule_alias))
(import_from_statement
module_name: (dotted_name) @module_name
name: (aliased_import
name: (dotted_name
(identifier) @submodule_name)
alias: (identifier) @submodule_alias))
(import_from_statement
module_name: (relative_import) @module_name
name: (aliased_import
name: ((dotted_name) @submodule_name)
alias: (identifier) @submodule_alias))
`
importNodes := []nodes.CSTImportNode{}
err := code.TSExecQuery(query, python.GetLanguage(),
cst.Code(),
cst.Root(),
func(m *sitter.QueryMatch, _ *sitter.Query, ok bool) error {
importNode := nodes.NewCSTImportNode(cst).
WithModuleName(m.Captures[0].Node)
if len(m.Captures) > 1 {
importNode = importNode.WithModuleItem(m.Captures[1].Node)
}
if len(m.Captures) > 2 {
importNode = importNode.WithModuleAlias(m.Captures[2].Node)
}
logger.Debugf("Found imported module: %s (%s) as (%s)",
importNode.ImportName(),
importNode.ImportItem(),
importNode.ImportAlias())
importNodes = append(importNodes, *importNode)
return nil
})
return importNodes, err
}
func (l *pythonSourceLanguage) GetFunctionDeclarationNodes(cst *nodes.CST) ([]nodes.CSTFunctionNode, error) {
query := `
(function_definition
name: (identifier) @function_name
parameters: (parameters) @function_args
body: (block) @function_body) @function_declaration
`
functionNodes := []nodes.CSTFunctionNode{}
err := code.TSExecQuery(query, python.GetLanguage(),
cst.Code(),
cst.Root(),
func(m *sitter.QueryMatch, _ *sitter.Query, ok bool) error {
if len(m.Captures) != 4 {
return fmt.Errorf("expected 4 captures, got %d", len(m.Captures))
}
functionNode := nodes.NewCSTFunctionNode(cst).
WithDeclaration(m.Captures[0].Node).
WithName(m.Captures[1].Node).
WithArgs(m.Captures[2].Node).
WithBody(m.Captures[3].Node)
// Find the containing class if any by walking up the tree
parent := functionNode.Declaration().Parent()
for parent != nil {
if parent.Type() == "class_definition" {
functionNode = functionNode.WithContainer(parent.ChildByFieldName("name"))
break
}
parent = parent.Parent()
}
logger.Debugf("Found function declaration: %s/%s",
functionNode.Container(),
functionNode.Name())
functionNodes = append(functionNodes, *functionNode)
return nil
})
return functionNodes, err
}
func (l *pythonSourceLanguage) GetFunctionCallNodes(cst *nodes.CST) ([]nodes.CSTFunctionCallNode, error) {
query := `
(call
function: (identifier) @function_name
arguments: (argument_list) @arguments) @function_call
(call
function: (attribute
object: (identifier) @object
attribute: (identifier) @function_name)
arguments: (argument_list) @arguments) @function_call
`
callNodes := []nodes.CSTFunctionCallNode{}
err := code.TSExecQuery(query, python.GetLanguage(),
cst.Code(),
cst.Root(),
func(m *sitter.QueryMatch, q *sitter.Query, ok bool) error {
if len(m.Captures) < 3 {
return fmt.Errorf("expected at least 3 captures, got %d", len(m.Captures))
}
functionCallNode := nodes.NewCSTFunctionCallNode(cst).WithCall(m.Captures[0].Node)
n := len(m.Captures)
switch n {
case 3:
functionCallNode = functionCallNode.WithCallee(m.Captures[1].Node)
functionCallNode = functionCallNode.WithArgs(m.Captures[2].Node)
case 4:
functionCallNode = functionCallNode.WithReceiver(m.Captures[1].Node)
functionCallNode = functionCallNode.WithCallee(m.Captures[2].Node)
functionCallNode = functionCallNode.WithArgs(m.Captures[3].Node)
}
logger.Debugf("Found function call: (%s).%s",
functionCallNode.Receiver(),
functionCallNode.Callee())
callNodes = append(callNodes, *functionCallNode)
return nil
})
return callNodes, err
}
func (l *pythonSourceLanguage) ResolveImportNameFromPath(relPath string) (string, error) {
if relPath[0] == '/' {
return "", fmt.Errorf("path is not relative: %s", relPath)
}
relPath = strings.TrimSuffix(relPath, "__init__.py")
relPath = strings.TrimSuffix(relPath, "/")
relPath = strings.TrimSuffix(relPath, ".py")
relPath = strings.ReplaceAll(relPath, "/", ".")
return relPath, nil
}
func (l *pythonSourceLanguage) ResolveImportPathsFromName(currentFile code.SourceFile,
importName string, includeImports bool) ([]string, error) {
paths := []string{}
if len(importName) == 0 {
return paths, fmt.Errorf("import name is empty")
}
// If its a relative import, resolve it to the root
if importName[0] == '.' {
currDir := filepath.Dir(currentFile.Path)
relativeImportName := filepath.Join(currDir, importName[1:])
rootRelativePath, err := currentFile.Repository().GetRelativePath(relativeImportName, includeImports)
if err != nil {
return paths, fmt.Errorf("failed to get relative path: %w", err)
}
importName = rootRelativePath
}
importName = strings.ReplaceAll(importName, ".", "/")
paths = append(paths, importName+".py")
paths = append(paths, importName+"/__init__.py")
return paths, nil
}