Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from sgpt.llm_functions.init_functions import install_functions as inst_funcs
from sgpt.role import DefaultRoles, SystemRole
from sgpt.utils import (
extract_command_from_completion,
get_edited_prompt,
get_fixed_prompt,
get_sgpt_version,
install_shell_integration,
run_command,
Expand Down Expand Up @@ -84,6 +86,10 @@ def main(
False,
help="Open $EDITOR to provide a prompt.",
),
fix: bool = typer.Option(
False,
help="Fix the wrong last command.",
),
cache: bool = typer.Option(
True,
help="Cache completion results.",
Expand Down Expand Up @@ -199,6 +205,9 @@ def main(
if editor:
prompt = get_edited_prompt()

if fix:
prompt = get_fixed_prompt()

role_class = (
DefaultRoles.check_get(shell, describe_shell, code)
if not role
Expand All @@ -218,6 +227,11 @@ def main(
functions=function_schemas,
)

if not prompt:
if not show_chat:
print("Prompt cant be empty. Use `sgpt <prompt>` to get started.")
return

if chat:
full_completion = ChatHandler(chat, role_class, md).handle(
prompt=prompt,
Expand All @@ -238,6 +252,7 @@ def main(
)

while shell and interaction:
command = extract_command_from_completion(full_completion)
option = typer.prompt(
text="[E]xecute, [D]escribe, [A]bort",
type=Choice(("e", "d", "a", "y"), case_sensitive=False),
Expand All @@ -247,10 +262,10 @@ def main(
)
if option in ("e", "y"):
# "y" option is for keeping compatibility with old version.
run_command(full_completion)
run_command(command)
elif option == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role(), md).handle(
full_completion,
command,
model=model,
temperature=temperature,
top_p=top_p,
Expand Down
31 changes: 31 additions & 0 deletions sgpt/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path
from click import UsageError
import json


class Command:
def __init__(self, command_path: Path):
self.command_path = command_path
self.command_path.mkdir(parents=True, exist_ok=True)

def get_last_command(self) -> tuple[str, str]:
"""
get the last command and output from the command path
"""
last_command_file = self.command_path / "last_command.json"
if not last_command_file.exists():
raise UsageError("No last command and output found.")
with open(last_command_file, "r", encoding="utf-8") as file:
data = json.load(file)
command = data.get("command", "")
output = data.get("output", "")
return command, output

def set_last_command(self, command: str, output: str) -> None:
"""
set the last command and output to the command path
"""
last_command_file = self.command_path / "last_command.json"
with open(last_command_file, "w", encoding="utf-8") as file:
data = {"command": command, "output": output}
json.dump(data, file, ensure_ascii=False)
2 changes: 2 additions & 0 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FUNCTIONS_PATH = SHELL_GPT_CONFIG_FOLDER / "functions"
CHAT_CACHE_PATH = Path(gettempdir()) / "chat_cache"
CACHE_PATH = Path(gettempdir()) / "cache"
COMMAND_PATH = Path(gettempdir()) / "commands"

# TODO: Refactor ENV variables with SGPT_ prefix.
DEFAULT_CONFIG = {
Expand All @@ -37,6 +38,7 @@
"SHELL_INTERACTION": os.getenv("SHELL_INTERACTION ", "true"),
"OS_NAME": os.getenv("OS_NAME", "auto"),
"SHELL_NAME": os.getenv("SHELL_NAME", "auto"),
"COMMAND_PATH": os.getenv("COMMAND_PATH", str(COMMAND_PATH)),
# New features might add their own config variables here.
}

Expand Down
11 changes: 7 additions & 4 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rich.rule import Rule

from ..role import DefaultRoles, SystemRole
from ..utils import run_command
from ..utils import run_command, get_fixed_prompt, extract_command_from_completion
from .chat_handler import ChatHandler
from .default_handler import DefaultHandler

Expand All @@ -32,7 +32,7 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
if not self.role.name == DefaultRoles.SHELL.value
else (
"Entering shell REPL mode, type [e] to execute commands "
"or [d] to describe the commands, press Ctrl+C to exit."
"or [d] to describe the commands, or [f] to fix the last command, press Ctrl+C to exit."
)
)
typer.secho(info_message, fg="yellow")
Expand All @@ -53,14 +53,17 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
if init_prompt:
prompt = f"{init_prompt}\n\n\n{prompt}"
init_prompt = ""
if self.role.name == DefaultRoles.SHELL.value and prompt == "f":
prompt = get_fixed_prompt()
command = extract_command_from_completion(full_completion)
if self.role.name == DefaultRoles.SHELL.value and prompt == "e":
typer.echo()
run_command(full_completion)
run_command(command)
typer.echo()
rich_print(Rule(style="bold magenta"))
elif self.role.name == DefaultRoles.SHELL.value and prompt == "d":
DefaultHandler(
DefaultRoles.DESCRIBE_SHELL.get_role(), self.markdown
).handle(prompt=full_completion, **kwargs)
).handle(prompt=command, **kwargs)
else:
full_completion = super().handle(prompt=prompt, **kwargs)
38 changes: 37 additions & 1 deletion sgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import os
import re
from pathlib import Path
import platform
import shlex
import subprocess
from tempfile import NamedTemporaryFile
from typing import Any, Callable

import typer
from click import BadParameter, UsageError

from sgpt.__version__ import __version__
from sgpt.command import Command
from sgpt.config import cfg
from sgpt.integration import bash_integration, zsh_integration


Expand All @@ -33,6 +38,26 @@ def get_edited_prompt() -> str:
return output


command_helper = Command(command_path=Path(cfg.get("COMMAND_PATH")))


def get_fixed_prompt() -> str:
"""
get the last command and output then return a PROMPT
"""
command, output = command_helper.get_last_command()
return f"The last command `{command}` failed with error:\n{output}\nPlease fix it."


def extract_command_from_completion(completion: str) -> str:
"""
using regex to extract the command from the completion
"""
if match := re.search(r"```(.*sh)?(.*?)```", completion, re.DOTALL):
return match[2].strip()
return completion


def run_command(command: str) -> None:
"""
Runs a command in the user's shell.
Expand All @@ -50,7 +75,18 @@ def run_command(command: str) -> None:
shell = os.environ.get("SHELL", "/bin/sh")
full_command = f"{shell} -c {shlex.quote(command)}"

os.system(full_command)
# os.system(full_command)
process = subprocess.Popen(
args=full_command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
shell=True,
)
output, _ = process.communicate()
print(output)

command_helper.set_last_command(command, output)


def option_callback(func: Callable) -> Callable: # type: ignore
Expand Down
28 changes: 22 additions & 6 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import subprocess
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -81,15 +82,22 @@ def test_describe_shell_stdin(completion):
assert "lists" in result.stdout


@patch("os.system")
@patch("subprocess.Popen")
@patch("sgpt.handlers.handler.completion")
def test_shell_run_description(completion, system):
def test_shell_run_description(completion, mock_popen):
mock_popen.return_value.communicate.return_value = ("stdout", None)
completion.side_effect = [mock_comp("echo hello"), mock_comp("prints hello")]
args = {"prompt": "echo hello", "--shell": True}
inputs = "__sgpt__eof__\nd\ne\n"
result = runner.invoke(app, cmd_args(**args), input=inputs)
shell = os.environ.get("SHELL", "/bin/sh")
system.assert_called_once_with(f"{shell} -c 'echo hello'")
mock_popen.assert_called_once_with(
args=f"{shell} -c 'echo hello'",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
shell=True,
)
assert result.exit_code == 0
assert "echo hello" in result.stdout
assert "prints hello" in result.stdout
Expand Down Expand Up @@ -133,9 +141,10 @@ def test_shell_chat(completion):
# TODO: Shell chat can be recalled without --shell option.


@patch("os.system")
@patch("subprocess.Popen")
@patch("sgpt.handlers.handler.completion")
def test_shell_repl(completion, mock_system):
def test_shell_repl(completion, mock_popen):
mock_popen.return_value.communicate.return_value = ("stdout", None)
completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value)
chat_name = "_test"
Expand All @@ -146,7 +155,14 @@ def test_shell_repl(completion, mock_system):
inputs = ["__sgpt__eof__", "list folder", "sort by name", "e", "exit()"]
result = runner.invoke(app, cmd_args(**args), input="\n".join(inputs))
shell = os.environ.get("SHELL", "/bin/sh")
mock_system.assert_called_once_with(f"{shell} -c 'ls | sort'")

mock_popen.assert_called_once_with(
args=f"{shell} -c 'ls | sort'",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
shell=True,
)

expected_messages = [
{"role": "system", "content": role.role},
Expand Down