openvla驱动mujoco的franka机械臂运动

  1. openvla-7b 需要自己下载好模型,大约15GB;不会科学上网参考这里 不会科学上网自己从魔塔下载openvla-7b
  2. mujoco 的 fr3.xml 文件,是从Fr3py项目抄过来的。可以参考加速git仓库:整个mujoco目录都复制过去
  3. mujoco 第一次渲染,一定会有不少报错,参考内部文档,自行解决问题 mujoco安装问题处理

如下代码,就可以用openvla直接驱动 基于mujoco的Franka运动起来。 后续的摄像头仿真后续加上

franka.py
#!/usr/bin/env python3
 
import os
import sys
import time
import numpy as np
from PIL import Image
 
os.environ['MUJOCO_GL'] = 'egl'
 
import mujoco
import mujoco.viewer
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
 
XML_PATH = os.path.join(os.path.dirname(__file__), "mujoco", "fr3.xml")
MODEL_PATH = "/home/ctbots/llm/openvla-7b"
 
 
class OpenVLAMuJoCoController:
    def __init__(self, use_gui: bool = True):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"[INFO] 使用设备: {self.device}")
 
        if torch.cuda.is_available():
            print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}")
 
        print(f"[INFO] 加载 MuJoCo 模型: {XML_PATH}")
        self.model = mujoco.MjModel.from_xml_path(XML_PATH)
        self.data = mujoco.MjData(self.model)
        print("[SUCCESS] MuJoCo 模型加载成功")
 
        print("[INFO] 加载 OpenVLA 模型...")
        self.processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
        self.vla = AutoModelForVision2Seq.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        ).to(self.device)
        print("[SUCCESS] OpenVLA 模型加载成功")
 
        self.use_gui = use_gui and os.environ.get('DISPLAY')
        self.renderer = None
 
        if not self.use_gui:
            print("[INFO] 初始化离屏渲染器...")
            try:
                self.renderer = mujoco.Renderer(self.model, height=224, width=224)
                print("[SUCCESS] 渲染器初始化成功")
            except Exception as e:
                print(f"[WARNING] 渲染器初始化失败: {e}")
 
        self.action_scale = 0.05
        self.step_count = 0
 
        self.initial_qpos = np.array([0, 0, 0, -1.57079, 0, 1.57079, -0.7853])
        self.data.qpos[:7] = self.initial_qpos
        mujoco.mj_forward(self.model, self.data)
 
    def get_camera_image(self):
        """获取相机图像"""
        if self.renderer:
            self.renderer.update_scene(self.data, camera="track")
            pixels = self.renderer.render()
            return pixels
        else:
            return np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
 
    def predict_action(self, image: np.ndarray, instruction: str):
        """使用 OpenVLA 预测动作"""
        pil_image = Image.fromarray(image)
        if pil_image.mode != 'RGB':
            pil_image = pil_image.convert('RGB')
 
        prompt = f"In: What action should the robot take to {instruction}?\nOut:"
 
        inputs = self.processor(prompt, pil_image).to(
            self.device,
            dtype=torch.bfloat16 if self.device == "cuda" else torch.float32
        )
 
        with torch.no_grad():
            action = self.vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
 
        if isinstance(action, torch.Tensor):
            action_np = action.cpu().numpy().flatten()
        else:
            action_np = np.array(action).flatten()
 
        if len(action_np) > 7:
            action_np = action_np[:7]
        elif len(action_np) < 7:
            action_np = np.pad(action_np, (0, 7 - len(action_np)), 'constant')
 
        action_np = np.clip(action_np, -1.0, 1.0) * self.action_scale
 
        return action_np
 
    def run_with_gui(self, instruction: str, num_steps: int = 500):
        """使用 GUI 交互式运行"""
        print(f"\n[INFO] 任务: {instruction}")
        print(f"[INFO] 总步数: {num_steps}")
        print("[INFO] 启动交互式查看器...")
        print("[INFO] 按 Ctrl+C 停止")
 
        with mujoco.viewer.launch_passive(self.model, self.data) as viewer:
            step = 0
            reset_interval = 100
 
            while viewer.is_running() and step < num_steps:
                self.step_count += 1
                step += 1
 
                if step % 10 == 0:
                    rgb_image = self.get_camera_image()
                    action = self.predict_action(rgb_image, instruction)
 
                    self.data.ctrl[:7] = np.clip(
                        self.data.qpos[:7] + action,
                        self.model.jnt_range[:7, 0],
                        self.model.jnt_range[:7, 1]
                    )
 
                mujoco.mj_step(self.model, self.data)
                viewer.sync()
 
                if step % 50 == 0:
                    ee_pos = self.data.xpos[self.model.body('hand').id]
                    print(f"[Step {step:04d}] 末端位置: [{ee_pos[0]:.3f}, {ee_pos[1]:.3f}, {ee_pos[2]:.3f}]")
 
                if step % reset_interval == 0:
                    print(f"[INFO] 重置到初始位置 (周期性运动)")
                    self.data.qpos[:7] = self.initial_qpos + np.random.randn(7) * 0.1
                    self.data.qpos[:7] = np.clip(
                        self.data.qpos[:7],
                        self.model.jnt_range[:7, 0],
                        self.model.jnt_range[:7, 1]
                    )
                    mujoco.mj_forward(self.model, self.data)
 
        print(f"\n[SUCCESS] 完成 {step} 步仿真")
 
    def run_headless(self, instruction: str, num_steps: int = 500):
        """无 GUI 运行(离屏渲染)"""
        print(f"\n[INFO] 任务: {instruction}")
        print(f"[INFO] 总步数: {num_steps}")
        print("[INFO] 离屏模式运行...")
 
        reset_interval = 100
 
        for step in range(num_steps):
            self.step_count += 1
 
            if step % 10 == 0:
                rgb_image = self.get_camera_image()
                action = self.predict_action(rgb_image, instruction)
 
                self.data.ctrl[:7] = np.clip(
                    self.data.qpos[:7] + action,
                    self.model.jnt_range[:7, 0],
                    self.model.jnt_range[:7, 1]
                )
 
            mujoco.mj_step(self.model, self.data)
 
            if step % 50 == 0:
                ee_pos = self.data.xpos[self.model.body('hand').id]
                print(f"[Step {step:04d}] 末端位置: [{ee_pos[0]:.3f}, {ee_pos[1]:.3f}, {ee_pos[2]:.3f}]")
 
            if step % reset_interval == 0 and step > 0:
                print(f"[INFO] 重置到初始位置 (周期性运动)")
                self.data.qpos[:7] = self.initial_qpos + np.random.randn(7) * 0.1
                self.data.qpos[:7] = np.clip(
                    self.data.qpos[:7],
                    self.model.jnt_range[:7, 0],
                    self.model.jnt_range[:7, 1]
                )
                mujoco.mj_forward(self.model, self.data)
 
        print(f"\n[SUCCESS] 完成 {num_steps} 步仿真")
 
    def close(self):
        if self.renderer:
            self.renderer.close()
 
 
def main():
    import argparse
 
    parser = argparse.ArgumentParser(description="OpenVLA + MuJoCo Franka 仿真")
    parser.add_argument("--instruction", type=str, default="pick up the object",
                        help="任务指令")
    parser.add_argument("--steps", type=int, default=500,
                        help="仿真步数")
    parser.add_argument("--no_gui", action="store_true",
                        help="禁用 GUI,使用离屏渲染")
 
    args = parser.parse_args()
 
    controller = OpenVLAMuJoCoController(use_gui=not args.no_gui)
 
    try:
        if controller.use_gui:
            controller.run_with_gui(args.instruction, args.steps)
        else:
            controller.run_headless(args.instruction, args.steps)
    except KeyboardInterrupt:
        print("\n[INFO] 用户中断")
    finally:
        controller.close()
 
 
if __name__ == "__main__":
    main()