Code Monkey home page Code Monkey logo

Comments (3)

nmrtv avatar nmrtv commented on August 17, 2024

I had some issues with version 2.0 (graphene-sqlalchemy) and instead of trying to fix them, I made my own classes. There is a draft of those classes and usage, maybe you can find something interesting:

from functools import partial

import graphene
from graphene import Field, ObjectType
from graphene.relay import ConnectionField
from graphene.relay.connection import PageInfo
from graphene.types.objecttype import ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs, get_type
from graphene_sqlalchemy.fields import registerConnectionFieldFactory
from graphene_sqlalchemy.types import construct_fields
from graphene_sqlalchemy.utils import is_mapped_class, is_mapped_instance
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.query import Query


class MyConnection(graphene.relay.Connection):
    class Meta:
        abstract = True

    total_count = graphene.Int()

    @classmethod
    def __init_subclass_with_meta__(cls, node=None, name=None, **options):
        node._meta.connection = cls
        return super(MyConnection, cls).__init_subclass_with_meta__(node=node, name=name, **options)

    @staticmethod
    def resolve_total_count(root, info):
        return root.length

    @classmethod
    def get_query(cls, info, model, **args):
        return model.query


class MyObjectTypeOptions(ObjectTypeOptions):
    model = None
    registry = None
    id = None

    def __setattr__(self, name, value):
        if name == 'connection':
            self.__dict__[name] = value
        else:
            if not self._frozen:
                super(MyObjectTypeOptions, self).__setattr__(name, value)
            else:
                raise Exception("Can't modify frozen Options {0}".format(self))


class MyObjectType(ObjectType):
    @classmethod
    def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
                                    only_fields=(), exclude_fields=(), interfaces=(), id=None, **options):
        assert is_mapped_class(model), (
            'You need to pass a valid SQLAlchemy Model in '
            '{}.Meta, received "{}".'
        ).format(cls.__name__, model)

        if not registry:
            registry = get_global_registry()

        assert isinstance(registry, Registry), (
            'The attribute registry in {} needs to be an instance of '
            'Registry, received "{}".'
        ).format(cls.__name__, registry)

        sqla_fields = yank_fields_from_attrs(
            construct_fields(model, registry, only_fields, exclude_fields),
            _as=Field,
        )

        _meta = MyObjectTypeOptions(cls)
        _meta.model = model
        _meta.registry = registry
        _meta.fields = sqla_fields
        _meta.id = id or 'id'

        super(MyObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)

        if not skip_registry:
            registry.register(cls)

    @classmethod
    def is_type_of(cls, root, info):
        if isinstance(root, cls):
            return True
        if not is_mapped_instance(root):
            raise Exception((
                                'Received incompatible instance "{}".'
                            ).format(root))
        return isinstance(root, cls._meta.model)

    @classmethod
    def get_query(cls, info):
        model = cls._meta.model
        return model.query

    @classmethod
    def get_node(cls, info, id):
        try:
            return cls.get_query(info).get(id)
        except NoResultFound:
            return None

    def resolve_id(self, info):
        keys = self.__mapper__.primary_key_from_instance(self)
        return tuple(keys) if len(keys) > 1 else keys[0]


class MyConnectionField(ConnectionField):
    @property
    def model(self):
        return self.type._meta.node._meta.model

    @property
    def type(self):
        _type = get_type(self._type)
        if issubclass(_type, MyObjectType):
            assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
            return _type._meta.connection
        else:
            return super(MyConnectionField, self).type

    @classmethod
    def connection_resolver(cls, resolver, connection, model, root, info, **args):
        iterable = resolver(root, info, **args)
        if iterable is None:
            iterable = connection.get_query(info, model, **args)
        if isinstance(iterable, Query):
            _len = iterable.count()
        else:
            _len = len(iterable)
        connection = connection_from_list_slice(
            iterable,
            args,
            slice_start=0,
            list_length=_len,
            list_slice_length=_len,
            connection_type=connection,
            pageinfo_type=PageInfo,
            edge_type=connection.Edge,
        )
        connection.iterable = iterable
        connection.length = _len
        return connection

    def get_resolver(self, parent_resolver):
        return partial(self.connection_resolver, parent_resolver, self.type, self.model)


registerConnectionFieldFactory(MyConnectionField)


