hop-2012/amqp_codegen.py

211 lines
7.5 KiB
Python
Raw Normal View History

from __future__ import with_statement
# Copyright (C) 2012 Tony Garnock-Jones. All rights reserved.
copyright_stmt = '(* Copyright (C) 2012 Tony Garnock-Jones. All rights reserved. *)'
import sys
import xml.dom.minidom
from collections import namedtuple
###########################################################################
# XML utils
def attr(n,a,d=None): return n.getAttribute(a).strip() if n.hasAttribute(a) else d
def kids(e,t): return [k for k in e.getElementsByTagName(t) if k.parentNode is e]
##########################################################################
# Identifier utils
def mlify(s):
s = s.replace('-', '_')
s = s.replace(' ', '_')
return s
def ctor(s):
return mlify(s).capitalize()
def tname(s):
return mlify(s) + '_t'
###########################################################################
# Load & parse the spec
with open('amqp0-9-1.stripped.xml') as f:
spec_xml = xml.dom.minidom.parse(f)
amqp_elt = spec_xml.getElementsByTagName('amqp')[0]
major = int(attr(amqp_elt, 'major', '0'))
minor = int(attr(amqp_elt, 'minor', '0'))
port = int(attr(amqp_elt, 'port', '5672'))
revision = int(attr(amqp_elt, 'revision', '0'))
constant_elts = amqp_elt.getElementsByTagName('constant')
def constants():
for e in constant_elts:
yield (attr(e, 'name'), attr(e, 'value'))
domain_elts = amqp_elt.getElementsByTagName('domain')
domains = {}
for e in domain_elts:
domains[attr(e, 'name')] = attr(e, 'type')
def resolve(typename):
seen = set()
while True:
if typename in seen:
return typename
seen.add(typename)
if typename in domains:
typename = domains[typename]
class AccessibleFieldsMixin:
@property
def accessible_fields(self):
return [f for f in self.fields if not f.reserved]
class Class(AccessibleFieldsMixin,
namedtuple('Class', 'index name fields methods'.split())):
pass
class Method(AccessibleFieldsMixin,
namedtuple('Method', ['class_name',
'class_index',
'has_content',
'deprecated',
'index',
'name',
'synchronous',
'responses',
'fields'])):
@property
def full_name(self):
return self.class_name + '-' + self.name
Field = namedtuple('Field', 'name type reserved'.split())
def load_fields(e):
return [Field(attr(f, 'name'),
resolve(attr(f, 'domain', attr(f, 'type'))),
int(attr(f, 'reserved', '0'))) \
for f in kids(e, 'field')]
class_elts = amqp_elt.getElementsByTagName('class')
classes = []
for e in class_elts:
classes.append(Class(int(attr(e, 'index')),
attr(e, 'name'),
load_fields(e),
[Method(attr(e, 'name'),
int(attr(e, 'index')),
int(attr(m, 'content', '0')),
int(attr(m, 'deprecated', '0')),
int(attr(m, 'index')),
attr(m, 'name'),
int(attr(m, 'synchronous', '0')),
[attr(r, 'name') for r in kids(m, 'response')],
load_fields(m)) \
for m in kids(e, 'method')]))
methods = []
for c in classes:
for m in c.methods:
methods.append(m)
###########################################################################
def print_codec():
print copyright_stmt
print '(* WARNING: Autogenerated code. Do not edit by hand! *)'
print
print 'open Amqp_wireformat'
print 'open Sexp'
print
print 'let version = (%d, %d, %d)' % (major, minor, revision)
print
print 'type method_t ='
for m in methods:
print ' | %s' % (ctor(m.full_name),),
if m.accessible_fields:
print 'of (' + ', '.join((tname(f.type) for f in m.accessible_fields)) + ')'
else:
print
print
print 'let has_content m = match m with '
for m in methods:
if m.has_content:
if m.accessible_fields:
print (' | %s (' + ', '.join(('_' for f in m.accessible_fields)) + ') = true') % \
(ctor(m.full_name),)
else:
print ' | %s = true' % (ctor(m.full_name),)
print ' | _ = false'
print
print 'type properties_t ='
for c in classes:
if c.fields:
if c.accessible_fields:
print (' | %s_properties of (' + ', '.join((tname(f.type) for f in c.accessible_fields)) + ')') % (ctor(c.name),)
else:
print ' | %s_properties' % (ctor(c.name),)
print
print 'let is_synchronous m = match m with '
for m in methods:
if not m.synchronous:
if m.accessible_fields:
print (' | %s (' + ', '.join(('_' for f in m.accessible_fields)) + ') = false') % \
(ctor(m.full_name),)
else:
print ' | %s = false' % (ctor(m.full_name),)
print ' | _ = true'
print
print 'let sexp_of_method m = match m with '
for m in methods:
print ' | %s' % (ctor(m.full_name),),
if m.accessible_fields:
print 'of (' + ', '.join((mlify(f.name) for f in m.accessible_fields)) + ') ->'
print ' Arr ["%s"; "%s"; %s]' % (
m.class_name,
m.name,
'; '.join(('Arr [Str "%s"; sexp_of_%s(%s)]' % \
(f.name, mlify(f.type), mlify(f.name)) for f in m.accessible_fields))
)
else:
print '->'
print ' Arr ["%s"; "%s"]' % (m.class_name, m.name)
print
print 'let read_method class_index method_index ch = match (class_index, method_index) with'
for m in methods:
print ' | (%d, %d) ->' % (m.class_index, m.index)
for f in m.fields:
if f.reserved:
print ' let _ = read_%s ch in' % (mlify(f.type))
else:
print ' let %s = read_%s ch in' % (mlify(f.name), mlify(f.type))
if m.accessible_fields:
print ' %s (%s)' % (ctor(m.full_name),
', '.join((mlify(f.name) for f in m.accessible_fields)))
else:
print ' %s' % (ctor(m.full_name),)
print
print 'let method_index m = match m with'
for m in methods:
if m.accessible_fields:
print (' | %s (' + ', '.join(('_' for f in m.accessible_fields)) + ') = (%d, %d)') % \
(ctor(m.full_name), m.class_index, m.index)
else:
print ' | %s = (%d, %d)' % (ctor(m.full_name), m.class_index, m.index)
print
print 'let write_method m ch = match m with'
for m in methods:
print ' | %s' % (ctor(m.full_name),),
if m.accessible_fields:
print 'of (' + ', '.join((mlify(f.name) for f in m.accessible_fields)) + ') ->'
for f in m.fields:
if f.reserved:
print ' write_%s ch reserved_value_%s;' % (mlify(f.type), mlify(f.type))
else:
print ' write_%s ch %s;' % (mlify(f.type), mlify(f.name))
else:
print '->'
print ' ()'
if __name__ == '__main__':
print_codec()