251 lines
9.4 KiB
Python
Executable File
251 lines
9.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Copyright (C) 2022 Intel Corporation.
|
|
#
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
|
|
import os
|
|
import argparse
|
|
from copy import deepcopy
|
|
|
|
from pipeline import PipelineObject, PipelineStage, PipelineEngine
|
|
|
|
class SchemaTypeSlicer:
|
|
xpath_ns = {
|
|
"xs": "http://www.w3.org/2001/XMLSchema",
|
|
"acrn": "https://projectacrn.org",
|
|
}
|
|
|
|
@classmethod
|
|
def get_node(cls, element, xpath):
|
|
return element.find(xpath, namespaces=cls.xpath_ns)
|
|
|
|
@classmethod
|
|
def get_nodes(cls, element, xpath):
|
|
return element.findall(xpath, namespaces=cls.xpath_ns)
|
|
|
|
def __init__(self, etree):
|
|
self.etree = etree
|
|
|
|
def get_type_definition(self, type_name):
|
|
type_node = self.get_node(self.etree, f"//xs:complexType[@name='{type_name}']")
|
|
if type_node is None:
|
|
type_node = self.get_node(self.etree, f"//xs:simpleType[@name='{type_name}']")
|
|
return type_node
|
|
|
|
def slice_element_list(self, element_list_node, new_nodes):
|
|
sliced = False
|
|
|
|
for element_node in self.get_nodes(element_list_node, "xs:element"):
|
|
if not self.is_element_needed(element_node):
|
|
element_list_node.remove(element_node)
|
|
sliced = True
|
|
continue
|
|
|
|
# For embedded complex type definition, also slice in place. If the sliced type contains no sub-element,
|
|
# remove the element itself, too.
|
|
element_type_node = self.get_node(element_node, "xs:complexType")
|
|
if element_type_node is not None:
|
|
new_sub_nodes = self.slice(element_type_node, in_place=True)
|
|
if len(self.get_nodes(element_type_node, ".//xs:element")) > 0:
|
|
new_nodes.extend(new_sub_nodes)
|
|
else:
|
|
element_list_node.remove(element_node)
|
|
continue
|
|
|
|
# For external type definition, create a copy to slice. If the sliced type contains no sub-element, remove
|
|
# the element itself.
|
|
element_type_name = element_node.get("type")
|
|
if element_type_name:
|
|
element_type_node = self.get_type_definition(element_type_name)
|
|
if element_type_node is not None:
|
|
sliced_type_name = self.get_name_of_slice(element_type_name)
|
|
|
|
# If a sliced type already exists, do not duplicate the effort
|
|
type_node = self.get_type_definition(sliced_type_name)
|
|
if type_node is not None:
|
|
element_node.set("type", sliced_type_name)
|
|
sliced = True
|
|
else:
|
|
new_sub_nodes = self.slice(element_type_node)
|
|
if len(new_sub_nodes) == 0:
|
|
continue
|
|
elif new_sub_nodes[-1].tag.endswith("simpleType") or len(self.get_nodes(new_sub_nodes[-1], ".//xs:element")) > 0:
|
|
new_nodes.extend(new_sub_nodes)
|
|
element_node.set("type", sliced_type_name)
|
|
sliced = True
|
|
else:
|
|
element_list_node.remove(element_node)
|
|
|
|
return sliced
|
|
|
|
def slice_restriction(self, restriction_node):
|
|
sliced = False
|
|
|
|
for restriction in self.get_nodes(restriction_node, "xs:enumeration"):
|
|
if not self.is_element_needed(restriction):
|
|
restriction_node.remove(restriction)
|
|
sliced = True
|
|
|
|
return sliced
|
|
|
|
def slice(self, type_node, in_place=False, force_copy=False):
|
|
new_nodes = []
|
|
sliced = False
|
|
|
|
if in_place:
|
|
new_type_node = type_node
|
|
else:
|
|
new_type_node = deepcopy(type_node)
|
|
type_name = type_node.get("name")
|
|
if type_name != None:
|
|
sliced_type_name = self.get_name_of_slice(type_name)
|
|
new_type_node.set("name", sliced_type_name)
|
|
|
|
element_list_node = self.get_node(new_type_node, "xs:all")
|
|
if element_list_node is not None:
|
|
sliced = self.slice_element_list(element_list_node, new_nodes)
|
|
|
|
restriction_node = self.get_node(new_type_node, "xs:restriction")
|
|
if restriction_node is not None:
|
|
sliced = self.slice_restriction(restriction_node)
|
|
|
|
if not in_place and (sliced or force_copy):
|
|
new_nodes.append(new_type_node)
|
|
return new_nodes
|
|
|
|
def is_element_needed(self, element_node):
|
|
return True
|
|
|
|
def get_name_of_slice(self, name):
|
|
return f"Sliced{name}"
|
|
|
|
class SlicingSchemaByVMTypeStage(PipelineStage):
|
|
uses = {"schema_etree"}
|
|
provides = {"schema_etree"}
|
|
|
|
class VMTypeSlicer(SchemaTypeSlicer):
|
|
def is_element_needed(self, element_node):
|
|
annot_node = self.get_node(element_node, "xs:annotation")
|
|
if annot_node is None:
|
|
return True
|
|
applicable_vms = annot_node.get("{https://projectacrn.org}applicable-vms")
|
|
return applicable_vms is None or applicable_vms.find(self.vm_type_indicator) >= 0
|
|
|
|
def get_name_of_slice(self, name):
|
|
return f"{self.type_prefix}{name}"
|
|
|
|
class PreLaunchedTypeSlicer(VMTypeSlicer):
|
|
vm_type_indicator = "pre-launched"
|
|
type_prefix = "PreLaunched"
|
|
|
|
class ServiceVMTypeSlicer(VMTypeSlicer):
|
|
vm_type_indicator = "service-vm"
|
|
type_prefix = "Service"
|
|
|
|
class PostLaunchedTypeSlicer(VMTypeSlicer):
|
|
vm_type_indicator = "post-launched"
|
|
type_prefix = "PostLaunched"
|
|
|
|
def run(self, obj):
|
|
schema_etree = obj.get("schema_etree")
|
|
|
|
vm_type_name = "VMConfigType"
|
|
vm_type_node = SchemaTypeSlicer.get_node(schema_etree, f"//xs:complexType[@name='{vm_type_name}']")
|
|
slicers = [
|
|
self.PreLaunchedTypeSlicer(schema_etree),
|
|
self.ServiceVMTypeSlicer(schema_etree),
|
|
self.PostLaunchedTypeSlicer(schema_etree)
|
|
]
|
|
|
|
for slicer in slicers:
|
|
new_nodes = slicer.slice(vm_type_node, force_copy=True)
|
|
for n in new_nodes:
|
|
schema_etree.getroot().append(n)
|
|
|
|
for node in SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType[@name='ACRNConfigType']//xs:element[@name='vm']//xs:alternative"):
|
|
test = node.get("test")
|
|
if test.find("PRE_LAUNCHED_VM") >= 0:
|
|
node.set("type", slicers[0].get_name_of_slice(vm_type_name))
|
|
elif test.find("SERVICE_VM") >= 0:
|
|
node.set("type", slicers[1].get_name_of_slice(vm_type_name))
|
|
elif test.find("POST_LAUNCHED_VM") >= 0:
|
|
node.set("type", slicers[2].get_name_of_slice(vm_type_name))
|
|
|
|
obj.set("schema_etree", schema_etree)
|
|
|
|
class SlicingSchemaByViewStage(PipelineStage):
|
|
uses = {"schema_etree"}
|
|
provides = {"schema_etree"}
|
|
|
|
class ViewSlicer(SchemaTypeSlicer):
|
|
def is_element_needed(self, element_node):
|
|
annot_node = self.get_node(element_node, "xs:annotation")
|
|
if annot_node is None:
|
|
return True
|
|
views = annot_node.get("{https://projectacrn.org}views")
|
|
return views is None or views.find(self.view_indicator) >= 0
|
|
|
|
def get_name_of_slice(self, name):
|
|
if name.find("ConfigType") >= 0:
|
|
return name.replace("ConfigType", f"{self.type_prefix}ConfigType")
|
|
else:
|
|
return f"{self.type_prefix}{name}"
|
|
|
|
class BasicViewSlicer(ViewSlicer):
|
|
view_indicator = "basic"
|
|
type_prefix = "Basic"
|
|
|
|
class AdvancedViewSlicer(ViewSlicer):
|
|
view_indicator = "advanced"
|
|
type_prefix = "Advanced"
|
|
|
|
def run(self, obj):
|
|
schema_etree = obj.get("schema_etree")
|
|
|
|
type_nodes = list(filter(lambda x: x.get("name") and x.get("name").endswith("VMConfigType"), SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType")))
|
|
type_nodes.append(SchemaTypeSlicer.get_node(schema_etree, "//xs:complexType[@name = 'HVConfigType']"))
|
|
|
|
slicers = [
|
|
self.BasicViewSlicer(schema_etree),
|
|
self.AdvancedViewSlicer(schema_etree),
|
|
]
|
|
|
|
for slicer in slicers:
|
|
for type_node in type_nodes:
|
|
new_nodes = slicer.slice(type_node, force_copy=True)
|
|
for n in new_nodes:
|
|
schema_etree.getroot().append(n)
|
|
|
|
obj.set("schema_etree", schema_etree)
|
|
|
|
def main(args):
|
|
from lxml_loader import LXMLLoadStage
|
|
|
|
pipeline = PipelineEngine(["schema_path"])
|
|
pipeline.add_stages([
|
|
LXMLLoadStage("schema"),
|
|
SlicingSchemaByVMTypeStage(),
|
|
SlicingSchemaByViewStage(),
|
|
])
|
|
|
|
obj = PipelineObject(schema_path = args.schema)
|
|
pipeline.run(obj)
|
|
obj.get("schema_etree").write(args.out)
|
|
|
|
print(f"Sliced schema written to {args.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# abs __file__ path to ignore `__file__ == 'schema_slicer.py'` issue
|
|
config_tools_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
|
schema_dir = os.path.join(config_tools_dir, "schema")
|
|
|
|
parser = argparse.ArgumentParser(description="Slice a given scenario schema by VM types and views")
|
|
parser.add_argument("out", nargs="?", default=os.path.join(schema_dir, "sliced.xsd"), help="Path where the output is placed")
|
|
parser.add_argument("--schema", default=os.path.join(schema_dir, "config.xsd"), help="the XML schema that defines the syntax of scenario XMLs")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|