Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle WDL structs #97

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions wdl2cwl/WdlV1_1ParserVisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,25 @@
from .WdlV1_1Parser import WdlV1_1Parser
else:
from WdlV1_1Parser import WdlV1_1Parser
from typing import List

# This class defines a complete generic visitor for a parse tree produced by WdlV1_1Parser.

class Struct:
"""A class to represent WDL's Struct's."""

def __init__(self, name: str):
""":param name: struct name."""
self.name = name
self.fields: List[str] = []


class WdlV1_1ParserVisitor(ParseTreeVisitor):
"""WDL parser AST visitor."""

def __init__(self):
"""Create a visitor."""
self.structs = []
self.task_inputs = []
self.task_inputs_bound = []
self.task_outputs = []
Expand Down Expand Up @@ -314,6 +327,11 @@ def visitImport_doc(self, ctx:WdlV1_1Parser.Import_docContext):

# Visit a parse tree produced by WdlV1_1Parser#struct.
def visitStruct(self, ctx:WdlV1_1Parser.StructContext):
"""Build a WDL struct to be used later when converting to CWL input/output records."""
struct = Struct(name=str(ctx.Identifier()))
for decl in ctx.unbound_decls():
struct.fields.append([str(decl.Identifier()), decl.wdl_type().getText()])
self.structs.append(struct)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the core of this change depends on saving the structs when visiting the WDL abstract tree. An alternative would be to fetch them when/if necessary from the tree (not from ast), but I opted for this approach. Happy to change if needed.

return self.visitChildren(ctx)