class Registry(object):
    def __init__(self):
        self._registry = {}
        self._registry_models = {}
        self._registry_composites = {}

    def register(self, cls):
        assert issubclass(cls, MyObjectType), (
            'Only classes of type MyObjectType can be registered, ',
            'received "{}"'
        ).format(cls.__name__)
        assert cls._meta.registry == self, 'Registry for a Model have to match.'
        # assert self.get_type_for_model(cls._meta.model) in [None, cls], (
        #     'SQLAlchemy model "{}" already associated with '
        #     'another type "{}".'
        # ).format(cls._meta.model, self._registry[cls._meta.model])
        self._registry[cls._meta.model] = cls

    def get_type_for_model(self, model):
        return self._registry.get(model)

    def register_composite_converter(self, composite, converter):
        self._registry_composites[composite] = converter

    def get_converter_for_composite(self, composite):
        return self._registry_composites.get(composite)


registry = None


def get_global_registry():
    global registry
    if not registry:
        registry = Registry()
    return registry


def reset_global_registry():
    global registry
    registry = None
import graphene
from graphene import relay
from graphql_relay import from_global_id
from sqlalchemy import inspect, asc, desc, and_, func
from sqlalchemy.orm import load_only

from application import db
from application.graphql_types import MyConnectionField, MyObjectType, MyConnection
from application.models.invoices import Customer as CustomerModel, TermItem as TermItemModel
from application.models.invoices import InvoiceProduct as InvoiceProductModel, Invoice as InvoiceModel, Tax as TaxModel
from application.models.users import User as UserModel


class Owner(MyObjectType):
    class Meta:
        interfaces = (relay.Node,)
        only_fields = ('id', 'rivile_id', 'name')
        model = UserModel


class OwnerConnection(MyConnection):
    class Meta:
        node = Owner

    @classmethod
    def get_query(cls, info, model, **args):
        return UserModel.query.filter(
            UserModel.rivile_id != None
        ).order_by(func.lower(UserModel.name))


class Tax(MyObjectType):
    class Meta:
        model = TaxModel
        interfaces = (relay.Node,)


class TaxConnection(MyConnection):
    class Meta:
        node = Tax


class Customer(MyObjectType):
    class Meta:
        model = CustomerModel
        interfaces = (relay.Node,)
        only_fields = ('id', 'name', 'code', 'vatcode', 'address')


class CustomerConnection(MyConnection):
    class Meta:
        node = Customer

    @classmethod
    def get_query(cls, info, model, **args):

        def col_by_name(model, colname):
            for column in inspect(model).attrs:
                if column.key == colname:
                    return column.expression
            return None

        query = model.query
        _sort_dir = 'asc'
        _sort_col = None
        for field, value in args.items():
            if field == 'sort_dir':
                _sort_dir = value
            elif field == 'sort_col':
                _sort_col = value

        if _sort_col:
            sort_col = col_by_name(model, _sort_col)
            if sort_col is None:
                raise Exception('Wrong sort column specified')
            sort_dir = asc if _sort_dir == 'asc' else desc
            query = query.order_by(sort_dir(sort_col))

        return query


class InvoiceProduct(MyObjectType):
    class Meta:
        model = InvoiceProductModel
        interfaces = (relay.Node,)
        only_fields = ('id', 'order_no', 'product_name', 'total', 'invoice', 'alt_count', 'price')


class InvoiceProductConnection(MyConnection):
    class Meta:
        node = InvoiceProduct

    @classmethod
    def get_query(cls, info, model, **args):

        def col_by_name(model, colname):
            for column in inspect(model).attrs:
                if column.key == colname:
                    return column.expression
            return None

        query = InvoiceProductModel.query.join(InvoiceModel).options(load_only(
            InvoiceProductModel.invoice_id,
            InvoiceProductModel.product_id,
            InvoiceProductModel.order_no,
            InvoiceProductModel.product_name,
            InvoiceProductModel.total,
            InvoiceProductModel.alt_count,
            InvoiceProductModel.price
        )).filter(InvoiceModel.op_type.in_((51, 52)))

        _sort_dir = 'asc'
        _sort_col = None
        for field, value in args.items():
            if field == 'sort_dir':
                _sort_dir = value
            elif field == 'sort_col':
                _sort_col = value

        if _sort_col:
            sort_col = col_by_name(InvoiceProductModel, _sort_col)
            if sort_col is None:
                sort_col = col_by_name(InvoiceModel, _sort_col)
            if sort_col is None:
                raise Exception('Wrong sort column specified')
            sort_dir = asc if _sort_dir == 'asc' else desc
            query = query.order_by(sort_dir(sort_col))

        return query


