Spaces:
Runtime error
Runtime error
Commit
·
db037f3
1
Parent(s):
41b76bb
Update main.py
Browse files
main.py
CHANGED
|
@@ -67,7 +67,7 @@ def create_position_map(pos, image_size=64, original_width=1024, original_height
|
|
| 67 |
pos_map = torch.zeros((1, image_size, image_size))
|
| 68 |
pos_map[0, y_scaled, x_scaled] = 1.0
|
| 69 |
|
| 70 |
-
return pos_map
|
| 71 |
|
| 72 |
# Serve the index.html file at the root URL
|
| 73 |
@app.get("/")
|
|
@@ -77,7 +77,7 @@ async def get():
|
|
| 77 |
def generate_random_image(width: int, height: int) -> np.ndarray:
|
| 78 |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
| 79 |
|
| 80 |
-
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
| 81 |
pil_image = Image.fromarray(image)
|
| 82 |
#pil_image = Image.open('image_3.png')
|
| 83 |
draw = ImageDraw.Draw(pil_image)
|
|
@@ -95,10 +95,12 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
|
|
| 95 |
y = y * 256 / 640
|
| 96 |
draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
|
| 97 |
|
|
|
|
| 98 |
if prev_x is not None:
|
| 99 |
#prev_x, prev_y = previous_actions[i-1][1]
|
| 100 |
draw.line([prev_x, prev_y, x, y], fill=color, width=1)
|
| 101 |
prev_x, prev_y = x, y
|
|
|
|
| 102 |
|
| 103 |
return np.array(pil_image)
|
| 104 |
|
|
@@ -204,7 +206,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 204 |
|
| 205 |
prompt = " ".join(action_descriptions[-8:])
|
| 206 |
|
| 207 |
-
pos_map = create_position_map(parse_action_string(action_descriptions[-1]))
|
| 208 |
|
| 209 |
|
| 210 |
#prompt = ''
|
|
@@ -220,7 +222,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 220 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
| 221 |
|
| 222 |
# Draw the trace of previous actions
|
| 223 |
-
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|
| 224 |
|
| 225 |
return new_frame_with_trace, new_frame_denormalized
|
| 226 |
|
|
|
|
| 67 |
pos_map = torch.zeros((1, image_size, image_size))
|
| 68 |
pos_map[0, y_scaled, x_scaled] = 1.0
|
| 69 |
|
| 70 |
+
return pos_map, x_scaled, y_scaled
|
| 71 |
|
| 72 |
# Serve the index.html file at the root URL
|
| 73 |
@app.get("/")
|
|
|
|
| 77 |
def generate_random_image(width: int, height: int) -> np.ndarray:
|
| 78 |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
| 79 |
|
| 80 |
+
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
|
| 81 |
pil_image = Image.fromarray(image)
|
| 82 |
#pil_image = Image.open('image_3.png')
|
| 83 |
draw = ImageDraw.Draw(pil_image)
|
|
|
|
| 95 |
y = y * 256 / 640
|
| 96 |
draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
|
| 97 |
|
| 98 |
+
|
| 99 |
if prev_x is not None:
|
| 100 |
#prev_x, prev_y = previous_actions[i-1][1]
|
| 101 |
draw.line([prev_x, prev_y, x, y], fill=color, width=1)
|
| 102 |
prev_x, prev_y = x, y
|
| 103 |
+
draw.ellipse([x_scaled*4-2, y_scaled*4-2, x_scaled*4+2, y_scaled*4+2], fill=(0, 255, 0))
|
| 104 |
|
| 105 |
return np.array(pil_image)
|
| 106 |
|
|
|
|
| 206 |
|
| 207 |
prompt = " ".join(action_descriptions[-8:])
|
| 208 |
|
| 209 |
+
pos_map, x_scaled, y_scaled = create_position_map(parse_action_string(action_descriptions[-1]))
|
| 210 |
|
| 211 |
|
| 212 |
#prompt = ''
|
|
|
|
| 222 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
| 223 |
|
| 224 |
# Draw the trace of previous actions
|
| 225 |
+
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
|
| 226 |
|
| 227 |
return new_frame_with_trace, new_frame_denormalized
|
| 228 |
|