acrn-hypervisor/misc/config_tools/scenario_config/schema_slicer.py

251 lines
9.4 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright (C) 2022 Intel Corporation.
#
# SPDX-License-Identifier: BSD-3-Clause
#
import os
import argparse
from copy import deepcopy
from pipeline import PipelineObject, PipelineStage, PipelineEngine
class SchemaTypeSlicer:
xpath_ns = {
"xs": "http://www.w3.org/2001/XMLSchema",
"acrn": "https://projectacrn.org",
}
@classmethod
def get_node(cls, element, xpath):
return element.find(xpath, namespaces=cls.xpath_ns)
@classmethod
def get_nodes(cls, element, xpath):
return element.findall(xpath, namespaces=cls.xpath_ns)
def __init__(self, etree):
self.etree = etree
def get_type_definition(self, type_name):
type_node = self.get_node(self.etree, f"//xs:complexType[@name='{type_name}']")
if type_node is None:
type_node = self.get_node(self.etree, f"//xs:simpleType[@name='{type_name}']")
return type_node
def slice_element_list(self, element_list_node, new_nodes):
sliced = False
for element_node in self.get_nodes(element_list_node, "xs:element"):
if not self.is_element_needed(element_node):
element_list_node.remove(element_node)
sliced = True
continue
# For embedded complex type definition, also slice in place. If the sliced type contains no sub-element,
# remove the element itself, too.
element_type_node = self.get_node(element_node, "xs:complexType")
if element_type_node is not None:
new_sub_nodes = self.slice(element_type_node, in_place=True)
if len(self.get_nodes(element_type_node, ".//xs:element")) > 0:
new_nodes.extend(new_sub_nodes)
else:
element_list_node.remove(element_node)
continue
# For external type definition, create a copy to slice. If the sliced type contains no sub-element, remove
# the element itself.
element_type_name = element_node.get("type")
if element_type_name:
element_type_node = self.get_type_definition(element_type_name)
if element_type_node is not None:
sliced_type_name = self.get_name_of_slice(element_type_name)
# If a sliced type already exists, do not duplicate the effort
type_node = self.get_type_definition(sliced_type_name)
if type_node is not None:
element_node.set("type", sliced_type_name)
sliced = True
else:
new_sub_nodes = self.slice(element_type_node)
if len(new_sub_nodes) == 0:
continue
elif new_sub_nodes[-1].tag.endswith("simpleType") or len(self.get_nodes(new_sub_nodes[-1], ".//xs:element")) > 0:
new_nodes.extend(new_sub_nodes)
element_node.set("type", sliced_type_name)
sliced = True
else:
element_list_node.remove(element_node)
return sliced
def slice_restriction(self, restriction_node):
sliced = False
for restriction in self.get_nodes(restriction_node, "xs:enumeration"):
if not self.is_element_needed(restriction):
restriction_node.remove(restriction)
sliced = True
return sliced
def slice(self, type_node, in_place=False, force_copy=False):
new_nodes = []
sliced = False
if in_place:
new_type_node = type_node
else:
new_type_node = deepcopy(type_node)
type_name = type_node.get("name")
if type_name != None:
sliced_type_name = self.get_name_of_slice(type_name)
new_type_node.set("name", sliced_type_name)
element_list_node = self.get_node(new_type_node, "xs:all")
if element_list_node is not None:
sliced = self.slice_element_list(element_list_node, new_nodes)
restriction_node = self.get_node(new_type_node, "xs:restriction")
if restriction_node is not None:
sliced = self.slice_restriction(restriction_node)
if not in_place and (sliced or force_copy):
new_nodes.append(new_type_node)
return new_nodes
def is_element_needed(self, element_node):
return True
def get_name_of_slice(self, name):
return f"Sliced{name}"
class SlicingSchemaByVMTypeStage(PipelineStage):
uses = {"schema_etree"}
provides = {"schema_etree"}
class VMTypeSlicer(SchemaTypeSlicer):
def is_element_needed(self, element_node):
annot_node = self.get_node(element_node, "xs:annotation")
if annot_node is None:
return True
applicable_vms = annot_node.get("{https://projectacrn.org}applicable-vms")
return applicable_vms is None or applicable_vms.find(self.vm_type_indicator) >= 0
def get_name_of_slice(self, name):
return f"{self.type_prefix}{name}"
class PreLaunchedTypeSlicer(VMTypeSlicer):
vm_type_indicator = "pre-launched"
type_prefix = "PreLaunched"
class ServiceVMTypeSlicer(VMTypeSlicer):
vm_type_indicator = "service-vm"
type_prefix = "Service"
class PostLaunchedTypeSlicer(VMTypeSlicer):
vm_type_indicator = "post-launched"
type_prefix = "PostLaunched"
def run(self, obj):
schema_etree = obj.get("schema_etree")
vm_type_name = "VMConfigType"
vm_type_node = SchemaTypeSlicer.get_node(schema_etree, f"//xs:complexType[@name='{vm_type_name}']")
slicers = [
self.PreLaunchedTypeSlicer(schema_etree),
self.ServiceVMTypeSlicer(schema_etree),
self.PostLaunchedTypeSlicer(schema_etree)
]
for slicer in slicers:
new_nodes = slicer.slice(vm_type_node, force_copy=True)
for n in new_nodes:
schema_etree.getroot().append(n)
for node in SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType[@name='ACRNConfigType']//xs:element[@name='vm']//xs:alternative"):
test = node.get("test")
if test.find("PRE_LAUNCHED_VM") >= 0:
node.set("type", slicers[0].get_name_of_slice(vm_type_name))
elif test.find("SERVICE_VM") >= 0:
node.set("type", slicers[1].get_name_of_slice(vm_type_name))
elif test.find("POST_LAUNCHED_VM") >= 0:
node.set("type", slicers[2].get_name_of_slice(vm_type_name))
obj.set("schema_etree", schema_etree)
class SlicingSchemaByViewStage(PipelineStage):
uses = {"schema_etree"}
provides = {"schema_etree"}
class ViewSlicer(SchemaTypeSlicer):
def is_element_needed(self, element_node):
annot_node = self.get_node(element_node, "xs:annotation")
if annot_node is None:
return True
views = annot_node.get("{https://projectacrn.org}views")
return views is None or views.find(self.view_indicator) >= 0
def get_name_of_slice(self, name):
if name.find("ConfigType") >= 0:
return name.replace("ConfigType", f"{self.type_prefix}ConfigType")
else:
return f"{self.type_prefix}{name}"
class BasicViewSlicer(ViewSlicer):
view_indicator = "basic"
type_prefix = "Basic"
class AdvancedViewSlicer(ViewSlicer):
view_indicator = "advanced"
type_prefix = "Advanced"
def run(self, obj):
schema_etree = obj.get("schema_etree")
type_nodes = list(filter(lambda x: x.get("name") and x.get("name").endswith("VMConfigType"), SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType")))
type_nodes.append(SchemaTypeSlicer.get_node(schema_etree, "//xs:complexType[@name = 'HVConfigType']"))
slicers = [
self.BasicViewSlicer(schema_etree),
self.AdvancedViewSlicer(schema_etree),
]
for slicer in slicers:
for type_node in type_nodes:
new_nodes = slicer.slice(type_node, force_copy=True)
for n in new_nodes:
schema_etree.getroot().append(n)
obj.set("schema_etree", schema_etree)
def main(args):
from lxml_loader import LXMLLoadStage
pipeline = PipelineEngine(["schema_path"])
pipeline.add_stages([
LXMLLoadStage("schema"),
SlicingSchemaByVMTypeStage(),
SlicingSchemaByViewStage(),
])
obj = PipelineObject(schema_path = args.schema)
pipeline.run(obj)
obj.get("schema_etree").write(args.out)
print(f"Sliced schema written to {args.out}")
if __name__ == "__main__":
# abs __file__ path to ignore `__file__ == 'schema_slicer.py'` issue
config_tools_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
schema_dir = os.path.join(config_tools_dir, "schema")
parser = argparse.ArgumentParser(description="Slice a given scenario schema by VM types and views")
parser.add_argument("out", nargs="?", default=os.path.join(schema_dir, "sliced.xsd"), help="Path where the output is placed")
parser.add_argument("--schema", default=os.path.join(schema_dir, "config.xsd"), help="the XML schema that defines the syntax of scenario XMLs")
args = parser.parse_args()
main(args)