class Invoice(MyObjectType):
    class Meta:
        model = InvoiceModel
        interfaces = (relay.Node,)
        only_fields = ('id', 'docdate', 'number', 'customer_name', 'term_date', 'op_type', 'updated', 'products', 'tax', 'total_without_tax', 'customer', 'owner_name')


class InvoiceConnection(MyConnection):
    class Meta:
        node = Invoice

    @classmethod
    def get_query(cls, info, model, **args):

        def col_by_name(model, colname):
            for column in inspect(model).attrs:
                if column.key == colname:
                    return column.expression
            return None

        clauses = []
        if args['supplier']:
            clauses.append(InvoiceModel.op_type == 1)
        else:
            clauses.append(InvoiceModel.op_type.in_((51, 52)))
            if 'owner_id' in args:
                _, owner_id = from_global_id(args['owner_id'])
                owner = UserModel.query.get(owner_id)
                clauses.append(InvoiceModel.object_id == owner.rivile_id)

        if 'number' in args:
            clauses.append(InvoiceModel.number.ilike('%{}%'.format(args['number'])))

        where = and_(*clauses)

        query = InvoiceModel.query.options(load_only(
            InvoiceModel.id,
            InvoiceModel.number,
            InvoiceModel.docdate,
            InvoiceModel.customer_name,
            InvoiceModel.term_date,
            InvoiceModel.total_without_tax,
            InvoiceModel.owner_name
        )).filter(where)

        _sort_dir = 'asc'
        _sort_col = None
        for field, value in args.items():
            if field == 'sort_dir':
                _sort_dir = value
            elif field == 'sort_col':
                _sort_col = value

        if _sort_col:
            sort_col = col_by_name(InvoiceModel, _sort_col)
            if sort_col is None:
                raise Exception('Wrong sort column specified')
            sort_dir = asc if _sort_dir == 'asc' else desc
            query = query.order_by(sort_dir(sort_col))

        return query


class InvoiceProductInput(graphene.InputObjectType):
    product_name = graphene.String(required=True)
    alt_count = graphene.Int(required=True)
    price = graphene.Float(required=True)


class AddInvoice(graphene.ClientIDMutation):

    invoice = graphene.Field(Invoice)

    class Input:
        docdate = graphene.String(required=True)
        customer_id = graphene.GlobalID(required=True)
        tax_id = graphene.GlobalID(required=True)
        products = graphene.List(InvoiceProductInput, required=True)

    @classmethod
    def mutate_and_get_payload(cls, root, info, **args):
        _input = args.copy()
        _, _input['tax_id'] = from_global_id(args['tax_id'])
        _, _input['customer_id'] = from_global_id(args['customer_id'])
        del _input['products']
        invoice = InvoiceModel(**_input)

        # invoice = InvoiceModel(
        #     docdate = args.get('docdate'),
        #     customer_id = args.get('customer_id'),
        #     tax_id = args.get('tax_id'),
        # )

        tax = TaxModel.query.with_entities(
            TaxModel.id,
            TaxModel.percents
        ).filter(
            TaxModel.id == _input['tax_id']
        ).one()
        invoice.tax_id = tax.id

        customer = CustomerModel.query.with_entities(
            CustomerModel.id,
            CustomerModel.name,
            CustomerModel.address,
            CustomerModel.term_days,
            CustomerModel.term_percents
        ).filter(
            CustomerModel.id == _input['customer_id']
        ).one()
        invoice.customer_id = customer.id
        invoice.customer_name = customer.name
        invoice.customer_address = customer.address

        if customer.term_days is not None:
            t = TermItemModel()
            t.order_no = 1
            t.days = customer.term_days
            t.percents = customer.term_percents
            invoice.terms.append(t)

        if args['products']:
            # add new
            for i, prod in enumerate(args['products']):
                product = InvoiceProductModel(**prod)
                product.order_no = i + 1
                product.tax_percents = tax.percents
                invoice.products.append(product)
        else:
            raise Exception('Invoice must have at least one product.')

        db.session.add(invoice)
        db.session.commit()

        return AddInvoice(invoice=invoice)