Expand Down
153 changes: 92 additions & 61 deletions wdl2cwl/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Main entrypoint for WDL2CWL."""
import argparse
from argparse import Namespace
import sys
from io import StringIO
from typing import List, cast, Union
from typing import List, cast, Union, Dict
from io import StringIO
import textwrap
import re
Expand All @@ -17,7 +15,7 @@

from wdl2cwl.WdlV1_1Lexer import WdlV1_1Lexer
from wdl2cwl.WdlV1_1Parser import WdlV1_1Parser
from wdl2cwl.WdlV1_1ParserVisitor import WdlV1_1ParserVisitor
from wdl2cwl.WdlV1_1ParserVisitor import WdlV1_1ParserVisitor, Struct

# WDL-CWL Type Mappings
wdl_type = {
Expand Down Expand Up @@ -396,7 +394,7 @@ def get_command(
return new_command


def get_output(expression: str, input_names: List[str]) -> str:
def get_expression_output(expression: str, input_names: List[str]) -> str:
"""Get expression for outputs."""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused at first as I thought get_output was to get_input, but apparently it did something slightly different (I am not familiar with expressions for outputs as in the docs). I thought it would be OK to rename it and create a get_output that does something similar to get_input?

output_value = ""

Expand Down Expand Up @@ -471,6 +469,7 @@ def get_input(
inputs: List[cwl.CommandInputParameter],
unbound_input: List[str],
bound_input: List[str],
record_schemas: Dict[str, cwl.RecordSchema],
) -> List[cwl.CommandInputParameter]:
"""Get bound and unbound inputs."""
for i in unbound_input:
Expand All @@ -491,6 +490,11 @@ def get_input(
)
)

elif i[0] in record_schemas.keys():
inputs.append(
cwl.CommandInputParameter(id=input_name, type=record_schemas.get(i[0]))
)

else:
input_type = (
wdl_type[i[0]]
Expand Down Expand Up @@ -533,6 +537,75 @@ def get_input(
return inputs


def get_output(
outputs: List[cwl.CommandOutputParameter],
input_names: List[str],
task_outputs: List[List[str]],
record_schemas: Dict[str, cwl.RecordSchema],
) -> List[cwl.CommandOutputParameter]:
"""Get outputs."""
for i in task_outputs:
output_name = i[1]
output_glob = get_expression_output(i[2], input_names)

if "Array" in i[0]:
# output_type = ""
temp_type = wdl_type[
i[0][i[0].find("[") + 1 : i[0].find("]")].replace('"', "")
]
output_type = temp_type if "?" not in i[0] else [temp_type, "null"]
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type=[
cwl.CommandOutputArraySchema(items=output_type, type="array")
],
outputBinding=cwl.CommandOutputBinding(glob=output_glob),
)
)
elif "read_string(" in i[2]:
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type="string",
outputBinding=cwl.CommandOutputBinding(
glob=output_glob.replace("read_string(", "")[:-1],
loadContents=True,
outputEval=r"$(self[0].contents.replace(/[\r\n]+$/, ''))",
),
)
)
elif i[2] == "stdout()":
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type="stdout",
)
)
elif i[0] in record_schemas.keys():
outputs.append(
cwl.CommandOutputParameter(
id=output_name, type=record_schemas.get(i[0])
)
)
else:
output_type = (
wdl_type[i[0]]
if "?" not in i[0]
else [wdl_type[i[0].replace("?", "")], "null"]
)

outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type=output_type,
outputBinding=cwl.CommandOutputBinding(glob=output_glob),
)
)

return outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ☝️ came from the convert method.


def convert(workflow: str) -> str:
"""Generate a CWL object to match "cat-tool.cwl"."""
f = open(workflow)
Expand All @@ -546,10 +619,18 @@ def convert(workflow: str) -> str:
ast = WdlV1_1ParserVisitor() # type: ignore
ast.walk_tree(tree) # type: ignore

record_schemas: Dict[str, cwl.RecordSchema] = {}
input_types: List[str] = []
input_names: List[str] = []
input_values: List[str] = []

for s in ast.structs:
schema = cwl.RecordSchema(
type="record",
fields=[cwl.RecordField(name=f[0], type=f[1]) for f in s.fields],
)
record_schemas[s.name] = schema

for i in ast.task_inputs:
input_types.append(i[0])
input_names.append(i[1])
Expand Down Expand Up @@ -587,7 +668,9 @@ def convert(workflow: str) -> str:
base_command = ["bash", "example.sh"]

cwl_inputs: List[cwl.CommandInputParameter] = []
cwl_inputs = get_input(cwl_inputs, ast.task_inputs, ast.task_inputs_bound)
cwl_inputs = get_input(
cwl_inputs, ast.task_inputs, ast.task_inputs_bound, record_schemas
)

requirements: List[cwl.ProcessRequirement] = []

Expand Down Expand Up @@ -690,60 +773,8 @@ def convert(workflow: str) -> str:
)
)

outputs = []

for i in ast.task_outputs:
output_name = i[1]
output_glob = get_output(i[2], input_names)

if "Array" in i[0]:
# output_type = ""
temp_type = wdl_type[
i[0][i[0].find("[") + 1 : i[0].find("]")].replace('"', "")
]
output_type = temp_type if "?" not in i[0] else [temp_type, "null"]
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type=[
cwl.CommandOutputArraySchema(items=output_type, type="array")
],
outputBinding=cwl.CommandOutputBinding(glob=output_glob),
)
)
elif "read_string(" in i[2]:
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type="string",
outputBinding=cwl.CommandOutputBinding(
glob=output_glob.replace("read_string(", "")[:-1],
loadContents=True,
outputEval=r"$(self[0].contents.replace(/[\r\n]+$/, ''))",
),
)
)
elif i[2] == "stdout()":
outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type="stdout",
)
)
else:
output_type = (
wdl_type[i[0]]
if "?" not in i[0]
else [wdl_type[i[0].replace("?", "")], "null"]
)

outputs.append(
cwl.CommandOutputParameter(
id=output_name,
type=output_type,
outputBinding=cwl.CommandOutputBinding(glob=output_glob),
)
)
outputs: List[cwl.CommandOutputParameter] = []
outputs = get_output(outputs, input_names, ast.task_outputs, record_schemas)

cat_tool = cwl.CommandLineTool(
id=ast.task_name,
Expand Down Expand Up @@ -780,7 +811,7 @@ def convert(workflow: str) -> str:
result_stream = StringIO()
cwl_result = cat_tool.save()
scalarstring.walk_tree(cwl_result)
# ^ converts multine line strings to nice multiline YAML
# ^ converts multiline strings to nice multiline YAML
yaml.dump(cwl_result, result_stream)
yaml.dump(cwl_result, sys.stdout)

Expand Down
Loading