class UpdateInvoice(graphene.ClientIDMutation):

    invoice = graphene.Field(Invoice)

    class Input:
        id = graphene.GlobalID(required=True)
        docdate = graphene.String()
        customer_id = graphene.GlobalID(required=False)
        tax_id = graphene.GlobalID(required=False)
        products = graphene.List(InvoiceProductInput)

    @classmethod
    def mutate_and_get_payload(cls, root, info, **args):
        _type, id = from_global_id(args['id'])
        invoice = InvoiceModel.query.get(id)

        if invoice.op_type not in (51, 52):
            raise Exception("Can't edit. Wrong invoice type")

        _input = args.copy()
        del _input['id']
        del _input['tax_id']
        del _input['customer_id']
        del _input['products']
        for key, value in _input.items():
            setattr(invoice, key, value)

        if 'tax_id' in args:
            _, tax_id = from_global_id(args['tax_id'])
            if invoice.tax_id != tax_id:
                tax = TaxModel.query.with_entities(
                    TaxModel.id,
                    TaxModel.percents
                ).filter(
                    TaxModel.id == tax_id
                ).one()
                invoice.tax_id = tax.id

        if 'customer_id' in args:
            _, customer_id = from_global_id(args['customer_id'])
            if invoice.customer_id != customer_id:
                customer = CustomerModel.query.with_entities(
                    CustomerModel.id,
                    CustomerModel.name,
                    CustomerModel.address,
                    CustomerModel.term_days,
                    CustomerModel.term_percents
                ).filter(
                    CustomerModel.id == customer_id
                ).one()
                invoice.customer_id = customer.id
                invoice.customer_name = customer.name
                invoice.customer_address = customer.address

                if customer.term_days is not None:
                    t = TermItemModel()
                    t.order_no = 1
                    t.days = customer.term_days
                    t.percents = customer.term_percents
                    if invoice.terms:
                        invoice.terms[0] = t  # TODO: patikrinti ar jie buna po kelis?
                    else:
                        invoice.terms.append(t)

        if 'products' in args and args['products']:
            # remove old
            for prod in invoice.products:
                db.session.delete(prod)
            invoice.products = []

            # get tax value
            tax = TaxModel.query.with_entities(
                TaxModel.id,
                TaxModel.percents
            ).filter(
                TaxModel.id == invoice.tax_id
            ).one()

            # add new
            for i, prod in enumerate(args['products']):
                product = InvoiceProductModel(**prod)
                product.order_no = i + 1
                product.tax_percents = tax.percents
                invoice.products.append(product)
        else:
            if not invoice.products:
                raise Exception('Invoice must have at least one product.')

        db.session.commit()

        return AddInvoice(invoice=invoice)


class MyMutations(graphene.ObjectType):
    add_invoice = AddInvoice.Field()
    update_invoice = UpdateInvoice.Field()


class Query(graphene.ObjectType):
    node = relay.Node.Field()
    all_invoices = MyConnectionField(InvoiceConnection, sort_dir=graphene.String(), sort_col=graphene.String(), owner_id=graphene.ID(), number=graphene.String(), supplier=graphene.Boolean(required=True))
    all_invoice_products = MyConnectionField(InvoiceProductConnection, sort_dir=graphene.String(), sort_col=graphene.String())
    invoice = relay.Node.Field(Invoice)

    all_customers = MyConnectionField(CustomerConnection, sort_dir=graphene.String(), sort_col=graphene.String())
    all_owners = MyConnectionField(OwnerConnection)
    all_taxes = MyConnectionField(TaxConnection)


schema = graphene.Schema(query=Query, mutation=MyMutations)

from graphene-sqlalchemy.

HeyHugo avatar HeyHugo commented on August 17, 2024

The mistake was mine, it's working now.

I think I was changing multiple things at once and probably broke something which confused me into thinking the upgrade was the issue. I must have fiddled with the Query class since it should really be using SQLAlchemyConnectionField rather than relay.ConnectionField

from graphene-sqlalchemy.

github-actions avatar github-actions commented on August 17, 2024

This issue has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related topics referencing this issue.

from graphene-sqlalchemy.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.