diff --git a/.gitignore b/.gitignore index 1571d34..f88a0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,18 @@ +anknotes/extra/ancillary/FrontTemplate-Processed.htm +anknotes/extra/logs/ +anknotes/extra/dev/anknotes.developer* +anknotes/extra/dev/beyond* +anknotes/extra/dev/auth_tokens.txt +anknotes/extra/user/*.enex +anknotes/extra/user/anki.profile +anknotes/constants_user.py +test.py +*.bk +*.lnk +*.py-bkup +*.c00k!e +ctags.tags + ################# ## Eclipse ################# diff --git a/README.md b/README.md new file mode 100644 index 0000000..24bcb34 --- /dev/null +++ b/README.md @@ -0,0 +1,208 @@ +# Anknotes (Evernote to Anki Importer) +**Forks and suggestions are very welcome.** + +##Outline +1. [Description] (#description) +1. [User Instructions] (#user-instructions) +1. [Current Features] (#current-features) +1. [Settings] (#settings) +1. [Details] (#details) + * [Templates] (#anki-templates) + * [Auto Import] (#auto-import) + * [Note Processing] (#note-processing-features) +1. [Beta Functions] (#beta-functions) +1. [See Also Footer Links] (#see-also-footer-links) +1. [Future Features] (#future-features) +1. [Developer Notes] (#developer-notes) + +## Description +An Anki plug-in for downloading Evernote notes to Anki directly from Anki. In addition to this core functionality, Anknotes can automatically modify Evernote notes, create new Evernote notes, and link related Evernote notes together. + +## User Instructions +1. Download everything, move it to your `Anki/addons` directory +1. Start Anki, and click the Menu Item `Anknotes → Import From Evernote` + - Optionally, you can customize [settings] (#settings) in the Anknotes tab in Anki's preferences +1. When you run it the first time a browser tab will open on the Evernote site asking you for access to your account + - When you click okay you are taken to a website where the OAuth verification key is displayed. You paste that key into the open Anki prompt and click okay. + - Note that for the first 24 hours after granting access, you have unlimited API usage. After that, Evernote applies rate limiting. + - So, sync everything immediately! + +## Current Features +#### Evernote Importing +- A rich set of [options] (#settings) will dynamically generate your query from tags, notebook, title, last updated date, or free text +- Free text can include any valid [Evernote query] (https://dev.evernote.com/doc/articles/search_grammar.php) +- [Auto Import] (#auto-import) is possible + +#### Anki Note Generation +- [Four types of Anki Notes] (#anki-templates) can be generated: + - Standard, Reversible, Reverse-Only, and Cloze +- [Post-process] (#note-processing-features) Evernote Notes with a few improvements + - [Fix Evernote note links] (#post-process-links) + - [Automatically embed images] (#post-process-images) + - [Occlude certain text] (#post-process-occlude) on fronts of Anki cards + - [Generate Cloze Fields] (#post-process-cloze) + - [Process a "See Also" Footer field] (#see-also-footer-links) for showing links to other Evernote notes +- See the [Beta Functions] (#beta-functions) section below for info on See Also Footer fields, Table of Contents notes, and Outline notes +## Settings +#### Evernote Query +- You can enter any valid Evernote Query in the `Search Terms` field +- The check box before a given text field enables or disables that field +- Anknotes requires **all fields match** by default. + - You can use the `Match Any Terms` option to override this, but see the Evernote documentation on search for limitations +### Pagination +- Controls the offset parameter of the Evernote search. +- Auto Pagination is recommended and on by default +#### Anki Note Options +- Controls what is saved to Anki +- You can change the base Anki deck + - Anknotes can append the base deck with the Evernote note's Notebook Stack and Notebook Name + - Any colon will be converted to two colons, to enable Anki's sub-deck functionality +- You can change which Evernote tags are saved +#### Note Updating +- By default, Anknotes will update existing Anki notes in place. This preserves all Anki statistics. +- You can also ignore existing notes, or delete and re-add existing notes (this will erase any Anki statistics) + +## Details +#### Anki Templates +- All use an advanced Anki template with customized content and CSS +- Reversible notes will generate a normal and reversed card for each note + - Add `#Reversible` tag to Evernote note before importing +- Reverse-only notes will only generate a reversed card + - Add `#Reverse-Only` tag to Evernote note before importing +- [Cloze notes] (#post-process-cloze) are automatically detected by Anknotes + +#### Auto Import +1. Automatically import on profile load + - Enable via Anknotes Menu + - Auto Import will be delayed if an import has occurred in the past 30 minutes +1. Automatically page through an Evernote query + - Enable via Anknotes Settings + - Evernote only returns 250 results per search, so queries with > 250 possible results require multiple searches + - If more than 10 API calls are made during a search, the next search is delayed by 15 minutes +1. Automatically import continuously + - Only configurable via source code at this time + - Enable Auto Import and Pagination as per above, and then modify `constants.py`, setting `PAGING_RESTART_WHEN_COMPLETE` to `True` +#### Note Processing Features +1. Fix [Evernote Note Links] (https://dev.evernote.com/doc/articles/note_links.php) so that they can be opened in Anki + - Convert "New Style" Evernote web links to "Classic" Evernote in-app links so that any note links open directly in Evernote + - Convert all Evernote links to use two forward slashes instead of three to get around an Anki bug +1. Automatically embed images + - This is a workaround since Anki cannot import Evernote resources such as embedded images, PDF files, sounds, etc + - Anknotes will convert any of the following to embedded, linkable images: + - Any HTML Dropbox sharing link to an image `(https://www.dropbox.com/s/...)` + - Any Dropbox plain-text to an image (same as above, but plain-text links must end with `?dl=0` or `?dl=1`) + - Any HTML link with Link Text beginning with "Image Link", e.g.: `Image Link #1` +1. Occlude (hide) certain text on fronts of Anki cards + - Useful for displaying additional information but ensuring it only shows on backs of cards + - Anknotes converts any of the following to special text that will display in grey color, and only on the backs of cards: + - Any text with white foreground + - Any text within two brackets, such as `<>` +1. Automatically generate [Cloze fields] (http://ankisrs.net/docs/manual.html#cloze) + - Any text with a single curly bracket will be converted into a cloze field + - E.g., two cloze fields are generated from: The central nervous system is made up of the `{brain}` and `{spinal cord}` + - If you want to generate a single cloze field (not increment the field #), insert a pound character `('#')` after the first curly bracket: + - E.g., a single cloze field is generated from: The central nervous system is made up of the `{brain}` and `{#spinal cord}` + +##Beta Functions +#### Note Creation +- Anknotes can create and upload/update existing Evernote notes +- Currently this is limited to creating new Auto TOC notes and modifying the See Also Footer field of existing notes +- Anknotes uses client-side validation to decrease API usage, but there is currently an issue with use of the validation library in Anki. + - So, Anknotes will execute this validation using an **external** script, not as an Anki addon + - Therefore, you must **manually** ensure that **Python** and the **lxml** module is installed on your system + - Alternately, disable validation: Edit `constants.py` and set `ENABLE_VALIDATION` to `False` + +#### Find Deleted/Orphaned Notes +- Anknotes is not intended for use as a sync client with Evernote (this may change in the future) +- Thus, notes deleted from the Evernote servers will not be deleted from Anki +- Use `Anknotes → Maintenance Tasks → Find Deleted Notes` to find and delete these notes from Anki + - You can also find notes in Evernote that don't exist in Anki + - First, you must create a "Table of Contents" note using the Evernote desktop application: + - In the Windows client, select ALL notes you want imported into Anki, and click the `Create Table of Contents Note` button on the right-sided panel + - Alternately, select 'Copy Note Links' and paste the content into a new Evernote Note. + - Export your Evernote note to `anknotes/extra/user/Table of Contents.enex` + +## "See Also" Footer Links +#### Concept +- You have topics (**Root Notes**) broken down into multiple sub-topics (**Sub Notes**) + - The Root Notes are too broad to be tested, and therefore not useful as Anki cards + - The Sub Notes are testable topics intended to be used as Anki cards +- Anknotes tries to link these related Sub Notes together so you can rapidly view related content in Evernote +#### Terms +1. **Table of Contents (TOC) Notes** + - Primarily contain a hierarchical list of links to other notes +2. **Outline Notes** + - Primarily contain content itself of sub-notes + - E.g. a summary of sub-notes or full text of sub-notes + - Common usage scenario is creating a broad **Outline** style note when studying a topic, and breaking that down into multiple **Sub Notes** to use in Anki +3. **"See Also" Footer** Fields + - Primarily consist of links to TOC notes, Outline notes, or other Evernote notes +4. **Root Titles** and **Sub Notes** + - Sub Notes are notes with a colon in the title + - Root Title is the portion of the title before the first colon + +#### Integration +###### With Anki: +- The **"See Also" Footer** field is shown on the backs of Anki cards only, so having a descriptive link in here won't give away the correct answer +- The content itself of **TOC** and **Outline** notes are also viewable on the backs of Anki cards +##### With Evernote: +- Anknotes can create new Evernote notes from automatically generated TOC notes +- Anknotes can update existing Evernote notes with modified See Also Footer fields + +#### Usage +###### Manual Usage: +- Add a new line to the end of your Evernote note that begins with `See Also`, and include relevant links after it +- Tag notes in Evernote before importing. + - Table of Contents (TOC) notes are designated by the `#TOC` tag. + - Outline notes are designed by the `#Outline` tag. +###### Automated Usage: +- Anknotes can automatically create: + - Table of Contents Notes + - Created for **Root Titles** containing two or more Sub Notes + - In Anki, click the `Anknotes Menu → Process See Also Footer Links → Step 3: Create Auto TOC Notes`. + - Once the Auto TOC notes are generated, click `Steps 4 & 5` to upload the notes to Evernote + - See Also' Footer fields for displaying links to other Evernote notes + - Any links from other notes, including automatically generated TOC notes, are inserted into this field by Anknotes + - Creation of Outline notes from sub-notes or sub-notes from outline notes is a possible future feature +#### Example: +Let's say we have nine **Sub Notes** titled `Diabetes: Symptoms`, `Diabetes: Treatment`, `Diabetes: Treatment: Types of Insulin`, and `Diabetes: Complications`, etc: +- Anknotes will generate a TOC note **`Diabetes`** with hierarchical links to all nine sub-notes as such: + + > DIABETES + > 1. Symptoms + > 2. Complications + > 1. Cardiovascular + > * Heart Attack Risk + > 2. Infectious + > 3. Ophthalmologic + > 3. Treatment + > * Types of Insulin +- Anknotes can then insert a link to that TOC note in the 'See Also' Footer field of the sub notes +- This 'See Also' Footer field will display on the backs of Anki cards +- The TOC note's contents themselves will also be available on the backs of Anki cards + +## Future Features +- More robust options + - Move options from source code into GUI + - Allow enabling/disabling of beta functions like See Also fields + - Customize criteria for detecting see also fields +- Implement full sync with Evernote servers +- Import resources (e.g., images, sounds, etc) from Evernote notes +- Automatically create Anki sub-notes from a large Evernote note + +## Developer Notes +#### Anki Template / CSS Files: +- Template File Location: `/extra/ancillary/FrontTemplate.htm` +- CSS File Location: `/extra/ancillary/_AviAnkiCSS.css` +- Message Box CSS: `/extra/ancillary/QMessageBox.css` + +#### Anknotes Local Database +- Anknotes saves all Evernote notes, tags, and notebooks in the SQL database of the active Anki profile + - You may force a resync with the local Anknotes database via the menu: `Anknotes → Maintenance Tasks` + - You may force update of ancillary tag/notebook data via this menu +- Maps of see also footer links and Table of Contents notes are also saved here +- All Evernote note history is saved in a separate table. This is not currently used but may be helpful if data loss occurs or for future functionality +#### Developer Functions +- If you are testing a new feature, you can automatically have Anki run that function when Anki starts. + - Simply add the method to `__main__.py` under the comment `Add a function here and it will automatically run on profile load` + - Also, create the folder `/anknotes/extra/dev` and add files `anknotes.developer` and `anknotes.developer.automate` \ No newline at end of file diff --git a/anknotes/Anki.py b/anknotes/Anki.py new file mode 100644 index 0000000..6d0d664 --- /dev/null +++ b/anknotes/Anki.py @@ -0,0 +1,706 @@ +# -*- coding: utf-8 -*- +### Python Imports +import shutil +import sys +import re + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### Anknotes Imports +from anknotes.AnkiNotePrototype import AnkiNotePrototype +from anknotes.base import fmt, encode, decode +from anknotes.shared import * +from anknotes import stopwatch + +### Evernote Imports +# from evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec +# from evernote.edam.type.ttypes import NoteSortOrder, Note +# from evernote.edam.error.ttypes import EDAMSystemException, EDAMErrorCode, EDAMUserException, EDAMNotFoundException +# from evernote.api.client import EvernoteClient + +### Anki Imports +try: + import anki + from anki.notes import Note as AnkiNote + from anki.utils import intTime + import aqt + from aqt import mw +except Exception: + pass + + +class Anki: + def __init__(self): + self.deck = None + self.templates = None + + @staticmethod + def get_notebook_guid_from_ankdb(evernote_guid): + return ankDB().scalar("SELECT notebookGuid FROM {n} WHERE guid = '%s'" % evernote_guid) + + def get_deck_name_from_evernote_notebook(self, notebookGuid, deck=None): + if not deck: + deck = self.deck if self.deck else "" + if not hasattr(self, 'notebook_data'): + self.notebook_data = {} + if not notebookGuid in self.notebook_data: + # log_error("Unexpected error: Notebook GUID '%s' could not be found in notebook data: %s" % (notebookGuid, str(self.notebook_data))) + notebook = EvernoteNotebook(fetch_guid=notebookGuid) + if not notebook.success: + log_error( + " get_deck_name_from_evernote_notebook FATAL ERROR: UNABLE TO FIND NOTEBOOK '%s'. " % notebookGuid) + return None + # log("Getting notebook info: %s" % str(notebook)) + self.notebook_data[notebookGuid] = notebook + notebook = self.notebook_data[notebookGuid] + if notebook.Stack: + deck += u'::' + notebook.Stack + deck += "::" + notebook.Name + deck = deck.replace(": ", "::") + if deck[: + 2] == '::': + deck = deck[2:] + return deck + + def update_evernote_notes(self, evernote_notes, log_update_if_unchanged=True): + """ + Update Notes in Anki Database + :type evernote_notes: list[EvernoteNotePrototype.EvernoteNotePrototype] + :rtype : int + :param evernote_notes: List of EvernoteNote returned from server or local db + :param log_update_if_unchanged: + :return: Count of notes successfully updated + """ + return self.add_evernote_notes(evernote_notes, True, log_update_if_unchanged=log_update_if_unchanged) + + def add_evernote_notes(self, evernote_notes, update=False, log_update_if_unchanged=True): + """ + Add Notes to or Update Notes in Anki Database + :param evernote_notes: + :param update: + :param log_update_if_unchanged: + :type evernote_notes: list[EvernoteNotePrototype.EvernoteNotePrototype] + :type update: bool + :return: Count of notes successfully added or updated + """ + new_nids=[] + action_str_base = ['Add', 'Update'][update] + action_str = ['Adding', 'Updating'][update] + action_preposition = ['To', 'In'][update] + info = stopwatch.ActionInfo(action_str + ' Of', 'Evernote Notes', action_preposition + ' Anki', report_if_empty=False) + tmr = stopwatch.Timer(evernote_notes, 10, info=info, + label='Add\\Anki-%sEvernoteNotes' % action_str_base) + + for ankiNote in evernote_notes: + try: + title = ankiNote.FullTitle + content = decode(ankiNote.Content) + anki_field_info = { + FIELDS.TITLE: title, + FIELDS.CONTENT: content, + FIELDS.EVERNOTE_GUID: FIELDS.EVERNOTE_GUID_PREFIX + ankiNote.Guid, + FIELDS.UPDATE_SEQUENCE_NUM: str(ankiNote.UpdateSequenceNum), + FIELDS.SEE_ALSO: u'' + } + except Exception: + log_error("Unable to set field info for: Note '%s': '%s'" % (ankiNote.FullTitle, ankiNote.Guid)) + log_dump(ankiNote.Content, " NOTE CONTENTS ") + # log_dump(encode(ankiNote.Content), " NOTE CONTENTS ") + raise + tmr.step(title) + baseNote = None + if update: + baseNote = self.get_anki_note_from_evernote_guid(ankiNote.Guid) + if not baseNote: + log_error('Updating note %s: COULD NOT FIND BASE NOTE FOR ANKI NOTE ID' % ankiNote.Guid) + tmr.reportStatus(EvernoteAPIStatus.MissingDataError) + continue + if ankiNote.Tags is None: + log_error("Could note find tags object for note %s: %s. " % (ankiNote.Guid, ankiNote.FullTitle)) + tmr.reportStatus(EvernoteAPIStatus.MissingDataError) + continue + anki_note_prototype = AnkiNotePrototype(self, anki_field_info, ankiNote.TagNames, baseNote, + notebookGuid=ankiNote.NotebookGuid, count=tmr.count, + count_update=tmr.counts.updated.completed.val, max_count=tmr.max) + anki_note_prototype._log_update_if_unchanged_ = log_update_if_unchanged + nid = tmr.autoStep(anki_note_prototype.update_note() if update else anki_note_prototype.add_note(), + ankiNote.FullTitle, update) + if tmr.status.IsSuccess and not update: + new_nids.append([nid, ankiNote.Guid]) + elif tmr.status.IsError: + log("ANKI ERROR WHILE %s EVERNOTE NOTES: " % action_str.upper() + str(tmr.status), tmr.label + '-Error') + tmr.Report() + if new_nids: + ankDB().executemany("UPDATE {n} SET nid = ? WHERE guid = ?", new_nids) + return tmr.counts.success + + def delete_anki_cards(self, evernote_guids): + col = self.collection() + card_ids = [] + for evernote_guid in evernote_guids: + card_ids += mw.col.findCards(FIELDS.EVERNOTE_GUID_PREFIX + evernote_guid) + col.remCards(card_ids) + return len(card_ids) + + @staticmethod + def get_evernote_model_styles(): + if MODELS.OPTIONS.IMPORT_STYLES: + return '@import url("%s");' % FILES.ANCILLARY.CSS + return file(os.path.join(FOLDERS.ANCILLARY, FILES.ANCILLARY.CSS), 'r').read() + + def add_evernote_model(self, mm, modelName, forceRebuild=False, cloze=False, allowForceRebuild=True): + model = mm.byName(modelName) + model_css = self.get_evernote_model_styles() + templates = self.get_templates(modelName == MODELS.DEFAULT) + if model and modelName is MODELS.DEFAULT and allowForceRebuild: + front = model['tmpls'][0]['qfmt'] + evernote_account_info = get_evernote_account_ids() + if not evernote_account_info.Valid: + info = ankDB().first( + "SELECT uid, shard, COUNT(uid) as c1, COUNT(shard) as c2 from {s} GROUP BY uid, shard ORDER BY c1 DESC, c2 DESC LIMIT 1") + if info and evernote_account_info.update(info[0], info[1]): + forceRebuild = True + if evernote_account_info.Valid: + if not "evernote_uid = '%s'" % evernote_account_info.uid in front or not "evernote_shard = '%s'" % evernote_account_info.shard in front: + forceRebuild = True + if model['css'] != model_css: + forceRebuild = True + if model['tmpls'][0]['qfmt'] != templates['Front']: + forceRebuild = True + if not model or forceRebuild: + if model: + for t in model['tmpls']: + t['qfmt'] = templates['Front'] + t['afmt'] = templates['Back'] + model['css'] = model_css + mm.update(model) + else: + model = mm.new(modelName) + # Add Field for Evernote GUID: + # Note that this field is first because Anki requires the first field to be unique + evernote_guid_field = mm.newField(FIELDS.EVERNOTE_GUID) + evernote_guid_field['sticky'] = True + evernote_guid_field['font'] = 'Consolas' + evernote_guid_field['size'] = 10 + mm.addField(model, evernote_guid_field) + + # Add Standard Fields: + mm.addField(model, mm.newField(FIELDS.TITLE)) + + evernote_content_field = mm.newField(FIELDS.CONTENT) + evernote_content_field['size'] = 14 + mm.addField(model, evernote_content_field) + + evernote_see_also_field = mm.newField(FIELDS.SEE_ALSO) + evernote_see_also_field['size'] = 14 + mm.addField(model, evernote_see_also_field) + + evernote_extra_field = mm.newField(FIELDS.EXTRA) + evernote_extra_field['size'] = 12 + mm.addField(model, evernote_extra_field) + + evernote_toc_field = mm.newField(FIELDS.TOC) + evernote_toc_field['size'] = 10 + mm.addField(model, evernote_toc_field) + + evernote_outline_field = mm.newField(FIELDS.OUTLINE) + evernote_outline_field['size'] = 10 + mm.addField(model, evernote_outline_field) + + # Add USN to keep track of changes vs Evernote's servers + evernote_usn_field = mm.newField(FIELDS.UPDATE_SEQUENCE_NUM) + evernote_usn_field['font'] = 'Consolas' + evernote_usn_field['size'] = 10 + mm.addField(model, evernote_usn_field) + + # Add Templates + + if modelName is MODELS.DEFAULT or modelName is MODELS.REVERSIBLE: + # Add Default Template + default_template = mm.newTemplate(TEMPLATES.DEFAULT) + default_template['qfmt'] = templates['Front'] + default_template['afmt'] = templates['Back'] + mm.addTemplate(model, default_template) + if modelName is MODELS.REVERSE_ONLY or modelName is MODELS.REVERSIBLE: + # Add Reversed Template + reversed_template = mm.newTemplate(TEMPLATES.REVERSED) + reversed_template['qfmt'] = templates['Front'] + reversed_template['afmt'] = templates['Back'] + mm.addTemplate(model, reversed_template) + if modelName is MODELS.CLOZE: + # Add Cloze Template + cloze_template = mm.newTemplate(TEMPLATES.CLOZE) + cloze_template['qfmt'] = templates['Front'] + cloze_template['afmt'] = templates['Back'] + mm.addTemplate(model, cloze_template) + + # Update Sort field to Title (By default set to GUID since it is the first field) + model['sortf'] = 1 + + # Update Model CSS + model['css'] = model_css + + # Set Type to Cloze + if cloze: + model['type'] = MODELS.TYPES.CLOZE + + # Add Model to Collection + mm.add(model) + + # Add Model id to list + self.evernoteModels[modelName] = model['id'] + return forceRebuild + + def get_templates(self, forceRebuild=False): + if not self.templates or forceRebuild: + evernote_account_info = get_evernote_account_ids() + field_names = { + "Title": FIELDS.TITLE, "Content": FIELDS.CONTENT, "Extra": FIELDS.EXTRA, + "See Also": FIELDS.SEE_ALSO, "TOC": FIELDS.TOC, "Outline": FIELDS.OUTLINE, + "Evernote GUID Prefix": FIELDS.EVERNOTE_GUID_PREFIX, "Evernote GUID": FIELDS.EVERNOTE_GUID, + "Evernote UID": evernote_account_info.uid, "Evernote shard": evernote_account_info.shard + } + # Generate Front and Back Templates from HTML Template in anknotes' addon directory + self.templates = {"Front": file(FILES.ANCILLARY.TEMPLATE, 'r').read() % field_names} + self.templates["Back"] = self.templates["Front"].replace("
", "
") + return self.templates + + def add_evernote_models(self, allowForceRebuild=True): + col = self.collection() + mm = col.models + self.evernoteModels = {} + + forceRebuild = self.add_evernote_model(mm, MODELS.DEFAULT, allowForceRebuild=allowForceRebuild) + self.add_evernote_model(mm, MODELS.REVERSE_ONLY, forceRebuild) + self.add_evernote_model(mm, MODELS.REVERSIBLE, forceRebuild) + self.add_evernote_model(mm, MODELS.CLOZE, forceRebuild, True) + + def setup_ancillary_files(self): + # Copy CSS file from anknotes addon directory to media directory + media_dir = re.sub("(?i)\.(anki2)$", ".media", self.collection().path) + if isinstance(media_dir, str): + media_dir = unicode(media_dir, sys.getfilesystemencoding()) + shutil.copy2(os.path.join(FOLDERS.ANCILLARY, FILES.ANCILLARY.CSS), os.path.join(media_dir, FILES.ANCILLARY.CSS)) + + def get_anki_fields_from_anki_note_id(self, a_id, fields_to_ignore=list()): + note = self.collection().getNote(a_id) + try: + items = note.items() + except Exception: + log_error("Unable to get note items for Note ID: %d" % a_id) + raise + return get_dict_from_list(items, fields_to_ignore) + + def get_evernote_guids_from_anki_note_ids(self, ids=None): + if ids is None: + ids = self.get_anknotes_note_ids() + evernote_guids = [] + self.usns = {} + for a_id in ids: + fields = self.get_anki_fields_from_anki_note_id(a_id, [FIELDS.CONTENT]) + evernote_guid = get_evernote_guid_from_anki_fields(fields) + if not evernote_guid: + continue + evernote_guids.append(evernote_guid) + # log('Anki USN for Note %s is %s' % (evernote_guid, fields[FIELDS.UPDATE_SEQUENCE_NUM]), 'anki-usn') + if FIELDS.UPDATE_SEQUENCE_NUM in fields: + self.usns[evernote_guid] = fields[FIELDS.UPDATE_SEQUENCE_NUM] + else: + log(" ! get_evernote_guids_from_anki_note_ids: Note '%s' is missing USN!" % evernote_guid) + return evernote_guids + + def get_evernote_guids_and_anki_fields_from_anki_note_ids(self, ids=None): + if ids is None: + ids = self.get_anknotes_note_ids() + evernote_guids = {} + for a_id in ids: + fields = self.get_anki_fields_from_anki_note_id(a_id) + evernote_guid = get_evernote_guid_from_anki_fields(fields) + if evernote_guid: + evernote_guids[evernote_guid] = fields + return evernote_guids + + def search_evernote_models_query(self): + query = "" + delimiter = "" + for mName, mid in self.evernoteModels.items(): + query += delimiter + "mid:" + str(mid) + delimiter = " OR " + return query + + def get_anknotes_note_ids(self, query_filter=""): + query = self.search_evernote_models_query() + if query_filter: + query = query_filter + " (%s)" % query + ids = self.collection().findNotes(query) + return ids + + def get_anki_note_from_evernote_guid(self, evernote_guid): + col = self.collection() + ids = col.findNotes(FIELDS.EVERNOTE_GUID_PREFIX + evernote_guid) + if not ids or not ids[0]: + return None + note = AnkiNote(col, None, ids[0]) + return note + + def get_anknotes_note_ids_by_tag(self, tag): + return self.get_anknotes_note_ids("tag:" + tag) + + def get_anknotes_note_ids_with_unadded_see_also(self): + return self.get_anknotes_note_ids('"See Also" "See_Also:"') + + def process_see_also_content(self, anki_note_ids): + log = Logger('See Also\\1-process_unadded_see_also_notes\\', rm_path=True) + tmr = stopwatch.Timer(anki_note_ids, infoStr='Processing Unadded See Also Notes', label=log.base_path) + tmr.info.BannerHeader('error') + for a_id in anki_note_ids: + ankiNote = self.collection().getNote(a_id) + try: + items = ankiNote.items() + except Exception: + log.error("Unable to get note items for Note ID: %d for %s" % (a_id, tmr.base_name)) + raise + fields = {} + for key, value in items: + fields[key] = value + if fields[FIELDS.SEE_ALSO]: + tmr.reportSkipped() + continue + anki_note_prototype = AnkiNotePrototype(self, fields, ankiNote.tags, ankiNote, count=tmr.count, + count_update=tmr.counts.updated.completed.val, + max_count=tmr.max, light_processing=True) + if not anki_note_prototype.Fields[FIELDS.SEE_ALSO]: + tmr.reportSkipped() + continue + log.go("Detected see also contents for Note '%s': %s" % ( + get_evernote_guid_from_anki_fields(fields), fields[FIELDS.TITLE])) + log.go(u" ::: %s " % strip_tags_and_new_lines(fields[FIELDS.SEE_ALSO])) + tmr.autoStep(anki_note_prototype.update_note(), fields[FIELDS.TITLE], update=True) + + def process_toc_and_outlines(self): + self.extract_links_from_toc() + self.insert_toc_into_see_also() + self.insert_toc_and_outline_contents_into_notes() + + def insert_toc_into_see_also(self): + db = ankDB() + db._db.row_factory = None + results = db.all( + "SELECT s.target_evernote_guid, s.source_evernote_guid, target_note.title, toc_note.title " + "FROM {s} as s, {n} as target_note, {n} as toc_note " + "WHERE s.source_evernote_guid != s.target_evernote_guid AND target_note.guid = s.target_evernote_guid " + "AND toc_note.guid = s.source_evernote_guid AND s.from_toc == 1 " + "ORDER BY target_note.title ASC") + # results_bad = db.all( + # "SELECT s.target_evernote_guid, s.source_evernote_guid FROM {t_see} as s WHERE s.source_evernote_guid COUNT(SELECT * FROM {tn} WHERE guid = s.source_evernote_guid) )" % ( + # TABLES.SEE_ALSO, TABLES.EVERNOTE.NOTES, TABLES.EVERNOTE.NOTES)) + all_child_guids = db.list("tagNames NOT LIKE '{t_toc}'", columns='guid') + all_toc_guids = db.list("tagNames LIKE '{t_toc}'", columns='guid') + grouped_results = {} + toc_titles = {} + for row in results: + target_guid = row[0] + toc_guid = row[1] + if toc_guid not in all_toc_guids: + continue + if target_guid not in all_toc_guids and target_guid not in all_child_guids: + continue + if target_guid not in grouped_results: + grouped_results[target_guid] = [row[2], []] + toc_titles[toc_guid] = row[3] + grouped_results[target_guid][1].append(toc_guid) + action_title = 'INSERT TOCS INTO ANKI NOTES' + info = stopwatch.ActionInfo('Inserting TOC Links into', 'Anki Notes', 'Anki Notes\' See Only Field') + log = Logger('See Also\\5-insert_toc_links_into_see_also\\', rm_path=True) + tmr = stopwatch.Timer(len(grouped_results), info=info, label=log.base_path) + tmr.info.BannerHeader('new', crosspost=['invalid', 'error']) + toc_separator = generate_evernote_span(u' | ', u'Links', u'See Also', bold=False) + log.add('

%s: %d TOTAL NOTES




' % (action_title, tmr.max), 'see_also_html', + timestamp=False, clear=True, + extension='htm') + logged_missing_anki_note = False + sorted_results = sorted(grouped_results.items(), key=lambda s: s[1][0]) + for target_guid, target_guid_info in sorted_results: + note_title, toc_guids = target_guid_info + ankiNote = self.get_anki_note_from_evernote_guid(target_guid) + # if tmr.step(): + # log.add("INSERTING TOC LINKS INTO NOTE %5s: %s: %s" % ('#' + str(tmr.count), tmr.progress, note_title), + # 'progress') + if not ankiNote: + log.dump(toc_guids, 'Missing Anki Note for ' + target_guid, tmr.label, timestamp=False, + crosspost_to_default=False) + if not logged_missing_anki_note: + log.error('%s: Missing Anki Note(s) for TOC entry. See %s dump log for more details' % + (action_title, tmr.label)) + logged_missing_anki_note = True + tmr.reportStatus(EvernoteAPIStatus.NotFoundError, title=note_title) + continue + fields = get_dict_from_list(ankiNote.items()) + see_also_html = fields[FIELDS.SEE_ALSO] + content_links = find_evernote_links_as_guids(fields[FIELDS.CONTENT]) + see_also_whole_links = find_evernote_links(see_also_html) + see_also_links = {x.Guid for x in see_also_whole_links} + invalid_see_also_links = {x for x in see_also_links if x not in all_child_guids and x not in all_toc_guids} + new_tocs = set(toc_guids) - see_also_links + if TAGS.TOC_AUTO in ankiNote.tags: + new_tocs -= set(content_links) + log.dump([new_tocs, toc_guids, invalid_see_also_links, see_also_links, content_links], + 'TOCs for %s' % fields[FIELDS.TITLE] + ' vs ' + note_title, 'new_tocs', + crosspost_to_default=False) + new_toc_count = len(new_tocs) + invalid_see_also_links_count = len(invalid_see_also_links) + if invalid_see_also_links_count > 0: + for link in see_also_whole_links: + if link.Guid in invalid_see_also_links: + see_also_html = remove_evernote_link(link, see_also_html) + see_also_links -= invalid_see_also_links + see_also_count = len(see_also_links) + + if new_toc_count > 0: + has_ol = u'%s' % toc_link) + toc_delimiter = toc_separator + if flat_links: + find_div_end = see_also_html.rfind('
') + if find_div_end > -1: + see_also_html = see_also_html[:find_div_end] + see_also_new + '\n' + see_also_html[ + find_div_end:] + see_also_new = '' + else: + see_also_toc_headers = { + 'ol': u'
\n%s
    ' % + generate_evernote_span('TABLE OF CONTENTS:', 'Levels', 'Auto TOC', escape=False) + } + see_also_toc_headers['ul'] = see_also_toc_headers['ol'].replace('
      ') + see_also_html = see_also_html[:find_ul_end] + '
    ' + see_also_html[find_ul_end + 5:] + see_also_html = see_also_html.replace(see_also_toc_headers['ul'], see_also_toc_headers['ol']) + if see_also_toc_headers['ol'] in see_also_html: + find_ol_end = see_also_html.rfind('
') + see_also_html = see_also_html[:find_ol_end] + see_also_new + '\n' + see_also_html[find_ol_end:] + see_also_new = '' + else: + header_type = 'ul' if new_toc_count is 1 else 'ol' + see_also_new = see_also_toc_headers[header_type] + u'%s\n' % (see_also_new, header_type) + if see_also_count == 0: + see_also_html = generate_evernote_span(u'See Also:', 'Links', 'See Also') + see_also_html += see_also_new + see_also_html = see_also_html.replace('
    ', '
      ') + log.add('

      %s


      ' % generate_evernote_span(fields[FIELDS.TITLE], 'Links', + 'TOC') + see_also_html + u'
      ', 'see_also_html', + crosspost='see_also_html\\' + note_title, timestamp=False, extension='htm') + see_also_html = see_also_html.replace('evernote:///', 'evernote://') + changed = see_also_html != fields[FIELDS.SEE_ALSO] + + fields[FIELDS.SEE_ALSO] = see_also_html + anki_note_prototype = AnkiNotePrototype(self, fields, ankiNote.tags, ankiNote, count=tmr.counts.handled, + count_update=tmr.counts.updated.completed.val, max_count=tmr.max, + light_processing=True, steps=[0, 1, 7]) + anki_note_prototype._log_update_if_unchanged_ = ( + changed or new_toc_count + invalid_see_also_links_count > 0) + tmr.autoStep(anki_note_prototype.update_note(error_if_unchanged=changed), note_title, True) + + crosspost = [] + if new_toc_count: + crosspost.append('new') + if invalid_see_also_links: + crosspost.append('invalid') + if tmr.status.IsError: + crosspost.append('error') + log.go(' %s | %2d TOTAL TOC''s | %s | %s | %s%s' % ( + format_count('%2d NEW TOC''s', new_toc_count), len(toc_guids), + format_count('%2d EXISTING LINKS', see_also_count), + format_count('%2d INVALID LINKS', invalid_see_also_links_count), + ('*' if changed else ' ') * 3, note_title), crosspost=crosspost, timestamp=False) + + db._db.row_factory = sqlite.Row + + def extract_links_from_toc(self): + db = ankDB(TABLES.SEE_ALSO) + db.setrowfactory() + toc_entries = db.all("SELECT * FROM {n} WHERE tagNames LIKE '{t_toc}' ORDER BY title ASC") + db.execute("DELETE FROM {t} WHERE from_toc = 1") + log = Logger('See Also\\4-extract_links_from_toc\\', timestamp=False, crosspost_to_default=False, rm_path=True) + tmr = stopwatch.Timer(toc_entries, 20, infoStr='Extracting Links', label=log.base_path) + tmr.info.BannerHeader('error') + toc_guids = [] + for toc_entry in toc_entries: + toc_evernote_guid, toc_link_title = toc_entry['guid'], toc_entry['title'] + toc_guids.append("'%s'" % toc_evernote_guid) + # toc_link_html = generate_evernote_span(toc_link_title, 'Links', 'TOC') + enLinks = find_evernote_links(toc_entry['content']) + tmr.increment(toc_link_title) + for enLink in enLinks: + target_evernote_guid = enLink.Guid + if not check_evernote_guid_is_valid(target_evernote_guid): + log.go("Invalid Target GUID for %-70s %s" % (toc_link_title + ':', target_evernote_guid), 'error') + continue + base = { + 'child_guid': target_evernote_guid, 'uid': enLink.Uid, + 'shard': enLink.Shard, 'toc_guid': toc_evernote_guid, 'l1': 'source', 'l2': 'source', + 'from_toc': 0, 'is_toc': 0 + } + query_count = "select COUNT(*) from {t} WHERE source_evernote_guid = '{%s_guid}'" + toc = { + 'num': 1 + db.scalar(fmt(query_count % 'toc', base)), + 'html': enLink.HTML.replace(u'\'', u'\'\''), + 'title': enLink.FullTitle.replace(u'\'', u'\'\''), + 'l1': 'target', + 'from_toc': 1 + } + # child = {1 + db.scalar(fmt(query_count % 'child', base)), + # 'html': toc_link_html.replace(u'\'', u'\'\''), + # 'title': toc_link_title.replace(u'\'', u'\'\''), + # 'l2': 'target', + # 'is_toc': 1 + # } + query = (u"INSERT OR REPLACE INTO `{t}`(`{l1}_evernote_guid`, `number`, `uid`, `shard`, " + u"`{l2}_evernote_guid`, `html`, `title`, `from_toc`, `is_toc`) " + u"VALUES('{child_guid}', {num}, {uid}, '{shard}', " + u"'{toc_guid}', '{html}', '{title}', {from_toc}, {is_toc})") + query_toc = fmt(query, base, toc) + db.execute(query_toc) + log.go("\t\t - Added %2d child link(s) from TOC %s" % (len(enLinks), encode(toc_link_title))) + db.update("is_toc = 1", where="target_evernote_guid IN (%s)" % ', '.join(toc_guids)) + db.commit() + + def insert_toc_and_outline_contents_into_notes(self): + linked_notes_fields = {} + db = ankDB(TABLES.SEE_ALSO) + source_guids = db.list("SELECT DISTINCT s.source_evernote_guid FROM {s} s, {n} n WHERE (s.is_toc = 1 OR " + "s.is_outline = 1) AND s.source_evernote_guid = n.guid ORDER BY n.title ASC") + info = stopwatch.ActionInfo('Insertion of', 'TOC/Outline Contents', 'Into Target Anki Notes') + log = Logger('See Also\\8-insert_toc_contents\\', rm_path=True, timestamp=False) + tmr = stopwatch.Timer(source_guids, 25, info=info, label=log.base_path) + tmr.info.BannerHeader('error') + for source_guid in source_guids: + note = self.get_anki_note_from_evernote_guid(source_guid) + if not note: + tmr.reportStatus(EvernoteAPIStatus.NotFoundError) + log.error("Could not find note for %s for %s" % (note.guid, tmr.base_name)) + continue + # if TAGS.TOC in note.tags: + # tmr.reportSkipped() + # continue + for fld in note._model['flds']: + if FIELDS.TITLE in fld.get('name'): + note_title = note.fields[fld.get('ord')] + continue + if not note_title: + tmr.reportStatus(EvernoteAPIStatus.NotFoundError) + log.error("Could not find note title for %s for %s" % (note.guid, tmr.base_name)) + continue + tmr.step(note_title) + note_toc = "" + note_outline = "" + toc_header = "" + outline_header = "" + toc_count = 0 + outline_count = 0 + toc_and_outline_links = db.execute("source_evernote_guid = '%s' AND (is_toc = 1 OR is_outline = 1) " + "ORDER BY number ASC" % source_guid, + columns='target_evernote_guid, is_toc, is_outline') + for target_evernote_guid, is_toc, is_outline in toc_and_outline_links: + if target_evernote_guid in linked_notes_fields: + linked_note_contents = linked_notes_fields[target_evernote_guid][FIELDS.CONTENT] + linked_note_title = linked_notes_fields[target_evernote_guid][FIELDS.TITLE] + else: + linked_note = self.get_anki_note_from_evernote_guid(target_evernote_guid) + if not linked_note: + continue + linked_note_contents = u"" + for fld in linked_note._model['flds']: + if FIELDS.CONTENT in fld.get('name'): + linked_note_contents = linked_note.fields[fld.get('ord')] + elif FIELDS.TITLE in fld.get('name'): + linked_note_title = linked_note.fields[fld.get('ord')] + if linked_note_contents: + linked_notes_fields[target_evernote_guid] = { + FIELDS.TITLE: linked_note_title, + FIELDS.CONTENT: linked_note_contents + } + if linked_note_contents: + linked_note_contents = decode(linked_note_contents) + if is_toc: + toc_count += 1 + if toc_count is 1: + toc_header = "TABLE OF CONTENTS: 1. %s" % linked_note_title + else: + note_toc += "

      "; toc_header += " | %d. %s" % ( + toc_count, linked_note_title) + note_toc += linked_note_contents + else: + outline_count += 1 + if outline_count is 1: + outline_header = "OUTLINE: 1. %s" % linked_note_title + else: + note_outline += "

      "; outline_header += " | %d. %s" % ( + outline_count, linked_note_title) + note_outline += linked_note_contents + if outline_count + toc_count is 0: + tmr.reportError(EvernoteAPIStatus.MissingDataError) + log.error(" No Valid TOCs or Outlines Found: %s" % note_title) + continue + tmr.reportSuccess() + + def makestr(title, count): + return '' if not count else 'One %s ' % title if count is 1 else '%s %ss' % (str(count).center(3), title) + + toc_str = makestr('TOC', toc_count).rjust(8) #if toc_count else '' + outline_str = makestr('Outline', outline_count).ljust(12) #if outline_count else '' + toc_str += ' & ' if toc_count and outline_count else ' ' + + log.go(" [%4d/%4d] + %s for Note %s: %s" % ( + tmr.count, tmr.max, toc_str + outline_str, source_guid.split('-')[0], note_title)) + + if outline_count > 1: + note_outline = "%s

      " % outline_header + note_outline + if toc_count > 1: + note_toc = "%s

      " % toc_header + note_toc + for fld in note._model['flds']: + if FIELDS.TOC in fld.get('name'): + note.fields[fld.get('ord')] = note_toc + elif FIELDS.OUTLINE in fld.get('name'): + note.fields[fld.get('ord')] = note_outline + # log.go(' '*16 + "> Flushing Note \r\n") + note.flush(intTime()) + tmr.Report() + + def start_editing(self): + self.window().requireReset() + + def stop_editing(self): + if self.collection(): + self.window().maybeReset() + + @staticmethod + def window(): + """ + :rtype : AnkiQt + :return: + """ + return aqt.mw + + def collection(self): + return self.window().col + + def models(self): + return self.collection().models + + def decks(self): + return self.collection().decks diff --git a/anknotes/AnkiNotePrototype.py b/anknotes/AnkiNotePrototype.py new file mode 100644 index 0000000..fcd0f18 --- /dev/null +++ b/anknotes/AnkiNotePrototype.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- +### Anknotes Shared Imports +from anknotes.base import encode, decode +from anknotes.shared import * +from anknotes.error import HandleUnicodeError +from anknotes.EvernoteNoteTitle import EvernoteNoteTitle + +### Anki Imports + + +try: + import anki + from anki.notes import Note as AnkiNote + from anki.utils import intTime + from aqt import mw +except Exception: + pass + + +def get_self_referential_fmap(): + fmap = {} + for i in range(0, len(FIELDS.LIST)): + fmap[i] = i + return fmap + + +class AnkiNotePrototype: + Anki = None + """:type : anknotes.Anki.Anki """ + BaseNote = None + """:type : AnkiNote """ + enNote = None + """:type: EvernoteNotePrototype.EvernoteNotePrototype""" + Fields = {} + """:type : dict[str, str]""" + Tags = [] + """:type : list[str]""" + ModelName = None + """:type : str""" + # Guid = "" + # """:type : str""" + NotebookGuid = "" + """:type : str""" + __cloze_count = 0 + + class Counts: + Updated = 0 + Current = 0 + Max = 1 + + OriginalGuid = None + """:type : str""" + Changed = False + _unprocessed_content_ = "" + _unprocessed_see_also_ = "" + _log_update_if_unchanged_ = True + + @property + def Guid(self): + return get_evernote_guid_from_anki_fields(self.Fields) + + def __init__(self, anki=None, fields=None, tags=None, base_note=None, notebookGuid=None, count=-1, count_update=0, + max_count=1, counts=None, light_processing=False, enNote=None, **kw): + """ + Create Anki Note Prototype Class from fields or Base Anki Note + :param anki: Anki: Anknotes Main Class Instance + :type anki: anknotes.Anki.Anki + :param fields: Dict of Fields + :param tags: List of Tags + :type tags : list[str] + :param base_note: Base Anki Note if Updating an Existing Note + :type base_note : anki.notes.Note + :param enNote: Base Evernote Note Prototype from Anknotes DB, usually used just to process a note's contents + :type enNote : EvernoteNotePrototype.EvernoteNotePrototype + :param notebookGuid: + :param count: + :param count_update: + :param max_count: + :param counts: AnkiNotePrototype.Counts if being used to add/update multiple notes + :type counts : AnkiNotePrototype.Counts + :return: AnkiNotePrototype + """ + self.loggedSeeAlsoError = None + self.__log_name = 'Add\\AnkiNotePrototype' + self.light_processing = light_processing + self.Anki = anki + self.Fields = fields + self.BaseNote = base_note + if enNote and light_processing and not fields: + self.Fields = { + FIELDS.TITLE: enNote.FullTitle, FIELDS.CONTENT: enNote.Content, FIELDS.SEE_ALSO: u'', + FIELDS.EVERNOTE_GUID: FIELDS.EVERNOTE_GUID_PREFIX + enNote.Guid + } + self.enNote = enNote + self.Changed = False + self.logged = False + if counts: + self.Counts = counts + else: + self.Counts.Updated = count_update + self.Counts.Current = count + 1 + self.Counts.Max = max_count + self.initialize_fields() + # self.Guid = get_evernote_guid_from_anki_fields(self.Fields) + self.NotebookGuid = notebookGuid + self.ModelName = None # MODELS.DEFAULT + # self.Title = EvernoteNoteTitle() + if not self.NotebookGuid and self.Anki: + self.NotebookGuid = self.Anki.get_notebook_guid_from_ankdb(self.Guid) + if not self.Guid and (self.light_processing or self.NotebookGuid): + log('Guid/Notebook Guid missing for: ' + self.FullTitle) + log(self.Guid) + log(self.NotebookGuid) + raise ValueError + self._deck_parent_ = self.Anki.deck if self.Anki else '' + assert tags is not None + self.Tags = tags + self.__cloze_count = 0 + self.process_note(**kw) + + def initialize_fields(self): + if self.BaseNote: + self.originalFields = get_dict_from_list(self.BaseNote.items()) + for field in FIELDS.LIST: + if not field in self.Fields: + self.Fields[field] = self.originalFields[field] if self.BaseNote else u'' + # self.Title = EvernoteNoteTitle(self.Fields) + + def deck(self): + deck = self._deck_parent_ + if TAGS.TOC in self.Tags or TAGS.TOC_AUTO in self.Tags: + deck += DECKS.TOC_SUFFIX + elif TAGS.OUTLINE in self.Tags and TAGS.OUTLINE_TESTABLE not in self.Tags: + deck += DECKS.OUTLINE_SUFFIX + elif not self._deck_parent_ or SETTINGS.ANKI.DECKS.EVERNOTE_NOTEBOOK_INTEGRATION.fetch(): + deck = self.Anki.get_deck_name_from_evernote_notebook(self.NotebookGuid, self._deck_parent_) + if not deck: + return None + if deck[:2] == '::': + deck = deck[2:] + return deck + + def evernote_cloze_regex(self, match): + matchText = match.group(2) + if matchText.startswith("#"): + matchText = matchText[1:] + else: + self.__cloze_count += 1 + if self.__cloze_count == 0: + self.__cloze_count = 1 + return "%s{{c%d::%s}}%s" % (match.group(1), self.__cloze_count, matchText, match.group(3)) + + def regex_occlude_match(self, match): + matchText = match.group(0) + if 'class="Occluded"' in matchText or "class='Occluded'" in matchText: + return matchText + return r'<<' + match.group('PrefixKeep') + '
      ' + match.group( + 'OccludedText') + '
      >>' + + def process_note_see_also(self): + if not FIELDS.SEE_ALSO in self.Fields or not FIELDS.EVERNOTE_GUID in self.Fields: + return + db = ankDB(TABLES.SEE_ALSO) + db.delete("source_evernote_guid = '%s' " % self.Guid) + link_num = 0 + for enLink in find_evernote_links(self.Fields[FIELDS.SEE_ALSO]): + if not check_evernote_guid_is_valid(enLink.Guid): + self.Fields[FIELDS.SEE_ALSO] = remove_evernote_link(enLink, self.Fields[FIELDS.SEE_ALSO]) + continue + link_num += 1 + values = DictCaseInsensitive(source_evernote_guid=self.Guid, target_evernote_guid=enLink.Guid, + number=link_num, uid=enLink.Uid, shard=enLink.Shard, + html=enLink.HTML, title=enLink.FullTitle) + values.from_toc = 1 if ',%s,' % TAGS.TOC in self.Tags else 0 + values.is_toc = 1 if (values.title == "TOC" or values.title == "TABLE OF CONTENTS") else 0 + values.is_outline = 1 if (values.title == "O" or values.title == "Outline") else 0 + db.insert_or_replace(values) + if link_num is 0: + self.Fields[FIELDS.SEE_ALSO] = "" + db.commit() + + def process_note_content(self, steps=None): + def step_0_remove_evernote_css_attributes(): + ################################### Step 0: Correct weird Evernote formatting + self.Fields[FIELDS.CONTENT] = clean_evernote_css(self.Fields[FIELDS.CONTENT]) + + def step_1_modify_evernote_links(): + ################################### Step 1: Modify Evernote Links + # We need to modify Evernote's "Classic" Style Note Links due to an Anki bug with executing the evernote command with three forward slashes. + # For whatever reason, Anki cannot handle evernote links with three forward slashes, but *can* handle links with two forward slashes. + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace("evernote:///", "evernote://") + + # Modify Evernote's "New" Style Note links that point to the Evernote website. Normally these links open the note using Evernote's web client. + # The web client then opens the local Evernote executable. Modifying the links as below will skip this step and open the note directly using the local Evernote executable + self.Fields[FIELDS.CONTENT] = re.sub(r'https://www.evernote.com/shard/(s\d+)/[\w\d]+/(\d+)/([\w\d\-]+)', + r'evernote://view/\2/\1/\3/\3/', self.Fields[FIELDS.CONTENT]) + + # If we are converting back to Evernote format + if self.light_processing: + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace("evernote://", + "evernote:///") + + def step_2_modify_image_links(): + ################################### Step 2: Modify Image Links + # Currently anknotes does not support rendering images embedded into an Evernote note. + # As a work around, this code will convert any link to an image on Dropbox, to an embedded tag. + # This code modifies the Dropbox link so it links to a raw image file rather than an interstitial web page + # Step 2.1: Modify HTML links to Dropbox images + dropbox_image_url_base_regex = r'(?P[^"''])(?Phttps://www.dropbox.com/s/[\w\d]+/.+\.(jpg|png|jpeg|gif|bmp))' + dropbox_image_url_html_link_regex = dropbox_image_url_base_regex + r'(?P(?:\?dl=(?:0|1))?)' + dropbox_image_url_suffix = r'(?P[^"''])' + dropbox_image_src_subst = r'\g
      Dropbox Link %s Automatically Generated by Anknotes\g' + self.Fields[FIELDS.CONTENT] = re.sub(r']*>(?P.+?)</a>' % ( + dropbox_image_url_html_link_regex + dropbox_image_url_suffix), + dropbox_image_src_subst % "'\g<Title>'", self.Fields[FIELDS.CONTENT]) + + # Step 2.2: Modify Plain-text links to Dropbox images + try: + dropbox_image_url_regex = dropbox_image_url_base_regex + r'(?P<QueryString>\?dl=(?:0|1))' + dropbox_image_url_suffix + self.Fields[FIELDS.CONTENT] = re.sub(dropbox_image_url_regex, + r'\g<URLPrefix>' + dropbox_image_src_subst % "From Plain-Text Link" + r'\g<URLSuffix>', + self.Fields[FIELDS.CONTENT]) + except Exception: + log_error("\nERROR processing note, Step 2.2. Content: %s" % self.Fields[FIELDS.CONTENT]) + + # Step 2.3: Modify HTML links with the inner text of exactly "(Image Link*)" + self.Fields[FIELDS.CONTENT] = re.sub( + r'<a href=["''](?P<URL>.+?)["''][^>]*>(?P<Title>\(Image Link[^<]*\))</a>', + r'''<img src="\g<URL>" alt="'\g<Title>' Automatically Generated by Anknotes" /> <BR><a href="\g<URL>">\g<Title></a>''', + self.Fields[FIELDS.CONTENT]) + + def step_3_occlude_text(): + ################################### Step 3: Change white text to transparent + # I currently use white text in Evernote to display information that I want to be initially hidden, but visible when desired by selecting the white text. + # We will change the white text to a special "occluded" CSS class so it can be visible on the back of cards, and also so we can adjust the color for the front of cards when using night mode + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace( + '<span style="color: rgb(255, 255, 255);">', '<span class="occluded">') + + ################################### Step 4: Automatically Occlude Text in <<Double Angle Brackets>> + self.Fields[FIELDS.CONTENT] = re.sub( + "(?s)(?P<Prefix><|<) ?(?P=Prefix) ?(?P<PrefixKeep>(?:</div>)?)(?P<OccludedText>.+?)(?P<Suffix>>|>) ?(?P=Suffix) ?", + self.regex_occlude_match, self.Fields[FIELDS.CONTENT]) + + def step_5_create_cloze_fields(): + ################################### Step 5: Create Cloze fields from shorthand. Syntax is {Text}. Optionally {#Text} will prevent the Cloze # from incrementing. + self.Fields[FIELDS.CONTENT] = re.sub(r'([^{]){([^{].*?)}([^}])', self.evernote_cloze_regex, + self.Fields[FIELDS.CONTENT]) + + def step_6_process_see_also_links(): + ################################### Step 6: Process "See Also: " Links + see_also_match = regex_see_also().search(self.Fields[FIELDS.CONTENT]) + if not see_also_match: + i_see_also = self.Fields[FIELDS.CONTENT].find("See Also") + if i_see_also > -1: + self.loggedSeeAlsoError = self.Guid + i_div = self.Fields[FIELDS.CONTENT].rfind("<div", 0, i_see_also) + if i_div is -1: + i_div = i_see_also + log_error( + "No See Also Content Found, but phrase 'See Also' exists in " + self.Guid + ": " + self.FullTitle, + crosspost_to_default=False) + log( + "No See Also Content Found, but phrase 'See Also' exists: \n" + self.Guid + ": " + self.FullTitle + " \n" + + self.Fields[FIELDS.CONTENT][i_div:i_see_also + 50] + '\n', 'SeeAlso\\MatchExpected') + log(self.Fields[FIELDS.CONTENT], 'SeeAlso\\MatchExpected\\' + self.FullTitle) + # raise ValueError + return False + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace(see_also_match.group(0), + see_also_match.group('Suffix')) + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace('<div><b><br/></b></div></en-note>', + '</en-note>') + see_also = see_also_match.group('SeeAlso') + see_also_header = see_also_match.group('SeeAlsoHeader') + see_also_header_stripme = see_also_match.group('SeeAlsoHeaderStripMe') + if see_also_header_stripme: + see_also = see_also.replace(see_also_header, see_also_header.replace(see_also_header_stripme, '')) + if self.Fields[FIELDS.SEE_ALSO]: + self.Fields[FIELDS.SEE_ALSO] += "<br><br>\r\n" + self.Fields[FIELDS.SEE_ALSO] += see_also + if self.light_processing: + self.Fields[FIELDS.CONTENT] = self.Fields[FIELDS.CONTENT].replace(see_also_match.group('Suffix'), + self.Fields[ + FIELDS.SEE_ALSO] + see_also_match.group( + 'Suffix')) + return False + return True + + if not FIELDS.CONTENT in self.Fields: + return + self._unprocessed_content_ = self.Fields[FIELDS.CONTENT] + self._unprocessed_see_also_ = self.Fields[FIELDS.SEE_ALSO] + if steps is None: + steps = [0, 1, 6] if self.light_processing else range(0, 7) + if self.light_processing and not ANKI.NOTE_LIGHT_PROCESSING_INCLUDE_CSS_FORMATTING: + steps.remove(0) + if 0 in steps: + step_0_remove_evernote_css_attributes() + if 1 in steps: + step_1_modify_evernote_links() + if 2 in steps: + step_2_modify_image_links() + if 3 in steps: + step_3_occlude_text() + if 5 in steps: + step_5_create_cloze_fields() + if (6 in steps and step_6_process_see_also_links()) or (6 not in steps and 7 in steps): + self.process_note_see_also() + # TODO: Add support for extracting an 'Extra' field from the Evernote Note contents + ################################### Note Processing complete. + + def detect_note_model(self): + # log('Title, self.model_name, tags, self.model_name', 'detectnotemodel') + # log(self.FullTitle, 'detectnotemodel') + # log(self.ModelName, 'detectnotemodel') + if FIELDS.CONTENT in self.Fields and "{{c1::" in self.Fields[FIELDS.CONTENT]: + self.ModelName = MODELS.CLOZE + if len(self.Tags) > 0: + reverse_override = (TAGS.TOC in self.Tags or TAGS.TOC_AUTO in self.Tags) + if TAGS.REVERSIBLE in self.Tags: + self.ModelName = MODELS.REVERSIBLE + self.Tags.remove(TAGS.REVERSIBLE) + elif TAGS.REVERSE_ONLY in self.Tags: + self.ModelName = MODELS.REVERSE_ONLY + self.Tags.remove(TAGS.REVERSE_ONLY) + if reverse_override: + self.ModelName = MODELS.DEFAULT + + def model_id(self): + if not self.ModelName: + return None + return long(self.Anki.models().byName(self.ModelName)['id']) + + def process_note(self, **kw): + self.process_note_content(**kw) + if not self.light_processing: + self.detect_note_model() + + def update_note_model(self): + modelNameNew = self.ModelName + if not modelNameNew: + return False + modelIdOld = self.note.mid + modelIdNew = self.model_id() + if modelIdOld == modelIdNew: + return False + mm = self.Anki.models() + modelOld = self.note.model() + modelNew = mm.get(modelIdNew) + modelNameOld = modelOld['name'] + fmap = get_self_referential_fmap() + cmap = {0: 0} + if modelNameOld == MODELS.REVERSE_ONLY and modelNameNew == MODELS.REVERSIBLE: + cmap[0] = 1 + elif modelNameOld == MODELS.REVERSIBLE: + if modelNameNew == MODELS.REVERSE_ONLY: + cmap = {0: None, 1: 0} + else: + cmap[1] = None + self.log_update("Changing model:\n From: '%s' \n To: '%s'" % (modelNameOld, modelNameNew)) + mm.change(modelOld, [self.note.id], modelNew, fmap, cmap) + self.Changed = True + return True + + def log_update(self, content=''): + if not self.logged: + count_updated_new = (self.Counts.Updated + 1 if content else 0) + count_str = '' + if self.Counts.Current > 0: + count_str = ' [' + if self.Counts.Current - count_updated_new > 0 and count_updated_new > 0: + count_str += '%3d/' % count_updated_new + count_str += '%-4d]/[' % self.Counts.Current + else: + count_str += '%4d/' % self.Counts.Current + count_str += '%-4d]' % self.Counts.Max + count_str += ' (%2d%%)' % (float(self.Counts.Current) / self.Counts.Max * 100) + log_title = '!' if content else '' + log_title += 'UPDATING NOTE%s: %-80s %s' % (count_str, self.FullTitle + ':', self.Guid) + log(log_title, self.__log_name, timestamp=(content is ''), + clear=((self.Counts.Current == 1 or self.Counts.Current == 100) and not self.logged)) + self.logged = True + if not content: + return + content = obj2log_simple(content) + content = content.replace('\n', '\n ') + if content.lstrip()[: + 1] != '>': content = '> ' + content + log(' %s\n' % content, self.__log_name, timestamp=False) + + def update_note_tags(self): + if len(self.Tags) == 0: + return False + self.Tags = get_tag_names_to_import(self.Tags) + if not self.BaseNote: + self.log_update("Error with unt") + self.log_update(self.Tags) + self.log_update(self.Fields) + self.log_update(self.BaseNote) + assert self.BaseNote + baseTags = sorted(self.BaseNote.tags, key=lambda s: s.lower()) + value = u','.join(self.Tags) + value_original = u','.join(baseTags) + if str(value) == str(value_original): + return False + self.log_update("Changing tags:\n From: '%s' \n To: '%s'" % (value_original, value)) + self.BaseNote.tags = self.Tags + self.Changed = True + return True + + def update_note_deck(self): + deckNameNew = self.deck() + if not deckNameNew: + return False + deckIDNew = self.Anki.decks().id(deckNameNew) + deckIDOld = get_anki_deck_id_from_note_id(self.note.id) + if deckIDNew == deckIDOld: + return False + self.log_update( + "Changing deck:\n From: '%s' \n To: '%s'" % (self.Anki.decks().nameOrNone(deckIDOld), self.deck())) + # Not sure if this is necessary or Anki does it by itself: + ankDB().execute("UPDATE cards SET did = ? WHERE nid = ?", [deckIDNew, self.note.id]) + return True + + def update_note_fields(self): + fields_to_update = [FIELDS.TITLE, FIELDS.CONTENT, FIELDS.SEE_ALSO, FIELDS.UPDATE_SEQUENCE_NUM] + fld_content_ord = -1 + flag_changed = False + field_updates = [] + fields_updated = {} + for fld in self.note._model['flds']: + if FIELDS.EVERNOTE_GUID in fld.get('name'): + self.OriginalGuid = self.note.fields[fld.get('ord')].replace(FIELDS.EVERNOTE_GUID_PREFIX, '') + for field_to_update in fields_to_update: + if field_to_update == fld.get('name') and field_to_update in self.Fields: + if field_to_update is FIELDS.CONTENT: + fld_content_ord = fld.get('ord') + try: + value = decode(self.Fields[field_to_update]) + value_original = decode(self.note.fields[fld.get('ord')]) + if not value == value_original: + flag_changed = True + self.note.fields[fld.get('ord')] = value + fields_updated[field_to_update] = value_original + if field_to_update is FIELDS.CONTENT or field_to_update is FIELDS.SEE_ALSO: + diff = generate_diff(value_original, value) + else: + diff = 'From: \n%s \n\n To: \n%s' % (value_original, value) + field_updates.append("Changing field #%d %s:\n%s" % (fld.get('ord'), field_to_update, diff)) + except Exception: + self.log_update(field_updates) + log_error( + "ERROR: UPDATE_NOTE: Note '%s': %s: Unable to set self.note.fields for field '%s'. Ord: %s. Note fields count: %d" % ( + self.Guid, self.FullTitle, field_to_update, str(fld.get('ord')), + len(self.note.fields))) + raise + for update in field_updates: + self.log_update(update) + if flag_changed: + self.Changed = True + return flag_changed + + def update_note(self, error_if_unchanged=True): + self.note = self.BaseNote + self.logged = False + if not self.BaseNote: + self.log_update("Not updating Note: Could not find base note") + return EvernoteAPIStatus.NotFoundError, self.note.id + self.Changed = False + self.update_note_tags() + self.update_note_fields() + i_see_also = self.Fields[FIELDS.CONTENT].find("See Also") + if i_see_also > -1: + i_div = self.Fields[FIELDS.CONTENT].rfind("<div", 0, i_see_also) + if i_div is -1: + i_div = i_see_also + if self.loggedSeeAlsoError != self.Guid: + log_error( + "No See Also Content Found, but phrase 'See Also' exists in " + self.Guid + ": " + self.FullTitle, + crosspost_to_default=False) + log( + "No See Also Content Found, but phrase 'See Also' exists: \n" + self.Guid + ": " + self.FullTitle + " \n" + + self.Fields[FIELDS.CONTENT][i_div:i_see_also + 50] + '\n', 'SeeAlso\\MatchExpectedUpdate') + log(self.Fields[FIELDS.CONTENT], 'SeeAlso\\MatchExpectedUpdate\\' + self.FullTitle) + if not (self.Changed or self.update_note_deck()): + if self._log_update_if_unchanged_: + self.log_update("Not updating Note: The fields, tags, and deck are the same") + elif (self.Counts.Updated is 0 or + self.Counts.Current / self.Counts.Updated > 9) and self.Counts.Current % 100 is 0: + self.log_update() + return EvernoteAPIStatus.UnchangedError if error_if_unchanged else EvernoteAPIStatus.Unchanged, self.note.id + if not self.Changed: + # i.e., the note deck has been changed but the tags and fields have not + self.Counts.Updated += 1 + return EvernoteAPIStatus.UnchangedError if error_if_unchanged else EvernoteAPIStatus.Success, self.note.id + if not self.OriginalGuid: + flds = get_dict_from_list(self.BaseNote.items()) + self.OriginalGuid = get_evernote_guid_from_anki_fields(flds) + db_title = get_evernote_title_from_guid(self.OriginalGuid) + self.check_titles_equal(db_title, self.FullTitle, self.Guid) + self.note.flush(intTime()) + self.log_update(" > Flushing Note") + self.update_note_model() + self.Counts.Updated += 1 + return EvernoteAPIStatus.Success, self.note.id + + def check_titles_equal(self, old_title, new_title, new_guid, log_title='DB INFO UNEQUAL'): + do_log_title = False + try: + new_title = decode(new_title) + except Exception: + do_log_title = True + try: + old_title = decode(old_title) + except Exception: + do_log_title = True + guid_text = '' if self.OriginalGuid is None else ' ' + self.OriginalGuid + ( + '' if new_guid == self.OriginalGuid else ' vs %s' % new_guid) + ':' + if do_log_title or new_title != old_title or (self.OriginalGuid and new_guid != self.OriginalGuid): + log_str = ' %s: %s%s' % ( + '*' if do_log_title else ' ' + log_title, guid_text, ' ' + new_title + ' vs ' + old_title) + log_error(log_str, crosspost_to_default=False) + self.log_update(log_str) + return False + return True + + @property + def Title(self): + """:rtype : EvernoteNoteTitle.EvernoteNoteTitle """ + title = "" + if FIELDS.TITLE in self.Fields: + title = self.Fields[FIELDS.TITLE] + elif self.BaseNote: + title = self.originalFields[FIELDS.TITLE] + return EvernoteNoteTitle(title) + + @property + def FullTitle(self): return self.Title.FullTitle + + def save_anki_fields_decoded(self, attempt, from_anp_fields=False, do_decode=None): + title = self.db_title if hasattr(self, 'db_title') else self.FullTitle + e_return = False + log_header = 'ANKI-->ANP-->' + if from_anp_fields: + log_header += 'CREATE ANKI FIELDS' + base_values = self.Fields.items() + else: + log_header += 'SAVE ANKI FIELDS (DECODED)' + base_values = enumerate(self.note.fields) + for key, value in base_values: + name = key if from_anp_fields else FIELDS.LIST[key - 1] if key > 0 else FIELDS.EVERNOTE_GUID + if isinstance(value, unicode) and not do_decode is True: + action = 'ENCODING' + elif isinstance(value, str) and not do_decode is False: + action = 'DECODING' + else: + action = 'DOING NOTHING' + log('\t - %s for %s field %s' % (action, value.__class__.__name__, name), 'unicode', timestamp=False) + if action is not 'DOING NOTHING': + try: + new_value = encode(value) if action == 'ENCODED' else decode(value) + if from_anp_fields: + self.note[key] = new_value + else: + self.note.fields[key] = new_value + except (UnicodeDecodeError, UnicodeEncodeError, UnicodeTranslateError, UnicodeError, Exception) as e: + e_return = HandleUnicodeError(log_header, e, self.Guid, title, action, attempt, value, field=name) + if e_return is not 1: + raise + if e_return is not False: + log_blank('unicode') + return 1 + + def add_note_try(self, attempt=1): + title = self.db_title if hasattr(self, 'db_title') else self.FullTitle + col = self.Anki.collection() + log_header = 'ANKI-->ANP-->ADD NOTE FAILED' + action = 'DECODING?' + try: + col.addNote(self.note) + except (UnicodeDecodeError, UnicodeEncodeError, UnicodeTranslateError, UnicodeError, Exception), e: + e_return = HandleUnicodeError(log_header, e, self.Guid, title, action, attempt, self.note[FIELDS.TITLE]) + if e_return is not 1: + raise + self.save_anki_fields_decoded(attempt + 1) + return self.add_note_try(attempt + 1) + return 1 + + def add_note(self): + self.create_note() + if self.note is None: + return EvernoteAPIStatus.NotFoundError, None + collection = self.Anki.collection() + db_title = get_evernote_title_from_guid(self.Guid) + log(' %s: ADD: ' % self.Guid + ' ' + self.FullTitle, self.__log_name) + self.check_titles_equal(db_title, self.FullTitle, self.Guid, 'NEW NOTE TITLE UNEQUAL TO DB ENTRY') + if self.add_note_try() is not 1: + return EvernoteAPIStatus.GenericError, None + collection.autosave() + self.Anki.start_editing() + return EvernoteAPIStatus.Success, self.note.id + + def create_note(self, attempt=1): + id_deck = self.Anki.decks().id(self.deck()) + if not self.ModelName: + self.ModelName = MODELS.DEFAULT + model = self.Anki.models().byName(self.ModelName) + col = self.Anki.collection() + self.note = AnkiNote(col, model) + self.note.model()['did'] = id_deck + self.note.tags = self.Tags + title = self.db_title if hasattr(self, 'db_title') else self.FullTitle + self.save_anki_fields_decoded(attempt, True, True) diff --git a/anknotes/Controller.py b/anknotes/Controller.py new file mode 100644 index 0000000..3b74e10 --- /dev/null +++ b/anknotes/Controller.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +### Python Imports +import socket +from datetime import datetime + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### Anknotes Shared Imports +from anknotes.shared import * +from anknotes.error import * + +### Anknotes Class Imports +from anknotes.AnkiNotePrototype import AnkiNotePrototype +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype +from anknotes.EvernoteNoteTitle import generateTOCTitle +from anknotes import stopwatch +### Anknotes Main Imports +from anknotes.Anki import Anki +from anknotes.ankEvernote import Evernote +from anknotes.EvernoteNotes import EvernoteNotes +from anknotes.EvernoteNoteFetcher import EvernoteNoteFetcher +from anknotes import settings +from anknotes.EvernoteImporter import EvernoteImporter + +### Evernote Imports +from anknotes.evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec +from anknotes.evernote.edam.type.ttypes import NoteSortOrder, Note as EvernoteNote +from anknotes.evernote.edam.error.ttypes import EDAMSystemException + +### Anki Imports +from aqt import mw + + +# load_time = datetime.now() +# log("Loaded controller at " + load_time.isoformat(), 'import') +class Controller: + evernoteImporter = None + """:type : EvernoteImporter""" + + def __init__(self): + self.forceAutoPage = False + self.auto_page_callback = None + self.anki = Anki() + self.anki.deck = SETTINGS.ANKI.DECKS.BASE.fetch() + self.anki.setup_ancillary_files() + ankDB().Init() + self.anki.add_evernote_models() + self.evernote = Evernote() + + def test_anki(self, title, evernote_guid, filename=""): + if not filename: + filename = title + fields = { + FIELDS.TITLE: title, + FIELDS.CONTENT: file( + os.path.join(FOLDERS.LOGS, filename.replace('.enex', '') + ".enex"), + 'r').read(), FIELDS.EVERNOTE_GUID: FIELDS.EVERNOTE_GUID_PREFIX + evernote_guid + } + tags = ['NoTags', 'NoTagsToRemove'] + return AnkiNotePrototype(self.anki, fields, tags) + + def process_unadded_see_also_notes(self): + update_regex() + anki_note_ids = self.anki.get_anknotes_note_ids_with_unadded_see_also() + self.evernote.getNoteCount = 0 + self.anki.process_see_also_content(anki_note_ids) + + def upload_validated_notes(self, automated=False): + db = ankDB(TABLES.NOTE_VALIDATION_QUEUE) + dbRows = db.all("validation_status = 1") + notes_created, notes_updated, queries1, queries2 = ([] for i in range(4)) + """ + :type: (list[EvernoteNote], list[EvernoteNote], list[str], list[str]) + """ + noteFetcher = EvernoteNoteFetcher() + tmr = stopwatch.Timer(len(dbRows), 25, infoStr="Upload of Validated Evernote Notes", automated=automated, + enabled=EVERNOTE.UPLOAD.ENABLED, max_allowed=EVERNOTE.UPLOAD.MAX, + label='Validation\\upload_validated_notes\\', display_initial_info=True) + if tmr.actionInitializationFailed: + return tmr.status, 0, 0 + for dbRow in dbRows: + entry = EvernoteValidationEntry(dbRow) + evernote_guid, rootTitle, contents, tagNames, notebookGuid, noteType = entry.items() + tagNames = tagNames.split(',') + if not tmr.checkLimits(): + break + whole_note = tmr.autoStep( + self.evernote.makeNote(rootTitle, contents, tagNames, notebookGuid, guid=evernote_guid, + noteType=noteType, validated=True), rootTitle, evernote_guid) + if tmr.report_result is False: + raise ValueError + if tmr.status.IsDelayableError: + break + if not tmr.status.IsSuccess: + continue + if not whole_note.tagNames: + whole_note.tagNames = tagNames + noteFetcher.addNoteFromServerToDB(whole_note, tagNames) + note = EvernoteNotePrototype(whole_note=whole_note) + assert whole_note.tagNames + assert note.Tags + if evernote_guid: + notes_updated.append(note) + queries1.append([evernote_guid]) + else: + notes_created.append(note) + queries2.append([rootTitle, contents]) + else: + tmr.reportNoBreak() + tmr.Report(self.anki.add_evernote_notes(notes_created) if tmr.counts.created else 0, + self.anki.update_evernote_notes(notes_updated) if tmr.counts.updated else 0) + if tmr.counts.created.completed.subcount: + db.executemany("DELETE FROM {t} WHERE title = ? and contents = ? ", queries2) + if tmr.counts.updated.completed.subcount: + db.executemany("DELETE FROM {t} WHERE guid = ? ", queries1) + if tmr.is_success: + db.commit() + if tmr.should_retry: + create_timer(30 if tmr.status.IsDelayableError else EVERNOTE.UPLOAD.RESTART_INTERVAL, + self.upload_validated_notes, True) + return tmr.status, tmr.count, 0 + + def create_toc_auto(self): + db = ankDB() + def check_old_values(): + old_values = db.first("UPPER(title) = UPPER(?) AND tagNames LIKE '{t_tauto}'", + rootTitle, columns='guid, content') + if not old_values: + log.go(rootTitle, 'Add') + return None, contents + evernote_guid, old_content = old_values + noteBodyUnencoded = self.evernote.makeNoteBody(contents, encode=False) + if type(old_content) != type(noteBodyUnencoded): + log.go([rootTitle, type(old_content), type(noteBodyUnencoded)], 'Update\\Diffs\\_') + raise UnicodeWarning + old_content = old_content.replace('guid-pending', evernote_guid).replace("'", '"') + noteBodyUnencoded = noteBodyUnencoded.replace('guid-pending', evernote_guid).replace("'", '"') + if old_content == noteBodyUnencoded: + log.go(rootTitle, 'Skipped') + tmr.reportSkipped() + return None, None + log.go(noteBodyUnencoded, 'Update\\New\\' + rootTitle, clear=True) + log.go(generate_diff(old_content, noteBodyUnencoded), 'Update\\Diffs\\' + rootTitle, clear=True) + return evernote_guid, contents.replace( + '/guid-pending/', '/%s/' % evernote_guid).replace('/guid-pending/', '/%s/' % evernote_guid) + + update_regex() + noteType = 'create-toc_auto_notes' + db.delete("noteType = '%s'" % noteType, table=TABLES.NOTE_VALIDATION_QUEUE) + NotesDB = EvernoteNotes() + NotesDB.baseQuery = ANKNOTES.HIERARCHY.ROOT_TITLES_BASE_QUERY + dbRows = NotesDB.populateAllNonCustomRootNotes() + notes_created, notes_updated = [], [] + """ + :type: (list[EvernoteNote], list[EvernoteNote]) + """ + info = stopwatch.ActionInfo('Creation of Table of Content Note(s)', row_source='Root Title(s)') + log = Logger('See Also\\2-%s\\' % noteType, rm_path=True) + tmr = stopwatch.Timer(len(dbRows), 25, info, max_allowed=EVERNOTE.UPLOAD.MAX, + label=log.base_path) + if tmr.actionInitializationFailed: + return tmr.status, 0, 0 + for dbRow in dbRows: + rootTitle, contents, tagNames, notebookGuid = dbRow.items() + tagNames = (set(tagNames[1:-1].split(',')) | {TAGS.TOC, TAGS.TOC_AUTO} | ( + {"#Sandbox"} if EVERNOTE.API.IS_SANDBOXED else set())) - {TAGS.REVERSIBLE, TAGS.REVERSE_ONLY} + rootTitle = generateTOCTitle(rootTitle) + evernote_guid, contents = check_old_values() + if contents is None: + continue + if not tmr.checkLimits(): + break + if not EVERNOTE.UPLOAD.ENABLED: + tmr.reportStatus(EvernoteAPIStatus.Disabled, title=rootTitle) + continue + whole_note = tmr.autoStep( + self.evernote.makeNote(rootTitle, contents, tagNames, notebookGuid, noteType=noteType, + guid=evernote_guid), rootTitle, evernote_guid) + if tmr.report_result is False: + raise ValueError + if tmr.status.IsDelayableError: + break + if not tmr.status.IsSuccess: + continue + (notes_updated if evernote_guid else notes_created).append(EvernoteNotePrototype(whole_note=whole_note)) + tmr.Report(self.anki.add_evernote_notes(notes_created) if tmr.counts.created.completed else 0, + self.anki.update_evernote_notes(notes_updated) if tmr.counts.updated.completed else 0) + if tmr.counts.queued: + db.commit() + return tmr.status, tmr.count, tmr.counts.skipped.val + + def update_ancillary_data(self): + self.evernote.update_ancillary_data() + + def proceed(self, auto_paging=False): + if not self.evernoteImporter: + self.evernoteImporter = EvernoteImporter() + self.evernoteImporter.anki = self.anki + self.evernoteImporter.evernote = self.evernote + self.evernoteImporter.forceAutoPage = self.forceAutoPage + self.evernoteImporter.auto_page_callback = self.auto_page_callback + if not hasattr(self, 'currentPage'): + self.currentPage = 1 + self.evernoteImporter.currentPage = self.currentPage + if hasattr(self, 'ManualGUIDs'): + self.evernoteImporter.ManualGUIDs = self.ManualGUIDs + self.evernoteImporter.proceed(auto_paging) + + def resync_with_local_db(self): + log_banner('Resync With Local DB', clear=False, append_newline=False, prepend_newline=True) + evernote_guids = get_all_local_db_guids() + tmr = stopwatch.Timer(evernote_guids, strInfo='Resync Notes From Local DB', label='resync_with_local_db\\') + results = self.evernote.create_evernote_notes(evernote_guids, use_local_db_only=True) + """:type: EvernoteNoteFetcherResults""" + log(' > Finished Creating Evernote Notes: '.ljust(40) + tmr.str_long) + tmr.reset() + number = self.anki.update_evernote_notes(results.Notes, log_update_if_unchanged=False) + log(' > Finished Updating Anki Notes: '.ljust(40) + tmr.str_long) + tooltip = '%d Evernote Notes Created<BR>%d Anki Notes Successfully Updated' % (results.Local, number) + show_report(' > Resync with Local DB Complete', tooltip) diff --git a/anknotes/EvernoteImporter.py b/anknotes/EvernoteImporter.py new file mode 100644 index 0000000..86c88eb --- /dev/null +++ b/anknotes/EvernoteImporter.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +### Python Imports +import socket + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### Anknotes Shared Imports +from anknotes.shared import * +from anknotes.error import * + +### Anknotes Class Imports +from anknotes.AnkiNotePrototype import AnkiNotePrototype +from anknotes.structs_base import UpdateExistingNotes + +### Anknotes Main Imports +from anknotes.Anki import Anki +from anknotes.ankEvernote import Evernote +from anknotes.EvernoteNotes import EvernoteNotes +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype + +try: + from anknotes import settings +except Exception: + pass + +### Evernote Imports +from anknotes.evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec, NoteMetadata, NotesMetadataList +from anknotes.evernote.edam.type.ttypes import NoteSortOrder, Note as EvernoteNote +from anknotes.evernote.edam.error.ttypes import EDAMSystemException + +### Anki Imports +try: + from aqt import mw +except Exception: + pass + + +class EvernoteImporter: + forceAutoPage = False + auto_page_callback = None + """:type : lambda""" + anki = None + """:type : Anki""" + evernote = None + """:type : Evernote""" + updateExistingNotes = UpdateExistingNotes.UpdateNotesInPlace + ManualGUIDs = None + + @property + def ManualMetadataMode(self): + return (self.ManualGUIDs is not None and len(self.ManualGUIDs) > 0) + + def __init(self): + self.updateExistingNotes = SETTINGS.ANKI.UPDATE_EXISTING_NOTES.fetch(UpdateExistingNotes.UpdateNotesInPlace) + self.ManualGUIDs = None + + def override_evernote_metadata(self): + guids = self.ManualGUIDs + self.MetadataProgress = EvernoteMetadataProgress(self.currentPage) + self.MetadataProgress.Total = len(guids) + self.MetadataProgress.Current = min(self.MetadataProgress.Total - self.MetadataProgress.Offset, + EVERNOTE.IMPORT.QUERY_LIMIT) + result = NotesMetadataList() + result.totalNotes = len(guids) + result.updateCount = -1 + result.startIndex = self.MetadataProgress.Offset + result.notes = [] + """:type : list[NoteMetadata]""" + for i in range(self.MetadataProgress.Offset, self.MetadataProgress.Completed): + result.notes.append(NoteMetadata(guids[i])) + self.MetadataProgress.loadResults(result) + self.evernote.metadata = self.MetadataProgress.NotesMetadata + return True + + def get_evernote_metadata(self): + """ + :returns: Metadata Progress Instance + :rtype : EvernoteMetadataProgress) + """ + query = settings.generate_evernote_query() + evernote_filter = NoteFilter(words=query, ascending=True, order=NoteSortOrder.UPDATED) + self.MetadataProgress = EvernoteMetadataProgress(self.currentPage) + spec = NotesMetadataResultSpec(includeTitle=False, includeUpdated=False, includeUpdateSequenceNum=True, + includeTagGuids=True, includeNotebookGuid=True) + notestore_status = self.evernote.initialize_note_store() + if not notestore_status.IsSuccess: + self.MetadataProgress.Status = notestore_status + return False # notestore_status + api_action_str = u'trying to search for note metadata' + log_api("findNotesMetadata", "[Offset: %3d]: Query: '%s'" % (self.MetadataProgress.Offset, query)) + try: + result = self.evernote.noteStore.findNotesMetadata(self.evernote.token, evernote_filter, + self.MetadataProgress.Offset, + EVERNOTE.IMPORT.METADATA_RESULTS_LIMIT, spec) + """ + :type: NotesMetadataList + """ + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + self.MetadataProgress.Status = EvernoteAPIStatus.RateLimitError + return False + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + self.MetadataProgress.Status = EvernoteAPIStatus.SocketError + return False + self.MetadataProgress.loadResults(result) + self.evernote.metadata = self.MetadataProgress.NotesMetadata + log(self.MetadataProgress.Summary, line_padding_header="- Metadata Results: ", + line_padding=ANKNOTES.FORMATTING.LINE_PADDING_HEADER, timestamp=False) + return True + + def update_in_anki(self, evernote_guids): + """ + :rtype : EvernoteNoteFetcherResults + """ + Results = self.evernote.create_evernote_notes(evernote_guids) + if self.ManualMetadataMode: + self.evernote.check_notebooks_up_to_date() + self.anki.notebook_data = self.evernote.notebook_data + Results.Imported = self.anki.update_evernote_notes(Results.Notes) + return Results + + def import_into_anki(self, evernote_guids): + """ + :rtype : EvernoteNoteFetcherResults + """ + Results = self.evernote.create_evernote_notes(evernote_guids) + if self.ManualMetadataMode: + self.evernote.check_notebooks_up_to_date() + self.anki.notebook_data = self.evernote.notebook_data + Results.Imported = self.anki.add_evernote_notes(Results.Notes) + return Results + + def check_note_sync_status(self, evernote_guids): + """ + Check for already existing, up-to-date, local db entries by Evernote GUID + :param evernote_guids: List of GUIDs + :return: List of Already Existing Evernote GUIDs + :rtype: list[str] + """ + notes_already_up_to_date = [] + db = ankDB() + for evernote_guid in evernote_guids: + db_usn = db.scalar({'guid': evernote_guid}, columns='updateSequenceNum') + if not self.evernote.metadata[evernote_guid].updateSequenceNum: + server_usn = 'N/A' + else: + server_usn = self.evernote.metadata[evernote_guid].updateSequenceNum + if evernote_guid in self.anki.usns: + current_usn = self.anki.usns[evernote_guid] + if current_usn == str(server_usn): + log_info = None # 'ANKI NOTE UP-TO-DATE' + notes_already_up_to_date.append(evernote_guid) + elif str(db_usn) == str(server_usn): + log_info = 'DATABASE ENTRY UP-TO-DATE' + else: + log_info = 'NO COPIES UP-TO-DATE' + else: + current_usn = 'N/A' + log_info = 'NO ANKI USN EXISTS' + if log_info: + if not self.evernote.metadata[evernote_guid].updateSequenceNum: + log_info += ' (Unable to find Evernote Metadata) ' + log(" > USN check for note '%s': %s: db/current/server = %s,%s,%s" % ( + evernote_guid, log_info, str(db_usn), str(current_usn), str(server_usn)), 'usn') + return notes_already_up_to_date + + def proceed(self, auto_paging=False): + self.proceed_start(auto_paging) + self.proceed_find_metadata(auto_paging) + self.proceed_import_notes() + self.proceed_autopage() + + def proceed_start(self, auto_paging=False): + col = self.anki.collection() + lastImport = SETTINGS.EVERNOTE.LAST_IMPORT.fetch() + SETTINGS.EVERNOTE.LAST_IMPORT.save(datetime.now().strftime(ANKNOTES.DATE_FORMAT)) + lastImportStr = get_friendly_interval_string(lastImport) + if lastImportStr: + lastImportStr = ' [LAST IMPORT: %s]' % lastImportStr + log_str = " > Starting Evernote Import: Page %3s Query: %s" % ( + '#' + str(self.currentPage), settings.generate_evernote_query()) + log_banner(log_str.ljust(ANKNOTES.FORMATTING.TEXT_LENGTH-len(lastImportStr)) + lastImportStr, append_newline=False, + chr='=', length=0, center=False, clear=False, timestamp=True) + if auto_paging: + return True + notestore_status = self.evernote.initialize_note_store() + if not notestore_status == EvernoteAPIStatus.Success: + log(" > Note store does not exist. Aborting.") + show_tooltip("Could not connect to Evernote servers (Status Code: %s)... Aborting." % notestore_status.name) + return False + self.evernote.getNoteCount = 0 + return True + + def proceed_find_metadata(self, auto_paging=False): + global latestEDAMRateLimit, latestSocketError + + if self.ManualMetadataMode: + self.override_evernote_metadata() + else: + self.get_evernote_metadata() + + if self.MetadataProgress.Status == EvernoteAPIStatus.RateLimitError: + m, s = divmod(latestEDAMRateLimit, 60) + show_report(" > Error: Delaying Operation", + "Over the rate limit when searching for Evernote metadata<BR>Evernote requested we wait %d:%02d min" % ( + m, s), delay=5) + create_timer(latestEDAMRateLimit + 10, self.proceed, auto_paging) + return False + elif self.MetadataProgress.Status == EvernoteAPIStatus.SocketError: + show_report(" > Error: Delaying Operation:", + "%s when searching for Evernote metadata" % + latestSocketError['friendly_error_msg'], "We will try again in 30 seconds", delay=5) + create_timer(30, self.proceed, auto_paging) + return False + + self.ImportProgress = EvernoteImportProgress(self.anki, self.MetadataProgress) + self.ImportProgress.loadAlreadyUpdated( + [] if self.ManualMetadataMode else self.check_note_sync_status( + self.ImportProgress.GUIDs.Server.Existing.All)) + log(self.ImportProgress.Summary + "\n", line_padding_header="- Note Sync Status: ", + line_padding=ANKNOTES.FORMATTING.LINE_PADDING_HEADER, timestamp=False) + + def proceed_import_notes(self): + self.anki.start_editing() + self.ImportProgress.processResults(self.import_into_anki(self.ImportProgress.GUIDs.Server.New)) + if self.updateExistingNotes == UpdateExistingNotes.UpdateNotesInPlace: + self.ImportProgress.processUpdateInPlaceResults( + self.update_in_anki(self.ImportProgress.GUIDs.Server.Existing.OutOfDate)) + elif self.updateExistingNotes == UpdateExistingNotes.DeleteAndReAddNotes: + self.anki.delete_anki_cards(self.ImportProgress.GUIDs.Server.Existing.OutOfDate) + self.ImportProgress.processDeleteAndUpdateResults( + self.import_into_anki(self.ImportProgress.GUIDs.Server.Existing.OutOfDate)) + show_report(" > Import Complete", self.ImportProgress.ResultsSummaryLines) + self.anki.stop_editing() + self.anki.collection().autosave() + + def save_current_page(self): + if self.forceAutoPage: + return + SETTINGS.EVERNOTE.PAGINATION_CURRENT_PAGE.save(self.currentPage) + + def proceed_autopage(self): + if not self.autoPagingEnabled: + return + global latestEDAMRateLimit, latestSocketError + status = self.ImportProgress.Status + restart = 0 + if status == EvernoteAPIStatus.RateLimitError: + m, s = divmod(latestEDAMRateLimit, 60) + show_report(" > Error: Delaying Auto Paging", + "Over the rate limit when getting Evernote notes<BR>Evernote requested we wait %d:%02d min" % ( + m, s), delay=5) + create_timer(latestEDAMRateLimit + 10, self.proceed, True) + return False + if status == EvernoteAPIStatus.SocketError: + show_report(" > Error: Delaying Auto Paging:", + "%s when getting Evernote notes" % latestSocketError[ + 'friendly_error_msg'], + "We will try again in 30 seconds", delay=5) + create_timer(30, self.proceed, True) + return False + if self.MetadataProgress.IsFinished: + self.currentPage = 1 + if self.forceAutoPage: + show_report(" > Terminating Auto Paging", + "All %d notes have been processed and forceAutoPage is True" % self.MetadataProgress.Total, + delay=5) + if self.auto_page_callback: + self.auto_page_callback() + return True + elif EVERNOTE.IMPORT.PAGING.RESTART.ENABLED: + restart = max(EVERNOTE.IMPORT.PAGING.RESTART.INTERVAL, 60 * 15) + restart_title = " > Restarting Auto Paging" + restart_msg = "All %d notes have been processed and EVERNOTE.IMPORT.PAGING.RESTART.ENABLED is True<BR>" % \ + self.MetadataProgress.Total + suffix = "Per EVERNOTE.IMPORT.PAGING.RESTART.INTERVAL, " + else: + show_report(" > Completed Auto Paging", + "All %d notes have been processed and EVERNOTE.IMPORT.PAGING.RESTART.ENABLED is False" % + self.MetadataProgress.Total, delay=5) + self.save_current_page() + return True + else: # Paging still in progress (else to ) + self.currentPage = self.MetadataProgress.Page + 1 + restart_title = " > Continuing Auto Paging" + restart_msg = "Page %d completed<BR>%d notes remain over %d page%s<BR>%d of %d notes have been processed" % ( + self.MetadataProgress.Page, self.MetadataProgress.Remaining, self.MetadataProgress.RemainingPages, + 's' if self.MetadataProgress.RemainingPages > 1 else '', self.MetadataProgress.Completed, + self.MetadataProgress.Total) + restart = -1 * max(30, EVERNOTE.IMPORT.PAGING.RESTART.INTERVAL_OVERRIDE) + if self.forceAutoPage: + suffix = "<BR>Only delaying {interval} as the forceAutoPage flag is set" + elif self.ImportProgress.APICallCount < EVERNOTE.IMPORT.PAGING.RESTART.DELAY_MINIMUM_API_CALLS: + suffix = "<BR>Only delaying {interval} as the API Call Count of %d is less than the minimum of %d set by EVERNOTE.IMPORT.PAGING.RESTART.DELAY_MINIMUM_API_CALLS" % ( + self.ImportProgress.APICallCount, EVERNOTE.IMPORT.PAGING.RESTART.DELAY_MINIMUM_API_CALLS) + else: + restart = max(EVERNOTE.IMPORT.PAGING.INTERVAL_SANDBOX, 60 * 5) if EVERNOTE.API.IS_SANDBOXED else max( + EVERNOTE.IMPORT.PAGING.INTERVAL, 60 * 10) + suffix = "<BR>Delaying Auto Paging: Per EVERNOTE.IMPORT.PAGING.INTERVAL, " + self.save_current_page() + if restart > 0: + suffix += "will delay for {interval} before continuing" + m, s = divmod(abs(restart), 60) + suffix = suffix.format(interval=['%2ds' % s, '%d:%02d min' % (m, s)][m > 0]) + show_report(restart_title, (restart_msg + suffix).split('<BR>'), delay=5) + if restart: + create_timer(restart, self.proceed, True) + return self.proceed(True) + + @property + def autoPagingEnabled(self): + return SETTINGS.EVERNOTE.AUTO_PAGING.fetch() or self.forceAutoPage diff --git a/anknotes/EvernoteNoteFetcher.py b/anknotes/EvernoteNoteFetcher.py new file mode 100644 index 0000000..aacb9f4 --- /dev/null +++ b/anknotes/EvernoteNoteFetcher.py @@ -0,0 +1,178 @@ +### Python Imports +import socket + +### Anknotes Shared Imports +from anknotes.base import decode +from anknotes.shared import * +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype +from anknotes.error import * + +### Evernote Imports +from evernote.edam.error.ttypes import EDAMSystemException + + +class EvernoteNoteFetcher(object): + def __init__(self, evernote=None, guid=None, use_local_db_only=False): + """ + + :type evernote: ankEvernote.Evernote + """ + self.__reset_data() + self.results = EvernoteNoteFetcherResults() + self.result = EvernoteNoteFetcherResult() + self.api_calls = 0 + self.keepEvernoteTags, self.deleteQueryTags = True, True + self.evernoteQueryTags, self.tagsToDelete = [], [] + self.use_local_db_only = use_local_db_only + self.__update_sequence_number = -1 + self.evernote = evernote if evernote else None + if not guid: + self.guid = ""; return + self.guid = guid + if evernote and not self.use_local_db_only: + self.__update_sequence_number = self.evernote.metadata[ + self.guid].updateSequenceNum + self.getNote() + + def __reset_data(self): + self.tagNames = [] + self.tagGuids = [] + self.whole_note = None + + @property + def UpdateSequenceNum(self): + if self.result.Note: + return self.result.Note.UpdateSequenceNum + return self.__update_sequence_number + + def reportSuccess(self, note, source=None): + self.reportResult(EvernoteAPIStatus.Success, note, source) + + def reportResult(self, status=None, note=None, source=None): + if note: + self.result.Note = note + status = EvernoteAPIStatus.Success + if not source: + source = 2 + if status: + self.result.Status = status + if source: + self.result.Source = source + self.results.reportResult(self.result) + + def getNoteLocal(self): + # Check Anknotes database for note + query = "guid = '%s'" % self.guid + if self.UpdateSequenceNum > -1: + query += " AND `updateSequenceNum` = %d" % self.UpdateSequenceNum + db_note = ankDB().first(query) + """:type : sqlite.Row""" + if not db_note: + return False + if not self.use_local_db_only: + log(' ' + '-' * 14 + ' ' * 5 + "> getNoteLocal: %s" % db_note['title'], 'api') + assert db_note['guid'] == self.guid + self.reportSuccess(EvernoteNotePrototype(db_note=db_note), 1) + self.setNoteTags(tag_names=self.result.Note.TagNames, tag_guids=self.result.Note.TagGuids) + return True + + def setNoteTags(self, tag_names=None, tag_guids=None): + if not self.keepEvernoteTags: + self.tagGuids, self.tagNames = [], []; return + # if not tag_names: + # if self.tagNames: tag_names = self.tagNames + # if not tag_names and self.result.Note: tag_names = self.result.Note.TagNames + # if not tag_names and self.whole_note: tag_names = self.whole_note.tagNames + # if not tag_names: tag_names = None + if not tag_guids: + tag_guids = self.tagGuids if self.tagGuids else ( + self.result.Note.TagGuids if self.result.Note else (self.whole_note.tagGuids if self.whole_note else None)) + if not tag_names: + tag_names = self.tagNames if self.tagNames else ( + self.result.Note.TagNames if self.result.Note else (self.whole_note.tagNames if self.whole_note else None)) + if not self.evernote or self.result.Source is 1: + self.tagGuids, self.tagNames = tag_guids, tag_names; return + self.tagGuids, self.tagNames = self.evernote.get_matching_tag_data(tag_guids, tag_names) + + def addNoteFromServerToDB(self, whole_note=None, tag_names=None): + """ + Adds note to Anknote DB from an Evernote Note object provided by the Evernote API + :type whole_note : evernote.edam.type.ttypes.Note + """ + if whole_note: + self.whole_note = whole_note + if tag_names: + self.tagNames = tag_names + log('Adding %s: %s' % (self.whole_note.guid, self.whole_note.title), 'ankDB') + if not self.tagGuids: + self.tagGuids = self.whole_note.tagGuids + auto_columns = ['guid', 'title', 'content', 'updated', 'created', 'updateSequenceNum', 'notebookGuid'] + columns = {key: getattr(self.whole_note, key) for key in auto_columns} + columns.update({key: getattr(self, key) for key in ['tagNames', 'tagGuids']}) + for key, value in columns.items(): + if isinstance(value, list): + columns[key] = u',' + u','.join(map(decode, value)) + u',' + elif isinstance(value, str): + columns[key] = decode(value) + db = ankDB() + db.insert_or_replace(columns) + db.insert(columns, table=TABLES.EVERNOTE.NOTES_HISTORY) + db.commit() + + def getNoteRemoteAPICall(self): + notestore_status = self.evernote.initialize_note_store() + if not notestore_status.IsSuccess: + self.reportResult(notestore_status) + return False + api_action_str = u'trying to retrieve a note. We will save the notes downloaded thus far.' + self.api_calls += 1 + log_api(" > getNote [%3d]" % self.api_calls, self.guid) + try: + self.whole_note = self.evernote.noteStore.getNote(self.evernote.token, self.guid, True, False, + False, False) + """:type : evernote.edam.type.ttypes.Note""" + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + self.reportResult(EvernoteAPIStatus.RateLimitError) + return False + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + self.reportResult(EvernoteAPIStatus.SocketError) + return False + assert self.whole_note.guid == self.guid + return True + + def getNoteRemote(self): + if self.api_calls > EVERNOTE.IMPORT.API_CALLS_LIMIT > -1: + log( + "Aborting Evernote.getNoteRemote: EVERNOTE.IMPORT.API_CALLS_LIMIT of %d has been reached" % EVERNOTE.IMPORT.API_CALLS_LIMIT) + return None + if not self.getNoteRemoteAPICall(): + return False + # self.tagGuids, self.tagNames = self.evernote.get_tag_names_from_guids(self.whole_note.tagGuids) + self.setNoteTags(tag_guids=self.whole_note.tagGuids) + self.addNoteFromServerToDB() + if not self.keepEvernoteTags: + self.tagNames = [] + self.reportSuccess(EvernoteNotePrototype(whole_note=self.whole_note, tags=self.tagNames)) + return True + + def setNote(self, whole_note): + self.whole_note = whole_note + self.addNoteFromServerToDB() + + def getNote(self, guid=None): + self.__reset_data() + if guid: + self.result.Note = None + self.guid = guid + self.evernote.guid = guid + self.__update_sequence_number = self.evernote.metadata[ + self.guid].updateSequenceNum if not self.use_local_db_only else -1 + if self.getNoteLocal(): + return True + if self.use_local_db_only: + return False + return self.getNoteRemote() diff --git a/anknotes/EvernoteNotePrototype.py b/anknotes/EvernoteNotePrototype.py new file mode 100644 index 0000000..dbb925f --- /dev/null +++ b/anknotes/EvernoteNotePrototype.py @@ -0,0 +1,139 @@ +### Anknotes Shared Imports +from anknotes.base import is_str_type, decode +from anknotes.html import generate_evernote_url, generate_evernote_link, generate_evernote_link_by_level +from anknotes.structs import upperFirst, EvernoteAPIStatus +from anknotes.logging import log, log_blank, log_error + +### Anknotes Class Imports +from anknotes.EvernoteNoteTitle import EvernoteNoteTitle + + +class EvernoteNotePrototype: + ################## CLASS Note ################ + Title = None + """:type: EvernoteNoteTitle""" + Content = "" + Guid = "" + UpdateSequenceNum = -1 + """:type: int""" + TagNames = [] + TagGuids = [] + NotebookGuid = None + Status = EvernoteAPIStatus.Uninitialized + """:type : EvernoteAPIStatus """ + Children = [] + + @property + def Tags(self): + return self.TagNames + + def process_tags(self): + if is_str_type(self.TagNames): + self.TagNames = self.TagNames[1:-1].split(',') + if is_str_type(self.TagGuids): + self.TagGuids = self.TagGuids[1:-1].split(',') + + def __repr__(self): + return u"<EN Note: %s: '%s'>" % (self.Guid, self.Title) + + def __init__(self, title=None, content=None, guid=None, tags=None, notebookGuid=None, updateSequenceNum=None, + whole_note=None, db_note=None): + """ + + :type whole_note: evernote.edam.type.ttypes.Note + :type db_note: sqlite3.dbapi2.Row + """ + + self.Status = EvernoteAPIStatus.Uninitialized + self.TagNames = tags + if whole_note is not None: + if self.TagNames is None: + self.TagNames = whole_note.tagNames + self.Title = EvernoteNoteTitle(whole_note) + self.Content = whole_note.content + self.Guid = whole_note.guid + self.NotebookGuid = whole_note.notebookGuid + self.UpdateSequenceNum = whole_note.updateSequenceNum + self.Status = EvernoteAPIStatus.Success + return + if db_note is not None: + self.Title = EvernoteNoteTitle(db_note) + db_note_keys = db_note.keys() + for key in ['content', 'guid', 'notebookGuid', 'updateSequenceNum', 'tagNames', 'tagGuids']: + if not key in db_note_keys: + log_error( + "FATAL ERROR: Unable to find key %s in db note %s! \n%s" % (key, self.FullTitle, db_note_keys)) + log("Values: \n\n" + str({k: db_note[k] for k in db_note_keys}), 'EvernoteNotePrototypeInit') + else: + setattr(self, upperFirst(key), db_note[key]) + self.TagNames = decode(self.TagNames) + self.Content = decode(self.Content) + self.process_tags() + self.Status = EvernoteAPIStatus.Success + return + self.Title = EvernoteNoteTitle(title) + self.Content = content + self.Guid = guid + self.NotebookGuid = notebookGuid + self.UpdateSequenceNum = updateSequenceNum + self.Status = EvernoteAPIStatus.Manual + + def generateURL(self): + return generate_evernote_url(self.Guid) + + def generateLink(self, value=None): + return generate_evernote_link(self.Guid, self.Title.Name, value) + + def generateLevelLink(self, value=None): + return generate_evernote_link_by_level(self.Guid, self.Title.Name, value) + + ### Shortcuts to EvernoteNoteTitle Properties; Autogenerated with regex /def +(\w+)\(\)\:/def \1\(\):\r\n\treturn self.Title.\1\r\n/ + @property + def Level(self): + return self.Title.Level + + @property + def Depth(self): + return self.Title.Depth + + @property + def FullTitle(self): return self.Title.FullTitle + + @property + def Name(self): + return self.Title.Name + + @property + def Root(self): + return self.Title.Root + + @property + def Base(self): + return self.Title.Base + + @property + def Parent(self): + return self.Title.Parent + + @property + def TitleParts(self): + return self.Title.TitleParts + + @property + def IsChild(self): + return self.Title.IsChild + + @property + def IsRoot(self): + return self.Title.IsRoot + + def IsAboveLevel(self, level_check): + return self.Title.IsAboveLevel(level_check) + + def IsBelowLevel(self, level_check): + return self.Title.IsBelowLevel(level_check) + + def IsLevel(self, level_check): + return self.Title.IsLevel(level_check) + + ################## END CLASS Note ################ diff --git a/anknotes/EvernoteNoteTitle.py b/anknotes/EvernoteNoteTitle.py new file mode 100644 index 0000000..f221c8d --- /dev/null +++ b/anknotes/EvernoteNoteTitle.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +### Anknotes Shared Imports +from anknotes.shared import * +from anknotes.base import is_str_type +from sys import stderr + + +def generateTOCTitle(title): + title = EvernoteNoteTitle.titleObjectToString(title).upper() + for chr in u'αβδφḃ': + title = title.replace(chr.upper(), chr) + return title + + +class EvernoteNoteTitle: + level = 0 + __title = "" + """:type: str""" + __titleParts = None + """:type: list[str]""" + + # # Parent = None + # def __str__(self): + # return "%d: %s" % (self.Level(), self.Title) + + def __repr__(self): + return "<%s:%s>" % (self.__class__.__name__, self.FullTitle) + + @property + def TitleParts(self): + if not self.FullTitle: + return [] + if not self.__titleParts: + self.__titleParts = generateTitleParts(self.FullTitle) + return self.__titleParts + + @property + def Level(self): + """ + :rtype: int + :return: Current Level with 1 being the Root Title + """ + if not self.level: + self.level = len(self.TitleParts) + return self.level + + @property + def Depth(self): + return self.Level - 1 + + def Parts(self, level=-1): + return self.Slice(level) + + def Part(self, level=-1): + mySlice = self.Parts(level) + if not mySlice: + return None + return mySlice.Root + + def BaseParts(self, level=None): + return self.Slice(1, level) + + def Parents(self, level=-1): + # noinspection PyTypeChecker + return self.Slice(None, level) + + def Names(self, level=-1): + return self.Parts(level) + + @property + def TOCTitle(self): + return generateTOCTitle(self.FullTitle) + + @property + def TOCName(self): + return generateTOCTitle(self.Name) + + @property + def TOCRootTitle(self): + return generateTOCTitle(self.Root) + + @property + def Name(self): + return self.Part() + + @property + def Root(self): + return self.Parents(1).FullTitle + + @property + def Base(self): + return self.BaseParts() + + def Slice(self, start=0, end=None): + # print "Slicing: <%s> %s ~ %d,%d" % (type(self.Title), self.Title, start, end) + oldParts = self.TitleParts + # print "Slicing: %s ~ %d,%d from parts %s" % (self.Title, start, end, str(oldParts)) + assert self.FullTitle and oldParts + if start is None and end is None: + print "Slicing: %s ~ %d,%d from parts %s" % (self.FullTitle, start, end, str(oldParts)) + assert start is not None or end is not None + newParts = oldParts[start:end] + if not newParts: + log_error("Slice failed for %s-%s of %s" % (str(start), str(end), self.FullTitle)) + # return None + assert len(newParts) > 0 + newStr = ': '.join(newParts) + # print "Slice: Just created new title %s from %s" % (newStr , self.Title) + return EvernoteNoteTitle(newStr) + + @property + def Parent(self): + return self.Parents() + + def IsAboveLevel(self, level_check): + return self.Level > level_check + + def IsBelowLevel(self, level_check): + return self.Level < level_check + + def IsLevel(self, level_check): + return self.Level == level_check + + @property + def IsChild(self): + return self.IsAboveLevel(1) + + @property + def IsRoot(self): + return self.IsLevel(1) + + @staticmethod + def titleObjectToString(title, recursion=0): + """ + :param title: Title in string, unicode, dict, sqlite, TOCKey or NoteTitle formats. Note objects are also parseable + :type title: None | str | unicode | dict[str,str] | sqlite.Row | EvernoteNoteTitle + :return: string Title + :rtype: str + """ + # if recursion == 0: + # str_ = str_safe(title) + # try: log(u'\n---------------------------------%s' % str_, 'tOTS', timestamp=False) + # except Exception: log(u'\n---------------------------------%s' % '[UNABLE TO DISPLAY TITLE]', 'tOTS', timestamp=False) + # pass + + if title is None: + # log('NoneType', 'tOTS', timestamp=False) + return "" + if is_str_type(title): + # log('str/unicode', 'tOTS', timestamp=False) + return title + if hasattr(title, 'FullTitle'): + # log('FullTitle', 'tOTS', timestamp=False) + # noinspection PyCallingNonCallable + title = title.FullTitle() if callable(title.FullTitle) else title.FullTitle + elif hasattr(title, 'Title'): + # log('Title', 'tOTS', timestamp=False) + title = title.Title() if callable(title.Title) else title.Title + elif hasattr(title, 'title'): + # log('title', 'tOTS', timestamp=False) + title = title.title() if callable(title.title) else title.title + else: + try: + if hasattr(title, 'keys'): + keys = title.keys() if callable(title.keys) else title.keys + if 'title' in keys: + # log('keys[title]', 'tOTS', timestamp=False) + title = title['title'] + elif 'Title' in keys: + # log('keys[Title]', 'tOTS', timestamp=False) + title = title['Title'] + elif not keys: + # log('keys[empty dict?]', 'tOTS', timestamp=False) + raise + else: + log('keys[Unknown Attr]: %s' % str(keys), 'tOTS', timestamp=False) + return "" + elif 'title' in title: + # log('[title]', 'tOTS', timestamp=False) + title = title['title'] + elif 'Title' in title: + # log('[Title]', 'tOTS', timestamp=False) + title = title['Title'] + elif FIELDS.TITLE in title: + # log('[FIELDS.TITLE]', 'tOTS', timestamp=False) + title = title[FIELDS.TITLE] + else: + # log('Nothing Found', 'tOTS', timestamp=False) + # log(title) + # log(title.keys()) + return title + except Exception: + log('except', 'tOTS', timestamp=False) + log(title, 'toTS', timestamp=False) + raise LookupError + recursion += 1 + # log(u'recursing %d: ' % recursion, 'tOTS', timestamp=False) + return EvernoteNoteTitle.titleObjectToString(title, recursion) + + @property + def FullTitle(self): + """:rtype: str""" + return self.__title + + @property + def HTML(self): + return self.__html + + def __init__(self, titleObj=None): + """:type titleObj: str | unicode | sqlite.Row | EvernoteNoteTitle | evernote.edam.type.ttypes.Note | EvernoteNotePrototype.EvernoteNotePrototype """ + self.__html = self.titleObjectToString(titleObj) + self.__title = strip_tags_and_new_lines(self.__html) + + +def generateTitleParts(title): + title = EvernoteNoteTitle.titleObjectToString(title) + try: + strTitle = re.sub(':+', ':', title) + except Exception: + log('generateTitleParts Unable to re.sub') + log(type(title)) + raise + strTitle = strTitle.strip(':') + partsText = strTitle.split(':') + count = len(partsText) + for i in range(1, count + 1): + txt = partsText[i - 1] + try: + txt = txt.strip() + except Exception: + print_safe(title + ' -- ' + '"' + txt + '"') + raise + partsText[i - 1] = txt + return partsText diff --git a/anknotes/EvernoteNotes.py b/anknotes/EvernoteNotes.py new file mode 100644 index 0000000..95a9660 --- /dev/null +++ b/anknotes/EvernoteNotes.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +### Python Imports +from operator import itemgetter + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### Anknotes Main Imports +from anknotes.base import encode +from anknotes.shared import * +from anknotes.EvernoteNoteTitle import * +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype +from anknotes.toc import TOCHierarchyClass +from anknotes.db import ankDB +from anknotes import stopwatch + +### Anknotes Class Imports +from anknotes.EvernoteNoteTitle import generateTOCTitle + +class EvernoteNoteProcessingFlags: + delayProcessing = False + populateRootTitlesList = True + populateRootTitlesDict = True + populateExistingRootTitlesList = False + populateExistingRootTitlesDict = False + populateMissingRootTitlesList = False + populateMissingRootTitlesDict = False + populateChildRootTitles = False + ignoreAutoTOCAsRootTitle = False + ignoreOutlineAsRootTitle = False + + def __init__(self, flags=None): + if isinstance(flags, bool): + if not flags: + self.set_default(False) + if flags: + self.update(flags) + + def set_default(self, flag): + self.populateRootTitlesList = flag + self.populateRootTitlesDict = flag + + def update(self, flags): + for flag_name, flag_value in flags: + if hasattr(self, flag_name): + setattr(self, flag_name, flag_value) + + +class EvernoteNotesCollection: + TitlesList = [] + TitlesDict = {} + NotesDict = {} + """:type : dict[str, EvernoteNote.EvernoteNote]""" + ChildNotesDict = {} + """:type : dict[str, EvernoteNote.EvernoteNote]""" + ChildTitlesDict = {} + + def __init__(self): + self.TitlesList = [] + self.TitlesDict = {} + self.NotesDict = {} + self.ChildNotesDict = {} + self.ChildTitlesDict = {} + + +class EvernoteNotes: + ################## CLASS Notes ################ + Notes = {} + """:type : dict[str, EvernoteNote.EvernoteNote]""" + RootNotes = EvernoteNotesCollection() + RootNotesChildren = EvernoteNotesCollection() + processingFlags = EvernoteNoteProcessingFlags() + baseQuery = "1" + + def __init__(self, delayProcessing=False): + self.processingFlags.delayProcessing = delayProcessing + self.RootNotes = EvernoteNotesCollection() + + def addNoteSilently(self, enNote): + """:type enNote: EvernoteNote.EvernoteNote""" + assert enNote + self.Notes[enNote.Guid] = enNote + + def addNote(self, enNote): + """:type enNote: EvernoteNote.EvernoteNote""" + assert enNote + self.addNoteSilently(enNote) + if self.processingFlags.delayProcessing: + return + self.processNote(enNote) + + def addDBNote(self, dbNote): + """:type dbNote: sqlite.Row""" + enNote = EvernoteNotePrototype(db_note=dbNote) + if not enNote: + log(dbNote) + log(dbNote.keys) + log(dir(dbNote)) + assert enNote + self.addNote(enNote) + + def addDBNotes(self, dbNotes): + """:type dbNotes: list[sqlite.Row]""" + for dbNote in dbNotes: + self.addDBNote(dbNote) + + def addDbQuery(self, sql_query, order=''): + if not sql_query: + sql_query = '1' + if self.baseQuery and self.baseQuery != '1': + if sql_query == '1': + sql_query = self.baseQuery + else: + sql_query = "(%s) AND (%s) " % (self.baseQuery, sql_query) + if order: + sql_query += ' ORDER BY ' + order + dbNotes = ankDB().execute(sql_query) + self.addDBNotes(dbNotes) + + @staticmethod + def getNoteFromDB(query): + """ + + :param query: + :return: + :rtype : sqlite.Row + """ + dbNote = ankDB().first(query) + if not dbNote: + return None + return dbNote + + def getNoteFromDBByGuid(self, guid): + sql_query = "guid = '%s' " % guid + return self.getNoteFromDB(sql_query) + + def getEnNoteFromDBByGuid(self, guid): + return EvernoteNotePrototype(db_note=self.getNoteFromDBByGuid(guid)) + + # def addChildNoteHierarchically(self, enChildNotes, enChildNote): + # parts = enChildNote.Title.TitleParts + # dict_updated = {} + # dict_building = {parts[len(parts)-1]: enChildNote} + # print_safe(parts) + # for i in range(len(parts), 1, -1): + # dict_building = {parts[i - 1]: dict_building} + # log_dump(dict_building) + # enChildNotes.update(dict_building) + # log_dump(enChildNotes) + # return enChildNotes + + def processNote(self, enNote): + """:type enNote: EvernoteNote.EvernoteNote""" + db = ankDB() + if self.processingFlags.populateRootTitlesList or self.processingFlags.populateRootTitlesDict or self.processingFlags.populateMissingRootTitlesList or self.processingFlags.populateMissingRootTitlesDict: + if enNote.IsChild: + # log([enNote.Title, enNote.Level, enNote.Title.TitleParts, enNote.IsChild]) + rootTitle = enNote.Title.Root + rootTitleStr = generateTOCTitle(rootTitle) + if self.processingFlags.populateMissingRootTitlesList or self.processingFlags.populateMissingRootTitlesDict: + if not rootTitleStr in self.RootNotesExisting.TitlesList: + if not rootTitleStr in self.RootNotesMissing.TitlesList: + self.RootNotesMissing.TitlesList.append(rootTitleStr) + self.RootNotesMissing.ChildTitlesDict[rootTitleStr] = {} + self.RootNotesMissing.ChildNotesDict[rootTitleStr] = {} + if not enNote.Title.Base: + log(enNote.Title) + log(enNote.Base) + assert enNote.Title.Base + childBaseTitleStr = enNote.Title.Base.FullTitle + if childBaseTitleStr in self.RootNotesMissing.ChildTitlesDict[rootTitleStr]: + log_error("Duplicate Child Base Title String. \n%-18s%s\n%-18s%s: %s\n%-18s%s" % ( + 'Root Note Title: ', rootTitleStr, 'Child Note: ', enNote.Guid, childBaseTitleStr, + 'Duplicate Note: ', + self.RootNotesMissing.ChildTitlesDict[rootTitleStr][childBaseTitleStr]), + crosspost_to_default=False) + if not hasattr(self, 'loggedDuplicateChildNotesWarning'): + log( + " > WARNING: Duplicate Child Notes found when processing Root Notes. See error log for more details") + self.loggedDuplicateChildNotesWarning = True + self.RootNotesMissing.ChildTitlesDict[rootTitleStr][childBaseTitleStr] = enNote.Guid + self.RootNotesMissing.ChildNotesDict[rootTitleStr][enNote.Guid] = enNote + if self.processingFlags.populateRootTitlesList or self.processingFlags.populateRootTitlesDict: + if not rootTitleStr in self.RootNotes.TitlesList: + self.RootNotes.TitlesList.append(rootTitleStr) + if self.processingFlags.populateRootTitlesDict: + self.RootNotes.TitlesDict[rootTitleStr][enNote.Guid] = enNote.Title.Base + self.RootNotes.NotesDict[rootTitleStr][enNote.Guid] = enNote + if self.processingFlags.populateChildRootTitles or self.processingFlags.populateExistingRootTitlesList or self.processingFlags.populateExistingRootTitlesDict: + if enNote.IsRoot: + rootTitle = enNote.Title + rootTitleStr = generateTOCTitle(rootTitle) + rootGuid = enNote.Guid + if self.processingFlags.populateExistingRootTitlesList or self.processingFlags.populateExistingRootTitlesDict or self.processingFlags.populateMissingRootTitlesList: + if not rootTitleStr in self.RootNotesExisting.TitlesList: + self.RootNotesExisting.TitlesList.append(rootTitleStr) + if self.processingFlags.populateChildRootTitles: + childNotes = db.execute("title LIKE ? || ':%' ORDER BY title ASC", rootTitleStr) + child_count = 0 + for childDbNote in childNotes: + child_count += 1 + childGuid = childDbNote['guid'] + childEnNote = EvernoteNotePrototype(db_note=childDbNote) + if child_count is 1: + self.RootNotesChildren.TitlesDict[rootGuid] = {} + self.RootNotesChildren.NotesDict[rootGuid] = {} + childBaseTitle = childEnNote.Title.Base + self.RootNotesChildren.TitlesDict[rootGuid][childGuid] = childBaseTitle + self.RootNotesChildren.NotesDict[rootGuid][childGuid] = childEnNote + + def processNotes(self, populateRootTitlesList=True, populateRootTitlesDict=True): + if self.processingFlags.populateRootTitlesList or self.processingFlags.populateRootTitlesDict: + self.RootNotes = EvernoteNotesCollection() + + self.processingFlags.populateRootTitlesList = populateRootTitlesList + self.processingFlags.populateRootTitlesDict = populateRootTitlesDict + + for guid, enNote in self.Notes: + self.processNote(enNote) + + def processAllChildNotes(self): + self.processingFlags.populateRootTitlesList = True + self.processingFlags.populateRootTitlesDict = True + self.processNotes() + + def populateAllRootTitles(self): + self.getChildNotes() + self.processAllRootTitles() + + def processAllRootTitles(self): + count = 0 + for rootTitle, baseTitles in self.RootNotes.TitlesDict.items(): + count += 1 + baseNoteCount = len(baseTitles) + query = "UPPER(title) = '%s'" % escape_text_sql(rootTitle).upper() + if self.processingFlags.ignoreAutoTOCAsRootTitle: + query += " AND tagNames NOT LIKE '%%,%s,%%'" % TAGS.TOC_AUTO + if self.processingFlags.ignoreOutlineAsRootTitle: + query += " AND tagNames NOT LIKE '%%,%s,%%'" % TAGS.OUTLINE + rootNote = self.getNoteFromDB(query) + if rootNote: + self.RootNotesExisting.TitlesList.append(rootTitle) + else: + self.RootNotesMissing.TitlesList.append(rootTitle) + print_safe(rootNote, ' TOP LEVEL: [%4d::%2d]: [%7s] ' % (count, baseNoteCount, 'is_toc_outline_str')) + # for baseGuid, baseTitle in baseTitles: + # pass + + def getChildNotes(self): + self.addDbQuery("title LIKE '%%:%%'", 'title ASC') + + def getRootNotes(self): + query = "title NOT LIKE '%%:%%'" + if self.processingFlags.ignoreAutoTOCAsRootTitle: + query += " AND tagNames NOT LIKE '%%,%s,%%'" % TAGS.TOC_AUTO + if self.processingFlags.ignoreOutlineAsRootTitle: + query += " AND tagNames NOT LIKE '%%,%s,%%'" % TAGS.OUTLINE + self.addDbQuery(query, 'title ASC') + + def populateAllPotentialRootNotes(self): + self.RootNotesMissing = EvernoteNotesCollection() + processingFlags = EvernoteNoteProcessingFlags(False) + processingFlags.populateMissingRootTitlesList = True + processingFlags.populateMissingRootTitlesDict = True + self.processingFlags = processingFlags + + log_banner(" CHECKING FOR ALL POTENTIAL ROOT TITLES ", 'RootTitles\\TOC', clear=True, timestamp=False) + log_banner(" CHECKING FOR ISOLATED ROOT TITLES ", 'RootTitles\\Isolated', clear=True, timestamp=False) + self.getChildNotes() + log("Total %d Missing Root Titles" % len(self.RootNotesMissing.TitlesList), 'RootTitles\\TOC', + timestamp=False) + self.RootNotesMissing.TitlesList = sorted(self.RootNotesMissing.TitlesList, key=lambda s: s.lower()) + + return self.processAllRootNotesMissing() + + def populateAllNonCustomRootNotes(self): + return self.populateAllRootNotesMissing(True, True) + + def populateAllRootNotesMissing(self, ignoreAutoTOCAsRootTitle=False, ignoreOutlineAsRootTitle=False): + processingFlags = EvernoteNoteProcessingFlags(False) + processingFlags.populateMissingRootTitlesList = True + processingFlags.populateMissingRootTitlesDict = True + processingFlags.populateExistingRootTitlesList = True + processingFlags.populateExistingRootTitlesDict = True + processingFlags.ignoreAutoTOCAsRootTitle = ignoreAutoTOCAsRootTitle + processingFlags.ignoreOutlineAsRootTitle = ignoreOutlineAsRootTitle + self.processingFlags = processingFlags + self.RootNotesExisting = EvernoteNotesCollection() + self.RootNotesMissing = EvernoteNotesCollection() + # log(', '.join(self.RootNotesMissing.TitlesList)) + self.getRootNotes() + + log_banner(" CHECKING FOR MISSING ROOT TITLES ", 'RootTitles\\Missing', clear=True, timestamp=False) + log_banner(" CHECKING FOR ISOLATED ROOT TITLES ", 'RootTitles\\Isolated', clear=True, timestamp=False) + log("Total %d Existing Root Titles" % len(self.RootNotesExisting.TitlesList), 'RootTitles\\Missing', + timestamp=False) + self.getChildNotes() + log("Total %d Missing Root Titles" % len(self.RootNotesMissing.TitlesList), 'RootTitles\\Missing', + timestamp=False) + self.RootNotesMissing.TitlesList = sorted(self.RootNotesMissing.TitlesList, key=lambda s: s.lower()) + + return self.processAllRootNotesMissing() + + def processAllRootNotesMissing(self): + """:rtype : list[EvernoteTOCEntry]""" + DEBUG_HTML = False + # log (" CREATING TOC's " , 'tocList', clear=True, timestamp=False) + # log ("------------------------------------------------" , 'tocList', timestamp=False) + # if DEBUG_HTML: log('<h1>CREATING TOCs</h1>', 'extra\\logs\\toc-ols\\toc-index.htm', timestamp=False, clear=True, extension='htm') + ols = [] + dbRows = [] + returns = [] + """:type : list[EvernoteTOCEntry]""" + db = ankDB(TABLES.TOC_AUTO) + db.delete("1", table=db.table) + db.commit() + # olsz = None + tmr = stopwatch.Timer(self.RootNotesMissing.TitlesList, infoStr='Processing Root Notes', label='RootTitles\\') + for rootTitleStr in self.RootNotesMissing.TitlesList: + count_child = 0 + childTitlesDictSortedKeys = sorted(self.RootNotesMissing.ChildTitlesDict[rootTitleStr], + key=lambda s: s.lower()) + total_child = len(childTitlesDictSortedKeys) + tags = [] + outline = self.getNoteFromDB("UPPER(title) = '%s' AND tagNames LIKE '%%,%s,%%'" % ( + escape_text_sql(rootTitleStr.upper()), TAGS.OUTLINE)) + currentAutoNote = self.getNoteFromDB("UPPER(title) = '%s' AND tagNames LIKE '%%,%s,%%'" % ( + escape_text_sql(rootTitleStr.upper()), TAGS.TOC_AUTO)) + notebookGuids = {} + childGuid = None + is_isolated = total_child is 1 and not outline + if is_isolated: + tmr.counts.isolated.step() + childBaseTitle = childTitlesDictSortedKeys[0] + childGuid = self.RootNotesMissing.ChildTitlesDict[rootTitleStr][childBaseTitle] + enChildNote = self.RootNotesMissing.ChildNotesDict[rootTitleStr][childGuid] + # tags = enChildNote.Tags + log(" > ISOLATED ROOT TITLE: [%-3d]: %-60s --> %-40s: %s" % ( + tmr.counts.isolated.val, rootTitleStr + ':', childBaseTitle, childGuid), tmr.label + 'Isolated', + timestamp=False) + else: + tmr.counts.created.completed.step() + log_blank(tmr.label + 'TOC') + log(" [%-3d] %s %s" % (tmr.count, rootTitleStr, '(O)' if outline else ' '), tmr.label + 'TOC', + timestamp=False) + + tmr.step(rootTitleStr) + + if is_isolated: + continue + + tocHierarchy = TOCHierarchyClass(rootTitleStr) + if outline: + tocHierarchy.Outline = TOCHierarchyClass(note=outline) + tocHierarchy.Outline.parent = tocHierarchy + + for childBaseTitle in childTitlesDictSortedKeys: + count_child += 1 + childGuid = self.RootNotesMissing.ChildTitlesDict[rootTitleStr][childBaseTitle] + enChildNote = self.RootNotesMissing.ChildNotesDict[rootTitleStr][childGuid] + if count_child == 1: + tags = enChildNote.Tags + else: + tags = [x for x in tags if x in enChildNote.Tags] + if not enChildNote.NotebookGuid in notebookGuids: + notebookGuids[enChildNote.NotebookGuid] = 0 + notebookGuids[enChildNote.NotebookGuid] += 1 + level = enChildNote.Title.Level + # childName = enChildNote.Title.Name + # childTitle = enChildNote.FullTitle + log(" %2d: %d. --> %-60s" % (count_child, level, childBaseTitle), + tmr.label + 'TOC', timestamp=False) + # tocList.generateEntry(childTitle, enChildNote) + tocHierarchy.addNote(enChildNote) + realTitle = get_evernote_title_from_guid(childGuid) + realTitle = realTitle[0:realTitle.index(':')] + # realTitleUTF8 = realTitle.encode('utf8') + notebookGuid = sorted(notebookGuids.items(), key=itemgetter(1), reverse=True)[0][0] + + real_root_title = generateTOCTitle(realTitle) + + ol = tocHierarchy.GetOrderedList() + tocEntry = EvernoteTOCEntry(real_root_title, ol, ',' + ','.join(tags) + ',', notebookGuid) + returns.append(tocEntry) + dbRows.append(tocEntry.items()) + + if not DEBUG_HTML: + continue + + # ols.append(ol) + # olutf8 = encode(ol) + # fn = 'toc-ols\\toc-' + str(tmr.count) + '-' + rootTitleStr.replace('\\', '_') + '.htm' + # full_path = os.path.join(FOLDERS.LOGS, fn) + # if not os.path.exists(os.path.dirname(full_path)): + # os.mkdir(os.path.dirname(full_path)) + # file_object = open(full_path, 'w') + # file_object.write(olutf8) + # file_object.close() + + # if DEBUG_HTML: log(ol, 'toc-ols\\toc-' + str(count) + '-' + rootTitleStr.replace('\\', '_'), timestamp=False, clear=True, extension='htm') + # log("Created TOC #%d:\n%s\n\n" % (count, str_), 'tocList', timestamp=False) + if DEBUG_HTML: + ols_html = u'\r\n<BR><BR><HR><BR><BR>\r\n'.join(ols) + fn = 'toc-ols\\toc-index.htm' + file_object = open(os.path.join(FOLDERS.LOGS, fn), 'w') + try: + file_object.write(u'<h1>CREATING TOCs</h1>\n\n' + ols_html) + except Exception: + try: + file_object.write(u'<h1>CREATING TOCs</h1>\n\n' + encode(ols_html)) + except Exception: + pass + + file_object.close() + + db.executemany("INSERT INTO {t} (root_title, contents, tagNames, notebookGuid) VALUES(?, ?, ?, ?)", dbRows) + db.commit() + + return returns + + def populateAllRootNotesWithoutTOCOrOutlineDesignation(self): + processingFlags = EvernoteNoteProcessingFlags() + processingFlags.populateRootTitlesList = False + processingFlags.populateRootTitlesDict = False + processingFlags.populateChildRootTitles = True + self.processingFlags = processingFlags + self.getRootNotes() + self.processAllRootNotesWithoutTOCOrOutlineDesignation() + + def processAllRootNotesWithoutTOCOrOutlineDesignation(self): + count = 0 + for rootGuid, childBaseTitleDicts in self.RootNotesChildren.TitlesDict.items(): + rootEnNote = self.Notes[rootGuid] + if len(childBaseTitleDicts.items()) > 0: + is_toc = TAGS.TOC in rootEnNote.Tags + is_outline = TAGS.OUTLINE in rootEnNote.Tags + is_both = is_toc and is_outline + is_none = not is_toc and not is_outline + is_toc_outline_str = "BOTH ???" if is_both else "TOC" if is_toc else "OUTLINE" if is_outline else "N/A" + if is_none: + count += 1 + print_safe(rootEnNote, ' TOP LEVEL: [%3d] %-8s: ' % (count, is_toc_outline_str)) diff --git a/anknotes/README.md b/anknotes/README.md deleted file mode 100644 index 0204cc8..0000000 --- a/anknotes/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Evernote2Anki Importer (beta) -**Forks and suggestions are very welcome.** - -## Description -An Anki plug-in aiming for syncing evernote account with anki directly from anki. It aims to replace a Java standalone application [available here] (https://code.google.com/p/evernote2anki/) -Very rudimentary for the moment. I wait for suggestions according to the needs of evernote/anki users. - -## Users : How to use it -- download everything, move it to your Anki/addons directory -- start Anki, fill in the Infromation in the prefrences tap and then press Import from Evernote --When you run it the first Time a browser tab will open on the evernote site asking you for access to your account -- when you click ok you are taken to a website where the oauth verification key is displayed you paste that key into the open anki windows and click ok with that you are set. - -## Features and further development -####Current feature : -- Import all the notes from evernote with selected tags -- Possibility to choose the name of the deck, as well as the default tag in anki (but should not be changed) -- Does not import twice a card (only new cards are imported) -- - A window allowing the user to change the options (instead of manual edit of options.cfg) - -####Desirable new features (?) : - -- Updating anki cards accordingly the edit of evernote notes. diff --git a/anknotes/___sqlite3.py b/anknotes/___sqlite3.py new file mode 100644 index 0000000..c60baed --- /dev/null +++ b/anknotes/___sqlite3.py @@ -0,0 +1,190 @@ +"""Skeleton for 'sqlite3' stdlib module.""" + + +import sqlite3 + + +def connect(database, timeout=5.0, detect_types=0, isolation_level=None, + check_same_thread=False, factory=None, cached_statements=100): + """Opens a connection to the SQLite database file database. + + :type database: bytes | unicode + :type timeout: float + :type detect_types: int + :type isolation_level: string | None + :type check_same_thread: bool + :type factory: (() -> sqlite3.Connection) | None + :rtype: sqlite3.Connection + """ + return sqlite3.Connection() + + +def register_converter(typename, callable): + """Registers a callable to convert a bytestring from the database into a + custom Python type. + + :type typename: string + :type callable: (bytes) -> unknown + :rtype: None + """ + pass + + +def register_adapter(type, callable): + """Registers a callable to convert the custom Python type type into one of + SQLite's supported types. + + :type type: type + :type callable: (unknown) -> unknown + :rtype: None + """ + pass + + +def complete_statement(sql): + """Returns True if the string sql contains one or more complete SQL + statements terminated by semicolons. + + :type sql: string + :rtype: bool + """ + return False + + +def enable_callback_tracebacks(flag): + """By default you will not get any tracebacks in user-defined functions, + aggregates, converters, authorizer callbacks etc. + + :type flag: bool + :rtype: None + """ + pass + + +class Connection(object): + """A SQLite database connection.""" + + def cursor(self, cursorClass=None): + """ + :type cursorClass: type | None + :rtype: sqlite3.Cursor + """ + return sqlite3.Cursor() + + def execute(self, sql, parameters=()): + """This is a nonstandard shortcut that creates an intermediate cursor + object by calling the cursor method, then calls the cursor's execute + method with the parameters given. + + :type sql: string + :type parameters: collections.Iterable + :rtype: sqlite3.Cursor + """ + pass + + def executemany(self, sql, seq_of_parameters=()): + """This is a nonstandard shortcut that creates an intermediate cursor + object by calling the cursor method, then calls the cursor's + executemany method with the parameters given. + + :type sql: string + :type seq_of_parameters: collections.Iterable[collections.Iterable] + :rtype: sqlite3.Cursor + """ + pass + + def executescript(self, sql_script): + """This is a nonstandard shortcut that creates an intermediate cursor + object by calling the cursor method, then calls the cursor's + executescript method with the parameters given. + + :type sql_script: bytes | unicode + :rtype: sqlite3.Cursor + """ + pass + + def create_function(self, name, num_params, func): + """Creates a user-defined function that you can later use from within + SQL statements under the function name name. + + :type name: string + :type num_params: int + :type func: collections.Callable + :rtype: None + """ + pass + + + def create_aggregate(self, name, num_params, aggregate_class): + """Creates a user-defined aggregate function. + + :type name: string + :type num_params: int + :type aggregate_class: type + :rtype: None + """ + pass + + def create_collation(self, name, callable): + """Creates a collation with the specified name and callable. + + :type name: string + :type callable: collections.Callable + :rtype: None + """ + pass + + +class Cursor(object): + """A SQLite database cursor.""" + + def execute(self, sql, parameters=()): + """Executes an SQL statement. + + :type sql: string + :type parameters: collections.Iterable + :rtype: sqlite3.Cursor + """ + pass + + def executemany(self, sql, seq_of_parameters=()): + """Executes an SQL command against all parameter sequences or mappings + found in the sequence. + + :type sql: string + :type seq_of_parameters: collections.Iterable[collections.Iterable] + :rtype: sqlite3.Cursor + """ + pass + + def executescript(self, sql_script): + """This is a nonstandard convenience method for executing multiple SQL + statements at once. + + :type sql_script: bytes | unicode + :rtype: sqlite3.Cursor + """ + pass + + def fetchone(self): + """Fetches the next row of a query result set, returning a single + sequence, or None when no more data is available. + + :rtype: tuple | None + """ + pass + + def fetchmany(self, size=-1): + """Fetches the next set of rows of a query result, returning a list. + + :type size: numbers.Integral + :rtype: list[tuple] + """ + return [] + + def fetchall(self): + """Fetches all (remaining) rows of a query result, returning a list. + + :rtype: list[tuple] + """ + return [] diff --git a/anknotes/__main__.py b/anknotes/__main__.py index f594b40..4aebb5a 100644 --- a/anknotes/__main__.py +++ b/anknotes/__main__.py @@ -1,393 +1,373 @@ +# -*- coding: utf-8 -*- +### Python Imports import os - -# from thrift.Thrift import * -from evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec -from evernote.edam.error.ttypes import EDAMSystemException, EDAMErrorCode -from evernote.api.client import EvernoteClient -# from evernote.edam.type.ttypes import SavedSearch - -import anki -import aqt -from anki.hooks import wrap +import re, sre_constants +import sys + +try: + from pysqlite2 import dbapi2 as sqlite + is_pysqlite = True +except ImportError: + from sqlite3 import dbapi2 as sqlite + is_pysqlite = False + + +### Anknotes Shared Imports +from anknotes.imports import in_anki +from anknotes.shared import * +from anknotes import stopwatch + +### Anknotes Main Imports +from anknotes import menu, settings + +### Anki Imports +if ANKNOTES.HOOKS.SEARCH: + from anki.find import Finder + from aqt import browser +if ANKNOTES.HOOKS.DB: + from anki.db import DB +from anki.hooks import wrap, addHook from aqt.preferences import Preferences -from aqt.utils import showInfo, getText, openLink, getOnlyText -from aqt.qt import QLineEdit, QLabel, QVBoxLayout, QGroupBox, SIGNAL, QCheckBox, QComboBox, QSpacerItem, QSizePolicy, QWidget from aqt import mw -# from pprint import pprint - - -# Note: This class was adapted from the Real-Time_Import_for_use_with_the_Rikaisama_Firefox_Extension plug-in -# by cb4960@gmail.com -# .. itself adapted from Yomichan plugin by Alex Yatskov. - -PATH = os.path.dirname(os.path.abspath(__file__)) -EVERNOTE_MODEL = 'evernote_note' -EVERNOTE_TEMPLATE_NAME = 'EvernoteReview' -TITLE_FIELD_NAME = 'title' -CONTENT_FIELD_NAME = 'content' -GUID_FIELD_NAME = 'Evernote GUID' - -SETTING_UPDATE_EXISTING_NOTES = 'evernoteUpdateExistingNotes' -SETTING_TOKEN = 'evernoteToken' -SETTING_KEEP_TAGS = 'evernoteKeepTags' -SETTING_TAGS_TO_IMPORT = 'evernoteTagsToImport' -SETTING_DEFAULT_TAG = 'evernoteDefaultTag' -SETTING_DEFAULT_DECK = 'evernoteDefaultDeck' - -class UpdateExistingNotes: - IgnoreExistingNotes, UpdateNotesInPlace, DeleteAndReAddNotes = range(3) - -class Anki: - def update_evernote_cards(self, evernote_cards, tag): - return self.add_evernote_cards(evernote_cards, None, tag, True) - - def add_evernote_cards(self, evernote_cards, deck, tag, update=False): - count = 0 - model_name = EVERNOTE_MODEL - for card in evernote_cards: - anki_field_info = {TITLE_FIELD_NAME: card.front.decode('utf-8'), - CONTENT_FIELD_NAME: card.back.decode('utf-8'), - GUID_FIELD_NAME: card.guid} - card.tags.append(tag) - if update: - self.update_note(anki_field_info, card.tags) - else: - self.add_note(deck, model_name, anki_field_info, card.tags) - count += 1 - return count - - def delete_anki_cards(self, guid_ids): - col = self.collection() - card_ids = [] - for guid in guid_ids: - card_ids += mw.col.findCards(guid) - col.remCards(card_ids) - return len(card_ids) - - def update_note(self, fields, tags=list()): - col = self.collection() - note_id = col.findNotes(fields[GUID_FIELD_NAME])[0] - note = anki.notes.Note(col, None, note_id) - note.tags = tags - for fld in note._model['flds']: - if TITLE_FIELD_NAME in fld.get('name'): - note.fields[fld.get('ord')] = fields[TITLE_FIELD_NAME] - elif CONTENT_FIELD_NAME in fld.get('name'): - note.fields[fld.get('ord')] = fields[CONTENT_FIELD_NAME] - # we dont have to update the evernote guid because if it changes we wont find this note anyway - note.flush() - return note.id - - def add_note(self, deck_name, model_name, fields, tags=list()): - note = self.create_note(deck_name, model_name, fields, tags) - if note is not None: - collection = self.collection() - collection.addNote(note) - collection.autosave() - self.start_editing() - return note.id - - def create_note(self, deck_name, model_name, fields, tags=list()): - id_deck = self.decks().id(deck_name) - model = self.models().byName(model_name) - col = self.collection() - note = anki.notes.Note(col, model) - note.model()['did'] = id_deck - note.tags = tags - for name, value in fields.items(): - note[name] = value - return note - - def add_evernote_model(self): # adapted from the IREAD plug-in from Frank - col = self.collection() - mm = col.models - evernote_model = mm.byName(EVERNOTE_MODEL) - if evernote_model is None: - evernote_model = mm.new(EVERNOTE_MODEL) - # Field for title: - model_field = mm.newField(TITLE_FIELD_NAME) - mm.addField(evernote_model, model_field) - # Field for text: - text_field = mm.newField(CONTENT_FIELD_NAME) - mm.addField(evernote_model, text_field) - # Field for source: - guid_field = mm.newField(GUID_FIELD_NAME) - guid_field['sticky'] = True - mm.addField(evernote_model, guid_field) - # Add template - t = mm.newTemplate(EVERNOTE_TEMPLATE_NAME) - t['qfmt'] = "{{" + TITLE_FIELD_NAME + "}}" - t['afmt'] = "{{" + CONTENT_FIELD_NAME + "}}" - mm.addTemplate(evernote_model, t) - mm.add(evernote_model) - return evernote_model +from aqt.qt import Qt, QIcon, QTreeWidget, QTreeWidgetItem, QDesktopServices, QUrl +from anki.utils import ids2str, splitFields + +def import_timer_toggle(): + title = "&Enable Auto Import On Profile Load" + doAutoImport = mw.col.conf.get( + SETTINGS.ANKNOTES_CHECKABLE_MENU_ITEMS_PREFIX + '_' + title.replace(' ', '_').replace('&', ''), False) + if not doAutoImport: + return + lastImport = SETTINGS.EVERNOTE.LAST_IMPORT.fetch() + importDelay = 0 + if lastImport: + td = (datetime.now() - datetime.strptime(lastImport, ANKNOTES.DATE_FORMAT)) + minimum = timedelta(seconds=max(EVERNOTE.IMPORT.INTERVAL, 20 * 60)) + if td < minimum: + importDelay = (minimum - td).total_seconds() + if importDelay is 0: + return menu.import_from_evernote() + m, s = divmod(importDelay, 60) + log("> Starting Auto Import, Triggered by Profile Load, in %d:%02d min" % (m, s)) + return create_timer(importDelay, menu.import_from_evernote) + +def _findEdited((val, args)): + try: + days = int(val) + except ValueError: + return None + return "c.mod > %d" % (time.time() - days * 86400) + + +def _findAnknotes((val, args)): + tmr = stopwatch.Timer(label='finder\\findAnknotes', begin=False) + log_banner("FINDANKNOTES SEARCH: " + val.upper().replace('_', ' '), tmr.label, append_newline=False, clear=False) + if not hasattr(_findAnknotes, 'note_ids'): + _findAnknotes.note_ids = {} + if val == 'hierarchical' or val == 'hierarchical_alt' and ( + val not in _findAnknotes.note_ids or not ANKNOTES.CACHE_SEARCHES): + tmr.reset() + val_root = val.replace('hierarchical', 'root') + val_child = val.replace('hierarchical', 'child') + _findAnknotes((val_root, None), ) + _findAnknotes((val_child, None), ) + _findAnknotes.note_ids[val] = _findAnknotes.note_ids[val_root] + _findAnknotes.note_ids[val_child] + write_file_contents(" > %s Search Complete: ".ljust(25) % val.upper().replace('_', ' ') + "%-5s --> %3d results" % ( + tmr.str_long, len(_findAnknotes.note_ids[val])), tmr.label) + + if not hasattr(_findAnknotes, 'queries'): + _findAnknotes.queries = { + 'all': get_evernote_model_ids(True), + 'sub': 'n.sfld like "%:%"', + 'root_alt': "n.sfld NOT LIKE '%:%' AND ank.title LIKE n.sfld || ':%'", + 'child_alt': "n.sfld LIKE '%%:%%' AND UPPER(SUBSTR(n.sfld, 0, INSTR(n.sfld, ':'))) IN (SELECT UPPER(title) FROM %s WHERE title NOT LIKE '%%:%%' AND tagNames LIKE '%%,%s,%%') " % ( + TABLES.EVERNOTE.NOTES, TAGS.TOC), + 'orphan_alt': "n.sfld LIKE '%%:%%' AND UPPER(SUBSTR(n.sfld, 0, INSTR(n.sfld, ':'))) NOT IN (SELECT UPPER(title) FROM %s WHERE title NOT LIKE '%%:%%' AND tagNames LIKE '%%,%s,%%') " % ( + TABLES.EVERNOTE.NOTES, TAGS.TOC) + } + + if val not in _findAnknotes.note_ids or (not ANKNOTES.CACHE_SEARCHES and 'hierarchical' not in val): + tmr.reset() + if val == 'root': + _findAnknotes.note_ids[val] = get_anknotes_root_notes_nids() + elif val == 'child': + _findAnknotes.note_ids[val] = get_anknotes_child_notes_nids() + elif val == 'orphan': + _findAnknotes.note_ids[val] = get_anknotes_orphan_notes_nids() + elif val in _findAnknotes.queries: + pred = _findAnknotes.queries[val] + col = 'n.id' + table = 'notes n' + if 'ank.' in pred: + col = 'DISTINCT ' + col + table += ', %s ank' % TABLES.EVERNOTE.NOTES + sql = 'select %s from %s where ' % (col, table) + pred + _findAnknotes.note_ids[val] = ankDB().list(sql) else: - fmap = mm.fieldMap(evernote_model) - title_ord, title_field = fmap[TITLE_FIELD_NAME] - text_ord, text_field = fmap[CONTENT_FIELD_NAME] - source_ord, source_field = fmap[GUID_FIELD_NAME] - source_field['sticky'] = False - - def get_guids_from_anki_id(self, ids): - guids = [] - for a_id in ids: - card = self.collection().getCard(a_id) - items = card.note().items() - if len(items) == 3: - guids.append(items[2][1]) # not a very smart access - return guids - - def can_add_note(self, deck_name, model_name, fields): - return bool(self.create_note(deck_name, model_name, fields)) - - def get_cards_id_from_tag(self, tag): - query = "tag:" + tag - ids = self.collection().findCards(query) - return ids - - def start_editing(self): - self.window().requireReset() - - def stop_editing(self): - if self.collection(): - self.window().maybeReset() - - def window(self): - return aqt.mw - - def collection(self): - return self.window().col - - def models(self): - return self.collection().models - - def decks(self): - return self.collection().decks - - -class EvernoteCard: - front = "" - back = "" - guid = "" - - def __init__(self, q, a, g, tags): - self.front = q - self.back = a - self.guid = g - self.tags = tags - - -class Evernote: - def __init__(self): - if not mw.col.conf.get(SETTING_TOKEN, False): - # First run of the Plugin we did not save the access key yet - client = EvernoteClient( - consumer_key='scriptkiddi-2682', - consumer_secret='965f1873e4df583c', - sandbox=False - ) - request_token = client.get_request_token('https://fap-studios.de/anknotes/index.html') - url = client.get_authorize_url(request_token) - showInfo("We will open a Evernote Tab in your browser so you can allow access to your account") - openLink(url) - oauth_verifier = getText(prompt="Please copy the code that showed up, after allowing access, in here")[0] - auth_token = client.get_access_token( - request_token.get('oauth_token'), - request_token.get('oauth_token_secret'), - oauth_verifier) - mw.col.conf[SETTING_TOKEN] = auth_token + return None + write_file_contents(" > Cached %s Note IDs: ".ljust(25) % val + "%-5s --> %3d results" % ( + tmr.str_long, len(_findAnknotes.note_ids[val])), tmr.label) + else: + write_file_contents(" > Retrieving %3d %s Note IDs from Cache" % (len(_findAnknotes.note_ids[val]), val), tmr.label) + log_blank(tmr.label) + return "c.nid IN %s" % ids2str(_findAnknotes.note_ids[val]) + + +class CallbackItem(QTreeWidgetItem): + def __init__(self, root, name, onclick, oncollapse=None): + QTreeWidgetItem.__init__(self, root, [name]) + self.onclick = onclick + self.oncollapse = oncollapse + + +def anknotes_browser_get_icon(icon=None): + if icon: + return QIcon(":/icons/" + icon) + if not hasattr(anknotes_browser_get_icon, 'default_icon'): + from anknotes.graphics import icoEvernoteWeb + anknotes_browser_get_icon.default_icon = icoEvernoteWeb + return anknotes_browser_get_icon.default_icon + + +def anknotes_browser_add_treeitem(self, tree, name, cmd, icon=None, index=None, root=None): + if root is None: + root = tree + def onclick(c=cmd): return self.setFilter(c) + if index: + widgetItem = QTreeWidgetItem([_(name)]) + widgetItem.onclick = onclick + widgetItem.setIcon(0, anknotes_browser_get_icon(icon)) + root.insertTopLevelItem(index, widgetItem) + return root, tree + item = self.CallbackItem(tree, _(name), onclick) + item.setIcon(0, anknotes_browser_get_icon(icon)) + return root, tree + + +def anknotes_browser_add_tree(self, tree, items, root=None, name=None, icon=None): + if root is None: + root = tree + for item in items: + if isinstance(item[1], list): + new_name = item[0] + # write_file_contents('Tree: Name: %s: \n' % str(new_name) + repr(item)) + new_tree = self.CallbackItem(tree, _(new_name), None) + new_tree.setExpanded(True) + new_tree.setIcon(0, anknotes_browser_get_icon(icon)) + root = anknotes_browser_add_tree(self, new_tree, item[1], root, new_name, icon) else: - auth_token = mw.col.conf.get(SETTING_TOKEN, False) - self.token = auth_token - self.client = EvernoteClient(token=auth_token, sandbox=False) - self.noteStore = self.client.get_note_store() - - def find_tag_guid(self, tag): - list_tags = self.noteStore.listTags() - for evernote_tag in list_tags: - if str(evernote_tag.name).strip() == str(tag).strip(): - return evernote_tag.guid - - def create_evernote_cards(self, guid_set): - cards = [] - for guid in guid_set: - note_info = self.get_note_information(guid) - if note_info is None: - return cards - title, content, tags = note_info - cards.append(EvernoteCard(title, content, guid, tags)) - return cards - - def find_notes_filter_by_tag_guids(self, guids_list): - evernote_filter = NoteFilter() - evernote_filter.ascending = False - evernote_filter.tagGuids = guids_list - spec = NotesMetadataResultSpec() - spec.includeTitle = True - note_list = self.noteStore.findNotesMetadata(self.token, evernote_filter, 0, 10000, spec) - guids = [] - for note in note_list.notes: - guids.append(note.guid) - return guids - - def get_note_information(self, note_guid): - tags = [] - try: - whole_note = self.noteStore.getNote(self.token, note_guid, True, True, False, False) - if mw.col.conf.get(SETTING_KEEP_TAGS, False): - tags = self.noteStore.getNoteTagNames(self.token, note_guid) - except EDAMSystemException, e: - if e.errorCode == EDAMErrorCode.RATE_LIMIT_REACHED: - m, s = divmod(e.rateLimitDuration, 60) - showInfo("Rate limit has been reached. We will save the notes downloaded thus far.\r\n" - "Please retry your request in {} min".format("%d:%02d" % (m, s))) - return None - raise - return whole_note.title, whole_note.content, tags - - -class Controller: - def __init__(self): - self.evernoteTags = mw.col.conf.get(SETTING_TAGS_TO_IMPORT, "").split(",") - self.ankiTag = mw.col.conf.get(SETTING_DEFAULT_TAG, "anknotes") - self.deck = mw.col.conf.get(SETTING_DEFAULT_DECK, "Default") - self.updateExistingNotes = mw.col.conf.get(SETTING_UPDATE_EXISTING_NOTES, - UpdateExistingNotes.UpdateNotesInPlace) - self.anki = Anki() - self.anki.add_evernote_model() - self.evernote = Evernote() - - def proceed(self): - anki_ids = self.anki.get_cards_id_from_tag(self.ankiTag) - anki_guids = self.anki.get_guids_from_anki_id(anki_ids) - evernote_guids = self.get_evernote_guids_from_tag(self.evernoteTags) - cards_to_add = set(evernote_guids) - set(anki_guids) - cards_to_update = set(evernote_guids) - set(cards_to_add) - self.anki.start_editing() - n = self.import_into_anki(cards_to_add, self.deck, self.ankiTag) - if self.updateExistingNotes is UpdateExistingNotes.IgnoreExistingNotes: - show_tooltip("{} new card(s) have been imported. Updating is disabled.".format(str(n))) + # write_file_contents('Tree Item: Name: %s: \n' % str(name) + repr(item)) + root, tree = anknotes_browser_add_treeitem(self, tree, *item, root=root) + return root + + +def anknotes_browser_tagtree_wrap(self, root, _old): + """ + + :param root: + :type root : QTreeWidget + :param _old: + :return: + """ + root = _old(self, root) + indices = root.findItems(_("Added Today"), Qt.MatchFixedString) + index = (root.indexOfTopLevelItem(indices[0]) + 1) if indices else 3 + tags = \ + [ + ["Edited This Week", "edited:7", "view-pim-calendar.png", index], + ["Anknotes", + [ + ["All Anknotes", "anknotes:all"], + ["Hierarchy", + [ + ["All Hierarchical Notes", "anknotes:hierarchical"], + ["Root Notes", "anknotes:root"], + ["Sub Notes", "anknotes:sub"], + ["Child Notes", "anknotes:child"], + ["Orphan Notes", "anknotes:orphan"] + ] + ], + # ["Hierarchy: Alt", + # [ + # ["All Hierarchical Notes", "anknotes:hierarchical_alt"], + # ["Root Notes", "anknotes:root_alt"], + # ["Child Notes", "anknotes:child_alt"], + # ["Orphan Notes", "anknotes:orphan_alt"] + # ] + # ], + ["Front Cards", "card:1"] + ] + ] + ] + + return anknotes_browser_add_tree(self, root, tags) + + +def anknotes_finder_findCards_wrap(self, query, order=False, _old=None): + tmr = stopwatch.Timer(label='finder\\findCards') + log_banner("FINDCARDS SEARCH: " + query, tmr.label, append_newline=False, clear=False) + tokens = self._tokenize(query) + preds, args = self._where(tokens) + write_file_contents('Tokens: '.ljust(25) + ', '.join(tokens), tmr.label) + if args: + write_file_contents('Args: '.ljust(25) + ', '.join(tokens), tmr.label) + if preds is None: + write_file_contents('Preds: '.ljust(25) + '<NONE>', tmr.label) + log_blank(tmr.label) + return [] + + order, rev = self._order(order) + sql = self._query(preds, order) + try: + res = self.col.db.list(sql, *args) + except Exception as ex: + # invalid grouping + log_error("Error with findCards Query %s: %s.\n%s" % (query, str(ex), [sql, args]), crosspost=tmr.label) + return [] + if rev: + res.reverse() + write_file_contents("FINDCARDS DONE: ".ljust(25) + "%-5s --> %3d results" % (tmr.str_long, len(res)), tmr.label) + log_blank(tmr.label) + return res + return _old(self, query, order) + + +def anknotes_finder_query_wrap(self, preds=None, order=None, _old=None): + if _old is None or not isinstance(self, Finder): + log_dump([self, preds, order], 'Finder Query Wrap Error', 'finder\\error', crosspost_to_default=False) + return + sql = _old(self, preds, order) + if "ank." in preds: + sql = sql.replace("select c.id", "select distinct c.id").replace("from cards c", + "from cards c, %s ank" % TABLES.EVERNOTE.NOTES) + write_file_contents('Custom anknotes finder SELECT query: \n%s' % sql, 'finder\\ank-query') + elif TABLES.EVERNOTE.NOTES in preds: + write_file_contents('Custom anknotes finder alternate query: \n%s' % sql, 'finder\\ank-query') + else: + write_file_contents("Anki finder query: %s" % sql[:100], 'finder\\query') + return sql + + +def anknotes_search_hook(search): + anknotes_search = {'edited': _findEdited, 'anknotes': _findAnknotes} + for key, value in anknotes_search.items(): + if key not in search: + search[key] = anknotes_search[key] + +def reset_everything(upload=True): + show_tooltip_enabled = show_tooltip.enabled if hasattr(show_tooltip, 'enabled') else None + show_tooltip.enabled = False + ankDB().InitSeeAlso(True) + menu.resync_with_local_db() + menu.see_also(upload=upload) + show_tooltip.enabled = show_tooltip_enabled + + +def anknotes_profile_loaded(): + # write_file_contents('%s: anknotes_profile_loaded' % __name__, 'load') + last_profile_dir = os.path.dirname(FILES.USER.LAST_PROFILE_LOCATION) + if not os.path.exists(last_profile_dir): + os.makedirs(last_profile_dir) + with open(FILES.USER.LAST_PROFILE_LOCATION, 'w+') as myFile: + print>> myFile, mw.pm.name, + # write_file_contents('%s: anknotes_profile_loaded: menu.anknotes_load_menu_settings' % __name__, 'load') + menu.anknotes_load_menu_settings() + if EVERNOTE.UPLOAD.VALIDATION.ENABLED and EVERNOTE.UPLOAD.VALIDATION.AUTOMATED: + # write_file_contents('%s: anknotes_profile_loaded: menu.upload_validated_notes' % __name__, 'load') + menu.upload_validated_notes(True) + if ANKNOTES.UPDATE_DB_ON_START: + # write_file_contents('%s: anknotes_profile_loaded: update_anknotes_nids' % __name__, 'load') + update_anknotes_nids() + # write_file_contents('%s: anknotes_profile_loaded: import_timer_toggle' % __name__, 'load') + import_timer_toggle() + if ANKNOTES.DEVELOPER_MODE.AUTOMATED: + ''' + For testing purposes only! + Add a function here and it will automatically run on profile load + You must create the files 'anknotes.developer' and 'anknotes.developer.automate' in the /extra/dev/ folder + ''' + # write_file_contents('%s: anknotes_profile_loaded: ANKNOTES.DEVELOPER_MODE.AUTOMATED' % __name__, 'load') + # menu.lxml_test() + # menu.see_also([8]) + # menu.see_also(upload=False) + reset_everything(False) + # menu.see_also(set(range(0,10)) - {3,4,8}) + # ankDB().InitSeeAlso(True) + # menu.resync_with_local_db() + # menu.see_also([1, 2, 6, 7, 9]) + # menu.lxml_test() + # menu.see_also() + # reset_everything() + # menu.import_from_evernote(auto_page_callback=lambda: lambda: menu.see_also(3)) + # mw.progress.timer(20000, lambda : menu.find_deleted_notes(True), False) + pass + +def anknotes_scalar(self, *a, **kw): + log_text = 'Call to DB.scalar():' + if not isinstance(self, DB): + log_text += '\n - Self: ' + pf(self) + if a: + log_text += '\n - Args: ' + pf(a) + if kw: + log_text += '\n - KWArgs: ' + pf(kw) + last_query='<None>' + if hasattr(self, 'ank_lastquery'): + last_query = self.ank_lastquery + if is_str_type(last_query): + last_query = last_query[:50] else: - n2 = len(cards_to_update) - if self.updateExistingNotes is UpdateExistingNotes.UpdateNotesInPlace: - update_str = "in-place" - self.update_in_anki(cards_to_update, self.ankiTag) - else: - update_str = "(deleted and re-added)" - self.anki.delete_anki_cards(cards_to_update) - self.import_into_anki(cards_to_update, self.deck, self.ankiTag) - show_tooltip("{} new card(s) have been imported and {} existing card(s) have been updated {}." - .format(str(n), str(n2), update_str)) - self.anki.stop_editing() - self.anki.collection().autosave() - - def update_in_anki(self, guid_set, tag): - cards = self.evernote.create_evernote_cards(guid_set) - number = self.anki.update_evernote_cards(cards, tag) - return number - - def import_into_anki(self, guid_set, deck, tag): - cards = self.evernote.create_evernote_cards(guid_set) - number = self.anki.add_evernote_cards(cards, deck, tag) - return number - - def get_evernote_guids_from_tag(self, tags): - note_guids = [] - for tag in tags: - tag_guid = self.evernote.find_tag_guid(tag) - if tag_guid is not None: - note_guids += self.evernote.find_notes_filter_by_tag_guids([tag_guid]) - return note_guids - - -def show_tooltip(text, time_out=3000): - aqt.utils.tooltip(text, time_out) - - -def main(): - controller = Controller() - controller.proceed() - - -action = aqt.qt.QAction("Import from Evernote", aqt.mw) -aqt.mw.connect(action, aqt.qt.SIGNAL("triggered()"), main) -aqt.mw.form.menuTools.addAction(action) - - -def setup_evernote(self): - global evernote_default_deck - global evernote_default_tag - global evernote_tags_to_import - global keep_evernote_tags - global update_existing_notes - - widget = QWidget() - layout = QVBoxLayout() - - # Default Deck - evernote_default_deck_label = QLabel("Default Deck:") - evernote_default_deck = QLineEdit() - evernote_default_deck.setText(mw.col.conf.get(SETTING_DEFAULT_DECK, "")) - layout.insertWidget(int(layout.count()) + 1, evernote_default_deck_label) - layout.insertWidget(int(layout.count()) + 2, evernote_default_deck) - evernote_default_deck.connect(evernote_default_deck, SIGNAL("editingFinished()"), update_evernote_default_deck) - - # Default Tag - evernote_default_tag_label = QLabel("Default Tag:") - evernote_default_tag = QLineEdit() - evernote_default_tag.setText(mw.col.conf.get(SETTING_DEFAULT_TAG, "")) - layout.insertWidget(int(layout.count()) + 1, evernote_default_tag_label) - layout.insertWidget(int(layout.count()) + 2, evernote_default_tag) - evernote_default_tag.connect(evernote_default_tag, SIGNAL("editingFinished()"), update_evernote_default_tag) - - # Tags to Import - evernote_tags_to_import_label = QLabel("Tags to Import:") - evernote_tags_to_import = QLineEdit() - evernote_tags_to_import.setText(mw.col.conf.get(SETTING_TAGS_TO_IMPORT, "")) - layout.insertWidget(int(layout.count()) + 1, evernote_tags_to_import_label) - layout.insertWidget(int(layout.count()) + 2, evernote_tags_to_import) - evernote_tags_to_import.connect(evernote_tags_to_import, - SIGNAL("editingFinished()"), - update_evernote_tags_to_import) - - # Keep Evernote Tags - keep_evernote_tags = QCheckBox("Keep Evernote Tags", self) - keep_evernote_tags.setChecked(mw.col.conf.get(SETTING_KEEP_TAGS, False)) - keep_evernote_tags.stateChanged.connect(update_evernote_keep_tags) - layout.insertWidget(int(layout.count()) + 1, keep_evernote_tags) - - # Update Existing Notes - update_existing_notes = QComboBox() - update_existing_notes.addItems(["Ignore Existing Notes", "Update Existing Notes In-Place", - "Delete and Re-Add Existing Notes"]) - update_existing_notes.setCurrentIndex(mw.col.conf.get(SETTING_UPDATE_EXISTING_NOTES, - UpdateExistingNotes.UpdateNotesInPlace)) - update_existing_notes.activated.connect(update_evernote_update_existing_notes) - layout.insertWidget(int(layout.count()) + 1, update_existing_notes) - - # Vertical Spacer - vertical_spacer = QSpacerItem(20, 0, QSizePolicy.Minimum, QSizePolicy.Expanding) - layout.addItem(vertical_spacer) - - # Parent Widget - widget.setLayout(layout) - - # New Tab - self.form.tabWidget.addTab(widget, "Evernote Importer") - -def update_evernote_default_deck(): - mw.col.conf[SETTING_DEFAULT_DECK] = evernote_default_deck.text() - -def update_evernote_default_tag(): - mw.col.conf[SETTING_DEFAULT_TAG] = evernote_default_tag.text() - -def update_evernote_tags_to_import(): - mw.col.conf[SETTING_TAGS_TO_IMPORT] = evernote_tags_to_import.text() - -def update_evernote_keep_tags(): - mw.col.conf[SETTING_KEEP_TAGS] = keep_evernote_tags.isChecked() - -def update_evernote_update_existing_notes(index): - mw.col.conf[SETTING_UPDATE_EXISTING_NOTES] = index - -Preferences.setupOptions = wrap(Preferences.setupOptions, setup_evernote) + last_query = pf(last_query) + log_text += '\n - Last Query: ' + last_query + write_file_contents(log_text + '\n', 'sql\\scalar') + try: + res = self.execute(*a, **kw) + except TypeError as e: + write_file_contents(" > ERROR with scalar while executing query: %s\n > LAST QUERY: %s" % (str(e), last_query), 'sql\\scalar', crosspost='sql\\scalar-error') + raise + if not isinstance(res, sqlite.Cursor): + write_file_contents(' > Cursor: %s' % pf(res), 'sql\\scalar') + try: + res = res.fetchone() + except TypeError as e: + write_file_contents(" > ERROR with scalar while fetching result: %s\n > LAST QUERY: %s" % (str(e), last_query), 'sql\\scalar', crosspost='sql\\scalar-error') + raise + write_file_contents('', 'sql\\scalar') + if res: + return res[0] + return None + +def anknotes_execute(self, sql, *a, **kw): + log_text = 'Call to DB.execute():' + if not isinstance(self, DB): + log_text += '\n - Self: ' + pf(self) + if a: + log_text += '\n - Args: ' + pf(a) + if kw: + log_text += '\n - KWArgs: ' + pf(kw) + last_query=sql + if is_str_type(last_query): + last_query = last_query[:50] + else: + last_query = pf(last_query) + log_text += '\n - Query: ' + last_query + write_file_contents(log_text + '\n\n', 'sql\\execute') + self.ank_lastquery = sql + +def anknotes_onload(): + # write_file_contents('%s: anknotes_onload' % __name__, 'load') + if in_anki(): + addHook("profileLoaded", anknotes_profile_loaded) + if ANKNOTES.HOOKS.DB: + DB.scalar = anknotes_scalar # wrap(DB.scalar, anknotes_scalar, "before") + DB.execute = wrap(DB.execute, anknotes_execute, "before") + if ANKNOTES.HOOKS.SEARCH: + addHook("search", anknotes_search_hook) + Finder._query = wrap(Finder._query, anknotes_finder_query_wrap, "around") + Finder.findCards = wrap(Finder.findCards, anknotes_finder_findCards_wrap, "around") + browser.Browser._systemTagTree = wrap(browser.Browser._systemTagTree, anknotes_browser_tagtree_wrap, "around") + # write_file_contents('%s: anknotes_onload: anknotes_setup_menu' % __name__, 'load') + menu.anknotes_setup_menu() + Preferences.setupOptions = wrap(Preferences.setupOptions, settings.setup_evernote) + # write_file_contents('%s: anknotes_onload: complete' % __name__, 'load') + +anknotes_onload() \ No newline at end of file diff --git a/anknotes/_re.py b/anknotes/_re.py new file mode 100644 index 0000000..151ddc9 --- /dev/null +++ b/anknotes/_re.py @@ -0,0 +1,277 @@ +"""Skeleton for 're' stdlib module.""" + + +def compile(pattern, flags=0): + """Compile a regular expression pattern, returning a pattern object. + + :type pattern: bytes | unicode + :type flags: int + :rtype: __Regex + """ + pass + + +def search(pattern, string, flags=0): + """Scan through string looking for a match, and return a corresponding + match instance. Return None if no position in the string matches. + + :type pattern: bytes | unicode | __Regex + :type string: T <= bytes | unicode + :type flags: int + :rtype: __Match[T] | None + """ + pass + + +def match(pattern, string, flags=0): + """Matches zero or more characters at the beginning of the string. + + :type pattern: bytes | unicode | __Regex + :type string: T <= bytes | unicode + :type flags: int + :rtype: __Match[T] | None + """ + pass + + +def split(pattern, string, maxsplit=0, flags=0): + """Split string by the occurrences of pattern. + + :type pattern: bytes | unicode | __Regex + :type string: T <= bytes | unicode + :type maxsplit: int + :type flags: int + :rtype: list[T] + """ + pass + + +def findall(pattern, string, flags=0): + """Return a list of all non-overlapping matches of pattern in string. + + :type pattern: bytes | unicode | __Regex + :type string: T <= bytes | unicode + :type flags: int + :rtype: list[T] + """ + pass + + +def finditer(pattern, string, flags=0): + """Return an iterator over all non-overlapping matches for the pattern in + string. For each match, the iterator returns a match object. + + :type pattern: bytes | unicode | __Regex + :type string: T <= bytes | unicode + :type flags: int + :rtype: collections.Iterable[__Match[T]] + """ + pass + + +def sub(pattern, repl, string, count=0, flags=0): + """Return the string obtained by replacing the leftmost non-overlapping + occurrences of pattern in string by the replacement repl. + + :type pattern: bytes | unicode | __Regex + :type repl: bytes | unicode | collections.Callable + :type string: T <= bytes | unicode + :type count: int + :type flags: int + :rtype: T + """ + pass + + +def subn(pattern, repl, string, count=0, flags=0): + """Return the tuple (new_string, number_of_subs_made) found by replacing + the leftmost non-overlapping occurrences of pattern with the + replacement repl. + + :type pattern: bytes | unicode | __Regex + :type repl: bytes | unicode | collections.Callable + :type string: T <= bytes | unicode + :type count: int + :type flags: int + :rtype: (T, int) + """ + pass + + +def escape(string): + """Escape all the characters in pattern except ASCII letters and numbers. + + :type string: T <= bytes | unicode + :type: T + """ + pass + + +class __Regex(object): + """Mock class for a regular expression pattern object.""" + + def __init__(self, flags, groups, groupindex, pattern): + """Create a new pattern object. + + :type flags: int + :type groups: int + :type groupindex: dict[bytes | unicode, int] + :type pattern: bytes | unicode + """ + self.flags = flags + self.groups = groups + self.groupindex = groupindex + self.pattern = pattern + + def search(self, string, pos=0, endpos=-1): + """Scan through string looking for a match, and return a corresponding + match instance. Return None if no position in the string matches. + + :type string: T <= bytes | unicode + :type pos: int + :type endpos: int + :rtype: __Match[T] | None + """ + pass + + def match(self, string, pos=0, endpos=-1): + """Matches zero | more characters at the beginning of the string. + + :type string: T <= bytes | unicode + :type pos: int + :type endpos: int + :rtype: __Match[T] | None + """ + pass + + def split(self, string, maxsplit=0): + """Split string by the occurrences of pattern. + + :type string: T <= bytes | unicode + :type maxsplit: int + :rtype: list[T] + """ + pass + + def findall(self, string, pos=0, endpos=-1): + """Return a list of all non-overlapping matches of pattern in string. + + :type string: T <= bytes | unicode + :type pos: int + :type endpos: int + :rtype: list[T] + """ + pass + + def finditer(self, string, pos=0, endpos=-1): + """Return an iterator over all non-overlapping matches for the + pattern in string. For each match, the iterator returns a + match object. + + :type string: T <= bytes | unicode + :type pos: int + :type endpos: int + :rtype: collections.Iterable[__Match[T]] + """ + pass + + def sub(self, repl, string, count=0): + """Return the string obtained by replacing the leftmost non-overlapping + occurrences of pattern in string by the replacement repl. + + :type repl: bytes | unicode | collections.Callable + :type string: T <= bytes | unicode + :type count: int + :rtype: T + """ + pass + + def subn(self, repl, string, count=0): + """Return the tuple (new_string, number_of_subs_made) found by replacing + the leftmost non-overlapping occurrences of pattern with the + replacement repl. + + :type repl: bytes | unicode | collections.Callable + :type string: T <= bytes | unicode + :type count: int + :rtype: (T, int) + """ + pass + + +class __Match(object): + """Mock class for a match object.""" + + def __init__(self, pos, endpos, lastindex, lastgroup, re, string): + """Create a new match object. + + :type pos: int + :type endpos: int + :type lastindex: int | None + :type lastgroup: int | bytes | unicode | None + :type re: __Regex + :type string: bytes | unicode + :rtype: __Match[T] + """ + self.pos = pos + self.endpos = endpos + self.lastindex = lastindex + self.lastgroup = lastgroup + self.re = re + self.string = string + + def expand(self, template): + """Return the string obtained by doing backslash substitution on the + template string template. + + :type template: T + :rtype: T + """ + pass + + def group(self, *args): + """Return one or more subgroups of the match. + + :rtype: T | tuple + """ + pass + + def groups(self, default=None): + """Return a tuple containing all the subgroups of the match, from 1 up + to however many groups are in the pattern. + + :rtype: tuple + """ + pass + + def groupdict(self, default=None): + """Return a dictionary containing all the named subgroups of the match, + keyed by the subgroup name. + + :rtype: dict[bytes | unicode, T] + """ + pass + + def start(self, group=0): + """Return the index of the start of the substring matched by group. + + :type group: int | bytes | unicode + :rtype: int + """ + pass + + def end(self, group=0): + """Return the index of the end of the substring matched by group. + + :type group: int | bytes | unicode + :rtype: int + """ + pass + + def span(self, group=0): + """Return a 2-tuple (start, end) for the substring matched by group. + + :type group: int | bytes | unicode + :rtype: (int, int) + """ + pass diff --git a/anknotes/addict/__init__.py b/anknotes/addict/__init__.py new file mode 100644 index 0000000..a22f2e8 --- /dev/null +++ b/anknotes/addict/__init__.py @@ -0,0 +1,8 @@ +from .addict import Dict + +__title__ = 'addict' +__version__ = '0.4.0' +__author__ = 'Mats Julian Olsen' +__license__ = 'MIT' +__copyright__ = 'Copyright 2014 Mats Julian Olsen' +__all__ = ['Dict'] diff --git a/anknotes/addict/addict.py b/anknotes/addict/addict.py new file mode 100644 index 0000000..7ccb63f --- /dev/null +++ b/anknotes/addict/addict.py @@ -0,0 +1,368 @@ +from inspect import isgenerator +import re +import os +import copy +from anknotes.base import is_dict_type, item_to_list, is_seq_type + + +class Dict(dict): + """ + Dict is a subclass of dict, which allows you to get AND SET(!!) + items in the dict using the attribute syntax! + + When you previously had to write: + + my_dict = {'a': {'b': {'c': [1, 2, 3]}}} + + you can now do the same simply by: + + my_Dict = Dict() + my_Dict.a.b.c = [1, 2, 3] + + Or for instance, if you'd like to add some additional stuff, + where you'd with the normal dict would write + + my_dict['a']['b']['d'] = [4, 5, 6], + + you may now do the AWESOME + + my_Dict.a.b.d = [4, 5, 6] + + instead. But hey, you can always use the same syntax as a regular dict, + however, this will not raise TypeErrors or AttributeErrors at any time + while you try to get an item. A lot like a defaultdict. + + """ + + def __init__(self, *a, **kw): + """ + If we're initialized with a dict, make sure we turn all the + subdicts into Dicts as well. + + """ + a = list(a) + mro = self._get_arg_(a, int, 'mro', kw) + pass # self.log_init('Dict', mro, a, kw) + self.update(*a, **kw) + + def __setattr__(self, name, value): + """ + setattr is called when the syntax a.b = 2 is used to set a value. + + """ + if hasattr(Dict, name): + raise AttributeError("'Dict' object attribute " + "'{0}' is read-only".format(name)) + else: + self[name] = value + + def __setitem__(self, name, value): + """ + This is called when trying to set a value of the Dict using []. + E.g. some_instance_of_Dict['b'] = val. If 'val + + """ + value = self._hook(value) + super(Dict, self).__setitem__(name, value) + + def _hook(self, item): + """ + Called to ensure that each dict-instance that are being set + is a addict Dict. Recurses. + + """ + if isinstance(item, dict): + return self._new_instance_(item) + if isinstance(item, (list, tuple)): + return item.__class__(self._hook(elem) for elem in item) + return item + + def __getattr__(self, item): + if item not in self and item in dir(self): + return super(Dict, self).__getattr__(item) + return self.__getitem__(item) + + def _new_instance_(self, *a, **kw): + return (self.__class__.mro()[self._mro_offset_] if self._is_obj_attr_('_mro_offset_') else self.__class__)(*a, **kw) + + def __getitem__(self, name): + """ + This is called when the Dict is accessed by []. E.g. + some_instance_of_Dict['a']; + If the name is in the dict, we return it. Otherwise we set both + the attr and item to a new instance of Dict. + + """ + if name not in self: + self[name] = self._new_instance_() + return super(Dict, self).__getitem__(name) + + def __delattr__(self, name): + """ Is invoked when del some_addict.b is called. """ + del self[name] + + _re_pattern = re.compile('[a-zA-Z_][a-zA-Z0-9_]*') + + # def log(self, str_, method, do_print=True, prefix=''): + # cls = self.__class__ + # str_lbl = self.label.full if self.label else '' + # if str_lbl: + # str_lbl += ': ' + # str_ = prefix + '%17s %-20s %s' % ('<%s>' % cls.__name__, str_lbl, str_) + # if do_print: + # print str_ + # write_file_contents(str_, 'Dicts\\%s\\%s' % (cls.__name__, method)) + + # def log_action(self, method, action, name, value, key=None, via=None, extra='', log_self=False): + # if key in ['_my_attrs_','_override_default_']: + # return + # if extra: + # extra += ' ' + # type = ('<%s>' % value.__class__.__name__).center(10) + # log_str = action.ljust(12) + ' ' + # log_str += name.ljust(12) + ' ' + # log_str += ('via '+via if via else '').ljust(10) + ' ' + # log_str += ('for `%s`' % key if key else '').ljust(25) + ' ' + # log_str += 'to %10s%s %s' % (extra, type, str(value)) + # if log_self: + # log_str += ' \n\n Self: ' + repr(self) + # pass # self.log(log_str, method); + + # def log_init(self, type, mro, a, kw): + # cls = self.__class__ + # mro_name = cls.mro()[mro].__name__ + # mro_name = (':' + mro_name) if mro_name != cls.__name__ and mro_name != type else '' + # log_str = "Init: %s%s #%d" % (type, mro_name, mro) + # log_str += "\n Args: %s" % a if a else "" + # log_str += "\n KWArgs: %s" % kw if kw else "" + # pass # self.log(log_str + '\n', '__init__', prefix='-'*40+'\n', do_print=False) + + # def clear_logs(self): + # name=self.__class__.__name__ + # reset_logs('Dicts' + os.path.sep + name, self.make_banner(name)) + + @staticmethod + def get_default_value(cls, default=None): + if default is not None: + return default + if cls is str or cls is unicode: + return '' + elif cls is int: + return 0 + elif cls is bool: + return False + return None + + def _get_arg_(self, a, cls=None, key=None, kw=None, default=None): + if cls is None: + cls = (str, unicode) + if a and isinstance(a[0], cls): + val = a[0] + #del a[0] + elif kw and key in kw: + val = kw[key] + del kw[key] + else: + val = self.get_default_value(cls, default) + return val + + def _key_transform_(self, key, keys=None, all=False, attrs=False): + return key + + def _key_transform_all_(self, key, keys=None): + return self._key_transform_(key, keys, all=True) + + def _key_transform_attrs_(self, key, keys=None): + return self._key_transform_(key, keys, attrs=True) + + def __contains__(self, item): + key = self._key_transform_(item) + return key in self._dict_keys_() + + def _is_obj_attr_(self, key): + keys = self._obj_attrs_() + return self._key_transform_(key, keys) in keys + + def _dict_keys_(self): + dict_keys = [] + for k in self.keys(): + if isinstance(k, str): + m = self._re_pattern.match(k) + if m: + dict_keys.append(m.string) + return dict_keys + + def _obj_attrs_(self): + return list(dir(self.__class__)) + + def __dir__(self): + """ + Return a list of addict object attributes. + This includes key names of any dict entries, filtered to the subset of + valid attribute names (e.g. alphanumeric strings beginning with a + letter or underscore). Also includes attributes of parent dict class. + """ + return self._dict_keys_() + self._obj_attrs_() + + def _ipython_display_(self): + print(str(self)) # pragma: no cover + + def _repr_html_(self): + return str(self) + + def prune(self, prune_zero=False, prune_empty_list=True): + """ + Removes all empty Dicts and falsy stuff inside the Dict. + E.g + >>> a = Dict() + >>> a.b.c.d + {} + >>> a.a = 2 + >>> a + {'a': 2, 'b': {'c': {'d': {}}}} + >>> a.prune() + >>> a + {'a': 2} + + Set prune_zero=True to remove 0 values + E.g + >>> a = Dict() + >>> a.b.c.d = 0 + >>> a.prune(prune_zero=True) + >>> a + {} + + Set prune_empty_list=False to have them persist + E.g + >>> a = Dict({'a': []}) + >>> a.prune() + >>> a + {} + >>> a = Dict({'a': []}) + >>> a.prune(prune_empty_list=False) + >>> a + {'a': []} + """ + for key, val in list(self.items()): + if ((not val) and ((val != 0) or prune_zero) and + not isinstance(val, list)): + del self[key] + elif isinstance(val, Dict): + val.prune(prune_zero, prune_empty_list) + if not val: + del self[key] + elif isinstance(val, (list, tuple)): + new_iter = self._prune_iter(val, prune_zero, prune_empty_list) + if (not new_iter) and prune_empty_list: + del self[key] + else: + if isinstance(val, tuple): + new_iter = tuple(new_iter) + self[key] = new_iter + + @classmethod + def _prune_iter(cls, some_iter, prune_zero=False, prune_empty_list=True): + new_iter = [] + for item in some_iter: + if item == 0 and prune_zero: + continue + elif isinstance(item, Dict): + item.prune(prune_zero, prune_empty_list) + if item: + new_iter.append(item) + elif isinstance(item, (list, tuple)): + new_item = item.__class__( + cls._prune_iter(item, prune_zero, prune_empty_list)) + if new_item or not prune_empty_list: + new_iter.append(new_item) + else: + new_iter.append(item) + return new_iter + + def to_dict(self): + """ Recursively turn your addict Dicts into dicts. """ + base = {} + cls = self.__class__ + for key, value in self.items(): + if isinstance(value, cls): + base[key] = value.to_dict() + elif isinstance(value, (list, tuple)): + base[key] = value.__class__( + item.to_dict() if isinstance(item, cls) else + item for item in value) + else: + base[key] = value + return base + + def copy(self): + """ + Return a disconnected deep copy of self. Children of type Dict, list + and tuple are copied recursively while values that are instances of + other mutable objects are not copied. + + """ + return self._new_instance_(self.to_dict()) + + def __deepcopy__(self, memo): + """ Return a disconnected deep copy of self. """ + + y = self.__class__() + memo[id(self)] = y + for key, value in self.items(): + y[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) + return y + + def initialize_keys(self, arg, split_chr='|'): + """ + Initializes keys from string or sequence. + From string: + 1) String is converted to list, split by 'split_chr' argument + From list: + 2) If list item has two subitems, it will be treated as a key-value pair + E.g: [['key', 'value'], ['key2', 'value2']] will set keys with corresponding values + 3) Otherwise, it will be treated as a list of keys + E.g.: [['key1', 'key2', 'key3']] will instantiate keys as new Dicts + 4) If list item is not a sequence, it will be converted to a list as per Example 1 + E.g.: ['key1', 'key2', 'key3'] will act similarly to Example 3 + 5) Exception: If list item is a dict, it will be handled via update.update_dict + """ + if not is_seq_type(arg): + arg = item_to_list(arg, split_chr=split_chr) + for items in arg: + if is_dict_type(items): + self.update(items) + continue + if not is_seq_type(items): + items = item_to_list(items, split_chr=split_chr) + if len(items) is 1: + self[items[0]] + elif len(items) is 2: + self[items[0]] = items[1] + else: + self.update_seq(items) + + def update(self, *a, **kw): + """ Update self with dict, sequence, or kwargs """ + def update_dict(d): + """ Recursively merge d into self. """ + for k, v in d.items(): + if k in self and is_dict_type(self[k], v): + self[k].update(v) + else: + self[k] = v + + # Begin update() + for arg in a: + if not arg: + continue + elif isinstance(arg, dict): + update_dict(arg) + elif isinstance(arg, tuple) and len(arg) is 2 and not isinstance(arg[0], tuple): + self[arg[0]] = arg[1] + elif is_seq_type(arg): + self.initialize_keys(arg) + else: + raise TypeError("Dict does not understand " + "{0} types".format(arg.__class__)) + update_dict(kw) + diff --git a/anknotes/ankEvernote.py b/anknotes/ankEvernote.py new file mode 100644 index 0000000..28c078c --- /dev/null +++ b/anknotes/ankEvernote.py @@ -0,0 +1,603 @@ +# -*- coding: utf-8 -*- +### Python Imports +import socket +import stopwatch +from datetime import datetime, timedelta +from StringIO import StringIO + +# try: +# from lxml import etree +# eTreeImported = True +# except ImportError: +# eTreeImported = False + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### Anknotes Imports +from anknotes.shared import * +from anknotes.error import * +from anknotes.imports import in_anki +from anknotes.base import is_str, encode + +### Anki Imports +if in_anki(): + ### Anknotes Class Imports + from anknotes.EvernoteNoteFetcher import EvernoteNoteFetcher + from anknotes.EvernoteNotePrototype import EvernoteNotePrototype + + ### Evernote Imports + from anknotes.evernote.edam.type.ttypes import Note as EvernoteNote + from anknotes.evernote.edam.error.ttypes import EDAMSystemException, EDAMUserException, EDAMNotFoundException + from anknotes.evernote.api.client import EvernoteClient + + ### Anki Imports + from aqt.utils import openLink, getText, showInfo + from aqt import mw + + +### Anki Imports +# import anki +# import aqt +# from anki.hooks import wrap, addHook +# from aqt.preferences import Preferences +# from aqt.utils import getText, openLink, getOnlyText +# from aqt.qt import QLineEdit, QLabel, QVBoxLayout, QHBoxLayout, QGroupBox, SIGNAL, QCheckBox, \ +# QComboBox, QSpacerItem, QSizePolicy, QWidget, QSpinBox, QFormLayout, QGridLayout, QFrame, QPalette, \ +# QRect, QStackedLayout, QDateEdit, QDateTimeEdit, QTimeEdit, QDate, QDateTime, QTime, QPushButton, QIcon, QMessageBox, QPixmap, QMenu, QAction +# from aqt import mw + +etree = None + +class Evernote(object): + metadata = {} + """:type : dict[str, evernote.edam.type.ttypes.Note]""" + notebook_data = {} + """:type : dict[str, anknotes.structs.EvernoteNotebook]""" + tag_data = {} + """:type : dict[str, anknotes.structs.EvernoteTag]""" + DTD = None + __hasValidator = None + token = None + client = None + """:type : EvernoteClient """ + + def hasValidator(self): + global etree + if self.__hasValidator is None: + self.__hasValidator = import_etree() + if self.__hasValidator: + from anknotes.imports import etree + return self.__hasValidator + + def __init__(self): + self.tag_data = {} + self.notebook_data = {} + self.noteStore = None + self.getNoteCount = 0 + # self.hasValidator = eTreeImported + if ankDBIsLocal(): + log("Skipping Evernote client load (DB is Local)", 'client') + return + self.setup_client() + + def setup_client(self): + auth_token = SETTINGS.EVERNOTE.AUTH_TOKEN.fetch() + if not auth_token: + # First run of the Plugin we did not save the access key yet + secrets = {'holycrepe': '36f46ea5dec83d4a', 'scriptkiddi-2682': '965f1873e4df583c'} + client = EvernoteClient( + consumer_key=EVERNOTE.API.CONSUMER_KEY, + consumer_secret=secrets[EVERNOTE.API.CONSUMER_KEY], + sandbox=EVERNOTE.API.IS_SANDBOXED + ) + request_token = client.get_request_token('https://fap-studios.de/anknotes/index.html') + url = client.get_authorize_url(request_token) + showInfo("We will open a Evernote Tab in your browser so you can allow access to your account") + openLink(url) + oauth_verifier = getText(prompt="Please copy the code that showed up, after allowing access, in here")[0] + auth_token = client.get_access_token( + request_token.get('oauth_token'), + request_token.get('oauth_token_secret'), + oauth_verifier) + SETTINGS.EVERNOTE.AUTH_TOKEN.save(auth_token) + else: + client = EvernoteClient(token=auth_token, sandbox=EVERNOTE.API.IS_SANDBOXED) + self.token = auth_token + self.client = client + log("Set up Evernote Client", 'client') + + def initialize_note_store(self): + if self.noteStore: + return EvernoteAPIStatus.Success + api_action_str = u'trying to initialize the Evernote Note Store.' + log_api("get_note_store") + if not self.client: + log_error( + "Client does not exist for some reason. Did we not initialize Evernote Class? Current token: " + str( + self.token)) + self.setup_client() + try: + self.noteStore = self.client.get_note_store() + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.RateLimitError + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.SocketError + return EvernoteAPIStatus.Success + + def loadDTD(self): + if self.DTD: + return + timerInterval = stopwatch.Timer() + log("Loading ENML DTD", "lxml", timestamp=False, do_print=True) + self.DTD = etree.DTD(FILES.ANCILLARY.ENML_DTD) + log("DTD Loaded in %s\n" % str(timerInterval), "lxml", timestamp=False, do_print=True) + log(' > Note Validation: ENML DTD Loaded in %s' % str(timerInterval)) + del timerInterval + + def validateNoteBody(self, noteBody, title="Note Body"): + """ + + :param noteBody: + :type noteBody : str | unicode + :param title: + :return: + :rtype : (EvernoteAPIStatus, [str|unicode]) + """ + self.loadDTD() + noteBody = noteBody.replace('"http://xml.evernote.com/pub/enml2.dtd"', + '"%s"' % convert_filename_to_local_link(FILES.ANCILLARY.ENML_DTD)) + parser = etree.XMLParser(dtd_validation=True, attribute_defaults=True) + try: + root = etree.fromstring(noteBody, parser) + except Exception as e: + log_str = "XML Loading of %s failed.\n - Error Details: %s" % (title, str(e)) + log(log_str, "lxml", timestamp=False, do_print=True) + log_error(log_str, False) + return EvernoteAPIStatus.UserError, [log_str] + try: + success = self.DTD.validate(root) + except Exception as e: + log_str = "DTD Validation of %s failed.\n - Error Details: %s" % (title, str(e)) + log(log_str, "lxml", timestamp=False, do_print=True) + log_error(log_str, False) + return EvernoteAPIStatus.UserError, [log_str] + log("Validation %-9s for %s" % ("Succeeded" if success else "Failed", title), "lxml", timestamp=False, + do_print=True) + errors = [str(x) for x in self.DTD.error_log.filter_from_errors()] + if not success: + log_str = "DTD Validation Errors for %s: \n%s\n" % (title, str(errors)) + log(log_str, "lxml", timestamp=False) + log_error(log_str, False) + return EvernoteAPIStatus.Success if success else EvernoteAPIStatus.UserError, errors + + def validateNoteContent(self, content, title="Note Contents"): + """ + + :param content: Valid ENML without the <en-note></en-note> tags. Will be processed by makeNoteBody + :type content : str|unicode + :return: + :rtype : (EvernoteAPIStatus, [str|unicode]) + """ + return self.validateNoteBody(self.makeNoteBody(content), title) + + def updateNote(self, guid, noteTitle, noteBody, tagNames=None, parentNotebook=None, noteType=None, resources=None): + """ + Update a Note instance with title and body + Send Note object to user's account + :rtype : (EvernoteAPIStatus, evernote.edam.type.ttypes.Note) + :returns Status and Note + """ + if resources is None: + resources = [] + return self.makeNote(noteTitle, noteBody, tagNames=tagNames, parentNotebook=parentNotebook, noteType=noteType, + resources=resources, + guid=guid) + + @staticmethod + def makeNoteBody(content, resources=None, encode=True): + ## Build body of note + if resources is None: + resources = [] + nBody = content + if not nBody.startswith("<?xml"): + nBody = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>" + nBody += "<!DOCTYPE en-note SYSTEM \"http://xml.evernote.com/pub/enml2.dtd\">" + nBody += "<en-note>%s" % content + "</en-note>" + if encode: + nBody = encode(nBody) + return nBody + + @staticmethod + def addNoteToMakeNoteQueue(noteTitle, noteContents, tagNames=list(), parentNotebook=None, resources=None, + noteType=None, + guid=None): + db = ankDB(TABLES.NOTE_VALIDATION_QUEUE) + if not noteType: + noteType = 'Unspecified' + if resources is None: + resources = [] + args = [noteType] + sql = "FROM {t} WHERE noteType = ? AND " + if guid: + sql += 'guid = ?' + args.append(guid) + else: + sql += 'title = ? AND contents = ?' + args += [noteTitle, noteContents] + statuses = db.all('SELECT validation_status ' + sql, args) + if statuses: + if str(statuses[0]['validation_status']) == '1': + return EvernoteAPIStatus.Success + db.execute("DELETE " + sql, args) + db.execute( + "INSERT INTO {t}(guid, title, contents, tagNames, notebookGuid, noteType) VALUES(?, ?, ?, ?, ?, ?)", + guid, noteTitle, noteContents, ','.join(tagNames), parentNotebook, noteType) + return EvernoteAPIStatus.RequestQueued + + def makeNote(self, noteTitle=None, noteContents=None, tagNames=None, parentNotebook=None, resources=None, + noteType=None, guid=None, + validated=None, enNote=None): + """ + Create or Update a Note instance with title and body + Send Note object to user's account + :type noteTitle: str + :param noteContents: Valid ENML without the <en-note></en-note> tags. Will be processed by makeNoteBody + :type enNote : EvernoteNotePrototype + :rtype : (EvernoteAPIStatus, EvernoteNote) + :returns Status and Note + """ + if tagNames is None: + tagNames = [] + if enNote: + guid, noteTitle, noteContents, tagNames, parentNotebook = enNote.Guid, enNote.FullTitle, enNote.Content, enNote.Tags, enNote.NotebookGuid or parentNotebook + if resources is None: + resources = [] + callType = "create" + validation_status = EvernoteAPIStatus.Uninitialized + if validated is None: + if not EVERNOTE.UPLOAD.VALIDATION.ENABLED: + validated = True + else: + validation_status = self.addNoteToMakeNoteQueue(noteTitle, noteContents, tagNames, parentNotebook, + resources, guid) + if not validation_status.IsSuccess and not self.hasValidator: + return validation_status, None + log('%s%s: %s: ' % ('+VALIDATOR ' if self.hasValidator else '' + noteType, str(validation_status), noteTitle), + 'validation') + ourNote = EvernoteNote() + ourNote.title = encode(noteTitle) + if guid: + callType = "update"; ourNote.guid = guid + + ## Build body of note + nBody = self.makeNoteBody(noteContents, resources) + if validated is not True and not validation_status.IsSuccess: + status, errors = self.validateNoteBody(nBody, ourNote.title) + assert isinstance(status, EvernoteAPIStatus) + if not status.IsSuccess: + return status, None + ourNote.content = nBody + + notestore_status = self.initialize_note_store() + if not notestore_status.IsSuccess: + return notestore_status, None + + while '' in tagNames: tagNames.remove('') + if tagNames: + if EVERNOTE.API.IS_SANDBOXED and not '#Sandbox' in tagNames: + tagNames.append("#Sandbox") + ourNote.tagNames = tagNames + + ## parentNotebook is optional; if omitted, default notebook is used + if parentNotebook: + if hasattr(parentNotebook, 'guid'): + ourNote.notebookGuid = parentNotebook.guid + elif hasattr(parentNotebook, 'Guid'): + ourNote.notebookGuid = parentNotebook.Guid + elif is_str(parentNotebook): + ourNote.notebookGuid = parentNotebook + + ## Attempt to create note in Evernote account + + api_action_str = u'trying to %s a note' % callType + log_api(callType + "Note", "'%s'" % noteTitle) + try: + note = getattr(self.noteStore, callType + 'Note')(self.token, ourNote) + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.RateLimitError, None + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.SocketError, None + except EDAMUserException as edue: + ## Something was wrong with the note data + ## See EDAMErrorCode enumeration for error code explanation + ## http://dev.evernote.com/documentation/reference/Errors.html#Enum_EDAMErrorCode + print "EDAMUserException:", edue + log_error("-" * 50, crosspost_to_default=False) + log_error("EDAMUserException: " + str(edue), crosspost='api') + log_error(str(ourNote.tagNames), crosspost_to_default=False) + log_error(str(ourNote.content), crosspost_to_default=False) + log_error("-" * 50 + "\r\n", crosspost_to_default=False) + if EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.UserError, None + except EDAMNotFoundException as ednfe: + print "EDAMNotFoundException:", ednfe + log_error("-" * 50, crosspost_to_default=False) + log_error("EDAMNotFoundException: " + str(ednfe), crosspost='api') + if callType is "update": + log_error('GUID: ' + str(ourNote.guid), crosspost_to_default=False) + if ourNote.notebookGuid: + log_error('Notebook GUID: ' + str(ourNote.notebookGuid), crosspost_to_default=False) + log_error("-" * 50 + "\r\n", crosspost_to_default=False) + if EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return EvernoteAPIStatus.NotFoundError, None + except Exception as e: + print "Unknown Exception:", e + log_error("-" * 50, crosspost_to_default=False) + log_error("Unknown Exception: " + str(e)) + log_error(str(ourNote.tagNames), crosspost_to_default=False) + log_error(str(ourNote.content), crosspost_to_default=False) + log_error("-" * 50 + "\r\n", crosspost_to_default=False) + # return EvernoteAPIStatus.UnhandledError, None + raise + # noinspection PyUnboundLocalVariable + note.content = nBody + return EvernoteAPIStatus.Success, note + + def create_evernote_notes(self, evernote_guids=None, use_local_db_only=False): + """ + Create EvernoteNote objects from Evernote GUIDs using EvernoteNoteFetcher.getNote(). + Will prematurely return if fetcher.getNote fails + + :rtype : EvernoteNoteFetcherResults + :param evernote_guids: + :param use_local_db_only: Do not initiate API calls + :return: EvernoteNoteFetcherResults + """ + if not hasattr(self, 'evernote_guids') or evernote_guids: + self.evernote_guids = evernote_guids + if not use_local_db_only: + self.check_ancillary_data_up_to_date() + action_str_base = 'Create' + action_str = 'Creation Of' + info = stopwatch.ActionInfo(action_str, 'Evernote Notes', report_if_empty=False) + tmr = stopwatch.Timer(evernote_guids, info=info, + label='Add\\Evernote-%sNotes' % (action_str_base)) + fetcher = EvernoteNoteFetcher(self, use_local_db_only=use_local_db_only) + if not evernote_guids: + fetcher.results.Status = EvernoteAPIStatus.EmptyRequest; return fetcher.results + if in_anki(): + fetcher.evernoteQueryTags = SETTINGS.EVERNOTE.QUERY.TAGS.fetch().replace(',', ' ').split() + fetcher.keepEvernoteTags = SETTINGS.ANKI.TAGS.KEEP_TAGS.fetch() + fetcher.deleteQueryTags = SETTINGS.ANKI.TAGS.DELETE_EVERNOTE_QUERY_TAGS.fetch() + fetcher.tagsToDelete = SETTINGS.ANKI.TAGS.TO_DELETE.fetch().replace(',', ' ').split() + for evernote_guid in self.evernote_guids: + if not fetcher.getNote(evernote_guid): + return fetcher.results + tmr.reportSuccess() + tmr.step(fetcher.result.Note.FullTitle) + tmr.Report() + return fetcher.results + + def check_ancillary_data_up_to_date(self): + new_tags = 0 if self.check_tags_up_to_date() else self.update_tags_database( + "Tags were not up to date when checking ancillary data") + new_nbs = 0 if self.check_notebooks_up_to_date() else self.update_notebooks_database() + self.report_ancillary_data_results(new_tags, new_nbs, 'Forced ') + + def update_ancillary_data(self): + new_tags = self.update_tags_database("Manual call to update ancillary data") + new_nbs = self.update_notebooks_database() + self.report_ancillary_data_results(new_tags, new_nbs, 'Manual ', report_blank=True) + + @staticmethod + def report_ancillary_data_results(new_tags, new_nbs, title_prefix='', report_blank=False): + str_ = '' + if new_tags is 0 and new_nbs is 0: + if not report_blank: + return + str_ = 'No new tags or notebooks found' + elif new_tags is None and new_nbs is None: + str_ = 'Error downloading ancillary data' + elif new_tags is None: + str_ = 'Error downloading tags list, and ' + elif new_nbs is None: + str_ = 'Error downloading notebooks list, and ' + + if new_tags > 0 and new_nbs > 0: + str_ = '%d new tag%s and %d new notebook%s found' % ( + new_tags, '' if new_tags is 1 else 's', new_nbs, '' if new_nbs is 1 else 's') + elif new_nbs > 0: + str_ += '%d new notebook%s found' % (new_nbs, '' if new_nbs is 1 else 's') + elif new_tags > 0: + str_ += '%d new tag%s found' % (new_tags, '' if new_tags is 1 else 's') + show_tooltip("%sUpdate of ancillary data complete: " % title_prefix + str_, do_log=True) + + def set_notebook_data(self): + if not hasattr(self, 'notebook_data') or not self.notebook_data or len(self.notebook_data.keys()) == 0: + self.notebook_data = {x['guid']: EvernoteNotebook(x) for x in + ankDB().execute("SELECT guid, name FROM {nb} WHERE 1")} + + def check_notebook_metadata(self, notes): + """ + :param notes: + :type : list[EvernoteNotePrototype] + :return: + """ + self.set_notebook_data() + for note in notes: + assert (isinstance(note, EvernoteNotePrototype)) + if note.NotebookGuid in self.notebook_data: + continue + new_nbs = self.update_notebooks_database() + if note.NotebookGuid in self.notebook_data: + log( + "Missing notebook GUID %s for note %s when checking notebook metadata. Notebook was found after updating Anknotes' notebook database." + '' if new_nbs < 1 else ' In total, %d new notebooks were found.' % new_nbs) + continue + log_error("FATAL ERROR: Notebook GUID %s for Note %s: %s does not exist on Evernote servers" % ( + note.NotebookGuid, note.Guid, note.Title)) + raise EDAMNotFoundException() + return True + + def check_notebooks_up_to_date(self): + for evernote_guid in self.evernote_guids: + note_metadata = self.metadata[evernote_guid] + notebookGuid = note_metadata.notebookGuid + if not notebookGuid: + log_error(" > Notebook check: Unable to find notebook guid for '%s'. Returned '%s'. Metadata: %s" % ( + evernote_guid, str(notebookGuid), str(note_metadata)), crosspost_to_default=False) + elif notebookGuid not in self.notebook_data: + notebook = EvernoteNotebook(fetch_guid=notebookGuid) + if not notebook.success: + log(" > Notebook check: Missing notebook guid '%s'. Will update with an API call." % notebookGuid) + return False + self.notebook_data[notebookGuid] = notebook + return True + + def update_notebooks_database(self): + notestore_status = self.initialize_note_store() + if not notestore_status.IsSuccess: + return None # notestore_status + api_action_str = u'trying to update Evernote notebooks.' + log_api("listNotebooks") + try: + notebooks = self.noteStore.listNotebooks(self.token) + """: type : list[evernote.edam.type.ttypes.Notebook] """ + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return None + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return None + data = [] + self.notebook_data = {} + for notebook in notebooks: + self.notebook_data[notebook.guid] = {"stack": notebook.stack, "name": notebook.name} + data.append( + [notebook.guid, notebook.name, notebook.updateSequenceNum, notebook.serviceUpdated, notebook.stack]) + db = ankDB(TABLES.EVERNOTE.NOTEBOOKS) + old_count = db.count() + db.drop(db.table) + db.recreate() + # log_dump(data, 'update_notebooks_database table data', crosspost_to_default=False) + db.executemany( + "INSERT INTO `{t}`(`guid`,`name`,`updateSequenceNum`,`serviceUpdated`, `stack`) VALUES (?, ?, ?, ?, ?)", + data) + db.commit() + return len(self.notebook_data) - old_count + + def update_tags_database(self, reason_str=''): + if hasattr(self, 'LastTagDBUpdate') and datetime.now() - self.LastTagDBUpdate < timedelta(minutes=15): + return None + self.LastTagDBUpdate = datetime.now() + notestore_status = self.initialize_note_store() + if not notestore_status.IsSuccess: + return None # notestore_status + api_action_str = u'trying to update Evernote tags.' + log_api("listTags" + (': ' + reason_str) if reason_str else '') + try: + tags = self.noteStore.listTags(self.token) + """: type : list[evernote.edam.type.ttypes.Tag] """ + except EDAMSystemException as e: + if not HandleEDAMRateLimitError(e, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return None + except socket.error as v: + if not HandleSocketError(v, api_action_str) or EVERNOTE.API.DEBUG_RAISE_ERRORS: + raise + return None + data = [] + self.tag_data = {} + enTag = None + for tag in tags: + enTag = EvernoteTag(tag) + self.tag_data[enTag.Guid] = enTag + data.append(enTag.items()) + if not enTag: + return None + db = ankDB(TABLES.EVERNOTE.TAGS) + old_count = db.count() + db.drop(db.table) + db.recreate() + db.executemany(enTag.sqlUpdateQuery(), data) + db.commit() + return len(self.tag_data) - old_count + + def set_tag_data(self): + if not hasattr(self, 'tag_data') or not self.tag_data or len(self.tag_data.keys()) == 0: + self.tag_data = {x['guid']: EvernoteTag(x) for x in ankDB().execute("SELECT guid, name FROM {tt} WHERE 1")} + + def get_missing_tags(self, current_tags, from_guids=True): + if isinstance(current_tags, list): + current_tags = set(current_tags) + self.set_tag_data() + all_tags = set(self.tag_data.keys() if from_guids else [v.Name for k, v in self.tag_data.items()]) + missing_tags = current_tags - all_tags + if missing_tags: + log_error("Missing Tag %s(s) were found:\nMissing: %s\n\nCurrent: %s\n\nAll Tags: %s\n\nTag Data: %s" % ( + 'Guids' if from_guids else 'Names', ', '.join(sorted(missing_tags)), ', '.join(sorted(current_tags)), + ', '.join(sorted(all_tags)), str(self.tag_data))) + return missing_tags + + def get_matching_tag_data(self, tag_guids=None, tag_names=None): + tagGuids = [] + tagNames = [] + assert tag_guids or tag_names + from_guids = True if (tag_guids is not None) else False + tags_original = tag_guids if from_guids else tag_names + if self.get_missing_tags(tags_original, from_guids): + self.update_tags_database("Missing Tag %s(s) Were found when attempting to get matching tag data" % ( + 'Guids' if from_guids else 'Names')) + missing_tags = self.get_missing_tags(tags_original, from_guids) + if missing_tags: + identifier = 'Guid' if from_guids else 'Name' + keys = ', '.join(sorted(missing_tags)) + log_error("FATAL ERROR: Tag %s(s) %s were not found on the Evernote Servers" % (identifier, keys)) + raise EDAMNotFoundException(identifier.lower(), keys) + if from_guids: + tags_dict = {x: self.tag_data[x] for x in tags_original} + else: + tags_dict = {[k for k, v in self.tag_data.items() if v.Name is tag_name][0]: tag_name for tag_name in + tags_original} + tagNamesToImport = get_tag_names_to_import(tags_dict) + """:type : dict[string, EvernoteTag]""" + if tagNamesToImport: + is_struct = None + for k, v in tagNamesToImport.items(): + if is_struct is None: + is_struct = isinstance(v, EvernoteTag) + tagGuids.append(k) + tagNames.append(v.Name if is_struct else v) + tagNames = sorted(tagNames, key=lambda s: s.lower()) + return tagGuids, tagNames + + def check_tags_up_to_date(self): + for evernote_guid in self.evernote_guids: + if evernote_guid not in self.metadata: + log_error('Could not find note metadata for Note ''%s''' % evernote_guid) + return False + note_metadata = self.metadata[evernote_guid] + if not note_metadata.tagGuids: + continue + for tag_guid in note_metadata.tagGuids: + if tag_guid in self.tag_data: + continue + tag = EvernoteTag(fetch_guid=tag_guid) + if not tag.success: + return False + self.tag_data[tag_guid] = tag + return True diff --git a/anknotes/args.py b/anknotes/args.py new file mode 100644 index 0000000..0d033a0 --- /dev/null +++ b/anknotes/args.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +import re +from datetime import datetime + +### Anknotes Imports +from anknotes.constants import * +from anknotes.base import item_to_list, caller_name, is_str, is_seq_type, is_dict_type +from anknotes.dicts import DictCaseInsensitive +from anknotes.dicts_base import DictKey + +class Args(object): + + def __init__(self, func_kwargs=None, func_args=None, func_args_list=None, set_list=None, set_dict=None, + require_all_args=None, limit_max_args=None, override_kwargs=None, use_set_list_as_arg_list=False): + from logging import write_file_contents, pf + self.__require_all_args = require_all_args or False + self.__limit_max_args = limit_max_args if limit_max_args is not None else True + self.__override_kwargs = override_kwargs or False + self.__set_args_and_kwargs(func_args, func_kwargs) + self.__func_args_list = use_set_list_as_arg_list and [set_list[i*2] for i in range(0, len(set_list)/2)] or func_args_list + if self.__func_args_list: + self.process_args() + self.__init_set(set_list, set_dict) + + def __init_set(self, set_list=None, set_dict=None): + if not set_list or set_dict: + return + self.__conv_set_list(set_list) + self.set_kwargs(set_list, set_dict) + + def __conv_set_list(self, set_list=None): + if set_list is None: + return + func_args_count = len(self.__func_args_list) + for i in range(0, len(set_list)/2): + if not set_list[i*2] and func_args_count > i: + set_list[i*2] = self.__func_args_list[i] + + def __set_args_and_kwargs(self, *a, **kw) + self.__func_args, self.__func_kwargs = self.__get_args_and_kwargs(*a, **kw) + + def __get_args_and_kwargs(self, func_args=None, func_kwargs=None, name=None, allow_cls_override=True): + if not func_args and not func_kwargs: + return self.args or [], self.kwargs or DictCaseInsensitive() + func_args = func_args or allow_cls_override and self.args or [] + func_kwargs = func_kwargs or allow_cls_override and self.kwargs or DictCaseInsensitive() + if is_seq_type(func_kwargs) and is_dict_type(func_args): + func_args, func_kwargs = func_kwargs, func_args + func_args = self.__args_to_list(func_args) + if isinstance(func_kwargs, dict): + func_kwargs = DictCaseInsensitive(func_kwargs, key=name, parent_key='kwargs') + if not isinstance(func_args, list): + func_args = [] + if not isinstance(func_kwargs, DictCaseInsensitive): + func_kwargs = DictCaseInsensitive(key=name) + return func_args, func_kwargs + + def __args_to_list(self, func_args): + if not is_str(func_args): + return list(func_args) + lst = [] + for arg in item_to_list(func_args, chrs=','): + lst += [arg] + [None] + return lst + + @property + def kwargs(self): + return self.__func_kwargs + + @property + def args(self): + return self.__func_args + + @property + def keys(self): + return self.kwargs.keys() + + def key_transform(self, key, keys=None): + if keys is None: + keys = self.keys + key = key.strip() + key_lower = key.lower() + for k in keys: + if k.lower() == key_lower: + return k + return key + + def get_kwarg(self, key, **kwargs): + kwargs['update_kwargs'] = False + return self.process_kwarg(key, **kwargs) + + def process_kwarg(self, key, default=None, func_kwargs=None, replace_none_type=True, delete_from_kwargs=None, return_value_only=True, update_cls_args=True): + delete_from_kwargs = delete_from_kwargs is not False + cls_kwargs = func_kwargs is None + func_kwargs = self.kwargs if cls_kwargs else DictCaseInsensitive(func_kwargs) + key = self.key_transform(key, func_kwargs.keys()) + if key not in func_kwargs: + return (func_kwargs, default) if delete_from_kwargs and not return_value_only else default + val = func_kwargs[key] + if val is None and replace_none_type: + val = default + if delete_from_kwargs: + del func_kwargs[key] + if cls_kwargs and update_cls_args: + del self.__func_kwargs[key] + if not delete_from_kwargs or return_value_only: + return val + return func_kwargs, val + + def process_kwargs(self, get_args=None, set_dict=None, func_kwargs=None, delete_from_kwargs=True, update_cls_args=True, **kwargs): + method_name='process_kwargs' + kwargs['return_value_only'] = False + cls_kwargs = func_kwargs is None + func_kwargs = self.kwargs if cls_kwargs else DictCaseInsensitive(func_kwargs) + keys = func_kwargs.keys() + for key, value in set_dict.items() if set_dict else []: + key = self.key_transform(key, keys) + if key not in func_kwargs: + func_kwargs[key] = value + if not get_args: + if cls_kwargs and update_cls_args: + self.__func_kwargs = func_kwargs + return func_kwargs + gets = [] + for arg in get_args: + # for arg in args: + if len(arg) is 1 and isinstance(arg[0], list): + arg = arg[0] + result = self.process_kwarg(arg[0], arg[1], func_kwargs=func_kwargs, delete_from_kwargs=delete_from_kwargs, **kwargs) + if delete_from_kwargs: + func_kwargs = result[0] + result = result[1] + gets.append(result) + if cls_kwargs and update_cls_args: + self.__func_kwargs = func_kwargs + if delete_from_kwargs: + return [func_kwargs] + gets + return gets + + def get_kwarg_values(self, *args, **kwargs): + kwargs['return_value_only'] = True + if not 'delete_from_kwargs' in kwargs: + kwargs['delete_from_kwargs'] = False + return self.get_kwargs(*args, **kwargs) + + def get_kwargs(self, *args_list, **kwargs): + method_name='get_kwargs' + lst = [] + for args in args_list: + if isinstance(args, dict): + args = item_to_list(args) + args_dict = args + if isinstance(args, list): + lst += [args[i * 2:i * 2 + 2] for i in range(0, len(args) / 2)] + else: + lst += [[arg, None] for arg in item_to_list(args)] + return self.process_kwargs(get_args=lst, **kwargs) + + def process_args(self, arg_list=None, func_args=None, func_kwargs=None, update_cls_args=True): + method_name='process_args' + arg_list = item_to_list(arg_list) if arg_list else self.__func_args_list + cls_args = func_args is None + cls_kwargs = func_kwargs is None + func_args, func_kwargs = self.__get_args_and_kwargs(func_args, func_kwargs) + arg_error = '' + if not func_args: + return func_args, func_kwargs + for i in range(0, len(arg_list)): + if len(func_args) is 0: + break + arg = arg_list[i] + if arg in func_kwargs and not self.__override_kwargs: + formats = (caller_name(return_string=True), arg) + raise TypeError("Anknotes.Args: %s() got multiple arguments for keyword argument '%s'" % formats) + func_kwargs[arg] = func_args[0] + del func_args[0] + else: + if self.__require_all_args: + arg_error = 'least' + if func_args and self.__limit_max_args: + arg_error = 'most' + if arg_error: + formats = (caller_name(return_string=True), arg_error, len(arg_list), '' if arg_list is 1 else 's', len(func_args)) + raise TypeError('Anknotes.Args: %s() takes at %s %d argument%s (%d given)' % formats) + if cls_args and update_cls_args: + self.__func_args = func_args + if cls_kwargs and update_cls_args: + self.__func_kwargs = func_kwargs + return func_args, func_kwargs + + def set_kwargs(self, set_list=None, set_dict=None, func_kwargs=None, name=None, delete_from_kwargs=None, *args, **kwargs): + method_name='set_kwargs' + new_args = [] + lst, dct = self.__get_args_and_kwargs(set_list, set_dict, allow_cls_override=False) + if isinstance(lst, list): + dct.update({lst[i * 2]: lst[i * 2 + 1] for i in range(0, len(lst) / 2)}) + lst = [] + for arg in args: + new_args += item_to_list(arg, False) + dct.update({key: None for key in item_to_list(lst, chrs=',') + new_args}) + dct.update(kwargs) + dct_items = dct.items() + processed_kwargs = self.process_kwargs(func_kwargs=func_kwargs, set_dict=dct, name=name, delete_from_kwargs=delete_from_kwargs).items() + return self.process_kwargs(func_kwargs=func_kwargs, set_dict=dct, name=name, delete_from_kwargs=delete_from_kwargs) \ No newline at end of file diff --git a/anknotes/base.py b/anknotes/base.py new file mode 100644 index 0000000..a574b46 --- /dev/null +++ b/anknotes/base.py @@ -0,0 +1,362 @@ +# -*- coding: utf-8 -*- +import re +from fnmatch import fnmatch +import inspect +from collections import defaultdict, Iterable +from bs4 import UnicodeDammit +import string +from datetime import datetime + + +### Anknotes Imports +from anknotes.imports import in_anki + +### Anki Imports +if in_anki(): + from aqt import mw + +class SafeDict(defaultdict): + def __init__(self, *a, **kw): + for i, arg in enumerate(a): + if arg is None: + raise TypeError("SafeDict arg %d is NoneType" % (i + 1)) + dct = dict(*a, **kw) + super(self.__class__, self).__init__(self.__missing__, dct) + + def __getitem__(self, key): + item = super(self.__class__, self).__getitem__(key) + if isinstance(item, dict): + item = SafeDict(item) + return item + + def __missing__(self, key): + return '{' + key + '}' + +def decode(str_, is_html=False, errors='strict'): + if isinstance(str_, unicode): + return str_ + if isinstance(str_, str): + return UnicodeDammit(str_, ['utf-8'], is_html=is_html).unicode_markup + return unicode(str_, 'utf-8', errors) + +def decode_html(str_): + return decode(str_, True) + +def encode(str_): + if isinstance(str_, unicode): + return str_.encode('utf-8') + return str_ + +def is_str(str_): + return str_ and is_str_type(str_) + +def is_str_type(str_): + return isinstance(str_, (str, unicode)) + +def is_seq_type(*a): + for item in a: + if not isinstance(item, Iterable) or not hasattr(item, '__iter__'): + return False + return True + +def is_dict_type(*a): + for item in a: + if not isinstance(item, dict) or hasattr(item, '__dict__'): + return False + return True + +def get_unique_strings(*a): + lst=[] + items=[] + if a and isinstance(a[0], dict): + lst = a[0].copy() + a = a[0].items() + else: + a = enumerate(a) + for key, str_ in sorted(a): + if isinstance(str_, list): + str_, attr = str_ + str_ = getattr(str_, attr, None) + if not str_ or str_ in lst or str_ in items: + if isinstance(lst, list): + lst.append('') + else: + lst[key] = '' + continue + items.append(str_) + str_ = str(str_) + if isinstance(lst, list): + lst.append(str_) + else: + lst[key] = str_ + return lst + +def call(func, *a, **kw): + if not callable(func): + return func + spec=inspect.getargspec(func) + if not spec.varargs: + a = a[:len(spec.args)] + if not spec.keywords: + kw = {key:value for key, value in kw.items() if key in spec.args} + return func(*a, **kw) + +def fmt(str_, recursion=None, *a, **kw): + """ + :type str_: str | unicode + :type recursion : int | dict | list + :rtype: str | unicode + """ + if not isinstance(recursion, int): + if recursion is not None: + a = [recursion] + list(a) + recursion = 1 + dct = SafeDict(*a, **kw) + str_ = string.Formatter().vformat(str_, [], dct) + if recursion <= 0: + return str_ + return fmt(str_, recursion-1, *a, **kw) + +def pad_digits(*a, **kw): + conv = [] + for str_ in a: + if isinstance(str_, int): + str_ = str(str_) + if not is_str_type(str_): + conv.append('') + else: + conv.append(str_.rjust(3) if str_.isdigit() else str_) + if len(conv) is 1: + return conv[0] + return conv + +def str_safe(str_, prefix=''): + repr_ = str_.__repr__() + try: + str_ = str(prefix + repr_) + except Exception: + str_ = str(prefix + encode(repr_, errors='replace')) + return str_ + +def str_split_case(str_, ignore_underscore=False): + words=[] + word='' + for chr in str_: + last_chr = word[-1:] + if chr.isupper() and (last_chr.islower() or (ignore_underscore and last_chr is '_')): + words.append(word) + word = '' + word += chr + return words + [word] + +def str_capitalize(str_, phrase_delimiter='.', word_delimiter='_'): + phrases = str_.split(phrase_delimiter) + return ''.join(''.join([word.capitalize() for word in phrase.split(word_delimiter)]) for phrase in phrases) + +def in_delimited_str(key, str_, chr='|', case_insensitive=True): + if case_insensitive: + key = key.lower() + str_ = str_.lower() + return key in str_.strip(chr).split(chr) + +def print_safe(str_, prefix=''): + print str_safe(str_, prefix) + +def item_to_list(item, list_from_unknown=True, chrs='', split_chr='|'): + if is_seq_type(item): + return list(item) + if isinstance(item, dict): + return [y for x in item.items() for y in x] + if is_str(item): + for c in chrs: + item = item.replace(c, split_chr or ' ') + return item.split(split_chr) + if item is None: + return [] + if list_from_unknown: + return [item] + return item + +def item_to_set(item, **kwargs): + if isinstance(item, set): + return item + item = item_to_list(item, **kwargs) + if not isinstance(item, list): + return item + return set(item) + +def matches_list(item, lst): + item = item.lower() + for index, value in enumerate(item_to_list(lst)): + value = value.lower() + if fnmatch(item, value) or fnmatch(item + 's', value): + return index + 1 + return 0 + +def get_default_value(cls, default=None): + if default is not None: + return default + if cls is str or cls is unicode: + return '' + elif cls is int: + return 0 + elif cls is bool: + return False + return None + +def key_transform(mapping, key, all=False): + key_lower = key.lower() + match = [k for k in (mapping if isinstance(mapping, Iterable) and not all else dir(mapping)) if k.lower() == key_lower] + return match and match[0] or key + +def delete_keys(mapping, keys_to_delete): + if not isinstance(keys_to_delete, list): + keys_to_delete = item_to_list(keys_to_delete, chrs=' *,') + for key in keys_to_delete: + key = key_transform(mapping, key) + if key in mapping: + del mapping[key] + +def ank_prop(self, keys, fget=None, fset=None, fdel=None, doc=None): + for key in list(keys): + all_args=locals() + args = {} + try: + property_ = getattr(self.__class__, key) + except AttributeError: + property_ = property() + + for v in ['fget', 'fset', 'fdel']: + args[v] = all_args[v] + if not args[v]: + args[v] = getattr(property_, v) + if is_str(args[v]): + args[v] = getattr(self.__class__, args[v]) + if isinstance(args[v], property): + args[v] = getattr(args[v], v) + if not doc: + doc = property_.__doc__ + if not doc: + doc = fget.__doc__ + args['doc'] = doc + property_ = property(**args) + setattr(self.__class__, key, property_) + +def get_friendly_interval_string(lastImport): + if not lastImport: + return "" + from anknotes.constants import ANKNOTES + td = (datetime.now() - datetime.strptime(lastImport, ANKNOTES.DATE_FORMAT)) + days = td.days + hours, remainder = divmod(td.total_seconds(), 3600) + minutes, seconds = divmod(remainder, 60) + if days > 1: + lastImportStr = "%d days" % td.days + else: + hours = round(hours) + hours_str = '' if hours == 0 else ('1:%02d hr' % minutes) if hours == 1 else '%d Hours' % hours + if days == 1: + lastImportStr = "One Day%s" % ('' if hours == 0 else ', ' + hours_str) + elif hours > 0: + lastImportStr = hours_str + else: + lastImportStr = "%d:%02d min" % (minutes, seconds) + return lastImportStr + + +def clean_evernote_css(str_): + remove_style_attrs = '-webkit-text-size-adjust: auto|-webkit-text-stroke-width: 0px|background-color: rgb(255, 255, 255)|color: rgb(0, 0, 0)|font-family: Tahoma|font-size: medium;|font-style: normal|font-variant: normal|font-weight: normal|letter-spacing: normal|orphans: 2|text-align: -webkit-auto|text-indent: 0px|text-transform: none|white-space: normal|widows: 2|word-spacing: 0px|word-wrap: break-word|-webkit-nbsp-mode: space|-webkit-line-break: after-white-space'.replace( + '(', '\\(').replace(')', '\\)') + # 'margin: 0px; padding: 0px 0px 0px 40px; ' + return re.sub(r' ?(%s);? ?' % remove_style_attrs, '', str_).replace(' style=""', '') + + +def caller_names(return_string=True, simplify=True): + return [c.Base if return_string else c for c in [__caller_name(i, simplify) for i in range(0, 20)] if + c and c.Base] + + +class CallerInfo: + Class = [] + Module = [] + Outer = [] + Name = "" + simplify = True + __keywords_exclude = ['pydevd', 'logging', 'base', '__caller_name', 'stopwatch', 'process_args'] + __keywords_strip = ['__maxin__', 'anknotes', '<module>'] + __outer = [] + filtered = True + + @property + def __trace(self): + return self.Module + self.Outer + self.Class + [self.Name] + + @property + def Trace(self): + t = self.__strip(self.__trace) + return t if not self.filtered or not [e for e in self.__keywords_exclude if e in t] else [] + + @property + def Base(self): + return '.'.join(self.__strip(self.Module + self.Class + [self.Name])) if self.Trace else '' + + @property + def Full(self): + return '.'.join(self.Trace) + + def __strip(self, lst): + return [t for t in lst if t and t not in self.__keywords_strip] + + def __init__(self, parentframe=None): + """ + + :rtype : CallerInfo + """ + if not parentframe: + return + self.Class = parentframe.f_locals['self'].__class__.__name__.split( + '.') if 'self' in parentframe.f_locals else [] + module = inspect.getmodule(parentframe) + self.Module = module.__name__.split('.') if module else [] + self.Name = parentframe.f_code.co_name if parentframe.f_code.co_name is not '<module>' else '' + self.__outer = [[f[1], f[3]] for f in inspect.getouterframes(parentframe) if f] + self.__outer.reverse() + self.Outer = [f[1] for f in self.__outer if + f and f[1] and not [exclude for exclude in self.__keywords_exclude + [self.Name] if + exclude in f[0] or exclude in f[1]]] + del parentframe + + +def create_log_filename(str_): + if str_ is None: + return "" + str_ = str_.replace('.', '\\') + str_ = re.sub(r"(^|\\)([^\\]+)\\\2(\b.|\\.|$)", r"\1\2\\", str_) + str_ = re.sub(r"^\\*(.+?)\\*$", r"\1", str_) + return str_ + + +# @clockit +def caller_name(skip=None, simplify=True, return_string=False, return_filename=False): + if skip is None: + names = [__caller_name(i, simplify) for i in range(0, 20)] + else: + names = [__caller_name(skip, simplify=simplify)] + for c in [c for c in names if c and c.Base]: + return create_log_filename(c.Base) if return_filename else c.Base if return_string else c + return "" if return_filename or return_string else None + + +def __caller_name(skip=0, simplify=True): + """ + :rtype : CallerInfo + """ + stack = inspect.stack() + start = 0 + skip + if len(stack) < start + 1: + return None + parentframe = stack[start][0] + c_info = CallerInfo(parentframe) + del parentframe + return c_info + diff --git a/anknotes/constants.py b/anknotes/constants.py new file mode 100644 index 0000000..969bfb5 --- /dev/null +++ b/anknotes/constants.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from anknotes.constants_standard import * +from anknotes.constants_settings import * diff --git a/anknotes/constants_default.py b/anknotes/constants_default.py new file mode 100644 index 0000000..289cc6f --- /dev/null +++ b/anknotes/constants_default.py @@ -0,0 +1,210 @@ +import os +from datetime import timedelta, datetime + +PATH = os.path.dirname(os.path.abspath(__file__)) + +class FOLDERS(): + ADDONS = os.path.dirname(PATH) + EXTRA = os.path.join(PATH, 'extra') + ANCILLARY = os.path.join(EXTRA, 'ancillary') + GRAPHICS = os.path.join(EXTRA, 'graphics') + LOGS = os.path.join(EXTRA, 'logs') + DEVELOPER = os.path.join(EXTRA, 'dev') + USER = os.path.join(EXTRA, 'user') + + +class FILES(): + class LOGS(): + class FDN(): + ANKI_ORPHANS = 'Find Deleted Notes\\' + UNIMPORTED_EVERNOTE_NOTES = ANKI_ORPHANS + 'UnimportedEvernoteNotes' + ANKI_TITLE_MISMATCHES = ANKI_ORPHANS + 'AnkiTitleMismatches' + ANKNOTES_TITLE_MISMATCHES = ANKI_ORPHANS + 'AnknotesTitleMismatches' + ANKNOTES_ORPHANS = ANKI_ORPHANS + 'AnknotesOrphans' + ANKI_ORPHANS += 'AnkiOrphans' + + BASE_NAME = '' + DEFAULT_NAME = 'anknotes' + MAIN = DEFAULT_NAME + ACTIVE = DEFAULT_NAME + USE_CALLER_NAME = False + ENABLED = ['*'] + DISABLED = ['finder*', 'args*', 'counter*', 'Dicts*'] + SEE_ALSO_DISABLED = [4,6] + + class ANCILLARY(): + TEMPLATE = os.path.join(FOLDERS.ANCILLARY, 'FrontTemplate.htm') + CSS = u'_AviAnkiCSS.css' + CSS_QMESSAGEBOX = os.path.join(FOLDERS.ANCILLARY, 'QMessageBox.css') + ENML_DTD = os.path.join(FOLDERS.ANCILLARY, 'enml2.dtd') + + class SCRIPTS(): + VALIDATION = os.path.join(FOLDERS.ADDONS, 'anknotes_start_note_validation.py') + FIND_DELETED_NOTES = os.path.join(FOLDERS.ADDONS, 'anknotes_start_find_deleted_notes.py') + + class GRAPHICS(): + class ICON(): + EVERNOTE_WEB = os.path.join(FOLDERS.GRAPHICS, u'evernote_web.ico') + EVERNOTE_ARTCORE = os.path.join(FOLDERS.GRAPHICS, u'evernote_artcore.ico') + TOMATO = os.path.join(FOLDERS.GRAPHICS, u'Tomato-icon.ico') + + class IMAGE(): + EVERNOTE_WEB = None + EVERNOTE_ARTCORE = None + + IMAGE.EVERNOTE_WEB = ICON.EVERNOTE_WEB.replace('.ico', '.png') + IMAGE.EVERNOTE_ARTCORE = ICON.EVERNOTE_ARTCORE.replace('.ico', '.png') + + class USER(): + TABLE_OF_CONTENTS_ENEX = os.path.join(FOLDERS.USER, "Table of Contents.enex") + LAST_PROFILE_LOCATION = os.path.join(FOLDERS.USER, 'anki.profile') + + +class ANKNOTES(): + DATE_FORMAT = '%Y-%m-%d %H:%M:%S' + CACHE_SEARCHES = False + UPDATE_DB_ON_START = False + + class HOOKS(): + DB = True + SEARCH = True + + class LXML(): + ENABLE_IN_ANKI = False + + class DEVELOPER_MODE: + ENABLED = (os.path.isfile(os.path.join(FOLDERS.DEVELOPER, 'anknotes.developer'))) + AUTOMATED = ENABLED and (os.path.isfile(os.path.join(FOLDERS.DEVELOPER, 'anknotes.developer.automate'))) + AUTO_RELOAD_MODULES = True + + class HIERARCHY(): + ROOT_TITLES_BASE_QUERY = "" + + class FORMATTING(): + BANNER_MINIMUM = 80 + COUNTER_BANNER_MINIMUM = 40 + LINE_PADDING_HEADER = 31 + LINE_LENGTH_TOTAL = 191 + LINE_LENGTH = LINE_LENGTH_TOTAL - 2 + LIST_PAD = 25 + PROGRESS_SUMMARY_PAD = 31 + PPRINT_WIDTH = 80 + TIMESTAMP_PAD = '\t' * 6 + TIMESTAMP_PAD_LENGTH = len(TIMESTAMP_PAD.replace('\t', ' ' * 4)) + TEXT_LENGTH = LINE_LENGTH_TOTAL - TIMESTAMP_PAD_LENGTH + + +class MODELS(): + class TYPES(): + CLOZE = 1 + + class OPTIONS(): + IMPORT_STYLES = True + + DEFAULT = 'evernote_note' + REVERSIBLE = 'evernote_note_reversible' + REVERSE_ONLY = 'evernote_note_reverse_only' + CLOZE = 'evernote_note_cloze' + + +class TEMPLATES(): + DEFAULT = 'EvernoteReview' + REVERSED = 'EvernoteReviewReversed' + CLOZE = 'EvernoteReviewCloze' + + +class FIELDS(): + TITLE = 'Title' + CONTENT = 'Content' + SEE_ALSO = 'See_Also' + TOC = 'TOC' + OUTLINE = 'Outline' + EXTRA = 'Extra' + EVERNOTE_GUID = 'Evernote GUID' + UPDATE_SEQUENCE_NUM = 'updateSequenceNum' + EVERNOTE_GUID_PREFIX = 'evernote_guid=' + LIST = [TITLE, CONTENT, SEE_ALSO, EXTRA, TOC, OUTLINE, + UPDATE_SEQUENCE_NUM] + + class ORD(): + EVERNOTE_GUID = 0 + + ORD.CONTENT = LIST.index(CONTENT) + 1 + ORD.SEE_ALSO = LIST.index(SEE_ALSO) + 1 + + +class DECKS(): + DEFAULT = "Evernote" + TOC_SUFFIX = "::See Also::TOC" + OUTLINE_SUFFIX = "::See Also::Outline" + + +class ANKI(): + PROFILE_NAME = '' + NOTE_LIGHT_PROCESSING_INCLUDE_CSS_FORMATTING = False + + +class TAGS(): + TOC = '#TOC' + TOC_AUTO = '#TOC.Auto' + OUTLINE = '#Outline' + OUTLINE_TESTABLE = '#Outline.Testable' + REVERSIBLE = '#Reversible' + REVERSE_ONLY = '#Reversible_Only' + + +class EVERNOTE(): + class IMPORT(): + class PAGING(): + # Note that Evernote's API documentation says not to run API calls to findNoteMetadata with any less than a 15 minute interval + # Auto Paging is probably only useful in the first 24 hours, when API usage is unlimited, or when executing a search that is likely to have most of the notes up-to-date locally + # To keep from overloading Evernote's servers, and flagging our API key, I recommend pausing 5-15 minutes in between searches, the higher the better. + class RESTART(): + INTERVAL = None + DELAY_MINIMUM_API_CALLS = 10 + INTERVAL_OVERRIDE = 60 * 5 + ENABLED = False + + INTERVAL = 60 * 15 + INTERVAL_SANDBOX = 60 * 5 + RESTART.INTERVAL = INTERVAL * 2 + + INTERVAL = PAGING.INTERVAL * 4 / 3 + METADATA_RESULTS_LIMIT = 10000 + QUERY_LIMIT = 250 # Max returned by API is 250 + API_CALLS_LIMIT = 300 + + class UPLOAD(): + ENABLED = True # Set False if debugging note creation + MAX = -1 # Set to -1 for unlimited + RESTART_INTERVAL = 30 # In seconds + + class VALIDATION(): + ENABLED = True + AUTOMATED = False + + class API(): + class RateLimitErrorHandling: + IgnoreError, ToolTipError, AlertError = range(3) + + CONSUMER_KEY = "holycrepe" + IS_SANDBOXED = False + EDAM_RATE_LIMIT_ERROR_HANDLING = RateLimitErrorHandling.ToolTipError + DEBUG_RAISE_ERRORS = False + + +class TABLES(): + SEE_ALSO = "anknotes_see_also" + NOTE_VALIDATION_QUEUE = "anknotes_note_validation_queue" + TOC_AUTO = u'anknotes_toc_auto' + + class EVERNOTE(): + NOTEBOOKS = "anknotes_evernote_notebooks" + TAGS = "anknotes_evernote_tags" + NOTES = u'anknotes_evernote_notes' + NOTES_HISTORY = u'anknotes_evernote_notes_history' + +class HEADINGS(): + TOP = "Summary|Definitions|Classifications|Types|Presentations|Organ Involvement|Age of Onset|Si/Sx|Sx|Signs|Triggers|MCC's|MCCs|Inheritance|Incidence|Prognosis|Derivations|Origins|Embryological Origins|Mechanisms|MOA|Pathophysiology|Indications|Examples|Causes|Causative Organisms|Risk Factors|Complications|Side Effects|Drug S/Es|Associated Conditions|A/w|Diagnosis|Dx|Physical Exam|Labs|Hemodynamic Parameters|Lab Findings|Imaging|Screening Tests|Confirmatory Tests|Xray|CT|MRI" + BOTTOM = "Management|Work Up|Tx" + NOT_REVERSIBLE = BOTTOM + "|Dx|Diagnosis" diff --git a/anknotes/constants_settings.py b/anknotes/constants_settings.py new file mode 100644 index 0000000..9837132 --- /dev/null +++ b/anknotes/constants_settings.py @@ -0,0 +1,31 @@ +#Python Imports +from datetime import datetime, timedelta + + +#Anki Main Imports +from anknotes.constants_standard import EVERNOTE, DECKS + +#Anki Class Imports +from anknotes.structs_base import UpdateExistingNotes +from anknotes.dicts import DictSettings + +SETTINGS = DictSettings(key='anknotes') +with SETTINGS as s: + s.FORM.LABEL_MINIMUM_WIDTH = 100 + with s.EVERNOTE as e: + e.AUTH_TOKEN.setDefault(lambda dct: dct.key.name + '_' + EVERNOTE.API.CONSUMER_KEY.upper() + ("_SANDBOX" if EVERNOTE.API.IS_SANDBOXED else "")) + e.AUTO_PAGING = True + with e.QUERY as q: + q.TAGS = '#Anki_Import' + q.NOTEBOOK = 'My Anki Notebook' + with q.LAST_UPDATED.VALUE.ABSOLUTE as a: + a.DATE = "{:%Y %m %d}".format(datetime.now() - timedelta(days=7)) + with e.ACCOUNT as a: + a.UID = '0' + a.SHARD = 'x999' + with s.ANKI as a, a.DECKS as d, a.TAGS as t: + a.UPDATE_EXISTING_NOTES = UpdateExistingNotes.UpdateNotesInPlace + d.BASE = DECKS.DEFAULT + d.EVERNOTE_NOTEBOOK_INTEGRATION = True + t.KEEP_TAGS = True + t.DELETE_EVERNOTE_QUERY_TAGS = False diff --git a/anknotes/constants_standard.py b/anknotes/constants_standard.py new file mode 100644 index 0000000..87004fc --- /dev/null +++ b/anknotes/constants_standard.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +import os + +PATH = os.path.dirname(os.path.abspath(__file__)) +if os.path.isfile(os.path.join(PATH, 'constants_user.py')): + from anknotes.constants_user import * +else: + from anknotes.constants_default import * \ No newline at end of file diff --git a/anknotes/constants_user_example.py b/anknotes/constants_user_example.py new file mode 100644 index 0000000..a0bacd5 --- /dev/null +++ b/anknotes/constants_user_example.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# INSTRUCTIONS: +# USE THIS FILE TO OVERRIDE THE MAIN SETTINGS FILE +# RENAME FILE TO constants_user.py +# DON'T FORGET TO REGENERATE ANY VARIABLES THAT DERIVE FROM THE ONES YOU ARE CHANGING +from anknotes.constants_default import * + +# BEGIN OVERRIDES HERE: +EVERNOTE.API.IS_SANDBOXED = True +EVERNOTE.UPLOAD.VALIDATION.AUTOMATED = False +EVERNOTE.UPLOAD.ENABLED = False \ No newline at end of file diff --git a/anknotes/counters.py b/anknotes/counters.py new file mode 100644 index 0000000..650c7f5 --- /dev/null +++ b/anknotes/counters.py @@ -0,0 +1,94 @@ +import os +import sys +from anknotes.constants_standard import ANKNOTES +from anknotes.base import item_to_list, item_to_set, is_str +from anknotes.dicts import DictNumeric, DictCaseInsensitive +from anknotes.dicts_base import DictKey + +class Counter(DictNumeric): + _override_default_ = False + _default_ = '_count_' + _count_ = 0 + _my_aggregates_ = 'max|max_allowed' + _my_attrs_ = '_count_' + + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + self.prop(['count', 'cnt'], 'default') + cls.default_override = cls.sum + + def setCount(self, value): + self._count_ = value + + def getCount(self): + return self._count_ + + def getDefault(self, allow_override=True): + if allow_override and self._override_default_: + return self.default_override + return self.sum + + +class EvernoteCounter(Counter): + _mro_offset_ = 1 + _default_override_ = True + + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + + @property + def success(self): + return self.created + self.updated + + @property + def queued(self): + return self.created.queued + self.updated.queued + + @property + def completed(self): + return self.created.completed + self.updated.completed + + @property + def delayed(self): + return self.skipped + self.queued + + @property + def handled(self): + return self.total - self.unhandled - self.error + + @property + def total(self): + return self.count + + def aggregateSummary(self, includeHeader=True): + aggs = '!max|!+max_allowed|total|+handled|++success|+++completed|+++queued|++delayed' + counts = self._get_summary_(header_only=True) if includeHeader else [] + parents, last_level = [], 1 + for key_code in aggs.split('|'): + override_default = key_code[0] is not '!' + counts += [DictCaseInsensitive(marker='*' if override_default else ' ', child_values={}, children=['<aggregate>'])] + if not override_default: + key_code = key_code[1:] + key = key_code.lstrip('+') + counts.level, counts.value = len(key_code) - len(key) + 1, getattr(self, key) + counts.class_name = type(counts.value) + if counts.class_name is not int: + counts.value = counts.value.getDefault() + parent_lbl = '.'.join(parents) + counts.key, counts.label = DictKey(key, parent_lbl), DictKey(key, parent_lbl, 'label') + if counts.level < last_level: + del parents[-1] + elif counts.level > last_level: + parents.append(key) + last_level = counts.level + return self._summarize_lines_(counts, includeHeader) + + def fullSummary(self, title='Evernote Counter'): + return '\n'.join( + [self.make_banner(title + ": Summary"), + self.__repr__(), + ' ', + self.make_banner(title + ": Aggregates"), + self.aggregateSummary(False)]) diff --git a/anknotes/create_subnotes.py b/anknotes/create_subnotes.py new file mode 100644 index 0000000..d80751d --- /dev/null +++ b/anknotes/create_subnotes.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Python Imports +from bs4 import BeautifulSoup, NavigableString, Tag +from copy import copy + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite +# Anknotes Shared Imports +from anknotes.shared import * +from anknotes.imports import import_lxml +from anknotes.constants import * +from anknotes.base import matches_list, fmt, decode_html +from anknotes.dicts import DictCaseInsensitive +from anknotes.logging import show_tooltip + +# Anknotes Main Imports +import anknotes.Controller +# from anknotes.Controller import Controller + +# Anki Imports +from aqt.qt import SIGNAL, QMenu, QAction +from aqt import mw +from aqt.utils import getText + +def create_subnotes(guids): + def create_subnote(guid): + def process_lists(note, lst, levels=None, names=None): + def add_log_entry(title, content, filename=None, prefix_content=True, title_pad=16, **kw): + names_padded = u''.join(map(lambda x: (x+':').ljust(33) + ' ', names[1:-1])) + names[-1] + fmts = dict(levels_pad=u'\t' * level, levels=u'.'.join(map(str, levels)), + num_levels=len(levels), names=u': '.join(names[1:]).ljust(20), + names_padded=names_padded) + fmts['levels_str'] = (fmts['levels'] + ':').ljust(6) + if prefix_content: + fmts['content'] = content + content = u'{levels_pad}{levels_str} {content}' + if isinstance(lst_items, Tag) and lst_items.name in list_tag_names: + fmts['list_name'] = list_tag_names[lst_items.name] + content = fmt(content, 0, fmts) + if title: + title = (fmt(title, 0, fmts) + u': ').ljust(title_pad) + l.go(title + content, filename, **kw) + + def process_tag(): + def get_log_fn(): + return u'.'.join(map(str, levels)) + u' - ' + u'-'.join(names[1:]) + def log_tag(): + if not lst_items.contents: + add_log_entry('NO TOP TEXT', decode_html(lst_items.contents), crosspost='no_top_text') + if lst_items.name in list_tag_names: + add_log_entry('{list_name}', '{levels_pad}[{num_levels}] {levels}', prefix_content=False) + elif lst_items.name != 'li': + add_log_entry('OTHER TAG', decode(lst_items.contents[0]) if lst_items.contents else u'N/A') + elif not sublist.is_subnote: + add_log_entry('LIST ITEM', strip_tags(u''.join(sublist.list_items), True).strip()) + else: + subnote_fn = u'..\\subnotes\\process_tag*\\' + get_log_fn() + subnote_shared = '*\\..\\..\\subnotes\\process_tag-all' + l.banner(u': '.join(names), subnote_fn) + if not create_subnote.logged_subnote: + l.blank(subnote_shared) + l.banner(title, subnote_shared, clear=False, append_newline=False) + l.banner(title, '..\\subnotes\\process_tag') + create_subnote.logged_subnote = True + add_log_entry('SUBNOTE', sublist.heading) + add_log_entry('', sublist.heading, '..\\subnotes\\process_tag', crosspost=subnote_fn) + add_log_entry('{levels}', '{names_padded}', subnote_shared, prefix_content=False, title_pad=13) + l.go(decode_html(sublist.subnote), subnote_fn) + + def add_note(sublist, new_levels, new_names): + subnote_html = decode_html(sublist.subnote) + log_fn = u'..\\subnotes\\add_note*\\' + get_log_fn() + add_log_entry('SUBNOTE', '{levels_str} {names}: \n%s\n' % subnote_html, '..\\subnotes\\add_note', crosspost=log_fn, prefix_content=False) + myNotes.append([new_levels, new_names, subnote_html]) + + def process_list_item(contents): + def check_subnote(li, sublist): + def check_heading_flags(): + if not isinstance(sublist.heading_flags, list): + sublist.heading_flags = [] + for str_ in "`':": + if sublist.heading.endswith(str_): + sublist.heading_flags.append(str_) + sublist.heading = sublist.heading[:-1*len(str_)] + check_heading_flags() + return + + #Begin check_subnote() + if not (isinstance(li, Tag) and (li.name in list_tag_names) and li.contents and li.contents[0]): + sublist.list_items.append(decode_html(li)) + return sublist + sublist.heading = strip_tags(decode_html(''.join(sublist.list_items)), True).strip() + sublist.base_title = u': '.join(names).replace(title + ': ', '') + sublist.is_reversible = not matches_list(sublist.heading, HEADINGS.NOT_REVERSIBLE) + check_heading_flags() + if "`" in sublist.heading_flags: + sublist.is_reversible = not sublist.is_reversible + sublist.use_descriptor = "'" in sublist.heading_flags or "`" in sublist.heading_flags + sublist.is_subnote = ':' in sublist.heading_flags or matches_list(sublist.heading, HEADINGS.TOP + '|' + HEADINGS.BOTTOM) + if not sublist.is_subnote: + return sublist + sublist.subnote = li + return sublist + + # Begin process_list_item() + sublist = DictCaseInsensitive(is_subnote=False, list_items=[]) + for li in contents: + sublist = check_subnote(li, sublist) + if sublist.is_subnote: + break + return sublist + + # Begin process_tag() + new_levels = levels[:] + new_names = names[:] + if lst_items.name in list_tag_names: + new_levels.append(0) + new_names.append('CHILD ' + lst_items.name.upper()) + elif lst_items.name == 'li': + levels[-1] = new_levels[-1] = levels[-1] + 1 + sublist = process_list_item(lst_items.contents) + if sublist.is_subnote: + names[-1] = new_names[-1] = sublist.heading + add_note(sublist, new_levels, new_names) + else: + names[-1] = new_names[-1] = sublist.heading if sublist.heading else 'Xx' + strip_tags(unicode(''.join(sublist.list_items)), True).strip() + log_tag() + if lst_items.name in list_tag_names or lst_items.name == 'li': + process_lists(note, lst_items.contents, new_levels, new_names) + + # Begin process_lists() + if levels is None or names is None: + levels = [] + names = [title] + level = len(levels) + for lst_items in lst: + if isinstance(lst_items, Tag): + process_tag() + elif isinstance(lst_items, NavigableString): + add_log_entry('NAV STRING', decode_html(lst_items).strip(), crosspost=['nav_strings', '*\\..\\..\\nav_strings']) + else: + add_log_entry('LST ITEMS', lst_items.__class__.__name__, crosspost=['unexpected-type', '*\\..\\..\\unexpected-type']) + + #Begin create_subnote() + content = db.scalar("guid = ?", guid, columns='content') + title = note_title = get_evernote_title_from_guid(guid) + l.path_suffix = '\\' + title + soup = BeautifulSoup(content) + en_note = soup.find('en-note') + note = DictCaseInsensitive(descriptor=None) + first_div = en_note.find('div') + if first_div: + descriptor_text = first_div.text + if descriptor_text.startswith('`'): + note.descriptor = descriptor_text[1:] + lists = en_note.find(['ol', 'ul']) + lists_all = soup.findAll(['ol', 'ul']) + l.banner(title, crosspost='strings') + create_subnote.logged_subnote = False + process_lists(note, [lists]) + l.go(decode_html(lists), 'lists', clear=True) + l.go(soup.prettify(), 'full', clear=True) + + #Begin create_subnotes() + list_tag_names = {'ul': 'UNORDERED LIST', 'ol': 'ORDERED LIST'} + db = ankDB() + myNotes = [] + if import_lxml() is False: + return False + from anknotes.imports import lxml + l = Logger('Create Subnotes\\', default_filename='bs4', timestamp=False, rm_path=True) + l.base_path += 'notes\\' + for guid in guids: + create_subnote(guid) + diff --git a/anknotes/db.py b/anknotes/db.py new file mode 100644 index 0000000..dbeab4b --- /dev/null +++ b/anknotes/db.py @@ -0,0 +1,585 @@ +### Python Imports +import time +from datetime import datetime +from copy import copy +import os +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +### For PyCharm code completion +# from anknotes import _sqlite3 + +### Anki Shared Imports +from anknotes.constants import * +from anknotes.base import is_str, item_to_list, fmt, is_dict_type, is_seq_type, encode +from anknotes.args import Args +from anknotes.logging import log_sql, log, log_error, log_blank, pf +from anknotes.dicts import DictCaseInsensitive +from anknotes.imports import in_anki + +### Anki Imports +if in_anki(): + from aqt import mw + from anki.utils import ids2str, splitFields + +ankNotesDBInstance = None +dbLocal = False + +lastHierarchyUpdate = datetime.now() + + +def anki_profile_path_root(): + return os.path.abspath(os.path.join(os.path.dirname(PATH), '..' + os.path.sep)) + + +def last_anki_profile_name(): + root = anki_profile_path_root() + name = ANKI.PROFILE_NAME + if name and os.path.isdir(os.path.join(root, name)): + return name + if os.path.isfile(FILES.USER.LAST_PROFILE_LOCATION): + name = file(FILES.USER.LAST_PROFILE_LOCATION, 'r').read().strip() + if name and os.path.isdir(os.path.join(root, name)): + return name + dirs = [x for x in os.listdir(root) if os.path.isdir(os.path.join(root, x)) and x is not 'addons'] + if not dirs: + return "" + return dirs[0] + + +def ankDBSetLocal(): + global dbLocal + dbLocal = True + + +def ankDBIsLocal(): + global dbLocal + return dbLocal + + +def ankDB(table=None,reset=False): + global ankNotesDBInstance, dbLocal + if not ankNotesDBInstance or reset: + path = None + if dbLocal: + path = os.path.abspath(os.path.join(anki_profile_path_root(), last_anki_profile_name(), 'collection.anki2')) + ankNotesDBInstance = ank_DB(path) + if table: + db_copy = ank_DB(init_db=False, table=table) + db_copy._db = ankNotesDBInstance._db + db_copy._path = ankNotesDBInstance._path + return db_copy + return ankNotesDBInstance + + +def escape_text_sql(title): + return title.replace("'", "''") + + +def delete_anki_notes_and_cards_by_guid(evernote_guids): + data = [[FIELDS.EVERNOTE_GUID_PREFIX + x] for x in evernote_guids] + db = ankDB() + db.executemany("DELETE FROM cards WHERE nid in (SELECT id FROM notes WHERE flds LIKE ? || '%')", data) + db.executemany("DELETE FROM notes WHERE flds LIKE ? || '%'", data) + + +def get_evernote_title_from_guid(guid): + return ankDB().scalar("SELECT title FROM {n} WHERE guid = '%s'" % guid) + + +def get_evernote_titles_from_nids(nids): + return get_evernote_titles(nids, 'nid') + + +def get_evernote_titles(guids, column='guid'): + return ankDB().list("SELECT title FROM {n} WHERE %s IN (%s) ORDER BY title ASC" % + (column, ', '.join(["'%s'" % x for x in guids]))) + + +def get_anki_deck_id_from_note_id(nid): + return long(ankDB().scalar("SELECT did FROM cards WHERE nid = ? LIMIT 1", nid)) + + +def get_anki_fields_from_evernote_guids(guids): + lst = isinstance(guids, list) + if not lst: + guids = [guids] + db = ankDB() + results = [db.scalar("SELECT flds FROM notes WHERE flds LIKE '{guid_prefix}' || ? || '%'", guid) for guid in guids] + if not lst: + return results[0] if results else None + return results + +def get_anki_card_ids_from_evernote_guids(guids, sql=None): + pred = "n.flds LIKE '%s' || ? || '%%'" % FIELDS.EVERNOTE_GUID_PREFIX + if sql is None: + sql = "SELECT c.id FROM cards c, notes n WHERE c.nid = n.id AND ({pred})" + return execute_sqlite_query(sql, guids, pred=pred) + + +def get_anki_note_id_from_evernote_guid(guid): + return ankDB().scalar("SELECT n.id FROM notes n WHERE n.flds LIKE '%s' || ? || '%%'" % FIELDS.EVERNOTE_GUID_PREFIX, + guid) + + +def get_anki_note_ids_from_evernote_guids(guids): + return get_anki_card_ids_from_evernote_guids(guids, "SELECT n.id FROM notes n WHERE {pred}") + + +def get_paired_anki_note_ids_from_evernote_guids(guids): + return get_anki_card_ids_from_evernote_guids([[x, x] for x in guids], + "SELECT n.id, n.flds FROM notes n WHERE {pred}") + + +def get_anknotes_root_notes_nids(): + return get_cached_data(get_anknotes_root_notes_nids, lambda: get_anknotes_root_notes_guids('nid')) + + +def get_cached_data(func, data_generator, subkey=''): + if not ANKNOTES.CACHE_SEARCHES: + return data_generator() + if subkey: + subkey += '_' + if not hasattr(func, subkey + 'data') or getattr(func, subkey + 'update') < lastHierarchyUpdate: + setattr(func, subkey + 'data', data_generator()) + setattr(func, subkey + 'update', datetime.now()) + return getattr(func, subkey + 'data') + + +def get_anknotes_root_notes_guids(column='guid', tag=None): + sql = "SELECT %s FROM {n} WHERE UPPER(title) IN {pred}" % column + data_key = column + if tag: + sql += " AND tagNames LIKE '%%,%s,%%'" % tag; data_key += '-' + tag + + def cmd(): + titles = get_anknotes_potential_root_titles(upper_case=False, encode=False) + return execute_sqlite_in_query(sql, titles, pred='UPPER(?)') + return get_cached_data(get_anknotes_root_notes_guids, cmd, data_key) + + +def get_anknotes_root_notes_titles(): + return get_cached_data(get_anknotes_root_notes_titles, + lambda: get_evernote_titles(get_anknotes_root_notes_guids())) + + +def get_anknotes_potential_root_titles(upper_case=False, encode=False, **kwargs): + global generateTOCTitle + from anknotes.EvernoteNoteTitle import generateTOCTitle + def mapper(x): return generateTOCTitle(x) + if upper_case: + mapper = lambda x, f=mapper: f(x).upper() + if encode: + mapper = lambda x, f=mapper: encode(f(x)) + data = get_cached_data(get_anknotes_potential_root_titles, lambda: ankDB().list( + "SELECT DISTINCT SUBSTR(title, 0, INSTR(title, ':')) FROM {n} WHERE title LIKE '%:%'")) + return map(mapper, data) + + +# def __get_anknotes_root_notes_titles_query(): +# return '(%s)' % ' OR '.join(["title LIKE '%s'" % (escape_text_sql(x) + ':%') for x in get_anknotes_root_notes_titles()]) + +def __get_anknotes_root_notes_pred(base=None, column='guid', **kwargs): + if base is None: + base = "SELECT %(column)s FROM %(table)s WHERE {pred} " + base = base % {'column': column, 'table': TABLES.EVERNOTE.NOTES} + pred = "title LIKE ? || ':%'" + return execute_sqlite_query(base, get_anknotes_root_notes_titles(), pred=pred) + + +def execute_sqlite_in_query(sql, data, in_query=True, **kwargs): + return execute_sqlite_query(sql, data, in_query=True, **kwargs) + + +def execute_sqlite_query(sql, data, in_query=False, **kwargs): + queries = generate_sqlite_in_predicate(data, **kwargs) if in_query else generate_sqlite_predicate(data, **kwargs) + results = [] + db = ankDB() + for query, data in queries: + sql = fmt(sql, pred=query) + result = db.list(sql, *data) + log_sql('FROM execute_sqlite_query ' + sql, + ['Data [%d]: ' % len(data), data,result[:3]]) + results += result + return results + + +def generate_sqlite_predicate(data, pred='?', pred_delim=' OR ', query_base='(%s)', max_round=990): + if not query_base: + query_base = '%s' + length = len(data) + rounds = float(length) / max_round + rounds = int(rounds) + 1 if int(rounds) < rounds else 0 + queries = [] + for i in range(0, rounds): + start = max_round * i + end = min(length, start + max_round) + # log_sql('FROM generate_sqlite_predicate ' + query_base, ['gen sql #%d of %d: %d-%d' % (i, rounds, start, end) , pred_delim, 'Data [%d]: ' % len(data), data[:3]]) + queries.append([query_base % (pred + (pred_delim + pred) * (end - start - 1)), data[start:end]]) + return queries + + +def generate_sqlite_in_predicate(data, pred='?', pred_delim=', ', query_base='(%s)'): + return generate_sqlite_predicate(data, pred=pred, query_base=query_base, pred_delim=pred_delim) + + +def get_sql_anki_cids_from_evernote_guids(guids): + return "c.nid IN " + ids2str(get_anki_note_ids_from_evernote_guids(guids)) + + +def get_anknotes_child_notes_nids(**kwargs): + if 'column' in kwargs: + del kwargs['column'] + return get_anknotes_child_notes(column='nid', **kwargs) + + +def get_anknotes_child_notes(column='guid', **kwargs): + return get_cached_data(get_anknotes_child_notes, lambda: __get_anknotes_root_notes_pred(column=column, **kwargs), + column) + + +def get_anknotes_orphan_notes_nids(**kwargs): + if 'column' in kwargs: + del kwargs['column'] + return get_anknotes_orphan_notes(column='nid', **kwargs) + + +def get_anknotes_orphan_notes(column='guid', **kwargs): + return get_cached_data(get_anknotes_orphan_notes, lambda: __get_anknotes_root_notes_pred( + "SELECT %(column)s FROM %(table)s WHERE title LIKE '%%:%%' AND NOT {pred}", column=column, **kwargs), column) + + +def get_evernote_guid_from_anki_fields(fields): + if isinstance(fields, dict): + if not FIELDS.EVERNOTE_GUID in fields: + return None + return fields[FIELDS.EVERNOTE_GUID].replace(FIELDS.EVERNOTE_GUID_PREFIX, '') + if is_str(fields): + fields = splitFields(fields) + return fields[FIELDS.ORD.EVERNOTE_GUID].replace(FIELDS.EVERNOTE_GUID_PREFIX, '') + + +def get_all_local_db_guids(filter=None): + if filter is None: + filter = "1" + return ankDB().list("SELECT guid FROM {n} WHERE %s ORDER BY title ASC" % filter) + + +def get_evernote_model_ids(sql=False): + if not hasattr(get_evernote_model_ids, 'model_ids'): + from anknotes.Anki import Anki + anki = Anki() + anki.add_evernote_models(allowForceRebuild=False) + get_evernote_model_ids.model_ids = anki.evernoteModels + del anki + del Anki + if sql: + return 'n.mid IN (%s)' % ', '.join(get_evernote_model_ids.model_ids.values()) + return get_evernote_model_ids.model_ids + + +def update_anknotes_nids(): + db = ankDB() + count = db.count('nid <= 0') + if not count: + return count + paired_data = db.all("SELECT n.id, n.flds FROM notes n WHERE " + get_evernote_model_ids(True)) + paired_data = [[nid, get_evernote_guid_from_anki_fields(flds)] for nid, flds in paired_data] + db.executemany('UPDATE {n} SET nid = ? WHERE guid = ?', paired_data) + db.commit() + return count + + +class ank_DB(object): + echo = False + + def __init__(self, path=None, text=None, timeout=0, init_db=True, table=None): + self._table_ = table + self.ankdb_lastquery = None + self.echo = False + if not init_db: + return + encpath = path + if isinstance(encpath, unicode): + encpath = encode(path) + if path: + log('Creating local ankDB instance from path: ' + path, 'sql\\ankDB') + self._db = sqlite.connect(encpath, timeout=timeout) + self._db.row_factory = sqlite.Row + if text: + self._db.text_factory = text + self._path = path + else: + log('Creating local ankDB instance from Anki DB instance at: ' + mw.col.db._path, 'sql\\ankDB') + self._db = mw.col.db._db + """ + :type : sqlite.Connection + """ + self._db.row_factory = sqlite.Row + self._path = mw.col.db._path + # self._db = self._get_db_(**kw) + + @property + def table(self): + return self._table_ if self._table_ else TABLES.EVERNOTE.NOTES + + def setrowfactory(self): + self._db.row_factory = sqlite.Row + + def drop(self, table): + self.execute("DROP TABLE IF EXISTS " + table) + + @staticmethod + def _is_stmt_(sql, stmts=None): + s = sql.strip().lower() + stmts = ["insert", "update", "delete", "drop", "create", "replace"] + item_to_list(stmts) + for stmt in stmts: + if s.startswith(stmt): + return True + return False + + def update(self, sql=None, *a, **ka): + if 'where' in ka: + ka['columns'] = sql + sql = None + if sql is None: + sql = '{columns} WHERE {where}' + sql = "UPDATE {t} SET " + sql + self.execute(sql, a, ka) + + def delete(self, sql, *a, **ka): + sql = "DELETE FROM {t} WHERE " + sql + self.execute(sql, a, ka) + + def insert(self, auto, replace_into=False, **kw): + keys = auto.keys() + values = [":%s" % key for key in keys] + keys = ["'%s'" % key for key in keys] + sql = 'INSERT%s INTO {t}(%s) VALUES(%s)' % (' OR REPLACE' if replace_into else '', + ', '.join(keys), ', '.join(values)) + self.execute(sql, auto=auto, kw=kw) + + def insert_or_replace(self, *a, **kw): + kw['replace_into'] = True + self.insert(*a, **kw) + + def execute(self, sql, a=None, kw=None, auto=None, **kwargs): + if is_dict_type(a): + kw, a = a, kw + if not is_seq_type(a): + a = item_to_list(a) + if is_dict_type(sql): + auto = sql + sql = ' AND '.join(["`{0}` = :{0}".format(key) for key in auto.keys()]) + if kw is None: + kw = {} + kwargs.update(kw) + sql = self._create_query_(sql, **kwargs) + if auto: + kw = auto + log_sql(sql, a, kw, self=self) + self.ankdb_lastquery = sql + if self._is_stmt_(sql): + self.mod = True + t = time.time() + try: + if a: + # execute("...where id = ?", 5) + res = self._db.execute(sql, a) + elif kw: + # execute("...where id = :id", id=5) + res = self._db.execute(sql, kw) + else: + res = self._db.execute(sql) + except (sqlite.OperationalError, sqlite.ProgrammingError, sqlite.Error, Exception) as e: + log_sql(sql, a, kw, self=self, filter_disabled=False) + import traceback + log_error('Error with ankDB().execute(): %s\n Query: %s\n Trace: %s' % + (str(e), sql, traceback.format_exc())) + raise + if self.echo: + # print a, ka + print sql, "%0.3fms" % ((time.time() - t) * 1000) + if self.echo == "2": + print a, kw + return res + + def _fmt_query_(self, sql, **kw): + formats = dict(table=self.table, where='1', columns='*') + override = dict(n=TABLES.EVERNOTE.NOTES, s=TABLES.SEE_ALSO, a=TABLES.TOC_AUTO, + nv=TABLES.NOTE_VALIDATION_QUEUE, nb=TABLES.EVERNOTE.NOTEBOOKS, tt=TABLES.EVERNOTE.TAGS, + t_toc='%%,%s,%%' % TAGS.TOC, t_tauto='%%,%s,%%' % TAGS.TOC_AUTO, + t_out='%%,%s,%%' % TAGS.OUTLINE, anki_guid='{guid_prefix}{guid}%', + guid_prefix=FIELDS.EVERNOTE_GUID_PREFIX) + keys = formats.keys() + formats.update(kw) + formats['t'] = formats['table'] + formats.update(override) + sql = fmt(sql, formats) + for key in keys: + if key in kw: + del kw[key] + return sql + + def _create_query_(self, sql, **kw): + if not self._is_stmt_(sql, 'select'): + sql = 'SELECT {columns} FROM {t} WHERE ' + sql + sql = self._fmt_query_(sql, **kw) + if 'order' in kw and 'order by' not in sql.lower(): + sql += ' ORDER BY ' + kw['order'] + del kw['order'] + return sql + + def executemany(self, sql, data, **kw): + sql = self._create_query_(sql, **kw) + log_sql(sql, data, self=self) + self.mod = True + t = time.time() + try: + self._db.executemany(sql, data) + except (sqlite.OperationalError, sqlite.ProgrammingError, sqlite.Error, Exception) as e: + log_sql(sql, data, self=self, filter_disabled=False) + import traceback + log_error('Error with ankDB().executemany(): %s\n Query: %s\n Trace: %s' % (str(e), sql, traceback.format_exc())) + raise + if self.echo: + print sql, "%0.3fms" % ((time.time() - t) * 1000) + if self.echo == "2": + print data + + def commit(self): + t = time.time() + self._db.commit() + if self.echo: + print "commit %0.3fms" % ((time.time() - t) * 1000) + + def executescript(self, sql): + self.mod = True + if self.echo: + print sql + self._db.executescript(sql) + + def rollback(self): + self._db.rollback() + + def exists(self, *a, **kw): + count = self.count(*a, **kw) + return count is not None and count > 0 + + def count(self, *a, **kw): + return self.scalar('SELECT COUNT(*) FROM {t} WHERE {where}', *a, **kw) + + def scalar(self, sql='1', *a, **kw): + log_text = 'Call to DB.ankdb_scalar():' + if not isinstance(self, ank_DB): + log_text += '\n - Self: ' + pf(self) + if a: + log_text += '\n - Args: ' + pf(a) + if kw: + log_text += '\n - KWArgs: ' + pf(kw) + last_query='<None>' + if hasattr(self, 'ankdb_lastquery'): + last_query = self.ankdb_lastquery + if is_str(last_query): + last_query = last_query[:50] + else: + last_query = pf(last_query) + log_text += '\n - Last Query: ' + last_query + log(log_text + '\n', 'sql\\ankdb_scalar') + try: + res = self.execute(sql, a, kw) + except TypeError as e: + log(" > ERROR with ankdb_scalar while executing query: %s\n > LAST QUERY: %s" % (str(e), last_query), 'sql\\ankdb_scalar', crosspost='sql\\ankdb_scalar-error') + raise + if not isinstance(res, sqlite.Cursor): + log(' > Cursor: %s' % pf(res), 'sql\\ankdb_scalar') + try: + res = res.fetchone() + except TypeError as e: + log(" > ERROR with ankdb_scalar while fetching result: %s\n > LAST QUERY: %s" % (str(e), last_query), 'sql\\ankdb_scalar', crosspost='sql\\ankdb_scalar-error') + raise + log_blank('sql\\ankdb_scalar') + if res: + return res[0] + return None + + def all(self, sql='1', *a, **kw): + return self.execute(sql, a, kw).fetchall() + + def first(self, sql='1', *a, **kw): + c = self.execute(sql, a, kw) + res = c.fetchone() + c.close() + return res + + def list(self, sql='1', *a, **kw): + return [x[0] for x in self.execute(sql, a, kw)] + + def close(self): + self._db.close() + + def set_progress_handler(self, *args): + self._db.set_progress_handler(*args) + + def __enter__(self): + self._db.execute("begin") + return self + + def __exit__(self, exc_type, *args): + self._db.close() + + def totalChanges(self): + return self._db.total_changes + + def interrupt(self): + self._db.interrupt() + + def recreate(self, force=True, t='{t}'): + self.Init(t, force) + + def InitTags(self, force=False): + if_exists = " IF NOT EXISTS" if not force else "" + log("Rebuilding %stags table" % ('*' if force else ''), 'sql\\ankDB') + self.execute( + """CREATE TABLE%s `%s` ( `guid` TEXT NOT NULL UNIQUE, `name` TEXT NOT NULL, `parentGuid` TEXT, `updateSequenceNum` INTEGER NOT NULL, PRIMARY KEY(guid) );""" % ( + if_exists, TABLES.EVERNOTE.TAGS)) + + def InitNotebooks(self, force=False): + if_exists = " IF NOT EXISTS" if not force else "" + self.execute( + """CREATE TABLE%s `%s` ( `guid` TEXT NOT NULL UNIQUE, `name` TEXT NOT NULL, `updateSequenceNum` INTEGER NOT NULL, `serviceUpdated` INTEGER NOT NULL, `stack` TEXT, PRIMARY KEY(guid) );""" % ( + if_exists, TABLES.EVERNOTE.NOTEBOOKS)) + + def InitSeeAlso(self, forceRebuild=False): + if_exists = " IF NOT EXISTS" + if forceRebuild: + self.drop(TABLES.SEE_ALSO) + self.commit() + if_exists = "" + self.execute( + """CREATE TABLE%s `%s` ( `id` INTEGER, `source_evernote_guid` TEXT NOT NULL, `number` INTEGER NOT NULL DEFAULT 100, `uid` INTEGER NOT NULL DEFAULT -1, `shard` TEXT NOT NULL DEFAULT -1, `target_evernote_guid` TEXT NOT NULL, `html` TEXT NOT NULL, `title` TEXT NOT NULL, `from_toc` INTEGER DEFAULT 0, `is_toc` INTEGER DEFAULT 0, `is_outline` INTEGER DEFAULT 0, PRIMARY KEY(id), unique(source_evernote_guid, target_evernote_guid) );""" % ( + if_exists, TABLES.SEE_ALSO)) + + def Init(self, table='*', force=False): + table = self._fmt_query_(table) + log("Rebuilding tables: %s" % table, 'sql\\ankDB') + if table == '*' or table == TABLES.EVERNOTE.NOTES: + self.execute( + """CREATE TABLE IF NOT EXISTS `{n}` ( `guid` TEXT NOT NULL UNIQUE, `nid` INTEGER NOT NULL DEFAULT -1, `title` TEXT NOT NULL, `content` TEXT NOT NULL, `updated` INTEGER NOT NULL, `created` INTEGER NOT NULL, `updateSequenceNum` INTEGER NOT NULL, `notebookGuid` TEXT NOT NULL, `tagGuids` TEXT NOT NULL, `tagNames` TEXT NOT NULL, PRIMARY KEY(guid) );""") + if table == '*' or table == TABLES.EVERNOTE.NOTES_HISTORY: + self.execute( + """CREATE TABLE IF NOT EXISTS `%s` ( `guid` TEXT NOT NULL, `title` TEXT NOT NULL, `content` TEXT NOT NULL, `updated` INTEGER NOT NULL, `created` INTEGER NOT NULL, `updateSequenceNum` INTEGER NOT NULL, `notebookGuid` TEXT NOT NULL, `tagGuids` TEXT NOT NULL, `tagNames` TEXT NOT NULL)""" % TABLES.EVERNOTE.NOTES_HISTORY) + if table == '*' or table == TABLES.TOC_AUTO: + self.execute( + """CREATE TABLE IF NOT EXISTS `%s` ( `root_title` TEXT NOT NULL UNIQUE, `contents` TEXT NOT NULL, `tagNames` TEXT NOT NULL, `notebookGuid` TEXT NOT NULL, PRIMARY KEY(root_title) );""" % TABLES.TOC_AUTO) + if table == '*' or table == TABLES.NOTE_VALIDATION_QUEUE: + self.execute( + """CREATE TABLE IF NOT EXISTS `%s` ( `guid` TEXT, `title` TEXT NOT NULL, `contents` TEXT NOT NULL, `tagNames` TEXT NOT NULL DEFAULT ',,', `notebookGuid` TEXT, `validation_status` INTEGER NOT NULL DEFAULT 0, `validation_result` TEXT, `noteType` TEXT);""" % TABLES.NOTE_VALIDATION_QUEUE) + if table == '*' or table == TABLES.SEE_ALSO: + self.InitSeeAlso(force) + if table == '*' or table == TABLES.EVERNOTE.TAGS: + self.InitTags(force) + if table == '*' or table == TABLES.EVERNOTE.NOTEBOOKS: + self.InitNotebooks(force) diff --git a/anknotes/detect_see_also_changes.py b/anknotes/detect_see_also_changes.py new file mode 100644 index 0000000..907e15d --- /dev/null +++ b/anknotes/detect_see_also_changes.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +import shutil +import sys + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +from anknotes.shared import * +from anknotes import stopwatch + +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype +from anknotes.AnkiNotePrototype import AnkiNotePrototype +from enum import Enum +from anknotes.enums import * +from anknotes.structs import EvernoteAPIStatus + +Error = sqlite.Error +ankDBSetLocal() +from anknotes.ankEvernote import Evernote +from anknotes.Anki import Anki + + +class notes: + class version(object): + class pstrings: + __updated = None + __processed = None + __original = None + __regex_updated = None + """: type : notes.version.see_also_match """ + __regex_processed = None + """: type : notes.version.see_also_match """ + __regex_original = None + """: type : notes.version.see_also_match """ + + @property + def regex_original(self): + if self.original is None: + return None + if self.__regex_original is None: + self.__regex_original = notes.version.see_also_match(self.original) + return self.__regex_original + + @property + def regex_processed(self): + if self.processed is None: + return None + if self.__regex_processed is None: + self.__regex_processed = notes.version.see_also_match(self.processed) + return self.__regex_processed + + @property + def regex_updated(self): + if self.updated is None: + return None + if self.__regex_updated is None: + self.__regex_updated = notes.version.see_also_match(self.updated) + return self.__regex_updated + + @property + def processed(self): + if self.__processed is None: + self.__processed = str_process(self.original) + return self.__processed + + @property + def updated(self): + if self.__updated is None: + return str_process(self.__original) + return self.__updated + + @updated.setter + def updated(self, value): + self.__regex_updated = None + self.__updated = value + + @property + def final(self): + return str_process_full(self.updated) + + @property + def original(self): + return self.__original + + def useProcessed(self): + self.updated = self.processed + + def __init__(self, original=None): + self.__original = original + + class see_also_match(object): + __subject = None + __content = None + __matchobject = None + """:type : anknotes._re.__Match """ + __match_attempted = 0 + + @property + def subject(self): + if not self.__subject: + return self.content + return self.__subject + + @subject.setter + def subject(self, value): + self.__subject = value + self.__match_attempted = 0 + self.__matchobject = None + + @property + def content(self): + return self.__content + + def groups(self, group=0): + """ + :param group: + :type group : int | str | unicode + :return: + """ + if not self.successful_match: + return None + return self.__matchobject.group(group) + + @property + def successful_match(self): + if self.__matchobject: + return True + if self.__match_attempted is 0 and self.subject is not None: + self.__matchobject = notes.rgx.search(self.subject) + """:type : anknotes._re.__Match """ + self.__match_attempted += 1 + return self.__matchobject is not None + + @property + def main(self): + return self.groups(0) + + @property + def see_also(self): + return self.groups('SeeAlso') + + @property + def see_also_content(self): + return self.groups('SeeAlsoContent') + + def __init__(self, content=None): + """ + + :type content: str | unicode + """ + self.__content = content + self.__match_attempted = 0 + self.__matchobject = None + """:type : anknotes._re.__Match """ + + content = pstrings() + see_also = pstrings() + + old = version() + new = version() + rgx = regex_see_also() + match_type = 'NA' + + +def str_process(str_): + if not str_: + return str_ + str_ = str_.replace(u"evernote:///", u"evernote://") + str_ = re.sub(r'https://www.evernote.com/shard/(s\d+)/[\w\d]+/(\d+)/([\w\d\-]+)', + r'evernote://view/\2/\1/\3/\3/', str_) + str_ = str_.replace(u"evernote://", u"evernote:///").replace(u'<BR>', u'<br />') + str_ = re.sub(r'<br ?/?>', u'<br/>', str_, 0, re.IGNORECASE) + str_ = re.sub(r'(?s)<<(?P<PrefixKeep>(?:</div>)?)<div class="occluded">(?P<OccludedText>.+?)</div>>>', + r'<<\g<PrefixKeep>>>', str_) + str_ = str_.replace('<span class="occluded">', '<span style="color: rgb(255, 255, 255);">') + return str_ + + +def str_process_full(str_): + return clean_evernote_css(str_) + + +def main(evernote=None, anki=None): + # @clockit + def print_results(log_folder='Diff\\SeeAlso', full=False, final=False): + if final: + oldResults = n.old.content.final + newResults = n.new.content.final + elif full: + oldResults = n.old.content.updated + newResults = n.new.content.updated + else: + oldResults = n.old.see_also.updated + newResults = n.new.see_also.updated + diff = generate_diff(oldResults, newResults) + if not 6 in FILES.LOGS.SEE_ALSO_DISABLED: + log.plain(diff, log_folder + '\\Diff\\%s\\' % n.match_type + enNote.FullTitle, extension='htm', clear=True) + log.plain(diffify(oldResults, split=False), log_folder + '\\Original\\%s\\' % n.match_type + enNote.FullTitle, + extension='htm', clear=True) + log.plain(diffify(newResults, split=False), log_folder + '\\New\\%s\\' % n.match_type + enNote.FullTitle, + extension='htm', clear=True) + if final: + log.plain(oldResults, log_folder + '\\Final\\Old\\%s\\' % n.match_type + enNote.FullTitle, extension='htm', + clear=True) + log.plain(newResults, log_folder + '\\Final\\New\\%s\\' % n.match_type + enNote.FullTitle, extension='htm', + clear=True) + log.plain(diff + '\n', log_folder + '\\__All') + + # @clockit + def process_note(): + n.old.content = notes.version.pstrings(enNote.Content) + if not n.old.content.regex_original.successful_match: + if n.new.see_also.original == "": + n.new.content = notes.version.pstrings(n.old.content.original) + return False + n.new.content = notes.version.pstrings(n.old.content.original.replace('</en-note>', + '<div><span><br/></span></div>' + n.new.see_also.original + '\n</en-note>')) + n.new.see_also.updated = str_process(n.new.content.original) + n.old.see_also.updated = str_process(n.old.content.original) + log.plain(enNote.Guid + '<BR>' + ', '.join( + enNote.TagNames) + '<HR>' + enNote.Content + '<HR>' + n.new.see_also.updated, + 'SeeAlsoMatchFail\\' + enNote.FullTitle, extension='htm', clear=True) + n.match_type = 'V1' + else: + n.old.see_also = notes.version.pstrings(n.old.content.regex_original.main) + n.match_type = 'V2' + if n.old.see_also.regex_processed.successful_match: + assert True or str_process(n.old.content.regex_original.main) is n.old.content.regex_processed.main + n.old.content.updated = n.old.content.original.replace(n.old.content.regex_original.main, + str_process(n.old.content.regex_original.main)) + n.old.see_also.useProcessed() + n.match_type += 'V3' + n.new.see_also.regex_original.subject = n.new.see_also.original + '</en-note>' + if not n.new.see_also.regex_original.successful_match: + log.plain(enNote.Guid + '\n' + ', '.join(enNote.TagNames) + '\n' + n.new.see_also.original, + 'SeeAlsoNewMatchFail\\' + enNote.FullTitle, extension='htm', clear=True) + # see_also_replace_old = n.old.content.original.match.processed.see_also.processed.content + n.old.see_also.updated = n.old.content.regex_updated.see_also + n.new.see_also.updated = n.new.see_also.processed + n.match_type += 'V4' + else: + assert (n.old.content.regex_processed.see_also_content == notes.version.see_also_match( + str_process(n.old.content.regex_original.main)).see_also_content) + n.old.see_also.updated = notes.version.see_also_match( + str_process(n.old.content.regex_original.main)).see_also_content + n.new.see_also.updated = str_process(n.new.see_also.regex_original.see_also_content) + n.match_type += 'V5' + n.new.content.updated = n.old.content.updated.replace(n.old.see_also.updated, n.new.see_also.updated) + + def print_results_fail(title, status=None): + log.go(title + ' for %s' % enNote.FullTitle, 'NoUpdate') + print_results('NoMatch\\SeeAlso') + print_results('NoMatch\\Contents', full=True) + if status is None: + status = EvernoteAPIStatus.GenericError + tmr.reportStatus(status) + + noteType = 'SeeAlso-Step6' + db = ankDB() + db.delete("noteType = '%s'" % noteType, table=TABLES.NOTE_VALIDATION_QUEUE) + results = db.all("SELECT DISTINCT s.target_evernote_guid, n.* FROM {s} as s, {n} as n " + "WHERE s.target_evernote_guid = n.guid AND n.tagNames NOT LIKE '{t_toc}' " + "AND n.tagNames NOT LIKE '{t_out}' ORDER BY n.title ASC;") + # count_queued = 0 + log = Logger('See Also\\6-update_see_also_footer_in_evernote_notes\\', rm_path=True) + tmr = stopwatch.Timer(len(results), 25, infoStr='Updating Evernote See Also Notes', + label=log.base_path, do_print=True) + # log.banner("UPDATING EVERNOTE SEE ALSO CONTENT: %d NOTES" % len(results), do_print=True) + notes_updated = [] + # number_updated = 0 + for result in results: + enNote = EvernoteNotePrototype(db_note=result) + n = notes() + tmr.step(enNote.FullTitle if enNote.Status.IsSuccess else '(%s)' % enNote.Guid) + flds = get_anki_fields_from_evernote_guids(enNote.Guid) + if not flds: + print_results_fail('No Anki Note Found') + continue + flds = flds.split("\x1f") + n.new.see_also = notes.version.pstrings(flds[FIELDS.ORD.SEE_ALSO]) + result = process_note() + if result is False: + print_results_fail('No Match') + continue + if n.match_type != 'V1' and str_process(n.old.see_also.updated) == n.new.see_also.updated: + print_results_fail('Match but contents are the same', EvernoteAPIStatus.RequestSkipped) + continue + print_results() + print_results('Diff\\Contents', final=True) + enNote.Content = n.new.content.final + if not EVERNOTE.UPLOAD.ENABLED: + tmr.reportStatus(EvernoteAPIStatus.Disabled) + continue + if not evernote: + evernote = Evernote() + whole_note = tmr.autoStep(evernote.makeNote(enNote=enNote, noteType=noteType), update=True) + if tmr.report_result is False: + raise ValueError + if tmr.status.IsDelayableError: + break + if tmr.status.IsSuccess: + notes_updated.append(EvernoteNotePrototype(whole_note=whole_note)) + if tmr.is_success and not anki: + anki = Anki() + tmr.Report(0, anki.update_evernote_notes(notes_updated) if tmr.is_success else 0) + diff --git a/anknotes/dicts.py b/anknotes/dicts.py new file mode 100644 index 0000000..578589c --- /dev/null +++ b/anknotes/dicts.py @@ -0,0 +1,117 @@ +import collections +from anknotes.imports import in_anki +from anknotes.base import item_to_list, is_str, is_str_type, in_delimited_str, delete_keys, key_transform, str_capitalize, ank_prop, pad_digits +from anknotes.dicts_base import DictAnk + +mw = None + +class DictCaseInsensitive(DictAnk): + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + + def _key_transform_(self, key, keys=None, all=False, attrs=False): + mapping = keys or self + if attrs: + mapping, all = dir(mapping.__class__), False + return key_transform(mapping, key, all=all) + +class DictNumeric(DictCaseInsensitive): + _default_value_ = 0 + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + + def _convert_(self, val=None): + def _check_(val): + return val if isinstance(val, (int, long, float)) else None + value = val is not None + if not value: + val = self + if _check_(val): + return val + if isinstance(val, (DictNumeric)): + return _check_(value and val.getDefault() or val.getDefaultAttr()) or _check_(val.getValueAttr()) or self._default_value_ + return self._default_value_ + + @property + def sum(self): + def_val = self._convert_() + sum = not self._override_default_ and def_val or 0 + for key in self: + if not self._is_my_aggregate_(key): + sum += self._convert_(self[key]) + if sum == int(sum): + return int(sum) + return sum + + def increment(self, val=1, negate=False, **kwargs): + new_count = self.__add__(val, negate, True) + self.setDefault(new_count) + return self + + step = increment + def __bool__(self): return self.__simplify__() > 0 + def __div__(self, y): return self.__simplify__() / y + def __rdiv__(self, y): return 1 / self.__div__(y) + __truediv__ = __div__ + + def __mul__(self, y): return y * self.__simplify__() + __rmul__ = __mul__ + + def __add__(self, y, negate=False, increment=False): return self.__simplify__(increment) + y * (-1 if negate else 1) + def __sub__ (self, y): return self.__add__(y, True) + def __rsub__ (self, y): return self.__sub__(y) * -1 + def __isub__ (self, y): return self.increment(y, True) + + default_override = sum + + +class DictString(DictCaseInsensitive): + _default_ = '_label_name_' + _default_value_ = '' + _value_ = '' + + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + cls_mro = cls.mro()[mro] + self._my_attrs_ += '|_value_|_summarize_dont_print_default_' + super(cls_mro, self).__init__(mro+1, *a, **kw) + cls_mro.setSecondary = cls_mro.setValueAttr + cls_mro.getSecondary = cls_mro.getValueAttr + + def getDefault(self): + lbl = str_capitalize(self.label.full) + return lbl[:1].lower() + lbl[1:] + + +class DictSettings(DictString): + _cls_missing_attrs_ = True + def __init__(self, *a, **kw): + a, cls, mro = list(a), self.__class__, self._get_arg_(a, int, 'mro', kw) + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + + @property + def mw(self): + global mw + if mw is None and in_anki(): + from aqt import mw + return mw + + def fetch(self, default=''): + mw = self.mw + if not mw: + raise Exception("Attempted to fetch from DictSettings without mw instance") + default_value = self.val + if default_value is None: + default_value = default + return mw.col.conf.get(self.getDefault(), default_value) + + def save(self, value): + mw = self.mw + if not mw: + raise Exception("Attempted to save from DictSettings without mw instance") + mw.col.conf[self.getDefault()] = value + mw.col.setMod() + mw.col.save() + return True \ No newline at end of file diff --git a/anknotes/dicts_base.py b/anknotes/dicts_base.py new file mode 100644 index 0000000..a3eec68 --- /dev/null +++ b/anknotes/dicts_base.py @@ -0,0 +1,360 @@ +import collections +from addict import Dict +from anknotes.constants_standard import ANKNOTES +from anknotes.base import item_to_list, is_str, is_str_type, in_delimited_str, delete_keys, key_transform, str_capitalize, ank_prop, pad_digits, call, get_unique_strings +from anknotes import dicts_summary + +class DictKey(object): + _type_ = _name_ = _parent_ = _delimiter_ = None + _default_name_='Root' + _default_parent_='' + _parent_dict_ = _self_dict_ = None + + def __init__(self, name=None, parent=None, type=None, self_dict=None, parent_dict=None, default_name=None, default_parent=None, delimiter=None): + type = type or 'key' + parent_dict = parent_dict or self_dict is not None and self_dict._parent_ + if parent_dict is not None: + self._parent_dict_ = parent_dict + if type == 'key': + default_name = parent_dict.__class__.__name__ + else: + base_key = getattr(parent_dict, '_key_') + default_name = base_key.name + default_parent = base_key.parent + base_key = getattr(parent_dict, '_%s_' % type.lower()) + if name is None: + name = base_key.name + if parent is None: + parent = base_key.parent + elif parent is None: + parent = base_key.full + if delimiter is None: + delimiter = base_key.delimiter + all_args = locals() + for attr in 'name|parent|delimiter|default_name|default_parent|type|self_dict'.split('|'): + val = all_args[attr] + if val is None: + continue + if (attr == 'name' or attr == 'parent') and not is_str_type(val): + str_val = str(val) + if type == 'label' and str_val.isdigit(): + val = str_val + else: + raise TypeError("Cannot set %s %s from non string type <%s> %s" % (type.capitalize(), attr, val.__class__.__name__, str_val)) + setattr(self, '_%s_' % (attr), val) + + @property + def type(self): + return self._type_.capitalize() if self._type_ is not None else 'Key' + + @property + def name(self): + return self.call(self._name_, self._default_name_) + + @name.setter + def name(self, value): + self._name_ = value + + @property + def parent(self): + return self.call(self._parent_, '') + + @property + def delimiter(self): + return self._delimiter_ if self._delimiter_ is not None else '.' + + @property + def full(self): + return self.join() + + def call(self, value, default=None): + if value is None: + return default + return call(value, self=self, dct=self._self_dict_, parent=self._parent_dict_) + + def join(self, delimiter=None): + delimiter = delimiter or self.delimiter if self.parent and self.name else '' + return self.parent + delimiter + self.name + + def __str__(self): + return self.full + + def __repr__(self): + return '<%s> %s: %s' % (self.__class__.__name__, self.type, self.full) + +class DictAnk(Dict): + _label_ = _key_ = _value_ = _parent_ = None + _my_aggregates_ = _my_attrs_ = '' + _mro_offset_ = 0 + _default_ = _default_value_ = _override_default_ = None + _cls_missing_attrs_ = False + + def __init__(self, *a, **kw): + def _copy_keys_from_parent_(kw): + def init(suffix=''): + k0, k1 = keys[0] + suffix, keys[1] + suffix + if k1 not in kw and k0 in kw and kw[k0]: + kw[k1] = kw[k0] + + # Begin _copy_keys_from_parent_(): + keys=['key', 'label'] + for k0 in keys: + if k0 in kw and is_str_type(kw[k0]): + kw[k0+'_name'] = kw[k0] + del kw[k0] + init(), init('_name') + for k in keys: + kv, kn, kp = self._get_kwargs_(kw, k, '%s_name' % k, ['%s_parent' % k, 'parent_%s' % k]) + self._my_attrs_ += '|_%s_' % k + kv = kv or DictKey(kn, kp, k, self, self._parent_) + setattr(self, '_%s_' % k, kv) + + # Begin __init__(): + cls, a = self.__class__, list(a) + self._my_attrs_ += '|_cls_missing_attrs_|_my_aggregates_|_default_|_default_value_|_override_default_|_mro_offset_|_parent_' + mro = self._get_arg_(a, int, 'mro', kw) + self._parent_=self._get_arg_(a, DictAnk) + _copy_keys_from_parent_(kw) + override_default = self._get_kwarg_(kw, 'override_default', True) + delete, initialize = self._get_kwargs_(kw, 'delete', 'initialize') + if self._default_: + self._my_attrs_ += '|' + self._default_ + self._override_default_ = override_default if self._override_default_ is not None else None + super(cls.mro()[mro], self).__init__(mro+1, *a, **kw) + if delete: + self.delete_keys(delete) + if initialize: + self.initialize_keys(initialize) + + prop = ank_prop + + @property + def key(self): + return self._key_ + + @property + def label(self): + return self._label_ + + @property + def _label_name_(self): + return self.label.name if self.label else '' + + @_label_name_.setter + def _label_name_(self, value): + if not self.label: + self._label_ = DictKey(value, type='label', self_dict=self) + else: + self._label_.name = value + + @property + def default(self): + return self.getDefault() + + @default.setter + def default(self, value): + return self.setDefault(value) + + def getDefaultAttr(self): + if self._default_ is None: + return None + if self._is_obj_attr_(self._default_): + val = getattr(self, self._default_) + return self._default_value_ if val is None else val + return self._default_value_ + + def _getDefault(self, allow_override=True): + if self._default_ is None: + return None + if allow_override and self._override_default_ and self.default_override: + return self.default_override + return self.getDefaultAttr() + + def setDefault(self, value, set_override=True): + if self._default_ is None: + return + if set_override is not None: + self._override_default_ = set_override + setattr(self, self._default_, value) + + def getDefaultOrValue(self): + if self._default_ is None: + return self.getValueAttr() + return self.getDefault() + + getDefault = _getDefault + getSecondary = getDefault + setSecondary = setDefault + + @property + def default_override(self): + return None + + @property + def has_value(self): + return self._is_my_attr_('_value_') + + def getValueAttr(self): + if self.has_value: + return self._value_ + return None + + def setValueAttr(self, value): + self._value_ = value + + def getValue(self): + if self.has_value: + return self._value_ + return self.getDefault() + + def setValue(self, value): + if self.has_value: + self._value_ = value + else: + self.setDefault(value) + + val = property(getValue, setValue, None, 'Property for `val` attribute') + get = default + + def _is_my_attr_(self, key): + return in_delimited_str(key, self._my_attrs_ + '|_my_attrs_') + + def _is_my_aggregate_(self, key): + return in_delimited_str(key, self._my_aggregates_) + + @staticmethod + def _is_protected_(key): + return (key.startswith('_') and key.endswith('_')) or key.startswith('__') + + def _new_instance_(self, *a, **kw): + return self.__class__.mro()[self._mro_offset_](*a, **kw) + + def __hash__(self, return_items=False): + def _items_to_hash_(): + def _item_hash_(item): + if isinstance(item, DictAnk): + return item.__hash__(True) + return (item,) + + # Begin _items_to_hash_(): + base_hash = [self.__class__.__name__, self.key.full, self.label.full, self.default, self.val] + for i, item in enumerate(base_hash): + item_hash = _item_hash_(item) + base_hash[i] = item_hash[0] if len(item_hash) is 1 else item_hash + hashes=[tuple(base_hash)] + for key in self: + item = self[key] + key_hash = _item_hash_(key) + item_hash = _item_hash_(item) + hashes.append((key_hash, item_hash)) + return tuple(hashes) + + # Begin __hash__() + items = _items_to_hash_() + return items if return_items or items is None else hash(items) + + def print_banner(self, title): + print self.make_banner(title) + + @staticmethod + def make_banner(title): + return '\n'.join(["-" * max(ANKNOTES.FORMATTING.COUNTER_BANNER_MINIMUM, len(title) + 5), title, + "-" * max(ANKNOTES.FORMATTING.COUNTER_BANNER_MINIMUM, len(title) + 5)]) + + def delete_keys(self, keys_to_delete): + delete_keys(self, keys_to_delete) + + def _get_kwargs_(self, kwargs, *a, **kw): + return [self._get_kwarg_(kwargs, key, **kw) for key in a] + + def _get_kwarg_(self, kwargs, keys, default=None, replace_none_type=True, **kw): + retval = replace_none_type and default or None + for key in item_to_list(keys): + key = self._key_transform_(key, kwargs) + if key not in kwargs: + continue + val = kwargs[key] + retval = val or retval + del kwargs[key] + return retval + + def reset(self, keys_to_keep=None): + if keys_to_keep is None: + keys_to_keep = self._my_aggregates_.lower().split("|") + self.delete_keys([key for key in self if key.lower() not in keys_to_keep]) + + _summarize_lines_ = dicts_summary._summarize_lines_ + _get_summary_ = dicts_summary._get_summary_ + + def __repr__(self, **kw): + return self._summarize_lines_(self._get_summary_(), **kw) + + def increment(self, val, negate=False): + new_value = self.__add__(val, increment=True) + self.setDefault(new_value) + return self + + def __simplify__(self, increment=False): + if increment: + return self.__simplify_increment__() + return self.getDefault() + + def __simplify_increment__(self): return self.getDefaultAttr() + def __coerce__(self, y): return (self.__simplify__(), y) + def __bool__(self): return bool(self.__simplify__()) or bool(self.items()) + def __truth__(self): return self.__bool__() + __nonzero__ = __truth__ + + def __add__(self, y, increment=False): return self.__simplify__(increment) + y + def __radd__(self, y): return y + self.__simplify__() + def __iadd__(self, y): return self.increment(y) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return exc_type is None + + def __setattr__(self, key, value): + key_adj = self._key_transform_(key) + if self._is_protected_(key): + if not self._is_my_attr_(key): + raise AttributeError("Attempted to set protected built-in item %s on %s\n\nMy Attrs: %s" % (key_adj, self.__class__.__name__, self._my_attrs_)) + else: + super(Dict, self).__setattr__(key_adj, value) + elif self._default_ and is_str(key_adj) and (key_adj.lower() == 'default' or key_adj.lower() == self._default_.strip('_')): + self.setDefault(value) + elif key_adj in self: + attr_val = getattr(self, key_adj) + if self._default_ and isinstance(attr_val, self.__class__): + attr_val.setSecondary(value.getDefaultAttr() if isinstance(value, self.__class__) else value) + else: + super(self.__class__.mro()[-4], self).__setattr__(key_adj, value) + else: + if self._cls_missing_attrs_: + self[key_adj] = self._new_instance_(self, key_name=key_adj, override_default=True) + self[key_adj].setValue(value) + else: + super(self.__class__.mro()[-4], self).__setitem__(key_adj, value) + + def __setitem__(self, name, value): + super(self.__class__.mro()[-4], self).__setitem__(name, value) + + def __getitem__(self, key): + key_adj = self._key_transform_all_(key) + if self._default_ and is_str(key_adj) and (key_adj.lower() == 'default' or key_adj.lower() == self._default_.strip('_')): + return self.getDefault() + if key_adj not in self: + if key_adj in dir(self.__class__): + return super(self.__class__.mro()[-3], self).__getattr__(key_adj) + elif self._is_protected_(key): + try: + return None if self._is_my_attr_(key) else super(Dict, self).__getattr__(key) + except KeyError: + raise KeyError("Could not find protected built-in item " + key) + self[key_adj] = self._new_instance_(self, key_name=key_adj, override_default=True) + try: + return super(self.__class__.mro()[-4], self).__getitem__(key_adj) + except TypeError: + return "<null>" + diff --git a/anknotes/dicts_summary.py b/anknotes/dicts_summary.py new file mode 100644 index 0000000..a44a7d9 --- /dev/null +++ b/anknotes/dicts_summary.py @@ -0,0 +1,95 @@ +import collections +from addict import Dict +from anknotes.constants_standard import ANKNOTES +from anknotes.base import item_to_list, is_str, is_str_type, in_delimited_str, delete_keys, key_transform, str_capitalize, ank_prop, pad_digits, call, get_unique_strings +from addict import Dict + +def _get_summary_(self, level=1, header_only=False): + summary=Dict(level=level, label=self.label, children=self.keys(), key=self.key, child_values={}) + summary.strs.update(class_name=self.__class__.__name__, marker=' ') + if self._default_ is not None: + dval = self.getDefault() + if dval != self._default_value_: + summary.strs.default_full = summary.strs.default = dval + if self._override_default_: + summary.strs.marker = '*' + attr = self.getDefaultAttr() + if attr != self._default_value_: + summary.strs.default_attr = attr + summary.strs.marker = '!' if self._override_default_ else '#' + summary.strs.value = self.has_value and self.val or '' + summaries=[] + if header_only: + return [summary] + for key in sorted(self.keys()): + if self._is_my_aggregate_(key): + continue + item = self[key] + if not isinstance(item, Dict): + summary.child_values[key] = item + elif not header_only: + summaries += item._get_summary_(level + 1) + return [summary] + summaries + +def _summarize_lines_(self, summary, header=True, value_pad=3): + def _summarize_child_lines_(pad_len): + if not item.child_values: + return '' + child_lines = [] + pad = ' '*(pad_len*3-1) + for child_key in sorted(item.child_values.keys()): + child_value = str(item.child_values[child_key]) + child_value = pad_digits(child_value) + marker = '+' + if child_key.startswith('#'): + marker = '#' + child_key = child_key[1:] + child_lines.append(('%s%s%-15s' % (pad, marker, child_key + ':')).ljust(16+pad_len * 4 + 11) + child_value + marker) + return '\n' + '\n'.join(child_lines) + + # Begin _summarize_lines_() + lines = [] + for i, item in enumerate(summary): + str_full_key = str_full_label = str_label = str_key = str_default_attr = str_value = str_default = str_default_full = '' + if item.key: + item.strs.update(akey_full=item.key.join(' -> '), akey_name=item.key.name, akey_parent=item.key.parent) + if item.label: + item.strs.update(alabel_full=item.label.join(' -> '), alabel_name=item.label.name, alabel_parent=item.label.parent) + strs = get_unique_strings(item.strs) + if strs.alabel_full: + if strs.akey_full and item.key.parent == item.label.parent: + strs.alabel_full = '* -> ' + item.label.name + strs.alabel_full='%sL[%s]' % (' | ' if strs.akey_full else '', strs.alabel_full) + if strs.akey_full and strs.default_full: + strs.alabel_full += ': ' + if self._is_my_attr_('_summarize_dont_print_default_') or not strs.default: + strs.default = strs.default_full = '' + elif strs.alabel_full: + strs.default_full = 'D[%s]' % strs.default + strs.default_full += ':' if strs.value else '' + if strs.alabel_name: + strs.alabel_name='%sL[%s]' % (' | ' if strs.akey_name else '', strs.alabel_name) + if i is 0 and header: + pad_len = len(strs.class_name) + (len(strs.alabel_full) - len(strs.alabel_name) if strs.alabel_full else 0) + lines.append("<%s%s:%s%s%s>" % (strs.marker.strip(), strs.class_name, strs.akey_full + strs.alabel_full, strs.default_full, + strs.value) + _summarize_child_lines_(item.level+1)) + continue + str_ = ' ' * (item.level * 3 - 1) + strs.marker + strs.akey_name + strs.alabel_name + child_str = _summarize_child_lines_(item.level+1) + strs.val_def = ' '*(value_pad+2) + strs.value, strs.default, strs.default_attr = pad_digits(strs.value, strs.default, strs.default_attr) + if strs.value and strs.default: + strs.val_def = (strs.value + ':').ljust(value_pad+1) + ' ' + strs.default + elif strs.default: + strs.val_def += strs.default + elif strs.value: + strs.val_def = strs.value.ljust(value_pad) + if strs.default_attr: + strs.val_def += ' ' + strs.default_attr + if strs.val_def.strip() == '': + strs.marker = '' + else: + str_ += ': ' + str_ = str_.ljust(16 + item.level * 4 + 5) + lines.append(str_ + ' ' + strs.val_def + strs.marker + child_str) + return '\n'.join(lines) \ No newline at end of file diff --git a/anknotes/enum/LICENSE b/anknotes/enum/LICENSE new file mode 100644 index 0000000..9003b88 --- /dev/null +++ b/anknotes/enum/LICENSE @@ -0,0 +1,32 @@ +Copyright (c) 2013, Ethan Furman. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + + Neither the name Ethan Furman nor the names of any + contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/anknotes/enum/README b/anknotes/enum/README new file mode 100644 index 0000000..511af98 --- /dev/null +++ b/anknotes/enum/README @@ -0,0 +1,2 @@ +enum34 is the new Python stdlib enum module available in Python 3.4 +backported for previous versions of Python from 2.4 to 3.3. diff --git a/anknotes/enum/__init__.py b/anknotes/enum/__init__.py new file mode 100644 index 0000000..6a327a8 --- /dev/null +++ b/anknotes/enum/__init__.py @@ -0,0 +1,790 @@ +"""Python Enumerations""" + +import sys as _sys + +__all__ = ['Enum', 'IntEnum', 'unique'] + +version = 1, 0, 4 + +pyver = float('%s.%s' % _sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +try: + basestring +except NameError: + # In Python 2 basestring is the ancestor of both str and unicode + # in Python 3 it's just str, but was missing in 3.1 + basestring = str + +try: + unicode +except NameError: + # In Python 3 unicode no longer exists (it's just str) + unicode = str + +class _RouteClassAttributeToGetattr(object): + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '<unknown>' + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super(_EnumDict, self).__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + Note: in 3.x __order__ is simply discarded as a not necessary piece + leftover from 2.x + + """ + if pyver >= 3.0 and key == '__order__': + return + if _is_sunder(key): + raise ValueError('_names_ are reserved for future Enum use') + elif _is_dunder(key): + pass + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('Key already defined as: %r' % self[key]) + self._member_names.append(key) + super(_EnumDict, self).__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + if type(classdict) is dict: + original_dict = classdict + classdict = _EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + # save enum items into separate mapping so they don't get baked into + # the new class + members = dict((k, classdict[k]) for k in classdict._member_names) + for name in classdict._member_names: + del classdict[name] + + # py2 support for definition order + __order__ = classdict.get('__order__') + if __order__ is None: + if pyver < 3.0: + try: + __order__ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] + except TypeError: + __order__ = [name for name in sorted(members.keys())] + else: + __order__ = classdict._member_names + else: + del classdict['__order__'] + if pyver < 3.0: + __order__ = __order__.replace(',', ' ').split() + aliases = [name for name in members if name not in __order__] + __order__ += aliases + + # check for illegal enum names (any others?) + invalid_names = set(members) & set(['mro']) + if invalid_names: + raise ValueError('Invalid enum member name(s): %s' % ( + ', '.join(invalid_names), )) + + # create our new Enum type + enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in random order + if OrderedDict is not None: + enum_class._member_map_ = OrderedDict() + else: + enum_class._member_map_ = {} # name->value map + enum_class._member_type_ = member_type + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + if __new__ is None: + __new__ = enum_class.__new__ + for member_name in __order__: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args or not args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member.value == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + unpicklable = False + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + unpicklable = True + + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if name not in classdict and class_method is not enum_method: + if name == '__reduce_ex__' and unpicklable: + continue + setattr(enum_class, name, enum_method) + + # method resolution and int's are not playing nice + # Python's less than 2.6 use __cmp__ + + if pyver < 2.6: + + if issubclass(enum_class, int): + setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) + + elif pyver < 3.0: + + if issubclass(enum_class, int): + for method in ( + '__le__', + '__lt__', + '__gt__', + '__ge__', + '__eq__', + '__ne__', + '__hash__', + ): + setattr(enum_class, method, getattr(int, method)) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) + setattr(enum_class, '__new__', Enum.__dict__['__new__']) + return enum_class + + def __call__(cls, value, names=None, module=None, type=None): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, type=type) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super(EnumMeta, cls).__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a copy of the internal mapping. + + """ + return cls._member_map_.copy() + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __len__(cls): + return len(cls._member_names_) + + def __repr__(cls): + return "<enum %r>" % cls.__name__ + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super(EnumMeta, cls).__setattr__(name, value) + + def _create_(cls, class_name, names=None, module=None, type=None): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + if pyver < 3.0: + # if class_name is unicode, attempt a conversion to ASCII + if isinstance(class_name, unicode): + try: + class_name = class_name.encode('ascii') + except UnicodeEncodeError: + raise TypeError('%r is not representable in ASCII' % class_name) + metacls = cls.__class__ + if type is None: + bases = (cls, ) + else: + bases = (type, cls) + classdict = metacls.__prepare__(class_name, bases) + __order__ = [] + + # special processing needed for names? + if isinstance(names, basestring): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): + names = [(e, i+1) for (i, e) in enumerate(names)] + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, basestring): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + __order__.append(member_name) + # only set __order__ in classdict if name/value was not from a mapping + if not isinstance(item, basestring): + classdict['__order__'] = ' '.join(__order__) + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = _sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError): + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases or Enum is None: + return object, Enum + + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, + # <class 'int'>, <Enum 'Enum'>, + # <class 'object'>) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + if pyver < 3.0: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + if __new__: + return None, True, True # __new__, save_new, use_args + + N__new__ = getattr(None, '__new__') + O__new__ = getattr(object, '__new__') + if Enum is None: + E__new__ = N__new__ + else: + E__new__ = Enum.__dict__['__new__'] + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + try: + target = possible.__dict__[method] + except (AttributeError, KeyError): + target = getattr(possible, method, None) + if target not in [ + None, + N__new__, + O__new__, + E__new__, + ]: + if method == '__member_new__': + classdict['__new__'] = target + return None, False, True + if isinstance(target, staticmethod): + target = target.__get__(member_type) + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, False, use_args + else: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __member_new__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in ( + None, + None.__new__, + object.__new__, + Enum.__new__, + ): + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +######################################################## +# In order to support Python 2 and 3 with a single +# codebase we have to create the Enum methods separately +# and then use the `type(name, bases, dict)` method to +# create the class. +######################################################## +temp_enum_dict = {} +temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" + +def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + value = value.value + #return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) +temp_enum_dict['__new__'] = __new__ +del __new__ + +def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) +temp_enum_dict['__repr__'] = __repr__ +del __repr__ + +def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) +temp_enum_dict['__str__'] = __str__ +del __str__ + +def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' + ] + return (['__class__', '__doc__', '__module__', ] + added_behavior) +temp_enum_dict['__dir__'] = __dir__ +del __dir__ + +def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self.value + return cls.__format__(val, format_spec) +temp_enum_dict['__format__'] = __format__ +del __format__ + + +#################################### +# Python's less than 2.6 use __cmp__ + +if pyver < 2.6: + + def __cmp__(self, other): + if type(other) is self.__class__: + if self is other: + return 0 + return -1 + return NotImplemented + raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__cmp__'] = __cmp__ + del __cmp__ + +else: + + def __le__(self, other): + raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__le__'] = __le__ + del __le__ + + def __lt__(self, other): + raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__lt__'] = __lt__ + del __lt__ + + def __ge__(self, other): + raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__ge__'] = __ge__ + del __ge__ + + def __gt__(self, other): + raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__gt__'] = __gt__ + del __gt__ + + +def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented +temp_enum_dict['__eq__'] = __eq__ +del __eq__ + +def __ne__(self, other): + if type(other) is self.__class__: + return self is not other + return NotImplemented +temp_enum_dict['__ne__'] = __ne__ +del __ne__ + +def __hash__(self): + return hash(self._name_) +temp_enum_dict['__hash__'] = __hash__ +del __hash__ + +def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) +temp_enum_dict['__reduce_ex__'] = __reduce_ex__ +del __reduce_ex__ + +# _RouteClassAttributeToGetattr is used to provide access to the `name` +# and `value` properties of enum members while keeping some measure of +# protection from modification, while still allowing for an enumeration +# to have members named `name` and `value`. This works because enumeration +# members are not set directly on the enum class -- __getattr__ is +# used to look them up. + +@_RouteClassAttributeToGetattr +def name(self): + return self._name_ +temp_enum_dict['name'] = name +del name + +@_RouteClassAttributeToGetattr +def value(self): + return self._value_ +temp_enum_dict['value'] = value +del value + +Enum = EnumMeta('Enum', (object, ), temp_enum_dict) +del temp_enum_dict + +# Enum has now been created +########################### + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + + +def unique(enumeration): + """Class decorator that ensures only unique members exist in an enumeration.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + duplicate_names = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) + raise ValueError('duplicate names found in %r: %s' % + (enumeration, duplicate_names) + ) + return enumeration diff --git a/anknotes/enum/doc/enum.rst b/anknotes/enum/doc/enum.rst new file mode 100644 index 0000000..0d429bf --- /dev/null +++ b/anknotes/enum/doc/enum.rst @@ -0,0 +1,725 @@ +``enum`` --- support for enumerations +======================================== + +.. :synopsis: enumerations are sets of symbolic names bound to unique, constant + values. +.. :moduleauthor:: Ethan Furman <ethan@stoneleaf.us> +.. :sectionauthor:: Barry Warsaw <barry@python.org>, +.. :sectionauthor:: Eli Bendersky <eliben@gmail.com>, +.. :sectionauthor:: Ethan Furman <ethan@stoneleaf.us> + +---------------- + +An enumeration is a set of symbolic names (members) bound to unique, constant +values. Within an enumeration, the members can be compared by identity, and +the enumeration itself can be iterated over. + + +Module Contents +--------------- + +This module defines two enumeration classes that can be used to define unique +sets of names and values: ``Enum`` and ``IntEnum``. It also defines +one decorator, ``unique``. + +``Enum`` + +Base class for creating enumerated constants. See section `Functional API`_ +for an alternate construction syntax. + +``IntEnum`` + +Base class for creating enumerated constants that are also subclasses of ``int``. + +``unique`` + +Enum class decorator that ensures only one name is bound to any one value. + + +Creating an Enum +---------------- + +Enumerations are created using the ``class`` syntax, which makes them +easy to read and write. An alternative creation method is described in +`Functional API`_. To define an enumeration, subclass ``Enum`` as +follows:: + + >>> from enum import Enum + >>> class Color(Enum): + ... red = 1 + ... green = 2 + ... blue = 3 + +Note: Nomenclature + + - The class ``Color`` is an *enumeration* (or *enum*) + - The attributes ``Color.red``, ``Color.green``, etc., are + *enumeration members* (or *enum members*). + - The enum members have *names* and *values* (the name of + ``Color.red`` is ``red``, the value of ``Color.blue`` is + ``3``, etc.) + +Note: + + Even though we use the ``class`` syntax to create Enums, Enums + are not normal Python classes. See `How are Enums different?`_ for + more details. + +Enumeration members have human readable string representations:: + + >>> print(Color.red) + Color.red + +...while their ``repr`` has more information:: + + >>> print(repr(Color.red)) + <Color.red: 1> + +The *type* of an enumeration member is the enumeration it belongs to:: + + >>> type(Color.red) + <enum 'Color'> + >>> isinstance(Color.green, Color) + True + >>> + +Enum members also have a property that contains just their item name:: + + >>> print(Color.red.name) + red + +Enumerations support iteration. In Python 3.x definition order is used; in +Python 2.x the definition order is not available, but class attribute +``__order__`` is supported; otherwise, value order is used:: + + >>> class Shake(Enum): + ... __order__ = 'vanilla chocolate cookies mint' # only needed in 2.x + ... vanilla = 7 + ... chocolate = 4 + ... cookies = 9 + ... mint = 3 + ... + >>> for shake in Shake: + ... print(shake) + ... + Shake.vanilla + Shake.chocolate + Shake.cookies + Shake.mint + +The ``__order__`` attribute is always removed, and in 3.x it is also ignored +(order is definition order); however, in the stdlib version it will be ignored +but not removed. + +Enumeration members are hashable, so they can be used in dictionaries and sets:: + + >>> apples = {} + >>> apples[Color.red] = 'red delicious' + >>> apples[Color.green] = 'granny smith' + >>> apples == {Color.red: 'red delicious', Color.green: 'granny smith'} + True + + +Programmatic access to enumeration members and their attributes +--------------------------------------------------------------- + +Sometimes it's useful to access members in enumerations programmatically (i.e. +situations where ``Color.red`` won't do because the exact color is not known +at program-writing time). ``Enum`` allows such access:: + + >>> Color(1) + <Color.red: 1> + >>> Color(3) + <Color.blue: 3> + +If you want to access enum members by *name*, use item access:: + + >>> Color['red'] + <Color.red: 1> + >>> Color['green'] + <Color.green: 2> + +If have an enum member and need its ``name`` or ``value``:: + + >>> member = Color.red + >>> member.name + 'red' + >>> member.value + 1 + + +Duplicating enum members and values +----------------------------------- + +Having two enum members (or any other attribute) with the same name is invalid; +in Python 3.x this would raise an error, but in Python 2.x the second member +simply overwrites the first:: + + >>> # python 2.x + >>> class Shape(Enum): + ... square = 2 + ... square = 3 + ... + >>> Shape.square + <Shape.square: 3> + + >>> # python 3.x + >>> class Shape(Enum): + ... square = 2 + ... square = 3 + Traceback (most recent call last): + ... + TypeError: Attempted to reuse key: 'square' + +However, two enum members are allowed to have the same value. Given two members +A and B with the same value (and A defined first), B is an alias to A. By-value +lookup of the value of A and B will return A. By-name lookup of B will also +return A:: + + >>> class Shape(Enum): + ... __order__ = 'square diamond circle alias_for_square' # only needed in 2.x + ... square = 2 + ... diamond = 1 + ... circle = 3 + ... alias_for_square = 2 + ... + >>> Shape.square + <Shape.square: 2> + >>> Shape.alias_for_square + <Shape.square: 2> + >>> Shape(2) + <Shape.square: 2> + + +Allowing aliases is not always desirable. ``unique`` can be used to ensure +that none exist in a particular enumeration:: + + >>> from enum import unique + >>> @unique + ... class Mistake(Enum): + ... __order__ = 'one two three four' # only needed in 2.x + ... one = 1 + ... two = 2 + ... three = 3 + ... four = 3 + Traceback (most recent call last): + ... + ValueError: duplicate names found in <enum 'Mistake'>: four -> three + +Iterating over the members of an enum does not provide the aliases:: + + >>> list(Shape) + [<Shape.square: 2>, <Shape.diamond: 1>, <Shape.circle: 3>] + +The special attribute ``__members__`` is a dictionary mapping names to members. +It includes all names defined in the enumeration, including the aliases:: + + >>> for name, member in sorted(Shape.__members__.items()): + ... name, member + ... + ('alias_for_square', <Shape.square: 2>) + ('circle', <Shape.circle: 3>) + ('diamond', <Shape.diamond: 1>) + ('square', <Shape.square: 2>) + +The ``__members__`` attribute can be used for detailed programmatic access to +the enumeration members. For example, finding all the aliases:: + + >>> [name for name, member in Shape.__members__.items() if member.name != name] + ['alias_for_square'] + +Comparisons +----------- + +Enumeration members are compared by identity:: + + >>> Color.red is Color.red + True + >>> Color.red is Color.blue + False + >>> Color.red is not Color.blue + True + +Ordered comparisons between enumeration values are *not* supported. Enum +members are not integers (but see `IntEnum`_ below):: + + >>> Color.red < Color.blue + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + TypeError: unorderable types: Color() < Color() + +.. warning:: + + In Python 2 *everything* is ordered, even though the ordering may not + make sense. If you want your enumerations to have a sensible ordering + check out the `OrderedEnum`_ recipe below. + + +Equality comparisons are defined though:: + + >>> Color.blue == Color.red + False + >>> Color.blue != Color.red + True + >>> Color.blue == Color.blue + True + +Comparisons against non-enumeration values will always compare not equal +(again, ``IntEnum`` was explicitly designed to behave differently, see +below):: + + >>> Color.blue == 2 + False + + +Allowed members and attributes of enumerations +---------------------------------------------- + +The examples above use integers for enumeration values. Using integers is +short and handy (and provided by default by the `Functional API`_), but not +strictly enforced. In the vast majority of use-cases, one doesn't care what +the actual value of an enumeration is. But if the value *is* important, +enumerations can have arbitrary values. + +Enumerations are Python classes, and can have methods and special methods as +usual. If we have this enumeration:: + + >>> class Mood(Enum): + ... funky = 1 + ... happy = 3 + ... + ... def describe(self): + ... # self is the member here + ... return self.name, self.value + ... + ... def __str__(self): + ... return 'my custom str! {0}'.format(self.value) + ... + ... @classmethod + ... def favorite_mood(cls): + ... # cls here is the enumeration + ... return cls.happy + +Then:: + + >>> Mood.favorite_mood() + <Mood.happy: 3> + >>> Mood.happy.describe() + ('happy', 3) + >>> str(Mood.funky) + 'my custom str! 1' + +The rules for what is allowed are as follows: _sunder_ names (starting and +ending with a single underscore) are reserved by enum and cannot be used; +all other attributes defined within an enumeration will become members of this +enumeration, with the exception of *__dunder__* names and descriptors (methods +are also descriptors). + +Note: + + If your enumeration defines ``__new__`` and/or ``__init__`` then + whatever value(s) were given to the enum member will be passed into + those methods. See `Planet`_ for an example. + + +Restricted subclassing of enumerations +-------------------------------------- + +Subclassing an enumeration is allowed only if the enumeration does not define +any members. So this is forbidden:: + + >>> class MoreColor(Color): + ... pink = 17 + Traceback (most recent call last): + ... + TypeError: Cannot extend enumerations + +But this is allowed:: + + >>> class Foo(Enum): + ... def some_behavior(self): + ... pass + ... + >>> class Bar(Foo): + ... happy = 1 + ... sad = 2 + ... + +Allowing subclassing of enums that define members would lead to a violation of +some important invariants of types and instances. On the other hand, it makes +sense to allow sharing some common behavior between a group of enumerations. +(See `OrderedEnum`_ for an example.) + + +Pickling +-------- + +Enumerations can be pickled and unpickled:: + + >>> from enum.test_enum import Fruit + >>> from pickle import dumps, loads + >>> Fruit.tomato is loads(dumps(Fruit.tomato, 2)) + True + +The usual restrictions for pickling apply: picklable enums must be defined in +the top level of a module, since unpickling requires them to be importable +from that module. + +Note: + + With pickle protocol version 4 (introduced in Python 3.4) it is possible + to easily pickle enums nested in other classes. + + + +Functional API +-------------- + +The ``Enum`` class is callable, providing the following functional API:: + + >>> Animal = Enum('Animal', 'ant bee cat dog') + >>> Animal + <enum 'Animal'> + >>> Animal.ant + <Animal.ant: 1> + >>> Animal.ant.value + 1 + >>> list(Animal) + [<Animal.ant: 1>, <Animal.bee: 2>, <Animal.cat: 3>, <Animal.dog: 4>] + +The semantics of this API resemble ``namedtuple``. The first argument +of the call to ``Enum`` is the name of the enumeration. + +The second argument is the *source* of enumeration member names. It can be a +whitespace-separated string of names, a sequence of names, a sequence of +2-tuples with key/value pairs, or a mapping (e.g. dictionary) of names to +values. The last two options enable assigning arbitrary values to +enumerations; the others auto-assign increasing integers starting with 1. A +new class derived from ``Enum`` is returned. In other words, the above +assignment to ``Animal`` is equivalent to:: + + >>> class Animals(Enum): + ... ant = 1 + ... bee = 2 + ... cat = 3 + ... dog = 4 + +Pickling enums created with the functional API can be tricky as frame stack +implementation details are used to try and figure out which module the +enumeration is being created in (e.g. it will fail if you use a utility +function in separate module, and also may not work on IronPython or Jython). +The solution is to specify the module name explicitly as follows:: + + >>> Animals = Enum('Animals', 'ant bee cat dog', module=__name__) + +Derived Enumerations +-------------------- + +IntEnum +^^^^^^^ + +A variation of ``Enum`` is provided which is also a subclass of +``int``. Members of an ``IntEnum`` can be compared to integers; +by extension, integer enumerations of different types can also be compared +to each other:: + + >>> from enum import IntEnum + >>> class Shape(IntEnum): + ... circle = 1 + ... square = 2 + ... + >>> class Request(IntEnum): + ... post = 1 + ... get = 2 + ... + >>> Shape == 1 + False + >>> Shape.circle == 1 + True + >>> Shape.circle == Request.post + True + +However, they still can't be compared to standard ``Enum`` enumerations:: + + >>> class Shape(IntEnum): + ... circle = 1 + ... square = 2 + ... + >>> class Color(Enum): + ... red = 1 + ... green = 2 + ... + >>> Shape.circle == Color.red + False + +``IntEnum`` values behave like integers in other ways you'd expect:: + + >>> int(Shape.circle) + 1 + >>> ['a', 'b', 'c'][Shape.circle] + 'b' + >>> [i for i in range(Shape.square)] + [0, 1] + +For the vast majority of code, ``Enum`` is strongly recommended, +since ``IntEnum`` breaks some semantic promises of an enumeration (by +being comparable to integers, and thus by transitivity to other +unrelated enumerations). It should be used only in special cases where +there's no other choice; for example, when integer constants are +replaced with enumerations and backwards compatibility is required with code +that still expects integers. + + +Others +^^^^^^ + +While ``IntEnum`` is part of the ``enum`` module, it would be very +simple to implement independently:: + + class IntEnum(int, Enum): + pass + +This demonstrates how similar derived enumerations can be defined; for example +a ``StrEnum`` that mixes in ``str`` instead of ``int``. + +Some rules: + +1. When subclassing ``Enum``, mix-in types must appear before + ``Enum`` itself in the sequence of bases, as in the ``IntEnum`` + example above. +2. While ``Enum`` can have members of any type, once you mix in an + additional type, all the members must have values of that type, e.g. + ``int`` above. This restriction does not apply to mix-ins which only + add methods and don't specify another data type such as ``int`` or + ``str``. +3. When another data type is mixed in, the ``value`` attribute is *not the + same* as the enum member itself, although it is equivalant and will compare + equal. +4. %-style formatting: ``%s`` and ``%r`` call ``Enum``'s ``__str__`` and + ``__repr__`` respectively; other codes (such as ``%i`` or ``%h`` for + IntEnum) treat the enum member as its mixed-in type. + + Note: Prior to Python 3.4 there is a bug in ``str``'s %-formatting: ``int`` + subclasses are printed as strings and not numbers when the ``%d``, ``%i``, + or ``%u`` codes are used. +5. ``str.__format__`` (or ``format``) will use the mixed-in + type's ``__format__``. If the ``Enum``'s ``str`` or + ``repr`` is desired use the ``!s`` or ``!r`` ``str`` format codes. + + +Decorators +---------- + +unique +^^^^^^ + +A ``class`` decorator specifically for enumerations. It searches an +enumeration's ``__members__`` gathering any aliases it finds; if any are +found ``ValueError`` is raised with the details:: + + >>> @unique + ... class NoDupes(Enum): + ... first = 'one' + ... second = 'two' + ... third = 'two' + Traceback (most recent call last): + ... + ValueError: duplicate names found in <enum 'NoDupes'>: third -> second + + +Interesting examples +-------------------- + +While ``Enum`` and ``IntEnum`` are expected to cover the majority of +use-cases, they cannot cover them all. Here are recipes for some different +types of enumerations that can be used directly, or as examples for creating +one's own. + + +AutoNumber +^^^^^^^^^^ + +Avoids having to specify the value for each enumeration member:: + + >>> class AutoNumber(Enum): + ... def __new__(cls): + ... value = len(cls.__members__) + 1 + ... obj = object.__new__(cls) + ... obj._value_ = value + ... return obj + ... + >>> class Color(AutoNumber): + ... __order__ = "red green blue" # only needed in 2.x + ... red = () + ... green = () + ... blue = () + ... + >>> Color.green.value == 2 + True + +Note: + + The `__new__` method, if defined, is used during creation of the Enum + members; it is then replaced by Enum's `__new__` which is used after + class creation for lookup of existing members. Due to the way Enums are + supposed to behave, there is no way to customize Enum's `__new__`. + + +UniqueEnum +^^^^^^^^^^ + +Raises an error if a duplicate member name is found instead of creating an +alias:: + + >>> class UniqueEnum(Enum): + ... def __init__(self, *args): + ... cls = self.__class__ + ... if any(self.value == e.value for e in cls): + ... a = self.name + ... e = cls(self.value).name + ... raise ValueError( + ... "aliases not allowed in UniqueEnum: %r --> %r" + ... % (a, e)) + ... + >>> class Color(UniqueEnum): + ... red = 1 + ... green = 2 + ... blue = 3 + ... grene = 2 + Traceback (most recent call last): + ... + ValueError: aliases not allowed in UniqueEnum: 'grene' --> 'green' + + +OrderedEnum +^^^^^^^^^^^ + +An ordered enumeration that is not based on ``IntEnum`` and so maintains +the normal ``Enum`` invariants (such as not being comparable to other +enumerations):: + + >>> class OrderedEnum(Enum): + ... def __ge__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ >= other._value_ + ... return NotImplemented + ... def __gt__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ > other._value_ + ... return NotImplemented + ... def __le__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ <= other._value_ + ... return NotImplemented + ... def __lt__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ < other._value_ + ... return NotImplemented + ... + >>> class Grade(OrderedEnum): + ... __ordered__ = 'A B C D F' + ... A = 5 + ... B = 4 + ... C = 3 + ... D = 2 + ... F = 1 + ... + >>> Grade.C < Grade.A + True + + +Planet +^^^^^^ + +If ``__new__`` or ``__init__`` is defined the value of the enum member +will be passed to those methods:: + + >>> class Planet(Enum): + ... MERCURY = (3.303e+23, 2.4397e6) + ... VENUS = (4.869e+24, 6.0518e6) + ... EARTH = (5.976e+24, 6.37814e6) + ... MARS = (6.421e+23, 3.3972e6) + ... JUPITER = (1.9e+27, 7.1492e7) + ... SATURN = (5.688e+26, 6.0268e7) + ... URANUS = (8.686e+25, 2.5559e7) + ... NEPTUNE = (1.024e+26, 2.4746e7) + ... def __init__(self, mass, radius): + ... self.mass = mass # in kilograms + ... self.radius = radius # in meters + ... @property + ... def surface_gravity(self): + ... # universal gravitational constant (m3 kg-1 s-2) + ... G = 6.67300E-11 + ... return G * self.mass / (self.radius * self.radius) + ... + >>> Planet.EARTH.value + (5.976e+24, 6378140.0) + >>> Planet.EARTH.surface_gravity + 9.802652743337129 + + +How are Enums different? +------------------------ + +Enums have a custom metaclass that affects many aspects of both derived Enum +classes and their instances (members). + + +Enum Classes +^^^^^^^^^^^^ + +The ``EnumMeta`` metaclass is responsible for providing the +``__contains__``, ``__dir__``, ``__iter__`` and other methods that +allow one to do things with an ``Enum`` class that fail on a typical +class, such as ``list(Color)`` or ``some_var in Color``. ``EnumMeta`` is +responsible for ensuring that various other methods on the final ``Enum`` +class are correct (such as ``__new__``, ``__getnewargs__``, +``__str__`` and ``__repr__``) + + +Enum Members (aka instances) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The most interesting thing about Enum members is that they are singletons. +``EnumMeta`` creates them all while it is creating the ``Enum`` +class itself, and then puts a custom ``__new__`` in place to ensure +that no new ones are ever instantiated by returning only the existing +member instances. + + +Finer Points +^^^^^^^^^^^^ + +Enum members are instances of an Enum class, and even though they are +accessible as ``EnumClass.member``, they are not accessible directly from +the member:: + + >>> Color.red + <Color.red: 1> + >>> Color.red.blue + Traceback (most recent call last): + ... + AttributeError: 'Color' object has no attribute 'blue' + +Likewise, ``__members__`` is only available on the class. + +In Python 3.x ``__members__`` is always an ``OrderedDict``, with the order being +the definition order. In Python 2.7 ``__members__`` is an ``OrderedDict`` if +``__order__`` was specified, and a plain ``dict`` otherwise. In all other Python +2.x versions ``__members__`` is a plain ``dict`` even if ``__order__`` was specified +as the ``OrderedDict`` type didn't exist yet. + +If you give your ``Enum`` subclass extra methods, like the `Planet`_ +class above, those methods will show up in a `dir` of the member, +but not of the class:: + + >>> dir(Planet) + ['EARTH', 'JUPITER', 'MARS', 'MERCURY', 'NEPTUNE', 'SATURN', 'URANUS', + 'VENUS', '__class__', '__doc__', '__members__', '__module__'] + >>> dir(Planet.EARTH) + ['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value'] + +A ``__new__`` method will only be used for the creation of the +``Enum`` members -- after that it is replaced. This means if you wish to +change how ``Enum`` members are looked up you either have to write a +helper function or a ``classmethod``. diff --git a/anknotes/enum/enum.py b/anknotes/enum/enum.py new file mode 100644 index 0000000..6a327a8 --- /dev/null +++ b/anknotes/enum/enum.py @@ -0,0 +1,790 @@ +"""Python Enumerations""" + +import sys as _sys + +__all__ = ['Enum', 'IntEnum', 'unique'] + +version = 1, 0, 4 + +pyver = float('%s.%s' % _sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +try: + basestring +except NameError: + # In Python 2 basestring is the ancestor of both str and unicode + # in Python 3 it's just str, but was missing in 3.1 + basestring = str + +try: + unicode +except NameError: + # In Python 3 unicode no longer exists (it's just str) + unicode = str + +class _RouteClassAttributeToGetattr(object): + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '<unknown>' + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super(_EnumDict, self).__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + Note: in 3.x __order__ is simply discarded as a not necessary piece + leftover from 2.x + + """ + if pyver >= 3.0 and key == '__order__': + return + if _is_sunder(key): + raise ValueError('_names_ are reserved for future Enum use') + elif _is_dunder(key): + pass + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('Key already defined as: %r' % self[key]) + self._member_names.append(key) + super(_EnumDict, self).__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + if type(classdict) is dict: + original_dict = classdict + classdict = _EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + # save enum items into separate mapping so they don't get baked into + # the new class + members = dict((k, classdict[k]) for k in classdict._member_names) + for name in classdict._member_names: + del classdict[name] + + # py2 support for definition order + __order__ = classdict.get('__order__') + if __order__ is None: + if pyver < 3.0: + try: + __order__ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] + except TypeError: + __order__ = [name for name in sorted(members.keys())] + else: + __order__ = classdict._member_names + else: + del classdict['__order__'] + if pyver < 3.0: + __order__ = __order__.replace(',', ' ').split() + aliases = [name for name in members if name not in __order__] + __order__ += aliases + + # check for illegal enum names (any others?) + invalid_names = set(members) & set(['mro']) + if invalid_names: + raise ValueError('Invalid enum member name(s): %s' % ( + ', '.join(invalid_names), )) + + # create our new Enum type + enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in random order + if OrderedDict is not None: + enum_class._member_map_ = OrderedDict() + else: + enum_class._member_map_ = {} # name->value map + enum_class._member_type_ = member_type + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + if __new__ is None: + __new__ = enum_class.__new__ + for member_name in __order__: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args or not args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member.value == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + unpicklable = False + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + unpicklable = True + + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if name not in classdict and class_method is not enum_method: + if name == '__reduce_ex__' and unpicklable: + continue + setattr(enum_class, name, enum_method) + + # method resolution and int's are not playing nice + # Python's less than 2.6 use __cmp__ + + if pyver < 2.6: + + if issubclass(enum_class, int): + setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) + + elif pyver < 3.0: + + if issubclass(enum_class, int): + for method in ( + '__le__', + '__lt__', + '__gt__', + '__ge__', + '__eq__', + '__ne__', + '__hash__', + ): + setattr(enum_class, method, getattr(int, method)) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) + setattr(enum_class, '__new__', Enum.__dict__['__new__']) + return enum_class + + def __call__(cls, value, names=None, module=None, type=None): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, type=type) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super(EnumMeta, cls).__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a copy of the internal mapping. + + """ + return cls._member_map_.copy() + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __len__(cls): + return len(cls._member_names_) + + def __repr__(cls): + return "<enum %r>" % cls.__name__ + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super(EnumMeta, cls).__setattr__(name, value) + + def _create_(cls, class_name, names=None, module=None, type=None): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + if pyver < 3.0: + # if class_name is unicode, attempt a conversion to ASCII + if isinstance(class_name, unicode): + try: + class_name = class_name.encode('ascii') + except UnicodeEncodeError: + raise TypeError('%r is not representable in ASCII' % class_name) + metacls = cls.__class__ + if type is None: + bases = (cls, ) + else: + bases = (type, cls) + classdict = metacls.__prepare__(class_name, bases) + __order__ = [] + + # special processing needed for names? + if isinstance(names, basestring): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): + names = [(e, i+1) for (i, e) in enumerate(names)] + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, basestring): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + __order__.append(member_name) + # only set __order__ in classdict if name/value was not from a mapping + if not isinstance(item, basestring): + classdict['__order__'] = ' '.join(__order__) + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = _sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError): + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases or Enum is None: + return object, Enum + + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, + # <class 'int'>, <Enum 'Enum'>, + # <class 'object'>) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + if pyver < 3.0: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + if __new__: + return None, True, True # __new__, save_new, use_args + + N__new__ = getattr(None, '__new__') + O__new__ = getattr(object, '__new__') + if Enum is None: + E__new__ = N__new__ + else: + E__new__ = Enum.__dict__['__new__'] + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + try: + target = possible.__dict__[method] + except (AttributeError, KeyError): + target = getattr(possible, method, None) + if target not in [ + None, + N__new__, + O__new__, + E__new__, + ]: + if method == '__member_new__': + classdict['__new__'] = target + return None, False, True + if isinstance(target, staticmethod): + target = target.__get__(member_type) + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, False, use_args + else: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __member_new__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in ( + None, + None.__new__, + object.__new__, + Enum.__new__, + ): + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +######################################################## +# In order to support Python 2 and 3 with a single +# codebase we have to create the Enum methods separately +# and then use the `type(name, bases, dict)` method to +# create the class. +######################################################## +temp_enum_dict = {} +temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" + +def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + value = value.value + #return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) +temp_enum_dict['__new__'] = __new__ +del __new__ + +def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) +temp_enum_dict['__repr__'] = __repr__ +del __repr__ + +def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) +temp_enum_dict['__str__'] = __str__ +del __str__ + +def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' + ] + return (['__class__', '__doc__', '__module__', ] + added_behavior) +temp_enum_dict['__dir__'] = __dir__ +del __dir__ + +def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self.value + return cls.__format__(val, format_spec) +temp_enum_dict['__format__'] = __format__ +del __format__ + + +#################################### +# Python's less than 2.6 use __cmp__ + +if pyver < 2.6: + + def __cmp__(self, other): + if type(other) is self.__class__: + if self is other: + return 0 + return -1 + return NotImplemented + raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__cmp__'] = __cmp__ + del __cmp__ + +else: + + def __le__(self, other): + raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__le__'] = __le__ + del __le__ + + def __lt__(self, other): + raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__lt__'] = __lt__ + del __lt__ + + def __ge__(self, other): + raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__ge__'] = __ge__ + del __ge__ + + def __gt__(self, other): + raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__gt__'] = __gt__ + del __gt__ + + +def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented +temp_enum_dict['__eq__'] = __eq__ +del __eq__ + +def __ne__(self, other): + if type(other) is self.__class__: + return self is not other + return NotImplemented +temp_enum_dict['__ne__'] = __ne__ +del __ne__ + +def __hash__(self): + return hash(self._name_) +temp_enum_dict['__hash__'] = __hash__ +del __hash__ + +def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) +temp_enum_dict['__reduce_ex__'] = __reduce_ex__ +del __reduce_ex__ + +# _RouteClassAttributeToGetattr is used to provide access to the `name` +# and `value` properties of enum members while keeping some measure of +# protection from modification, while still allowing for an enumeration +# to have members named `name` and `value`. This works because enumeration +# members are not set directly on the enum class -- __getattr__ is +# used to look them up. + +@_RouteClassAttributeToGetattr +def name(self): + return self._name_ +temp_enum_dict['name'] = name +del name + +@_RouteClassAttributeToGetattr +def value(self): + return self._value_ +temp_enum_dict['value'] = value +del value + +Enum = EnumMeta('Enum', (object, ), temp_enum_dict) +del temp_enum_dict + +# Enum has now been created +########################### + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + + +def unique(enumeration): + """Class decorator that ensures only unique members exist in an enumeration.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + duplicate_names = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) + raise ValueError('duplicate names found in %r: %s' % + (enumeration, duplicate_names) + ) + return enumeration diff --git a/anknotes/enum/test_enum.py b/anknotes/enum/test_enum.py new file mode 100644 index 0000000..d7a9794 --- /dev/null +++ b/anknotes/enum/test_enum.py @@ -0,0 +1,1690 @@ +import enum +import sys +import unittest +from enum import Enum, IntEnum, unique, EnumMeta +from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL + +pyver = float('%s.%s' % sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + unicode +except NameError: + unicode = str + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +# for pickle tests +try: + class Stooges(Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception: + Stooges = sys.exc_info()[1] + +try: + class IntStooges(int, Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception: + IntStooges = sys.exc_info()[1] + +try: + class FloatStooges(float, Enum): + LARRY = 1.39 + CURLY = 2.72 + MOE = 3.142596 +except Exception: + FloatStooges = sys.exc_info()[1] + +# for pickle test and subclass tests +try: + class StrEnum(str, Enum): + 'accepts only string values' + class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' +except Exception: + Name = sys.exc_info()[1] + +try: + Question = Enum('Question', 'who what when where why', module=__name__) +except Exception: + Question = sys.exc_info()[1] + +try: + Answer = Enum('Answer', 'him this then there because') +except Exception: + Answer = sys.exc_info()[1] + +try: + Theory = Enum('Theory', 'rule law supposition', qualname='spanish_inquisition') +except Exception: + Theory = sys.exc_info()[1] + +# for doctests +try: + class Fruit(Enum): + tomato = 1 + banana = 2 + cherry = 3 +except Exception: + pass + +def test_pickle_dump_load(assertion, source, target=None, + protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + failures = [] + for protocol in range(start, stop+1): + try: + if target is None: + assertion(loads(dumps(source, protocol=protocol)) is source) + else: + assertion(loads(dumps(source, protocol=protocol)), target) + except Exception: + exc, tb = sys.exc_info()[1:] + failures.append('%2d: %s' %(protocol, exc)) + if failures: + raise ValueError('Failed with protocols: %s' % ', '.join(failures)) + +def test_pickle_exception(assertion, exception, obj, + protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + failures = [] + for protocol in range(start, stop+1): + try: + assertion(exception, dumps, obj, protocol=protocol) + except Exception: + exc = sys.exc_info()[1] + failures.append('%d: %s %s' % (protocol, exc.__class__.__name__, exc)) + if failures: + raise ValueError('Failed with protocols: %s' % ', '.join(failures)) + + +class TestHelpers(unittest.TestCase): + # _is_descriptor, _is_sunder, _is_dunder + + def test_is_descriptor(self): + class foo: + pass + for attr in ('__get__','__set__','__delete__'): + obj = foo() + self.assertFalse(enum._is_descriptor(obj)) + setattr(obj, attr, 1) + self.assertTrue(enum._is_descriptor(obj)) + + def test_is_sunder(self): + for s in ('_a_', '_aa_'): + self.assertTrue(enum._is_sunder(s)) + + for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_sunder(s)) + + def test_is_dunder(self): + for s in ('__a__', '__aa__'): + self.assertTrue(enum._is_dunder(s)) + for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_dunder(s)) + + +class TestEnum(unittest.TestCase): + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + + class Konstants(float, Enum): + E = 2.7182818 + PI = 3.1415926 + TAU = 2 * PI + self.Konstants = Konstants + + class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + self.Grades = Grades + + class Directional(str, Enum): + EAST = 'east' + WEST = 'west' + NORTH = 'north' + SOUTH = 'south' + self.Directional = Directional + + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + self.Holiday = Holiday + + if pyver >= 2.6: # cannot specify custom `dir` on previous versions + def test_dir_on_class(self): + Season = self.Season + self.assertEqual( + set(dir(Season)), + set(['__class__', '__doc__', '__members__', '__module__', + 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + ) + + def test_dir_on_item(self): + Season = self.Season + self.assertEqual( + set(dir(Season.WINTER)), + set(['__class__', '__doc__', '__module__', 'name', 'value']), + ) + + def test_dir_on_sub_with_behavior_on_super(self): + # see issue22506 + class SuperEnum(Enum): + def invisible(self): + return "did you see me?" + class SubEnum(SuperEnum): + sample = 5 + self.assertEqual( + set(dir(SubEnum.sample)), + set(['__class__', '__doc__', '__module__', 'name', 'value', 'invisible']), + ) + + if pyver >= 2.7: # OrderedDict first available here + def test_members_is_ordereddict_if_ordered(self): + class Ordered(Enum): + __order__ = 'first second third' + first = 'bippity' + second = 'boppity' + third = 'boo' + self.assertTrue(type(Ordered.__members__) is OrderedDict) + + def test_members_is_ordereddict_if_not_ordered(self): + class Unordered(Enum): + this = 'that' + these = 'those' + self.assertTrue(type(Unordered.__members__) is OrderedDict) + + if pyver >= 3.0: # all objects are ordered in Python 2.x + def test_members_is_always_ordered(self): + class AlwaysOrdered(Enum): + first = 1 + second = 2 + third = 3 + self.assertTrue(type(AlwaysOrdered.__members__) is OrderedDict) + + def test_comparisons(self): + def bad_compare(): + Season.SPRING > 4 + Season = self.Season + self.assertNotEqual(Season.SPRING, 1) + self.assertRaises(TypeError, bad_compare) + + class Part(Enum): + SPRING = 1 + CLIP = 2 + BARREL = 3 + + self.assertNotEqual(Season.SPRING, Part.SPRING) + def bad_compare(): + Season.SPRING < Part.CLIP + self.assertRaises(TypeError, bad_compare) + + def test_enum_in_enum_out(self): + Season = self.Season + self.assertTrue(Season(Season.WINTER) is Season.WINTER) + + def test_enum_value(self): + Season = self.Season + self.assertEqual(Season.SPRING.value, 1) + + def test_intenum_value(self): + self.assertEqual(IntStooges.CURLY.value, 2) + + def test_enum(self): + Season = self.Season + lst = list(Season) + self.assertEqual(len(lst), len(Season)) + self.assertEqual(len(Season), 4, Season) + self.assertEqual( + [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) + + for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split()): + i += 1 + e = Season(i) + self.assertEqual(e, getattr(Season, season)) + self.assertEqual(e.value, i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, season) + self.assertTrue(e in Season) + self.assertTrue(type(e) is Season) + self.assertTrue(isinstance(e, Season)) + self.assertEqual(str(e), 'Season.' + season) + self.assertEqual( + repr(e), + '<Season.%s: %s>' % (season, i), + ) + + def test_value_name(self): + Season = self.Season + self.assertEqual(Season.SPRING.name, 'SPRING') + self.assertEqual(Season.SPRING.value, 1) + def set_name(obj, new_value): + obj.name = new_value + def set_value(obj, new_value): + obj.value = new_value + self.assertRaises(AttributeError, set_name, Season.SPRING, 'invierno', ) + self.assertRaises(AttributeError, set_value, Season.SPRING, 2) + + def test_attribute_deletion(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + + def spam(cls): + pass + + self.assertTrue(hasattr(Season, 'spam')) + del Season.spam + self.assertFalse(hasattr(Season, 'spam')) + + self.assertRaises(AttributeError, delattr, Season, 'SPRING') + self.assertRaises(AttributeError, delattr, Season, 'DRY') + self.assertRaises(AttributeError, delattr, Season.SPRING, 'name') + + def test_invalid_names(self): + def create_bad_class_1(): + class Wrong(Enum): + mro = 9 + def create_bad_class_2(): + class Wrong(Enum): + _reserved_ = 3 + self.assertRaises(ValueError, create_bad_class_1) + self.assertRaises(ValueError, create_bad_class_2) + + def test_contains(self): + Season = self.Season + self.assertTrue(Season.AUTUMN in Season) + self.assertTrue(3 not in Season) + + val = Season(3) + self.assertTrue(val in Season) + + class OtherEnum(Enum): + one = 1; two = 2 + self.assertTrue(OtherEnum.two not in Season) + + if pyver >= 2.6: # when `format` came into being + + def test_format_enum(self): + Season = self.Season + self.assertEqual('{0}'.format(Season.SPRING), + '{0}'.format(str(Season.SPRING))) + self.assertEqual( '{0:}'.format(Season.SPRING), + '{0:}'.format(str(Season.SPRING))) + self.assertEqual('{0:20}'.format(Season.SPRING), + '{0:20}'.format(str(Season.SPRING))) + self.assertEqual('{0:^20}'.format(Season.SPRING), + '{0:^20}'.format(str(Season.SPRING))) + self.assertEqual('{0:>20}'.format(Season.SPRING), + '{0:>20}'.format(str(Season.SPRING))) + self.assertEqual('{0:<20}'.format(Season.SPRING), + '{0:<20}'.format(str(Season.SPRING))) + + def test_format_enum_custom(self): + class TestFloat(float, Enum): + one = 1.0 + two = 2.0 + def __format__(self, spec): + return 'TestFloat success!' + self.assertEqual('{0}'.format(TestFloat.one), 'TestFloat success!') + + def assertFormatIsValue(self, spec, member): + self.assertEqual(spec.format(member), spec.format(member.value)) + + def test_format_enum_date(self): + Holiday = self.Holiday + self.assertFormatIsValue('{0}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:^20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:>20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:<20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:%Y %m}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:%Y %m %M:00}', Holiday.IDES_OF_MARCH) + + def test_format_enum_float(self): + Konstants = self.Konstants + self.assertFormatIsValue('{0}', Konstants.TAU) + self.assertFormatIsValue('{0:}', Konstants.TAU) + self.assertFormatIsValue('{0:20}', Konstants.TAU) + self.assertFormatIsValue('{0:^20}', Konstants.TAU) + self.assertFormatIsValue('{0:>20}', Konstants.TAU) + self.assertFormatIsValue('{0:<20}', Konstants.TAU) + self.assertFormatIsValue('{0:n}', Konstants.TAU) + self.assertFormatIsValue('{0:5.2}', Konstants.TAU) + self.assertFormatIsValue('{0:f}', Konstants.TAU) + + def test_format_enum_int(self): + Grades = self.Grades + self.assertFormatIsValue('{0}', Grades.C) + self.assertFormatIsValue('{0:}', Grades.C) + self.assertFormatIsValue('{0:20}', Grades.C) + self.assertFormatIsValue('{0:^20}', Grades.C) + self.assertFormatIsValue('{0:>20}', Grades.C) + self.assertFormatIsValue('{0:<20}', Grades.C) + self.assertFormatIsValue('{0:+}', Grades.C) + self.assertFormatIsValue('{0:08X}', Grades.C) + self.assertFormatIsValue('{0:b}', Grades.C) + + def test_format_enum_str(self): + Directional = self.Directional + self.assertFormatIsValue('{0}', Directional.WEST) + self.assertFormatIsValue('{0:}', Directional.WEST) + self.assertFormatIsValue('{0:20}', Directional.WEST) + self.assertFormatIsValue('{0:^20}', Directional.WEST) + self.assertFormatIsValue('{0:>20}', Directional.WEST) + self.assertFormatIsValue('{0:<20}', Directional.WEST) + + def test_hash(self): + Season = self.Season + dates = {} + dates[Season.WINTER] = '1225' + dates[Season.SPRING] = '0315' + dates[Season.SUMMER] = '0704' + dates[Season.AUTUMN] = '1031' + self.assertEqual(dates[Season.AUTUMN], '1031') + + def test_enum_duplicates(self): + __order__ = "SPRING SUMMER AUTUMN WINTER" + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = FALL = 3 + WINTER = 4 + ANOTHER_SPRING = 1 + lst = list(Season) + self.assertEqual( + lst, + [Season.SPRING, Season.SUMMER, + Season.AUTUMN, Season.WINTER, + ]) + self.assertTrue(Season.FALL is Season.AUTUMN) + self.assertEqual(Season.FALL.value, 3) + self.assertEqual(Season.AUTUMN.value, 3) + self.assertTrue(Season(3) is Season.AUTUMN) + self.assertTrue(Season(1) is Season.SPRING) + self.assertEqual(Season.FALL.name, 'AUTUMN') + self.assertEqual( + set([k for k,v in Season.__members__.items() if v.name != k]), + set(['FALL', 'ANOTHER_SPRING']), + ) + + if pyver >= 3.0: + cls = vars() + result = {'Enum':Enum} + exec("""def test_duplicate_name(self): + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + red = 4 + + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def red(self): + return 'red' + + with self.assertRaises(TypeError): + class Color(Enum): + @property + + def red(self): + return 'redder' + red = 1 + green = 2 + blue = 3""", + result) + cls['test_duplicate_name'] = result['test_duplicate_name'] + + def test_enum_with_value_name(self): + class Huh(Enum): + name = 1 + value = 2 + self.assertEqual( + list(Huh), + [Huh.name, Huh.value], + ) + self.assertTrue(type(Huh.name) is Huh) + self.assertEqual(Huh.name.name, 'name') + self.assertEqual(Huh.name.value, 1) + + def test_intenum_from_scratch(self): + class phy(int, Enum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_intenum_inherited(self): + class IntEnum(int, Enum): + pass + class phy(IntEnum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_from_scratch(self): + class phy(float, Enum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_inherited(self): + class FloatEnum(float, Enum): + pass + class phy(FloatEnum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_from_scratch(self): + class phy(str, Enum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_inherited(self): + class StrEnum(str, Enum): + pass + class phy(StrEnum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + def test_intenum(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') + self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) + + lst = list(WeekDay) + self.assertEqual(len(lst), len(WeekDay)) + self.assertEqual(len(WeekDay), 7) + target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + target = target.split() + for i, weekday in enumerate(target): + i += 1 + e = WeekDay(i) + self.assertEqual(e, i) + self.assertEqual(int(e), i) + self.assertEqual(e.name, weekday) + self.assertTrue(e in WeekDay) + self.assertEqual(lst.index(e)+1, i) + self.assertTrue(0 < e < 8) + self.assertTrue(type(e) is WeekDay) + self.assertTrue(isinstance(e, int)) + self.assertTrue(isinstance(e, Enum)) + + def test_intenum_duplicates(self): + class WeekDay(IntEnum): + __order__ = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + SUNDAY = 1 + MONDAY = 2 + TUESDAY = TEUSDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + self.assertTrue(WeekDay.TEUSDAY is WeekDay.TUESDAY) + self.assertEqual(WeekDay(3).name, 'TUESDAY') + self.assertEqual([k for k,v in WeekDay.__members__.items() + if v.name != k], ['TEUSDAY', ]) + + def test_pickle_enum(self): + if isinstance(Stooges, Exception): + raise Stooges + test_pickle_dump_load(self.assertTrue, Stooges.CURLY) + test_pickle_dump_load(self.assertTrue, Stooges) + + def test_pickle_int(self): + if isinstance(IntStooges, Exception): + raise IntStooges + test_pickle_dump_load(self.assertTrue, IntStooges.CURLY) + test_pickle_dump_load(self.assertTrue, IntStooges) + + def test_pickle_float(self): + if isinstance(FloatStooges, Exception): + raise FloatStooges + test_pickle_dump_load(self.assertTrue, FloatStooges.CURLY) + test_pickle_dump_load(self.assertTrue, FloatStooges) + + def test_pickle_enum_function(self): + if isinstance(Answer, Exception): + raise Answer + test_pickle_dump_load(self.assertTrue, Answer.him) + test_pickle_dump_load(self.assertTrue, Answer) + + def test_pickle_enum_function_with_module(self): + if isinstance(Question, Exception): + raise Question + test_pickle_dump_load(self.assertTrue, Question.who) + test_pickle_dump_load(self.assertTrue, Question) + + if pyver >= 3.4: + def test_class_nested_enum_and_pickle_protocol_four(self): + # would normally just have this directly in the class namespace + class NestedEnum(Enum): + twigs = 'common' + shiny = 'rare' + + self.__class__.NestedEnum = NestedEnum + self.NestedEnum.__qualname__ = '%s.NestedEnum' % self.__class__.__name__ + test_pickle_exception( + self.assertRaises, PicklingError, self.NestedEnum.twigs, + protocol=(0, 3)) + test_pickle_dump_load(self.assertTrue, self.NestedEnum.twigs, + protocol=(4, HIGHEST_PROTOCOL)) + + def test_exploding_pickle(self): + BadPickle = Enum('BadPickle', 'dill sweet bread-n-butter') + enum._make_class_unpicklable(BadPickle) + globals()['BadPickle'] = BadPickle + test_pickle_exception(self.assertRaises, TypeError, BadPickle.dill) + test_pickle_exception(self.assertRaises, PicklingError, BadPickle) + + def test_string_enum(self): + class SkillLevel(str, Enum): + master = 'what is the sound of one hand clapping?' + journeyman = 'why did the chicken cross the road?' + apprentice = 'knock, knock!' + self.assertEqual(SkillLevel.apprentice, 'knock, knock!') + + def test_getattr_getitem(self): + class Period(Enum): + morning = 1 + noon = 2 + evening = 3 + night = 4 + self.assertTrue(Period(2) is Period.noon) + self.assertTrue(getattr(Period, 'night') is Period.night) + self.assertTrue(Period['morning'] is Period.morning) + + def test_getattr_dunder(self): + Season = self.Season + self.assertTrue(getattr(Season, '__hash__')) + + def test_iteration_order(self): + class Season(Enum): + __order__ = 'SUMMER WINTER AUTUMN SPRING' + SUMMER = 2 + WINTER = 4 + AUTUMN = 3 + SPRING = 1 + self.assertEqual( + list(Season), + [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], + ) + + def test_iteration_order_with_unorderable_values(self): + class Complex(Enum): + a = complex(7, 9) + b = complex(3.14, 2) + c = complex(1, -1) + d = complex(-77, 32) + self.assertEqual( + list(Complex), + [Complex.a, Complex.b, Complex.c, Complex.d], + ) + + def test_programatic_function_string(self): + SummerMonth = Enum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_string_list(self): + SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + (('june', 1), ('july', 2), ('august', 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_from_dict(self): + SummerMonth = Enum( + 'SummerMonth', + dict((('june', 1), ('july', 2), ('august', 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + if pyver < 3.0: + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_type(self): + SummerMonth = Enum('SummerMonth', 'june july august', type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode(self): + SummerMonth = Enum('SummerMonth', unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_list(self): + SummerMonth = Enum('SummerMonth', [unicode('june'), unicode('july'), unicode('august')]) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + ((unicode('june'), 1), (unicode('july'), 2), (unicode('august'), 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_from_unicode_dict(self): + SummerMonth = Enum( + 'SummerMonth', + dict(((unicode('june'), 1), (unicode('july'), 2), (unicode('august'), 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + if pyver < 3.0: + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_type(self): + SummerMonth = Enum('SummerMonth', unicode('june july august'), type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programmatic_function_unicode_class(self): + if pyver < 3.0: + class_names = unicode('SummerMonth'), 'S\xfcmm\xe9rM\xf6nth'.decode('latin1') + else: + class_names = 'SummerMonth', 'S\xfcmm\xe9rM\xf6nth' + for i, class_name in enumerate(class_names): + if pyver < 3.0 and i == 1: + self.assertRaises(TypeError, Enum, class_name, unicode('june july august')) + else: + SummerMonth = Enum(class_name, unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e.value, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_subclassing(self): + if isinstance(Name, Exception): + raise Name + self.assertEqual(Name.BDFL, 'Guido van Rossum') + self.assertTrue(Name.BDFL, Name('Guido van Rossum')) + self.assertTrue(Name.BDFL is getattr(Name, 'BDFL')) + test_pickle_dump_load(self.assertTrue, Name.BDFL) + + def test_extending(self): + def bad_extension(): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertRaises(TypeError, bad_extension) + + def test_exclude_methods(self): + class whatever(Enum): + this = 'that' + these = 'those' + def really(self): + return 'no, not %s' % self.value + self.assertFalse(type(whatever.really) is whatever) + self.assertEqual(whatever.this.really(), 'no, not that') + + def test_wrong_inheritance_order(self): + def wrong_inherit(): + class Wrong(Enum, str): + NotHere = 'error before this point' + self.assertRaises(TypeError, wrong_inherit) + + def test_intenum_transitivity(self): + class number(IntEnum): + one = 1 + two = 2 + three = 3 + class numero(IntEnum): + uno = 1 + dos = 2 + tres = 3 + self.assertEqual(number.one, numero.uno) + self.assertEqual(number.two, numero.dos) + self.assertEqual(number.three, numero.tres) + + def test_introspection(self): + class Number(IntEnum): + one = 100 + two = 200 + self.assertTrue(Number.one._member_type_ is int) + self.assertTrue(Number._member_type_ is int) + class String(str, Enum): + yarn = 'soft' + rope = 'rough' + wire = 'hard' + self.assertTrue(String.yarn._member_type_ is str) + self.assertTrue(String._member_type_ is str) + class Plain(Enum): + vanilla = 'white' + one = 1 + self.assertTrue(Plain.vanilla._member_type_ is object) + self.assertTrue(Plain._member_type_ is object) + + def test_wrong_enum_in_call(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_wrong_enum_in_mixed_call(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_mixed_enum_in_call_1(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertTrue(Monochrome(Gender.female) is Monochrome.white) + + def test_mixed_enum_in_call_2(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertTrue(Monochrome(Gender.male) is Monochrome.black) + + def test_flufl_enum(self): + class Fluflnum(Enum): + def __int__(self): + return int(self.value) + class MailManOptions(Fluflnum): + option1 = 1 + option2 = 2 + option3 = 3 + self.assertEqual(int(MailManOptions.option1), 1) + + def test_no_such_enum_member(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + self.assertRaises(ValueError, Color, 4) + self.assertRaises(KeyError, Color.__getitem__, 'chartreuse') + + def test_new_repr(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Color.blue), + "don't you just love shades of blue?", + ) + + def test_inherited_repr(self): + class MyEnum(Enum): + def __repr__(self): + return "My name is %s." % self.name + class MyIntEnum(int, MyEnum): + this = 1 + that = 2 + theother = 3 + self.assertEqual(repr(MyIntEnum.that), "My name is that.") + + def test_multiple_mixin_mro(self): + class auto_enum(EnumMeta): + def __new__(metacls, cls, bases, classdict): + original_dict = classdict + classdict = enum._EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + temp = type(classdict)() + names = set(classdict._member_names) + i = 0 + for k in classdict._member_names: + v = classdict[k] + if v == (): + v = i + else: + i = v + i += 1 + temp[k] = v + for k, v in classdict.items(): + if k not in names: + temp[k] = v + return super(auto_enum, metacls).__new__( + metacls, cls, bases, temp) + + AutoNumberedEnum = auto_enum('AutoNumberedEnum', (Enum,), {}) + + AutoIntEnum = auto_enum('AutoIntEnum', (IntEnum,), {}) + + class TestAutoNumber(AutoNumberedEnum): + a = () + b = 3 + c = () + + class TestAutoInt(AutoIntEnum): + a = () + b = 3 + c = () + + def test_subclasses_with_getnewargs(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs__(self): + return self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertTrue, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + if pyver >= 3.4: + def test_subclasses_with_getnewargs_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 2: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs_ex__(self): + return self._args, {} + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5, protocol=(4, HIGHEST_PROTOCOL)) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y, protocol=(4, HIGHEST_PROTOCOL)) + + def test_subclasses_with_reduce(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce__(self): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + def test_subclasses_with_reduce_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce_ex__(self, proto): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + def test_subclasses_without_direct_pickle_support(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, args = args[0], args[1:] + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_exception(self.assertRaises, TypeError, NEI.x) + test_pickle_exception(self.assertRaises, PicklingError, NEI) + + def test_subclasses_without_direct_pickle_support_using_name(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, args = args[0], args[1:] + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + def __reduce_ex__(self, proto): + return getattr, (self.__class__, self._name_) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + test_pickle_dump_load(self.assertTrue, NEI) + + def test_tuple_subclass(self): + class SomeTuple(tuple, Enum): + __qualname__ = 'SomeTuple' + first = (1, 'for the money') + second = (2, 'for the show') + third = (3, 'for the music') + self.assertTrue(type(SomeTuple.first) is SomeTuple) + self.assertTrue(isinstance(SomeTuple.second, tuple)) + self.assertEqual(SomeTuple.third, (3, 'for the music')) + globals()['SomeTuple'] = SomeTuple + test_pickle_dump_load(self.assertTrue, SomeTuple.first) + + def test_duplicate_values_give_unique_enum_items(self): + class AutoNumber(Enum): + __order__ = 'enum_m enum_d enum_y' + enum_m = () + enum_d = () + enum_y = () + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + self.assertEqual(int(AutoNumber.enum_d), 2) + self.assertEqual(AutoNumber.enum_y.value, 3) + self.assertTrue(AutoNumber(1) is AutoNumber.enum_m) + self.assertEqual( + list(AutoNumber), + [AutoNumber.enum_m, AutoNumber.enum_d, AutoNumber.enum_y], + ) + + def test_inherited_new_from_enhanced_enum(self): + class AutoNumber2(Enum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + class Color(AutoNumber2): + __order__ = 'red green blue' + red = () + green = () + blue = () + self.assertEqual(len(Color), 3, "wrong number of elements: %d (should be %d)" % (len(Color), 3)) + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + if pyver >= 3.0: + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_inherited_new_from_mixed_enum(self): + class AutoNumber3(IntEnum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = int.__new__(cls, value) + obj._value_ = value + return obj + class Color(AutoNumber3): + red = () + green = () + blue = () + self.assertEqual(len(Color), 3, "wrong number of elements: %d (should be %d)" % (len(Color), 3)) + Color.red + Color.green + Color.blue + + def test_ordered_mixin(self): + class OrderedEnum(Enum): + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value_ >= other._value_ + return NotImplemented + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value_ > other._value_ + return NotImplemented + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value_ <= other._value_ + return NotImplemented + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value_ < other._value_ + return NotImplemented + class Grade(OrderedEnum): + __order__ = 'A B C D F' + A = 5 + B = 4 + C = 3 + D = 2 + F = 1 + self.assertEqual(list(Grade), [Grade.A, Grade.B, Grade.C, Grade.D, Grade.F]) + self.assertTrue(Grade.A > Grade.B) + self.assertTrue(Grade.F <= Grade.C) + self.assertTrue(Grade.D < Grade.A) + self.assertTrue(Grade.B >= Grade.B) + + def test_extending2(self): + def bad_extension(): + class Shade(Enum): + def shade(self): + print(self.name) + class Color(Shade): + red = 1 + green = 2 + blue = 3 + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertRaises(TypeError, bad_extension) + + def test_extending3(self): + class Shade(Enum): + def shade(self): + return self.name + class Color(Shade): + def hex(self): + return '%s hexlified!' % self.value + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertEqual(MoreColor.magenta.hex(), '5 hexlified!') + + def test_no_duplicates(self): + def bad_duplicates(): + class UniqueEnum(Enum): + def __init__(self, *args): + cls = self.__class__ + if any(self.value == e.value for e in cls): + a = self.name + e = cls(self.value).name + raise ValueError( + "aliases not allowed in UniqueEnum: %r --> %r" + % (a, e) + ) + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + grene = 2 + self.assertRaises(ValueError, bad_duplicates) + + def test_reversed(self): + self.assertEqual( + list(reversed(self.Season)), + [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER, + self.Season.SPRING] + ) + + def test_init(self): + class Planet(Enum): + MERCURY = (3.303e+23, 2.4397e6) + VENUS = (4.869e+24, 6.0518e6) + EARTH = (5.976e+24, 6.37814e6) + MARS = (6.421e+23, 3.3972e6) + JUPITER = (1.9e+27, 7.1492e7) + SATURN = (5.688e+26, 6.0268e7) + URANUS = (8.686e+25, 2.5559e7) + NEPTUNE = (1.024e+26, 2.4746e7) + def __init__(self, mass, radius): + self.mass = mass # in kilograms + self.radius = radius # in meters + @property + def surface_gravity(self): + # universal gravitational constant (m3 kg-1 s-2) + G = 6.67300E-11 + return G * self.mass / (self.radius * self.radius) + self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80) + self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6)) + + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value_ = value + return obj + class ColorInAList(AutoNumberInAList): + __order__ = 'red green blue' + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + self.assertEqual(ColorInAList.red.value, [1]) + self.assertEqual(ColorInAList([1]), ColorInAList.red) + + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj + + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) + +class TestUnique(unittest.TestCase): + """2.4 doesn't allow class decorators, use function syntax.""" + + def test_unique_clean(self): + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + unique(Clean) + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + unique(Cleaner) + + def test_unique_dirty(self): + try: + class Dirty(Enum): + __order__ = 'one two tres' + one = 1 + two = 'dos' + tres = 1 + unique(Dirty) + except ValueError: + exc = sys.exc_info()[1] + message = exc.args[0] + self.assertTrue('tres -> one' in message) + + try: + class Dirtier(IntEnum): + __order__ = 'single double triple turkey' + single = 1 + double = 1 + triple = 3 + turkey = 3 + unique(Dirtier) + except ValueError: + exc = sys.exc_info()[1] + message = exc.args[0] + self.assertTrue('double -> single' in message) + self.assertTrue('turkey -> triple' in message) + + +class TestMe(unittest.TestCase): + + pass + +if __name__ == '__main__': + unittest.main() diff --git a/anknotes/enums.py b/anknotes/enums.py new file mode 100644 index 0000000..79b8247 --- /dev/null +++ b/anknotes/enums.py @@ -0,0 +1,111 @@ +from anknotes.enum import Enum, EnumMeta, IntEnum +from anknotes import enum + +class AutoNumber(Enum): + def __new__(cls, *args): + """ + + :param cls: + :return: + :rtype : AutoNumber + """ + value = len(cls.__members__) + 1 + if args and args[0]: + value = args[0] + while value in cls._value2member_map_: value += 1 + obj = object.__new__(cls) + obj._id_ = value + obj._value_ = value + # if obj.name in obj._member_names_: + # raise KeyError + return obj + +class OrderedEnum(Enum): + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value_ >= other._value_ + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value_ > other._value_ + return NotImplemented + + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value_ <= other._value_ + return NotImplemented + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value_ < other._value_ + return NotImplemented + + +class auto_enum(EnumMeta): + def __new__(metacls, cls, bases, classdict): + original_dict = classdict + classdict = enum._EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + temp = type(classdict)() + names = set(classdict._member_names) + i = 0 + + for k in classdict._member_names: + v = classdict[k] + if v == (): + v = i + else: + i = max(v, i) + i += 1 + temp[k] = v + for k, v in classdict.items(): + if k not in names: + temp[k] = v + return super(auto_enum, metacls).__new__( + metacls, cls, bases, temp) + + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value_ >= other._value_ + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value_ > other._value_ + return NotImplemented + + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value_ <= other._value_ + return NotImplemented + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value_ < other._value_ + return NotImplemented + + +AutoNumberedEnum = auto_enum('AutoNumberedEnum', (OrderedEnum,), {}) + +AutoIntEnum = auto_enum('AutoIntEnum', (IntEnum,), {}) + + +# +# +# class APIStatus(AutoIntEnum): +# Val1=() +# """:type : AutoIntEnum""" +# Val2=() +# """:type : AutoIntEnum""" +# Val3=() +# """:type : AutoIntEnum""" +# Val4=() +# """:type : AutoIntEnum""" +# Val5=() +# """:type : AutoIntEnum""" +# Val6=() +# """:type : AutoIntEnum""" +# +# Val1, Val2, Val3, Val4, Val5, Val6, Val7 = range(1, 8) diff --git a/anknotes/error.py b/anknotes/error.py new file mode 100644 index 0000000..b01b2e5 --- /dev/null +++ b/anknotes/error.py @@ -0,0 +1,97 @@ +import errno +from anknotes.evernote.edam.error.ttypes import EDAMErrorCode +from anknotes.base import str_safe +from anknotes.logging import log_error, log, showInfo, show_tooltip, log_dump +from anknotes.constants import * + +latestSocketError = {'code': 0, 'friendly_error_msg': '', 'constant': ''} + + +def HandleSocketError(e, strErrorBase): + global latestSocketError + errorcode = e[0] + friendly_error_msgs = { + errno.ECONNREFUSED: "Connection was refused", + errno.WSAECONNRESET: "Connection was reset or forcibly closed by the remote host", + errno.ETIMEDOUT: "Connection timed out" + } + if errorcode not in errno.errorcode: + log_error("Unknown socket error (%s) occurred: %s" % (str(errorcode), str(e))) + return False + error_constant = errno.errorcode[errorcode] + if errorcode in friendly_error_msgs: + strError = friendly_error_msgs[errorcode] + else: + strError = "Unhandled socket error (%s) occurred" % error_constant + latestSocketError = {'code': errorcode, 'friendly_error_msg': strError, 'constant': error_constant} + strError = "Error: %s while %s\r\n" % (strError, strErrorBase) + log_error(" SocketError.%s: " % error_constant + strError) + log_error(str(e)) + log(" SocketError.%s: " % error_constant + strError, 'api') + if EVERNOTE.API.EDAM_RATE_LIMIT_ERROR_HANDLING is EVERNOTE.API.RateLimitErrorHandling.AlertError: + showInfo(strError) + elif EVERNOTE.API.EDAM_RATE_LIMIT_ERROR_HANDLING is EVERNOTE.API.RateLimitErrorHandling.ToolTipError: + show_tooltip(strError) + return True + + +latestEDAMRateLimit = 0 + + +def HandleEDAMRateLimitError(e, strError): + global latestEDAMRateLimit + if not e.errorCode is EDAMErrorCode.RATE_LIMIT_REACHED: + return False + latestEDAMRateLimit = e.rateLimitDuration + m, s = divmod(e.rateLimitDuration, 60) + strError = "Error: Rate limit has been reached while %s\r\n" % strError + strError += "Please retry your request in {} min".format("%d:%02d" % (m, s)) + log_strError = " EDAMErrorCode.RATE_LIMIT_REACHED: " + strError.replace('\r\n', '\n') + log_error(log_strError) + log(log_strError, 'api') + if EVERNOTE.API.EDAM_RATE_LIMIT_ERROR_HANDLING is EVERNOTE.API.RateLimitErrorHandling.AlertError: + showInfo(strError) + elif EVERNOTE.API.EDAM_RATE_LIMIT_ERROR_HANDLING is EVERNOTE.API.RateLimitErrorHandling.ToolTipError: + show_tooltip(strError) + return True + + +lastUnicodeError = None + + +def HandleUnicodeError(log_header, e, guid, title, action='', attempt=1, content=None, field=None, attempt_max=3, + attempt_min=1): + global lastUnicodeError + object = "" + e_type = e.__class__.__name__ + is_unicode = e_type.find("Unicode") > -1 + if is_unicode: + content_type = e.object.__class__.__name__ + object = e.object[e.start - 20:e.start + 20] + elif not content: + content = "Not Provided" + content_type = "N/A" + else: + content_type = content.__class__.__name__ + log_header += ': ' + e_type + ': {field}' + content_type + (' <%s>' % action if action else '') + save_header = log_header.replace('{field}', '') + ': ' + title + log_header = log_header.format(field='%s: ' % field if field else '') + + new_error = lastUnicodeError != save_header + + if is_unicode: + return_val = 1 if attempt < attempt_max else -1 + if new_error: + log(save_header + '\n' + '-' * ANKNOTES.FORMATTING.LINE_LENGTH, 'unicode', replace_newline=False) + lastUnicodeError = save_header + log(ANKNOTES.FORMATTING.TIMESTAMP_PAD + '\t - ' + ( + ('Field %s' % field if field else 'Unknown Field') + ': ').ljust(20) + str_safe(object), 'unicode', + timestamp=False) + else: + return_val = 0 + if attempt is 1 and content: + log_dump(content, log_header, 'NonUnicodeErrors') + if (new_error and attempt >= attempt_min) or not is_unicode: + log_error(log_header + "\n - Error: %s\n - GUID: %s\n - Title: %s%s" % ( + str(e), guid, str_safe(title), '' if not object else "\n - Object: %s" % str_safe(object))) + return return_val diff --git a/anknotes/evernote/edam/notestore/ttypes.py b/anknotes/evernote/edam/notestore/ttypes.py index 9ef9ab7..c94cf54 100644 --- a/anknotes/evernote/edam/notestore/ttypes.py +++ b/anknotes/evernote/edam/notestore/ttypes.py @@ -1605,6 +1605,7 @@ def __init__(self, startIndex=None, totalNotes=None, notes=None, stoppedWords=No self.startIndex = startIndex self.totalNotes = totalNotes self.notes = notes + """:type : list[NoteMetadata]""" self.stoppedWords = stoppedWords self.searchedWords = searchedWords self.updateCount = updateCount diff --git a/anknotes/extra/ancillary/FrontTemplate.htm b/anknotes/extra/ancillary/FrontTemplate.htm new file mode 100644 index 0000000..22b2b30 --- /dev/null +++ b/anknotes/extra/ancillary/FrontTemplate.htm @@ -0,0 +1,147 @@ +<div id='Template-{{Type}}'> + <div id='Card-{{Card}}'> + <div id='Deck-{{Deck}}'> + <div id='Side-Front'> + <section class="header header-avi Field-%(Title)s-Prompt"> + <h2>ANKNOTES</h2> + + <div id='Field-%(Title)s-Prompt-New'>What is the Note's Title???</div> + </section> + <section class="header header-avi Field-%(Title)s" id='Header-Field-%(Title)s'> + <h2>ANKNOTES</h2> + + <div id='Field-%(Title)s-New'>{{%(Title)s}}</div> + </section> + + <hr id=answer> + {{#%(See Also)s}} + <div id='Header-Links'> +<span class='Field-%(See Also)s'> +<a href='javascript:;' onclick='scrollToElementVisible("Field-%(See Also)s")' class='header'>See Also</a>: +</span> + {{#%(TOC)s}} +<span class='Field-%(TOC)s'> +<a href='javascript:;' onclick='scrollToElementToggle("Field-%(TOC)s")' class='header'>[TOC]</a> +</span> + {{/%(TOC)s}} + + {{#%(Outline)s}} + {{#%(TOC)s}} + <span class='Field-%(See Also)s'> | </span> + {{/%(TOC)s}} +<span class='Field-%(Outline)s'> +<a href='javascript:;' onclick='scrollToElementToggle("Field-%(Outline)s")' class='header'>(Outline)</a> +</span> + {{/%(Outline)s}} + </div> + {{/%(See Also)s}} + + <div id='Field-%(Content)s'>{{%(Content)s}}</div> + <div id='Field-Cloze-%(Content)s'>{{cloze:%(Content)s}}</div> + + {{#%(Extra)s}} + <div id='Field-%(Extra)s-Front'> + <HR> + <span class='header'><u>Note</u>: Additional Information is Available</span></span> + </div> + <div id='Field-%(Extra)s'> + <HR> + <span class='header'><u>Additional Info</u>: </span></span> + {{%(Extra)s}} + <BR><BR> + </div> + {{/%(Extra)s}} + + <div id='Footer-Line'> + <span id='Link-EN-Self'></span> + {{#Tags}} + <span id='Tags'><span class='header'><u>Tags</u>: </span>{{Tags}}</span> + {{/Tags}} + </div> + + {{#%(See Also)s}} + <div id='Field-%(See Also)s'> + <HR> + {{%(See Also)s}} + </div> + {{/%(See Also)s}} + + {{#%(TOC)s}} + <div id='Field-%(TOC)s'> + <BR> + <HR> + {{%(TOC)s}} + </div> + {{/%(TOC)s}} + + {{#%(Outline)s}} + <div id='Field-%(Outline)s'> + <BR> + <HR> + {{%(Outline)s}} + </div> + {{/%(Outline)s}} + + + </div> + </div> + </div> +</div> + + +<script> + evernote_guid_prefix = '%(Evernote GUID Prefix)s' + evernote_uid = '%(Evernote UID)s' + evernote_shard = '%(Evernote shard)s' + function generateEvernoteLink(guid_field) { + guid = guid_field.replace(evernote_guid_prefix, '') + en_link = 'evernote://view/'+evernote_uid+'/'+evernote_shard+'/'+guid+'/'+guid+'/' + return en_link + } + function setElementDisplay(id,show) { + el = document.getElementById(id) + if (el == null) { return; } + // Assuming if display is not set, it is set to none by CSS + if (show === 0) { show = (el.style.display == 'none' || el.style.display == ''); } + el.style.display = (show ? 'block' : 'none') + } + function hideElement(id) { + setElementDisplay(id, false); + } + function showElement(id) { + setElementDisplay(id, true); + } + function toggleElement(id) { + setElementDisplay(id, 0); + } + + function scrollToElement(id, show) { + setElementDisplay(id, show); + el = document.getElementById(id) + if (el == null) { return; } + window.scroll(0,findPos(el)); + } + function scrollToElementToggle(id) { + scrollToElement(id, 0); + } + function scrollToElementVisible(id) { + scrollToElement(id, true); + } +//Finds y value of given object +function findPos(obj) { + var curtop = 0; + if (obj.offsetParent) { + do { + curtop += obj.offsetTop; + } while (obj = obj.offsetParent); + return [curtop]; + } +} + +document.getElementById('Link-EN-Self').innerHTML = "<a href='" + generateEvernoteLink('{{%(Evernote GUID)s}}') + "'>Open in EN</a> <span class='separator'> | </span>" +document.getElementById('Field-%(Title)s-New').innerHTML = "<span class='link'>" + document.getElementById('Field-%(Title)s-New').innerHTML + "</span>" +document.getElementById('Header-Field-%(Title)s').outerHTML = "<a href='" + generateEvernoteLink('{{%(Evernote GUID)s}}') + "'>" + document.getElementById('Header-Field-%(Title)s').outerHTML + "</a>" + + + +</script> diff --git a/anknotes/extra/ancillary/QMessageBox.css b/anknotes/extra/ancillary/QMessageBox.css new file mode 100644 index 0000000..2a37c1f --- /dev/null +++ b/anknotes/extra/ancillary/QMessageBox.css @@ -0,0 +1,15 @@ +table tr.tr0, table tr.tr0 td.td1, table tr.tr0 td.td2, table tr.tr0 td.td3 { background: rgb(78, 124, 39); height: 42px; font-size: 32px; cell-spacing: 0px; padding-top: 1px; padding-bottom: 1px; } +table { border: 1px solid black; border-bottom: 10px solid black; } +table tr td { border: 1px solid black; } +tr.std { background: rgb(105, 170, 53); cell-spacing: 0px; } +tr.alt { background: rgb(135, 187, 93); cell-spacing: 0px; } +tr.tr0 { background: rgb(78, 124, 39); height: 42px; font-size: 32px; cell-spacing: 0px; } +tr.tr0 td.td2 { color: rgb(173, 0, 0); } +tr.tr0 td.td3 { color: #444; } +a { color: rgb(105, 170, 53); font-weight:bold; } +a:hover { color: rgb(135, 187, 93); font-weight:bold; text-decoration: none; } +a:active { color: rgb(135, 187, 93); font-weight:bold; text-decoration: none; } +table a { color: rgb(106, 0, 129);} +td.td1 { font-weight: bold; color: #bf0060; text-align: center; padding-left: 10px; padding-right: 10px; } +td.td2 { text-transform: uppercase; font-weight: bold; color: #0060bf; padding-left:20px; padding-right:20px; font-size: 18px; } +td.td3 { color: #666; font-size: 10px; } diff --git a/anknotes/extra/ancillary/_AviAnkiCSS.css b/anknotes/extra/ancillary/_AviAnkiCSS.css new file mode 100644 index 0000000..84f68b7 --- /dev/null +++ b/anknotes/extra/ancillary/_AviAnkiCSS.css @@ -0,0 +1,404 @@ +/* @import url("_AviAnkiCSS.css"); */ +@import url(https://fonts.googleapis.com/css?family=Roboto:400,300,700,100); + +/******************************************************************************************************* + Default Card Rules +*******************************************************************************************************/ + +body { + font-family: Roboto, sans-serif; + background: rgb(44, 61, 81); + padding: 1em; + -webkit-font-smoothing: antialiased; +} + +.card { + font-size: 20px; + text-align: left; + background-color: white; + color: black; +} + +/******************************************************************************************************* + Header rectangles, which sit at the top of every single card +*******************************************************************************************************/ + +/* Positioning: Vertical Align: Middle */ + +section.header { + display: block; + width: 100%; + height: 200px; + position: relative; + text-align: center; + -webkit-transform-style: preserve-3d; + -moz-transform-style: preserve-3d; + transform-style: preserve-3d; +} + +section.header div { + padding: 1em; + margin: 0px; + position: relative; + top: 50%; + -webkit-transform: translateY(-50%); + -ms-transform: translateY(-50%); + transform: translateY(-50%); +} + +/* Other Styling */ + +section.header { + background: rgb(231, 76, 60); + margin: 0px; + border-radius: .2em; + /* border:3px #990000 solid; */ +} + +/******************************************************************************************************* + ANKNOTES Top-Right call out +*******************************************************************************************************/ + +section.header h2 { + text-transform: uppercase; + margin: 0; + font-size: 16px; + position: absolute; + top: 5px; + right: 5px; + font-weight: bold; + /* color: rgb(236, 240, 241); */ + color: rgb(44, 61, 81); + /* border:3px #990000 solid; */ +} + +section.header.header-avi h2, section.header.header-bluewhitered h2 { +color: rgb(138, 16, 16); +} + +/******************************************************************************************************* + TITLE Fields +*******************************************************************************************************/ + +.card section.header.Field-Title, .card section.header.Field-Title-Prompt { + text-align: center; + font-family: Tahoma; + font-weight: bold; + font-size: 72px; + font-variant: small-caps; + padding:10px; + /* color: #A40F2D; */ +} + +.card section.header a { + text-decoration: none; +} + +.card section.header:hover #Field-Title-New span.link, .card section.header:hover #Field-Title-Prompt-New span.link { + border-bottom: none; +} + +.card section.header.Field-Title-Prompt { + color: #a90030; +} + +.card section.header #Field-Title-Prompt-New span.link { + border-bottom-color: #a90030; +} + +.card a:hover #Field-Title-Prompt-New { + color: rgb(210, 13, 13); + } + +/******************************************************************************************************* + Header bars with custom gradient backgrounds +*******************************************************************************************************/ + +section.header +{ + text-shadow: 1px 1px 2px rgba(0,0,0,0.2); + text-shadow: black 0 1px; +} + +section.header.header-redorange { + color:#990000; +} + +a:hover section.header.header-redorange { + color: rgb(106, 6, 6); +} + +.card section.header.header-redorange.Field-Title span.link { + border-bottom-color:#990000; +} + +section.header.header-redorange { + /* Background gradient code */ + background: -moz-linear-gradient(left, #ff1a00 0%, #fff200 36%, #fff200 58%, #ff1a00 100%); /* FF3.6+ */ + background: -webkit-gradient(linear, left top, right top, color-stop(0%,#ff1a00), color-stop(36%,#fff200), color-stop(58%,#fff200), color-stop(100%,#ff1a00)); /* Chrome,Safari4+ */ + background: -webkit-linear-gradient(left, #ff1a00 0%,#fff200 36%,#fff200 58%,#ff1a00 100%); /* Chrome10+,Safari5.1+ */ + background: -o-linear-gradient(left, #ff1a00 0%,#fff200 36%,#fff200 58%,#ff1a00 100%); /* Opera 11.10+ */ + background: -ms-linear-gradient(left, #ff1a00 0%,#fff200 36%,#fff200 58%,#ff1a00 100%); /* IE10+ */ + background: linear-gradient(to right, #ff1a00 0%,#fff200 36%,#fff200 58%,#ff1a00 100%); /* W3C */ + filter: progid:DXImageTransform.Microsoft.gradient( startColorstr='#ff1a00', endColorstr='#ff1a00',GradientType=1 ); /* IE6-9 */ + /* z-index: 100; /* the stack order: foreground */ + /* border-bottom:1px #990000 solid; */ + border:3px #990000 solid; +} + +section.header.header-bluewhitered, +section.header.header-avi { + color:#004C99; +} + +a:hover section.header.header-bluewhitered, +a:hover section.header.header-avi { +color: rgb(10, 121, 243); +} + +.card section.header.header-bluewhitered.Field-Title span.link, +.card section.header.header-avi.Field-Title span.link { + border-bottom-color:#004C99; +} + +section.header.header-bluewhitered, +section.header.header-avi { + /* Background gradient code */ + background: #3b679e; /* Old browsers */ + background: -moz-linear-gradient(left, #3b679e 0%, #ffffff 38%, #ffffff 59%, #ff1111 100%); /* FF3.6+ */ + background: -webkit-gradient(linear, left top, right top, color-stop(0%,#3b679e), color-stop(38%,#ffffff), color-stop(59%,#ffffff), color-stop(100%,#ff1111)); /* Chrome,Safari4+ */ + background: -webkit-linear-gradient(left, #3b679e 0%,#ffffff 38%,#ffffff 59%,#ff1111 100%); /* Chrome10+,Safari5.1+ */ + background: -o-linear-gradient(left, #3b679e 0%,#ffffff 38%,#ffffff 59%,#ff1111 100%); /* Opera 11.10+ */ + background: -ms-linear-gradient(left, #3b679e 0%,#ffffff 38%,#ffffff 59%,#ff1111 100%); /* IE10+ */ + background: linear-gradient(to right, #3b679e 0%,#ffffff 38%,#ffffff 59%,#ff1111 100%); /* W3C */ + filter: progid:DXImageTransform.Microsoft.gradient( startColorstr='#3b679e', endColorstr='#ff1111',GradientType=1 ); /* IE6-9 */ + /* z-index: 100; /* the stack order: foreground */ + /* border-bottom:1px #004C99 solid; */ + border:3px #990000 solid; +} + +section.header.header-avi-bluered { +color:#80A6CC; +} + +a:hover section.header.header-avi-bluered { +color: rgb(241, 135, 154); +} + + +.card section.header.header-avi-bluered.Field-Title span.link { + border-bottom-color:#80A6CC; +} + +section.header.header-avi-bluered { + /* Background gradient code */ + background: -moz-linear-gradient(left, #bf0060 0%, #0060bf 36%, #0060bf 58%, #bf0060 100%); /* FF3.6+ */ + background: -webkit-gradient(linear, left top, right top, color-stop(0%,#bf0060), color-stop(36%,#0060bf), color-stop(58%,#0060bf), color-stop(100%,#bf0060)); /* Chrome,Safari4+ */ + background: -webkit-linear-gradient(left, #bf0060 0%,#0060bf 36%,#0060bf 58%,#bf0060 100%); /* Chrome10+,Safari5.1+ */ + background: -o-linear-gradient(left, #bf0060 0%,#0060bf 36%,#0060bf 58%,#bf0060 100%); /* Opera 11.10+ */ + background: -ms-linear-gradient(left, #bf0060 0%,#0060bf 36%,#0060bf 58%,#bf0060 100%); /* IE10+ */ + background: linear-gradient(to right, #bf0060 0%,#0060bf 36%,#0060bf 58%,#bf0060 100%); /* W3C */ + filter: progid:DXImageTransform.Microsoft.gradient( startColorstr='#0060bf', endColorstr='#bf0060',GradientType=1 ); /* IE6-9 */ + /* z-index: 100; /* the stack order: foreground */ + /* border-bottom:1px #004C99 solid; */ + border:3px #004C99 solid; +} + +/******************************************************************************************************* + Headers with Links for See Also, TOC, Outline +*******************************************************************************************************/ +.card #Header-Links { + font-size: 14px; + margin-top: -20px; + margin-bottom: 10px; + font-weight: bold; +} + +.card #Field-Header-Links #Field-See_Also-Link { + color: rgb(45, 79, 201); +} + +/******************************************************************************************************* + HTML Link Elements +*******************************************************************************************************/ + +a { + color: rgb(105, 170, 53); + text-decoration: underline; +} + +a:hover { + color: rgb(135, 187, 93); + text-decoration: none; +} + +.card .See_Also a, .card .Field-See_Also a, .card #Field-Header-Links #Field-See_Also-Link a .Note_Link { + color: rgb(45, 79, 201); +} + +.card .See_Also a:hover, .card .Field-See_Also a:hover, .card #Field-Header-Links #Field-See_Also-Link a:hover .Note_Link { + color: rgb(108, 132, 217); +} + +.card .Field-TOC a , .card #Field-Header-Links #Field-TOC-Link a .Note_Link{ + color: rgb(173, 0, 0); +} + +.card .Field-TOC a:hover, .card #Field-Header-Links #Field-TOC-Link a:hover .Note_Link { + color: rgb(196, 71, 71); +} + +.card .Field-Outline a, .card #Field-Header-Links #Field-Outline-Link a .Note_Link, { + color: rgb(105, 170, 53); +} +.card .Field-Outline a:hover , .card #Field-Header-Links #Field-Outline-Link a:hover .Note_Link{ + color: rgb(135, 187, 93); +} +.card #Link-EN-Self a { + color: rgb(30, 155, 67) +} + +.card #Link-EN-Self a:hover { + color: rgb(107, 226, 143) +} + +/******************************************************************************************************* + TOC/Outline Headers (Automatically generated and placed in TOC/Outline fields when > 1 source note) +*******************************************************************************************************/ + +.card .TOC, .card .Outline { + font-weight: bold; +} + +.card .TOC { + color: rgb(173, 0, 0); +} + +.card .Outline { + color: rgb(105, 170, 53); +} + +.card .TOC .header, .card .Outline .header { + text-decoration: underline; + color: #bf0060; +} + +.card .TOC .header:nth-of-type(1){ + color: rgb(173, 0, 0); +} + +.card .Outline .header:nth-of-type(1) { + color: rgb(105, 170, 53); +} + + + +/******************************************************************************************************* + Per-Field Rules +*******************************************************************************************************/ + +.card #Field-Extra , #Field-Extra-Front, +.card #Field-See_Also , .card #Field-TOC , .card #Field-Outline { + font-size: 14px; +} + +.card #Footer-Line { + font-size: 10px; +} + +.card #Field-See_Also ol { + padding-top: 0px; + margin-top: 0px; +} + +.card #Field-See_Also hr { + padding-top: 0px; + padding-bottom: 0px; + margin-top: 5px; + margin-bottom: 10px; +} + +/******************************************************************************************************* + Extra Field/Tags Rules +*******************************************************************************************************/ + +.card #Field-Extra , #Field-Extra-Front, +.card #Tags { + color: #aaa; +} + +.card #Field-Extra .header, #Field-Extra-Front .header { + color: #666; +} +.card #Field-Extra .header, #Field-Extra-Front .header , #Tags .header { + font-weight: bold; +} + +.card #Field-Extra-Front { + color: #444; +} + +/******************************************************************************************************* + Special Span Classes +*******************************************************************************************************/ + +.card .occluded { + color: #555; +} +.card div.occluded { + display: inline-block; + margin: 0px; padding: 0px; +} +.card div.occluded :first-child { + margin-top: 0px; +} +.card div.occluded :last-child { + margin-bottom: 0px; +} + +.card .See_Also, .card .Field-See_Also, .card .separator { + color: rgb(45, 79, 201); + font-weight: bold; +} + +/******************************************************************************************************* + Default Visibility Rules +*******************************************************************************************************/ + + +.card #Field-Cloze-Content, +.card section.header.Field-Title-Prompt, +.card #Side-Front #Footer-Line, +.card #Side-Front #Field-See_Also, +.card #Field-TOC, +.card #Field-Outline, +.card #Side-Front #Field-Extra, +.card #Side-Back #Field-Extra-Front, +.card #Card-EvernoteReviewCloze #Field-Content, +.card #Card-EvernoteReviewReversed #Side-Front section.header.Field-Title +{ + display: none; +} + +.card #Side-Front #Header-Links, +.card #Card-EvernoteReview #Side-Front #Field-Content, +.card #Side-Front .occluded +{ + visibility: hidden; +} + +.card #Card-EvernoteReviewCloze #Field-Cloze-Content, +.card #Card-EvernoteReviewReversed #Side-Front section.header.Field-Title-Prompt, +.card #Side-Back #Field-See_Also +{ + display: block; +} + +/******************************************************************************************************* + Rules for Anki-Generated Classes +*******************************************************************************************************/ + +.cloze { + font-weight: bold; + color: blue; +} \ No newline at end of file diff --git a/anknotes/extra/ancillary/_attributes.css b/anknotes/extra/ancillary/_attributes.css new file mode 100644 index 0000000..9ae94de --- /dev/null +++ b/anknotes/extra/ancillary/_attributes.css @@ -0,0 +1,113 @@ +/******************************************************************************************************* + Helpful Attributes +*******************************************************************************************************/ + +/* + + Colors: + <OL> + Levels + 'OL': { + 1: { + 'Default': 'rgb(106, 0, 129);', + 'Hover': 'rgb(168, 0, 204);' + }, + 2: { + 'Default': 'rgb(235, 0, 115);', + 'Hover': 'rgb(255, 94, 174);' + }, + 3: { + 'Default': 'rgb(186, 0, 255);', + 'Hover': 'rgb(213, 100, 255);' + }, + 4: { + 'Default': 'rgb(129, 182, 255);', + 'Hover': 'rgb(36, 130, 255);' + }, + 5: { + 'Default': 'rgb(232, 153, 220);', + 'Hover': 'rgb(142, 32, 125);' + }, + 6: { + 'Default': 'rgb(201, 213, 172);', + 'Hover': 'rgb(130, 153, 77);' + }, + 7: { + 'Default': 'rgb(231, 179, 154);', + 'Hover': 'rgb(215, 129, 87);' + }, + 8: { + 'Default': 'rgb(249, 136, 198);', + 'Hover': 'rgb(215, 11, 123);' + } + Headers + Auto TOC: + color: rgb(11, 59, 225); + Modifiers + Orange: + color: rgb(222, 87, 0); + Orange (Light): + color: rgb(250, 122, 0); + Dark Red/Pink: + color: rgb(164, 15, 45); + Pink Alternative LVL1: + color: rgb(188, 0, 88); + + Header Boxes + Red-Orange: + Gradient Start: + color: rgb(255, 26, 0); + Gradient End: + color: rgb(255, 242, 0); + Title: + color: rgb(153, 0, 0); + color: rgb(106, 6, 6); + Blue-White-Red + Gradient Start: + color: rgb(59, 103, 158); + Gradient End: + color: rgb(255, 17, 17); + Title: + color: rgb(0, 76, 153); + color: rgb(10, 121, 243); + Old Border: + color: rgb(0, 76, 153); + Avi-Blue-Red + Gradient Start: + color: rgb(0, 96, 191); + Gradient End: + color: rgb(191, 0, 96); + Title: + color: rgb(128, 166, 204); + color: rgb(241, 135, 154); + Old Border: + color: rgb(0, 76, 153); + Borders + color: rgb(153, 0, 0); + + Titles: + Field Title Prompt: + color: rgb(169, 0, 48); + + See Also (Link + Hover) + See Also: + color: rgb(45, 79, 201); + color: rgb(108, 132, 217); + TOC: + color: rgb(173, 0, 0); + color: rgb(196, 71, 71); + Outline: + color: rgb(105, 170, 53); + color: rgb(135, 187, 93); + + Evernote Anknotes Self-Referential Link + color: rgb(30, 155, 67) + color: rgb(107, 226, 143) + + Evernote Classic (In-App) Note Link + color: rgb(105, 170, 53); + color: rgb(135, 187, 93); + + Unused: + color: rgb(122, 220, 241); +*/ \ No newline at end of file diff --git a/anknotes/extra/ancillary/enml2.dtd b/anknotes/extra/ancillary/enml2.dtd new file mode 100644 index 0000000..4bff331 --- /dev/null +++ b/anknotes/extra/ancillary/enml2.dtd @@ -0,0 +1,592 @@ +<!-- + + Evernote Markup Language (ENML) 2.0 DTD + + This expresses the structure of an XML document that can be used as the + 'content' of a Note within Evernote's data model. + The Evernote service will reject attempts to create or update notes if + their contents do not validate against this DTD. + + This is based on a subset of XHTML which is intentionally broadened to + reject less real-world HTML, to reduce the likelihood of synchronization + failures. This means that all attributes are defined as CDATA instead of + more-restrictive types, and every HTML element may embed every other + HTML element. + + Copyright (c) 2007-2009 Evernote Corp. + + $Date: 2007/10/15 18:00:00 $ + +--> + + <!--=========== External character mnemonic entities ===================--> + + <!ENTITY % HTMLlat1 PUBLIC + "-//W3C//ENTITIES Latin 1 for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-lat1.ent"> + %HTMLlat1; + + <!ENTITY % HTMLsymbol PUBLIC + "-//W3C//ENTITIES Symbols for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-symbol.ent"> + %HTMLsymbol; + + <!ENTITY % HTMLspecial PUBLIC + "-//W3C//ENTITIES Special for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-special.ent"> + %HTMLspecial; + + <!--=================== Generic Attributes ===============================--> + + <!ENTITY % coreattrs + "style CDATA #IMPLIED + title CDATA #IMPLIED" + > + + <!ENTITY % i18n + "lang CDATA #IMPLIED + xml:lang CDATA #IMPLIED + dir CDATA #IMPLIED" + > + + <!ENTITY % focus + "accesskey CDATA #IMPLIED + tabindex CDATA #IMPLIED" + > + + <!ENTITY % attrs + "%coreattrs; + %i18n;" + > + + <!ENTITY % TextAlign + "align CDATA #IMPLIED" + > + + <!ENTITY % cellhalign + "align CDATA #IMPLIED + char CDATA #IMPLIED + charoff CDATA #IMPLIED" + > + + <!ENTITY % cellvalign + "valign CDATA #IMPLIED" + > + + <!ENTITY % AnyContent + "( #PCDATA | + a | + abbr | + acronym | + address | + area | + b | + bdo | + big | + blockquote | + br | + caption | + center | + cite | + code | + col | + colgroup | + dd | + del | + dfn | + div | + dl | + dt | + em | + en-crypt | + en-media | + en-todo | + font | + h1 | + h2 | + h3 | + h4 | + h5 | + h6 | + hr | + i | + img | + ins | + kbd | + li | + map | + ol | + p | + pre | + q | + s | + samp | + small | + span | + strike | + strong | + sub | + sup | + table | + tbody | + td | + tfoot | + th | + thead | + tr | + tt | + u | + ul | + var )*" + > + + <!--=========== Evernote-specific Elements and Attributes ===============--> + + <!ELEMENT en-note %AnyContent;> + <!ATTLIST en-note + %attrs; + bgcolor CDATA #IMPLIED + text CDATA #IMPLIED + xmlns CDATA #FIXED 'http://xml.evernote.com/pub/enml2.dtd' + > + + <!ELEMENT en-crypt (#PCDATA)> + <!ATTLIST en-crypt + hint CDATA #IMPLIED + cipher CDATA "RC2" + length CDATA "64" + > + + <!ELEMENT en-todo EMPTY> + <!ATTLIST en-todo + checked (true|false) "false" + > + + <!ELEMENT en-media EMPTY> + <!ATTLIST en-media + %attrs; + type CDATA #REQUIRED + hash CDATA #REQUIRED + height CDATA #IMPLIED + width CDATA #IMPLIED + usemap CDATA #IMPLIED + align CDATA #IMPLIED + border CDATA #IMPLIED + hspace CDATA #IMPLIED + vspace CDATA #IMPLIED + longdesc CDATA #IMPLIED + alt CDATA #IMPLIED + > + + <!--=========== Simplified HTML Elements and Attributes ===============--> + + <!ELEMENT a %AnyContent;> + <!ATTLIST a + %attrs; + %focus; + charset CDATA #IMPLIED + type CDATA #IMPLIED + name CDATA #IMPLIED + href CDATA #IMPLIED + hreflang CDATA #IMPLIED + rel CDATA #IMPLIED + rev CDATA #IMPLIED + shape CDATA #IMPLIED + coords CDATA #IMPLIED + target CDATA #IMPLIED + > + + <!ELEMENT abbr %AnyContent;> + <!ATTLIST abbr + %attrs; + > + + <!ELEMENT acronym %AnyContent;> + <!ATTLIST acronym + %attrs; + > + + <!ELEMENT address %AnyContent;> + <!ATTLIST address + %attrs; + > + + <!ELEMENT area %AnyContent;> + <!ATTLIST area + %attrs; + %focus; + shape CDATA #IMPLIED + coords CDATA #IMPLIED + href CDATA #IMPLIED + nohref CDATA #IMPLIED + alt CDATA #IMPLIED + target CDATA #IMPLIED + > + + <!ELEMENT b %AnyContent;> + <!ATTLIST b + %attrs; + > + + <!ELEMENT bdo %AnyContent;> + <!ATTLIST bdo + %coreattrs; + lang CDATA #IMPLIED + xml:lang CDATA #IMPLIED + dir CDATA #IMPLIED + > + + <!ELEMENT big %AnyContent;> + <!ATTLIST big + %attrs; + > + + <!ELEMENT blockquote %AnyContent;> + <!ATTLIST blockquote + %attrs; + cite CDATA #IMPLIED + > + + <!ELEMENT br %AnyContent;> + <!ATTLIST br + %coreattrs; + clear CDATA #IMPLIED + > + + <!ELEMENT caption %AnyContent;> + <!ATTLIST caption + %attrs; + align CDATA #IMPLIED + > + + <!ELEMENT center %AnyContent;> + <!ATTLIST center + %attrs; + > + + <!ELEMENT cite %AnyContent;> + <!ATTLIST cite + %attrs; + > + + <!ELEMENT code %AnyContent;> + <!ATTLIST code + %attrs; + > + + <!ELEMENT col %AnyContent;> + <!ATTLIST col + %attrs; + %cellhalign; + %cellvalign; + span CDATA #IMPLIED + width CDATA #IMPLIED + > + + <!ELEMENT colgroup %AnyContent;> + <!ATTLIST colgroup + %attrs; + %cellhalign; + %cellvalign; + span CDATA #IMPLIED + width CDATA #IMPLIED + > + + <!ELEMENT dd %AnyContent;> + <!ATTLIST dd + %attrs; + > + + <!ELEMENT del %AnyContent;> + <!ATTLIST del + %attrs; + cite CDATA #IMPLIED + datetime CDATA #IMPLIED + > + + <!ELEMENT dfn %AnyContent;> + <!ATTLIST dfn + %attrs; + > + + <!ELEMENT div %AnyContent;> + <!ATTLIST div + %attrs; + %TextAlign; + > + + <!ELEMENT dl %AnyContent;> + <!ATTLIST dl + %attrs; + compact CDATA #IMPLIED + > + + <!ELEMENT dt %AnyContent;> + <!ATTLIST dt + %attrs; + > + + <!ELEMENT em %AnyContent;> + <!ATTLIST em + %attrs; + > + + <!ELEMENT font %AnyContent;> + <!ATTLIST font + %coreattrs; + %i18n; + size CDATA #IMPLIED + color CDATA #IMPLIED + face CDATA #IMPLIED + > + + <!ELEMENT h1 %AnyContent;> + <!ATTLIST h1 + %attrs; + %TextAlign; + > + + <!ELEMENT h2 %AnyContent;> + <!ATTLIST h2 + %attrs; + %TextAlign; + > + + <!ELEMENT h3 %AnyContent;> + <!ATTLIST h3 + %attrs; + %TextAlign; + > + + <!ELEMENT h4 %AnyContent;> + <!ATTLIST h4 + %attrs; + %TextAlign; + > + + <!ELEMENT h5 %AnyContent;> + <!ATTLIST h5 + %attrs; + %TextAlign; + > + + <!ELEMENT h6 %AnyContent;> + <!ATTLIST h6 + %attrs; + %TextAlign; + > + + <!ELEMENT hr %AnyContent;> + <!ATTLIST hr + %attrs; + align CDATA #IMPLIED + noshade CDATA #IMPLIED + size CDATA #IMPLIED + width CDATA #IMPLIED + > + + <!ELEMENT i %AnyContent;> + <!ATTLIST i + %attrs; + > + + <!ELEMENT img %AnyContent;> + <!ATTLIST img + %attrs; + src CDATA #IMPLIED + alt CDATA #IMPLIED + name CDATA #IMPLIED + longdesc CDATA #IMPLIED + height CDATA #IMPLIED + width CDATA #IMPLIED + usemap CDATA #IMPLIED + ismap CDATA #IMPLIED + align CDATA #IMPLIED + border CDATA #IMPLIED + hspace CDATA #IMPLIED + vspace CDATA #IMPLIED + > + + <!ELEMENT ins %AnyContent;> + <!ATTLIST ins + %attrs; + cite CDATA #IMPLIED + datetime CDATA #IMPLIED + > + + <!ELEMENT kbd %AnyContent;> + <!ATTLIST kbd + %attrs; + > + + <!ELEMENT li %AnyContent;> + <!ATTLIST li + %attrs; + type CDATA #IMPLIED + value CDATA #IMPLIED + > + + <!ELEMENT map %AnyContent;> + <!ATTLIST map + %i18n; + title CDATA #IMPLIED + name CDATA #IMPLIED + > + + <!ELEMENT ol %AnyContent;> + <!ATTLIST ol + %attrs; + type CDATA #IMPLIED + compact CDATA #IMPLIED + start CDATA #IMPLIED + > + + <!ELEMENT p %AnyContent;> + <!ATTLIST p + %attrs; + %TextAlign; + > + + <!ELEMENT pre %AnyContent;> + <!ATTLIST pre + %attrs; + width CDATA #IMPLIED + xml:space (preserve) #FIXED 'preserve' + > + + <!ELEMENT q %AnyContent;> + <!ATTLIST q + %attrs; + cite CDATA #IMPLIED + > + + <!ELEMENT s %AnyContent;> + <!ATTLIST s + %attrs; + > + + <!ELEMENT samp %AnyContent;> + <!ATTLIST samp + %attrs; + > + + <!ELEMENT small %AnyContent;> + <!ATTLIST small + %attrs; + > + + <!ELEMENT span %AnyContent;> + <!ATTLIST span + %attrs; + > + + <!ELEMENT strike %AnyContent;> + <!ATTLIST strike + %attrs; + > + + <!ELEMENT strong %AnyContent;> + <!ATTLIST strong + %attrs; + > + + <!ELEMENT sub %AnyContent;> + <!ATTLIST sub + %attrs; + > + + <!ELEMENT sup %AnyContent;> + <!ATTLIST sup + %attrs; + > + + <!ELEMENT table %AnyContent;> + <!ATTLIST table + %attrs; + summary CDATA #IMPLIED + width CDATA #IMPLIED + border CDATA #IMPLIED + cellspacing CDATA #IMPLIED + cellpadding CDATA #IMPLIED + align CDATA #IMPLIED + bgcolor CDATA #IMPLIED + > + + <!ELEMENT tbody %AnyContent;> + <!ATTLIST tbody + %attrs; + %cellhalign; + %cellvalign; + > + + <!ELEMENT td %AnyContent;> + <!ATTLIST td + %attrs; + %cellhalign; + %cellvalign; + abbr CDATA #IMPLIED + rowspan CDATA #IMPLIED + colspan CDATA #IMPLIED + nowrap CDATA #IMPLIED + bgcolor CDATA #IMPLIED + width CDATA #IMPLIED + height CDATA #IMPLIED + > + + <!ELEMENT tfoot %AnyContent;> + <!ATTLIST tfoot + %attrs; + %cellhalign; + %cellvalign; + > + + <!ELEMENT th %AnyContent;> + <!ATTLIST th + %attrs; + %cellhalign; + %cellvalign; + abbr CDATA #IMPLIED + rowspan CDATA #IMPLIED + colspan CDATA #IMPLIED + nowrap CDATA #IMPLIED + bgcolor CDATA #IMPLIED + width CDATA #IMPLIED + height CDATA #IMPLIED + > + + <!ELEMENT thead %AnyContent;> + <!ATTLIST thead + %attrs; + %cellhalign; + %cellvalign; + > + + <!ELEMENT tr %AnyContent;> + <!ATTLIST tr + %attrs; + %cellhalign; + %cellvalign; + bgcolor CDATA #IMPLIED + > + + <!ELEMENT tt %AnyContent;> + <!ATTLIST tt + %attrs; + > + + <!ELEMENT u %AnyContent;> + <!ATTLIST u + %attrs; + > + + <!ELEMENT ul %AnyContent;> + <!ATTLIST ul + %attrs; + type CDATA #IMPLIED + compact CDATA #IMPLIED + > + + <!ELEMENT var %AnyContent;> + <!ATTLIST var + %attrs; + > diff --git a/anknotes/extra/ancillary/index.html b/anknotes/extra/ancillary/index.html new file mode 100644 index 0000000..5fa4057 --- /dev/null +++ b/anknotes/extra/ancillary/index.html @@ -0,0 +1,96 @@ +<html lang="en"> +<head> + <meta charset="utf-8"> + <meta http-equiv="X-UA-Compatible" content="IE=edge"> + <meta name="viewport" content="width=device-width, initial-scale=1"> + <meta name="description" content=""> + <meta name="author" content=""> + <link rel="icon" href="../../favicon.ico"> + + <title>Starter Template for Bootstrap + + + + + + + + + + + + + + + + +
      + +
      +

      Evernote Auth

      + +

      Copy the shown key into the Anki Pop up.

      +
      +
      +

      +
      + +
      + + + + + + + + + + diff --git a/anknotes/extra/ancillary/regex-see_also.txt b/anknotes/extra/ancillary/regex-see_also.txt new file mode 100644 index 0000000..e428ccb --- /dev/null +++ b/anknotes/extra/ancillary/regex-see_also.txt @@ -0,0 +1,18 @@ +(?P
      (?:)?]*>(?:)?
      )? +(?P + (?P]*>) + (?P + (?P + (?P(?:<(?:span|b|font|br)[^>]*>){0,5}) + (?P
      (?:\r|\n|\r\n)?)? + (?P(?:){0,2}) + (?P(?:<(?:span|b|font|br)[^>]*>){0,1}) + ) + See.[Aa]lso:?(?:\ | )? + (?P(?:){0,5}) + ) + (?P + .+ # See Also Contents + ) +) +(?P) \ No newline at end of file diff --git a/anknotes/extra/ancillary/regex-see_also2.txt b/anknotes/extra/ancillary/regex-see_also2.txt new file mode 100644 index 0000000..b8c86b0 --- /dev/null +++ b/anknotes/extra/ancillary/regex-see_also2.txt @@ -0,0 +1,21 @@ +(?P
      ]*>
      )? +(?P(?P]*>)(?P + +(?:<(?:b|span|font)[^>]*>){0,3} +(?:]*>) +(?:]+?)?>)? +(?P
      (?:\r|\n|\r\n)?)? + +(?:See.Also:? + +(?:]*> )? +(?: )?) + +(?:)? +(?:) +(?:)? +(?:)? + + + +)(?P.+))(?P) \ No newline at end of file diff --git a/anknotes/extra/ancillary/regex.txt b/anknotes/extra/ancillary/regex.txt new file mode 100644 index 0000000..1a3a26f --- /dev/null +++ b/anknotes/extra/ancillary/regex.txt @@ -0,0 +1,36 @@ +Converting this file to Python: + 1) \(\?:(\r\n|\r|\n){1,2}\)\? + 2) (?< + +Finding Evernote Links + ]+)?>(?P.+?)</a> + https://www.evernote.com/shard/(?P<shard>s\d+)/[\w\d]+/(?P<uid>\d+)/(?P<guid>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}) + + +Step 6: Process "See Also: " Links + (?<PrefixStrip><div><b><span[^>]*><br/></span></b></div>)?(?: + )?(?<SeeAlso>(?<SeeAlsoPrefix><div[^>]*>)(?<SeeAlsoHeader>(?: + + )?(?:<b[^>]*>)?(?: + )?(?:<(?:span|font)[^>]*>){0,2}(?: + )?(?:<span[^>]*>)(?: + )?(?:<b(?: style=[^>]+?)?>)?(?: + )?(?<SeeAlsoHeaderStripMe><br />(?:\r|\n|\r\n)?)?(?: + + )?(?:See Also:?(?: + + )?(?:<span[^>]*> </span>)?(?: + )?(?: )?)(?: + + )?(?:</b>)?(?: + )?(?:</span>)(?: + )?(?:</(?:span|font)>)?(?: + )?(?:</b>)?(?: + + )?(?: + + )?)(?<SeeAlsoContents>.+))(?<Suffix></en-note>) + +Replace Python Parameters with Reference to Self + ([\w_]+)(?: ?= ?(.+?))?(,|\)) + $1=$1$3 \ No newline at end of file diff --git a/anknotes/extra/ancillary/sorting.txt b/anknotes/extra/ancillary/sorting.txt new file mode 100644 index 0000000..79fbcf1 --- /dev/null +++ b/anknotes/extra/ancillary/sorting.txt @@ -0,0 +1,49 @@ + first = [ + 'Summary', + 'Definition', + 'Classification', + 'Types', + 'Presentation', + 'Age of Onset', + 'Si/Sx', + 'Sx', + 'Sign', + + 'MCC\'s', + 'MCC', + 'Inheritance', + 'Incidence', + 'Prognosis', + 'Mechanism', + 'MOA', + 'Pathophysiology', + + 'Indications', + 'Examples', + 'Cause', + 'Causes', + 'Causative Organisms', + 'Risk Factors', + 'Complication', + 'Complications', + 'Side Effects', + 'Drug S/E', + 'Associated Conditions', + 'A/w', + + 'Dx', + 'Physical Exam', + 'Labs', + 'Hemodynamic Parameters', + 'Lab Findings', + 'Imaging', + 'Confirmatory Test', + 'Screening Test' + ] + last = [ + 'Management', + 'Work Up', + 'Tx' + + + ] \ No newline at end of file diff --git a/anknotes/extra/ancillary/sqlite3.dll b/anknotes/extra/ancillary/sqlite3.dll new file mode 100644 index 0000000..3bedc32 Binary files /dev/null and b/anknotes/extra/ancillary/sqlite3.dll differ diff --git a/anknotes/extra/ancillary/xhtml-lat1.ent b/anknotes/extra/ancillary/xhtml-lat1.ent new file mode 100644 index 0000000..c944e2d --- /dev/null +++ b/anknotes/extra/ancillary/xhtml-lat1.ent @@ -0,0 +1,193 @@ +<!-- Portions (C) International Organization for Standardization 1986 + Permission to copy in any form is granted for use with + conforming SGML systems and applications as defined in + ISO 8879, provided this notice is included in all copies. +--> + <!-- Character entity set. Typical invocation: + <!ENTITY % HTMLlat1 PUBLIC + "-//W3C//ENTITIES Latin 1 for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-lat1.ent"> + %HTMLlat1; + --> + + <!ENTITY nbsp " "> <!-- no-break space = non-breaking space, + U+00A0 ISOnum --> + <!ENTITY iexcl "¡"> <!-- inverted exclamation mark, U+00A1 ISOnum --> + <!ENTITY cent "¢"> <!-- cent sign, U+00A2 ISOnum --> + <!ENTITY pound "£"> <!-- pound sign, U+00A3 ISOnum --> + <!ENTITY curren "¤"> <!-- currency sign, U+00A4 ISOnum --> + <!ENTITY yen "¥"> <!-- yen sign = yuan sign, U+00A5 ISOnum --> + <!ENTITY brvbar "¦"> <!-- broken bar = broken vertical bar, + U+00A6 ISOnum --> + <!ENTITY sect "§"> <!-- section sign, U+00A7 ISOnum --> + <!ENTITY uml "¨"> <!-- diaeresis = spacing diaeresis, + U+00A8 ISOdia --> + <!ENTITY copy "©"> <!-- copyright sign, U+00A9 ISOnum --> + <!ENTITY ordf "ª"> <!-- feminine ordinal indicator, U+00AA ISOnum --> + <!ENTITY laquo "«"> <!-- left-pointing double angle quotation mark + = left pointing guillemet, U+00AB ISOnum --> + <!ENTITY not "¬"> <!-- not sign = angled dash, + U+00AC ISOnum --> + <!ENTITY shy "­"> <!-- soft hyphen = discretionary hyphen, + U+00AD ISOnum --> + <!ENTITY reg "®"> <!-- registered sign = registered trade mark sign, + U+00AE ISOnum --> + <!ENTITY macr "¯"> <!-- macron = spacing macron = overline + = APL overbar, U+00AF ISOdia --> + <!ENTITY deg "°"> <!-- degree sign, U+00B0 ISOnum --> + <!ENTITY plusmn "±"> <!-- plus-minus sign = plus-or-minus sign, + U+00B1 ISOnum --> + <!ENTITY sup2 "²"> <!-- superscript two = superscript digit two + = squared, U+00B2 ISOnum --> + <!ENTITY sup3 "³"> <!-- superscript three = superscript digit three + = cubed, U+00B3 ISOnum --> + <!ENTITY acute "´"> <!-- acute accent = spacing acute, + U+00B4 ISOdia --> + <!ENTITY micro "µ"> <!-- micro sign, U+00B5 ISOnum --> + <!ENTITY para "¶"> <!-- pilcrow sign = paragraph sign, + U+00B6 ISOnum --> + <!ENTITY middot "·"> <!-- middle dot = Georgian comma + = Greek middle dot, U+00B7 ISOnum --> + <!ENTITY cedil "¸"> <!-- cedilla = spacing cedilla, U+00B8 ISOdia --> + <!ENTITY sup1 "¹"> <!-- superscript one = superscript digit one, + U+00B9 ISOnum --> + <!ENTITY ordm "º"> <!-- masculine ordinal indicator, + U+00BA ISOnum --> + <!ENTITY raquo "»"> <!-- right-pointing double angle quotation mark + = right pointing guillemet, U+00BB ISOnum --> + <!ENTITY frac14 "¼"> <!-- vulgar fraction one quarter + = fraction one quarter, U+00BC ISOnum --> + <!ENTITY frac12 "½"> <!-- vulgar fraction one half + = fraction one half, U+00BD ISOnum --> + <!ENTITY frac34 "¾"> <!-- vulgar fraction three quarters + = fraction three quarters, U+00BE ISOnum --> + <!ENTITY iquest "¿"> <!-- inverted question mark + = turned question mark, U+00BF ISOnum --> + <!ENTITY Agrave "À"> <!-- latin capital letter A with grave + = latin capital letter A grave, + U+00C0 ISOlat1 --> + <!ENTITY Aacute "Á"> <!-- latin capital letter A with acute, + U+00C1 ISOlat1 --> + <!ENTITY Acirc "Â"> <!-- latin capital letter A with circumflex, + U+00C2 ISOlat1 --> + <!ENTITY Atilde "Ã"> <!-- latin capital letter A with tilde, + U+00C3 ISOlat1 --> + <!ENTITY Auml "Ä"> <!-- latin capital letter A with diaeresis, + U+00C4 ISOlat1 --> + <!ENTITY Aring "Å"> <!-- latin capital letter A with ring above + = latin capital letter A ring, + U+00C5 ISOlat1 --> + <!ENTITY AElig "Æ"> <!-- latin capital letter AE + = latin capital ligature AE, + U+00C6 ISOlat1 --> + <!ENTITY Ccedil "Ç"> <!-- latin capital letter C with cedilla, + U+00C7 ISOlat1 --> + <!ENTITY Egrave "È"> <!-- latin capital letter E with grave, + U+00C8 ISOlat1 --> + <!ENTITY Eacute "É"> <!-- latin capital letter E with acute, + U+00C9 ISOlat1 --> + <!ENTITY Ecirc "Ê"> <!-- latin capital letter E with circumflex, + U+00CA ISOlat1 --> + <!ENTITY Euml "Ë"> <!-- latin capital letter E with diaeresis, + U+00CB ISOlat1 --> + <!ENTITY Igrave "Ì"> <!-- latin capital letter I with grave, + U+00CC ISOlat1 --> + <!ENTITY Iacute "Í"> <!-- latin capital letter I with acute, + U+00CD ISOlat1 --> + <!ENTITY Icirc "Î"> <!-- latin capital letter I with circumflex, + U+00CE ISOlat1 --> + <!ENTITY Iuml "Ï"> <!-- latin capital letter I with diaeresis, + U+00CF ISOlat1 --> + <!ENTITY ETH "Ð"> <!-- latin capital letter ETH, U+00D0 ISOlat1 --> + <!ENTITY Ntilde "Ñ"> <!-- latin capital letter N with tilde, + U+00D1 ISOlat1 --> + <!ENTITY Ograve "Ò"> <!-- latin capital letter O with grave, + U+00D2 ISOlat1 --> + <!ENTITY Oacute "Ó"> <!-- latin capital letter O with acute, + U+00D3 ISOlat1 --> + <!ENTITY Ocirc "Ô"> <!-- latin capital letter O with circumflex, + U+00D4 ISOlat1 --> + <!ENTITY Otilde "Õ"> <!-- latin capital letter O with tilde, + U+00D5 ISOlat1 --> + <!ENTITY Ouml "Ö"> <!-- latin capital letter O with diaeresis, + U+00D6 ISOlat1 --> + <!ENTITY times "×"> <!-- multiplication sign, U+00D7 ISOnum --> + <!ENTITY Oslash "Ø"> <!-- latin capital letter O with stroke + = latin capital letter O slash, + U+00D8 ISOlat1 --> + <!ENTITY Ugrave "Ù"> <!-- latin capital letter U with grave, + U+00D9 ISOlat1 --> + <!ENTITY Uacute "Ú"> <!-- latin capital letter U with acute, + U+00DA ISOlat1 --> + <!ENTITY Ucirc "Û"> <!-- latin capital letter U with circumflex, + U+00DB ISOlat1 --> + <!ENTITY Uuml "Ü"> <!-- latin capital letter U with diaeresis, + U+00DC ISOlat1 --> + <!ENTITY Yacute "Ý"> <!-- latin capital letter Y with acute, + U+00DD ISOlat1 --> + <!ENTITY THORN "Þ"> <!-- latin capital letter THORN, + U+00DE ISOlat1 --> + <!ENTITY szlig "ß"> <!-- latin small letter sharp s = ess-zed, + U+00DF ISOlat1 --> + <!ENTITY agrave "à"> <!-- latin small letter a with grave + = latin small letter a grave, + U+00E0 ISOlat1 --> + <!ENTITY aacute "á"> <!-- latin small letter a with acute, + U+00E1 ISOlat1 --> + <!ENTITY acirc "â"> <!-- latin small letter a with circumflex, + U+00E2 ISOlat1 --> + <!ENTITY atilde "ã"> <!-- latin small letter a with tilde, + U+00E3 ISOlat1 --> + <!ENTITY auml "ä"> <!-- latin small letter a with diaeresis, + U+00E4 ISOlat1 --> + <!ENTITY aring "å"> <!-- latin small letter a with ring above + = latin small letter a ring, + U+00E5 ISOlat1 --> + <!ENTITY aelig "æ"> <!-- latin small letter ae + = latin small ligature ae, U+00E6 ISOlat1 --> + <!ENTITY ccedil "ç"> <!-- latin small letter c with cedilla, + U+00E7 ISOlat1 --> + <!ENTITY egrave "è"> <!-- latin small letter e with grave, + U+00E8 ISOlat1 --> + <!ENTITY eacute "é"> <!-- latin small letter e with acute, + U+00E9 ISOlat1 --> + <!ENTITY ecirc "ê"> <!-- latin small letter e with circumflex, + U+00EA ISOlat1 --> + <!ENTITY euml "ë"> <!-- latin small letter e with diaeresis, + U+00EB ISOlat1 --> + <!ENTITY igrave "ì"> <!-- latin small letter i with grave, + U+00EC ISOlat1 --> + <!ENTITY iacute "í"> <!-- latin small letter i with acute, + U+00ED ISOlat1 --> + <!ENTITY icirc "î"> <!-- latin small letter i with circumflex, + U+00EE ISOlat1 --> + <!ENTITY iuml "ï"> <!-- latin small letter i with diaeresis, + U+00EF ISOlat1 --> + <!ENTITY eth "ð"> <!-- latin small letter eth, U+00F0 ISOlat1 --> + <!ENTITY ntilde "ñ"> <!-- latin small letter n with tilde, + U+00F1 ISOlat1 --> + <!ENTITY ograve "ò"> <!-- latin small letter o with grave, + U+00F2 ISOlat1 --> + <!ENTITY oacute "ó"> <!-- latin small letter o with acute, + U+00F3 ISOlat1 --> + <!ENTITY ocirc "ô"> <!-- latin small letter o with circumflex, + U+00F4 ISOlat1 --> + <!ENTITY otilde "õ"> <!-- latin small letter o with tilde, + U+00F5 ISOlat1 --> + <!ENTITY ouml "ö"> <!-- latin small letter o with diaeresis, + U+00F6 ISOlat1 --> + <!ENTITY divide "÷"> <!-- division sign, U+00F7 ISOnum --> + <!ENTITY oslash "ø"> <!-- latin small letter o with stroke, + = latin small letter o slash, + U+00F8 ISOlat1 --> + <!ENTITY ugrave "ù"> <!-- latin small letter u with grave, + U+00F9 ISOlat1 --> + <!ENTITY uacute "ú"> <!-- latin small letter u with acute, + U+00FA ISOlat1 --> + <!ENTITY ucirc "û"> <!-- latin small letter u with circumflex, + U+00FB ISOlat1 --> + <!ENTITY uuml "ü"> <!-- latin small letter u with diaeresis, + U+00FC ISOlat1 --> + <!ENTITY yacute "ý"> <!-- latin small letter y with acute, + U+00FD ISOlat1 --> + <!ENTITY thorn "þ"> <!-- latin sm \ No newline at end of file diff --git a/anknotes/extra/ancillary/xhtml-special.ent b/anknotes/extra/ancillary/xhtml-special.ent new file mode 100644 index 0000000..f9879b6 --- /dev/null +++ b/anknotes/extra/ancillary/xhtml-special.ent @@ -0,0 +1,79 @@ +<!-- Special characters for XHTML --> + + <!-- Character entity set. Typical invocation: + <!ENTITY % HTMLspecial PUBLIC + "-//W3C//ENTITIES Special for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-special.ent"> + %HTMLspecial; + --> + + <!-- Portions (C) International Organization for Standardization 1986: + Permission to copy in any form is granted for use with + conforming SGML systems and applications as defined in + ISO 8879, provided this notice is included in all copies. + --> + + <!-- Relevant ISO entity set is given unless names are newly introduced. + New names (i.e., not in ISO 8879 list) do not clash with any + existing ISO 8879 entity names. ISO 10646 character numbers + are given for each character, in hex. values are decimal + conversions of the ISO 10646 values and refer to the document + character set. Names are Unicode names. + --> + + <!-- C0 Controls and Basic Latin --> + <!ENTITY quot """> <!-- quotation mark, U+0022 ISOnum --> + <!ENTITY amp "&#38;"> <!-- ampersand, U+0026 ISOnum --> + <!ENTITY lt "&#60;"> <!-- less-than sign, U+003C ISOnum --> + <!ENTITY gt ">"> <!-- greater-than sign, U+003E ISOnum --> + <!ENTITY apos "'"> <!-- apostrophe = APL quote, U+0027 ISOnum --> + + <!-- Latin Extended-A --> + <!ENTITY OElig "Œ"> <!-- latin capital ligature OE, + U+0152 ISOlat2 --> + <!ENTITY oelig "œ"> <!-- latin small ligature oe, U+0153 ISOlat2 --> + <!-- ligature is a misnomer, this is a separate character in some languages --> + <!ENTITY Scaron "Š"> <!-- latin capital letter S with caron, + U+0160 ISOlat2 --> + <!ENTITY scaron "š"> <!-- latin small letter s with caron, + U+0161 ISOlat2 --> + <!ENTITY Yuml "Ÿ"> <!-- latin capital letter Y with diaeresis, + U+0178 ISOlat2 --> + + <!-- Spacing Modifier Letters --> + <!ENTITY circ "ˆ"> <!-- modifier letter circumflex accent, + U+02C6 ISOpub --> + <!ENTITY tilde "˜"> <!-- small tilde, U+02DC ISOdia --> + + <!-- General Punctuation --> + <!ENTITY ensp " "> <!-- en space, U+2002 ISOpub --> + <!ENTITY emsp " "> <!-- em space, U+2003 ISOpub --> + <!ENTITY thinsp " "> <!-- thin space, U+2009 ISOpub --> + <!ENTITY zwnj "‌"> <!-- zero width non-joiner, + U+200C NEW RFC 2070 --> + <!ENTITY zwj "‍"> <!-- zero width joiner, U+200D NEW RFC 2070 --> + <!ENTITY lrm "‎"> <!-- left-to-right mark, U+200E NEW RFC 2070 --> + <!ENTITY rlm "‏"> <!-- right-to-left mark, U+200F NEW RFC 2070 --> + <!ENTITY ndash "–"> <!-- en dash, U+2013 ISOpub --> + <!ENTITY mdash "—"> <!-- em dash, U+2014 ISOpub --> + <!ENTITY lsquo "‘"> <!-- left single quotation mark, + U+2018 ISOnum --> + <!ENTITY rsquo "’"> <!-- right single quotation mark, + U+2019 ISOnum --> + <!ENTITY sbquo "‚"> <!-- single low-9 quotation mark, U+201A NEW --> + <!ENTITY ldquo "“"> <!-- left double quotation mark, + U+201C ISOnum --> + <!ENTITY rdquo "”"> <!-- right double quotation mark, + U+201D ISOnum --> + <!ENTITY bdquo "„"> <!-- double low-9 quotation mark, U+201E NEW --> + <!ENTITY dagger "†"> <!-- dagger, U+2020 ISOpub --> + <!ENTITY Dagger "‡"> <!-- double dagger, U+2021 ISOpub --> + <!ENTITY permil "‰"> <!-- per mille sign, U+2030 ISOtech --> + <!ENTITY lsaquo "‹"> <!-- single left-pointing angle quotation mark, + U+2039 ISO proposed --> + <!-- lsaquo is proposed but not yet ISO standardized --> + <!ENTITY rsaquo "›"> <!-- single right-pointing angle quotation mark, + U+203A ISO proposed --> + <!-- rsaquo is proposed but not yet ISO standardized --> + + <!-- Cu \ No newline at end of file diff --git a/anknotes/extra/ancillary/xhtml-symbol.ent b/anknotes/extra/ancillary/xhtml-symbol.ent new file mode 100644 index 0000000..f02c548 --- /dev/null +++ b/anknotes/extra/ancillary/xhtml-symbol.ent @@ -0,0 +1,234 @@ +<!-- Mathematical, Greek and Symbolic characters for XHTML --> + + <!-- Character entity set. Typical invocation: + <!ENTITY % HTMLsymbol PUBLIC + "-//W3C//ENTITIES Symbols for XHTML//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml-symbol.ent"> + %HTMLsymbol; + --> + + <!-- Portions (C) International Organization for Standardization 1986: + Permission to copy in any form is granted for use with + conforming SGML systems and applications as defined in + ISO 8879, provided this notice is included in all copies. + --> + + <!-- Relevant ISO entity set is given unless names are newly introduced. + New names (i.e., not in ISO 8879 list) do not clash with any + existing ISO 8879 entity names. ISO 10646 character numbers + are given for each character, in hex. values are decimal + conversions of the ISO 10646 values and refer to the document + character set. Names are Unicode names. + --> + + <!-- Latin Extended-B --> + <!ENTITY fnof "ƒ"> <!-- latin small letter f with hook = function + = florin, U+0192 ISOtech --> + + <!-- Greek --> + <!ENTITY Alpha "Α"> <!-- greek capital letter alpha, U+0391 --> + <!ENTITY Beta "Β"> <!-- greek capital letter beta, U+0392 --> + <!ENTITY Gamma "Γ"> <!-- greek capital letter gamma, + U+0393 ISOgrk3 --> + <!ENTITY Delta "Δ"> <!-- greek capital letter delta, + U+0394 ISOgrk3 --> + <!ENTITY Epsilon "Ε"> <!-- greek capital letter epsilon, U+0395 --> + <!ENTITY Zeta "Ζ"> <!-- greek capital letter zeta, U+0396 --> + <!ENTITY Eta "Η"> <!-- greek capital letter eta, U+0397 --> + <!ENTITY Theta "Θ"> <!-- greek capital letter theta, + U+0398 ISOgrk3 --> + <!ENTITY Iota "Ι"> <!-- greek capital letter iota, U+0399 --> + <!ENTITY Kappa "Κ"> <!-- greek capital letter kappa, U+039A --> + <!ENTITY Lambda "Λ"> <!-- greek capital letter lamda, + U+039B ISOgrk3 --> + <!ENTITY Mu "Μ"> <!-- greek capital letter mu, U+039C --> + <!ENTITY Nu "Ν"> <!-- greek capital letter nu, U+039D --> + <!ENTITY Xi "Ξ"> <!-- greek capital letter xi, U+039E ISOgrk3 --> + <!ENTITY Omicron "Ο"> <!-- greek capital letter omicron, U+039F --> + <!ENTITY Pi "Π"> <!-- greek capital letter pi, U+03A0 ISOgrk3 --> + <!ENTITY Rho "Ρ"> <!-- greek capital letter rho, U+03A1 --> + <!-- there is no Sigmaf, and no U+03A2 character either --> + <!ENTITY Sigma "Σ"> <!-- greek capital letter sigma, + U+03A3 ISOgrk3 --> + <!ENTITY Tau "Τ"> <!-- greek capital letter tau, U+03A4 --> + <!ENTITY Upsilon "Υ"> <!-- greek capital letter upsilon, + U+03A5 ISOgrk3 --> + <!ENTITY Phi "Φ"> <!-- greek capital letter phi, + U+03A6 ISOgrk3 --> + <!ENTITY Chi "Χ"> <!-- greek capital letter chi, U+03A7 --> + <!ENTITY Psi "Ψ"> <!-- greek capital letter psi, + U+03A8 ISOgrk3 --> + <!ENTITY Omega "Ω"> <!-- greek capital letter omega, + U+03A9 ISOgrk3 --> + + <!ENTITY alpha "α"> <!-- greek small letter alpha, + U+03B1 ISOgrk3 --> + <!ENTITY beta "β"> <!-- greek small letter beta, U+03B2 ISOgrk3 --> + <!ENTITY gamma "γ"> <!-- greek small letter gamma, + U+03B3 ISOgrk3 --> + <!ENTITY delta "δ"> <!-- greek small letter delta, + U+03B4 ISOgrk3 --> + <!ENTITY epsilon "ε"> <!-- greek small letter epsilon, + U+03B5 ISOgrk3 --> + <!ENTITY zeta "ζ"> <!-- greek small letter zeta, U+03B6 ISOgrk3 --> + <!ENTITY eta "η"> <!-- greek small letter eta, U+03B7 ISOgrk3 --> + <!ENTITY theta "θ"> <!-- greek small letter theta, + U+03B8 ISOgrk3 --> + <!ENTITY iota "ι"> <!-- greek small letter iota, U+03B9 ISOgrk3 --> + <!ENTITY kappa "κ"> <!-- greek small letter kappa, + U+03BA ISOgrk3 --> + <!ENTITY lambda "λ"> <!-- greek small letter lamda, + U+03BB ISOgrk3 --> + <!ENTITY mu "μ"> <!-- greek small letter mu, U+03BC ISOgrk3 --> + <!ENTITY nu "ν"> <!-- greek small letter nu, U+03BD ISOgrk3 --> + <!ENTITY xi "ξ"> <!-- greek small letter xi, U+03BE ISOgrk3 --> + <!ENTITY omicron "ο"> <!-- greek small letter omicron, U+03BF NEW --> + <!ENTITY pi "π"> <!-- greek small letter pi, U+03C0 ISOgrk3 --> + <!ENTITY rho "ρ"> <!-- greek small letter rho, U+03C1 ISOgrk3 --> + <!ENTITY sigmaf "ς"> <!-- greek small letter final sigma, + U+03C2 ISOgrk3 --> + <!ENTITY sigma "σ"> <!-- greek small letter sigma, + U+03C3 ISOgrk3 --> + <!ENTITY tau "τ"> <!-- greek small letter tau, U+03C4 ISOgrk3 --> + <!ENTITY upsilon "υ"> <!-- greek small letter upsilon, + U+03C5 ISOgrk3 --> + <!ENTITY phi "φ"> <!-- greek small letter phi, U+03C6 ISOgrk3 --> + <!ENTITY chi "χ"> <!-- greek small letter chi, U+03C7 ISOgrk3 --> + <!ENTITY psi "ψ"> <!-- greek small letter psi, U+03C8 ISOgrk3 --> + <!ENTITY omega "ω"> <!-- greek small letter omega, + U+03C9 ISOgrk3 --> + <!ENTITY thetasym "ϑ"> <!-- greek theta symbol, + U+03D1 NEW --> + <!ENTITY upsih "ϒ"> <!-- greek upsilon with hook symbol, + U+03D2 NEW --> + <!ENTITY piv "ϖ"> <!-- greek pi symbol, U+03D6 ISOgrk3 --> + + <!-- General Punctuation --> + <!ENTITY bull "•"> <!-- bullet = black small circle, + U+2022 ISOpub --> + <!-- bullet is NOT the same as bullet operator, U+2219 --> + <!ENTITY hellip "…"> <!-- horizontal ellipsis = three dot leader, + U+2026 ISOpub --> + <!ENTITY prime "′"> <!-- prime = minutes = feet, U+2032 ISOtech --> + <!ENTITY Prime "″"> <!-- double prime = seconds = inches, + U+2033 ISOtech --> + <!ENTITY oline "‾"> <!-- overline = spacing overscore, + U+203E NEW --> + <!ENTITY frasl "⁄"> <!-- fraction slash, U+2044 NEW --> + + <!-- Letterlike Symbols --> + <!ENTITY weierp "℘"> <!-- script capital P = power set + = Weierstrass p, U+2118 ISOamso --> + <!ENTITY image "ℑ"> <!-- black-letter capital I = imaginary part, + U+2111 ISOamso --> + <!ENTITY real "ℜ"> <!-- black-letter capital R = real part symbol, + U+211C ISOamso --> + <!ENTITY trade "™"> <!-- trade mark sign, U+2122 ISOnum --> + <!ENTITY alefsym "ℵ"> <!-- alef symbol = first transfinite cardinal, + U+2135 NEW --> + <!-- alef symbol is NOT the same as hebrew letter alef, + U+05D0 although the same glyph could be used to depict both characters --> + + <!-- Arrows --> + <!ENTITY larr "←"> <!-- leftwards arrow, U+2190 ISOnum --> + <!ENTITY uarr "↑"> <!-- upwards arrow, U+2191 ISOnum--> + <!ENTITY rarr "→"> <!-- rightwards arrow, U+2192 ISOnum --> + <!ENTITY darr "↓"> <!-- downwards arrow, U+2193 ISOnum --> + <!ENTITY harr "↔"> <!-- left right arrow, U+2194 ISOamsa --> + <!ENTITY crarr "↵"> <!-- downwards arrow with corner leftwards + = carriage return, U+21B5 NEW --> + <!ENTITY lArr "⇐"> <!-- leftwards double arrow, U+21D0 ISOtech --> + <!-- Unicode does not say that lArr is the same as the 'is implied by' arrow + but also does not have any other character for that function. So lArr can + be used for 'is implied by' as ISOtech suggests --> + <!ENTITY uArr "⇑"> <!-- upwards double arrow, U+21D1 ISOamsa --> + <!ENTITY rArr "⇒"> <!-- rightwards double arrow, + U+21D2 ISOtech --> + <!-- Unicode does not say this is the 'implies' character but does not have + another character with this function so rArr can be used for 'implies' + as ISOtech suggests --> + <!ENTITY dArr "⇓"> <!-- downwards double arrow, U+21D3 ISOamsa --> + <!ENTITY hArr "⇔"> <!-- left right double arrow, + U+21D4 ISOamsa --> + + <!-- Mathematical Operators --> + <!ENTITY forall "∀"> <!-- for all, U+2200 ISOtech --> + <!ENTITY part "∂"> <!-- partial differential, U+2202 ISOtech --> + <!ENTITY exist "∃"> <!-- there exists, U+2203 ISOtech --> + <!ENTITY empty "∅"> <!-- empty set = null set, U+2205 ISOamso --> + <!ENTITY nabla "∇"> <!-- nabla = backward difference, + U+2207 ISOtech --> + <!ENTITY isin "∈"> <!-- element of, U+2208 ISOtech --> + <!ENTITY notin "∉"> <!-- not an element of, U+2209 ISOtech --> + <!ENTITY ni "∋"> <!-- contains as member, U+220B ISOtech --> + <!ENTITY prod "∏"> <!-- n-ary product = product sign, + U+220F ISOamsb --> + <!-- prod is NOT the same character as U+03A0 'greek capital letter pi' though + the same glyph might be used for both --> + <!ENTITY sum "∑"> <!-- n-ary summation, U+2211 ISOamsb --> + <!-- sum is NOT the same character as U+03A3 'greek capital letter sigma' + though the same glyph might be used for both --> + <!ENTITY minus "−"> <!-- minus sign, U+2212 ISOtech --> + <!ENTITY lowast "∗"> <!-- asterisk operator, U+2217 ISOtech --> + <!ENTITY radic "√"> <!-- square root = radical sign, + U+221A ISOtech --> + <!ENTITY prop "∝"> <!-- proportional to, U+221D ISOtech --> + <!ENTITY infin "∞"> <!-- infinity, U+221E ISOtech --> + <!ENTITY ang "∠"> <!-- angle, U+2220 ISOamso --> + <!ENTITY and "∧"> <!-- logical and = wedge, U+2227 ISOtech --> + <!ENTITY or "∨"> <!-- logical or = vee, U+2228 ISOtech --> + <!ENTITY cap "∩"> <!-- intersection = cap, U+2229 ISOtech --> + <!ENTITY cup "∪"> <!-- union = cup, U+222A ISOtech --> + <!ENTITY int "∫"> <!-- integral, U+222B ISOtech --> + <!ENTITY there4 "∴"> <!-- therefore, U+2234 ISOtech --> + <!ENTITY sim "∼"> <!-- tilde operator = varies with = similar to, + U+223C ISOtech --> + <!-- tilde operator is NOT the same character as the tilde, U+007E, + although the same glyph might be used to represent both --> + <!ENTITY cong "≅"> <!-- approximately equal to, U+2245 ISOtech --> + <!ENTITY asymp "≈"> <!-- almost equal to = asymptotic to, + U+2248 ISOamsr --> + <!ENTITY ne "≠"> <!-- not equal to, U+2260 ISOtech --> + <!ENTITY equiv "≡"> <!-- identical to, U+2261 ISOtech --> + <!ENTITY le "≤"> <!-- less-than or equal to, U+2264 ISOtech --> + <!ENTITY ge "≥"> <!-- greater-than or equal to, + U+2265 ISOtech --> + <!ENTITY sub "⊂"> <!-- subset of, U+2282 ISOtech --> + <!ENTITY sup "⊃"> <!-- superset of, U+2283 ISOtech --> + <!ENTITY nsub "⊄"> <!-- not a subset of, U+2284 ISOamsn --> + <!ENTITY sube "⊆"> <!-- subset of or equal to, U+2286 ISOtech --> + <!ENTITY supe "⊇"> <!-- superset of or equal to, + U+2287 ISOtech --> + <!ENTITY oplus "⊕"> <!-- circled plus = direct sum, + U+2295 ISOamsb --> + <!ENTITY otimes "⊗"> <!-- circled times = vector product, + U+2297 ISOamsb --> + <!ENTITY perp "⊥"> <!-- up tack = orthogonal to = perpendicular, + U+22A5 ISOtech --> + <!ENTITY sdot "⋅"> <!-- dot operator, U+22C5 ISOamsb --> + <!-- dot operator is NOT the same character as U+00B7 middle dot --> + + <!-- Miscellaneous Technical --> + <!ENTITY lceil "⌈"> <!-- left ceiling = APL upstile, + U+2308 ISOamsc --> + <!ENTITY rceil "⌉"> <!-- right ceiling, U+2309 ISOamsc --> + <!ENTITY lfloor "⌊"> <!-- left floor = APL downstile, + U+230A ISOamsc --> + <!ENTITY rfloor "⌋"> <!-- right floor, U+230B ISOamsc --> + <!ENTITY lang "〈"> <!-- left-pointing angle bracket = bra, + U+2329 ISOtech --> + <!-- lang is NOT the same character as U+003C 'less than sign' + or U+2039 'single left-pointing angle quotation mark' --> + <!ENTITY rang "〉"> <!-- right-pointing angle bracket = ket, + U+232A ISOtech --> + <!-- rang is NOT the same character as U+003E 'greater than sign' + or U+203A 'single right-pointing angle quotation mark' --> + + <!-- Geometric Shapes --> + <!ENTITY loz "◊"> <!-- lozenge, U+25CA ISOpub --> + + <!-- Miscellaneous Symbols --> + <!ENTITY spades "♠"> <!-- black spade suit, U+2660 ISOpub --> + <!-- black here seems to mean filled as opposed to hollow --> + <!ENTITY clubs "♣"> <!-- black club suit = shamrock, + \ No newline at end of file diff --git a/anknotes/extra/dev/Restart Anki - How to use.txt b/anknotes/extra/dev/Restart Anki - How to use.txt new file mode 100644 index 0000000..1ed47f9 --- /dev/null +++ b/anknotes/extra/dev/Restart Anki - How to use.txt @@ -0,0 +1 @@ +Create a link that points to invisible.vbs with restart_anki.bat as its argument. Use that link to quickly restart anki while debugging \ No newline at end of file diff --git a/anknotes/extra/dev/anknotes_standAlone_template.py b/anknotes/extra/dev/anknotes_standAlone_template.py new file mode 100644 index 0000000..2d8a259 --- /dev/null +++ b/anknotes/extra/dev/anknotes_standAlone_template.py @@ -0,0 +1,124 @@ +import os +from anknotes import stopwatch +import time + +try: + from lxml import etree + + eTreeImported = True +except: + eTreeImported = False +if eTreeImported: + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError: + from sqlite3 import dbapi2 as sqlite + + # Anknotes Module Imports for Stand Alone Scripts + from anknotes import evernote as evernote + + # Anknotes Shared Imports + from anknotes.shared import * + from anknotes.error import * + from anknotes.toc import TOCHierarchyClass + + # Anknotes Class Imports + from anknotes.AnkiNotePrototype import AnkiNotePrototype + from anknotes.EvernoteNoteTitle import generateTOCTitle + + # Anknotes Main Imports + from anknotes.Anki import Anki + from anknotes.ankEvernote import Evernote + from anknotes.EvernoteNoteFetcher import EvernoteNoteFetcher + from anknotes.EvernoteNotes import EvernoteNotes + from anknotes.EvernoteNotePrototype import EvernoteNotePrototype + from anknotes.EvernoteImporter import EvernoteImporter + + # Evernote Imports + from anknotes.evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec + from anknotes.evernote.edam.type.ttypes import NoteSortOrder, Note as EvernoteNote + from anknotes.evernote.edam.error.ttypes import EDAMSystemException, EDAMUserException, EDAMNotFoundException + from anknotes.evernote.api.client import EvernoteClient + + ankDBSetLocal() + db = ankDB() + db.Init() + + failed_queued_items = db.all("SELECT * FROM %s WHERE validation_status = 1 " % TABLES.MAKE_NOTE_QUEUE) + pending_queued_items = db.all("SELECT * FROM %s WHERE validation_status = 0" % TABLES.MAKE_NOTE_QUEUE) + success_queued_items = db.all("SELECT * FROM %s WHERE validation_status = -1 " % TABLES.MAKE_NOTE_QUEUE) + + currentLog = 'Successful' + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True, clear=True) + log(" CHECKING %3d SUCCESSFUL MAKE NOTE QUEUE ITEMS " % len(success_queued_items), 'MakeNoteQueue-' + currentLog, + timestamp=False, do_print=True) + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True) + + for result in success_queued_items: + line = (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW [%-30s] " % '' + line += result['title'] + log(line, 'MakeNoteQueue-' + currentLog, timestamp=False, do_print=False) + + currentLog = 'Failed' + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True, clear=True) + log(" CHECKING %3d FAILED MAKE NOTE QUEUE ITEMS " % len(failed_queued_items), 'MakeNoteQueue-' + currentLog, + clear=False, timestamp=False, do_print=True) + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True) + + for result in failed_queued_items: + line = '%-60s ' % (result['title'] + ':') + line += (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW" + line += result['validation_result'] + log(line, 'MakeNoteQueue-' + currentLog, timestamp=False, do_print=True) + log("------------------------------------------------\n", 'MakeNoteQueue-' + currentLog, timestamp=False) + log(result['contents'], 'MakeNoteQueue-' + currentLog, timestamp=False) + log("------------------------------------------------\n", 'MakeNoteQueue-' + currentLog, timestamp=False) + + EN = Evernote() + + currentLog = 'Pending' + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True, clear=True) + log(" CHECKING %3d PENDING MAKE NOTE QUEUE ITEMS " % len(pending_queued_items), 'MakeNoteQueue-' + currentLog, + clear=False, timestamp=False, do_print=True) + log("------------------------------------------------", 'MakeNoteQueue-' + currentLog, timestamp=False, + do_print=True) + + timerFull = stopwatch.Timer() + for result in pending_queued_items: + guid = result['guid'] + noteContents = result['contents'] + noteTitle = result['title'] + line = (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW [%-30s] " % '' + + success, errors = EN.validateNoteContent(noteContents, noteTitle) + validation_status = 1 if success else -1 + + line = " SUCCESS! " if success else " FAILURE: " + line += ' ' if result['guid'] else ' NEW ' + # line += ' %-60s ' % (result['title'] + ':') + if not success: + errors = '\n * ' + '\n * '.join(errors) + log(line, 'MakeNoteQueue-' + currentLog, timestamp=False, do_print=True) + else: + errors = '\n'.join(errors) + + sql = "UPDATE %s SET validation_status = %d, validation_result = '%s' WHERE " % ( + TABLES.MAKE_NOTE_QUEUE, validation_status, escape_text_sql(errors)) + if guid: + sql += "guid = '%s'" % guid + else: + sql += "title = '%s' AND contents = '%s'" % (escape_text_sql(noteTitle), escape_text_sql(noteContents)) + + db.execute(sql) + + timerFull.stop() + log("Validation of %d results completed in %s" % (len(pending_queued_items), str(timerFull)), + 'MakeNoteQueue-' + currentLog, timestamp=False, do_print=True) + + db.commit() + db.close() diff --git a/anknotes/extra/dev/anknotes_test.py b/anknotes/extra/dev/anknotes_test.py new file mode 100644 index 0000000..06e4c31 --- /dev/null +++ b/anknotes/extra/dev/anknotes_test.py @@ -0,0 +1,213 @@ +import os +import re +from HTMLParser import HTMLParser + +PATH = os.path.dirname(os.path.abspath(__file__)) +ANKNOTES_TEMPLATE_FRONT = 'FrontTemplate.htm' +MODEL_EVERNOTE_DEFAULT = 'evernote_note' +MODEL_EVERNOTE_REVERSIBLE = 'evernote_note_reversible' +MODEL_EVERNOTE_REVERSE_ONLY = 'evernote_note_reverse_only' +MODEL_EVERNOTE_CLOZE = 'evernote_note_cloze' +MODEL_TYPE_CLOZE = 1 + +TEMPLATE_EVERNOTE_DEFAULT = 'EvernoteReview' +TEMPLATE_EVERNOTE_REVERSED = 'EvernoteReviewReversed' +TEMPLATE_EVERNOTE_CLOZE = 'EvernoteReviewCloze' +FIELD_TITLE = 'title' +FIELD_CONTENT = 'content' +FIELD_SEE_ALSO = 'See Also' +FIELD_EXTRA = 'Extra' +FIELD_EVERNOTE_GUID = 'Evernote GUID' + +EVERNOTE_TAG_REVERSIBLE = '#Reversible' +EVERNOTE_TAG_REVERSE_ONLY = '#Reversible_Only' + +TABLE_SEE_ALSO = "anknotes_see_also" +TABLE_TOC = "anknotes_toc" + +SETTING_KEEP_EVERNOTE_TAGS_DEFAULT_VALUE = True +SETTING_EVERNOTE_TAGS_TO_IMPORT_DEFAULT_VALUE = "#Anki_Import" +SETTING_DEFAULT_ANKI_TAG_DEFAULT_VALUE = "#Evernote" +SETTING_DEFAULT_ANKI_DECK_DEFAULT_VALUE = "Evernote" + +SETTING_DELETE_EVERNOTE_TAGS_TO_IMPORT = 'anknotesDeleteEvernoteTagsToImport' +SETTING_UPDATE_EXISTING_NOTES = 'anknotesUpdateExistingNotes' +SETTING_EVERNOTE_AUTH_TOKEN = 'anknotesEvernoteAuthToken' +SETTING_KEEP_EVERNOTE_TAGS = 'anknotesKeepEvernoteTags' +SETTING_EVERNOTE_TAGS_TO_IMPORT = 'anknotesEvernoteTagsToImport' +# Deprecated +# SETTING_DEFAULT_ANKI_TAG = 'anknotesDefaultAnkiTag' +SETTING_DEFAULT_ANKI_DECK = 'anknotesDefaultAnkiDeck' + +evernote_cloze_count = 0 + + +class MLStripper(HTMLParser): + def __init__(self): + self.reset() + self.fed = [] + + def handle_data(self, d): + self.fed.append(d) + + def get_data(self): + return ''.join(self.fed) + + +def strip_tags(html): + s = MLStripper() + s.feed(html) + return s.get_data() + + +class AnkiNotePrototype: + fields = {} + tags = [] + evernote_tags_to_import = [] + model_name = MODEL_EVERNOTE_DEFAULT + + def __init__(self, fields, tags, evernote_tags_to_import=list()): + self.fields = fields + self.tags = tags + self.evernote_tags_to_import = evernote_tags_to_import + + self.process_note() + + @staticmethod + def evernote_cloze_regex(match): + global evernote_cloze_count + matchText = match.group(1) + if matchText[0] == "#": + matchText = matchText[1:] + else: + evernote_cloze_count += 1 + if evernote_cloze_count == 0: + evernote_cloze_count = 1 + + # print "Match: Group #%d: %s" % (evernote_cloze_count, matchText) + return "{{c%d::%s}}" % (evernote_cloze_count, matchText) + + def process_note_see_also(self): + if not FIELD_SEE_ALSO in self.fields or not FIELD_EVERNOTE_GUID in self.fields: + return + + note_guid = self.fields[FIELD_EVERNOTE_GUID] + # mw.col.db.execute("CREATE TABLE IF NOT EXISTS %s(id INTEGER PRIMARY KEY, note_guid TEXT, uid INTEGER, shard TEXT, guid TEXT, html TEXT, text TEXT ) " % TABLE_SEE_ALSO) + # mw.col.db.execute("CREATE TABLE IF NOT EXISTS %s(id INTEGER PRIMARY KEY, note_guid TEXT, uid INTEGER, shard TEXT, guid TEXT, title TEXT ) " % TABLE_TOC) + # mw.col.db.execute("DELETE FROM %s WHERE note_guid = '%s' " % (TABLE_SEE_ALSO, note_guid)) + # mw.col.db.execute("DELETE FROM %s WHERE note_guid = '%s' " % (TABLE_TOC, note_guid)) + + + print "Running See Also" + iter = re.finditer( + r'<a href="(?P<URL>evernote:///?view/(?P<uid>[\d]+)/(?P<shard>s\d+)/(?P<guid>[\w\-]+)/(?P=guid)/?)"(?: shape="rect")?>(?P<Title>.+?)</a>', + self.fields[FIELD_SEE_ALSO]) + for match in iter: + title_text = strip_tags(match.group('Title')) + print "Link: %s: %s" % (match.group('guid'), title_text) + # for id, ivl in mw.col.db.execute("select id, ivl from cards limit 3"): + + + + + # .NET Regex: <a href="(?<URL>evernote:///?view/(?<uid>[\d]+)/(?<shard>s\d+)/(?<guid>[\w\-]+)/\k<guid>/?)"(?: shape="rect")?>(?<Title>.+?)</a> + # links_match + + def process_note_content(self): + if not FIELD_CONTENT in self.fields: + return + content = self.fields[FIELD_CONTENT] + ################################## Step 1: Modify Evernote Links + # We need to modify Evernote's "Classic" Style Note Links due to an Anki bug with executing the evernote command with three forward slashes. + # For whatever reason, Anki cannot handle evernote links with three forward slashes, but *can* handle links with two forward slashes. + content = content.replace("evernote:///", "evernote://") + + # Modify Evernote's "New" Style Note links that point to the Evernote website. Normally these links open the note using Evernote's web client. + # The web client then opens the local Evernote executable. Modifying the links as below will skip this step and open the note directly using the local Evernote executable + content = re.sub(r'https://www.evernote.com/shard/(s\d+)/[\w\d]+/(\d+)/([\w\d\-]+)', + r'evernote://view/\2/\1/\3/\3/', content) + + ################################## Step 2: Modify Image Links + # Currently anknotes does not support rendering images embedded into an Evernote note. + # As a work around, this code will convert any link to an image on Dropbox, to an embedded <img> tag. + # This code modifies the Dropbox link so it links to a raw image file rather than an interstitial web page + # Step 2.1: Modify HTML links to Dropbox images + dropbox_image_url_regex = r'(?P<URL>https://www.dropbox.com/s/[\w\d]+/.+\.(jpg|png|jpeg|gif|bmp))(?P<QueryString>\?dl=(?:0|1))?' + dropbox_image_src_subst = r'<a href="\g<URL>}\g<QueryString>}" shape="rect"><img src="\g<URL>?raw=1" alt="Dropbox Link %s Automatically Generated by Anknotes" /></a>' + content = re.sub(r'<a href="%s".*?>(?P<Title>.+?)</a>' % dropbox_image_url_regex, + dropbox_image_src_subst % "'\g<Title>'", content) + + # Step 2.2: Modify Plain-text links to Dropbox images + content = re.sub(dropbox_image_url_regex, dropbox_image_src_subst % "From Plain-Text Link", content) + + # Step 2.3: Modify HTML links with the inner text of exactly "(Image Link)" + content = re.sub(r'<a href="(?P<URL>.+)"[^>]+>(?P<Title>\(Image Link.*\))</a>', + r'''<img src="\g<URL>" alt="'\g<Title>' Automatically Generated by Anknotes" /> <BR><a href="\g<URL>">\g<Title></a>''', + content) + + ################################## Step 3: Change white text to transparent + # I currently use white text in Evernote to display information that I want to be initially hidden, but visible when desired by selecting the white text. + # We will change the white text to a special "occluded" CSS class so it can be visible on the back of cards, and also so we can adjust the color for the front of cards when using night mode + content = content.replace('<span style="color: rgb(255, 255, 255);">', '<span class="occluded">') + + ################################## Step 4: Automatically Occlude Text in <<Double Angle Brackets>> + content = re.sub(r'<<(.+?)>>', r'<<<span class="occluded">$1</span>>>', content) + + ################################## Step 5: Create Cloze fields from shorthand. Syntax is {Text}. Optionally {#Text} will prevent the Cloze # from incrementing. + content = re.sub(r'{(.+?)}', self.evernote_cloze_regex, content) + + ################################## Step 6: Process "See Also: " Links + # .NET regex: (?<PrefixStrip><div><b><span style="color: rgb\(\d{1,3}, \d{1,3}, \d{1,3}\);"><br/></span></b></div>)?(?<SeeAlso>(?<SeeAlsoPrefix><div>)(?<SeeAlsoHeader><span style="color: rgb\(45, 79, 201\);"><b>See Also:(?: )?</b></span>|<b><span style="color: rgb\(45, 79, 201\);">See Also:</span></b>)(?<SeeAlsoContents>.+))(?<Suffix></en-note>) + see_also_match = re.search( + r'(?:<div><b><span style="color: rgb\(\d{1,3}, \d{1,3}, \d{1,3}\);"><br/></span></b></div>)?(?P<SeeAlso>(?:<div>)(?:<span style="color: rgb\(45, 79, 201\);"><b>See Also:(?: )?</b></span>|<b><span style="color: rgb\(45, 79, 201\);">See Also:</span></b>) ?(?P<SeeAlsoLinks>.+))(?P<Suffix></en-note>)', + content) + # see_also_match = re.search(r'(?P<PrefixStrip><div><b><span style="color: rgb\(\d{1,3}, \d{1,3}, \d{1,3}\);"><br/></span></b></div>)?(?P<SeeAlso>(?:<div>)(?P<SeeAlsoHeader><span style="color: rgb\(45, 79, 201\);">(?:See Also|<b>See Also:</b>).*?</span>).+?)(?P<Suffix></en-note>)', content) + + if see_also_match: + content = content.replace(see_also_match.group(0), see_also_match.group('Suffix')) + self.fields[FIELD_SEE_ALSO] = see_also_match.group('SeeAlso') + self.process_note_see_also() + + ################################## Note Processing complete. + self.fields[FIELD_CONTENT] = content + + def process_note(self): + self.model_name = MODEL_EVERNOTE_DEFAULT + # Process Note Content + self.process_note_content() + + # Dynamically determine Anki Card Type + if FIELD_CONTENT in self.fields and "{{c1: + :" in self.fields[FIELD_CONTENT]: + self.model_name = MODEL_EVERNOTE_CLOZE + elif EVERNOTE_TAG_REVERSIBLE in self.tags: + self.model_name = MODEL_EVERNOTE_REVERSIBLE + if True: # if mw.col.conf.get(SETTING_DELETE_EVERNOTE_TAGS_TO_IMPORT, True): + self.tags.remove(EVERNOTE_TAG_REVERSIBLE) + elif EVERNOTE_TAG_REVERSE_ONLY in self.tags: + model_name = MODEL_EVERNOTE_REVERSE_ONLY + if True: # if mw.col.conf.get(SETTING_DELETE_EVERNOTE_TAGS_TO_IMPORT, True): + self.tags.remove(EVERNOTE_TAG_REVERSE_ONLY) + + # Remove Evernote Tags to Import + if True: # if mw.col.conf.get(SETTING_DELETE_EVERNOTE_TAGS_TO_IMPORT, True): + for tag in self.evernote_tags_to_import: + self.tags.remove(tag) + + +def test_anki(title, guid, filename=""): + if not filename: + filename = title + fields = { + FIELD_TITLE: title, FIELD_CONTENT: file(os.path.join(PATH, filename + ".enex"), 'r').read(), + FIELD_EVERNOTE_GUID: guid + } + tags = ['NoTags', 'NoTagsToRemove'] + en_tags = ['NoTagsToRemove'] + return AnkiNotePrototype(fields, tags, en_tags) + + +title = "Test title" +content = file(os.path.join(PATH, ANKNOTES_TEMPLATE_FRONT), 'r').read() +anki_note_prototype = test_anki("CNS Lesions Presentations Neuromuscular", '301a42d6-7ce5-4850-a365-cd1f0e98939d') +print "EN GUID: " + anki_note_prototype.fields[FIELD_EVERNOTE_GUID] diff --git a/anknotes/extra/dev/invisible.vbs b/anknotes/extra/dev/invisible.vbs new file mode 100644 index 0000000..ab66213 --- /dev/null +++ b/anknotes/extra/dev/invisible.vbs @@ -0,0 +1,6 @@ +args = "" +For I = 0 to Wscript.Arguments.Count - 1 + args = args & """" & WScript.Arguments(i) & """ " +Next + +CreateObject("Wscript.Shell").Run args, 0, False \ No newline at end of file diff --git a/anknotes/extra/dev/restart_anki.bat b/anknotes/extra/dev/restart_anki.bat new file mode 100644 index 0000000..8c833ef --- /dev/null +++ b/anknotes/extra/dev/restart_anki.bat @@ -0,0 +1,2 @@ +taskkill /f /im anki.exe +"C:\Program Files (x86)\Anki\anki.exe" \ No newline at end of file diff --git a/anknotes/extra/dev/restart_anki_automate.bat b/anknotes/extra/dev/restart_anki_automate.bat new file mode 100644 index 0000000..3c341a2 --- /dev/null +++ b/anknotes/extra/dev/restart_anki_automate.bat @@ -0,0 +1,4 @@ +cd /d "%~dp0" +rename anknotes.developer.automate2 anknotes.developer.automate +taskkill /f /im anki.exe +"C:\Program Files (x86)\Anki\anki.exe" \ No newline at end of file diff --git a/anknotes/extra/graphics/Evernote.ico b/anknotes/extra/graphics/Evernote.ico new file mode 100644 index 0000000..5e2e6b3 Binary files /dev/null and b/anknotes/extra/graphics/Evernote.ico differ diff --git a/anknotes/extra/graphics/Evernote.png b/anknotes/extra/graphics/Evernote.png new file mode 100644 index 0000000..5dead3d Binary files /dev/null and b/anknotes/extra/graphics/Evernote.png differ diff --git a/anknotes/extra/graphics/Tomato-icon.ico b/anknotes/extra/graphics/Tomato-icon.ico new file mode 100644 index 0000000..c6c3ee5 Binary files /dev/null and b/anknotes/extra/graphics/Tomato-icon.ico differ diff --git a/anknotes/extra/graphics/Tomato-icon.png b/anknotes/extra/graphics/Tomato-icon.png new file mode 100644 index 0000000..7015a4c Binary files /dev/null and b/anknotes/extra/graphics/Tomato-icon.png differ diff --git a/anknotes/extra/graphics/evernote_artcore.ico b/anknotes/extra/graphics/evernote_artcore.ico new file mode 100644 index 0000000..cbdb775 Binary files /dev/null and b/anknotes/extra/graphics/evernote_artcore.ico differ diff --git a/anknotes/extra/graphics/evernote_artcore.png b/anknotes/extra/graphics/evernote_artcore.png new file mode 100644 index 0000000..302ab56 Binary files /dev/null and b/anknotes/extra/graphics/evernote_artcore.png differ diff --git a/anknotes/extra/graphics/evernote_metro.ico b/anknotes/extra/graphics/evernote_metro.ico new file mode 100644 index 0000000..a0369d1 Binary files /dev/null and b/anknotes/extra/graphics/evernote_metro.ico differ diff --git a/anknotes/extra/graphics/evernote_metro.png b/anknotes/extra/graphics/evernote_metro.png new file mode 100644 index 0000000..bcd57ac Binary files /dev/null and b/anknotes/extra/graphics/evernote_metro.png differ diff --git a/anknotes/extra/graphics/evernote_metro_reflected.ico b/anknotes/extra/graphics/evernote_metro_reflected.ico new file mode 100644 index 0000000..6251c63 Binary files /dev/null and b/anknotes/extra/graphics/evernote_metro_reflected.ico differ diff --git a/anknotes/extra/graphics/evernote_metro_reflected.png b/anknotes/extra/graphics/evernote_metro_reflected.png new file mode 100644 index 0000000..982eabc Binary files /dev/null and b/anknotes/extra/graphics/evernote_metro_reflected.png differ diff --git a/anknotes/extra/graphics/evernote_web.ico b/anknotes/extra/graphics/evernote_web.ico new file mode 100644 index 0000000..f1e545a Binary files /dev/null and b/anknotes/extra/graphics/evernote_web.ico differ diff --git a/anknotes/extra/graphics/evernote_web.png b/anknotes/extra/graphics/evernote_web.png new file mode 100644 index 0000000..f448295 Binary files /dev/null and b/anknotes/extra/graphics/evernote_web.png differ diff --git a/anknotes/find_deleted_notes.py b/anknotes/find_deleted_notes.py new file mode 100644 index 0000000..eeca6ca --- /dev/null +++ b/anknotes/find_deleted_notes.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +import os + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite + +from anknotes.shared import * + + +def do_find_deleted_notes(all_anki_notes=None): + """ + :param all_anki_notes: from Anki.get_evernote_guids_and_anki_fields_from_anki_note_ids() + :type : dict[str, dict[str, str]] + :return: + """ + + Error = sqlite.Error + + if not os.path.isfile(FILES.USER.TABLE_OF_CONTENTS_ENEX): + log_error('Unable to proceed with find_deleted_notes: TOC enex does not exist.', do_print=True) + return False + + enTableOfContents = file(FILES.USER.TABLE_OF_CONTENTS_ENEX, 'r').read() + # find = file(os.path.join(PATH, "powergrep-find.txt") , 'r').read().splitlines() + # replace = file(os.path.join(PATH, "powergrep-replace.txt") , 'r').read().replace('https://www.evernote.com/shard/s175/nl/19775535/' , '').splitlines() + db=ankDB() + all_anknotes_notes = db.all(columns='guid, title, tagNames') + find_guids = {} + log_banner(' FIND DELETED EVERNOTE NOTES: UNIMPORTED EVERNOTE NOTES ', FILES.LOGS.FDN.UNIMPORTED_EVERNOTE_NOTES) + log_banner(' FIND DELETED EVERNOTE NOTES: ORPHAN ANKI NOTES ', FILES.LOGS.FDN.ANKI_ORPHANS) + log_banner(' FIND DELETED EVERNOTE NOTES: ORPHAN ANKNOTES DB ENTRIES ', FILES.LOGS.FDN.ANKNOTES_ORPHANS) + log_banner(' FIND DELETED EVERNOTE NOTES: ANKNOTES TITLE MISMATCHES ', FILES.LOGS.FDN.ANKNOTES_TITLE_MISMATCHES) + log_banner(' FIND DELETED EVERNOTE NOTES: ANKI TITLE MISMATCHES ', FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES) + log_banner(' FIND DELETED EVERNOTE NOTES: POSSIBLE TOC NOTES MISSING TAG ', + FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES + '_possibletoc') + anki_mismatch = 0 + is_toc_or_outline = [] + all_anki_notes = db.all("SELECT n.sfld, n.flds FROM notes n WHERE n.flds LIKE ? || '%'", FIELDS.EVERNOTE_GUID_PREFIX) + all_anki_notes = {get_evernote_guid_from_anki_fields(flds): clean_title(sfld) for sfld, flds in all_anki_notes} + delete_title_mismatches = True + for line in all_anknotes_notes: + guid = line['guid'] + title = line['title'] + if not (',' + TAGS.TOC + ',' in line['tagNames']): + if title.upper() == title: + log_plain(guid + '::: %-50s: ' % line['tagNames'][1:-1] + title, + FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES + '_possibletoc', do_print=True) + + title = clean_title(title) + title_safe = str_safe(title) + find_guids[guid] = title + if guid in all_anki_notes: + find_title = clean_title(all_anki_notes[guid]) + find_title_safe = str_safe(find_title) + if find_title_safe == title_safe or find_title == title: + del all_anki_notes[guid] + else: + log_plain(guid + '::: ' + title + '\n ' + ' ' * len(guid) + '::: ' + find_title, + FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES) + log_plain(repr(find_title) + '\n ' + repr(title), FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES + '-2') + anki_mismatch += 1 + if delete_title_mismatches: + del all_anki_notes[guid] + mismatch = 0 + missing_evernote_notes = [] + for enLink in find_evernote_links(enTableOfContents): + guid = enLink.Guid + title = clean_title(enLink.FullTitle) + title_safe = str_safe(title) + + if guid in find_guids: + find_title = clean_title(find_guids[guid]) + find_title_safe = str_safe(find_title) + if find_title_safe == title_safe or find_title == title: + del find_guids[guid] + else: + log_plain(guid + '::: ' + title + '\n ' + ' ' * len(guid) + '::: ' + find_title, + FILES.LOGS.FDN.ANKNOTES_TITLE_MISMATCHES) + if delete_title_mismatches: + del find_guids[guid] + mismatch += 1 + else: + log_plain(guid + '::: ' + title, FILES.LOGS.FDN.UNIMPORTED_EVERNOTE_NOTES) + missing_evernote_notes.append(guid) + + anki_dels, anknotes_dels = [], [] + for guid, title in all_anki_notes.items(): + log_plain(guid + '::: ' + title, FILES.LOGS.FDN.ANKI_ORPHANS) + anki_dels.append(guid) + for guid, title in find_guids.items(): + log_plain(guid + '::: ' + title, FILES.LOGS.FDN.ANKNOTES_ORPHANS) + anknotes_dels.append(guid) + + logs = [ + ["Orphan Anknotes DB Note(s)", + + len(anknotes_dels), + FILES.LOGS.FDN.ANKNOTES_ORPHANS, + "(not present in Evernote)" + + ], + + ["Orphan Anki Note(s)", + + len(anki_dels), + FILES.LOGS.FDN.ANKI_ORPHANS, + "(not present in Anknotes DB)" + + ], + + ["Unimported Evernote Note(s)", + + len(missing_evernote_notes), + FILES.LOGS.FDN.UNIMPORTED_EVERNOTE_NOTES, + "(not present in Anknotes DB" + + ], + + ["Anknotes DB Title Mismatches", + + mismatch, + FILES.LOGS.FDN.ANKNOTES_TITLE_MISMATCHES + + ], + + ["Anki Title Mismatches", + + anki_mismatch, + FILES.LOGS.FDN.ANKI_TITLE_MISMATCHES + + ] + ] + results = [ + [ + log[1], + log[0] if log[1] == 0 else '<a href="%s">%s</a>' % (get_log_full_path(log[2], as_url_link=True, filter_disabled=False), log[0]), + log[3] if len(log) > 3 else '' + ] + for log in logs] + + # showInfo(str(results)) + + return { + "Summary": results, "AnknotesOrphans": anknotes_dels, "AnkiOrphans": anki_dels, + "MissingEvernoteNotes": missing_evernote_notes + } diff --git a/anknotes/graphics.py b/anknotes/graphics.py new file mode 100644 index 0000000..58ebc62 --- /dev/null +++ b/anknotes/graphics.py @@ -0,0 +1,15 @@ +from anknotes.constants import * +### Anki Imports +try: + from aqt.qt import QIcon, QPixmap +except Exception: + pass + +try: + icoEvernoteWeb = QIcon(FILES.GRAPHICS.ICON.EVERNOTE_WEB) + icoEvernoteArtcore = QIcon(FILES.GRAPHICS.ICON.EVERNOTE_ARTCORE) + icoTomato = QIcon(FILES.GRAPHICS.ICON.TOMATO) + imgEvernoteWeb = QPixmap(FILES.GRAPHICS.IMAGE.EVERNOTE_WEB, "PNG") + imgEvernoteWebMsgBox = imgEvernoteWeb.scaledToWidth(64) +except Exception: + pass diff --git a/anknotes/html.py b/anknotes/html.py new file mode 100644 index 0000000..f98f56f --- /dev/null +++ b/anknotes/html.py @@ -0,0 +1,300 @@ +import re +from HTMLParser import HTMLParser + + +from anknotes.constants import SETTINGS +from anknotes.base import is_str, decode +from anknotes.db import get_evernote_title_from_guid +from anknotes.logging import log + +try: + from aqt import mw +except Exception: + pass + + +class MLStripper(HTMLParser): + def __init__(self): + HTMLParser.__init__(self) + self.reset() + self.fed = [] + + def handle_data(self, d): + self.fed.append(d) + + def get_data(self): + return ''.join(self.fed) + + +def strip_tags(html, strip_entities=False): + __html_entity_repl = '_!_DONT_STRIP_HTML_ENTITIES_!_' + if html is None: + return None + if not strip_entities: + html = html.replace('&', __html_entity_repl) + s = MLStripper() + s.feed(html) + html = s.get_data() + if not strip_entities: + html = html.replace(__html_entity_repl, '&') + return html + # s = MLStripper() + # s.feed(html) + # return s.get_data() + + +def strip_tags_and_new_lines(html): + if html is None: + return None + return re.sub(r'[\r\n]+', ' ', strip_tags(html)) + + +__text_escape_phrases = u'&|&|\'|'|"|"|>|>|<|<'.split('|') + + +def escape_text(title): + global __text_escape_phrases + for i in range(0, len(__text_escape_phrases), 2): + title = title.replace(__text_escape_phrases[i], __text_escape_phrases[i + 1]) + return title + + +def unescape_text(title, try_decoding=False): + title_orig = title + global __text_escape_phrases + if try_decoding: + title = decode(title) + try: + for i in range(0, len(__text_escape_phrases), 2): + title = title.replace(__text_escape_phrases[i + 1], __text_escape_phrases[i]) + title = title.replace(u" ", u" ") + except Exception: + if try_decoding: + raise UnicodeError + title_new = unescape_text(title, True) + log(title + '\n' + title_new + '\n\n', 'unicode') + return title_new + return title + + +def clean_title(title): + title = unescape_text(decode(title)) + title = re.sub(r'( |\xa0)+', ' ', decode(title)) + return title + + +def generate_evernote_url(guid): + ids = get_evernote_account_ids() + return u'evernote:///view/%s/%s/%s/%s/' % (ids.uid, ids.shard, guid, guid) + + +def generate_evernote_link_by_type(guid, title=None, link_type=None, value=None, escape=True): + url = generate_evernote_url(guid) + if not title: + title = get_evernote_title_from_guid(guid) + if escape: + title = escape_text(title) + style = generate_evernote_html_element_style_attribute(link_type, value) + html = u"""<a href="%s"><span style="%s">%s</span></a>""" % (url, style, title) + # print html + return html + + +def generate_evernote_link(guid, title=None, value=None, escape=True): + return generate_evernote_link_by_type(guid, title, 'Links', value, escape=escape) + + +def generate_evernote_link_by_level(guid, title=None, value=None, escape=True): + return generate_evernote_link_by_type(guid, title, 'Levels', value, escape=escape) + + +def generate_evernote_html_element_style_attribute(link_type, value, bold=True, group=None): + global evernote_link_colors + colors = None + if link_type in evernote_link_colors: + color_types = evernote_link_colors[link_type] + if link_type is 'Levels': + if not value: + value = 1 + if not group: + group = 'OL' if isinstance(value, int) else 'Modifiers' + if not value in color_types[group]: + group = 'Headers' + if value in color_types[group]: + colors = color_types[group][value] + elif link_type is 'Links': + if not value: + value = 'Default' + if value in color_types: + colors = color_types[value] + if not colors: + colors = evernote_link_colors['Default'] + colorDefault = colors + if not is_str_type(colorDefault): + colorDefault = colorDefault['Default'] + if not colorDefault[-1] is ';': + colorDefault += ';' + style = 'color: ' + colorDefault + if bold: + style += 'font-weight:bold;' + return style + + +def generate_evernote_span(title=None, element_type=None, value=None, guid=None, bold=True, escape=True): + assert title or guid + if not title: + title = get_evernote_title_from_guid(guid) + if escape: + title = escape_text(title) + style = generate_evernote_html_element_style_attribute(element_type, value, bold) + html = u"""<span style="%s">%s</span>""" % (style, title) + return html + + +evernote_link_colors = { + 'Levels': { + 'OL': { + 1: { + 'Default': 'rgb(106, 0, 129);', + 'Hover': 'rgb(168, 0, 204);' + }, + 2: { + 'Default': 'rgb(235, 0, 115);', + 'Hover': 'rgb(255, 94, 174);' + }, + 3: { + 'Default': 'rgb(186, 0, 255);', + 'Hover': 'rgb(213, 100, 255);' + }, + 4: { + 'Default': 'rgb(129, 182, 255);', + 'Hover': 'rgb(36, 130, 255);' + }, + 5: { + 'Default': 'rgb(232, 153, 220);', + 'Hover': 'rgb(142, 32, 125);' + }, + 6: { + 'Default': 'rgb(201, 213, 172);', + 'Hover': 'rgb(130, 153, 77);' + }, + 7: { + 'Default': 'rgb(231, 179, 154);', + 'Hover': 'rgb(215, 129, 87);' + }, + 8: { + 'Default': 'rgb(249, 136, 198);', + 'Hover': 'rgb(215, 11, 123);' + } + }, + 'Headers': { + 'Auto TOC': 'rgb(11, 59, 225);' + }, + 'Modifiers': { + 'Orange': 'rgb(222, 87, 0);', + 'Orange (Light)': 'rgb(250, 122, 0);', + 'Dark Red/Pink': 'rgb(164, 15, 45);', + 'Pink Alternative LVL1:': 'rgb(188, 0, 88);' + } + }, + 'Titles': { + 'Field Title Prompt': 'rgb(169, 0, 48);' + }, + 'Links': { + 'See Also': { + 'Default': 'rgb(45, 79, 201);', + 'Hover': 'rgb(108, 132, 217);' + }, + 'TOC': { + 'Default': 'rgb(173, 0, 0);', + 'Hover': 'rgb(196, 71, 71);' + }, + 'Outline': { + 'Default': 'rgb(105, 170, 53);', + 'Hover': 'rgb(135, 187, 93);' + }, + 'AnkNotes': { + 'Default': 'rgb(30, 155, 67);', + 'Hover': 'rgb(107, 226, 143);' + } + } +} + +evernote_link_colors['Default'] = evernote_link_colors['Links']['Outline'] +evernote_link_colors['Links']['Default'] = evernote_link_colors['Default'] + +enAccountIDs = None + + +def get_evernote_account_ids(): + global enAccountIDs + if not enAccountIDs or not enAccountIDs.Valid: + enAccountIDs = EvernoteAccountIDs() + return enAccountIDs + + +def tableify_column(column): + return str(column).replace('\n', '\n<BR>').replace(' ', '  ') + + +def tableify_lines(rows, columns=None, tr_index_offset=0, return_html=True): + if columns is None: + columns = [] + elif not isinstance(columns, list): + columns = [columns] + trs = ['<tr class="tr%d%s">%s\n</tr>\n' % (i_row, ' alt' if i_row % 2 is 0 else ' std', ''.join( + ['\n <td class="td%d%s">%s</td>' % (i_col + 1, ' alt' if i_col % 2 is 0 else ' std', tableify_column(column)) + for i_col, column in enumerate(row if isinstance(row, list) else row.split('|'))])) for i_row, row in + enumerate(columns + rows)] + if return_html: + return "<table cellspacing='0' style='border: 1px solid black;border-collapse: collapse;'>\n%s</table>" % ''.join( + trs) + return trs + + +class EvernoteAccountIDs: + uid = SETTINGS.EVERNOTE.ACCOUNT.UID_DEFAULT_VALUE + shard = SETTINGS.EVERNOTE.ACCOUNT.SHARD_DEFAULT_VALUE + + @property + def Valid(self): + return self.is_valid() + + def is_valid(self, uid=None, shard=None): + if uid is None: + uid = self.uid + if shard is None: + shard = self.shard + if not uid or not shard: + return False + if uid == '0' or uid == SETTINGS.EVERNOTE.ACCOUNT.UID.val or not unicode( + uid).isnumeric(): return False + if shard == 's999' or uid == SETTINGS.EVERNOTE.ACCOUNT.SHARD.val or shard[0] != 's' or not unicode( + shard[1:]).isnumeric(): return False + return True + + def __init__(self, uid=None, shard=None): + if uid and shard: + if self.update(uid, shard): + return + try: + self.uid = SETTINGS.EVERNOTE.ACCOUNT.UID.fetch() + self.shard = SETTINGS.EVERNOTE.ACCOUNT.SHARD.fetch() + if self.Valid: + return + except Exception: + pass + self.uid = SETTINGS.EVERNOTE.ACCOUNT.UID.val + self.shard = SETTINGS.EVERNOTE.ACCOUNT.SHARD.val + + def update(self, uid, shard): + if not self.is_valid(uid, shard): + return False + try: + SETTINGS.EVERNOTE.ACCOUNT.UID.save(uid) + SETTINGS.EVERNOTE.ACCOUNT.SHARD.save(shard) + except Exception: + return False + self.uid = uid + self.shard = shard + return self.Valid diff --git a/anknotes/imports.py b/anknotes/imports.py new file mode 100644 index 0000000..4890b72 --- /dev/null +++ b/anknotes/imports.py @@ -0,0 +1,73 @@ +import os +import imp +import sys + +lxml = None +etree = None + + +def in_anki(): + return 'anki' in sys.modules + +def import_module(name, path=None, sublevels=2, path_suffix=''): + print "Import " + str(path) + " Level " + str(sublevels) + if path is None: + path = os.path.dirname(__file__) + print "Auto Path " + path + for i in range(0, sublevels): + path = os.path.join(path, '..' + os.path.sep) + print "Path Level " + str(i) + " - " + path + if path_suffix: + path = os.path.join(path, path_suffix) + path = os.path.abspath(path) + try: + modfile, modpath, description = imp.find_module(name, [path + os.path.sep]) + modobject = imp.load_module(name, modfile, modpath, description) + except ImportError as e: + print path + '\n' + str(e) + import pdb + import traceback + print traceback.format_exc() + pdb.set_trace() + return None + try: + modfile.close() + except Exception: + pass + return modobject + + +def import_anki_module(name): + return import_module(name, path_suffix='anki_master' + os.path.sep) + + +def import_etree(): + global etree + from anknotes.constants import ANKNOTES + if not ANKNOTES.LXML.ENABLE_IN_ANKI and in_anki(): + return False + if not import_lxml(): + return False + try: + from lxml import etree; return True + except Exception: + return False + + +def import_lxml(): + global lxml + try: + assert lxml + return True + except Exception: + pass + try: + import lxml + return True + except ImportError as e: + lxml = None + pass + import os + import imp + lxml = import_module('lxml') + return lxml is not None diff --git a/anknotes/index.html b/anknotes/index.html deleted file mode 100644 index c4478fb..0000000 --- a/anknotes/index.html +++ /dev/null @@ -1,93 +0,0 @@ - -<html lang="en"> - <head> - <meta charset="utf-8"> - <meta http-equiv="X-UA-Compatible" content="IE=edge"> - <meta name="viewport" content="width=device-width, initial-scale=1"> - <meta name="description" content=""> - <meta name="author" content=""> - <link rel="icon" href="../../favicon.ico"> - - <title>Starter Template for Bootstrap - - - - - - - - - - - - - - - - - -
      - -
      -

      Evernote Auth

      -

      Copy the shown key into the Anki Pop up.

      -
      -
      -

      -
      - -
      - - - - - - - - - diff --git a/anknotes/logging.py b/anknotes/logging.py new file mode 100644 index 0000000..63f8970 --- /dev/null +++ b/anknotes/logging.py @@ -0,0 +1,600 @@ +# Python Imports +from datetime import datetime, timedelta +import difflib +import pprint +import re +import shutil +import time +from fnmatch import fnmatch + + +# Anknotes Shared Imports +from anknotes.constants import * +from anknotes.logging_base import write_file_contents, rm_log_path, reset_logs, filter_logs +from anknotes.base import item_to_list, is_str, is_str_type, caller_name, create_log_filename, str_safe, encode, decode +from anknotes.methods import create_timer +from anknotes.args import Args +from anknotes.graphics import * +from anknotes.dicts import DictCaseInsensitive + +# Anki Imports +try: + from aqt import mw + from aqt.utils import tooltip + from aqt.qt import QMessageBox, QPushButton, QSizePolicy, QSpacerItem, QGridLayout, QLayout +except Exception: + pass + + +def show_tooltip(text, time_out=7, delay=None, do_log=False, **kwargs): + if not hasattr(show_tooltip, 'enabled'): + show_tooltip.enabled = None + if do_log: + log(text, **kwargs) + if delay: + try: + return create_timer(delay, tooltip, text, time_out * 1000) + except Exception: + pass + if show_tooltip.enabled is not False: + tooltip(text, time_out * 1000) + + +def counts_as_str(count, max_=None): + from anknotes.counters import Counter + if isinstance(count, Counter): + count = count.val + if isinstance(max_, Counter): + max_ = max_.val + if max_ is None or max_ <= 0: + return str(count).center(3) + if count == max_: + return "All %s" % str(count).center(3) + return "Total %s of %s" % (str(count).center(3), str(max_).center(3)) + + +def format_count(format_str, count): + """ + :param format_str: + :type format_str : str | unicode + :param count: + :return: + """ + if not count > 0: + return ' ' * len(format_str % 1) + return format_str % count + + +def show_report(title, header=None, log_lines=None, delay=None, log_header_prefix=' ' * 5, + blank_line_before=True, blank_line_after=True, hr_if_empty=False, **kw): + if log_lines is None: + log_lines = [] + if header is None: + header = [] + lines = [] + for line in ('
      '.join(header) if isinstance(header, list) else header).split('
      ') + ( + '
      '.join(log_lines).split('
      ') if log_lines else []): + level = 0 + while line and line[level] is '-': level += 1 + lines.append('\t' * level + ('\t\t- ' if lines else '') + line[level:]) + if len(lines) > 1: + lines[0] += ': ' + log_text = '
      '.join(lines) + if not header and not log_lines: + i = title.find('> ') + show_tooltip(title[0 if i < 0 else i + 2:], delay=delay) + else: + show_tooltip(log_text.replace('\t', '  ' * 4), delay=delay) + if blank_line_before: + log_blank(**kw) + log(title, **kw) + if len(lines) == 1 and not lines[0]: + if hr_if_empty: + log("-" * ANKNOTES.FORMATTING.LINE_LENGTH, timestamp=False, **kw) + return + log("-" * ANKNOTES.FORMATTING.LINE_LENGTH + '\n' + log_header_prefix + log_text.replace('
      ', '\n'), + timestamp=False, replace_newline=True, **kw) + if blank_line_after: + log_blank(**kw) + +def showInfo(message, title="Anknotes: Evernote Importer for Anki", textFormat=0, cancelButton=False, richText=False, + minHeight=None, minWidth=400, styleSheet=None, convertNewLines=True): + global imgEvernoteWebMsgBox, icoEvernoteArtcore, icoEvernoteWeb + msgDefaultButton = QPushButton(icoEvernoteArtcore, "Okay!", mw) + + if not styleSheet: + styleSheet = file(FILES.ANCILLARY.CSS_QMESSAGEBOX, 'r').read() + + if not is_str_type(message): + message = str(message) + + if richText: + textFormat = 1 + message = '\n\n%s' % (styleSheet, message) + global messageBox + messageBox = QMessageBox() + messageBox.addButton(msgDefaultButton, QMessageBox.AcceptRole) + if cancelButton: + msgCancelButton = QPushButton(icoTomato, "No Thanks", mw) + messageBox.addButton(msgCancelButton, QMessageBox.RejectRole) + messageBox.setDefaultButton(msgDefaultButton) + messageBox.setIconPixmap(imgEvernoteWebMsgBox) + messageBox.setTextFormat(textFormat) + + messageBox.setWindowIcon(icoEvernoteWeb) + messageBox.setWindowIconText("Anknotes") + messageBox.setText(message) + messageBox.setWindowTitle(title) + hSpacer = QSpacerItem(minWidth, 0, QSizePolicy.Minimum, QSizePolicy.Expanding) + + layout = messageBox.layout() + """:type : QGridLayout """ + layout.addItem(hSpacer, layout.rowCount() + 1, 0, 1, layout.columnCount()) + ret = messageBox.exec_() + if not cancelButton: + return True + if messageBox.clickedButton() == msgCancelButton or messageBox.clickedButton() == 0: + return False + return True + + +def diffify(content, split=True): + for tag in [u'div', u'ol', u'ul', u'li', u'span']: + content = content.replace(u"<" + tag, u"\n<" + tag).replace(u"" % tag, u"\n" % tag) + content = re.sub(r'[\r\n]+', u'\n', content) + return content.splitlines() if split else content + + +def generate_diff(value_original, value): + try: + return '\n'.join(list(difflib.unified_diff(diffify(value_original), diffify(value), lineterm=''))) + except Exception: + pass + try: + return '\n'.join( + list(difflib.unified_diff(diffify(decode(value_original)), diffify(value), lineterm=''))) + except Exception: + pass + try: + return '\n'.join( + list(difflib.unified_diff(diffify(value_original), diffify(decode(value)), lineterm=''))) + except Exception: + pass + try: + return '\n'.join(list( + difflib.unified_diff(diffify(decode(value_original)), diffify(decode(value)), lineterm=''))) + except Exception: + raise + + +def PadList(lst, length=ANKNOTES.FORMATTING.LIST_PAD): + newLst = [] + for val in lst: + if isinstance(val, list): + newLst.append(PadList(val, length)) + else: + newLst.append(val.center(length)) + return newLst + + +def JoinList(lst, joiners='\n', pad=0, depth=1): + if is_str_type(joiners): + joiners = [joiners] + str_ = '' + if pad and is_str_type(lst): + return lst.center(pad) + if not lst or not isinstance(lst, list): + return lst + delimit = joiners[min(len(joiners), depth) - 1] + for val in lst: + if str_: + str_ += delimit + str_ += JoinList(val, joiners, pad, depth + 1) + return str_ + + +def PadLines(content, line_padding=ANKNOTES.FORMATTING.LINE_PADDING_HEADER, line_padding_plus=0, line_padding_header='', + pad_char=' ', **kwargs): + if not line_padding and not line_padding_plus and not line_padding_header: + return content + if not line_padding: + line_padding = line_padding_plus; line_padding_plus = True + if str(line_padding).isdigit(): + line_padding = pad_char * int(line_padding) + if line_padding_header: + content = line_padding_header + content; line_padding_plus = len(line_padding_header) + 1 + elif line_padding_plus is True: + line_padding_plus = content.find('\n') + if str(line_padding_plus).isdigit(): + line_padding_plus = pad_char * int(line_padding_plus) + return line_padding + content.replace('\n', '\n' + line_padding + line_padding_plus) + + +def obj2log_simple(content): + if not is_str_type(content): + content = str(content) + return content + + +def convert_filename_to_local_link(filename): + return 'file:///' + filename.replace("\\", "//") + + +class Logger(object): + base_path = None + path_suffix = None + caller_info = None + default_filename = None + defaults = {} + auto_header=True + default_banner=None + + def wrap_filename(self, filename=None, final_suffix='', wrap_fn_auto_header=True, crosspost=None, **kwargs): + if filename is None: + filename = self.default_filename + if self.base_path is not None: + filename = os.path.join(self.base_path, filename if filename else '') + if self.path_suffix is not None: + i_asterisk = filename.find('*') + if i_asterisk > -1: + final_suffix += filename[i_asterisk + 1:] + filename = filename[:i_asterisk] + filename += self.path_suffix + final_suffix + if crosspost is not None: + crosspost = [self.wrap_filename(cp)[0] for cp in item_to_list(crosspost, False)] + kwargs['crosspost'] = crosspost + + if wrap_fn_auto_header and self.auto_header and self.default_banner and not os.path.exists(get_log_full_path(filename)): + log_banner(self.default_banner, filename) + return filename, kwargs + + def error(self, content, crosspost=None, *a, **kw): + if crosspost is None: + crosspost = [] + crosspost.append(self.wrap_filename('error'), **DictCaseInsensitive(self.defaults, kw)) + log_error(content, crosspost=crosspost, *a, **kw) + + def dump(self, obj, title='', filename=None, *args, **kwargs): + filename, kwargs = self.wrap_filename(filename, **DictCaseInsensitive(self.defaults, kwargs)) + # noinspection PyArgumentList + log_dump(obj, title, filename, *args, **kwargs) + + def blank(self, filename=None, *args, **kwargs): + filename, kwargs = self.wrap_filename(filename, **DictCaseInsensitive(self.defaults, kwargs)) + log_blank(filename, *args, **kwargs) + + def banner(self, title, filename=None, *args, **kwargs): + filename, kwargs = self.wrap_filename(filename, **DictCaseInsensitive(self.defaults, kwargs, wrap_fn_auto_header=False)) + self.default_banner = title + log_banner(title, filename, *args, **kwargs) + + def go(self, content=None, filename=None, wrap_filename=True, *args, **kwargs): + if wrap_filename: + filename, kwargs = self.wrap_filename(filename, **DictCaseInsensitive(self.defaults, kwargs)) + log(content, filename, *args, **kwargs) + + def plain(self, content=None, filename=None, *args, **kwargs): + filename, kwargs = self.wrap_filename(filename, **DictCaseInsensitive(self.defaults, kwargs)) + log_plain(content, filename, *args, **kwargs) + + log = do = add = go + + def default(self, *args, **kwargs): + self.log(wrap_filename=False, *args, **DictCaseInsensitive(self.defaults, kwargs)) + + def __init__(self, base_path=None, default_filename=None, rm_path=False, no_base_path=None, **kwargs): + self.defaults = kwargs + if no_base_path and not default_filename: + default_filename = no_base_path + self.default_filename = default_filename + if base_path: + self.base_path = base_path + elif not no_base_path: + self.caller_info = caller_name() + if self.caller_info: + self.base_path = create_log_filename(self.caller_info.Base) + os.path.sep + if rm_path: + rm_log_path(self.base_path) + + +def log_blank(*args, **kwargs): + log(None, *args, **DictCaseInsensitive(kwargs, timestamp=False, delete='content')) + + +def log_plain(*args, **kwargs): + log(*args, **DictCaseInsensitive(kwargs, timestamp=False)) + +def rm_log_paths(*args, **kwargs): + for arg in args: + rm_log_path(arg, **kwargs) + +def log_banner(title, filename=None, length=ANKNOTES.FORMATTING.BANNER_MINIMUM, append_newline=True, timestamp=False, + chr='-', center=True, clear=True, crosspost=None, prepend_newline=False, *args, **kwargs): + if crosspost is not None: + for cp in item_to_list(crosspost, False): + log_banner(title, cp, **DictCaseInsensitive(kwargs, locals(), delete='title crosspost kwargs args filename')) + if length is 0: + length = ANKNOTES.FORMATTING.LINE_LENGTH + 1 + if center: + title = title.center(length - (ANKNOTES.FORMATTING.TIMESTAMP_PAD_LENGTH if timestamp else 0)) + if prepend_newline: + log_blank(filename, **kwargs) + log(chr * length, filename, clear=clear, timestamp=False, **kwargs) + log(title, filename, timestamp=timestamp, **kwargs) + log(chr * length, filename, timestamp=False, **kwargs) + if append_newline: + log_blank(filename, **kwargs) + + +_log_filename_history = [] + + +def set_current_log(fn): + global _log_filename_history + _log_filename_history.append(fn) + + +def end_current_log(fn=None): + global _log_filename_history + if fn: + _log_filename_history.remove(fn) + else: + _log_filename_history = _log_filename_history[:-1] + + +def get_log_full_path(filename=None, extension='log', as_url_link=False, prefix='', filter_disabled=True, **kwargs): + global _log_filename_history + logging_base_name = FILES.LOGS.BASE_NAME + filename_suffix = '' + if filename and filename.startswith('*'): + filename_suffix = '\\' + filename[1:] + logging_base_name = '' + filename = None + if filename is None: + if FILES.LOGS.USE_CALLER_NAME: + caller = caller_name() + if caller: + filename = caller.Base.replace('.', '\\') + if filename is None: + filename = _log_filename_history[-1] if _log_filename_history else FILES.LOGS.ACTIVE + if not filename: + filename = logging_base_name + if not filename: + filename = FILES.LOGS.DEFAULT_NAME + else: + if filename[0] is '+': + filename = filename[1:] + filename = (logging_base_name + '-' if logging_base_name and logging_base_name[-1] != '\\' else '') + filename + filename += filename_suffix + if filename and filename.endswith(os.path.sep): + filename += 'main' + filename = re.sub(r'[^\w\-_\.\\]', '_', filename) + if filter_disabled and not filter_logs(filename): + return False + filename += ('.' if filename and filename[-1] is not '.' else '') + extension + full_path = os.path.join(FOLDERS.LOGS, filename) + if prefix: + parent, fn = os.path.split(full_path) + if fn != '.' + extension: + fn = '-' + fn + full_path = os.path.join(parent, prefix + fn) + full_path = os.path.abspath(full_path) + if not os.path.exists(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + if as_url_link: + return convert_filename_to_local_link(full_path) + return full_path + + +def encode_log_text(content, encode_text=True, **kwargs): + if not encode_text: + return content + try: + return encode(content) + except Exception: + return content + + +def parse_log_content(content, prefix='', **kwargs): + if content is None: + return '', prefix + if not is_str_type(content): + content = pf(content, pf_replace_newline=False, pf_encode_text=False) + if not content: + content = '{EMPTY STRING}' + if content.startswith("!"): + content = content[1:]; prefix = '\n' + return content, prefix + + +def process_log_content(content, prefix='', timestamp=None, do_encode=True, **kwargs): + content = pad_lines_regex(content, timestamp=timestamp, **kwargs) + st = '[%s]:\t' % datetime.now().strftime(ANKNOTES.DATE_FORMAT) if timestamp else '' + return prefix + ' ' + st + (encode_log_text(content, **kwargs) if do_encode else content), content + + +def crosspost_log(content, filename=None, crosspost_to_default=False, crosspost=None, do_show_tooltip=False, **kwargs): + if crosspost_to_default and filename: + summary = " ** %s%s: " % ('' if filename.upper() == 'ERROR' else 'CROSS-POST TO ', filename.upper()) + content + log(summary[:200], **kwargs) + if do_show_tooltip: + show_tooltip(content) + if not crosspost: + return + for fn in item_to_list(crosspost): + log(content, fn, **kwargs) + + +def pad_lines_regex(content, timestamp=None, replace_newline=None, try_decode=True, **kwargs): + content = PadLines(content, **kwargs) + if not (timestamp and replace_newline is not False) and not replace_newline: + return content + try: + return re.sub(r'[\r\n]+', u'\n' + ANKNOTES.FORMATTING.TIMESTAMP_PAD, content) + except UnicodeDecodeError: + if not try_decode: + raise + return re.sub(r'[\r\n]+', u'\n' + ANKNOTES.FORMATTING.TIMESTAMP_PAD, decode(content)) + +# @clockit +def log(content=None, filename=None, **kwargs): + kwargs = Args(kwargs).set_kwargs('line_padding, line_padding_plus, line_padding_header', timestamp=True) + write_file_contents('Log Args: ' + str(kwargs.items()), 'args\\log_kwargs', get_log_full_path=get_log_full_path) + content, prefix = parse_log_content(content, **kwargs) + crosspost_log(content, filename, **kwargs) + full_path = get_log_full_path(filename, **kwargs) + if full_path is False: + return + content, print_content = process_log_content(content, prefix, **kwargs) + write_file_contents(content, full_path, print_content=print_content, get_log_full_path=get_log_full_path, **kwargs) + + +def log_sql(content, a=None, kw=None, self=None, sql_fn_prefix='', **kwargs): + table = re.sub(r'[^A-Z_ ]', ' ', content.upper()) + table = ' %s ' % re.sub(' +', ' ', table).replace(' IF NOT EXISTS ', ' ').replace(' IF EXISTS ', ' ').strip() + if table.startswith('CREATE') or table.startswith('DROP'): + table = 'TABLES' + else: + for stmt in ' WHERE , VALUES '.split(','): + i = table.find(stmt) + if i > -1: + table = table[:i] + found = (-1, None) + for stmt in ' FROM , INTO , UPDATE , TABLE '.split(','): + i = table.find(stmt) + if i is -1: + continue + if i > found[0] > -1: + continue + found = (i, stmt) + if found[0] > -1: + table = table[found[0] + len(found[1]):].strip() + if ' ' in table: + table = table[:table.find(' ')] + if a or kw: + content = u"SQL: %s" % content + if self: + content += u"\n\nSelf: " + pf(self, pf_encode_text=False, pf_decode_text=True) + if a: + content += u"\n\nArgs: " + pf(a, pf_encode_text=False, pf_decode_text=True) + if kw: + content += u"\n\nKwargs: " + pf(kw, pf_encode_text=False, pf_decode_text=True) + log(content, 'sql\\' + sql_fn_prefix + table, **kwargs) + + +def log_error(content, *a, **kw): + kwargs = Args(a, kw, set_list=['crosspost_to_default', True], use_set_list_as_arg_list=True, require_all_args=False).kwargs + log(content, 'error', **kwargs) + + +def pf(obj, title='', pf_replace_newline=True, pf_encode_text=True, pf_decode_text=False, *a, **kw): + content = pprint.pformat(obj, indent=4, width=ANKNOTES.FORMATTING.PPRINT_WIDTH) + content = content.replace(', ', ', \n ') + if pf_replace_newline: + content = content.replace('\r', '\r' + ' ' * 30).replace('\n', '\n' + ' ' * 30) + if pf_encode_text: + content = encode_log_text(content) + elif pf_decode_text: + content = decode(content, errors='ignore') + if title: + content = title + ": " + content + return content + + +def print_dump(*a, **kw): + content = pf(*a, **kw) + print content + return content + + +pp = print_dump + + +def log_dump(obj, title="Object", filename='', crosspost_to_default=True, **kwargs): + content = pprint.pformat(obj, indent=4, width=ANKNOTES.FORMATTING.PPRINT_WIDTH) + try: + content = decode(content, errors='ignore') + except Exception: + pass + content = content.replace("\\n", '\n').replace('\\r', '\r') + if filename and filename[0] is '+': + summary = " ** CROSS-POST TO %s: " % filename[1:] + content + log(summary[:200]) + full_path = get_log_full_path(filename, prefix='dump', **kwargs) + if full_path is False: + return + if not title: + title = "<%s>" % obj.__class__.__name__ + if title.startswith('-'): + crosspost_to_default = False; title = title[1:] + prefix = " **** Dumping %s" % title + if crosspost_to_default: + log(prefix + + " to " + os.path.splitext(full_path.replace(FOLDERS.LOGS + os.path.sep, ''))[0]) + + content = encode_log_text(content) + + try: + prefix += '\r\n' + content = prefix + content.replace(', ', ', \n ') + content = content.replace("': {", "': {\n ") + content = content.replace('\r', '\r' + ' ' * 30).replace('\n', '\n' + ' ' * 30) + except Exception: + pass + + if not os.path.exists(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + try_print(full_path, content, prefix, **kwargs) + + +def try_print(full_path, content, prefix='', line_prefix=u'\n ', attempt=0, clear=False, timestamp=True, **kwargs): + try: + st = '[%s]: ' % datetime.now().strftime(ANKNOTES.DATE_FORMAT) if timestamp else '' + print_content = line_prefix + (u' <%d>' % attempt if attempt > 0 else u'') + u' ' + st + if attempt is 0: + print_content += content + elif attempt is 1: + print_content += decode(content) + elif attempt is 2: + print_content += encode(content) + elif attempt is 3: + print_content = encode(print_content) + encode(content) + elif attempt is 4: + print_content = decode(print_content) + decode(content) + elif attempt is 5: + print_content += "Error printing content: " + str_safe(content) + elif attempt is 6: + print_content += "Error printing content: " + content[:10] + elif attempt is 7: + print_content += "Unable to print content." + with open(full_path, 'w+' if clear else 'a+') as fileLog: + print>> fileLog, print_content + except Exception as e: + if attempt < 8: + try_print(full_path, content, prefix=prefix, line_prefix=line_prefix, attempt=attempt + 1, + clear=clear) + else: + log("Try print error to %s: %s" % (os.path.split(full_path)[1], str(e))) + + +def log_api(method, content='', **kw): + if content: + content = ': ' + content + log(" API_CALL [%3d]: %10s%s" % (get_api_call_count(), method, content), 'api', **kw) + + +def get_api_call_count(): + path = get_log_full_path('api') + if path is False or not os.path.exists(path): + return 0 + api_log = file(path, 'r').read().splitlines() + count = 1 + for i in range(len(api_log), 0, -1): + call = api_log[i - 1] + if "API_CALL" not in call: + continue + ts = call.replace(':\t', ': ').split(': ')[0][2:-1] + td = datetime.now() - datetime.strptime(ts, ANKNOTES.DATE_FORMAT) + if td >= timedelta(hours=1): + break + count += 1 + return count diff --git a/anknotes/logging_base.py b/anknotes/logging_base.py new file mode 100644 index 0000000..2bba5f6 --- /dev/null +++ b/anknotes/logging_base.py @@ -0,0 +1,103 @@ +# Python Imports +import os +import shutil +from fnmatch import fnmatch +from datetime import datetime +import time + +# Anknotes Main Imports +from anknotes.constants_standard import FILES, FOLDERS, ANKNOTES +from anknotes.base import encode, item_to_list + +def write_file_contents(content, full_path, clear=False, try_encode=True, do_print=False, print_timestamp=True, + print_content=None, wfc_timestamp=True, wfc_crosspost=None, get_log_full_path=None, **kwargs): + all_args = locals() + if wfc_crosspost: + del all_args['kwargs'], all_args['wfc_crosspost'], all_args['content'], all_args['full_path'] + all_args.update(kwargs) + for cp in item_to_list(wfc_crosspost): + write_file_contents(content, cp, **all_args) + orig_path = full_path + if not os.path.exists(os.path.dirname(full_path)): + if callable(get_log_full_path): + full_path = get_log_full_path(full_path) + if full_path is False: + return + else: + if not filter_logs(full_path): + return + full_path = os.path.abspath(os.path.join(FOLDERS.LOGS, full_path + '.log')) + base_path = os.path.dirname(full_path) + if not os.path.exists(base_path): + os.makedirs(base_path) + if wfc_timestamp: + print_content = content + content = '[%s]: ' % datetime.now().strftime(ANKNOTES.DATE_FORMAT) + content + with open(full_path, 'w+' if clear else 'a+') as fileLog: + try: + print>> fileLog, content + except UnicodeEncodeError: + content = encode(content) + print>> fileLog, content + if do_print: + print content if print_timestamp or not print_content else print_content + +def filter_logs(filename): + def do_filter(x): return fnmatch(filename, x) + return (filter(do_filter, item_to_list(FILES.LOGS.ENABLED)) and not + filter(do_filter, item_to_list(FILES.LOGS.DISABLED))) + +def reset_logs(folder='', banner='', clear=True, *a, **kw): + absolutely_unused_variable = os.system("cls") + keep = ['anknotes', 'api', 'automation'] + folder = os.path.join(FOLDERS.LOGS, folder) + logs = os.listdir(folder) + for fn in logs: + full_path = os.path.join(folder, fn) + if os.path.isfile(full_path): + if filter(lambda x: fnmatch(fn, x + '*'), keep): + continue + if clear: + with open(full_path, 'w+') as myFile: + if banner: + print >> myFile, banner + else: + os.unlink(full_path) + else: + rm_log_path(fn) + +def rm_log_path(filename='*', subfolders_only=False, retry_errors=0, get_log_full_path=None, *args, **kwargs): + def del_subfolder(arg=None, dirname=None, filenames=None, is_subfolder=True): + def rmtree_error(f, p, e): + rm_log_path.errors += [p] + + # Begin del_subfolder + if is_subfolder and dirname is path: + return + shutil.rmtree(dirname, onerror=rmtree_error) + + # Begin rm_log_path + if callable(get_log_full_path): + path = get_log_full_path(filename, filter_disabled=False) + else: + path = filename + if FOLDERS.LOGS not in path.strip(os.path.sep): + path = os.path.join(FOLDERS.LOGS, path) + path = os.path.abspath(path) + if not os.path.isdir(path): + path = os.path.dirname(path) + if path is FOLDERS.LOGS or FOLDERS.LOGS not in path: + return + rm_log_path.errors = [] + + if not subfolders_only: + del_subfolder(dirname=path, is_subfolder=False) + else: + os.path.walk(path, del_subfolder, None) + if rm_log_path.errors: + if retry_errors > 5: + print "Unable to delete log path: " + path + ' -> ' + filename + write_file_contents("Unable to delete log path as requested", filename) + return + time.sleep(1) + rm_log_path(filename, subfolders_only, retry_errors + 1) diff --git a/anknotes/menu.py b/anknotes/menu.py new file mode 100644 index 0000000..2e3bfb8 --- /dev/null +++ b/anknotes/menu.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# Python Imports +from subprocess import * +from datetime import datetime + +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite +# Anknotes Shared Imports +from anknotes.shared import * +from anknotes.constants import * +from anknotes.dicts import DictCaseInsensitive +from anknotes.logging import show_tooltip + +# Anknotes Main Imports +import anknotes.Controller +import anknotes.create_subnotes + +# Anki Imports +from aqt.qt import SIGNAL, QMenu, QAction +from aqt import mw +from aqt.utils import getText + +# noinspection PyTypeChecker +def anknotes_setup_menu(): + menu_items = [ + [u"&Anknotes", + [ + ["&Import from Evernote", import_from_evernote], + ["&Enable Auto Import On Profile Load", {'action': anknotes_menu_auto_import_changed, 'checkable': True}], + ["Note &Validation", + [ + ["Validate &And Upload Pending Notes", validate_pending_notes], + ["SEPARATOR", None], + ["&Validate Pending Notes", lambda: validate_pending_notes(True, False)], + ["&Upload Validated Notes", upload_validated_notes] + ] + ], + ["Process &See Also Footer Links [Power Users Only!]", + [ + ["Complete All &Steps", see_also], + ["SEPARATOR", None], + ["Step &1: Process Anki Notes Without See Also Field", lambda: see_also(1)], + ["SEPARATOR", None], + ["Step &2: Create Auto TOC Evernote Notes", lambda: see_also(2)], + ["Step &3: Validate and Upload Auto TOC Notes", lambda: see_also(3)], + ["Step &4: Extract Links from TOC Notes", lambda: see_also(4)], + ["SEPARATOR", None], + ["Step &5: Insert TOC/Outline Links Into Anki Notes", lambda: see_also(5)], + ["Step &6: Update See Also Footer In Evernote Notes", lambda: see_also(6)], + ["Step &7: Validate and Upload Modified Evernote Notes", lambda: see_also(7)], + ["SEPARATOR", None], + ["Step &8: Insert TOC and Outline Content Into Anki Notes", lambda: see_also(8)] + ] + ], + ["&Maintenance Tasks", + [ + ["Find &Deleted Notes", find_deleted_notes], + ["Res&ync with Local DB", resync_with_local_db], + ["Update Evernote &Ancillary Data", update_ancillary_data], + ["&lxml Test", lxml_test] + ] + ] + + ] + ] + ] + add_menu_items(menu_items) + + +def auto_reload_wrapper(function): return lambda: auto_reload_modules(function) + + +def auto_reload_modules(function): + if ANKNOTES.DEVELOPER_MODE.ENABLED and ANKNOTES.DEVELOPER_MODE.AUTO_RELOAD_MODULES: + log_banner('AUTO RELOAD MODULES - RELOADING', 'automation', claar=True) + anknotes.shared = reload(anknotes.shared) + if not anknotes.Controller: + from anknotes.imports import import_module + import_module('anknotes.Controller', sublevels=1) + if not anknotes.create_subnotes: + from anknotes.imports import import_module + import_module('anknotes.create_subnotes', sublevels=1) + reload(anknotes.Controller) + reload(anknotes.create_subnotes) + else: + log_banner('AUTO RELOAD MODULES - SKIPPING RELOAD', 'automation', clear=True) + function() + + +def add_menu_items(menu_items, parent=None): + if not parent: + parent = mw.form.menubar + for title, action in menu_items: + if title == "SEPARATOR": + parent.addSeparator() + elif isinstance(action, list): + menu = QMenu(_(title), parent) + parent.insertMenu(mw.form.menuTools.menuAction(), menu) + add_menu_items(action, menu) + else: + checkable = False + if isinstance(action, dict): + options = action + action = options['action'] + if 'checkable' in options: + checkable = options['checkable'] + # if ANKNOTES.DEVELOPER_MODE.ENABLED and ANKNOTES.DEVELOPER_MODE.AUTO_RELOAD_MODULES: + action = auto_reload_wrapper(action) + # noinspection PyArgumentList + menu_action = QAction(_(title), mw, checkable=checkable) + parent.addAction(menu_action) + parent.connect(menu_action, SIGNAL("triggered()"), action) + if checkable: + anknotes_checkable_menu_items[title] = menu_action + + +def anknotes_menu_auto_import_changed(): + title = "&Enable Auto Import On Profile Load" + doAutoImport = anknotes_checkable_menu_items[title].isChecked() + mw.col.conf[ + SETTINGS.ANKNOTES_CHECKABLE_MENU_ITEMS_PREFIX.getDefault() + '_' + title.replace(' ', '_').replace('&', '')] = doAutoImport + mw.col.setMod() + mw.col.save() + # import_timer_toggle() + + +def anknotes_load_menu_settings(): + global anknotes_checkable_menu_items + for title, menu_action in anknotes_checkable_menu_items.items(): + menu_action.setChecked(mw.col.conf.get( + SETTINGS.ANKNOTES_CHECKABLE_MENU_ITEMS_PREFIX + '_' + title.replace(' ', '_').replace('&', ''), False)) + + +def import_from_evernote_manual_metadata(guids=None): + if not guids: + guids = find_evernote_guids(file(FILES.LOGS.FDN.UNIMPORTED_EVERNOTE_NOTES, 'r').read()) + log("Manually downloading %d Notes" % len(guids)) + controller = anknotes.Controller.Controller() + controller.forceAutoPage = True + controller.currentPage = 1 + controller.ManualGUIDs = guids + controller.proceed() + + +def import_from_evernote(auto_page_callback=None): + controller = anknotes.Controller.Controller() + controller.auto_page_callback = auto_page_callback + if auto_page_callback: + controller.forceAutoPage = True + controller.currentPage = 1 + else: + controller.forceAutoPage = False + controller.currentPage = SETTINGS.EVERNOTE.PAGINATION_CURRENT_PAGE.fetch(1) + controller.proceed() + + +def lxml_test(): + log("Creating Subnotes", 'automation') + guids = ankDB().list("tagNames LIKE '{t_out}' ORDER BY title ASC ", columns='guid') + anknotes.create_subnotes.create_subnotes(guids) + + +def upload_validated_notes(automated=False, **kwargs): + controller = anknotes.Controller.Controller() + controller.upload_validated_notes(automated) + + +def find_deleted_notes(automated=False): + if not automated: + showInfo("""In order for this to work, you must create a 'Table of Contents' Note using the Evernote desktop application. Include all notes that you want to sync with Anki. + +Export this note to the following path: +%s + +Press Okay to save and close your Anki collection, open the command-line deleted notes detection tool, and then re-open your Anki collection. + +Once the command line tool is done running, you will get a summary of the results, and will be prompted to delete Anki Orphan Notes or download Missing Evernote Notes""".replace( + '\n', '\n
      ') % FILES.USER.TABLE_OF_CONTENTS_ENEX, richText=True) + from anknotes import find_deleted_notes + returnedData = find_deleted_notes.do_find_deleted_notes() + if returnedData is False: + showInfo("An error occurred while executing the script. Please ensure you created the TOC note and saved it as instructed in the previous dialog.") + return + lines = returnedData['Summary'] + info = tableify_lines(lines, '#|Type|Info') + # info = '%s
      #Type
      ' % '\n'.join(lines) + # info = info.replace('\n', '\n
      ').replace(' ', '    ') + anknotes_dels = returnedData['AnknotesOrphans'] + anknotes_dels_count = len(anknotes_dels) + anki_dels = returnedData['AnkiOrphans'] + anki_dels_count = len(anki_dels) + missing_evernote_notes = returnedData['MissingEvernoteNotes'] + missing_evernote_notes_count = len(missing_evernote_notes) + showInfo(info, richText=True, minWidth=600) + db_changed = False + if anknotes_dels_count > 0: + correct_code = 'ANKNOTES_DEL_%d' % anknotes_dels_count + code = getText("Please enter code '%s' to delete your orphan Anknotes DB note(s)" % correct_code)[0] + if code == correct_code: + ankDB().executemany("DELETE FROM {n} WHERE guid = ?", [[x] for x in anknotes_dels]) + delete_anki_notes_and_cards_by_guid(anknotes_dels) + db_changed = True + show_tooltip("Deleted all %d Orphan Anknotes DB Notes" % anknotes_dels_count, 5, 3) + if anki_dels_count > 0: + correct_code = 'ANKI_DEL_%d' % anki_dels_count + code = getText("Please enter code '%s' to delete your orphan Anki note(s)" % correct_code)[0] + if code == correct_code: + delete_anki_notes_and_cards_by_guid(anki_dels) + db_changed = True + show_tooltip("Deleted all %d Orphan Anki Notes" % anki_dels_count, 5, 6) + if db_changed: + ankDB().commit() + if missing_evernote_notes_count > 0: + text = "Would you like to import %d missing Evernote Notes?

      Click to view results" % ( + missing_evernote_notes_count, + convert_filename_to_local_link(get_log_full_path(FILES.LOGS.FDN.UNIMPORTED_EVERNOTE_NOTES, filter_disabled=False))) + if showInfo(text, cancelButton=True, richText=True): + import_from_evernote_manual_metadata(missing_evernote_notes) + + +def validate_pending_notes(showAlerts=True, uploadAfterValidation=True, callback=None, unloadedCollection=False, + reload_delay=10): + if not unloadedCollection: + return unload_collection( + lambda *xargs, **xkwargs: validate_pending_notes(showAlerts, uploadAfterValidation, callback(*xargs, **xkwargs), + True)) + log("Validating Notes", 'automation') + if showAlerts: + showInfo("""Press Okay to save and close your Anki collection, open the command-line note validation tool, and then re-open your Anki collection.%s + +Anki will be unresponsive until the validation tool completes. This will take at least 45 seconds. The tool's output will be displayed upon completion. """ + % ( + ' You will be given the option of uploading successfully validated notes once your Anki collection is reopened.' if uploadAfterValidation else '')) + handle = Popen(['python', FILES.SCRIPTS.VALIDATION], stdin=PIPE, stderr=PIPE, stdout=PIPE, shell=True) + stdoutdata, stderrdata = handle.communicate() + stdoutdata = re.sub(' +', ' ', stdoutdata) + info = ("ERROR: {%s}
      " % stderrdata) if stderrdata else '' + allowUpload = True + if showAlerts: + tds = [[str(count), 'VIEW %s VALIDATIONS LOG' % (fn, key.upper())] for key, fn, count in [ + [key, get_log_full_path('MakeNoteQueue\\' + key, filter_disabled=False, as_url_link=True), + int(re.search(r'CHECKING +(\d{1,3}) +' + key.upper() + ' MAKE NOTE QUEUE ITEMS', stdoutdata).group(1))] + for key in ['Pending', 'Successful', 'Failed']] if count > 0] + if not tds: + show_tooltip("No notes found in the validation queue.") + allowUpload = False + else: + info += tableify_lines(tds, '#|Results') + successful = int( + re.search(r'CHECKING +(\d{1,3}) +' + 'Successful'.upper() + ' MAKE NOTE QUEUE ITEMS', stdoutdata).group( + 1)) + allowUpload = (uploadAfterValidation and successful > 0) + allowUpload = allowUpload & showInfo("Completed: %s
      %s" % ( + 'Press Okay to begin uploading %d successfully validated note(s) to the Evernote Servers' % successful if ( + uploadAfterValidation and successful > 0) else '', + info), cancelButton=(successful > 0), richText=True) + log("Validate Notes completed", 'automation') + if callback is None and allowUpload: + def callback(*xa, **xkw): return upload_validated_notes() + mw.progress.timer(reload_delay * 1000, lambda: reload_collection(callback), False) + + +def modify_collection(collection_operation, action_str='modifying collection', callback=None, callback_failure=False, + callback_delay=0, delay=30, attempt=1, max_attempts=5, **kwargs): + passed = False + retry = ( + "Will try again in %ds" % delay + ' (Attempt #%d)' % attempt if attempt > 0 else '') if attempt <= max_attempts else "Max attempts of %d exceeded... Aborting operation" % max_attempts + return_val = None + try: + return_val = collection_operation() + passed = True + except (sqlite.OperationalError, sqlite.ProgrammingError, Exception) as e: + if e.message.replace(".", "") == 'database is locked': + friendly_message = 'sqlite database is locked' + elif e.message == "Cannot operate on a closed database.": + friendly_message = 'sqlite database is closed' + else: + if e.message.replace('.', '') == 'database is locked': + log_error('**locked', crosspost='automation', + crosspost_to_default=False) + import traceback + type = str(e.__class__) + type = type[type.find("'") + 1:type.rfind("'")] + friendly_message = ('Unhandled Error' if type == 'Exception' else type) + ':\n Full Error: ' + ' '.join( + str(e).split()) + '\n Message: "%s"' % e.message + '\n Trace: ' + traceback.format_exc() + '\n' + log_error(" > Modify Collection: Error %s: %s. %s" % (action_str, retry, friendly_message), time_out=10000, + do_show_tooltip=True, crosspost='automation', crosspost_to_default=False) + if not passed: + if callback_failure is False: + return False + if attempt > max_attempts: + return callback(None, **kwargs) + return create_timer(delay, modify_collection, collection_operation, action_str, callback, callback_failure, callback_delay, delay, attempt + 1, **kwargs) + if not callback: + log(" > Modify Collection: Completed %s" % action_str, 'automation') + return None + log(" > Modify Collection: Completed %s" % action_str + ': %s Initiated' % ( + '%ds Callback Timer' % callback_delay if callback_delay > 0 else 'Callback'), 'automation') + if callback_delay > 0: + return create_timer(callback_delay, callback, return_val, **kwargs) + return callback(return_val, **kwargs) + +def reload_collection(callback=None, reopen_delay=0, callback_delay=30, *a, **kw): + if not mw.col is None: + try: + myDB = ankDB(reset=True) + db = myDB._db + assert db is not None + cur = myDB.execute("SELECT title FROM {t} WHERE 1 ORDER BY RANDOM() LIMIT 1") + assert cur is not None + result = cur.fetchone() + log(" > Reload Collection: Not needed: ankDB exists and cursor created: %s" % (str_safe(result[0])), + 'automation') + if callback: + callback(True) + return True + except (sqlite.ProgrammingError, Exception) as e: + if e.message == "Cannot operate on a closed database": + # mw.loadCollection() + log(" > Reloading Collection Check: DB is Closed. Proceed with reload. Col: " + str(mw.col), + 'automation') + else: + import traceback + log(" > Reloading Collection Check Failed : " + str(e) + '\n - Trace: ' + traceback.format_exc(), + 'automation') + log(" > Initiating Reload: %sInitiated: %s" % ( + '%ds Timer ' % reopen_delay if reopen_delay > 0 else '', str(mw.col)), 'automation') + if reopen_delay > 0: + def callback_reopen(): + def inner_callback(*xa, **xkw): return callback(*a, **kw) + return modify_collection(do_load_collection, 'reload collection', inner_callback, callback_delay=callback_delay, *a, **kw) + return create_timer(reopen_delay, callback_reopen) + return modify_collection(do_load_collection, 'Reloading Collection', callback, callback_delay=callback_delay, *a, **kw) + +def do_load_collection(): + log(" > Do Load Collection: Attempting mw.loadCollection()", 'automation') + mw.loadCollection() + log(" > Do Load Collection: Attempting ankDB(True)", 'automation') + ankDB(reset=True) + + +def do_unload_collection(): + mw.unloadCollection() + + +def unload_collection(*args, **kwargs): + log("Initiating Unload Collection:", 'automation') + modify_collection(mw.unloadCollection, 'Unload Collection', *args, **kwargs) + + +def load_controller(callback=None, callback_failure=True, *args, **kwargs): + # log('Col: ' + str(mw.col), 'automation') + # log('Col db: ' + str(mw.col.db), 'automation') + modify_collection(anknotes.Controller.Controller, 'Loading Controller', callback, callback_failure=callback_failure) + # return anknotes.Controller.Controller() + + +def see_also(steps=None, showAlerts=None, validationComplete=False, controller=None, upload=True): + all_args = locals() + show_tooltip_enabled = False + if controller is None: + check = reload_collection() + if check: + log("See Also --> 2. Loading Controller", 'automation') + callback_args = dict(all_args) + del callback_args['controller'] + load_controller(lambda x, *xa, **xkw: see_also(controller=x, **callback_args)) + else: + log("See Also --> 1. Loading Collection", 'automation') + reload_collection(lambda *xa, **xkw: see_also(**all_args)) + return False + if not steps: + steps = range(1, 10) + if isinstance(steps, int): + steps = [steps] + + if not upload: + if 3 in steps: + steps.remove(3) + if 7 in steps: + steps.remove(7) + steps = list(steps) + log("See Also --> 3. Proceeding: " + ', '.join(map(str, steps)), 'automation') + multipleSteps = (len(steps) > 1) + if showAlerts is None: + showAlerts = not multipleSteps + if multipleSteps: + show_tooltip_enabled = show_tooltip.enabled if hasattr(show_tooltip, 'enabled') else None + show_tooltip.enabled = False + remaining_steps = steps + if 1 in steps: + # Should be unnecessary once See Also algorithms are finalized + log(" > See Also: Step 1: Process Un Added See Also Notes", crosspost='automation') + controller.process_unadded_see_also_notes() + if 2 in steps: + log(" > See Also: Step 2: Create Auto TOC Evernote Notes", crosspost='automation') + controller.create_toc_auto() + if 3 in steps: + if validationComplete: + log(" > See Also: Step 3B: Validate and Upload Auto TOC Notes: Upload Validated Notes", + crosspost='automation') + upload_validated_notes(multipleSteps) + validationComplete = False + else: + steps = [-3] + if 4 in steps: + log(" > See Also: Step 4: Extract Links from TOC", crosspost='automation') + controller.anki.extract_links_from_toc() + if 5 in steps: + log(" > See Also: Step 5: Insert TOC/Outline Links Into Anki Notes' See Also Field", crosspost='automation') + controller.anki.insert_toc_into_see_also() + if 6 in steps: + log(" > See Also: Step 6: Update See Also Footer In Evernote Notes", crosspost='automation') + from anknotes import detect_see_also_changes + detect_see_also_changes.main() + if 7 in steps: + if validationComplete: + log(" > See Also: Step 7B: Validate and Upload Modified Evernote Notes: Upload Validated Notes", + crosspost='automation') + upload_validated_notes(multipleSteps) + else: + steps = [-7] + if 8 in steps: + log(" > See Also: Step 8: Insert TOC/Outline Contents Into Anki Notes", crosspost='automation') + controller.anki.insert_toc_and_outline_contents_into_notes() + + do_validation = steps[0] * -1 + if do_validation > 0: + log(" > See Also: Step %dA: Validate and Upload %s Notes: Validate Notes" % ( + do_validation, {3: 'Auto TOC', 7: 'Modified Evernote'}[do_validation]), crosspost='automation') + remaining_steps = remaining_steps[remaining_steps.index(do_validation):] + callback_args = all_args + callback_args.update(dict(steps=remaining_steps, showAlerts=False, validationComplete=True)) + validate_pending_notes(showAlerts, callback=lambda *xargs, **xkwargs: see_also(**callback_args)) + if multipleSteps: + show_tooltip.enabled = show_tooltip_enabled + +def update_ancillary_data(): + controller = anknotes.Controller.Controller() + log("Ancillary data - loaded controller - " + str(controller.evernote) + " - " + str(controller.evernote.client), + 'client') + controller.update_ancillary_data() + + +def resync_with_local_db(): + controller = anknotes.Controller.Controller() + controller.resync_with_local_db() + + +anknotes_checkable_menu_items = {} diff --git a/anknotes/methods.py b/anknotes/methods.py new file mode 100644 index 0000000..352c0ad --- /dev/null +++ b/anknotes/methods.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +import re +from datetime import datetime + +### Anknotes Imports +from anknotes.constants import * +from anknotes.args import Args +from anknotes.imports import in_anki + +### Anki Imports +if in_anki(): + from aqt import mw + +def create_timer(delay, callback, *a, **kw): + kw, repeat = Args(kw).get_kwargs(['repeat', False]) + if a or kw: + def cb(): return callback(*a, **kw) + else: + cb = callback + return mw.progress.timer(abs(delay) * 1000, cb, repeat) \ No newline at end of file diff --git a/anknotes/oauth2/__init__.py b/anknotes/oauth2/__init__.py index 088e107..a2745c1 100644 --- a/anknotes/oauth2/__init__.py +++ b/anknotes/oauth2/__init__.py @@ -29,6 +29,7 @@ import urlparse import hmac import binascii + import httplib2 try: diff --git a/anknotes/settings.py b/anknotes/settings.py new file mode 100644 index 0000000..37106b9 --- /dev/null +++ b/anknotes/settings.py @@ -0,0 +1,517 @@ +# -*- coding: utf-8 -*- + +### Anknotes Shared Imports +from anknotes.shared import * +from anknotes.graphics import * +from anknotes.imports import in_anki +from anknotes.dicts import DictCaseInsensitive + +### Anki Imports +if in_anki(): + import anki + import aqt + from aqt.preferences import Preferences + from aqt.utils import getText, openLink, getOnlyText + from aqt.qt import QLineEdit, QLabel, QVBoxLayout, QHBoxLayout, QGroupBox, SIGNAL, QCheckBox, \ + QComboBox, QSpacerItem, QSizePolicy, QWidget, QSpinBox, QFormLayout, QGridLayout, QFrame, QPalette, \ + QRect, QStackedLayout, QDateEdit, QDateTimeEdit, QTimeEdit, QDate, QDateTime, QTime, QPushButton, QIcon, \ + QMessageBox, QPixmap + from aqt import mw + +ANKI = SETTINGS.ANKI +DECKS = ANKI.DECKS +TAGS = ANKI.TAGS +EVERNOTE = SETTINGS.EVERNOTE +QUERY = EVERNOTE.QUERY +QUERY_TEXTBOXES = ['TAGS', 'EXCLUDED_TAGS', 'SEARCH_TERMS', 'NOTE_TITLE', 'NOTEBOOK'] + +class EvernoteQueryLocationValueQSpinBox(QSpinBox): + __prefix = "" + + def setPrefix(self, text): + self.__prefix = text + + def prefix(self): + return self.__prefix + + def valueFromText(self, text): + if text is self.prefix(): + return 0 + return text[len(self.prefix()) + 1:] + + def textFromValue(self, value): + return self.prefix() + ("-%d" % value if value else "") + +def get_conf(setting, default_value): + + return mw.col.conf.get(setting, default_value) + +def setup_evernote(self): + global icoEvernoteWeb + global imgEvernoteWeb + global elements + global evernote_query_last_updated + global evernote_pagination_current_page_spinner + + def update_checkbox(setting): + if setting is DECKS.EVERNOTE_NOTEBOOK_INTEGRATION and not elements[DECKS.BASE].text(): + return + if setting.get.startswith(QUERY.get): + update_evernote_query_visibilities() + setting.save(elements[setting].isChecked()) + # mw.col.conf[setting] = + if setting is QUERY.USE_TAGS: + update_evernote_query_visibilities() + if setting is QUERY.LAST_UPDATED.USE: + evernote_query_last_updated_value_set_visibilities() + + def create_checkbox(setting, label=" ", default_value=False, is_fixed_size=False, fixed_width=None): + if isinstance(label, bool): + default_value = label + label = " " + checkbox = QCheckBox(label, self) + sval = setting.fetch() + if not isinstance(sval, bool): + sval = default_value + checkbox.setChecked(sval) + # noinspection PyUnresolvedReferences + checkbox.stateChanged.connect(lambda: update_checkbox(setting)) + if is_fixed_size or fixed_width: + checkbox.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + if fixed_width: + checkbox.setFixedWidth(fixed_width) + elements[setting] = checkbox + return checkbox + + def create_checked_checkbox(*a, **kw): + kw['default_value'] = True + return create_checkbox(*a, **kw) + + def update_text(setting, text): + text = text.strip() + setting.save(text) + if setting is DECKS.BASE: + update_anki_deck_visibilities() + if setting.get.startswith(QUERY.get): + if text: + use_key = getattr(QUERY, 'USE_' + setting.label.name) + elements[use_key].setChecked(True) + evernote_query_text_changed() + if setting is QUERY.SEARCH_TERMS: + update_evernote_query_visibilities() + + def create_textbox(setting, default_value=""): + textbox = QLineEdit() + textbox.setText(setting.fetch(default_value)) + textbox.connect(textbox, + SIGNAL("textEdited(QString)"), + lambda text: update_text(setting, text)) + elements[setting] = textbox + return textbox + + def add_query_row(setting, is_checked=False, **kw): + try: + default_value = setting.val + except: + default_value = '' + row_label = ' '.join(x.capitalize() for x in setting.replace('_', ' ').split()) + hbox = QHBoxLayout() + hbox.addWidget(create_checkbox(getattr(QUERY, 'USE_' + setting), + default_value=is_checked, **kw)) + hbox.addWidget(create_textbox(getattr(QUERY, setting), default_value)) + form.addRow(row_label, hbox) + + def gen_qt_hr(): + vbox = QVBoxLayout() + hr = QFrame() + hr.setAutoFillBackground(True) + hr.setFrameShape(QFrame.HLine) + hr.setStyleSheet("QFrame { background-color: #0060bf; color: #0060bf; }") + hr.setFixedHeight(2) + vbox.addWidget(hr) + vbox.addSpacing(4) + return vbox + + # Begin setup_evernote() + widget = QWidget() + layout = QVBoxLayout() + elements = {} + rm_log_path('Dicts\\') + evernote_query_last_updated = DictCaseInsensitive() + + + ########################## QUERY ########################## + ##################### QUERY: TEXTBOXES #################### + group = QGroupBox("EVERNOTE SEARCH OPTIONS:") + group.setStyleSheet('QGroupBox{ font-size: 10px; font-weight: bold; color: rgb(105, 170, 53);}') + form = QFormLayout() + + form.addRow(gen_qt_hr()) + + # Show Generated Evernote Query Button + button_show_generated_evernote_query = QPushButton(icoEvernoteWeb, "Show Full Query", self) + button_show_generated_evernote_query.setAutoDefault(False) + button_show_generated_evernote_query.connect(button_show_generated_evernote_query, + SIGNAL("clicked()"), + handle_show_generated_evernote_query) + + + # Add Form Row for Match Any Terms + hbox = QHBoxLayout() + hbox.addWidget(create_checked_checkbox(QUERY.ANY, " Match Any Terms", is_fixed_size=True)) + hbox.addWidget(button_show_generated_evernote_query) + form.addRow("Search Parameters:", hbox) + + # Add Form Rows for Evernote Query Textboxes + for el in QUERY_TEXTBOXES: + add_query_row(el, 'TAGS' in el) + + ################### QUERY: LAST UPDATED ################### + # Evernote Query: Last Updated Type + evernote_query_last_updated.type = QComboBox() + evernote_query_last_updated.type.setStyleSheet(' QComboBox { color: rgb(45, 79, 201); font-weight: bold; } ') + evernote_query_last_updated.type.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + evernote_query_last_updated.type.addItems([u"Δ Day", u"Δ Week", u"Δ Month", u"Δ Year", "Date", "+ Time"]) + evernote_query_last_updated.type.setCurrentIndex(QUERY.LAST_UPDATED.TYPE.fetch(EvernoteQueryLocationType.RelativeDay)) + evernote_query_last_updated.type.activated.connect(update_evernote_query_last_updated_type) + + + # Evernote Query: Last Updated Type: Relative Date + evernote_query_last_updated.value.relative.spinner = EvernoteQueryLocationValueQSpinBox() + evernote_query_last_updated.value.relative.spinner.setVisible(False) + evernote_query_last_updated.value.relative.spinner.setStyleSheet( + " QSpinBox, EvernoteQueryLocationValueQSpinBox { font-weight: bold; color: rgb(173, 0, 0); } ") + evernote_query_last_updated.value.relative.spinner.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + evernote_query_last_updated.value.relative.spinner.connect(evernote_query_last_updated.value.relative.spinner, + SIGNAL("valueChanged(int)"), + update_evernote_query_last_updated_value_relative_spinner) + + # Evernote Query: Last Updated Type: Absolute Date + evernote_query_last_updated.value.absolute.date = QDateEdit() + evernote_query_last_updated.value.absolute.date.setDisplayFormat('M/d/yy') + evernote_query_last_updated.value.absolute.date.setCalendarPopup(True) + evernote_query_last_updated.value.absolute.date.setVisible(False) + evernote_query_last_updated.value.absolute.date.setStyleSheet( + "QDateEdit { font-weight: bold; color: rgb(173, 0, 0); } ") + evernote_query_last_updated.value.absolute.date.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + evernote_query_last_updated.value.absolute.date.connect(evernote_query_last_updated.value.absolute.date, + SIGNAL("dateChanged(QDate)"), + update_evernote_query_last_updated_value_absolute_date) + + # Evernote Query: Last Updated Type: Absolute DateTime + evernote_query_last_updated.value.absolute.datetime = QDateTimeEdit() + evernote_query_last_updated.value.absolute.datetime.setDisplayFormat('M/d/yy h:mm AP') + evernote_query_last_updated.value.absolute.datetime.setCalendarPopup(True) + evernote_query_last_updated.value.absolute.datetime.setVisible(False) + evernote_query_last_updated.value.absolute.datetime.setStyleSheet( + "QDateTimeEdit { font-weight: bold; color: rgb(173, 0, 0); } ") + evernote_query_last_updated.value.absolute.datetime.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + evernote_query_last_updated.value.absolute.datetime.connect(evernote_query_last_updated.value.absolute.datetime, + SIGNAL("dateTimeChanged(QDateTime)"), + update_evernote_query_last_updated_value_absolute_datetime) + + + + # Evernote Query: Last Updated Type: Absolute Time + evernote_query_last_updated.value.absolute.time = QTimeEdit() + evernote_query_last_updated.value.absolute.time.setDisplayFormat('h:mm AP') + evernote_query_last_updated.value.absolute.time.setVisible(False) + evernote_query_last_updated.value.absolute.time.setStyleSheet( + "QTimeEdit { font-weight: bold; color: rgb(143, 0, 30); } ") + evernote_query_last_updated.value.absolute.time.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + evernote_query_last_updated.value.absolute.time.connect(evernote_query_last_updated.value.absolute.time, + SIGNAL("timeChanged(QTime)"), + update_evernote_query_last_updated_value_absolute_time) + + # Create HBox for Separated Date & Time + hbox_datetime = QHBoxLayout() + hbox_datetime.addWidget(evernote_query_last_updated.value.absolute.date) + hbox_datetime.addWidget(evernote_query_last_updated.value.absolute.time) + + # Evernote Query: Last Updated Type + evernote_query_last_updated.value.stacked_layout = QStackedLayout() + evernote_query_last_updated.value.stacked_layout.addWidget(evernote_query_last_updated.value.relative.spinner) + evernote_query_last_updated.value.stacked_layout.addItem(hbox_datetime) + + # Add Form Row for Evernote Query: Last Updated + hbox = QHBoxLayout() + label = QLabel("Last Updated: ") + label.setMinimumWidth(SETTINGS.FORM.LABEL_MINIMUM_WIDTH.val) + hbox.addWidget(create_checkbox(QUERY.LAST_UPDATED.USE, is_fixed_size=True)) + hbox.addWidget(evernote_query_last_updated.type) + hbox.addWidget(evernote_query_last_updated.value.relative.spinner) + hbox.addWidget(evernote_query_last_updated.value.absolute.date) + hbox.addWidget(evernote_query_last_updated.value.absolute.time) + form.addRow(label, hbox) + + # Add Horizontal Row Separator + form.addRow(gen_qt_hr()) + + ############################ PAGINATION ########################## + # Evernote Pagination: Current Page + evernote_pagination_current_page_spinner = QSpinBox() + evernote_pagination_current_page_spinner.setStyleSheet("QSpinBox { font-weight: bold; color: rgb(173, 0, 0); } ") + evernote_pagination_current_page_spinner.setPrefix("PAGE: ") + evernote_pagination_current_page_spinner.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + evernote_pagination_current_page_spinner.setValue(EVERNOTE.PAGINATION_CURRENT_PAGE.fetch(1)) + evernote_pagination_current_page_spinner.connect(evernote_pagination_current_page_spinner, + SIGNAL("valueChanged(int)"), + update_evernote_pagination_current_page_spinner) + + # Evernote Pagination: Automation + hbox = QHBoxLayout() + hbox.addWidget(create_checked_checkbox(EVERNOTE.AUTO_PAGING, " Automate", fixed_width=105)) + hbox.addWidget(evernote_pagination_current_page_spinner) + + # Add Form Row for Evernote Pagination + form.addRow("Pagination:", hbox) + + # Add Query Form to Group Box + group.setLayout(form) + + # Add Query Group Box to Main Layout + layout.addWidget(group) + + ########################## DECK ########################## + # Setup Group Box and Form + group = QGroupBox("ANKI NOTE OPTIONS:") + group.setStyleSheet('QGroupBox{ font-size: 10px; font-weight: bold; color: rgb(105, 170, 53);}') + form = QFormLayout() + + # Add Horizontal Row Separator + form.addRow(gen_qt_hr()) + + # Add Form Row for Default Anki Deck + hbox = QHBoxLayout() + hbox.insertSpacing(0, 33) + hbox.addWidget(create_textbox(DECKS.BASE, DECKS.BASE_DEFAULT_VALUE)) + label_deck = QLabel("Anki Deck:") + label_deck.setMinimumWidth(SETTINGS.FORM.LABEL_MINIMUM_WIDTH.val) + form.addRow(label_deck, hbox) + + # Add Form Row for Evernote Notebook Integration + label_deck = QLabel("Evernote Notebook:") + label_deck.setMinimumWidth(SETTINGS.FORM.LABEL_MINIMUM_WIDTH.val) + form.addRow("", create_checked_checkbox(DECKS.EVERNOTE_NOTEBOOK_INTEGRATION, " Append Evernote Notebook")) + + # Add Horizontal Row Separator + form.addRow(gen_qt_hr()) + + ############################ TAGS ########################## + # Add Form Row for Evernote Tag Options + label = QLabel("Evernote Tags:") + label.setMinimumWidth(SETTINGS.FORM.LABEL_MINIMUM_WIDTH.val) + + # Tags: Save To Anki Note + form.addRow(label, create_checkbox(TAGS.KEEP_TAGS, " Save To Anki Note", TAGS.KEEP_TAGS_DEFAULT_VALUE)) + hbox = QHBoxLayout() + hbox.insertSpacing(0, 33) + hbox.addWidget(create_textbox(TAGS.TO_DELETE)) + + # Tags: Tags To Delete + form.addRow("Tags to Delete:", hbox) + form.addRow(" ", create_checkbox(TAGS.DELETE_EVERNOTE_QUERY_TAGS, " Also Delete Search Tags")) + + # Add Horizontal Row Separator + form.addRow(gen_qt_hr()) + + ############################ NOTE UPDATING ########################## + # Note Update Method + update_existing_notes = QComboBox() + update_existing_notes.setStyleSheet( + ' QComboBox { color: #3b679e; font-weight: bold; } QComboBoxItem { color: #A40F2D; font-weight: bold; } ') + update_existing_notes.addItems(["Ignore Existing Notes", "Update In-Place", + "Delete and Re-Add"]) + sval = ANKI.UPDATE_EXISTING_NOTES.fetch() + if not isinstance(sval, int): + sval = ANKI.UPDATE_EXISTING_NOTES.val + update_existing_notes.setCurrentIndex(sval) + update_existing_notes.activated.connect(update_update_existing_notes) + + # Add Form Row for Note Update Method + hbox = QHBoxLayout() + hbox.insertSpacing(0, 33) + hbox.addWidget(update_existing_notes) + form.addRow("Note Updating:", hbox) + + # Add Note Update Method Form to Group Box + group.setLayout(form) + + # Add Note Update Method Group Box to Main Layout + layout.addWidget(group) + + ######################### UPDATE VISIBILITIES ####################### + # Update Visibilities of Anki Deck Options + update_anki_deck_visibilities() + + # Update Visibilities of Query Options + evernote_query_text_changed() + update_evernote_query_visibilities() + + ######################## ADD TO SETTINGS PANEL ###################### + # Vertical Spacer + vertical_spacer = QSpacerItem(20, 0, QSizePolicy.Minimum, QSizePolicy.Expanding) + layout.addItem(vertical_spacer) + + # Parent Widget + widget.setLayout(layout) + + # New Tab + self.form.tabWidget.addTab(widget, "Anknotes") + +def update_anki_deck_visibilities(): + if not elements[DECKS.BASE].text(): + elements[DECKS.EVERNOTE_NOTEBOOK_INTEGRATION].setChecked(True) + elements[DECKS.EVERNOTE_NOTEBOOK_INTEGRATION].setEnabled(False) + else: + elements[DECKS.EVERNOTE_NOTEBOOK_INTEGRATION].setEnabled(True) + elements[DECKS.EVERNOTE_NOTEBOOK_INTEGRATION].setChecked( + DECKS.EVERNOTE_NOTEBOOK_INTEGRATION.fetch(True)) + +def update_evernote_pagination_current_page_spinner(value): + if value < 1: + value = 1 + evernote_pagination_current_page_spinner.setValue(1) + EVERNOTE.PAGINATION_CURRENT_PAGE.save(value) + + +def update_update_existing_notes(index): + ANKI.UPDATE_EXISTING_NOTES.save(index) + + +def evernote_query_text_changed(): + for key in QUERY_TEXTBOXES: + setting_use = getattr(QUERY, 'USE_' + key) + el_use = elements[setting_use] + is_enabled = is_checked = bool(elements[getattr(QUERY, key)].text()) + if is_checked: + is_checked = setting_use.fetch(True) + el_use.setEnabled(is_enabled) + el_use.setChecked(is_checked) + +def update_evernote_query_visibilities(): + for key in QUERY_TEXTBOXES: + el_use = elements[getattr(QUERY, 'USE_' + key)] + elements[getattr(QUERY, key)].setEnabled(el_use.isChecked() or not el_use.isEnabled()) + evernote_query_last_updated_value_set_visibilities() + + +def update_evernote_query_last_updated_type(index): + QUERY.LAST_UPDATED.TYPE.save(index) + evernote_query_last_updated_value_set_visibilities() + + +def evernote_query_last_updated_get_current_value(): + index = QUERY.LAST_UPDATED.TYPE.fetch(0) + if index < EvernoteQueryLocationType.AbsoluteDate: + spinner_text = ['day', 'week', 'month', 'year'][index] + spinner_val = QUERY.LAST_UPDATED.VALUE.RELATIVE.fetch(0) + if spinner_val > 0: + spinner_text += "-" + str(spinner_val) + return spinner_text + + absolute_date_str = QUERY.LAST_UPDATED.VALUE.ABSOLUTE.DATE.fetch().replace(' ', '') + if index is EvernoteQueryLocationType.AbsoluteDate: + return absolute_date_str + absolute_time_str = QUERY.LAST_UPDATED.VALUE.ABSOLUTE.TIME.fetch("{:HH mm ss}".format(datetime.now())).replace(' ', '') + return absolute_date_str + "'T'" + absolute_time_str + + +def evernote_query_last_updated_value_set_visibilities(): + index = QUERY.LAST_UPDATED.TYPE.fetch(0) + use_last_updated = elements[QUERY.LAST_UPDATED.USE].isChecked() + with evernote_query_last_updated as lu, lu.value as v, QUERY.LAST_UPDATED.VALUE as LUV: + lu.type.setEnabled(use_last_updated) + v.absolute.date.setEnabled(use_last_updated) + v.absolute.time.setEnabled(use_last_updated) + v.relative.spinner.setEnabled(use_last_updated) + if not use_last_updated: + return + + absolute_date = LUV.ABSOLUTE.DATE.fetch() + absolute_date = QDate().fromString(absolute_date, 'yyyy MM dd') + if index < EvernoteQueryLocationType.AbsoluteDate: + v.absolute.date.setVisible(False) + v.absolute.time.setVisible(False) + spinner_prefix = ['day', 'week', 'month', 'year'][index] + v.relative.spinner.setPrefix(spinner_prefix) + v.relative.spinner.setValue(int(LUV.RELATIVE.fetch(0))) + v.stacked_layout.setCurrentIndex(0) + else: + v.relative.spinner.setVisible(False) + v.absolute.date.setDate(absolute_date) + v.stacked_layout.setCurrentIndex(1) + if index is EvernoteQueryLocationType.AbsoluteDate: + v.absolute.time.setVisible(False) + v.absolute.datetime.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + else: + # absolute_time = "{:HH mm ss}".format(datetime.now()) + # absolute_time = QUERY.LAST_UPDATED_VALUE_ABSOLUTE_TIME.fetch(absolute_time) + absolute_time = LUV.ABSOLUTE.TIME.fetch("{:HH mm ss}".format(datetime.now())) + # absolute_time = QTime().fromString(absolute_time, 'HH mm ss') + v.absolute.time.setTime(QTime().fromString(absolute_time, 'HH mm ss')) + v.absolute.datetime.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + +def update_evernote_query_last_updated_value_relative_spinner(value): + if value < 0: + value = 0 + evernote_query_last_updated.value.relative.spinner.setValue(0) + QUERY.LAST_UPDATED.VALUE.RELATIVE.save(value) + + +def update_evernote_query_last_updated_value_absolute_date(date): + QUERY.LAST_UPDATED.VALUE.ABSOLUTE.DATE.save(date.toString('yyyy MM dd')) + + +def update_evernote_query_last_updated_value_absolute_datetime(dt): + QUERY.LAST_UPDATED.VALUE.ABSOLUTE.DATE.save(dt.toString('yyyy MM dd')) + QUERY.LAST_UPDATED.VALUE.ABSOLUTE.TIME.save(dt.toString('HH mm ss')) + + +def update_evernote_query_last_updated_value_absolute_time(time_value): + QUERY.LAST_UPDATED.VALUE.ABSOLUTE.TIME.save(time_value.toString('HH mm ss')) + + +def generate_evernote_query(): + def generate_tag_pred(tags, negate=False): + pred = '' + prefix = '-' if negate else '' + if not isinstance(tags, list): + tags = tags.replace(',', ' ').split() + for tag in tags: + tag = tag.strip() + if ' ' in tag: + tag = '"%s"' % tag + pred += prefix + 'tag:%s ' % tag + return pred + + # Begin generate_evernote_query() + query = "" + if QUERY.USE_NOTEBOOK.fetch(False): + query_notebook = QUERY.NOTEBOOK.fetch(QUERY.NOTEBOOK_DEFAULT_VALUE).strip() + query += 'notebook:"%s" ' % query_notebook + if QUERY.ANY.fetch(True): + query += "any: " + if QUERY.USE_NOTE_TITLE.fetch(False): + query_note_title = QUERY.NOTE_TITLE.fetch("") + if not query_note_title.startswith('"') and query_note_title.endswith('"'): + query_note_title = '"%s"' % query_note_title + query += 'intitle:%s ' % query_note_title + if QUERY.USE_TAGS.fetch(True): + query += generate_tag_pred(QUERY.TAGS.fetch(QUERY.TAGS_DEFAULT_VALUE)) + if QUERY.USE_EXCLUDED_TAGS.fetch(True): + query += generate_tag_pred(QUERY.EXCLUDED_TAGS.fetch(''), True) + if QUERY.LAST_UPDATED.USE.fetch(False): + query += " updated:%s " % evernote_query_last_updated_get_current_value() + if QUERY.USE_SEARCH_TERMS.fetch(False): + query += QUERY.SEARCH_TERMS.fetch("") + if not query.replace('any:','').strip(): + query = '*' + return query + + +def handle_show_generated_evernote_query(): + showInfo( + "The Evernote search query for your current options is below. You can press copy the text to your clipboard by pressing the copy keyboard shortcut (CTRL+C in Windows) while this message box has focus.\n\nQuery: %s" % generate_evernote_query(), + "Evernote Search Query") diff --git a/anknotes/shared.py b/anknotes/shared.py new file mode 100644 index 0000000..e69c46f --- /dev/null +++ b/anknotes/shared.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +### Python Imports +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite +import os +import re +from fnmatch import fnmatch +from bs4 import UnicodeDammit + + +### Anknotes Imports +from anknotes.constants import * +from anknotes.imports import * +# write_file_contents('Loading %s: Importing base' % __name__, 'load') +from anknotes.base import * +# write_file_contents('Loading %s: Imported base' % __name__, 'load') +# write_file_contents('Loading %s: Importing logging' % __name__, 'load') +from anknotes.logging import * +from anknotes.db import * +from anknotes.html import * +from anknotes.structs import * + +### Check if in Anki +if in_anki(): + from aqt import mw + from aqt.qt import QIcon, QPixmap, QPushButton, QMessageBox + from anknotes.evernote.edam.error.ttypes import EDAMSystemException, EDAMErrorCode, EDAMUserException, \ + EDAMNotFoundException + +class EvernoteQueryLocationType: + RelativeDay, RelativeWeek, RelativeMonth, RelativeYear, AbsoluteDate, AbsoluteDateTime = range(6) + +def get_tag_names_to_import(tagNames, evernoteQueryTags=None, evernoteTagsToDelete=None, keepEvernoteTags=None, + deleteEvernoteQueryTags=None): + def check_tag_name(v, tags_to_delete): + return v not in tags_to_delete and (not hasattr(v, 'Name') or getattr(v, 'Name') not in tags_to_delete) and ( + not hasattr(v, 'name') or getattr(v, 'name') not in tags_to_delete) + if keepEvernoteTags is None: + keepEvernoteTags = SETTINGS.ANKI.TAGS.KEEP_TAGS.fetch() + if not keepEvernoteTags: + return {} if isinstance(tagNames, dict) else [] + if evernoteQueryTags is None: + evernoteQueryTags = SETTINGS.EVERNOTE.QUERY.TAGS.fetch().replace(',', ' ').split() + if deleteEvernoteQueryTags is None: + deleteEvernoteQueryTags = SETTINGS.ANKI.TAGS.DELETE_EVERNOTE_QUERY_TAGS.fetch() + if evernoteTagsToDelete is None: + evernoteTagsToDelete = SETTINGS.ANKI.TAGS.TO_DELETE.fetch() + tags_to_delete = evernoteQueryTags if deleteEvernoteQueryTags else [] + evernoteTagsToDelete + if isinstance(tagNames, dict): + return {k: v for k, v in tagNames.items() if check_tag_name(v, tags_to_delete)} + return sorted([v for v in tagNames if check_tag_name(v, tags_to_delete)]) + + +def find_evernote_guids(content): + return [x.group('guid') for x in + re.finditer(r'\b(?P[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})\b', content)] + + +def find_evernote_links_as_guids(content): + return [x.Guid for x in find_evernote_links(content)] + + +def replace_evernote_web_links(content): + return re.sub( + r'https://www.evernote.com/shard/(s\d+)/[\w\d]+/(\d+)/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})', + r'evernote:///view/\2/\1/\3/\3/', content) + +def find_evernote_links(content): + """ + + :param content: + :return: + :rtype : list[EvernoteLink] + """ + # .NET regex saved to regex.txt as 'Finding Evernote Links' + content = replace_evernote_web_links(content) + regex_str = r"""(?si)evernote:///?view/(?P[\d]+?)/(?Ps\d+)/(?P[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})/(?P=guid)/?)["''](?:[^>]+)?>(?P.+?)</a>""" + ids = get_evernote_account_ids() + if not ids.Valid: + match = re.search(regex_str, content) + if match: + ids.update(match.group('uid'), match.group('shard')) + return [EvernoteLink(m) for m in re.finditer(regex_str, content)] + + +def check_evernote_guid_is_valid(guid): + return ankDB().exists(where="guid = '%s'" % guid) + + +def escape_regex(str_): + return re.sub(r"(?sx)(\(|\||\))", r"\\\1", str_) + + +def remove_evernote_link(link, html): + html = UnicodeDammit(html, ['utf-8'], is_html=True).unicode_markup + link_converted = UnicodeDammit(link.WholeRegexMatch, ['utf-8'], is_html=True).unicode_markup + sep = u'<span style="color: rgb(105, 170, 53);"> | </span>' + sep_regex = escape_regex(sep) + no_start_tag_regex = r'[^<]*' + regex_replace = r'<{0}[^>]*>[^<]*{1}[^<]*</{0}>' + # html = re.sub(regex_replace.format('li', link.WholeRegexMatch), "", html) + # Remove link + html = html.replace(link.WholeRegexMatch, "") + # Remove empty li + html = re.sub(regex_replace.format('li', no_start_tag_regex), "", html) + # Remove dangling separator + + regex_span = regex_replace.format('span', no_start_tag_regex) + no_start_tag_regex + sep_regex + html = re.sub(regex_span, "", html) + # Remove double separator + html = re.sub(sep_regex + no_start_tag_regex + sep_regex, sep_regex, html) + return html + + +def get_dict_from_list(lst, keys_to_ignore=list()): + dic = {} + for key, value in lst: + if not key in keys_to_ignore: + dic[key] = value + return dic + +def update_regex(): + regex_str = file(os.path.join(FOLDERS.ANCILLARY, 'regex-see_also.txt'), 'r').read() + regex_str = regex_str.replace('(?<', '(?P<') + regex_see_also._regex_see_also = re.compile(regex_str, re.UNICODE | re.VERBOSE | re.DOTALL) + +def regex_see_also(): + if not hasattr(regex_see_also, '_regex_see_also'): + update_regex() + return regex_see_also._regex_see_also diff --git a/anknotes/stopwatch/__init__.py b/anknotes/stopwatch/__init__.py new file mode 100644 index 0000000..767a969 --- /dev/null +++ b/anknotes/stopwatch/__init__.py @@ -0,0 +1,800 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2008 John Paulett (john -at- 7oars.com) +# All rights reserved. +# +# This software is licensed as described in the file COPYING, which +# you should have received as part of this distribution. + +import time +import re +import os +from anknotes.constants import ANKNOTES +from anknotes.base import is_str, item_to_list +from anknotes.structs import EvernoteAPIStatus +from anknotes.logging import caller_name, log, log_banner, log_blank, show_report, counts_as_str, get_log_full_path +from anknotes.counters import Counter, EvernoteCounter +from anknotes.dicts import DictCaseInsensitive + +"""stopwatch is a very simple Python module for measuring time. +Great for finding out how long code takes to execute. + +>>> import stopwatch +>>> t = stopwatch.Timer() +>>> t.elapsed +3.8274309635162354 +>>> print t +15.9507198334 sec +>>> t.stop() +30.153270959854126 +>>> print t +30.1532709599 sec + +Decorator exists for printing out execution times: +>>> from stopwatch import clockit +>>> @clockit + def mult(a, b): + return a * b +>>> print mult(2, 6) +mult in 1.38282775879e-05 sec +6 + +""" + +__version__ = '0.5' +__author__ = 'Avinash Puchalapalli <http://www.github.com/holycrepe/>' +__info__ = 'Forked from stopwatch 0.3.1 by John Paulett <http://blog.7oars.com>' + + +# class TimerCounts(object): +# Max, Current, Updated, Added, Queued, Error = (0) * 6 + +class ActionInfo(object): + Status = EvernoteAPIStatus.Uninitialized + __created_str = " Added to Anki" + __updated_str = " Updated in Anki" + __queued_str = " for Upload to Evernote" + + @property + def ActionShort(self): + try: + if self.__action_short: + return self.__action_short + finally: + return (self.ActionBase.upper() + 'ING').replace('INGING', 'ING').replace(' OFING', 'ING').replace( + "TIONING", "TING").replace("CREATEING", "CREATING") + ' ' + self.RowItemBase.upper() + + @property + def ActionShortSingle(self): + return self.ActionShort.replace('(s)', '') + + @property + def ActionTemplate(self): + return self._action_template_() + + @property + def ActionTemplateNumeric(self): + return self._action_template_(True) + + def _action_template_(self, numeric=False, short_row=False): + if self.__action_template: + return self.__action_template + return self.ActionBase + (' {num} ' if numeric else ' ') + (self.RowItemBase if short_row else self.RowItemFull) + + @property + def ActionBase(self): + return self.__action_base + + def _action_(self, **kw): + strNum = '' if self.emptyResults else '%3d ' % self.Max + template = re.sub(r'\(([sS])\)', '' if self.Max == 1 else r'\1', self._action_template_(**kw)) + return template.replace('{num} ', strNum) + + @property + def Action(self): + return self._action_() + + @property + def ActionNumeric(self): + return self._action_(numeric=True) + + @property + def Automated(self): + return self.__automated + + @property + def RowItemFull(self): + if self.__row_item_full: + if self.RowItemBase in self.__row_item_full: + return self.__row_item_full + return self.RowItemBase + ' ' + self.__row_item_full + return self.RowItemBase + + @property + def RowItemBase(self): + return self.__row_item_base + + @property + def RowSource(self): + if self.__row_source: + return self.__row_source + return self.RowItemFull + + @property + def Label(self): + return self.__label + + @Label.setter + def Label(self, value): + self.__label = value + + @property + def Max(self): + if not self.__max: + return -1 + if isinstance(self.__max, Counter): + return self.__max.val + return self.__max + + @Max.setter + def Max(self, value): self.__max = value + + @property + def Interval(self): + return self.__interval + + @property + def emptyResults(self): + return not self.Max or self.Max < 0 + + @property + def willReportProgress(self): + if self.emptyResults: + return False + if not self.Interval or self.Interval < 1: + return False + return self.Max > self.Interval + + def FormatLine(self, text, num=None): + if isinstance(num, Counter): + num = num.val + return text.format(num=('%' + str(len(str(self.Max))) + 'd ') % num if num else '', + row_sources=self.RowSource.replace('(s)', 's'), + rows=self.RowItemFull.replace('(s)', 's'), + row=self.RowItemFull.replace('(s)', ''), + row_=self.RowItemFull, + r=self.RowItemFull.replace('(s)', '' if num == 1 else 's'), + action=self.Action + ' ' + ) + + def ActionLine(self, title, text='', num=None, short_row=True, **kw): + if num: + kw['numeric'] = True + kw['short_row'] = short_row + action = self._action_(**kw) + if action == action.upper(): + title = title.upper() + if text: + text = ': ' + self.FormatLine(text, num) + return " > %s %s%s" % (action, title, text) + + @property + def Aborted(self): + return self.ActionLine("Aborted", "No Qualifying {row_sources} Found") + + @property + def Initiated(self): + return self.ActionLine("Initiated", num=self.Max) + + def BannerHeader(self, append_newline=False, filename=None, crosspost=None, bh_wrap_filename=True, **kw): + if filename is None: + filename = '' + if bh_wrap_filename: + filename = self.Label + filename + if crosspost is not None: + crosspost = [self.Label + cp for cp in item_to_list(crosspost, False)] + log_banner(self.ActionNumeric.upper(), do_print=self.__do_print, **DictCaseInsensitive(kw, locals(), delete='self kw bh_wrap_filename cp')) + + def setStatus(self, status): + self.Status = status + return status + + def displayInitialInfo(self, max=None, interval=None, automated=None, enabled=None, **kw): + if max: + self.__max = max + if interval: + self.__interval = interval + if automated is not None: + self.__automated = automated + if enabled is not None: + self.__enabled = enabled + if self.emptyResults: + if not self.Automated and self.__report_if_empty: + log('report: ' + self.Aborted, self.Label) + show_report(self.Aborted, blank_line_before=False) + else: + log('report: [automated] ' + self.Aborted, self.Label) + return self.setStatus(EvernoteAPIStatus.EmptyRequest) + if self.__enabled is False: + log("Not starting - stopwatch.ActionInfo: enabled = false ", self.Label, do_print=self.__do_print) + if not automated: + show_report(self.ActionLine("Aborted", "Action has been disabled"), + blank_line_before=False) + return self.setStatus(EvernoteAPIStatus.Disabled) + log(self.Initiated, do_print=self.__do_print) + self.BannerHeader() + return self.setStatus(EvernoteAPIStatus.Initialized) + + def str_value(self, str_name): + return getattr(self, '_' + self.__class__.__name__ + '__' + str_name + '_str') + + def __init__(self, action_base='Upload of Validated Evernote Notes', row_item_base=None, row_item_full=None, + action_full=None, action_template=None, label=None, auto_label=True, max=None, automated=False, enabled=True, + interval=None, row_source=None, do_print=False, report_if_empty=True, **kw): + self.__action_short = None + if label is None and auto_label: + label = caller_name(return_filename=True) + if row_item_base is None: + actions = action_base.split() + action_base = actions[0] + if len(actions) == 1: + action_base = actions[0] + row_item_base = action_base + else: + if actions[1].lower() == 'of': + action_base += ' ' + actions[1] + actions = actions[1:] + assert len(actions) > 1 + row_item_base = ' '.join(actions[1:]) + if row_item_full is None and len(actions) > 2: + row_item_base = actions[-1] + row_item_full = ' '.join(actions[1:]) + self.__action_base = action_base + self.__action_full = action_full + self.__row_item_base = row_item_base + self.__row_item_full = row_item_full + self.__row_source = row_source + self.__action_template = action_template + self.__automated = automated + self.__enabled = enabled + self.__label = label + self.__max = max + self.__interval = interval + self.__do_print = do_print + self.__report_if_empty = report_if_empty + + +class Timer(object): + __times = [] + __stopped = None + __start = None + __status = EvernoteAPIStatus.Uninitialized + __counts = None + __did_break = True + __laps = 0 + __interval = 100 + __parent_timer = None + __caller = None + __info = None + """:type : Timer""" + + @property + def counts(self): + if self.__counts is None: + log("Init counter from property: " + repr(self.__counts), "counters") + self.__counts = EvernoteCounter() + return self.__counts + + @counts.setter + def counts(self, value): + self.__counts = value + + @property + def laps(self): + return len(self.__times) + + @property + def max(self): + return self.counts.max + + @max.setter + def max(self, value): + self.counts.max = value + if self.counts.max_allowed < 1: + self.counts.max_allowed = value + + @property + def is_success(self): + return self.counts.success + + @property + def parent(self): + return self.__parent_timer + + @property + def label(self): + if self.info: + return self.info.Label + return "" + + @label.setter + def label(self, value): + if self.info and isinstance(self.info, ActionInfo): + self.info.Label = value + return + self.__info = ActionInfo(value, label=value) + + @parent.setter + def parent(self, value): + """:type value : Timer""" + self.__parent_timer = value + + @property + def parentTotal(self): + if not self.__parent_timer: + return -1 + return self.__parent_timer.total + + @property + def percentOfParent(self): + if not self.__parent_timer: + return -1 + return float(self.total) / float(self.parentTotal) * 100 + + @property + def percentOfParentStr(self): + return str(int(round(self.percentOfParent))) + '%' + + @property + def percentComplete(self): + if not self.counts.max: + return -1 + return float(self.count) / self.counts.max * 100 + + @property + def percentCompleteStr(self): + return str(int(round(self.percentComplete))) + '%' + + @property + def rate(self): + return self.rateCustom() + + @property + def rateStr(self): + return self.rateStrCustom() + + def rateCustom(self, unit=None): + if unit is None: + unit = self.__interval + return self.elapsed / self.count * unit + + def rateStrCustom(self, unit=None): + if unit is None: + unit = self.__interval + return self.__timetostr(self.rateCustom(unit)) + + @property + def count(self): + return max(self.counts.val, 1) + + @property + def projectedTime(self): + if not self.counts.max: + return -1 + return self.counts.max * self.rateCustom(1) + + @property + def projectedTimeStr(self): + return self.__timetostr(self.projectedTime) + + @property + def remainingTime(self): + return self.projectedTime - self.elapsed + + @property + def remainingTimeStr(self): + return self.__timetostr(self.remainingTime) + + @property + def progress(self): + return '%5s (%3s): @ %3s/%d. %3s of %3s remain' % ( + self.__timetostr(short=False), self.percentCompleteStr, self.rateStr, self.__interval, + self.remainingTimeStr, + self.projectedTimeStr) + + @property + def active(self): + return self.__start and not self.__stopped + + @property + def completed(self): + return self.__start and self.__stopped + + @property + def lap_info(self): + strs = [] + if self.active: + strs.append('Active: %s' % self.__timetostr()) + elif self.completed: + strs.append('Latest: %s' % self.__timetostr()) + elif self.laps > 0: + strs.append('Last: %s' % self.__timetostr(self.__times)) + if self.laps > 0 + 0 if self.active or self.completed else 1: + strs.append('%2d Laps: %s' % (self.laps, self.__timetostr(self.history))) + strs.append('Average: %s' % self.__timetostr(self.average)) + if self.__parent_timer: + strs.append("Parent: %s" % self.__timetostr(self.parentTotal)) + strs.append(" (%3s) " % self.percentOfParentStr) + return ' | '.join(strs) + + @property + def isProgressCheck(self): + if not self.counts.max: + return False + return self.count % max(self.__interval, 1) is 0 + + @property + def status(self): + if self.hasActionInfo: + return self.info.Status + return self.__status + + @status.setter + def status(self, value): + if self.hasActionInfo: + self.info.Status = value + + def autoStep(self, returned_tuple, title=None, update=None): + retval = self.extractStatus(returned_tuple, update) + self.step(title) + return retval + + def extractStatus(self, returned_tuple, update=None): + self.report_result = self.reportStatus(returned_tuple[0], update) + if len(returned_tuple) == 2: + return returned_tuple[1] + return returned_tuple[1:] + + def checkLimits(self): + if not -1 < self.counts.max_allowed <= self.counts.updated + self.counts.created: + return True + log("Count exceeded- Breaking with status " + str(self.status), self.label, do_print=self.__do_print) + self.reportStatus(EvernoteAPIStatus.ExceededLocalLimit) + return False + + def reportStatus(self, status, update=None, title=None, **kw): + """ + :type status : EvernoteAPIStatus + """ + self.status = status + if status.IsError: + retval = self.reportError(save_status=False) + elif status == EvernoteAPIStatus.RequestQueued: + retval = self.reportQueued(save_status=False) + elif status.IsSuccess: + retval = self.reportSuccess(update, save_status=False) + elif status == EvernoteAPIStatus.ExceededLocalLimit: + retval = status + else: + self.counts.unhandled.step() + retval = False + if title: + self.step(title, **kw) + return retval + + def reportSkipped(self, save_status=True): + if save_status: + self.status = EvernoteAPIStatus.RequestSkipped + return self.counts.skipped.step() + + def reportSuccess(self, update=None, save_status=True): + if save_status: + self.status = EvernoteAPIStatus.Success + if update: + self.counts.updated.completed.step() + else: + self.counts.created.completed.step() + return self.counts.success + + def reportError(self, save_status=True): + if save_status: + self.status = EvernoteAPIStatus.GenericError + return self.counts.error.step() + + def reportQueued(self, save_status=True, update=None): + if save_status: + self.status = EvernoteAPIStatus.RequestQueued + if update: + return self.counts.updated.queued.step() + return self.counts.created.queued.step() + + @property + def ReportHeader(self): + return None if not self.counts.total else self.info.FormatLine( + "%s {r} were processed" % counts_as_str(self.counts.total, self.counts.max), self.counts.total) + + def ReportSingle(self, text, count, subtext='', queued_text='', queued=0, subcount=0, process_subcounts=True): + if not count: + return [] + if isinstance(count, Counter) and process_subcounts: + if count.queued: + queued = count.queued.val + if count.completed.subcount: + subcount = count.completed.subcount.val + if not queued_text: + queued_text = self.info.str_value('queued') + strs = [self.info.FormatLine("%s {r} %s" % (counts_as_str(count), text), self.count)] + if process_subcounts: + if queued: + strs.append("-%-3d of these were queued%s" % (queued, queued_text)) + if subcount: + strs.append("-%-3d of these were successfully%s " % (subcount, subtext)) + return strs + + def Report(self, subcount_created=0, subcount_updated=0): + str_tips = [] + self.counts.created.completed.subcount = subcount_created + self.counts.updated.completed.subcount = subcount_updated + str_tips += self.ReportSingle('were newly created', self.counts.created, self.info.str_value('created')) + str_tips += self.ReportSingle('already exist and were updated', self.counts.updated, self.info.str_value('updated')) + str_tips += self.ReportSingle('already exist but were unchanged', self.counts.skipped, process_subcounts=False) + if self.counts.error: + str_tips.append("%d Error(s) occurred " % self.counts.error.val) + if self.status == EvernoteAPIStatus.ExceededLocalLimit: + str_tips.append("Action was prematurely terminated because locally-defined limit of %d was exceeded." % + self.counts.max_allowed) + report_title = " > %s Complete" % self.info.Action + if self.counts.total is 0: + report_title += self.info.FormatLine(": No {r} were processed") + show_report(report_title, self.ReportHeader, str_tips, blank_line_before=False, do_print=self.__do_print) + log_blank('counters') + log(self.counts.fullSummary(self.name + ': End'), 'counters') + + def increment(self, *a, **kw): + self.counts.step(**kw) + return self.step(*a, **kw) + + def step(self, title=None, **kw): + if self.hasActionInfo and self.isProgressCheck and title: + title_str = ("%" + str(len('#' + str(self.max))) + "s: %s") % ('#' + str(self.count), title) + progress_str = ' [%s]' % self.progress + title_len = ANKNOTES.FORMATTING.LINE_LENGTH_TOTAL - 1 - 2 - len(progress_str) + log_path = self.label + ('' if self.label.endswith('\\') else '-') + 'progress' + if not self.__reported_progress: + self.info.BannerHeader(filename=log_path, bh_wrap_filename=False) + self.__reported_progress = True + log(title_str.ljust(title_len) + progress_str, log_path, timestamp=False, do_print=self.__do_print, **kw) + return self.isProgressCheck + + @property + def info(self): + """ + :rtype : ActionInfo + """ + return self.__info + + @property + def did_break(self): return self.__did_break + + def reportNoBreak(self): self.__did_break = False + + @property + def should_retry(self): return self.did_break and self.status != EvernoteAPIStatus.ExceededLocalLimit + + @property + def automated(self): + if self.info is None: + return False + return self.info.Automated + + def hasActionInfo(self): + return self.info is not None and self.counts.max > 0 + + def __init__(self, max=None, interval=100, info=None, infoStr=None, automated=None, begin=True, + label=None, display_initial_info=None, max_allowed=None, do_print=False, **kw): + """ + :type info : ActionInfo + """ + args = DictCaseInsensitive(kw, locals(), delete='kw infoStr info max self') + simple_label = False + self.counts = EvernoteCounter() + self.__interval = interval + self.__reported_progress = False + if not isinstance(max, int): + if hasattr(max, '__len__'): + max = len(max) + else: + max = None + self.counts.max = -1 + if max is not None: + self.counts.max = max + args.max = self.counts.max + if is_str(info): + # noinspection PyTypeChecker + info = ActionInfo(info, **args) + elif infoStr and not info: + info = ActionInfo(infoStr, **args) + elif label and not info: + simple_label = True + if display_initial_info is None: + display_initial_info = False + info = ActionInfo(label, **args) + elif label: + info.Label = label + if self.counts.max > 0 and info and (info.Max is None or info.Max < 1): + info.Max = max + self.counts.max_allowed = self.counts.max if max_allowed is None else max_allowed + self.__did_break = True + self.__do_print = do_print + self.__info = info + self.__action_initialized = False + self.__action_attempted = self.hasActionInfo and (display_initial_info is not False) + if self.__action_attempted: + if self.info is None: + log("Unexpected; Timer '%s' has no ActionInfo instance" % label, do_print=True) + else: + self.__action_initialized = self.info.displayInitialInfo(**args) is EvernoteAPIStatus.Initialized + if begin: + self.reset(False) + log_blank(filename='counters') + log(self.counts.fullSummary(self.name + ': Start'), 'counters') + + @property + def name(self): + name = (self.label.strip('\\').replace('\\', ': ') if self.label else self.caller) + return name.replace('.', ': ').replace('-', ': ').replace('_', ' ').capitalize() + + @property + def base_name(self): + return self.name.split(': ')[-1] + + @property + def caller(self): + if self.__caller is None: + self.__caller = caller_name(return_filename=True) + return self.__caller + + @property + def willReportProgress(self): + return self.counts.max and self.counts.max > self.interval + + @property + def actionInitializationFailed(self): + return self.__action_attempted and not self.__action_initialized + + @property + def interval(self): + return max(self.__interval, 1) + + def start(self): + self.reset() + + def reset(self, reset_counter=True): + # keep = [] + # if self.counts: + # keep = [self.counts.max, self.counts.max_allowed] + # del self.__counts + if reset_counter: + log("Resetting counter", 'counters') + if self.counts is None: + self.counts = EvernoteCounter() + else: + self.counts.reset() + # if keep: + # self.counts.max = keep[0] + # self.counts.max_allowed = keep[1] + if not self.__stopped: + self.stop() + self.__stopped = None + self.__start = self.__time() + + def stop(self): + """Stops the clock permanently for the instance of the Timer. + Returns the time at which the instance was stopped. + """ + if not self.__start: + return -1 + self.__stopped = self.__last_time() + self.__times.append(self.elapsed) + return self.elapsed + + @property + def history(self): + return sum(self.__times) + + @property + def total(self): + return self.history + self.elapsed + + @property + def average(self): + return float(self.history) / self.laps + + def elapsed(self): + """The number of seconds since the current time that the Timer + object was created. If stop() was called, it is the number + of seconds from the instance creation until stop() was called. + """ + if not self.__start: + return -1 + return self.__last_time() - self.__start + + elapsed = property(elapsed) + + def start_time(self): + """The time at which the Timer instance was created. + """ + return self.__start + + start_time = property(start_time) + + def stop_time(self): + """The time at which stop() was called, or None if stop was + never called. + """ + return self.__stopped + + stop_time = property(stop_time) + + def __last_time(self): + """Return the current time or the time at which stop() was call, + if called at all. + """ + if self.__stopped is not None: + return self.__stopped + return self.__time() + + def __time(self): + """Wrapper for time.time() to allow unit testing. + """ + return time.time() + + @property + def str_long(self): + return self.__timetostr(short=False) + + def __timetostr(self, total_seconds=None, short=True, pad=True): + if total_seconds is None: + total_seconds = self.elapsed + total_seconds = int(round(total_seconds)) + if total_seconds < 60: + return ['%ds', '%2ds'][pad] % total_seconds + m, s = divmod(total_seconds, 60) + if short: + # if total_seconds < 120: return '%dm' % (m, s) + return ['%dm', '%2dm'][pad] % m + return '%d:%02d' % (m, s) + + def __str__(self): + """Nicely format the elapsed time + """ + return self.__timetostr() + + def __repr__(self): + return "<%s%s> %s" % (self.__class__.__name__, '' if not self.label else ':%s' % self.label, self.str_long) + + +all_clockit_timers = {} + + +def clockit(func): + """Function decorator that times the evaluation of *func* and prints the + execution time. + """ + + def new(*args, **kw): + # fn = func.__name__ + # print "Request to clock %s" % fn + # return func(*args, **kw) + global all_clockit_timers + fn = func.__name__ + if fn not in all_clockit_timers: + all_clockit_timers[fn] = Timer() + else: + all_clockit_timers[fn].reset() + retval = func(*args, **kw) + all_clockit_timers[fn].stop() + # print ('Function %s completed in %s\n > %s' % (fn, all_clockit_timers[fn].__timetostr(short=False), all_clockit_timers[fn].lap_info)) + return retval + + return new diff --git a/anknotes/stopwatch/tests/__init__.py b/anknotes/stopwatch/tests/__init__.py new file mode 100644 index 0000000..2c810c8 --- /dev/null +++ b/anknotes/stopwatch/tests/__init__.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2008 John Paulett (john -at- 7oars.com) +# All rights reserved. +# +# This software is licensed as described in the file COPYING, which +# you should have received as part of this distribution. + +import unittest +import doctest +import stopwatch +from stopwatch import clockit + + +class TimeControlledTimer(stopwatch.Timer): + def __init__(self): + self.__count = 0 + super(TimeControlledTimer, self).__init__() + + def __time(self): + retval = self.__count + self.__count += 1 + return retval + + +class TimerTestCase(unittest.TestCase): + def setUp(self): + self.timer = stopwatch.Timer() + + def test_simple(self): + point1 = self.timer.elapsed + self.assertTrue(point1 > 0) + + point2 = self.timer.elapsed + self.assertTrue(point2 > point1) + + point3 = self.timer.elapsed + self.assertTrue(point3 > point2) + + def test_stop(self): + point1 = self.timer.elapsed + self.assertTrue(point1 > 0) + + self.timer.stop() + point2 = self.timer.elapsed + self.assertTrue(point2 > point1) + + point3 = self.timer.elapsed + self.assertEqual(point2, point3) + + +@clockit +def timed_multiply(a, b): + return a * b + + +class DecoratorTestCase(unittest.TestCase): + def test_clockit(self): + self.assertEqual(6, timed_multiply(2, b=3)) + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TimerTestCase)) + suite.addTest(unittest.makeSuite(DecoratorTestCase)) + # suite.addTest(doctest.DocTestSuite(stopwatch)) + return suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/anknotes/structs.py b/anknotes/structs.py new file mode 100644 index 0000000..917b89a --- /dev/null +++ b/anknotes/structs.py @@ -0,0 +1,770 @@ +import re +# from BeautifulSoup import UnicodeDammit +import anknotes +from bs4 import UnicodeDammit + + +from anknotes.constants import * +from anknotes.base import item_to_set, item_to_list +from anknotes.db import * +from anknotes.enum import Enum +from anknotes.html import strip_tags +from anknotes.logging import PadList, JoinList +from anknotes.enums import * +from anknotes.EvernoteNoteTitle import EvernoteNoteTitle + + +# from evernote.edam.notestore.ttypes import NoteMetadata, NotesMetadataList + +def upperFirst(name): + return name[0].upper() + name[1:] + + +def getattrcallable(obj, attr): + val = getattr(obj, attr) + if callable(val): + return val() + return val + + +# from anknotes.EvernoteNotePrototype import EvernoteNotePrototype +# from anknotes.EvernoteNoteTitle import EvernoteNoteTitle + +class EvernoteStruct(object): + success = False + Name = "" + Guid = "" + _sql_columns_ = "name" + _sql_table_ = TABLES.EVERNOTE.TAGS + _sql_where_ = "guid" + _attr_order_ = None + _additional_attr_ = None + _title_is_note_title_ = False + + @staticmethod + def _attr_from_key_(key): + return upperFirst(key) + + def keys(self): + return self._valid_attributes_() + + def items(self): + return [self.getAttribute(key) for key in self._attr_order_] + + def sqlUpdateQuery(self): + columns = self._attr_order_ if self._attr_order_ else self._sql_columns_ + return "INSERT OR REPLACE INTO `%s`(%s) VALUES (%s)" % ( + self._sql_table_, '`' + '`,`'.join(columns) + '`', ', '.join(['?'] * len(columns))) + + def sqlSelectQuery(self, allColumns=True): + return "SELECT %s FROM %s WHERE %s = ?" % ( + '*' if allColumns else ','.join(self._sql_columns_), self._sql_table_, self._sql_where_) + + def getFromDB(self, allColumns=True): + ankDB().setrowfactory() + result = ankDB().first(self.sqlSelectQuery(allColumns), self.Where) + if result: + self.success = True + self.setFromKeyedObject(result) + else: + self.success = False + return self.success + + @property + def Where(self): + return self.getAttribute(self._sql_where_) + + @Where.setter + def Where(self, value): + self.setAttribute(self._sql_where_, value) + + def getAttribute(self, key, default=None, raiseIfInvalidKey=False): + if not self.hasAttribute(key): + if raiseIfInvalidKey: + raise KeyError + return default + return getattr(self, self._attr_from_key_(key)) + + def hasAttribute(self, key): + return hasattr(self, self._attr_from_key_(key)) + + def setAttribute(self, key, value): + if key == "fetch_" + self._sql_where_: + self.setAttribute(self._sql_where_, value) + self.getFromDB() + elif self._is_valid_attribute_(key): + setattr(self, self._attr_from_key_(key), value) + else: + raise KeyError("%s: %s is not a valid attribute" % (self.__class__.__name__, key)) + + def setAttributeByObject(self, key, keyed_object): + self.setAttribute(key, keyed_object[key]) + + def setFromKeyedObject(self, keyed_object, keys=None): + """ + + :param keyed_object: + :type: sqlite.Row | dict[str, object] | re.MatchObject | _sre.SRE_Match + :return: + """ + lst = self._valid_attributes_() + if keys or isinstance(keyed_object, dict): + pass + elif isinstance(keyed_object, type(re.search('', ''))): + regex_attr = 'wholeRegexMatch' + self._additional_attr_.add(regex_attr) + whole_match = keyed_object.group(0) + keyed_object = keyed_object.groupdict() + keyed_object[regex_attr] = whole_match + elif hasattr(keyed_object, 'keys'): + keys = getattrcallable(keyed_object, 'keys') + elif hasattr(keyed_object, self._sql_where_): + for key in self.keys(): + if hasattr(keyed_object, key): + self.setAttribute(key, getattr(keyed_object, key)) + return True + else: + return False + + if keys is None: + keys = keyed_object + for key in keys: + if key == "fetch_" + self._sql_where_: + self.Where = keyed_object[key] + self.getFromDB() + elif key in lst: + self.setAttributeByObject(key, keyed_object) + return True + + def setFromListByDefaultOrder(self, args): + max = len(self._attr_order_) + for i, value in enumerate(args): + if i > max: + raise Exception("Argument #%d for %s (%s) exceeds the default number of attributes for the class." % ( + i, self.__class__.__name__, str(value))) + self.setAttribute(self._attr_order_[i], value) + + def _valid_attributes_(self): + return self._additional_attr_.union(self._sql_columns_, [self._sql_where_], self._attr_order_) + + def _is_valid_attribute_(self, attribute): + return (attribute[0].lower() + attribute[1:]) in self._valid_attributes_() + + def __init__(self, *args, **kwargs): + if self._attr_order_ is None: + self._attr_order_ = [] + if self._additional_attr_ is None: + self._additional_attr_ = set() + self._sql_columns_ = item_to_list(self._sql_columns_, chrs=' ,;') + self._attr_order_ = item_to_list(self._attr_order_, chrs=' ,;') + self._additional_attr_ = item_to_set(self._additional_attr_, chrs=' ,;') + args = list(args) + if args and self.setFromKeyedObject(args[0]): + del args[0] + self.setFromListByDefaultOrder(args) + self.setFromKeyedObject(kwargs) + + +class EvernoteNotebook(EvernoteStruct): + Stack = "" + _sql_columns_ = ["name", "stack"] + _sql_table_ = TABLES.EVERNOTE.NOTEBOOKS + + +class EvernoteTag(EvernoteStruct): + ParentGuid = "" + UpdateSequenceNum = -1 + _sql_columns_ = ["name", "parentGuid"] + _sql_table_ = TABLES.EVERNOTE.TAGS + _attr_order_ = 'guid|name|parentGuid|updateSequenceNum' + + +class EvernoteLink(EvernoteStruct): + _uid_ = -1 + Shard = 'x999' + Guid = "" + WholeRegexMatch = "" + _title_ = None + """:type: EvernoteNoteTitle.EvernoteNoteTitle """ + _attr_order_ = 'uid|shard|guid|title' + + def __init__(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) + + @property + def HTML(self): + return self.Title.HTML + + @property + def Title(self): + """:rtype : EvernoteNoteTitle.EvernoteNoteTitle""" + return self._title_ + + @property + def FullTitle(self): return self.Title.FullTitle + + @Title.setter + def Title(self, value): + """ + :param value: + :type value : EvernoteNoteTitle.EvernoteNoteTitle | str | unicode + :return: + """ + self._title_ = anknotes.EvernoteNoteTitle.EvernoteNoteTitle(value) + """:type : EvernoteNoteTitle.EvernoteNoteTitle""" + + @property + def Uid(self): + return int(self._uid_) + + @Uid.setter + def Uid(self, value): + self._uid_ = int(value) + + @property + def NoteTitle(self): + f = anknotes.EvernoteNoteFetcher.EvernoteNoteFetcher(guid=self.Guid, use_local_db_only=True) + if not f.getNote(): + return "<Invalid Note>" + return f.result.Note.FullTitle + + def __str__(self): + return "<%s> %s: %s" % (self.__class__.__name__, self.Guid, self.FullTitle) + + def __repr__(self): + # id = + return "<%s> %s: %s" % (self.__class__.__name__, self.Guid, self.NoteTitle) + + +class EvernoteTOCEntry(EvernoteStruct): + RealTitle = "" + """:type : str""" + OrderedList = "" + """ + HTML output of Root Title's Ordererd List + :type : str + """ + TagNames = "" + """:type : str""" + NotebookGuid = "" + + def __init__(self, *args, **kwargs): + self._attr_order_ = 'realTitle|orderedList|tagNames|notebookGuid' + super(self.__class__, self).__init__(*args, **kwargs) + + +class EvernoteValidationEntry(EvernoteStruct): + Guid = "" + """:type : str""" + Title = "" + """:type : str""" + Contents = "" + """:type : str""" + TagNames = "" + """:type : str""" + NotebookGuid = "" + NoteType = "" + + def __init__(self, *args, **kwargs): + # spr = super(self.__class__ , self) + # spr._attr_order_ = self._attr_order_ + # spr.__init__(*args, **kwargs) + self._attr_order_ = 'guid|title|contents|tagNames|notebookGuid|noteType' + super(self.__class__, self).__init__(*args, **kwargs) + + +class EvernoteAPIStatusOld(AutoNumber): + Uninitialized = -100 + """:type : EvernoteAPIStatus""" + EmptyRequest = -3 + """:type : EvernoteAPIStatus""" + Manual = -2 + """:type : EvernoteAPIStatus""" + RequestQueued = -1 + """:type : EvernoteAPIStatus""" + Success = 0 + """:type : EvernoteAPIStatus""" + RateLimitError = () + """:type : EvernoteAPIStatus""" + SocketError = () + """:type : EvernoteAPIStatus""" + UserError = () + """:type : EvernoteAPIStatus""" + NotFoundError = () + """:type : EvernoteAPIStatus""" + UnhandledError = () + """:type : EvernoteAPIStatus""" + Unknown = 100 + """:type : EvernoteAPIStatus""" + + def __getitem__(self, item): + """:rtype : EvernoteAPIStatus""" + + return super(self.__class__, self).__getitem__(item) + + # def __new__(cls, *args, **kwargs): + # """:rtype : EvernoteAPIStatus""" + # return type(cls).__new__(*args, **kwargs) + + @property + def IsError(self): + return EvernoteAPIStatus.Unknown.value > self.value > EvernoteAPIStatus.Success.value + + @property + def IsSuccessful(self): + return EvernoteAPIStatus.Success.value >= self.value > EvernoteAPIStatus.Uninitialized.value + + @property + def IsSuccess(self): + return self == EvernoteAPIStatus.Success + + +class EvernoteAPIStatus(AutoNumberedEnum): + Uninitialized = -100 + """:type : EvernoteAPIStatus""" + Initialized = -75 + """:type : EvernoteAPIStatus""" + UnableToFindStatus = -70 + """:type : EvernoteAPIStatus""" + InvalidStatus = -60 + """:type : EvernoteAPIStatus""" + Cancelled = -50 + """:type : EvernoteAPIStatus""" + Disabled = -25 + """:type : EvernoteAPIStatus""" + Unchanged = -15 + """:type : EvernoteAPIStatus""" + EmptyRequest = -10 + """:type : EvernoteAPIStatus""" + Manual = -5 + """:type : EvernoteAPIStatus""" + RequestSkipped = -4 + """:type : EvernoteAPIStatus""" + RequestQueued = -3 + """:type : EvernoteAPIStatus""" + ExceededLocalLimit = -2 + """:type : EvernoteAPIStatus""" + DelayedDueToRateLimit = -1 + """:type : EvernoteAPIStatus""" + Success = 0 + """:type : EvernoteAPIStatus""" + RateLimitError = () + """:type : EvernoteAPIStatus""" + SocketError = () + """:type : EvernoteAPIStatus""" + UserError = () + """:type : EvernoteAPIStatus""" + UnchangedError = () + """:type : EvernoteAPIStatus""" + NotFoundError = () + """:type : EvernoteAPIStatus""" + MissingDataError = () + """:type : EvernoteAPIStatus""" + UnhandledError = () + """:type : EvernoteAPIStatus""" + GenericError = () + """:type : EvernoteAPIStatus""" + Unknown = 100 + """:type : EvernoteAPIStatus""" + + # def __new__(cls, *args, **kwargs): + # """:rtype : EvernoteAPIStatus""" + # return type(cls).__new__(*args, **kwargs) + + @property + def IsError(self): + return EvernoteAPIStatus.Unknown.value > self.value > EvernoteAPIStatus.Success.value + + @property + def IsDelayableError(self): + return self.value == EvernoteAPIStatus.RateLimitError.value or self.value == EvernoteAPIStatus.SocketError.value + + @property + def IsSuccessful(self): + return EvernoteAPIStatus.Success.value >= self.value >= EvernoteAPIStatus.Manual.value + + @property + def IsSuccess(self): + return self == EvernoteAPIStatus.Success + + +class EvernoteImportType: + Add, UpdateInPlace, DeleteAndUpdate = range(3) + + +class EvernoteNoteFetcherResult(object): + def __init__(self, note=None, status=None, source=-1): + """ + + :type note: EvernoteNotePrototype.EvernoteNotePrototype + :type status: EvernoteAPIStatus + """ + if not status: + status = EvernoteAPIStatus.Uninitialized + self.Note = note + self.Status = status + self.Source = source + + +class EvernoteNoteFetcherResults(object): + Status = EvernoteAPIStatus.Uninitialized + ImportType = EvernoteImportType.Add + Local = 0 + Notes = [] + Imported = 0 + Max = 0 + AlreadyUpToDate = 0 + + @property + def DownloadSuccess(self): + return self.Count == self.Max + + @property + def AnkiSuccess(self): + return self.Imported == self.Count + + @property + def TotalSuccess(self): + return self.DownloadSuccess and self.AnkiSuccess + + @property + def LocalDownloadsOccurred(self): + return self.Local > 0 + + @property + def Remote(self): + return self.Count - self.Local + + @property + def SummaryShort(self): + add_update_strs = ['New', "Added"] if self.ImportType == EvernoteImportType.Add else ['Existing', + 'Updated In-Place' if self.ImportType == EvernoteImportType.UpdateInPlace else 'Deleted and Updated'] + return "%d %s Notes Have Been %s" % (self.Imported, add_update_strs[0], add_update_strs[1]) + + @property + def SummaryLines(self): + if self.Max is 0: + return [] + add_update_strs = ['New', "Added to"] if self.ImportType == EvernoteImportType.Add else ['Existing', + "%s in" % ( + 'Updated In-Place' if self.ImportType == EvernoteImportType.UpdateInPlace else 'Deleted and Updated')] + add_update_strs[1] += " Anki" + + ## Evernote Status + if self.DownloadSuccess: + line = "All %3d" % self.Max + else: + line = "%3d of %3d" % (self.Count, self.Max) + lines = [line + " %s Evernote Metadata Results Were Successfully Downloaded%s." % ( + add_update_strs[0], (' And %s' % add_update_strs[1]) if self.AnkiSuccess else '')] + if self.Status.IsError: + lines.append("-An error occurred during download (%s)." % str(self.Status)) + + ## Local Calls + if self.LocalDownloadsOccurred: + lines.append( + "-%3d %s note%s unexpectedly found in the local db and did not require an API call." % ( + self.Local, add_update_strs[0], 's were' if self.Local > 1 else ' was')) + lines.append("-%3d %s note(s) required an API call" % (self.Remote, add_update_strs[0])) + if not self.ImportType == EvernoteImportType.Add and self.AlreadyUpToDate > 0: + lines.append( + "-%3d existing note%s already up-to-date with Evernote's servers, so %s not retrieved." % ( + self.AlreadyUpToDate, 's are' if self.Local > 1 else ' is', + 'they were' if self.Local > 1 else 'it was')) + + ## Anki Status + if self.DownloadSuccess: + return lines + if self.AnkiSuccess: + line = "All %3d" % self.Imported + else: + line = "%3d of %3d" % (self.Imported, self.Count) + lines.append(line + " %s Downloaded Evernote Notes Have Been Successfully %s." % ( + add_update_strs[0], add_update_strs[1])) + + return lines + + @property + def Summary(self): + lines = self.SummaryLines + if len(lines) is 0: + return '' + return '<BR> - '.join(lines) + + @property + def Count(self): + return len(self.Notes) + + @property + def EvernoteFails(self): + return self.Max - self.Count + + @property + def AnkiFails(self): + return self.Count - self.Imported + + def __init__(self, status=None, local=None): + """ + :param status: + :type status : EvernoteAPIStatus + :param local: + :return: + """ + if not status: + status = EvernoteAPIStatus.Uninitialized + if not local: + local = 0 + self.Status = status + self.Local = local + self.Imported = 0 + self.Notes = [] + """ + :type : list[EvernoteNotePrototype.EvernoteNotePrototype] + """ + + def reportResult(self, result): + """ + :type result : EvernoteNoteFetcherResult + """ + self.Status = result.Status + if self.Status == EvernoteAPIStatus.Success: + self.Notes.append(result.Note) + if result.Source == 1: + self.Local += 1 + + +class EvernoteImportProgress: + class _GUIDs: + Anki = None + """:type : anknotes.Anki.Anki""" + Local = None + _anki_note_ids_ = None + + class Server: + All = None + New = None + + class Existing: + All = None + UpToDate = None + OutOfDate = None + + def __init__(self, anki=None, anki_note_ids=None): + if anki is None: + return + self.Anki = anki + self._anki_note_ids_ = anki_note_ids + self.Server.All, self.Server.New = set(), set() + self.Server.Existing.All, self.Server.Existing.UpToDate, self.Server.Existing.OutOfDate = set(), set(), set() + + def setup(self, anki_note_ids=None): + if not anki_note_ids: + anki_note_ids = self._anki_note_ids_ or self.Anki.get_anknotes_note_ids() + self.Local = self.Anki.get_evernote_guids_from_anki_note_ids(anki_note_ids) + + def loadNew(self, server_evernote_guids=None): + if server_evernote_guids: + self.Server.All = server_evernote_guids + if not self.Server.All: + return + if not self.Local: + self.setup() + setServer = set(self.Server.All) + self.Server.New = setServer - set(self.Local) + self.Server.Existing.All = setServer - set(self.Server.New) + + class Results: + Adding = None + """:type : EvernoteNoteFetcherResults""" + Updating = None + """:type : EvernoteNoteFetcherResults""" + + GUIDs = None + + @property + def Adding(self): + return len(self.GUIDs.Server.New) + + @property + def Updating(self): + return len(self.GUIDs.Server.Existing.OutOfDate) + + @property + def AlreadyUpToDate(self): + return len(self.GUIDs.Server.Existing.UpToDate) + + @property + def Success(self): + return self.Status == EvernoteAPIStatus.Success + + @property + def IsError(self): + return self.Status.IsError + + @property + def Status(self): + s1 = self.Results.Adding.Status + s2 = self.Results.Updating.Status if self.Results.Updating else EvernoteAPIStatus.Uninitialized + if s1 == EvernoteAPIStatus.RateLimitError or s2 == EvernoteAPIStatus.RateLimitError: + return EvernoteAPIStatus.RateLimitError + if s1 == EvernoteAPIStatus.SocketError or s2 == EvernoteAPIStatus.SocketError: + return EvernoteAPIStatus.SocketError + if s1.IsError: + return s1 + if s2.IsError: + return s2 + if s1.IsSuccessful and s2.IsSuccessful: + return EvernoteAPIStatus.Success + if s2 == EvernoteAPIStatus.Uninitialized: + return s1 + if s1 == EvernoteAPIStatus.Success: + return s2 + return s1 + + @property + def SummaryList(self): + return [ + "New Notes: %d" % self.Adding, + "Out-Of-Date Notes: %d" % self.Updating, + "Up-To-Date Notes: %d" % self.AlreadyUpToDate + ] + + @property + def Summary(self): return JoinList(self.SummaryList, ' | ', ANKNOTES.FORMATTING.PROGRESS_SUMMARY_PAD) + + def loadAlreadyUpdated(self, db_guids): + self.GUIDs.Server.Existing.UpToDate = db_guids + self.GUIDs.Server.Existing.OutOfDate = set(self.GUIDs.Server.Existing.All) - set( + self.GUIDs.Server.Existing.UpToDate) + + def processUpdateInPlaceResults(self, results): + return self.processResults(results, EvernoteImportType.UpdateInPlace) + + def processDeleteAndUpdateResults(self, results): + return self.processResults(results, EvernoteImportType.DeleteAndUpdate) + + @property + def ResultsSummaryShort(self): + line = self.Results.Adding.SummaryShort + if self.Results.Adding.Status.IsError: + line += " to Anki. Skipping update due to an error (%s)" % self.Results.Adding.Status + elif not self.Results.Updating: + line += " to Anki. Updating is disabled" + else: + line += " and " + self.Results.Updating.SummaryShort + return line + + @property + def ResultsSummaryLines(self): + lines = [self.ResultsSummaryShort] + self.Results.Adding.SummaryLines + if self.Results.Updating: + lines += self.Results.Updating.SummaryLines + return lines + + @property + def APICallCount(self): + return self.Results.Adding.Remote + self.Results.Updating.Remote if self.Results.Updating else 0 + + def processResults(self, results, importType=None): + """ + :type results : EvernoteNoteFetcherResults + :type importType : EvernoteImportType + """ + if not importType: + importType = EvernoteImportType.Add + results.ImportType = importType + if importType == EvernoteImportType.Add: + results.Max = self.Adding + results.AlreadyUpToDate = 0 + self.Results.Adding = results + else: + results.Max = self.Updating + results.AlreadyUpToDate = self.AlreadyUpToDate + self.Results.Updating = results + + def __init__(self, anki=None, metadataProgress=None, server_evernote_guids=None, anki_note_ids=None): + """ + :param anki: Anknotes Main Anki Instance + :type anki: anknotes.Anki.Anki + :type metadataProgress: EvernoteMetadataProgress + :return: + """ + if not anki: + return + self.GUIDs = self._GUIDs(anki, anki_note_ids) + if metadataProgress: + server_evernote_guids = metadataProgress.Guids + if server_evernote_guids: + self.GUIDs.loadNew(server_evernote_guids) + self.Results.Adding = EvernoteNoteFetcherResults() + self.Results.Updating = EvernoteNoteFetcherResults() + + +class EvernoteMetadataProgress: + Page = Total = Current = UpdateCount = -1 + Status = EvernoteAPIStatus.Uninitialized + Guids = [] + NotesMetadata = {} + """ + :type: dict[str, anknotes.evernote.edam.notestore.ttypes.NoteMetadata] + """ + + @property + def IsFinished(self): + return self.Remaining <= 0 + + @property + def SummaryList(self): + return [["Total Notes: %d" % self.Total, + "Total Pages: %d" % self.TotalPages, + "Returned Notes: %d" % self.Current, + "Result Range: %d-%d" % (self.Offset, self.Completed) + ], + ["Remaining Notes: %d" % self.Remaining, + "Remaining Pages: %d" % self.RemainingPages, + "Update Count: %d" % self.UpdateCount]] + + @property + def Summary(self): return JoinList(self.SummaryList, ['\n', ' | '], ANKNOTES.FORMATTING.PROGRESS_SUMMARY_PAD) + + @property + def QueryLimit(self): return EVERNOTE.IMPORT.QUERY_LIMIT + + @property + def Offset(self): return (self.Page - 1) * self.QueryLimit + + @property + def TotalPages(self): + if self.Total is -1: + return -1 + p = float(self.Total) / self.QueryLimit + return int(p) + (1 if p > int(p) else 0) + + @property + def RemainingPages(self): return max(0, self.TotalPages - self.Page) + + @property + def Completed(self): return self.Current + self.Offset + + @property + def Remaining(self): return self.Total - self.Completed + + def __init__(self, page=1): + self.Page = int(page) + + def loadResults(self, result): + """ + :param result: Result Returned by Evernote API Call to getNoteMetadata + :type result: anknotes.evernote.edam.notestore.ttypes.NotesMetadataList + :return: + """ + self.Total = int(result.totalNotes) + self.Current = len(result.notes) + self.UpdateCount = result.updateCount + self.Status = EvernoteAPIStatus.Success + self.Guids = [] + self.NotesMetadata = {} + for note in result.notes: + # assert isinstance(note, NoteMetadata) + self.Guids.append(note.guid) + self.NotesMetadata[note.guid] = note diff --git a/anknotes/structs_base.py b/anknotes/structs_base.py new file mode 100644 index 0000000..24767e0 --- /dev/null +++ b/anknotes/structs_base.py @@ -0,0 +1,4 @@ + + +class UpdateExistingNotes: + IgnoreExistingNotes, UpdateNotesInPlace, DeleteAndReAddNotes = range(3) \ No newline at end of file diff --git a/anknotes/thrift/TSCons.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/TSCons.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 2404625..0000000 --- a/anknotes/thrift/TSCons.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from os import path -from SCons.Builder import Builder - -def scons_env(env, add=''): - opath = path.dirname(path.abspath('$TARGET')) - lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action = lstr) - env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) - -def gen_cpp(env, dir, file): - scons_env(env) - suffixes = ['_types.h', '_types.cpp'] - targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) - return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/anknotes/thrift/TSCons.py~HEAD b/anknotes/thrift/TSCons.py~HEAD deleted file mode 100644 index 2404625..0000000 --- a/anknotes/thrift/TSCons.py~HEAD +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from os import path -from SCons.Builder import Builder - -def scons_env(env, add=''): - opath = path.dirname(path.abspath('$TARGET')) - lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action = lstr) - env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) - -def gen_cpp(env, dir, file): - scons_env(env) - suffixes = ['_types.h', '_types.cpp'] - targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) - return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/anknotes/thrift/TSerialization.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/TSerialization.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index b19f98a..0000000 --- a/anknotes/thrift/TSerialization.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from protocol import TBinaryProtocol -from transport import TTransport - -def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): - transport = TTransport.TMemoryBuffer() - protocol = protocol_factory.getProtocol(transport) - thrift_object.write(protocol) - return transport.getvalue() - -def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): - transport = TTransport.TMemoryBuffer(buf) - protocol = protocol_factory.getProtocol(transport) - base.read(protocol) - return base - diff --git a/anknotes/thrift/TSerialization.py~HEAD b/anknotes/thrift/TSerialization.py~HEAD deleted file mode 100644 index b19f98a..0000000 --- a/anknotes/thrift/TSerialization.py~HEAD +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from protocol import TBinaryProtocol -from transport import TTransport - -def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): - transport = TTransport.TMemoryBuffer() - protocol = protocol_factory.getProtocol(transport) - thrift_object.write(protocol) - return transport.getvalue() - -def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): - transport = TTransport.TMemoryBuffer(buf) - protocol = protocol_factory.getProtocol(transport) - base.read(protocol) - return base - diff --git a/anknotes/thrift/Thrift.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/Thrift.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 1d271fc..0000000 --- a/anknotes/thrift/Thrift.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,154 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import sys - -class TType: - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 - - _VALUES_TO_NAMES = ( 'STOP', - 'VOID', - 'BOOL', - 'BYTE', - 'DOUBLE', - None, - 'I16', - None, - 'I32', - None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16' ) - -class TMessageType: - CALL = 1 - REPLY = 2 - EXCEPTION = 3 - ONEWAY = 4 - -class TProcessor: - - """Base class for procsessor, which works on two streams.""" - - def process(iprot, oprot): - pass - -class TException(Exception): - - """Base class for all thrift exceptions.""" - - # BaseException.message is deprecated in Python v[2.6,3.0) - if (2,6,0) <= sys.version_info < (3,0): - def _get_message(self): - return self._message - def _set_message(self, message): - self._message = message - message = property(_get_message, _set_message) - - def __init__(self, message=None): - Exception.__init__(self, message) - self.message = message - -class TApplicationException(TException): - - """Application level thrift exceptions.""" - - UNKNOWN = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - - def __str__(self): - if self.message: - return self.message - elif self.type == self.UNKNOWN_METHOD: - return 'Unknown method' - elif self.type == self.INVALID_MESSAGE_TYPE: - return 'Invalid message type' - elif self.type == self.WRONG_METHOD_NAME: - return 'Wrong method name' - elif self.type == self.BAD_SEQUENCE_ID: - return 'Bad sequence ID' - elif self.type == self.MISSING_RESULT: - return 'Missing result' - else: - return 'Default (unknown) TApplicationException' - - def read(self, iprot): - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.message = iprot.readString(); - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.type = iprot.readI32(); - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - oprot.writeStructBegin('TApplicationException') - if self.message != None: - oprot.writeFieldBegin('message', TType.STRING, 1) - oprot.writeString(self.message) - oprot.writeFieldEnd() - if self.type != None: - oprot.writeFieldBegin('type', TType.I32, 2) - oprot.writeI32(self.type) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() diff --git a/anknotes/thrift/Thrift.py~HEAD b/anknotes/thrift/Thrift.py~HEAD deleted file mode 100644 index 1d271fc..0000000 --- a/anknotes/thrift/Thrift.py~HEAD +++ /dev/null @@ -1,154 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import sys - -class TType: - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 - - _VALUES_TO_NAMES = ( 'STOP', - 'VOID', - 'BOOL', - 'BYTE', - 'DOUBLE', - None, - 'I16', - None, - 'I32', - None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16' ) - -class TMessageType: - CALL = 1 - REPLY = 2 - EXCEPTION = 3 - ONEWAY = 4 - -class TProcessor: - - """Base class for procsessor, which works on two streams.""" - - def process(iprot, oprot): - pass - -class TException(Exception): - - """Base class for all thrift exceptions.""" - - # BaseException.message is deprecated in Python v[2.6,3.0) - if (2,6,0) <= sys.version_info < (3,0): - def _get_message(self): - return self._message - def _set_message(self, message): - self._message = message - message = property(_get_message, _set_message) - - def __init__(self, message=None): - Exception.__init__(self, message) - self.message = message - -class TApplicationException(TException): - - """Application level thrift exceptions.""" - - UNKNOWN = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - - def __str__(self): - if self.message: - return self.message - elif self.type == self.UNKNOWN_METHOD: - return 'Unknown method' - elif self.type == self.INVALID_MESSAGE_TYPE: - return 'Invalid message type' - elif self.type == self.WRONG_METHOD_NAME: - return 'Wrong method name' - elif self.type == self.BAD_SEQUENCE_ID: - return 'Bad sequence ID' - elif self.type == self.MISSING_RESULT: - return 'Missing result' - else: - return 'Default (unknown) TApplicationException' - - def read(self, iprot): - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.message = iprot.readString(); - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.type = iprot.readI32(); - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - oprot.writeStructBegin('TApplicationException') - if self.message != None: - oprot.writeFieldBegin('message', TType.STRING, 1) - oprot.writeString(self.message) - oprot.writeFieldEnd() - if self.type != None: - oprot.writeFieldBegin('type', TType.I32, 2) - oprot.writeI32(self.type) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() diff --git a/anknotes/thrift/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 48d659c..0000000 --- a/anknotes/thrift/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['Thrift', 'TSCons'] diff --git a/anknotes/thrift/__init__.py~HEAD b/anknotes/thrift/__init__.py~HEAD deleted file mode 100644 index 48d659c..0000000 --- a/anknotes/thrift/__init__.py~HEAD +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['Thrift', 'TSCons'] diff --git a/anknotes/thrift/protocol/TBase.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/TBase.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index e675c7d..0000000 --- a/anknotes/thrift/protocol/TBase.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,72 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from thrift.Thrift import * -from thrift.protocol import TBinaryProtocol -from thrift.transport import TTransport - -try: - from thrift.protocol import fastbinary -except: - fastbinary = None - -class TBase(object): - __slots__ = [] - - def __repr__(self): - L = ['%s=%r' % (key, getattr(self, key)) - for key in self.__slots__ ] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - for attr in self.__slots__: - my_val = getattr(self, attr) - other_val = getattr(other, attr) - if my_val != other_val: - return False - return True - - def __ne__(self, other): - return not (self == other) - - def read(self, iprot): - if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: - fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) - return - iprot.readStruct(self, self.thrift_spec) - - def write(self, oprot): - if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: - oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStruct(self, self.thrift_spec) - -class TExceptionBase(Exception): - # old style class so python2.4 can raise exceptions derived from this - # This can't inherit from TBase because of that limitation. - __slots__ = [] - - __repr__ = TBase.__repr__.im_func - __eq__ = TBase.__eq__.im_func - __ne__ = TBase.__ne__.im_func - read = TBase.read.im_func - write = TBase.write.im_func - diff --git a/anknotes/thrift/protocol/TBase.py~HEAD b/anknotes/thrift/protocol/TBase.py~HEAD deleted file mode 100644 index e675c7d..0000000 --- a/anknotes/thrift/protocol/TBase.py~HEAD +++ /dev/null @@ -1,72 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from thrift.Thrift import * -from thrift.protocol import TBinaryProtocol -from thrift.transport import TTransport - -try: - from thrift.protocol import fastbinary -except: - fastbinary = None - -class TBase(object): - __slots__ = [] - - def __repr__(self): - L = ['%s=%r' % (key, getattr(self, key)) - for key in self.__slots__ ] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - for attr in self.__slots__: - my_val = getattr(self, attr) - other_val = getattr(other, attr) - if my_val != other_val: - return False - return True - - def __ne__(self, other): - return not (self == other) - - def read(self, iprot): - if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: - fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) - return - iprot.readStruct(self, self.thrift_spec) - - def write(self, oprot): - if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: - oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStruct(self, self.thrift_spec) - -class TExceptionBase(Exception): - # old style class so python2.4 can raise exceptions derived from this - # This can't inherit from TBase because of that limitation. - __slots__ = [] - - __repr__ = TBase.__repr__.im_func - __eq__ = TBase.__eq__.im_func - __ne__ = TBase.__ne__.im_func - read = TBase.read.im_func - write = TBase.write.im_func - diff --git a/anknotes/thrift/protocol/TBinaryProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/TBinaryProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 50c6aa8..0000000 --- a/anknotes/thrift/protocol/TBinaryProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,259 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TProtocol import * -from struct import pack, unpack - -class TBinaryProtocol(TProtocolBase): - - """Binary implementation of the Thrift protocol driver.""" - - # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be - # positive, converting this into a long. If we hardcode the int value - # instead it'll stay in 32 bit-land. - - # VERSION_MASK = 0xffff0000 - VERSION_MASK = -65536 - - # VERSION_1 = 0x80010000 - VERSION_1 = -2147418112 - - TYPE_MASK = 0x000000ff - - def __init__(self, trans, strictRead=False, strictWrite=True): - TProtocolBase.__init__(self, trans) - self.strictRead = strictRead - self.strictWrite = strictWrite - - def writeMessageBegin(self, name, type, seqid): - if self.strictWrite: - self.writeI32(TBinaryProtocol.VERSION_1 | type) - self.writeString(name) - self.writeI32(seqid) - else: - self.writeString(name) - self.writeByte(type) - self.writeI32(seqid) - - def writeMessageEnd(self): - pass - - def writeStructBegin(self, name): - pass - - def writeStructEnd(self): - pass - - def writeFieldBegin(self, name, type, id): - self.writeByte(type) - self.writeI16(id) - - def writeFieldEnd(self): - pass - - def writeFieldStop(self): - self.writeByte(TType.STOP); - - def writeMapBegin(self, ktype, vtype, size): - self.writeByte(ktype) - self.writeByte(vtype) - self.writeI32(size) - - def writeMapEnd(self): - pass - - def writeListBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeListEnd(self): - pass - - def writeSetBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeSetEnd(self): - pass - - def writeBool(self, bool): - if bool: - self.writeByte(1) - else: - self.writeByte(0) - - def writeByte(self, byte): - buff = pack("!b", byte) - self.trans.write(buff) - - def writeI16(self, i16): - buff = pack("!h", i16) - self.trans.write(buff) - - def writeI32(self, i32): - buff = pack("!i", i32) - self.trans.write(buff) - - def writeI64(self, i64): - buff = pack("!q", i64) - self.trans.write(buff) - - def writeDouble(self, dub): - buff = pack("!d", dub) - self.trans.write(buff) - - def writeString(self, str): - self.writeI32(len(str)) - self.trans.write(str) - - def readMessageBegin(self): - sz = self.readI32() - if sz < 0: - version = sz & TBinaryProtocol.VERSION_MASK - if version != TBinaryProtocol.VERSION_1: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) - type = sz & TBinaryProtocol.TYPE_MASK - name = self.readString() - seqid = self.readI32() - else: - if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') - name = self.trans.readAll(sz) - type = self.readByte() - seqid = self.readI32() - return (name, type, seqid) - - def readMessageEnd(self): - pass - - def readStructBegin(self): - pass - - def readStructEnd(self): - pass - - def readFieldBegin(self): - type = self.readByte() - if type == TType.STOP: - return (None, type, 0) - id = self.readI16() - return (None, type, id) - - def readFieldEnd(self): - pass - - def readMapBegin(self): - ktype = self.readByte() - vtype = self.readByte() - size = self.readI32() - return (ktype, vtype, size) - - def readMapEnd(self): - pass - - def readListBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) - - def readListEnd(self): - pass - - def readSetBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) - - def readSetEnd(self): - pass - - def readBool(self): - byte = self.readByte() - if byte == 0: - return False - return True - - def readByte(self): - buff = self.trans.readAll(1) - val, = unpack('!b', buff) - return val - - def readI16(self): - buff = self.trans.readAll(2) - val, = unpack('!h', buff) - return val - - def readI32(self): - buff = self.trans.readAll(4) - val, = unpack('!i', buff) - return val - - def readI64(self): - buff = self.trans.readAll(8) - val, = unpack('!q', buff) - return val - - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val - - def readString(self): - len = self.readI32() - str = self.trans.readAll(len) - return str - - -class TBinaryProtocolFactory: - def __init__(self, strictRead=False, strictWrite=True): - self.strictRead = strictRead - self.strictWrite = strictWrite - - def getProtocol(self, trans): - prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) - return prot - - -class TBinaryProtocolAccelerated(TBinaryProtocol): - - """C-Accelerated version of TBinaryProtocol. - - This class does not override any of TBinaryProtocol's methods, - but the generated code recognizes it directly and will call into - our C module to do the encoding, bypassing this object entirely. - We inherit from TBinaryProtocol so that the normal TBinaryProtocol - encoding can happen if the fastbinary module doesn't work for some - reason. (TODO(dreiss): Make this happen sanely in more cases.) - - In order to take advantage of the C module, just use - TBinaryProtocolAccelerated instead of TBinaryProtocol. - - NOTE: This code was contributed by an external developer. - The internal Thrift team has reviewed and tested it, - but we cannot guarantee that it is production-ready. - Please feel free to report bugs and/or success stories - to the public mailing list. - """ - - pass - - -class TBinaryProtocolAcceleratedFactory: - def getProtocol(self, trans): - return TBinaryProtocolAccelerated(trans) diff --git a/anknotes/thrift/protocol/TBinaryProtocol.py~HEAD b/anknotes/thrift/protocol/TBinaryProtocol.py~HEAD deleted file mode 100644 index 50c6aa8..0000000 --- a/anknotes/thrift/protocol/TBinaryProtocol.py~HEAD +++ /dev/null @@ -1,259 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TProtocol import * -from struct import pack, unpack - -class TBinaryProtocol(TProtocolBase): - - """Binary implementation of the Thrift protocol driver.""" - - # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be - # positive, converting this into a long. If we hardcode the int value - # instead it'll stay in 32 bit-land. - - # VERSION_MASK = 0xffff0000 - VERSION_MASK = -65536 - - # VERSION_1 = 0x80010000 - VERSION_1 = -2147418112 - - TYPE_MASK = 0x000000ff - - def __init__(self, trans, strictRead=False, strictWrite=True): - TProtocolBase.__init__(self, trans) - self.strictRead = strictRead - self.strictWrite = strictWrite - - def writeMessageBegin(self, name, type, seqid): - if self.strictWrite: - self.writeI32(TBinaryProtocol.VERSION_1 | type) - self.writeString(name) - self.writeI32(seqid) - else: - self.writeString(name) - self.writeByte(type) - self.writeI32(seqid) - - def writeMessageEnd(self): - pass - - def writeStructBegin(self, name): - pass - - def writeStructEnd(self): - pass - - def writeFieldBegin(self, name, type, id): - self.writeByte(type) - self.writeI16(id) - - def writeFieldEnd(self): - pass - - def writeFieldStop(self): - self.writeByte(TType.STOP); - - def writeMapBegin(self, ktype, vtype, size): - self.writeByte(ktype) - self.writeByte(vtype) - self.writeI32(size) - - def writeMapEnd(self): - pass - - def writeListBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeListEnd(self): - pass - - def writeSetBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeSetEnd(self): - pass - - def writeBool(self, bool): - if bool: - self.writeByte(1) - else: - self.writeByte(0) - - def writeByte(self, byte): - buff = pack("!b", byte) - self.trans.write(buff) - - def writeI16(self, i16): - buff = pack("!h", i16) - self.trans.write(buff) - - def writeI32(self, i32): - buff = pack("!i", i32) - self.trans.write(buff) - - def writeI64(self, i64): - buff = pack("!q", i64) - self.trans.write(buff) - - def writeDouble(self, dub): - buff = pack("!d", dub) - self.trans.write(buff) - - def writeString(self, str): - self.writeI32(len(str)) - self.trans.write(str) - - def readMessageBegin(self): - sz = self.readI32() - if sz < 0: - version = sz & TBinaryProtocol.VERSION_MASK - if version != TBinaryProtocol.VERSION_1: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) - type = sz & TBinaryProtocol.TYPE_MASK - name = self.readString() - seqid = self.readI32() - else: - if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') - name = self.trans.readAll(sz) - type = self.readByte() - seqid = self.readI32() - return (name, type, seqid) - - def readMessageEnd(self): - pass - - def readStructBegin(self): - pass - - def readStructEnd(self): - pass - - def readFieldBegin(self): - type = self.readByte() - if type == TType.STOP: - return (None, type, 0) - id = self.readI16() - return (None, type, id) - - def readFieldEnd(self): - pass - - def readMapBegin(self): - ktype = self.readByte() - vtype = self.readByte() - size = self.readI32() - return (ktype, vtype, size) - - def readMapEnd(self): - pass - - def readListBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) - - def readListEnd(self): - pass - - def readSetBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) - - def readSetEnd(self): - pass - - def readBool(self): - byte = self.readByte() - if byte == 0: - return False - return True - - def readByte(self): - buff = self.trans.readAll(1) - val, = unpack('!b', buff) - return val - - def readI16(self): - buff = self.trans.readAll(2) - val, = unpack('!h', buff) - return val - - def readI32(self): - buff = self.trans.readAll(4) - val, = unpack('!i', buff) - return val - - def readI64(self): - buff = self.trans.readAll(8) - val, = unpack('!q', buff) - return val - - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val - - def readString(self): - len = self.readI32() - str = self.trans.readAll(len) - return str - - -class TBinaryProtocolFactory: - def __init__(self, strictRead=False, strictWrite=True): - self.strictRead = strictRead - self.strictWrite = strictWrite - - def getProtocol(self, trans): - prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) - return prot - - -class TBinaryProtocolAccelerated(TBinaryProtocol): - - """C-Accelerated version of TBinaryProtocol. - - This class does not override any of TBinaryProtocol's methods, - but the generated code recognizes it directly and will call into - our C module to do the encoding, bypassing this object entirely. - We inherit from TBinaryProtocol so that the normal TBinaryProtocol - encoding can happen if the fastbinary module doesn't work for some - reason. (TODO(dreiss): Make this happen sanely in more cases.) - - In order to take advantage of the C module, just use - TBinaryProtocolAccelerated instead of TBinaryProtocol. - - NOTE: This code was contributed by an external developer. - The internal Thrift team has reviewed and tested it, - but we cannot guarantee that it is production-ready. - Please feel free to report bugs and/or success stories - to the public mailing list. - """ - - pass - - -class TBinaryProtocolAcceleratedFactory: - def getProtocol(self, trans): - return TBinaryProtocolAccelerated(trans) diff --git a/anknotes/thrift/protocol/TCompactProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/TCompactProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 016a331..0000000 --- a/anknotes/thrift/protocol/TCompactProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,395 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TProtocol import * -from struct import pack, unpack - -__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] - -CLEAR = 0 -FIELD_WRITE = 1 -VALUE_WRITE = 2 -CONTAINER_WRITE = 3 -BOOL_WRITE = 4 -FIELD_READ = 5 -CONTAINER_READ = 6 -VALUE_READ = 7 -BOOL_READ = 8 - -def make_helper(v_from, container): - def helper(func): - def nested(self, *args, **kwargs): - assert self.state in (v_from, container), (self.state, v_from, container) - return func(self, *args, **kwargs) - return nested - return helper -writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) -reader = make_helper(VALUE_READ, CONTAINER_READ) - -def makeZigZag(n, bits): - return (n << 1) ^ (n >> (bits - 1)) - -def fromZigZag(n): - return (n >> 1) ^ -(n & 1) - -def writeVarint(trans, n): - out = [] - while True: - if n & ~0x7f == 0: - out.append(n) - break - else: - out.append((n & 0xff) | 0x80) - n = n >> 7 - trans.write(''.join(map(chr, out))) - -def readVarint(trans): - result = 0 - shift = 0 - while True: - x = trans.readAll(1) - byte = ord(x) - result |= (byte & 0x7f) << shift - if byte >> 7 == 0: - return result - shift += 7 - -class CompactType: - STOP = 0x00 - TRUE = 0x01 - FALSE = 0x02 - BYTE = 0x03 - I16 = 0x04 - I32 = 0x05 - I64 = 0x06 - DOUBLE = 0x07 - BINARY = 0x08 - LIST = 0x09 - SET = 0x0A - MAP = 0x0B - STRUCT = 0x0C - -CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection - TType.BYTE: CompactType.BYTE, - TType.I16: CompactType.I16, - TType.I32: CompactType.I32, - TType.I64: CompactType.I64, - TType.DOUBLE: CompactType.DOUBLE, - TType.STRING: CompactType.BINARY, - TType.STRUCT: CompactType.STRUCT, - TType.LIST: CompactType.LIST, - TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP - } - -TTYPES = {} -for k, v in CTYPES.items(): - TTYPES[v] = k -TTYPES[CompactType.FALSE] = TType.BOOL -del k -del v - -class TCompactProtocol(TProtocolBase): - "Compact implementation of the Thrift protocol driver." - - PROTOCOL_ID = 0x82 - VERSION = 1 - VERSION_MASK = 0x1f - TYPE_MASK = 0xe0 - TYPE_SHIFT_AMOUNT = 5 - - def __init__(self, trans): - TProtocolBase.__init__(self, trans) - self.state = CLEAR - self.__last_fid = 0 - self.__bool_fid = None - self.__bool_value = None - self.__structs = [] - self.__containers = [] - - def __writeVarint(self, n): - writeVarint(self.trans, n) - - def writeMessageBegin(self, name, type, seqid): - assert self.state == CLEAR - self.__writeUByte(self.PROTOCOL_ID) - self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) - self.__writeVarint(seqid) - self.__writeString(name) - self.state = VALUE_WRITE - - def writeMessageEnd(self): - assert self.state == VALUE_WRITE - self.state = CLEAR - - def writeStructBegin(self, name): - assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_WRITE - self.__last_fid = 0 - - def writeStructEnd(self): - assert self.state == FIELD_WRITE - self.state, self.__last_fid = self.__structs.pop() - - def writeFieldStop(self): - self.__writeByte(0) - - def __writeFieldHeader(self, type, fid): - delta = fid - self.__last_fid - if 0 < delta <= 15: - self.__writeUByte(delta << 4 | type) - else: - self.__writeByte(type) - self.__writeI16(fid) - self.__last_fid = fid - - def writeFieldBegin(self, name, type, fid): - assert self.state == FIELD_WRITE, self.state - if type == TType.BOOL: - self.state = BOOL_WRITE - self.__bool_fid = fid - else: - self.state = VALUE_WRITE - self.__writeFieldHeader(CTYPES[type], fid) - - def writeFieldEnd(self): - assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state - self.state = FIELD_WRITE - - def __writeUByte(self, byte): - self.trans.write(pack('!B', byte)) - - def __writeByte(self, byte): - self.trans.write(pack('!b', byte)) - - def __writeI16(self, i16): - self.__writeVarint(makeZigZag(i16, 16)) - - def __writeSize(self, i32): - self.__writeVarint(i32) - - def writeCollectionBegin(self, etype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size <= 14: - self.__writeUByte(size << 4 | CTYPES[etype]) - else: - self.__writeUByte(0xf0 | CTYPES[etype]) - self.__writeSize(size) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - writeSetBegin = writeCollectionBegin - writeListBegin = writeCollectionBegin - - def writeMapBegin(self, ktype, vtype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size == 0: - self.__writeByte(0) - else: - self.__writeSize(size) - self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - - def writeCollectionEnd(self): - assert self.state == CONTAINER_WRITE, self.state - self.state = self.__containers.pop() - writeMapEnd = writeCollectionEnd - writeSetEnd = writeCollectionEnd - writeListEnd = writeCollectionEnd - - def writeBool(self, bool): - if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) - elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) - else: - raise AssertionError, "Invalid state in compact protocol" - - writeByte = writer(__writeByte) - writeI16 = writer(__writeI16) - - @writer - def writeI32(self, i32): - self.__writeVarint(makeZigZag(i32, 32)) - - @writer - def writeI64(self, i64): - self.__writeVarint(makeZigZag(i64, 64)) - - @writer - def writeDouble(self, dub): - self.trans.write(pack('!d', dub)) - - def __writeString(self, s): - self.__writeSize(len(s)) - self.trans.write(s) - writeString = writer(__writeString) - - def readFieldBegin(self): - assert self.state == FIELD_READ, self.state - type = self.__readUByte() - if type & 0x0f == TType.STOP: - return (None, 0, 0) - delta = type >> 4 - if delta == 0: - fid = self.__readI16() - else: - fid = self.__last_fid + delta - self.__last_fid = fid - type = type & 0x0f - if type == CompactType.TRUE: - self.state = BOOL_READ - self.__bool_value = True - elif type == CompactType.FALSE: - self.state = BOOL_READ - self.__bool_value = False - else: - self.state = VALUE_READ - return (None, self.__getTType(type), fid) - - def readFieldEnd(self): - assert self.state in (VALUE_READ, BOOL_READ), self.state - self.state = FIELD_READ - - def __readUByte(self): - result, = unpack('!B', self.trans.readAll(1)) - return result - - def __readByte(self): - result, = unpack('!b', self.trans.readAll(1)) - return result - - def __readVarint(self): - return readVarint(self.trans) - - def __readZigZag(self): - return fromZigZag(self.__readVarint()) - - def __readSize(self): - result = self.__readVarint() - if result < 0: - raise TException("Length < 0") - return result - - def readMessageBegin(self): - assert self.state == CLEAR - proto_id = self.__readUByte() - if proto_id != self.PROTOCOL_ID: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad protocol id in the message: %d' % proto_id) - ver_type = self.__readUByte() - type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT - version = ver_type & self.VERSION_MASK - if version != self.VERSION: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad version: %d (expect %d)' % (version, self.VERSION)) - seqid = self.__readVarint() - name = self.__readString() - return (name, type, seqid) - - def readMessageEnd(self): - assert self.state == CLEAR - assert len(self.__structs) == 0 - - def readStructBegin(self): - assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_READ - self.__last_fid = 0 - - def readStructEnd(self): - assert self.state == FIELD_READ - self.state, self.__last_fid = self.__structs.pop() - - def readCollectionBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size_type = self.__readUByte() - size = size_type >> 4 - type = self.__getTType(size_type) - if size == 15: - size = self.__readSize() - self.__containers.append(self.state) - self.state = CONTAINER_READ - return type, size - readSetBegin = readCollectionBegin - readListBegin = readCollectionBegin - - def readMapBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size = self.__readSize() - types = 0 - if size > 0: - types = self.__readUByte() - vtype = self.__getTType(types) - ktype = self.__getTType(types >> 4) - self.__containers.append(self.state) - self.state = CONTAINER_READ - return (ktype, vtype, size) - - def readCollectionEnd(self): - assert self.state == CONTAINER_READ, self.state - self.state = self.__containers.pop() - readSetEnd = readCollectionEnd - readListEnd = readCollectionEnd - readMapEnd = readCollectionEnd - - def readBool(self): - if self.state == BOOL_READ: - return self.__bool_value == CompactType.TRUE - elif self.state == CONTAINER_READ: - return self.__readByte() == CompactType.TRUE - else: - raise AssertionError, "Invalid state in compact protocol: %d" % self.state - - readByte = reader(__readByte) - __readI16 = __readZigZag - readI16 = reader(__readZigZag) - readI32 = reader(__readZigZag) - readI64 = reader(__readZigZag) - - @reader - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val - - def __readString(self): - len = self.__readSize() - return self.trans.readAll(len) - readString = reader(__readString) - - def __getTType(self, byte): - return TTYPES[byte & 0x0f] - - -class TCompactProtocolFactory: - def __init__(self): - pass - - def getProtocol(self, trans): - return TCompactProtocol(trans) diff --git a/anknotes/thrift/protocol/TCompactProtocol.py~HEAD b/anknotes/thrift/protocol/TCompactProtocol.py~HEAD deleted file mode 100644 index 016a331..0000000 --- a/anknotes/thrift/protocol/TCompactProtocol.py~HEAD +++ /dev/null @@ -1,395 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TProtocol import * -from struct import pack, unpack - -__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] - -CLEAR = 0 -FIELD_WRITE = 1 -VALUE_WRITE = 2 -CONTAINER_WRITE = 3 -BOOL_WRITE = 4 -FIELD_READ = 5 -CONTAINER_READ = 6 -VALUE_READ = 7 -BOOL_READ = 8 - -def make_helper(v_from, container): - def helper(func): - def nested(self, *args, **kwargs): - assert self.state in (v_from, container), (self.state, v_from, container) - return func(self, *args, **kwargs) - return nested - return helper -writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) -reader = make_helper(VALUE_READ, CONTAINER_READ) - -def makeZigZag(n, bits): - return (n << 1) ^ (n >> (bits - 1)) - -def fromZigZag(n): - return (n >> 1) ^ -(n & 1) - -def writeVarint(trans, n): - out = [] - while True: - if n & ~0x7f == 0: - out.append(n) - break - else: - out.append((n & 0xff) | 0x80) - n = n >> 7 - trans.write(''.join(map(chr, out))) - -def readVarint(trans): - result = 0 - shift = 0 - while True: - x = trans.readAll(1) - byte = ord(x) - result |= (byte & 0x7f) << shift - if byte >> 7 == 0: - return result - shift += 7 - -class CompactType: - STOP = 0x00 - TRUE = 0x01 - FALSE = 0x02 - BYTE = 0x03 - I16 = 0x04 - I32 = 0x05 - I64 = 0x06 - DOUBLE = 0x07 - BINARY = 0x08 - LIST = 0x09 - SET = 0x0A - MAP = 0x0B - STRUCT = 0x0C - -CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection - TType.BYTE: CompactType.BYTE, - TType.I16: CompactType.I16, - TType.I32: CompactType.I32, - TType.I64: CompactType.I64, - TType.DOUBLE: CompactType.DOUBLE, - TType.STRING: CompactType.BINARY, - TType.STRUCT: CompactType.STRUCT, - TType.LIST: CompactType.LIST, - TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP - } - -TTYPES = {} -for k, v in CTYPES.items(): - TTYPES[v] = k -TTYPES[CompactType.FALSE] = TType.BOOL -del k -del v - -class TCompactProtocol(TProtocolBase): - "Compact implementation of the Thrift protocol driver." - - PROTOCOL_ID = 0x82 - VERSION = 1 - VERSION_MASK = 0x1f - TYPE_MASK = 0xe0 - TYPE_SHIFT_AMOUNT = 5 - - def __init__(self, trans): - TProtocolBase.__init__(self, trans) - self.state = CLEAR - self.__last_fid = 0 - self.__bool_fid = None - self.__bool_value = None - self.__structs = [] - self.__containers = [] - - def __writeVarint(self, n): - writeVarint(self.trans, n) - - def writeMessageBegin(self, name, type, seqid): - assert self.state == CLEAR - self.__writeUByte(self.PROTOCOL_ID) - self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) - self.__writeVarint(seqid) - self.__writeString(name) - self.state = VALUE_WRITE - - def writeMessageEnd(self): - assert self.state == VALUE_WRITE - self.state = CLEAR - - def writeStructBegin(self, name): - assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_WRITE - self.__last_fid = 0 - - def writeStructEnd(self): - assert self.state == FIELD_WRITE - self.state, self.__last_fid = self.__structs.pop() - - def writeFieldStop(self): - self.__writeByte(0) - - def __writeFieldHeader(self, type, fid): - delta = fid - self.__last_fid - if 0 < delta <= 15: - self.__writeUByte(delta << 4 | type) - else: - self.__writeByte(type) - self.__writeI16(fid) - self.__last_fid = fid - - def writeFieldBegin(self, name, type, fid): - assert self.state == FIELD_WRITE, self.state - if type == TType.BOOL: - self.state = BOOL_WRITE - self.__bool_fid = fid - else: - self.state = VALUE_WRITE - self.__writeFieldHeader(CTYPES[type], fid) - - def writeFieldEnd(self): - assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state - self.state = FIELD_WRITE - - def __writeUByte(self, byte): - self.trans.write(pack('!B', byte)) - - def __writeByte(self, byte): - self.trans.write(pack('!b', byte)) - - def __writeI16(self, i16): - self.__writeVarint(makeZigZag(i16, 16)) - - def __writeSize(self, i32): - self.__writeVarint(i32) - - def writeCollectionBegin(self, etype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size <= 14: - self.__writeUByte(size << 4 | CTYPES[etype]) - else: - self.__writeUByte(0xf0 | CTYPES[etype]) - self.__writeSize(size) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - writeSetBegin = writeCollectionBegin - writeListBegin = writeCollectionBegin - - def writeMapBegin(self, ktype, vtype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size == 0: - self.__writeByte(0) - else: - self.__writeSize(size) - self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - - def writeCollectionEnd(self): - assert self.state == CONTAINER_WRITE, self.state - self.state = self.__containers.pop() - writeMapEnd = writeCollectionEnd - writeSetEnd = writeCollectionEnd - writeListEnd = writeCollectionEnd - - def writeBool(self, bool): - if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) - elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) - else: - raise AssertionError, "Invalid state in compact protocol" - - writeByte = writer(__writeByte) - writeI16 = writer(__writeI16) - - @writer - def writeI32(self, i32): - self.__writeVarint(makeZigZag(i32, 32)) - - @writer - def writeI64(self, i64): - self.__writeVarint(makeZigZag(i64, 64)) - - @writer - def writeDouble(self, dub): - self.trans.write(pack('!d', dub)) - - def __writeString(self, s): - self.__writeSize(len(s)) - self.trans.write(s) - writeString = writer(__writeString) - - def readFieldBegin(self): - assert self.state == FIELD_READ, self.state - type = self.__readUByte() - if type & 0x0f == TType.STOP: - return (None, 0, 0) - delta = type >> 4 - if delta == 0: - fid = self.__readI16() - else: - fid = self.__last_fid + delta - self.__last_fid = fid - type = type & 0x0f - if type == CompactType.TRUE: - self.state = BOOL_READ - self.__bool_value = True - elif type == CompactType.FALSE: - self.state = BOOL_READ - self.__bool_value = False - else: - self.state = VALUE_READ - return (None, self.__getTType(type), fid) - - def readFieldEnd(self): - assert self.state in (VALUE_READ, BOOL_READ), self.state - self.state = FIELD_READ - - def __readUByte(self): - result, = unpack('!B', self.trans.readAll(1)) - return result - - def __readByte(self): - result, = unpack('!b', self.trans.readAll(1)) - return result - - def __readVarint(self): - return readVarint(self.trans) - - def __readZigZag(self): - return fromZigZag(self.__readVarint()) - - def __readSize(self): - result = self.__readVarint() - if result < 0: - raise TException("Length < 0") - return result - - def readMessageBegin(self): - assert self.state == CLEAR - proto_id = self.__readUByte() - if proto_id != self.PROTOCOL_ID: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad protocol id in the message: %d' % proto_id) - ver_type = self.__readUByte() - type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT - version = ver_type & self.VERSION_MASK - if version != self.VERSION: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad version: %d (expect %d)' % (version, self.VERSION)) - seqid = self.__readVarint() - name = self.__readString() - return (name, type, seqid) - - def readMessageEnd(self): - assert self.state == CLEAR - assert len(self.__structs) == 0 - - def readStructBegin(self): - assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_READ - self.__last_fid = 0 - - def readStructEnd(self): - assert self.state == FIELD_READ - self.state, self.__last_fid = self.__structs.pop() - - def readCollectionBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size_type = self.__readUByte() - size = size_type >> 4 - type = self.__getTType(size_type) - if size == 15: - size = self.__readSize() - self.__containers.append(self.state) - self.state = CONTAINER_READ - return type, size - readSetBegin = readCollectionBegin - readListBegin = readCollectionBegin - - def readMapBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size = self.__readSize() - types = 0 - if size > 0: - types = self.__readUByte() - vtype = self.__getTType(types) - ktype = self.__getTType(types >> 4) - self.__containers.append(self.state) - self.state = CONTAINER_READ - return (ktype, vtype, size) - - def readCollectionEnd(self): - assert self.state == CONTAINER_READ, self.state - self.state = self.__containers.pop() - readSetEnd = readCollectionEnd - readListEnd = readCollectionEnd - readMapEnd = readCollectionEnd - - def readBool(self): - if self.state == BOOL_READ: - return self.__bool_value == CompactType.TRUE - elif self.state == CONTAINER_READ: - return self.__readByte() == CompactType.TRUE - else: - raise AssertionError, "Invalid state in compact protocol: %d" % self.state - - readByte = reader(__readByte) - __readI16 = __readZigZag - readI16 = reader(__readZigZag) - readI32 = reader(__readZigZag) - readI64 = reader(__readZigZag) - - @reader - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val - - def __readString(self): - len = self.__readSize() - return self.trans.readAll(len) - readString = reader(__readString) - - def __getTType(self, byte): - return TTYPES[byte & 0x0f] - - -class TCompactProtocolFactory: - def __init__(self): - pass - - def getProtocol(self, trans): - return TCompactProtocol(trans) diff --git a/anknotes/thrift/protocol/TProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/TProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index d6d3938..0000000 --- a/anknotes/thrift/protocol/TProtocol.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,404 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from anknotes.thrift.Thrift import * - -class TProtocolException(TException): - - """Custom Protocol Exception class""" - - UNKNOWN = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - -class TProtocolBase: - - """Base class for Thrift protocol driver.""" - - def __init__(self, trans): - self.trans = trans - - def writeMessageBegin(self, name, type, seqid): - pass - - def writeMessageEnd(self): - pass - - def writeStructBegin(self, name): - pass - - def writeStructEnd(self): - pass - - def writeFieldBegin(self, name, type, id): - pass - - def writeFieldEnd(self): - pass - - def writeFieldStop(self): - pass - - def writeMapBegin(self, ktype, vtype, size): - pass - - def writeMapEnd(self): - pass - - def writeListBegin(self, etype, size): - pass - - def writeListEnd(self): - pass - - def writeSetBegin(self, etype, size): - pass - - def writeSetEnd(self): - pass - - def writeBool(self, bool): - pass - - def writeByte(self, byte): - pass - - def writeI16(self, i16): - pass - - def writeI32(self, i32): - pass - - def writeI64(self, i64): - pass - - def writeDouble(self, dub): - pass - - def writeString(self, str): - pass - - def readMessageBegin(self): - pass - - def readMessageEnd(self): - pass - - def readStructBegin(self): - pass - - def readStructEnd(self): - pass - - def readFieldBegin(self): - pass - - def readFieldEnd(self): - pass - - def readMapBegin(self): - pass - - def readMapEnd(self): - pass - - def readListBegin(self): - pass - - def readListEnd(self): - pass - - def readSetBegin(self): - pass - - def readSetEnd(self): - pass - - def readBool(self): - pass - - def readByte(self): - pass - - def readI16(self): - pass - - def readI32(self): - pass - - def readI64(self): - pass - - def readDouble(self): - pass - - def readString(self): - pass - - def skip(self, type): - if type == TType.STOP: - return - elif type == TType.BOOL: - self.readBool() - elif type == TType.BYTE: - self.readByte() - elif type == TType.I16: - self.readI16() - elif type == TType.I32: - self.readI32() - elif type == TType.I64: - self.readI64() - elif type == TType.DOUBLE: - self.readDouble() - elif type == TType.STRING: - self.readString() - elif type == TType.STRUCT: - name = self.readStructBegin() - while True: - (name, type, id) = self.readFieldBegin() - if type == TType.STOP: - break - self.skip(type) - self.readFieldEnd() - self.readStructEnd() - elif type == TType.MAP: - (ktype, vtype, size) = self.readMapBegin() - for i in range(size): - self.skip(ktype) - self.skip(vtype) - self.readMapEnd() - elif type == TType.SET: - (etype, size) = self.readSetBegin() - for i in range(size): - self.skip(etype) - self.readSetEnd() - elif type == TType.LIST: - (etype, size) = self.readListBegin() - for i in range(size): - self.skip(etype) - self.readListEnd() - - # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) - _TTYPE_HANDLERS = ( - (None, None, False), # 0 == TType,STOP - (None, None, False), # 1 == TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 == TType.BOOL - ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE - (None, None, False), # 5, undefined - ('readI16', 'writeI16', False), # 6 == TType.I16 - (None, None, False), # 7, undefined - ('readI32', 'writeI32', False), # 8 == TType.I32 - (None, None, False), # 9, undefined - ('readI64', 'writeI64', False), # 10 == TType.I64 - ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET - ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST - (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? - (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? - ) - - def readFieldByTType(self, ttype, spec): - try: - (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] - except IndexError: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - if r_handler is None: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - reader = getattr(self, r_handler) - if not is_container: - return reader() - return reader(spec) - - def readContainerList(self, spec): - results = [] - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (list_type, list_len) = self.readListBegin() - if tspec is None: - # list values are simple types - for idx in xrange(list_len): - results.append(reader()) - else: - # this is like an inlined readFieldByTType - container_reader = self._TTYPE_HANDLERS[list_type][0] - val_reader = getattr(self, container_reader) - for idx in xrange(list_len): - val = val_reader(tspec) - results.append(val) - self.readListEnd() - return results - - def readContainerSet(self, spec): - results = set() - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (set_type, set_len) = self.readSetBegin() - if tspec is None: - # set members are simple types - for idx in xrange(set_len): - results.add(reader()) - else: - container_reader = self._TTYPE_HANDLERS[set_type][0] - val_reader = getattr(self, container_reader) - for idx in xrange(set_len): - results.add(val_reader(tspec)) - self.readSetEnd() - return results - - def readContainerStruct(self, spec): - (obj_class, obj_spec) = spec - obj = obj_class() - obj.read(self) - return obj - - def readContainerMap(self, spec): - results = dict() - key_ttype, key_spec = spec[0], spec[1] - val_ttype, val_spec = spec[2], spec[3] - (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree - key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) - val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) - # list values are simple types - for idx in xrange(map_len): - if key_spec is None: - k_val = key_reader() - else: - k_val = self.readFieldByTType(key_ttype, key_spec) - if val_spec is None: - v_val = val_reader() - else: - v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails - results[k_val] = v_val - self.readMapEnd() - return results - - def readStruct(self, obj, thrift_spec): - self.readStructBegin() - while True: - (fname, ftype, fid) = self.readFieldBegin() - if ftype == TType.STOP: - break - try: - field = thrift_spec[fid] - except IndexError: - self.skip(ftype) - else: - if field is not None and ftype == field[1]: - fname = field[2] - fspec = field[3] - val = self.readFieldByTType(ftype, fspec) - setattr(obj, fname, val) - else: - self.skip(ftype) - self.readFieldEnd() - self.readStructEnd() - - def writeContainerStruct(self, val, spec): - val.write(self) - - def writeContainerList(self, val, spec): - self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeListEnd() - - def writeContainerSet(self, val, spec): - self.writeSetBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeSetEnd() - - def writeContainerMap(self, val, spec): - k_type = spec[0] - v_type = spec[2] - ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] - ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] - k_writer = getattr(self, ktype_name) - v_writer = getattr(self, vtype_name) - self.writeMapBegin(k_type, v_type, len(val)) - for m_key, m_val in val.iteritems(): - if not k_is_container: - k_writer(m_key) - else: - k_writer(m_key, spec[1]) - if not v_is_container: - v_writer(m_val) - else: - v_writer(m_val, spec[3]) - self.writeMapEnd() - - def writeStruct(self, obj, thrift_spec): - self.writeStructBegin(obj.__class__.__name__) - for field in thrift_spec: - if field is None: - continue - fname = field[2] - val = getattr(obj, fname) - if val is None: - # skip writing out unset fields - continue - fid = field[0] - ftype = field[1] - fspec = field[3] - # get the writer method for this value - self.writeFieldBegin(fname, ftype, fid) - self.writeFieldByTType(ftype, val, fspec) - self.writeFieldEnd() - self.writeFieldStop() - self.writeStructEnd() - - def writeFieldByTType(self, ttype, val, spec): - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] - writer = getattr(self, w_handler) - if is_container: - writer(val, spec) - else: - writer(val) - -class TProtocolFactory: - def getProtocol(self, trans): - pass - diff --git a/anknotes/thrift/protocol/TProtocol.py~HEAD b/anknotes/thrift/protocol/TProtocol.py~HEAD deleted file mode 100644 index 7338ff6..0000000 --- a/anknotes/thrift/protocol/TProtocol.py~HEAD +++ /dev/null @@ -1,404 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from thrift.Thrift import * - -class TProtocolException(TException): - - """Custom Protocol Exception class""" - - UNKNOWN = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - -class TProtocolBase: - - """Base class for Thrift protocol driver.""" - - def __init__(self, trans): - self.trans = trans - - def writeMessageBegin(self, name, type, seqid): - pass - - def writeMessageEnd(self): - pass - - def writeStructBegin(self, name): - pass - - def writeStructEnd(self): - pass - - def writeFieldBegin(self, name, type, id): - pass - - def writeFieldEnd(self): - pass - - def writeFieldStop(self): - pass - - def writeMapBegin(self, ktype, vtype, size): - pass - - def writeMapEnd(self): - pass - - def writeListBegin(self, etype, size): - pass - - def writeListEnd(self): - pass - - def writeSetBegin(self, etype, size): - pass - - def writeSetEnd(self): - pass - - def writeBool(self, bool): - pass - - def writeByte(self, byte): - pass - - def writeI16(self, i16): - pass - - def writeI32(self, i32): - pass - - def writeI64(self, i64): - pass - - def writeDouble(self, dub): - pass - - def writeString(self, str): - pass - - def readMessageBegin(self): - pass - - def readMessageEnd(self): - pass - - def readStructBegin(self): - pass - - def readStructEnd(self): - pass - - def readFieldBegin(self): - pass - - def readFieldEnd(self): - pass - - def readMapBegin(self): - pass - - def readMapEnd(self): - pass - - def readListBegin(self): - pass - - def readListEnd(self): - pass - - def readSetBegin(self): - pass - - def readSetEnd(self): - pass - - def readBool(self): - pass - - def readByte(self): - pass - - def readI16(self): - pass - - def readI32(self): - pass - - def readI64(self): - pass - - def readDouble(self): - pass - - def readString(self): - pass - - def skip(self, type): - if type == TType.STOP: - return - elif type == TType.BOOL: - self.readBool() - elif type == TType.BYTE: - self.readByte() - elif type == TType.I16: - self.readI16() - elif type == TType.I32: - self.readI32() - elif type == TType.I64: - self.readI64() - elif type == TType.DOUBLE: - self.readDouble() - elif type == TType.STRING: - self.readString() - elif type == TType.STRUCT: - name = self.readStructBegin() - while True: - (name, type, id) = self.readFieldBegin() - if type == TType.STOP: - break - self.skip(type) - self.readFieldEnd() - self.readStructEnd() - elif type == TType.MAP: - (ktype, vtype, size) = self.readMapBegin() - for i in range(size): - self.skip(ktype) - self.skip(vtype) - self.readMapEnd() - elif type == TType.SET: - (etype, size) = self.readSetBegin() - for i in range(size): - self.skip(etype) - self.readSetEnd() - elif type == TType.LIST: - (etype, size) = self.readListBegin() - for i in range(size): - self.skip(etype) - self.readListEnd() - - # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) - _TTYPE_HANDLERS = ( - (None, None, False), # 0 == TType,STOP - (None, None, False), # 1 == TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 == TType.BOOL - ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE - (None, None, False), # 5, undefined - ('readI16', 'writeI16', False), # 6 == TType.I16 - (None, None, False), # 7, undefined - ('readI32', 'writeI32', False), # 8 == TType.I32 - (None, None, False), # 9, undefined - ('readI64', 'writeI64', False), # 10 == TType.I64 - ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET - ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST - (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? - (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? - ) - - def readFieldByTType(self, ttype, spec): - try: - (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] - except IndexError: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - if r_handler is None: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - reader = getattr(self, r_handler) - if not is_container: - return reader() - return reader(spec) - - def readContainerList(self, spec): - results = [] - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (list_type, list_len) = self.readListBegin() - if tspec is None: - # list values are simple types - for idx in xrange(list_len): - results.append(reader()) - else: - # this is like an inlined readFieldByTType - container_reader = self._TTYPE_HANDLERS[list_type][0] - val_reader = getattr(self, container_reader) - for idx in xrange(list_len): - val = val_reader(tspec) - results.append(val) - self.readListEnd() - return results - - def readContainerSet(self, spec): - results = set() - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (set_type, set_len) = self.readSetBegin() - if tspec is None: - # set members are simple types - for idx in xrange(set_len): - results.add(reader()) - else: - container_reader = self._TTYPE_HANDLERS[set_type][0] - val_reader = getattr(self, container_reader) - for idx in xrange(set_len): - results.add(val_reader(tspec)) - self.readSetEnd() - return results - - def readContainerStruct(self, spec): - (obj_class, obj_spec) = spec - obj = obj_class() - obj.read(self) - return obj - - def readContainerMap(self, spec): - results = dict() - key_ttype, key_spec = spec[0], spec[1] - val_ttype, val_spec = spec[2], spec[3] - (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree - key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) - val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) - # list values are simple types - for idx in xrange(map_len): - if key_spec is None: - k_val = key_reader() - else: - k_val = self.readFieldByTType(key_ttype, key_spec) - if val_spec is None: - v_val = val_reader() - else: - v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails - results[k_val] = v_val - self.readMapEnd() - return results - - def readStruct(self, obj, thrift_spec): - self.readStructBegin() - while True: - (fname, ftype, fid) = self.readFieldBegin() - if ftype == TType.STOP: - break - try: - field = thrift_spec[fid] - except IndexError: - self.skip(ftype) - else: - if field is not None and ftype == field[1]: - fname = field[2] - fspec = field[3] - val = self.readFieldByTType(ftype, fspec) - setattr(obj, fname, val) - else: - self.skip(ftype) - self.readFieldEnd() - self.readStructEnd() - - def writeContainerStruct(self, val, spec): - val.write(self) - - def writeContainerList(self, val, spec): - self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeListEnd() - - def writeContainerSet(self, val, spec): - self.writeSetBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeSetEnd() - - def writeContainerMap(self, val, spec): - k_type = spec[0] - v_type = spec[2] - ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] - ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] - k_writer = getattr(self, ktype_name) - v_writer = getattr(self, vtype_name) - self.writeMapBegin(k_type, v_type, len(val)) - for m_key, m_val in val.iteritems(): - if not k_is_container: - k_writer(m_key) - else: - k_writer(m_key, spec[1]) - if not v_is_container: - v_writer(m_val) - else: - v_writer(m_val, spec[3]) - self.writeMapEnd() - - def writeStruct(self, obj, thrift_spec): - self.writeStructBegin(obj.__class__.__name__) - for field in thrift_spec: - if field is None: - continue - fname = field[2] - val = getattr(obj, fname) - if val is None: - # skip writing out unset fields - continue - fid = field[0] - ftype = field[1] - fspec = field[3] - # get the writer method for this value - self.writeFieldBegin(fname, ftype, fid) - self.writeFieldByTType(ftype, val, fspec) - self.writeFieldEnd() - self.writeFieldStop() - self.writeStructEnd() - - def writeFieldByTType(self, ttype, val, spec): - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] - writer = getattr(self, w_handler) - if is_container: - writer(val, spec) - else: - writer(val) - -class TProtocolFactory: - def getProtocol(self, trans): - pass - diff --git a/anknotes/thrift/protocol/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index d53359b..0000000 --- a/anknotes/thrift/protocol/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] diff --git a/anknotes/thrift/protocol/__init__.py~HEAD b/anknotes/thrift/protocol/__init__.py~HEAD deleted file mode 100644 index d53359b..0000000 --- a/anknotes/thrift/protocol/__init__.py~HEAD +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] diff --git a/anknotes/thrift/protocol/fastbinary.c~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/protocol/fastbinary.c~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 2ce5660..0000000 --- a/anknotes/thrift/protocol/fastbinary.c~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,1219 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include <Python.h> -#include "cStringIO.h" -#include <stdint.h> -#ifndef _WIN32 -# include <stdbool.h> -# include <netinet/in.h> -#else -# include <WinSock2.h> -# pragma comment (lib, "ws2_32.lib") -# define BIG_ENDIAN (4321) -# define LITTLE_ENDIAN (1234) -# define BYTE_ORDER LITTLE_ENDIAN -# if defined(_MSC_VER) && _MSC_VER < 1600 - typedef int _Bool; -# define bool _Bool -# define false 0 -# define true 1 -# endif -# define inline __inline -#endif - -/* Fix endianness issues on Solaris */ -#if defined (__SVR4) && defined (__sun) - #if defined(__i386) && !defined(__i386__) - #define __i386__ - #endif - - #ifndef BIG_ENDIAN - #define BIG_ENDIAN (4321) - #endif - #ifndef LITTLE_ENDIAN - #define LITTLE_ENDIAN (1234) - #endif - - /* I386 is LE, even on Solaris */ - #if !defined(BYTE_ORDER) && defined(__i386__) - #define BYTE_ORDER LITTLE_ENDIAN - #endif -#endif - -// TODO(dreiss): defval appears to be unused. Look into removing it. -// TODO(dreiss): Make parse_spec_args recursive, and cache the output -// permanently in the object. (Malloc and orphan.) -// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? -// Can cStringIO let us work with a BufferedTransport? -// TODO(dreiss): Don't ignore the rv from cwrite (maybe). - -/* ====== BEGIN UTILITIES ====== */ - -#define INIT_OUTBUF_SIZE 128 - -// Stolen out of TProtocol.h. -// It would be a huge pain to have both get this from one place. -typedef enum TType { - T_STOP = 0, - T_VOID = 1, - T_BOOL = 2, - T_BYTE = 3, - T_I08 = 3, - T_I16 = 6, - T_I32 = 8, - T_U64 = 9, - T_I64 = 10, - T_DOUBLE = 4, - T_STRING = 11, - T_UTF7 = 11, - T_STRUCT = 12, - T_MAP = 13, - T_SET = 14, - T_LIST = 15, - T_UTF8 = 16, - T_UTF16 = 17 -} TType; - -#ifndef __BYTE_ORDER -# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) -# define __BYTE_ORDER BYTE_ORDER -# define __LITTLE_ENDIAN LITTLE_ENDIAN -# define __BIG_ENDIAN BIG_ENDIAN -# else -# error "Cannot determine endianness" -# endif -#endif - -// Same comment as the enum. Sorry. -#if __BYTE_ORDER == __BIG_ENDIAN -# define ntohll(n) (n) -# define htonll(n) (n) -#elif __BYTE_ORDER == __LITTLE_ENDIAN -# if defined(__GNUC__) && defined(__GLIBC__) -# include <byteswap.h> -# define ntohll(n) bswap_64(n) -# define htonll(n) bswap_64(n) -# else /* GNUC & GLIBC */ -# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) -# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) -# endif /* GNUC & GLIBC */ -#else /* __BYTE_ORDER */ -# error "Can't define htonll or ntohll!" -#endif - -// Doing a benchmark shows that interning actually makes a difference, amazingly. -#define INTERN_STRING(value) _intern_ ## value - -#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) -#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) - -// Py_ssize_t was not defined before Python 2.5 -#if (PY_VERSION_HEX < 0x02050000) -typedef int Py_ssize_t; -#endif - -/** - * A cache of the spec_args for a set or list, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - TType element_type; - PyObject* typeargs; -} SetListTypeArgs; - -/** - * A cache of the spec_args for a map, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - TType ktag; - TType vtag; - PyObject* ktypeargs; - PyObject* vtypeargs; -} MapTypeArgs; - -/** - * A cache of the spec_args for a struct, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - PyObject* klass; - PyObject* spec; -} StructTypeArgs; - -/** - * A cache of the item spec from a struct specification, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - int tag; - TType type; - PyObject* attrname; - PyObject* typeargs; - PyObject* defval; -} StructItemSpec; - -/** - * A cache of the two key attributes of a CReadableTransport, - * so we don't have to keep calling PyObject_GetAttr. - */ -typedef struct { - PyObject* stringiobuf; - PyObject* refill_callable; -} DecodeBuffer; - -/** Pointer to interned string to speed up attribute lookup. */ -static PyObject* INTERN_STRING(cstringio_buf); -/** Pointer to interned string to speed up attribute lookup. */ -static PyObject* INTERN_STRING(cstringio_refill); - -static inline bool -check_ssize_t_32(Py_ssize_t len) { - // error from getting the int - if (INT_CONV_ERROR_OCCURRED(len)) { - return false; - } - if (!CHECK_RANGE(len, 0, INT32_MAX)) { - PyErr_SetString(PyExc_OverflowError, "string size out of range"); - return false; - } - return true; -} - -static inline bool -parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { - long val = PyInt_AsLong(o); - - if (INT_CONV_ERROR_OCCURRED(val)) { - return false; - } - if (!CHECK_RANGE(val, min, max)) { - PyErr_SetString(PyExc_OverflowError, "int out of range"); - return false; - } - - *ret = (int32_t) val; - return true; -} - - -/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ - -static bool -parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 2) { - PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); - return false; - } - - dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { - return false; - } - - dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); - - return true; -} - -static bool -parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 4) { - PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); - return false; - } - - dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { - return false; - } - - dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); - if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { - return false; - } - - dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); - dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); - - return true; -} - -static bool -parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 2) { - PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); - return false; - } - - dest->klass = PyTuple_GET_ITEM(typeargs, 0); - dest->spec = PyTuple_GET_ITEM(typeargs, 1); - - return true; -} - -static int -parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { - - // i'd like to use ParseArgs here, but it seems to be a bottleneck. - if (PyTuple_Size(spec_tuple) != 5) { - PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); - return false; - } - - dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->tag)) { - return false; - } - - dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); - if (INT_CONV_ERROR_OCCURRED(dest->type)) { - return false; - } - - dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); - dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); - dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); - return true; -} - -/* ====== END UTILITIES ====== */ - - -/* ====== BEGIN WRITING FUNCTIONS ====== */ - -/* --- LOW-LEVEL WRITING FUNCTIONS --- */ - -static void writeByte(PyObject* outbuf, int8_t val) { - int8_t net = val; - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); -} - -static void writeI16(PyObject* outbuf, int16_t val) { - int16_t net = (int16_t)htons(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); -} - -static void writeI32(PyObject* outbuf, int32_t val) { - int32_t net = (int32_t)htonl(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); -} - -static void writeI64(PyObject* outbuf, int64_t val) { - int64_t net = (int64_t)htonll(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); -} - -static void writeDouble(PyObject* outbuf, double dub) { - // Unfortunately, bitwise_cast doesn't work in C. Bad C! - union { - double f; - int64_t t; - } transfer; - transfer.f = dub; - writeI64(outbuf, transfer.t); -} - - -/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ - -static int -output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { - /* - * Refcounting Strategy: - * - * We assume that elements of the thrift_spec tuple are not going to be - * mutated, so we don't ref count those at all. Other than that, we try to - * keep a reference to all the user-created objects while we work with them. - * output_val assumes that a reference is already held. The *caller* is - * responsible for handling references - */ - - switch (type) { - - case T_BOOL: { - int v = PyObject_IsTrue(value); - if (v == -1) { - return false; - } - - writeByte(output, (int8_t) v); - break; - } - case T_I08: { - int32_t val; - - if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { - return false; - } - - writeByte(output, (int8_t) val); - break; - } - case T_I16: { - int32_t val; - - if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { - return false; - } - - writeI16(output, (int16_t) val); - break; - } - case T_I32: { - int32_t val; - - if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { - return false; - } - - writeI32(output, val); - break; - } - case T_I64: { - int64_t nval = PyLong_AsLongLong(value); - - if (INT_CONV_ERROR_OCCURRED(nval)) { - return false; - } - - if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { - PyErr_SetString(PyExc_OverflowError, "int out of range"); - return false; - } - - writeI64(output, nval); - break; - } - - case T_DOUBLE: { - double nval = PyFloat_AsDouble(value); - if (nval == -1.0 && PyErr_Occurred()) { - return false; - } - - writeDouble(output, nval); - break; - } - - case T_STRING: { - Py_ssize_t len = PyString_Size(value); - - if (!check_ssize_t_32(len)) { - return false; - } - - writeI32(output, (int32_t) len); - PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); - break; - } - - case T_LIST: - case T_SET: { - Py_ssize_t len; - SetListTypeArgs parsedargs; - PyObject *item; - PyObject *iterator; - - if (!parse_set_list_args(&parsedargs, typeargs)) { - return false; - } - - len = PyObject_Length(value); - - if (!check_ssize_t_32(len)) { - return false; - } - - writeByte(output, parsedargs.element_type); - writeI32(output, (int32_t) len); - - iterator = PyObject_GetIter(value); - if (iterator == NULL) { - return false; - } - - while ((item = PyIter_Next(iterator))) { - if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { - Py_DECREF(item); - Py_DECREF(iterator); - return false; - } - Py_DECREF(item); - } - - Py_DECREF(iterator); - - if (PyErr_Occurred()) { - return false; - } - - break; - } - - case T_MAP: { - PyObject *k, *v; - Py_ssize_t pos = 0; - Py_ssize_t len; - - MapTypeArgs parsedargs; - - len = PyDict_Size(value); - if (!check_ssize_t_32(len)) { - return false; - } - - if (!parse_map_args(&parsedargs, typeargs)) { - return false; - } - - writeByte(output, parsedargs.ktag); - writeByte(output, parsedargs.vtag); - writeI32(output, len); - - // TODO(bmaurer): should support any mapping, not just dicts - while (PyDict_Next(value, &pos, &k, &v)) { - // TODO(dreiss): Think hard about whether these INCREFs actually - // turn any unsafe scenarios into safe scenarios. - Py_INCREF(k); - Py_INCREF(v); - - if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) - || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { - Py_DECREF(k); - Py_DECREF(v); - return false; - } - Py_DECREF(k); - Py_DECREF(v); - } - break; - } - - // TODO(dreiss): Consider breaking this out as a function - // the way we did for decode_struct. - case T_STRUCT: { - StructTypeArgs parsedargs; - Py_ssize_t nspec; - Py_ssize_t i; - - if (!parse_struct_args(&parsedargs, typeargs)) { - return false; - } - - nspec = PyTuple_Size(parsedargs.spec); - - if (nspec == -1) { - return false; - } - - for (i = 0; i < nspec; i++) { - StructItemSpec parsedspec; - PyObject* spec_tuple; - PyObject* instval = NULL; - - spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); - if (spec_tuple == Py_None) { - continue; - } - - if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { - return false; - } - - instval = PyObject_GetAttr(value, parsedspec.attrname); - - if (!instval) { - return false; - } - - if (instval == Py_None) { - Py_DECREF(instval); - continue; - } - - writeByte(output, (int8_t) parsedspec.type); - writeI16(output, parsedspec.tag); - - if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { - Py_DECREF(instval); - return false; - } - - Py_DECREF(instval); - } - - writeByte(output, (int8_t)T_STOP); - break; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return false; - - } - - return true; -} - - -/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ - -static PyObject * -encode_binary(PyObject *self, PyObject *args) { - PyObject* enc_obj; - PyObject* type_args; - PyObject* buf; - PyObject* ret = NULL; - - if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { - return NULL; - } - - buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); - if (output_val(buf, enc_obj, T_STRUCT, type_args)) { - ret = PycStringIO->cgetvalue(buf); - } - - Py_DECREF(buf); - return ret; -} - -/* ====== END WRITING FUNCTIONS ====== */ - - -/* ====== BEGIN READING FUNCTIONS ====== */ - -/* --- LOW-LEVEL READING FUNCTIONS --- */ - -static void -free_decodebuf(DecodeBuffer* d) { - Py_XDECREF(d->stringiobuf); - Py_XDECREF(d->refill_callable); -} - -static bool -decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { - dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); - if (!dest->stringiobuf) { - return false; - } - - if (!PycStringIO_InputCheck(dest->stringiobuf)) { - free_decodebuf(dest); - PyErr_SetString(PyExc_TypeError, "expecting stringio input"); - return false; - } - - dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); - - if(!dest->refill_callable) { - free_decodebuf(dest); - return false; - } - - if (!PyCallable_Check(dest->refill_callable)) { - free_decodebuf(dest); - PyErr_SetString(PyExc_TypeError, "expecting callable"); - return false; - } - - return true; -} - -static bool readBytes(DecodeBuffer* input, char** output, int len) { - int read; - - // TODO(dreiss): Don't fear the malloc. Think about taking a copy of - // the partial read instead of forcing the transport - // to prepend it to its buffer. - - read = PycStringIO->cread(input->stringiobuf, output, len); - - if (read == len) { - return true; - } else if (read == -1) { - return false; - } else { - PyObject* newiobuf; - - // using building functions as this is a rare codepath - newiobuf = PyObject_CallFunction( - input->refill_callable, "s#i", *output, read, len, NULL); - if (newiobuf == NULL) { - return false; - } - - // must do this *AFTER* the call so that we don't deref the io buffer - Py_CLEAR(input->stringiobuf); - input->stringiobuf = newiobuf; - - read = PycStringIO->cread(input->stringiobuf, output, len); - - if (read == len) { - return true; - } else if (read == -1) { - return false; - } else { - // TODO(dreiss): This could be a valid code path for big binary blobs. - PyErr_SetString(PyExc_TypeError, - "refill claimed to have refilled the buffer, but didn't!!"); - return false; - } - } -} - -static int8_t readByte(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int8_t))) { - return -1; - } - - return *(int8_t*) buf; -} - -static int16_t readI16(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int16_t))) { - return -1; - } - - return (int16_t) ntohs(*(int16_t*) buf); -} - -static int32_t readI32(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int32_t))) { - return -1; - } - return (int32_t) ntohl(*(int32_t*) buf); -} - - -static int64_t readI64(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int64_t))) { - return -1; - } - - return (int64_t) ntohll(*(int64_t*) buf); -} - -static double readDouble(DecodeBuffer* input) { - union { - int64_t f; - double t; - } transfer; - - transfer.f = readI64(input); - if (transfer.f == -1) { - return -1; - } - return transfer.t; -} - -static bool -checkTypeByte(DecodeBuffer* input, TType expected) { - TType got = readByte(input); - if (INT_CONV_ERROR_OCCURRED(got)) { - return false; - } - - if (expected != got) { - PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); - return false; - } - return true; -} - -static bool -skip(DecodeBuffer* input, TType type) { -#define SKIPBYTES(n) \ - do { \ - if (!readBytes(input, &dummy_buf, (n))) { \ - return false; \ - } \ - } while(0) - - char* dummy_buf; - - switch (type) { - - case T_BOOL: - case T_I08: SKIPBYTES(1); break; - case T_I16: SKIPBYTES(2); break; - case T_I32: SKIPBYTES(4); break; - case T_I64: - case T_DOUBLE: SKIPBYTES(8); break; - - case T_STRING: { - // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. - int len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - SKIPBYTES(len); - break; - } - - case T_LIST: - case T_SET: { - TType etype; - int len, i; - - etype = readByte(input); - if (etype == -1) { - return false; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - for (i = 0; i < len; i++) { - if (!skip(input, etype)) { - return false; - } - } - break; - } - - case T_MAP: { - TType ktype, vtype; - int len, i; - - ktype = readByte(input); - if (ktype == -1) { - return false; - } - - vtype = readByte(input); - if (vtype == -1) { - return false; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - for (i = 0; i < len; i++) { - if (!(skip(input, ktype) && skip(input, vtype))) { - return false; - } - } - break; - } - - case T_STRUCT: { - while (true) { - TType type; - - type = readByte(input); - if (type == -1) { - return false; - } - - if (type == T_STOP) - break; - - SKIPBYTES(2); // tag - if (!skip(input, type)) { - return false; - } - } - break; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return false; - - } - - return true; - -#undef SKIPBYTES -} - - -/* --- HELPER FUNCTION FOR DECODE_VAL --- */ - -static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); - -static bool -decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { - int spec_seq_len = PyTuple_Size(spec_seq); - if (spec_seq_len == -1) { - return false; - } - - while (true) { - TType type; - int16_t tag; - PyObject* item_spec; - PyObject* fieldval = NULL; - StructItemSpec parsedspec; - - type = readByte(input); - if (type == -1) { - return false; - } - if (type == T_STOP) { - break; - } - tag = readI16(input); - if (INT_CONV_ERROR_OCCURRED(tag)) { - return false; - } - if (tag >= 0 && tag < spec_seq_len) { - item_spec = PyTuple_GET_ITEM(spec_seq, tag); - } else { - item_spec = Py_None; - } - - if (item_spec == Py_None) { - if (!skip(input, type)) { - return false; - } else { - continue; - } - } - - if (!parse_struct_item_spec(&parsedspec, item_spec)) { - return false; - } - if (parsedspec.type != type) { - if (!skip(input, type)) { - PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); - return false; - } else { - continue; - } - } - - fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); - if (fieldval == NULL) { - return false; - } - - if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { - Py_DECREF(fieldval); - return false; - } - Py_DECREF(fieldval); - } - return true; -} - - -/* --- MAIN RECURSIVE INPUT FUCNTION --- */ - -// Returns a new reference. -static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { - switch (type) { - - case T_BOOL: { - int8_t v = readByte(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - - switch (v) { - case 0: Py_RETURN_FALSE; - case 1: Py_RETURN_TRUE; - // Don't laugh. This is a potentially serious issue. - default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; - } - break; - } - case T_I08: { - int8_t v = readByte(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - - return PyInt_FromLong(v); - } - case T_I16: { - int16_t v = readI16(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - return PyInt_FromLong(v); - } - case T_I32: { - int32_t v = readI32(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - return PyInt_FromLong(v); - } - - case T_I64: { - int64_t v = readI64(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - // TODO(dreiss): Find out if we can take this fastpath always when - // sizeof(long) == sizeof(long long). - if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { - return PyInt_FromLong((long) v); - } - - return PyLong_FromLongLong(v); - } - - case T_DOUBLE: { - double v = readDouble(input); - if (v == -1.0 && PyErr_Occurred()) { - return false; - } - return PyFloat_FromDouble(v); - } - - case T_STRING: { - Py_ssize_t len = readI32(input); - char* buf; - if (!readBytes(input, &buf, len)) { - return NULL; - } - - return PyString_FromStringAndSize(buf, len); - } - - case T_LIST: - case T_SET: { - SetListTypeArgs parsedargs; - int32_t len; - PyObject* ret = NULL; - int i; - - if (!parse_set_list_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!checkTypeByte(input, parsedargs.element_type)) { - return NULL; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return NULL; - } - - ret = PyList_New(len); - if (!ret) { - return NULL; - } - - for (i = 0; i < len; i++) { - PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); - if (!item) { - Py_DECREF(ret); - return NULL; - } - PyList_SET_ITEM(ret, i, item); - } - - // TODO(dreiss): Consider biting the bullet and making two separate cases - // for list and set, avoiding this post facto conversion. - if (type == T_SET) { - PyObject* setret; -#if (PY_VERSION_HEX < 0x02050000) - // hack needed for older versions - setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); -#else - // official version - setret = PySet_New(ret); -#endif - Py_DECREF(ret); - return setret; - } - return ret; - } - - case T_MAP: { - int32_t len; - int i; - MapTypeArgs parsedargs; - PyObject* ret = NULL; - - if (!parse_map_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!checkTypeByte(input, parsedargs.ktag)) { - return NULL; - } - if (!checkTypeByte(input, parsedargs.vtag)) { - return NULL; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - ret = PyDict_New(); - if (!ret) { - goto error; - } - - for (i = 0; i < len; i++) { - PyObject* k = NULL; - PyObject* v = NULL; - k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); - if (k == NULL) { - goto loop_error; - } - v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); - if (v == NULL) { - goto loop_error; - } - if (PyDict_SetItem(ret, k, v) == -1) { - goto loop_error; - } - - Py_DECREF(k); - Py_DECREF(v); - continue; - - // Yuck! Destructors, anyone? - loop_error: - Py_XDECREF(k); - Py_XDECREF(v); - goto error; - } - - return ret; - - error: - Py_XDECREF(ret); - return NULL; - } - - case T_STRUCT: { - StructTypeArgs parsedargs; - PyObject* ret; - if (!parse_struct_args(&parsedargs, typeargs)) { - return NULL; - } - - ret = PyObject_CallObject(parsedargs.klass, NULL); - if (!ret) { - return NULL; - } - - if (!decode_struct(input, ret, parsedargs.spec)) { - Py_DECREF(ret); - return NULL; - } - - return ret; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return NULL; - } -} - - -/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ - -static PyObject* -decode_binary(PyObject *self, PyObject *args) { - PyObject* output_obj = NULL; - PyObject* transport = NULL; - PyObject* typeargs = NULL; - StructTypeArgs parsedargs; - DecodeBuffer input = {0, 0}; - - if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { - return NULL; - } - - if (!parse_struct_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!decode_buffer_from_obj(&input, transport)) { - return NULL; - } - - if (!decode_struct(&input, output_obj, parsedargs.spec)) { - free_decodebuf(&input); - return NULL; - } - - free_decodebuf(&input); - - Py_RETURN_NONE; -} - -/* ====== END READING FUNCTIONS ====== */ - - -/* -- PYTHON MODULE SETUP STUFF --- */ - -static PyMethodDef ThriftFastBinaryMethods[] = { - - {"encode_binary", encode_binary, METH_VARARGS, ""}, - {"decode_binary", decode_binary, METH_VARARGS, ""}, - - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -PyMODINIT_FUNC -initfastbinary(void) { -#define INIT_INTERN_STRING(value) \ - do { \ - INTERN_STRING(value) = PyString_InternFromString(#value); \ - if(!INTERN_STRING(value)) return; \ - } while(0) - - INIT_INTERN_STRING(cstringio_buf); - INIT_INTERN_STRING(cstringio_refill); -#undef INIT_INTERN_STRING - - PycString_IMPORT; - if (PycStringIO == NULL) return; - - (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); -} diff --git a/anknotes/thrift/protocol/fastbinary.c~HEAD b/anknotes/thrift/protocol/fastbinary.c~HEAD deleted file mode 100644 index 2ce5660..0000000 --- a/anknotes/thrift/protocol/fastbinary.c~HEAD +++ /dev/null @@ -1,1219 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include <Python.h> -#include "cStringIO.h" -#include <stdint.h> -#ifndef _WIN32 -# include <stdbool.h> -# include <netinet/in.h> -#else -# include <WinSock2.h> -# pragma comment (lib, "ws2_32.lib") -# define BIG_ENDIAN (4321) -# define LITTLE_ENDIAN (1234) -# define BYTE_ORDER LITTLE_ENDIAN -# if defined(_MSC_VER) && _MSC_VER < 1600 - typedef int _Bool; -# define bool _Bool -# define false 0 -# define true 1 -# endif -# define inline __inline -#endif - -/* Fix endianness issues on Solaris */ -#if defined (__SVR4) && defined (__sun) - #if defined(__i386) && !defined(__i386__) - #define __i386__ - #endif - - #ifndef BIG_ENDIAN - #define BIG_ENDIAN (4321) - #endif - #ifndef LITTLE_ENDIAN - #define LITTLE_ENDIAN (1234) - #endif - - /* I386 is LE, even on Solaris */ - #if !defined(BYTE_ORDER) && defined(__i386__) - #define BYTE_ORDER LITTLE_ENDIAN - #endif -#endif - -// TODO(dreiss): defval appears to be unused. Look into removing it. -// TODO(dreiss): Make parse_spec_args recursive, and cache the output -// permanently in the object. (Malloc and orphan.) -// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? -// Can cStringIO let us work with a BufferedTransport? -// TODO(dreiss): Don't ignore the rv from cwrite (maybe). - -/* ====== BEGIN UTILITIES ====== */ - -#define INIT_OUTBUF_SIZE 128 - -// Stolen out of TProtocol.h. -// It would be a huge pain to have both get this from one place. -typedef enum TType { - T_STOP = 0, - T_VOID = 1, - T_BOOL = 2, - T_BYTE = 3, - T_I08 = 3, - T_I16 = 6, - T_I32 = 8, - T_U64 = 9, - T_I64 = 10, - T_DOUBLE = 4, - T_STRING = 11, - T_UTF7 = 11, - T_STRUCT = 12, - T_MAP = 13, - T_SET = 14, - T_LIST = 15, - T_UTF8 = 16, - T_UTF16 = 17 -} TType; - -#ifndef __BYTE_ORDER -# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) -# define __BYTE_ORDER BYTE_ORDER -# define __LITTLE_ENDIAN LITTLE_ENDIAN -# define __BIG_ENDIAN BIG_ENDIAN -# else -# error "Cannot determine endianness" -# endif -#endif - -// Same comment as the enum. Sorry. -#if __BYTE_ORDER == __BIG_ENDIAN -# define ntohll(n) (n) -# define htonll(n) (n) -#elif __BYTE_ORDER == __LITTLE_ENDIAN -# if defined(__GNUC__) && defined(__GLIBC__) -# include <byteswap.h> -# define ntohll(n) bswap_64(n) -# define htonll(n) bswap_64(n) -# else /* GNUC & GLIBC */ -# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) -# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) -# endif /* GNUC & GLIBC */ -#else /* __BYTE_ORDER */ -# error "Can't define htonll or ntohll!" -#endif - -// Doing a benchmark shows that interning actually makes a difference, amazingly. -#define INTERN_STRING(value) _intern_ ## value - -#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) -#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) - -// Py_ssize_t was not defined before Python 2.5 -#if (PY_VERSION_HEX < 0x02050000) -typedef int Py_ssize_t; -#endif - -/** - * A cache of the spec_args for a set or list, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - TType element_type; - PyObject* typeargs; -} SetListTypeArgs; - -/** - * A cache of the spec_args for a map, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - TType ktag; - TType vtag; - PyObject* ktypeargs; - PyObject* vtypeargs; -} MapTypeArgs; - -/** - * A cache of the spec_args for a struct, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - PyObject* klass; - PyObject* spec; -} StructTypeArgs; - -/** - * A cache of the item spec from a struct specification, - * so we don't have to keep calling PyTuple_GET_ITEM. - */ -typedef struct { - int tag; - TType type; - PyObject* attrname; - PyObject* typeargs; - PyObject* defval; -} StructItemSpec; - -/** - * A cache of the two key attributes of a CReadableTransport, - * so we don't have to keep calling PyObject_GetAttr. - */ -typedef struct { - PyObject* stringiobuf; - PyObject* refill_callable; -} DecodeBuffer; - -/** Pointer to interned string to speed up attribute lookup. */ -static PyObject* INTERN_STRING(cstringio_buf); -/** Pointer to interned string to speed up attribute lookup. */ -static PyObject* INTERN_STRING(cstringio_refill); - -static inline bool -check_ssize_t_32(Py_ssize_t len) { - // error from getting the int - if (INT_CONV_ERROR_OCCURRED(len)) { - return false; - } - if (!CHECK_RANGE(len, 0, INT32_MAX)) { - PyErr_SetString(PyExc_OverflowError, "string size out of range"); - return false; - } - return true; -} - -static inline bool -parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { - long val = PyInt_AsLong(o); - - if (INT_CONV_ERROR_OCCURRED(val)) { - return false; - } - if (!CHECK_RANGE(val, min, max)) { - PyErr_SetString(PyExc_OverflowError, "int out of range"); - return false; - } - - *ret = (int32_t) val; - return true; -} - - -/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ - -static bool -parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 2) { - PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); - return false; - } - - dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { - return false; - } - - dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); - - return true; -} - -static bool -parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 4) { - PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); - return false; - } - - dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { - return false; - } - - dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); - if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { - return false; - } - - dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); - dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); - - return true; -} - -static bool -parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 2) { - PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); - return false; - } - - dest->klass = PyTuple_GET_ITEM(typeargs, 0); - dest->spec = PyTuple_GET_ITEM(typeargs, 1); - - return true; -} - -static int -parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { - - // i'd like to use ParseArgs here, but it seems to be a bottleneck. - if (PyTuple_Size(spec_tuple) != 5) { - PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); - return false; - } - - dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); - if (INT_CONV_ERROR_OCCURRED(dest->tag)) { - return false; - } - - dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); - if (INT_CONV_ERROR_OCCURRED(dest->type)) { - return false; - } - - dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); - dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); - dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); - return true; -} - -/* ====== END UTILITIES ====== */ - - -/* ====== BEGIN WRITING FUNCTIONS ====== */ - -/* --- LOW-LEVEL WRITING FUNCTIONS --- */ - -static void writeByte(PyObject* outbuf, int8_t val) { - int8_t net = val; - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); -} - -static void writeI16(PyObject* outbuf, int16_t val) { - int16_t net = (int16_t)htons(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); -} - -static void writeI32(PyObject* outbuf, int32_t val) { - int32_t net = (int32_t)htonl(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); -} - -static void writeI64(PyObject* outbuf, int64_t val) { - int64_t net = (int64_t)htonll(val); - PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); -} - -static void writeDouble(PyObject* outbuf, double dub) { - // Unfortunately, bitwise_cast doesn't work in C. Bad C! - union { - double f; - int64_t t; - } transfer; - transfer.f = dub; - writeI64(outbuf, transfer.t); -} - - -/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ - -static int -output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { - /* - * Refcounting Strategy: - * - * We assume that elements of the thrift_spec tuple are not going to be - * mutated, so we don't ref count those at all. Other than that, we try to - * keep a reference to all the user-created objects while we work with them. - * output_val assumes that a reference is already held. The *caller* is - * responsible for handling references - */ - - switch (type) { - - case T_BOOL: { - int v = PyObject_IsTrue(value); - if (v == -1) { - return false; - } - - writeByte(output, (int8_t) v); - break; - } - case T_I08: { - int32_t val; - - if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { - return false; - } - - writeByte(output, (int8_t) val); - break; - } - case T_I16: { - int32_t val; - - if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { - return false; - } - - writeI16(output, (int16_t) val); - break; - } - case T_I32: { - int32_t val; - - if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { - return false; - } - - writeI32(output, val); - break; - } - case T_I64: { - int64_t nval = PyLong_AsLongLong(value); - - if (INT_CONV_ERROR_OCCURRED(nval)) { - return false; - } - - if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { - PyErr_SetString(PyExc_OverflowError, "int out of range"); - return false; - } - - writeI64(output, nval); - break; - } - - case T_DOUBLE: { - double nval = PyFloat_AsDouble(value); - if (nval == -1.0 && PyErr_Occurred()) { - return false; - } - - writeDouble(output, nval); - break; - } - - case T_STRING: { - Py_ssize_t len = PyString_Size(value); - - if (!check_ssize_t_32(len)) { - return false; - } - - writeI32(output, (int32_t) len); - PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); - break; - } - - case T_LIST: - case T_SET: { - Py_ssize_t len; - SetListTypeArgs parsedargs; - PyObject *item; - PyObject *iterator; - - if (!parse_set_list_args(&parsedargs, typeargs)) { - return false; - } - - len = PyObject_Length(value); - - if (!check_ssize_t_32(len)) { - return false; - } - - writeByte(output, parsedargs.element_type); - writeI32(output, (int32_t) len); - - iterator = PyObject_GetIter(value); - if (iterator == NULL) { - return false; - } - - while ((item = PyIter_Next(iterator))) { - if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { - Py_DECREF(item); - Py_DECREF(iterator); - return false; - } - Py_DECREF(item); - } - - Py_DECREF(iterator); - - if (PyErr_Occurred()) { - return false; - } - - break; - } - - case T_MAP: { - PyObject *k, *v; - Py_ssize_t pos = 0; - Py_ssize_t len; - - MapTypeArgs parsedargs; - - len = PyDict_Size(value); - if (!check_ssize_t_32(len)) { - return false; - } - - if (!parse_map_args(&parsedargs, typeargs)) { - return false; - } - - writeByte(output, parsedargs.ktag); - writeByte(output, parsedargs.vtag); - writeI32(output, len); - - // TODO(bmaurer): should support any mapping, not just dicts - while (PyDict_Next(value, &pos, &k, &v)) { - // TODO(dreiss): Think hard about whether these INCREFs actually - // turn any unsafe scenarios into safe scenarios. - Py_INCREF(k); - Py_INCREF(v); - - if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) - || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { - Py_DECREF(k); - Py_DECREF(v); - return false; - } - Py_DECREF(k); - Py_DECREF(v); - } - break; - } - - // TODO(dreiss): Consider breaking this out as a function - // the way we did for decode_struct. - case T_STRUCT: { - StructTypeArgs parsedargs; - Py_ssize_t nspec; - Py_ssize_t i; - - if (!parse_struct_args(&parsedargs, typeargs)) { - return false; - } - - nspec = PyTuple_Size(parsedargs.spec); - - if (nspec == -1) { - return false; - } - - for (i = 0; i < nspec; i++) { - StructItemSpec parsedspec; - PyObject* spec_tuple; - PyObject* instval = NULL; - - spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); - if (spec_tuple == Py_None) { - continue; - } - - if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { - return false; - } - - instval = PyObject_GetAttr(value, parsedspec.attrname); - - if (!instval) { - return false; - } - - if (instval == Py_None) { - Py_DECREF(instval); - continue; - } - - writeByte(output, (int8_t) parsedspec.type); - writeI16(output, parsedspec.tag); - - if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { - Py_DECREF(instval); - return false; - } - - Py_DECREF(instval); - } - - writeByte(output, (int8_t)T_STOP); - break; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return false; - - } - - return true; -} - - -/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ - -static PyObject * -encode_binary(PyObject *self, PyObject *args) { - PyObject* enc_obj; - PyObject* type_args; - PyObject* buf; - PyObject* ret = NULL; - - if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { - return NULL; - } - - buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); - if (output_val(buf, enc_obj, T_STRUCT, type_args)) { - ret = PycStringIO->cgetvalue(buf); - } - - Py_DECREF(buf); - return ret; -} - -/* ====== END WRITING FUNCTIONS ====== */ - - -/* ====== BEGIN READING FUNCTIONS ====== */ - -/* --- LOW-LEVEL READING FUNCTIONS --- */ - -static void -free_decodebuf(DecodeBuffer* d) { - Py_XDECREF(d->stringiobuf); - Py_XDECREF(d->refill_callable); -} - -static bool -decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { - dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); - if (!dest->stringiobuf) { - return false; - } - - if (!PycStringIO_InputCheck(dest->stringiobuf)) { - free_decodebuf(dest); - PyErr_SetString(PyExc_TypeError, "expecting stringio input"); - return false; - } - - dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); - - if(!dest->refill_callable) { - free_decodebuf(dest); - return false; - } - - if (!PyCallable_Check(dest->refill_callable)) { - free_decodebuf(dest); - PyErr_SetString(PyExc_TypeError, "expecting callable"); - return false; - } - - return true; -} - -static bool readBytes(DecodeBuffer* input, char** output, int len) { - int read; - - // TODO(dreiss): Don't fear the malloc. Think about taking a copy of - // the partial read instead of forcing the transport - // to prepend it to its buffer. - - read = PycStringIO->cread(input->stringiobuf, output, len); - - if (read == len) { - return true; - } else if (read == -1) { - return false; - } else { - PyObject* newiobuf; - - // using building functions as this is a rare codepath - newiobuf = PyObject_CallFunction( - input->refill_callable, "s#i", *output, read, len, NULL); - if (newiobuf == NULL) { - return false; - } - - // must do this *AFTER* the call so that we don't deref the io buffer - Py_CLEAR(input->stringiobuf); - input->stringiobuf = newiobuf; - - read = PycStringIO->cread(input->stringiobuf, output, len); - - if (read == len) { - return true; - } else if (read == -1) { - return false; - } else { - // TODO(dreiss): This could be a valid code path for big binary blobs. - PyErr_SetString(PyExc_TypeError, - "refill claimed to have refilled the buffer, but didn't!!"); - return false; - } - } -} - -static int8_t readByte(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int8_t))) { - return -1; - } - - return *(int8_t*) buf; -} - -static int16_t readI16(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int16_t))) { - return -1; - } - - return (int16_t) ntohs(*(int16_t*) buf); -} - -static int32_t readI32(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int32_t))) { - return -1; - } - return (int32_t) ntohl(*(int32_t*) buf); -} - - -static int64_t readI64(DecodeBuffer* input) { - char* buf; - if (!readBytes(input, &buf, sizeof(int64_t))) { - return -1; - } - - return (int64_t) ntohll(*(int64_t*) buf); -} - -static double readDouble(DecodeBuffer* input) { - union { - int64_t f; - double t; - } transfer; - - transfer.f = readI64(input); - if (transfer.f == -1) { - return -1; - } - return transfer.t; -} - -static bool -checkTypeByte(DecodeBuffer* input, TType expected) { - TType got = readByte(input); - if (INT_CONV_ERROR_OCCURRED(got)) { - return false; - } - - if (expected != got) { - PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); - return false; - } - return true; -} - -static bool -skip(DecodeBuffer* input, TType type) { -#define SKIPBYTES(n) \ - do { \ - if (!readBytes(input, &dummy_buf, (n))) { \ - return false; \ - } \ - } while(0) - - char* dummy_buf; - - switch (type) { - - case T_BOOL: - case T_I08: SKIPBYTES(1); break; - case T_I16: SKIPBYTES(2); break; - case T_I32: SKIPBYTES(4); break; - case T_I64: - case T_DOUBLE: SKIPBYTES(8); break; - - case T_STRING: { - // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. - int len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - SKIPBYTES(len); - break; - } - - case T_LIST: - case T_SET: { - TType etype; - int len, i; - - etype = readByte(input); - if (etype == -1) { - return false; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - for (i = 0; i < len; i++) { - if (!skip(input, etype)) { - return false; - } - } - break; - } - - case T_MAP: { - TType ktype, vtype; - int len, i; - - ktype = readByte(input); - if (ktype == -1) { - return false; - } - - vtype = readByte(input); - if (vtype == -1) { - return false; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - for (i = 0; i < len; i++) { - if (!(skip(input, ktype) && skip(input, vtype))) { - return false; - } - } - break; - } - - case T_STRUCT: { - while (true) { - TType type; - - type = readByte(input); - if (type == -1) { - return false; - } - - if (type == T_STOP) - break; - - SKIPBYTES(2); // tag - if (!skip(input, type)) { - return false; - } - } - break; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return false; - - } - - return true; - -#undef SKIPBYTES -} - - -/* --- HELPER FUNCTION FOR DECODE_VAL --- */ - -static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); - -static bool -decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { - int spec_seq_len = PyTuple_Size(spec_seq); - if (spec_seq_len == -1) { - return false; - } - - while (true) { - TType type; - int16_t tag; - PyObject* item_spec; - PyObject* fieldval = NULL; - StructItemSpec parsedspec; - - type = readByte(input); - if (type == -1) { - return false; - } - if (type == T_STOP) { - break; - } - tag = readI16(input); - if (INT_CONV_ERROR_OCCURRED(tag)) { - return false; - } - if (tag >= 0 && tag < spec_seq_len) { - item_spec = PyTuple_GET_ITEM(spec_seq, tag); - } else { - item_spec = Py_None; - } - - if (item_spec == Py_None) { - if (!skip(input, type)) { - return false; - } else { - continue; - } - } - - if (!parse_struct_item_spec(&parsedspec, item_spec)) { - return false; - } - if (parsedspec.type != type) { - if (!skip(input, type)) { - PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); - return false; - } else { - continue; - } - } - - fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); - if (fieldval == NULL) { - return false; - } - - if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { - Py_DECREF(fieldval); - return false; - } - Py_DECREF(fieldval); - } - return true; -} - - -/* --- MAIN RECURSIVE INPUT FUCNTION --- */ - -// Returns a new reference. -static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { - switch (type) { - - case T_BOOL: { - int8_t v = readByte(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - - switch (v) { - case 0: Py_RETURN_FALSE; - case 1: Py_RETURN_TRUE; - // Don't laugh. This is a potentially serious issue. - default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; - } - break; - } - case T_I08: { - int8_t v = readByte(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - - return PyInt_FromLong(v); - } - case T_I16: { - int16_t v = readI16(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - return PyInt_FromLong(v); - } - case T_I32: { - int32_t v = readI32(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - return PyInt_FromLong(v); - } - - case T_I64: { - int64_t v = readI64(input); - if (INT_CONV_ERROR_OCCURRED(v)) { - return NULL; - } - // TODO(dreiss): Find out if we can take this fastpath always when - // sizeof(long) == sizeof(long long). - if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { - return PyInt_FromLong((long) v); - } - - return PyLong_FromLongLong(v); - } - - case T_DOUBLE: { - double v = readDouble(input); - if (v == -1.0 && PyErr_Occurred()) { - return false; - } - return PyFloat_FromDouble(v); - } - - case T_STRING: { - Py_ssize_t len = readI32(input); - char* buf; - if (!readBytes(input, &buf, len)) { - return NULL; - } - - return PyString_FromStringAndSize(buf, len); - } - - case T_LIST: - case T_SET: { - SetListTypeArgs parsedargs; - int32_t len; - PyObject* ret = NULL; - int i; - - if (!parse_set_list_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!checkTypeByte(input, parsedargs.element_type)) { - return NULL; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return NULL; - } - - ret = PyList_New(len); - if (!ret) { - return NULL; - } - - for (i = 0; i < len; i++) { - PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); - if (!item) { - Py_DECREF(ret); - return NULL; - } - PyList_SET_ITEM(ret, i, item); - } - - // TODO(dreiss): Consider biting the bullet and making two separate cases - // for list and set, avoiding this post facto conversion. - if (type == T_SET) { - PyObject* setret; -#if (PY_VERSION_HEX < 0x02050000) - // hack needed for older versions - setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); -#else - // official version - setret = PySet_New(ret); -#endif - Py_DECREF(ret); - return setret; - } - return ret; - } - - case T_MAP: { - int32_t len; - int i; - MapTypeArgs parsedargs; - PyObject* ret = NULL; - - if (!parse_map_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!checkTypeByte(input, parsedargs.ktag)) { - return NULL; - } - if (!checkTypeByte(input, parsedargs.vtag)) { - return NULL; - } - - len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; - } - - ret = PyDict_New(); - if (!ret) { - goto error; - } - - for (i = 0; i < len; i++) { - PyObject* k = NULL; - PyObject* v = NULL; - k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); - if (k == NULL) { - goto loop_error; - } - v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); - if (v == NULL) { - goto loop_error; - } - if (PyDict_SetItem(ret, k, v) == -1) { - goto loop_error; - } - - Py_DECREF(k); - Py_DECREF(v); - continue; - - // Yuck! Destructors, anyone? - loop_error: - Py_XDECREF(k); - Py_XDECREF(v); - goto error; - } - - return ret; - - error: - Py_XDECREF(ret); - return NULL; - } - - case T_STRUCT: { - StructTypeArgs parsedargs; - PyObject* ret; - if (!parse_struct_args(&parsedargs, typeargs)) { - return NULL; - } - - ret = PyObject_CallObject(parsedargs.klass, NULL); - if (!ret) { - return NULL; - } - - if (!decode_struct(input, ret, parsedargs.spec)) { - Py_DECREF(ret); - return NULL; - } - - return ret; - } - - case T_STOP: - case T_VOID: - case T_UTF16: - case T_UTF8: - case T_U64: - default: - PyErr_SetString(PyExc_TypeError, "Unexpected TType"); - return NULL; - } -} - - -/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ - -static PyObject* -decode_binary(PyObject *self, PyObject *args) { - PyObject* output_obj = NULL; - PyObject* transport = NULL; - PyObject* typeargs = NULL; - StructTypeArgs parsedargs; - DecodeBuffer input = {0, 0}; - - if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { - return NULL; - } - - if (!parse_struct_args(&parsedargs, typeargs)) { - return NULL; - } - - if (!decode_buffer_from_obj(&input, transport)) { - return NULL; - } - - if (!decode_struct(&input, output_obj, parsedargs.spec)) { - free_decodebuf(&input); - return NULL; - } - - free_decodebuf(&input); - - Py_RETURN_NONE; -} - -/* ====== END READING FUNCTIONS ====== */ - - -/* -- PYTHON MODULE SETUP STUFF --- */ - -static PyMethodDef ThriftFastBinaryMethods[] = { - - {"encode_binary", encode_binary, METH_VARARGS, ""}, - {"decode_binary", decode_binary, METH_VARARGS, ""}, - - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -PyMODINIT_FUNC -initfastbinary(void) { -#define INIT_INTERN_STRING(value) \ - do { \ - INTERN_STRING(value) = PyString_InternFromString(#value); \ - if(!INTERN_STRING(value)) return; \ - } while(0) - - INIT_INTERN_STRING(cstringio_buf); - INIT_INTERN_STRING(cstringio_refill); -#undef INIT_INTERN_STRING - - PycString_IMPORT; - if (PycStringIO == NULL) return; - - (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); -} diff --git a/anknotes/thrift/server/THttpServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/server/THttpServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 3047d9c..0000000 --- a/anknotes/thrift/server/THttpServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import BaseHTTPServer - -from thrift.server import TServer -from thrift.transport import TTransport - -class ResponseException(Exception): - """Allows handlers to override the HTTP response - - Normally, THttpServer always sends a 200 response. If a handler wants - to override this behavior (e.g., to simulate a misconfigured or - overloaded web server during testing), it can raise a ResponseException. - The function passed to the constructor will be called with the - RequestHandler as its only argument. - """ - def __init__(self, handler): - self.handler = handler - - -class THttpServer(TServer.TServer): - """A simple HTTP-based Thrift server - - This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint.""" - - def __init__(self, processor, server_address, - inputProtocolFactory, outputProtocolFactory = None, - server_class = BaseHTTPServer.HTTPServer): - """Set up protocol factories and HTTP server. - - See BaseHTTPServer for server_address. - See TServer for protocol factories.""" - - if outputProtocolFactory is None: - outputProtocolFactory = inputProtocolFactory - - TServer.TServer.__init__(self, processor, None, None, None, - inputProtocolFactory, outputProtocolFactory) - - thttpserver = self - - class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): - def do_POST(self): - # Don't care about the request path. - itrans = TTransport.TFileObjectTransport(self.rfile) - otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) - otrans = TTransport.TMemoryBuffer() - iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) - oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) - try: - thttpserver.processor.process(iprot, oprot) - except ResponseException, exn: - exn.handler(self) - else: - self.send_response(200) - self.send_header("content-type", "application/x-thrift") - self.end_headers() - self.wfile.write(otrans.getvalue()) - - self.httpd = server_class(server_address, RequestHander) - - def serve(self): - self.httpd.serve_forever() diff --git a/anknotes/thrift/server/THttpServer.py~HEAD b/anknotes/thrift/server/THttpServer.py~HEAD deleted file mode 100644 index 3047d9c..0000000 --- a/anknotes/thrift/server/THttpServer.py~HEAD +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import BaseHTTPServer - -from thrift.server import TServer -from thrift.transport import TTransport - -class ResponseException(Exception): - """Allows handlers to override the HTTP response - - Normally, THttpServer always sends a 200 response. If a handler wants - to override this behavior (e.g., to simulate a misconfigured or - overloaded web server during testing), it can raise a ResponseException. - The function passed to the constructor will be called with the - RequestHandler as its only argument. - """ - def __init__(self, handler): - self.handler = handler - - -class THttpServer(TServer.TServer): - """A simple HTTP-based Thrift server - - This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint.""" - - def __init__(self, processor, server_address, - inputProtocolFactory, outputProtocolFactory = None, - server_class = BaseHTTPServer.HTTPServer): - """Set up protocol factories and HTTP server. - - See BaseHTTPServer for server_address. - See TServer for protocol factories.""" - - if outputProtocolFactory is None: - outputProtocolFactory = inputProtocolFactory - - TServer.TServer.__init__(self, processor, None, None, None, - inputProtocolFactory, outputProtocolFactory) - - thttpserver = self - - class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): - def do_POST(self): - # Don't care about the request path. - itrans = TTransport.TFileObjectTransport(self.rfile) - otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) - otrans = TTransport.TMemoryBuffer() - iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) - oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) - try: - thttpserver.processor.process(iprot, oprot) - except ResponseException, exn: - exn.handler(self) - else: - self.send_response(200) - self.send_header("content-type", "application/x-thrift") - self.end_headers() - self.wfile.write(otrans.getvalue()) - - self.httpd = server_class(server_address, RequestHander) - - def serve(self): - self.httpd.serve_forever() diff --git a/anknotes/thrift/server/TNonblockingServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/server/TNonblockingServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index ea348a0..0000000 --- a/anknotes/thrift/server/TNonblockingServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,310 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""Implementation of non-blocking server. - -The main idea of the server is reciving and sending requests -only from main thread. - -It also makes thread pool server in tasks terms, not connections. -""" -import threading -import socket -import Queue -import select -import struct -import logging - -from thrift.transport import TTransport -from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory - -__all__ = ['TNonblockingServer'] - -class Worker(threading.Thread): - """Worker is a small helper to process incoming connection.""" - def __init__(self, queue): - threading.Thread.__init__(self) - self.queue = queue - - def run(self): - """Process queries from task queue, stop if processor is None.""" - while True: - try: - processor, iprot, oprot, otrans, callback = self.queue.get() - if processor is None: - break - processor.process(iprot, oprot) - callback(True, otrans.getvalue()) - except Exception: - logging.exception("Exception while processing request") - callback(False, '') - -WAIT_LEN = 0 -WAIT_MESSAGE = 1 -WAIT_PROCESS = 2 -SEND_ANSWER = 3 -CLOSED = 4 - -def locked(func): - "Decorator which locks self.lock." - def nested(self, *args, **kwargs): - self.lock.acquire() - try: - return func(self, *args, **kwargs) - finally: - self.lock.release() - return nested - -def socket_exception(func): - "Decorator close object on socket.error." - def read(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except socket.error: - self.close() - return read - -class Connection: - """Basic class is represented connection. - - It can be in state: - WAIT_LEN --- connection is reading request len. - WAIT_MESSAGE --- connection is reading request. - WAIT_PROCESS --- connection has just read whole request and - waits for call ready routine. - SEND_ANSWER --- connection is sending answer string (including length - of answer). - CLOSED --- socket was closed and connection should be deleted. - """ - def __init__(self, new_socket, wake_up): - self.socket = new_socket - self.socket.setblocking(False) - self.status = WAIT_LEN - self.len = 0 - self.message = '' - self.lock = threading.Lock() - self.wake_up = wake_up - - def _read_len(self): - """Reads length of request. - - It's really paranoic routine and it may be replaced by - self.socket.recv(4).""" - read = self.socket.recv(4 - len(self.message)) - if len(read) == 0: - # if we read 0 bytes and self.message is empty, it means client close - # connection - if len(self.message) != 0: - logging.error("can't read frame size from socket") - self.close() - return - self.message += read - if len(self.message) == 4: - self.len, = struct.unpack('!i', self.message) - if self.len < 0: - logging.error("negative frame size, it seems client"\ - " doesn't use FramedTransport") - self.close() - elif self.len == 0: - logging.error("empty frame, it's really strange") - self.close() - else: - self.message = '' - self.status = WAIT_MESSAGE - - @socket_exception - def read(self): - """Reads data from stream and switch state.""" - assert self.status in (WAIT_LEN, WAIT_MESSAGE) - if self.status == WAIT_LEN: - self._read_len() - # go back to the main loop here for simplicity instead of - # falling through, even though there is a good chance that - # the message is already available - elif self.status == WAIT_MESSAGE: - read = self.socket.recv(self.len - len(self.message)) - if len(read) == 0: - logging.error("can't read frame from socket (get %d of %d bytes)" % - (len(self.message), self.len)) - self.close() - return - self.message += read - if len(self.message) == self.len: - self.status = WAIT_PROCESS - - @socket_exception - def write(self): - """Writes data from socket and switch state.""" - assert self.status == SEND_ANSWER - sent = self.socket.send(self.message) - if sent == len(self.message): - self.status = WAIT_LEN - self.message = '' - self.len = 0 - else: - self.message = self.message[sent:] - - @locked - def ready(self, all_ok, message): - """Callback function for switching state and waking up main thread. - - This function is the only function witch can be called asynchronous. - - The ready can switch Connection to three states: - WAIT_LEN if request was oneway. - SEND_ANSWER if request was processed in normal way. - CLOSED if request throws unexpected exception. - - The one wakes up main thread. - """ - assert self.status == WAIT_PROCESS - if not all_ok: - self.close() - self.wake_up() - return - self.len = '' - if len(message) == 0: - # it was a oneway request, do not write answer - self.message = '' - self.status = WAIT_LEN - else: - self.message = struct.pack('!i', len(message)) + message - self.status = SEND_ANSWER - self.wake_up() - - @locked - def is_writeable(self): - "Returns True if connection should be added to write list of select." - return self.status == SEND_ANSWER - - # it's not necessary, but... - @locked - def is_readable(self): - "Returns True if connection should be added to read list of select." - return self.status in (WAIT_LEN, WAIT_MESSAGE) - - @locked - def is_closed(self): - "Returns True if connection is closed." - return self.status == CLOSED - - def fileno(self): - "Returns the file descriptor of the associated socket." - return self.socket.fileno() - - def close(self): - "Closes connection" - self.status = CLOSED - self.socket.close() - -class TNonblockingServer: - """Non-blocking server.""" - def __init__(self, processor, lsocket, inputProtocolFactory=None, - outputProtocolFactory=None, threads=10): - self.processor = processor - self.socket = lsocket - self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() - self.out_protocol = outputProtocolFactory or self.in_protocol - self.threads = int(threads) - self.clients = {} - self.tasks = Queue.Queue() - self._read, self._write = socket.socketpair() - self.prepared = False - - def setNumThreads(self, num): - """Set the number of worker threads that should be created.""" - # implement ThreadPool interface - assert not self.prepared, "You can't change number of threads for working server" - self.threads = num - - def prepare(self): - """Prepares server for serve requests.""" - self.socket.listen() - for _ in xrange(self.threads): - thread = Worker(self.tasks) - thread.setDaemon(True) - thread.start() - self.prepared = True - - def wake_up(self): - """Wake up main thread. - - The server usualy waits in select call in we should terminate one. - The simplest way is using socketpair. - - Select always wait to read from the first socket of socketpair. - - In this case, we can just write anything to the second socket from - socketpair.""" - self._write.send('1') - - def _select(self): - """Does select on open connections.""" - readable = [self.socket.handle.fileno(), self._read.fileno()] - writable = [] - for i, connection in self.clients.items(): - if connection.is_readable(): - readable.append(connection.fileno()) - if connection.is_writeable(): - writable.append(connection.fileno()) - if connection.is_closed(): - del self.clients[i] - return select.select(readable, writable, readable) - - def handle(self): - """Handle requests. - - WARNING! You must call prepare BEFORE calling handle. - """ - assert self.prepared, "You have to call prepare before handle" - rset, wset, xset = self._select() - for readable in rset: - if readable == self._read.fileno(): - # don't care i just need to clean readable flag - self._read.recv(1024) - elif readable == self.socket.handle.fileno(): - client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, self.wake_up) - else: - connection = self.clients[readable] - connection.read() - if connection.status == WAIT_PROCESS: - itransport = TTransport.TMemoryBuffer(connection.message) - otransport = TTransport.TMemoryBuffer() - iprot = self.in_protocol.getProtocol(itransport) - oprot = self.out_protocol.getProtocol(otransport) - self.tasks.put([self.processor, iprot, oprot, - otransport, connection.ready]) - for writeable in wset: - self.clients[writeable].write() - for oob in xset: - self.clients[oob].close() - del self.clients[oob] - - def close(self): - """Closes the server.""" - for _ in xrange(self.threads): - self.tasks.put([None, None, None, None, None]) - self.socket.close() - self.prepared = False - - def serve(self): - """Serve forever.""" - self.prepare() - while True: - self.handle() diff --git a/anknotes/thrift/server/TNonblockingServer.py~HEAD b/anknotes/thrift/server/TNonblockingServer.py~HEAD deleted file mode 100644 index ea348a0..0000000 --- a/anknotes/thrift/server/TNonblockingServer.py~HEAD +++ /dev/null @@ -1,310 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""Implementation of non-blocking server. - -The main idea of the server is reciving and sending requests -only from main thread. - -It also makes thread pool server in tasks terms, not connections. -""" -import threading -import socket -import Queue -import select -import struct -import logging - -from thrift.transport import TTransport -from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory - -__all__ = ['TNonblockingServer'] - -class Worker(threading.Thread): - """Worker is a small helper to process incoming connection.""" - def __init__(self, queue): - threading.Thread.__init__(self) - self.queue = queue - - def run(self): - """Process queries from task queue, stop if processor is None.""" - while True: - try: - processor, iprot, oprot, otrans, callback = self.queue.get() - if processor is None: - break - processor.process(iprot, oprot) - callback(True, otrans.getvalue()) - except Exception: - logging.exception("Exception while processing request") - callback(False, '') - -WAIT_LEN = 0 -WAIT_MESSAGE = 1 -WAIT_PROCESS = 2 -SEND_ANSWER = 3 -CLOSED = 4 - -def locked(func): - "Decorator which locks self.lock." - def nested(self, *args, **kwargs): - self.lock.acquire() - try: - return func(self, *args, **kwargs) - finally: - self.lock.release() - return nested - -def socket_exception(func): - "Decorator close object on socket.error." - def read(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except socket.error: - self.close() - return read - -class Connection: - """Basic class is represented connection. - - It can be in state: - WAIT_LEN --- connection is reading request len. - WAIT_MESSAGE --- connection is reading request. - WAIT_PROCESS --- connection has just read whole request and - waits for call ready routine. - SEND_ANSWER --- connection is sending answer string (including length - of answer). - CLOSED --- socket was closed and connection should be deleted. - """ - def __init__(self, new_socket, wake_up): - self.socket = new_socket - self.socket.setblocking(False) - self.status = WAIT_LEN - self.len = 0 - self.message = '' - self.lock = threading.Lock() - self.wake_up = wake_up - - def _read_len(self): - """Reads length of request. - - It's really paranoic routine and it may be replaced by - self.socket.recv(4).""" - read = self.socket.recv(4 - len(self.message)) - if len(read) == 0: - # if we read 0 bytes and self.message is empty, it means client close - # connection - if len(self.message) != 0: - logging.error("can't read frame size from socket") - self.close() - return - self.message += read - if len(self.message) == 4: - self.len, = struct.unpack('!i', self.message) - if self.len < 0: - logging.error("negative frame size, it seems client"\ - " doesn't use FramedTransport") - self.close() - elif self.len == 0: - logging.error("empty frame, it's really strange") - self.close() - else: - self.message = '' - self.status = WAIT_MESSAGE - - @socket_exception - def read(self): - """Reads data from stream and switch state.""" - assert self.status in (WAIT_LEN, WAIT_MESSAGE) - if self.status == WAIT_LEN: - self._read_len() - # go back to the main loop here for simplicity instead of - # falling through, even though there is a good chance that - # the message is already available - elif self.status == WAIT_MESSAGE: - read = self.socket.recv(self.len - len(self.message)) - if len(read) == 0: - logging.error("can't read frame from socket (get %d of %d bytes)" % - (len(self.message), self.len)) - self.close() - return - self.message += read - if len(self.message) == self.len: - self.status = WAIT_PROCESS - - @socket_exception - def write(self): - """Writes data from socket and switch state.""" - assert self.status == SEND_ANSWER - sent = self.socket.send(self.message) - if sent == len(self.message): - self.status = WAIT_LEN - self.message = '' - self.len = 0 - else: - self.message = self.message[sent:] - - @locked - def ready(self, all_ok, message): - """Callback function for switching state and waking up main thread. - - This function is the only function witch can be called asynchronous. - - The ready can switch Connection to three states: - WAIT_LEN if request was oneway. - SEND_ANSWER if request was processed in normal way. - CLOSED if request throws unexpected exception. - - The one wakes up main thread. - """ - assert self.status == WAIT_PROCESS - if not all_ok: - self.close() - self.wake_up() - return - self.len = '' - if len(message) == 0: - # it was a oneway request, do not write answer - self.message = '' - self.status = WAIT_LEN - else: - self.message = struct.pack('!i', len(message)) + message - self.status = SEND_ANSWER - self.wake_up() - - @locked - def is_writeable(self): - "Returns True if connection should be added to write list of select." - return self.status == SEND_ANSWER - - # it's not necessary, but... - @locked - def is_readable(self): - "Returns True if connection should be added to read list of select." - return self.status in (WAIT_LEN, WAIT_MESSAGE) - - @locked - def is_closed(self): - "Returns True if connection is closed." - return self.status == CLOSED - - def fileno(self): - "Returns the file descriptor of the associated socket." - return self.socket.fileno() - - def close(self): - "Closes connection" - self.status = CLOSED - self.socket.close() - -class TNonblockingServer: - """Non-blocking server.""" - def __init__(self, processor, lsocket, inputProtocolFactory=None, - outputProtocolFactory=None, threads=10): - self.processor = processor - self.socket = lsocket - self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() - self.out_protocol = outputProtocolFactory or self.in_protocol - self.threads = int(threads) - self.clients = {} - self.tasks = Queue.Queue() - self._read, self._write = socket.socketpair() - self.prepared = False - - def setNumThreads(self, num): - """Set the number of worker threads that should be created.""" - # implement ThreadPool interface - assert not self.prepared, "You can't change number of threads for working server" - self.threads = num - - def prepare(self): - """Prepares server for serve requests.""" - self.socket.listen() - for _ in xrange(self.threads): - thread = Worker(self.tasks) - thread.setDaemon(True) - thread.start() - self.prepared = True - - def wake_up(self): - """Wake up main thread. - - The server usualy waits in select call in we should terminate one. - The simplest way is using socketpair. - - Select always wait to read from the first socket of socketpair. - - In this case, we can just write anything to the second socket from - socketpair.""" - self._write.send('1') - - def _select(self): - """Does select on open connections.""" - readable = [self.socket.handle.fileno(), self._read.fileno()] - writable = [] - for i, connection in self.clients.items(): - if connection.is_readable(): - readable.append(connection.fileno()) - if connection.is_writeable(): - writable.append(connection.fileno()) - if connection.is_closed(): - del self.clients[i] - return select.select(readable, writable, readable) - - def handle(self): - """Handle requests. - - WARNING! You must call prepare BEFORE calling handle. - """ - assert self.prepared, "You have to call prepare before handle" - rset, wset, xset = self._select() - for readable in rset: - if readable == self._read.fileno(): - # don't care i just need to clean readable flag - self._read.recv(1024) - elif readable == self.socket.handle.fileno(): - client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, self.wake_up) - else: - connection = self.clients[readable] - connection.read() - if connection.status == WAIT_PROCESS: - itransport = TTransport.TMemoryBuffer(connection.message) - otransport = TTransport.TMemoryBuffer() - iprot = self.in_protocol.getProtocol(itransport) - oprot = self.out_protocol.getProtocol(otransport) - self.tasks.put([self.processor, iprot, oprot, - otransport, connection.ready]) - for writeable in wset: - self.clients[writeable].write() - for oob in xset: - self.clients[oob].close() - del self.clients[oob] - - def close(self): - """Closes the server.""" - for _ in xrange(self.threads): - self.tasks.put([None, None, None, None, None]) - self.socket.close() - self.prepared = False - - def serve(self): - """Serve forever.""" - self.prepare() - while True: - self.handle() diff --git a/anknotes/thrift/server/TProcessPoolServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/server/TProcessPoolServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 7ed814a..0000000 --- a/anknotes/thrift/server/TProcessPoolServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,125 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - - -import logging -from multiprocessing import Process, Value, Condition, reduction - -from TServer import TServer -from thrift.transport.TTransport import TTransportException - -class TProcessPoolServer(TServer): - - """ - Server with a fixed size pool of worker subprocesses which service requests. - Note that if you need shared state between the handlers - it's up to you! - Written by Dvir Volk, doat.com - """ - - def __init__(self, * args): - TServer.__init__(self, *args) - self.numWorkers = 10 - self.workers = [] - self.isRunning = Value('b', False) - self.stopCondition = Condition() - self.postForkCallback = None - - def setPostForkCallback(self, callback): - if not callable(callback): - raise TypeError("This is not a callback!") - self.postForkCallback = callback - - def setNumWorkers(self, num): - """Set the number of worker threads that should be created""" - self.numWorkers = num - - def workerProcess(self): - """Loop around getting clients from the shared queue and process them.""" - - if self.postForkCallback: - self.postForkCallback() - - while self.isRunning.value == True: - try: - client = self.serverTransport.accept() - self.serveClient(client) - except (KeyboardInterrupt, SystemExit): - return 0 - except Exception, x: - logging.exception(x) - - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - try: - while True: - self.processor.process(iprot, oprot) - except TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - - - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - - #this is a shared state that can tell the workers to exit when set as false - self.isRunning.value = True - - #first bind and listen to the port - self.serverTransport.listen() - - #fork the children - for i in range(self.numWorkers): - try: - w = Process(target=self.workerProcess) - w.daemon = True - w.start() - self.workers.append(w) - except Exception, x: - logging.exception(x) - - #wait until the condition is set by stop() - - while True: - - self.stopCondition.acquire() - try: - self.stopCondition.wait() - break - except (SystemExit, KeyboardInterrupt): - break - except Exception, x: - logging.exception(x) - - self.isRunning.value = False - - def stop(self): - self.isRunning.value = False - self.stopCondition.acquire() - self.stopCondition.notify() - self.stopCondition.release() - diff --git a/anknotes/thrift/server/TProcessPoolServer.py~HEAD b/anknotes/thrift/server/TProcessPoolServer.py~HEAD deleted file mode 100644 index 7ed814a..0000000 --- a/anknotes/thrift/server/TProcessPoolServer.py~HEAD +++ /dev/null @@ -1,125 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - - -import logging -from multiprocessing import Process, Value, Condition, reduction - -from TServer import TServer -from thrift.transport.TTransport import TTransportException - -class TProcessPoolServer(TServer): - - """ - Server with a fixed size pool of worker subprocesses which service requests. - Note that if you need shared state between the handlers - it's up to you! - Written by Dvir Volk, doat.com - """ - - def __init__(self, * args): - TServer.__init__(self, *args) - self.numWorkers = 10 - self.workers = [] - self.isRunning = Value('b', False) - self.stopCondition = Condition() - self.postForkCallback = None - - def setPostForkCallback(self, callback): - if not callable(callback): - raise TypeError("This is not a callback!") - self.postForkCallback = callback - - def setNumWorkers(self, num): - """Set the number of worker threads that should be created""" - self.numWorkers = num - - def workerProcess(self): - """Loop around getting clients from the shared queue and process them.""" - - if self.postForkCallback: - self.postForkCallback() - - while self.isRunning.value == True: - try: - client = self.serverTransport.accept() - self.serveClient(client) - except (KeyboardInterrupt, SystemExit): - return 0 - except Exception, x: - logging.exception(x) - - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - try: - while True: - self.processor.process(iprot, oprot) - except TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - - - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - - #this is a shared state that can tell the workers to exit when set as false - self.isRunning.value = True - - #first bind and listen to the port - self.serverTransport.listen() - - #fork the children - for i in range(self.numWorkers): - try: - w = Process(target=self.workerProcess) - w.daemon = True - w.start() - self.workers.append(w) - except Exception, x: - logging.exception(x) - - #wait until the condition is set by stop() - - while True: - - self.stopCondition.acquire() - try: - self.stopCondition.wait() - break - except (SystemExit, KeyboardInterrupt): - break - except Exception, x: - logging.exception(x) - - self.isRunning.value = False - - def stop(self): - self.isRunning.value = False - self.stopCondition.acquire() - self.stopCondition.notify() - self.stopCondition.release() - diff --git a/anknotes/thrift/server/TServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/server/TServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 8456e2d..0000000 --- a/anknotes/thrift/server/TServer.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,274 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import logging -import sys -import os -import traceback -import threading -import Queue - -from thrift.Thrift import TProcessor -from thrift.transport import TTransport -from thrift.protocol import TBinaryProtocol - -class TServer: - - """Base interface for a server, which must have a serve method.""" - - """ 3 constructors for all servers: - 1) (processor, serverTransport) - 2) (processor, serverTransport, transportFactory, protocolFactory) - 3) (processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory)""" - def __init__(self, *args): - if (len(args) == 2): - self.__initArgs__(args[0], args[1], - TTransport.TTransportFactoryBase(), - TTransport.TTransportFactoryBase(), - TBinaryProtocol.TBinaryProtocolFactory(), - TBinaryProtocol.TBinaryProtocolFactory()) - elif (len(args) == 4): - self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) - elif (len(args) == 6): - self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) - - def __initArgs__(self, processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory): - self.processor = processor - self.serverTransport = serverTransport - self.inputTransportFactory = inputTransportFactory - self.outputTransportFactory = outputTransportFactory - self.inputProtocolFactory = inputProtocolFactory - self.outputProtocolFactory = outputProtocolFactory - - def serve(self): - pass - -class TSimpleServer(TServer): - - """Simple single-threaded server that just pumps around one transport.""" - - def __init__(self, *args): - TServer.__init__(self, *args) - - def serve(self): - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - -class TThreadedServer(TServer): - - """Threaded server that spawns a new thread per each connection.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.daemon = kwargs.get("daemon", False) - - def serve(self): - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - t = threading.Thread(target = self.handle, args=(client,)) - t.setDaemon(self.daemon) - t.start() - except KeyboardInterrupt: - raise - except Exception, x: - logging.exception(x) - - def handle(self, client): - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - -class TThreadPoolServer(TServer): - - """Server with a fixed size pool of threads which service requests.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.clients = Queue.Queue() - self.threads = 10 - self.daemon = kwargs.get("daemon", False) - - def setNumThreads(self, num): - """Set the number of worker threads that should be created""" - self.threads = num - - def serveThread(self): - """Loop around getting clients from the shared queue and process them.""" - while True: - try: - client = self.clients.get() - self.serveClient(client) - except Exception, x: - logging.exception(x) - - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - for i in range(self.threads): - try: - t = threading.Thread(target = self.serveThread) - t.setDaemon(self.daemon) - t.start() - except Exception, x: - logging.exception(x) - - # Pump the socket for clients - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - self.clients.put(client) - except Exception, x: - logging.exception(x) - - -class TForkingServer(TServer): - - """A Thrift server that forks a new process for each request""" - """ - This is more scalable than the threaded server as it does not cause - GIL contention. - - Note that this has different semantics from the threading server. - Specifically, updates to shared variables will no longer be shared. - It will also not work on windows. - - This code is heavily inspired by SocketServer.ForkingMixIn in the - Python stdlib. - """ - - def __init__(self, *args): - TServer.__init__(self, *args) - self.children = [] - - def serve(self): - def try_close(file): - try: - file.close() - except IOError, e: - logging.warning(e, exc_info=True) - - - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - try: - pid = os.fork() - - if pid: # parent - # add before collect, otherwise you race w/ waitpid - self.children.append(pid) - self.collect_children() - - # Parent must close socket or the connection may not get - # closed promptly - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - try_close(itrans) - try_close(otrans) - else: - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - ecode = 0 - try: - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, e: - logging.exception(e) - ecode = 1 - finally: - try_close(itrans) - try_close(otrans) - - os._exit(ecode) - - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - - def collect_children(self): - while self.children: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except os.error: - pid = None - - if pid: - self.children.remove(pid) - else: - break - - diff --git a/anknotes/thrift/server/TServer.py~HEAD b/anknotes/thrift/server/TServer.py~HEAD deleted file mode 100644 index 8456e2d..0000000 --- a/anknotes/thrift/server/TServer.py~HEAD +++ /dev/null @@ -1,274 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import logging -import sys -import os -import traceback -import threading -import Queue - -from thrift.Thrift import TProcessor -from thrift.transport import TTransport -from thrift.protocol import TBinaryProtocol - -class TServer: - - """Base interface for a server, which must have a serve method.""" - - """ 3 constructors for all servers: - 1) (processor, serverTransport) - 2) (processor, serverTransport, transportFactory, protocolFactory) - 3) (processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory)""" - def __init__(self, *args): - if (len(args) == 2): - self.__initArgs__(args[0], args[1], - TTransport.TTransportFactoryBase(), - TTransport.TTransportFactoryBase(), - TBinaryProtocol.TBinaryProtocolFactory(), - TBinaryProtocol.TBinaryProtocolFactory()) - elif (len(args) == 4): - self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) - elif (len(args) == 6): - self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) - - def __initArgs__(self, processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory): - self.processor = processor - self.serverTransport = serverTransport - self.inputTransportFactory = inputTransportFactory - self.outputTransportFactory = outputTransportFactory - self.inputProtocolFactory = inputProtocolFactory - self.outputProtocolFactory = outputProtocolFactory - - def serve(self): - pass - -class TSimpleServer(TServer): - - """Simple single-threaded server that just pumps around one transport.""" - - def __init__(self, *args): - TServer.__init__(self, *args) - - def serve(self): - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - -class TThreadedServer(TServer): - - """Threaded server that spawns a new thread per each connection.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.daemon = kwargs.get("daemon", False) - - def serve(self): - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - t = threading.Thread(target = self.handle, args=(client,)) - t.setDaemon(self.daemon) - t.start() - except KeyboardInterrupt: - raise - except Exception, x: - logging.exception(x) - - def handle(self, client): - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - -class TThreadPoolServer(TServer): - - """Server with a fixed size pool of threads which service requests.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.clients = Queue.Queue() - self.threads = 10 - self.daemon = kwargs.get("daemon", False) - - def setNumThreads(self, num): - """Set the number of worker threads that should be created""" - self.threads = num - - def serveThread(self): - """Loop around getting clients from the shared queue and process them.""" - while True: - try: - client = self.clients.get() - self.serveClient(client) - except Exception, x: - logging.exception(x) - - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - itrans.close() - otrans.close() - - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - for i in range(self.threads): - try: - t = threading.Thread(target = self.serveThread) - t.setDaemon(self.daemon) - t.start() - except Exception, x: - logging.exception(x) - - # Pump the socket for clients - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - self.clients.put(client) - except Exception, x: - logging.exception(x) - - -class TForkingServer(TServer): - - """A Thrift server that forks a new process for each request""" - """ - This is more scalable than the threaded server as it does not cause - GIL contention. - - Note that this has different semantics from the threading server. - Specifically, updates to shared variables will no longer be shared. - It will also not work on windows. - - This code is heavily inspired by SocketServer.ForkingMixIn in the - Python stdlib. - """ - - def __init__(self, *args): - TServer.__init__(self, *args) - self.children = [] - - def serve(self): - def try_close(file): - try: - file.close() - except IOError, e: - logging.warning(e, exc_info=True) - - - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - try: - pid = os.fork() - - if pid: # parent - # add before collect, otherwise you race w/ waitpid - self.children.append(pid) - self.collect_children() - - # Parent must close socket or the connection may not get - # closed promptly - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - try_close(itrans) - try_close(otrans) - else: - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - ecode = 0 - try: - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException, tx: - pass - except Exception, e: - logging.exception(e) - ecode = 1 - finally: - try_close(itrans) - try_close(otrans) - - os._exit(ecode) - - except TTransport.TTransportException, tx: - pass - except Exception, x: - logging.exception(x) - - - def collect_children(self): - while self.children: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except os.error: - pid = None - - if pid: - self.children.remove(pid) - else: - break - - diff --git a/anknotes/thrift/server/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/server/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 1bf6e25..0000000 --- a/anknotes/thrift/server/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TServer', 'TNonblockingServer'] diff --git a/anknotes/thrift/server/__init__.py~HEAD b/anknotes/thrift/server/__init__.py~HEAD deleted file mode 100644 index 1bf6e25..0000000 --- a/anknotes/thrift/server/__init__.py~HEAD +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TServer', 'TNonblockingServer'] diff --git a/anknotes/thrift/transport/THttpClient.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/THttpClient.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index d74baa4..0000000 --- a/anknotes/thrift/transport/THttpClient.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,161 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TTransport import TTransportBase -from cStringIO import StringIO - -import urlparse -import httplib -import warnings -import socket - - -class THttpClient(TTransportBase): - - """Http implementation of TTransport base.""" - - def __init__( - self, - uri_or_host, - port=None, - path=None, - proxy_host=None, - proxy_port=None - ): - """THttpClient supports two different types constructor parameters. - - THttpClient(host, port, path) - deprecated - THttpClient(uri) - - Only the second supports https.""" - - """THttpClient supports proxy - THttpClient(host, port, path, proxy_host, proxy_port) - deprecated - ThttpClient(uri, None, None, proxy_host, proxy_port)""" - - if port is not None: - warnings.warn( - "Please use the THttpClient('http://host:port/path') syntax", - DeprecationWarning, - stacklevel=2) - self.host = uri_or_host - self.port = port - assert path - self.path = path - self.scheme = 'http' - else: - parsed = urlparse.urlparse(uri_or_host) - self.scheme = parsed.scheme - assert self.scheme in ('http', 'https') - if self.scheme == 'http': - self.port = parsed.port or httplib.HTTP_PORT - elif self.scheme == 'https': - self.port = parsed.port or httplib.HTTPS_PORT - self.host = parsed.hostname - self.path = parsed.path - if parsed.query: - self.path += '?%s' % parsed.query - - if proxy_host is not None and proxy_port is not None: - self.endpoint_host = proxy_host - self.endpoint_port = proxy_port - self.path = urlparse.urlunparse(( - self.scheme, - "%s:%i" % (self.host, self.port), - self.path, - None, - None, - None - )) - else: - self.endpoint_host = self.host - self.endpoint_port = self.port - - self.__wbuf = StringIO() - self.__http = None - self.__timeout = None - self.__headers = {} - - def open(self): - protocol = httplib.HTTP if self.scheme == 'http' else httplib.HTTPS - self.__http = protocol(self.endpoint_host, self.endpoint_port) - - def close(self): - self.__http.close() - self.__http = None - - def isOpen(self): - return self.__http is not None - - def setTimeout(self, ms): - if not hasattr(socket, 'getdefaulttimeout'): - raise NotImplementedError - - if ms is None: - self.__timeout = None - else: - self.__timeout = ms / 1000.0 - - def read(self, sz): - return self.__http.file.read(sz) - - def write(self, buf): - self.__wbuf.write(buf) - - def __withTimeout(f): - def _f(*args, **kwargs): - orig_timeout = socket.getdefaulttimeout() - socket.setdefaulttimeout(args[0].__timeout) - result = f(*args, **kwargs) - socket.setdefaulttimeout(orig_timeout) - return result - return _f - - def addHeaders(self, **kwargs): - self.__headers.update(kwargs) - - def flush(self): - if self.isOpen(): - self.close() - self.open() - - # Pull data out of buffer - data = self.__wbuf.getvalue() - self.__wbuf = StringIO() - - # HTTP request - self.__http.putrequest('POST', self.path) - - # Write headers - self.__http.putheader('Host', self.host) - self.__http.putheader('Content-Type', 'application/x-thrift') - self.__http.putheader('Content-Length', str(len(data))) - for key, value in self.__headers.iteritems(): - self.__http.putheader(key, value) - self.__http.endheaders() - - # Write payload - self.__http.send(data) - - # Get reply to flush the request - self.code, self.message, self.headers = self.__http.getreply() - - # Decorate if we know how to timeout - if hasattr(socket, 'getdefaulttimeout'): - flush = __withTimeout(flush) diff --git a/anknotes/thrift/transport/THttpClient.py~HEAD b/anknotes/thrift/transport/THttpClient.py~HEAD deleted file mode 100644 index d74baa4..0000000 --- a/anknotes/thrift/transport/THttpClient.py~HEAD +++ /dev/null @@ -1,161 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TTransport import TTransportBase -from cStringIO import StringIO - -import urlparse -import httplib -import warnings -import socket - - -class THttpClient(TTransportBase): - - """Http implementation of TTransport base.""" - - def __init__( - self, - uri_or_host, - port=None, - path=None, - proxy_host=None, - proxy_port=None - ): - """THttpClient supports two different types constructor parameters. - - THttpClient(host, port, path) - deprecated - THttpClient(uri) - - Only the second supports https.""" - - """THttpClient supports proxy - THttpClient(host, port, path, proxy_host, proxy_port) - deprecated - ThttpClient(uri, None, None, proxy_host, proxy_port)""" - - if port is not None: - warnings.warn( - "Please use the THttpClient('http://host:port/path') syntax", - DeprecationWarning, - stacklevel=2) - self.host = uri_or_host - self.port = port - assert path - self.path = path - self.scheme = 'http' - else: - parsed = urlparse.urlparse(uri_or_host) - self.scheme = parsed.scheme - assert self.scheme in ('http', 'https') - if self.scheme == 'http': - self.port = parsed.port or httplib.HTTP_PORT - elif self.scheme == 'https': - self.port = parsed.port or httplib.HTTPS_PORT - self.host = parsed.hostname - self.path = parsed.path - if parsed.query: - self.path += '?%s' % parsed.query - - if proxy_host is not None and proxy_port is not None: - self.endpoint_host = proxy_host - self.endpoint_port = proxy_port - self.path = urlparse.urlunparse(( - self.scheme, - "%s:%i" % (self.host, self.port), - self.path, - None, - None, - None - )) - else: - self.endpoint_host = self.host - self.endpoint_port = self.port - - self.__wbuf = StringIO() - self.__http = None - self.__timeout = None - self.__headers = {} - - def open(self): - protocol = httplib.HTTP if self.scheme == 'http' else httplib.HTTPS - self.__http = protocol(self.endpoint_host, self.endpoint_port) - - def close(self): - self.__http.close() - self.__http = None - - def isOpen(self): - return self.__http is not None - - def setTimeout(self, ms): - if not hasattr(socket, 'getdefaulttimeout'): - raise NotImplementedError - - if ms is None: - self.__timeout = None - else: - self.__timeout = ms / 1000.0 - - def read(self, sz): - return self.__http.file.read(sz) - - def write(self, buf): - self.__wbuf.write(buf) - - def __withTimeout(f): - def _f(*args, **kwargs): - orig_timeout = socket.getdefaulttimeout() - socket.setdefaulttimeout(args[0].__timeout) - result = f(*args, **kwargs) - socket.setdefaulttimeout(orig_timeout) - return result - return _f - - def addHeaders(self, **kwargs): - self.__headers.update(kwargs) - - def flush(self): - if self.isOpen(): - self.close() - self.open() - - # Pull data out of buffer - data = self.__wbuf.getvalue() - self.__wbuf = StringIO() - - # HTTP request - self.__http.putrequest('POST', self.path) - - # Write headers - self.__http.putheader('Host', self.host) - self.__http.putheader('Content-Type', 'application/x-thrift') - self.__http.putheader('Content-Length', str(len(data))) - for key, value in self.__headers.iteritems(): - self.__http.putheader(key, value) - self.__http.endheaders() - - # Write payload - self.__http.send(data) - - # Get reply to flush the request - self.code, self.message, self.headers = self.__http.getreply() - - # Decorate if we know how to timeout - if hasattr(socket, 'getdefaulttimeout'): - flush = __withTimeout(flush) diff --git a/anknotes/thrift/transport/TSSLSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/TSSLSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index be35844..0000000 --- a/anknotes/thrift/transport/TSSLSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,176 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -import socket -import ssl - -from thrift.transport import TSocket -from thrift.transport.TTransport import TTransportException - -class TSSLSocket(TSocket.TSocket): - """ - SSL implementation of client-side TSocket - - This class creates outbound sockets wrapped using the - python standard ssl module for encrypted connections. - - The protocol used is set using the class variable - SSL_VERSION, which must be one of ssl.PROTOCOL_* and - defaults to ssl.PROTOCOL_TLSv1 for greatest security. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 - - def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, unix_socket=None): - """ - @param validate: Set to False to disable SSL certificate validation entirely. - @type validate: bool - @param ca_certs: Filename to the Certificate Authority pem file, possibly a - file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to - the ssl_wrap function as the 'ca_certs' parameter. - @type ca_certs: str - - Raises an IOError exception if validate is True and the ca_certs file is - None, not present or unreadable. - """ - self.validate = validate - self.is_valid = False - self.peercert = None - if not validate: - self.cert_reqs = ssl.CERT_NONE - else: - self.cert_reqs = ssl.CERT_REQUIRED - self.ca_certs = ca_certs - if validate: - if ca_certs is None or not os.access(ca_certs, os.R_OK): - raise IOError('Certificate Authority ca_certs file "%s" is not readable, cannot validate SSL certificates.' % (ca_certs)) - TSocket.TSocket.__init__(self, host, port, unix_socket) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - sock_family, sock_type= res[0:2] - ip_port = res[4] - plain_sock = socket.socket(sock_family, sock_type) - self.handle = ssl.wrap_socket(plain_sock, ssl_version=self.SSL_VERSION, - do_handshake_on_connect=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs) - self.handle.settimeout(self._timeout) - try: - self.handle.connect(ip_port) - except socket.error, e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error, e: - if self._unix_socket: - message = 'Could not connect to secure socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) - if self.validate: - self._validate_cert() - - def _validate_cert(self): - """internal method to validate the peer's SSL certificate, and to check the - commonName of the certificate to ensure it matches the hostname we - used to make this connection. Does not support subjectAltName records - in certificates. - - raises TTransportException if the certificate fails validation.""" - cert = self.handle.getpeercert() - self.peercert = cert - if 'subject' not in cert: - raise TTransportException(type=TTransportException.NOT_OPEN, - message='No SSL certificate found from %s:%s' % (self.host, self.port)) - fields = cert['subject'] - for field in fields: - # ensure structure we get back is what we expect - if not isinstance(field, tuple): - continue - cert_pair = field[0] - if len(cert_pair) < 2: - continue - cert_key, cert_value = cert_pair[0:2] - if cert_key != 'commonName': - continue - certhost = cert_value - if certhost == self.host: - # success, cert commonName matches desired hostname - self.is_valid = True - return - else: - raise TTransportException(type=TTransportException.UNKNOWN, - message='Host name we connected to "%s" doesn\'t match certificate provided commonName "%s"' % (self.host, certhost)) - raise TTransportException(type=TTransportException.UNKNOWN, - message='Could not validate SSL certificate from host "%s". Cert=%s' % (self.host, cert)) - -class TSSLServerSocket(TSocket.TServerSocket): - """ - SSL implementation of TServerSocket - - This uses the ssl module's wrap_socket() method to provide SSL - negotiated encryption. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 - - def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None): - """Initialize a TSSLServerSocket - - @param certfile: The filename of the server certificate file, defaults to cert.pem - @type certfile: str - @param host: The hostname or IP to bind the listen socket to, i.e. 'localhost' for only allowing - local network connections. Pass None to bind to all interfaces. - @type host: str - @param port: The port to listen on for inbound connections. - @type port: int - """ - self.setCertfile(certfile) - TSocket.TServerSocket.__init__(self, host, port) - - def setCertfile(self, certfile): - """Set or change the server certificate file used to wrap new connections. - - @param certfile: The filename of the server certificate, i.e. '/etc/certs/server.pem' - @type certfile: str - - Raises an IOError exception if the certfile is not present or unreadable. - """ - if not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % (certfile)) - self.certfile = certfile - - def accept(self): - plain_client, addr = self.handle.accept() - try: - client = ssl.wrap_socket(plain_client, certfile=self.certfile, - server_side=True, ssl_version=self.SSL_VERSION) - except ssl.SSLError, ssl_exc: - # failed handshake/ssl wrap, close socket to client - plain_client.close() - # raise ssl_exc - # We can't raise the exception, because it kills most TServer derived serve() - # methods. - # Instead, return None, and let the TServer instance deal with it in - # other exception handling. (but TSimpleServer dies anyway) - return None - result = TSocket.TSocket() - result.setHandle(client) - return result diff --git a/anknotes/thrift/transport/TSSLSocket.py~HEAD b/anknotes/thrift/transport/TSSLSocket.py~HEAD deleted file mode 100644 index be35844..0000000 --- a/anknotes/thrift/transport/TSSLSocket.py~HEAD +++ /dev/null @@ -1,176 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -import socket -import ssl - -from thrift.transport import TSocket -from thrift.transport.TTransport import TTransportException - -class TSSLSocket(TSocket.TSocket): - """ - SSL implementation of client-side TSocket - - This class creates outbound sockets wrapped using the - python standard ssl module for encrypted connections. - - The protocol used is set using the class variable - SSL_VERSION, which must be one of ssl.PROTOCOL_* and - defaults to ssl.PROTOCOL_TLSv1 for greatest security. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 - - def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, unix_socket=None): - """ - @param validate: Set to False to disable SSL certificate validation entirely. - @type validate: bool - @param ca_certs: Filename to the Certificate Authority pem file, possibly a - file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to - the ssl_wrap function as the 'ca_certs' parameter. - @type ca_certs: str - - Raises an IOError exception if validate is True and the ca_certs file is - None, not present or unreadable. - """ - self.validate = validate - self.is_valid = False - self.peercert = None - if not validate: - self.cert_reqs = ssl.CERT_NONE - else: - self.cert_reqs = ssl.CERT_REQUIRED - self.ca_certs = ca_certs - if validate: - if ca_certs is None or not os.access(ca_certs, os.R_OK): - raise IOError('Certificate Authority ca_certs file "%s" is not readable, cannot validate SSL certificates.' % (ca_certs)) - TSocket.TSocket.__init__(self, host, port, unix_socket) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - sock_family, sock_type= res[0:2] - ip_port = res[4] - plain_sock = socket.socket(sock_family, sock_type) - self.handle = ssl.wrap_socket(plain_sock, ssl_version=self.SSL_VERSION, - do_handshake_on_connect=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs) - self.handle.settimeout(self._timeout) - try: - self.handle.connect(ip_port) - except socket.error, e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error, e: - if self._unix_socket: - message = 'Could not connect to secure socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) - if self.validate: - self._validate_cert() - - def _validate_cert(self): - """internal method to validate the peer's SSL certificate, and to check the - commonName of the certificate to ensure it matches the hostname we - used to make this connection. Does not support subjectAltName records - in certificates. - - raises TTransportException if the certificate fails validation.""" - cert = self.handle.getpeercert() - self.peercert = cert - if 'subject' not in cert: - raise TTransportException(type=TTransportException.NOT_OPEN, - message='No SSL certificate found from %s:%s' % (self.host, self.port)) - fields = cert['subject'] - for field in fields: - # ensure structure we get back is what we expect - if not isinstance(field, tuple): - continue - cert_pair = field[0] - if len(cert_pair) < 2: - continue - cert_key, cert_value = cert_pair[0:2] - if cert_key != 'commonName': - continue - certhost = cert_value - if certhost == self.host: - # success, cert commonName matches desired hostname - self.is_valid = True - return - else: - raise TTransportException(type=TTransportException.UNKNOWN, - message='Host name we connected to "%s" doesn\'t match certificate provided commonName "%s"' % (self.host, certhost)) - raise TTransportException(type=TTransportException.UNKNOWN, - message='Could not validate SSL certificate from host "%s". Cert=%s' % (self.host, cert)) - -class TSSLServerSocket(TSocket.TServerSocket): - """ - SSL implementation of TServerSocket - - This uses the ssl module's wrap_socket() method to provide SSL - negotiated encryption. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 - - def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None): - """Initialize a TSSLServerSocket - - @param certfile: The filename of the server certificate file, defaults to cert.pem - @type certfile: str - @param host: The hostname or IP to bind the listen socket to, i.e. 'localhost' for only allowing - local network connections. Pass None to bind to all interfaces. - @type host: str - @param port: The port to listen on for inbound connections. - @type port: int - """ - self.setCertfile(certfile) - TSocket.TServerSocket.__init__(self, host, port) - - def setCertfile(self, certfile): - """Set or change the server certificate file used to wrap new connections. - - @param certfile: The filename of the server certificate, i.e. '/etc/certs/server.pem' - @type certfile: str - - Raises an IOError exception if the certfile is not present or unreadable. - """ - if not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % (certfile)) - self.certfile = certfile - - def accept(self): - plain_client, addr = self.handle.accept() - try: - client = ssl.wrap_socket(plain_client, certfile=self.certfile, - server_side=True, ssl_version=self.SSL_VERSION) - except ssl.SSLError, ssl_exc: - # failed handshake/ssl wrap, close socket to client - plain_client.close() - # raise ssl_exc - # We can't raise the exception, because it kills most TServer derived serve() - # methods. - # Instead, return None, and let the TServer instance deal with it in - # other exception handling. (but TSimpleServer dies anyway) - return None - result = TSocket.TSocket() - result.setHandle(client) - return result diff --git a/anknotes/thrift/transport/TSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/TSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 4e0e187..0000000 --- a/anknotes/thrift/transport/TSocket.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,163 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TTransport import * -import os -import errno -import socket -import sys - -class TSocketBase(TTransportBase): - def _resolveAddr(self): - if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] - else: - return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) - - def close(self): - if self.handle: - self.handle.close() - self.handle = None - -class TSocket(TSocketBase): - """Socket implementation of TTransport base.""" - - def __init__(self, host='localhost', port=9090, unix_socket=None): - """Initialize a TSocket - - @param host(str) The host to connect to. - @param port(int) The (TCP) port to connect to. - @param unix_socket(str) The filename of a unix socket to connect to. - (host and port will be ignored.) - """ - - self.host = host - self.port = port - self.handle = None - self._unix_socket = unix_socket - self._timeout = None - - def setHandle(self, h): - self.handle = h - - def isOpen(self): - return self.handle is not None - - def setTimeout(self, ms): - if ms is None: - self._timeout = None - else: - self._timeout = ms/1000.0 - - if self.handle is not None: - self.handle.settimeout(self._timeout) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - self.handle = socket.socket(res[0], res[1]) - self.handle.settimeout(self._timeout) - try: - self.handle.connect(res[4]) - except socket.error, e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error, e: - if self._unix_socket: - message = 'Could not connect to socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) - - def read(self, sz): - try: - buff = self.handle.recv(sz) - except socket.error, e: - if (e.args[0] == errno.ECONNRESET and - (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): - # freebsd and Mach don't follow POSIX semantic of recv - # and fail with ECONNRESET if peer performed shutdown. - # See corresponding comment and code in TSocket::read() - # in lib/cpp/src/transport/TSocket.cpp. - self.close() - # Trigger the check to raise the END_OF_FILE exception below. - buff = '' - else: - raise - if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') - return buff - - def write(self, buff): - if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') - sent = 0 - have = len(buff) - while sent < have: - plus = self.handle.send(buff) - if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') - sent += plus - buff = buff[plus:] - - def flush(self): - pass - -class TServerSocket(TSocketBase, TServerTransportBase): - """Socket implementation of TServerTransport base.""" - - def __init__(self, host=None, port=9090, unix_socket=None): - self.host = host - self.port = port - self._unix_socket = unix_socket - self.handle = None - - def listen(self): - res0 = self._resolveAddr() - for res in res0: - if res[0] is socket.AF_INET6 or res is res0[-1]: - break - - # We need remove the old unix socket if the file exists and - # nobody is listening on it. - if self._unix_socket: - tmp = socket.socket(res[0], res[1]) - try: - tmp.connect(res[4]) - except socket.error, err: - eno, message = err.args - if eno == errno.ECONNREFUSED: - os.unlink(res[4]) - - self.handle = socket.socket(res[0], res[1]) - self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(self.handle, 'settimeout'): - self.handle.settimeout(None) - self.handle.bind(res[4]) - self.handle.listen(128) - - def accept(self): - client, addr = self.handle.accept() - result = TSocket() - result.setHandle(client) - return result diff --git a/anknotes/thrift/transport/TSocket.py~HEAD b/anknotes/thrift/transport/TSocket.py~HEAD deleted file mode 100644 index 4e0e187..0000000 --- a/anknotes/thrift/transport/TSocket.py~HEAD +++ /dev/null @@ -1,163 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from TTransport import * -import os -import errno -import socket -import sys - -class TSocketBase(TTransportBase): - def _resolveAddr(self): - if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] - else: - return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) - - def close(self): - if self.handle: - self.handle.close() - self.handle = None - -class TSocket(TSocketBase): - """Socket implementation of TTransport base.""" - - def __init__(self, host='localhost', port=9090, unix_socket=None): - """Initialize a TSocket - - @param host(str) The host to connect to. - @param port(int) The (TCP) port to connect to. - @param unix_socket(str) The filename of a unix socket to connect to. - (host and port will be ignored.) - """ - - self.host = host - self.port = port - self.handle = None - self._unix_socket = unix_socket - self._timeout = None - - def setHandle(self, h): - self.handle = h - - def isOpen(self): - return self.handle is not None - - def setTimeout(self, ms): - if ms is None: - self._timeout = None - else: - self._timeout = ms/1000.0 - - if self.handle is not None: - self.handle.settimeout(self._timeout) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - self.handle = socket.socket(res[0], res[1]) - self.handle.settimeout(self._timeout) - try: - self.handle.connect(res[4]) - except socket.error, e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error, e: - if self._unix_socket: - message = 'Could not connect to socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) - - def read(self, sz): - try: - buff = self.handle.recv(sz) - except socket.error, e: - if (e.args[0] == errno.ECONNRESET and - (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): - # freebsd and Mach don't follow POSIX semantic of recv - # and fail with ECONNRESET if peer performed shutdown. - # See corresponding comment and code in TSocket::read() - # in lib/cpp/src/transport/TSocket.cpp. - self.close() - # Trigger the check to raise the END_OF_FILE exception below. - buff = '' - else: - raise - if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') - return buff - - def write(self, buff): - if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') - sent = 0 - have = len(buff) - while sent < have: - plus = self.handle.send(buff) - if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') - sent += plus - buff = buff[plus:] - - def flush(self): - pass - -class TServerSocket(TSocketBase, TServerTransportBase): - """Socket implementation of TServerTransport base.""" - - def __init__(self, host=None, port=9090, unix_socket=None): - self.host = host - self.port = port - self._unix_socket = unix_socket - self.handle = None - - def listen(self): - res0 = self._resolveAddr() - for res in res0: - if res[0] is socket.AF_INET6 or res is res0[-1]: - break - - # We need remove the old unix socket if the file exists and - # nobody is listening on it. - if self._unix_socket: - tmp = socket.socket(res[0], res[1]) - try: - tmp.connect(res[4]) - except socket.error, err: - eno, message = err.args - if eno == errno.ECONNREFUSED: - os.unlink(res[4]) - - self.handle = socket.socket(res[0], res[1]) - self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(self.handle, 'settimeout'): - self.handle.settimeout(None) - self.handle.bind(res[4]) - self.handle.listen(128) - - def accept(self): - client, addr = self.handle.accept() - result = TSocket() - result.setHandle(client) - return result diff --git a/anknotes/thrift/transport/TTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/TTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 9ffdc05..0000000 --- a/anknotes/thrift/transport/TTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,331 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from cStringIO import StringIO -from struct import pack,unpack -from anknotes.thrift.Thrift import TException - -class TTransportException(TException): - - """Custom Transport Exception class""" - - UNKNOWN = 0 - NOT_OPEN = 1 - ALREADY_OPEN = 2 - TIMED_OUT = 3 - END_OF_FILE = 4 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - -class TTransportBase: - - """Base class for Thrift transport layer.""" - - def isOpen(self): - pass - - def open(self): - pass - - def close(self): - pass - - def read(self, sz): - pass - - def readAll(self, sz): - buff = '' - have = 0 - while (have < sz): - chunk = self.read(sz-have) - have += len(chunk) - buff += chunk - - if len(chunk) == 0: - raise EOFError() - - return buff - - def write(self, buf): - pass - - def flush(self): - pass - -# This class should be thought of as an interface. -class CReadableTransport: - """base class for transports that are readable from C""" - - # TODO(dreiss): Think about changing this interface to allow us to use - # a (Python, not c) StringIO instead, because it allows - # you to write after reading. - - # NOTE: This is a classic class, so properties will NOT work - # correctly for setting. - @property - def cstringio_buf(self): - """A cStringIO buffer that contains the current chunk we are reading.""" - pass - - def cstringio_refill(self, partialread, reqlen): - """Refills cstringio_buf. - - Returns the currently used buffer (which can but need not be the same as - the old cstringio_buf). partialread is what the C code has read from the - buffer, and should be inserted into the buffer before any more reads. The - return value must be a new, not borrowed reference. Something along the - lines of self._buf should be fine. - - If reqlen bytes can't be read, throw EOFError. - """ - pass - -class TServerTransportBase: - - """Base class for Thrift server transports.""" - - def listen(self): - pass - - def accept(self): - pass - - def close(self): - pass - -class TTransportFactoryBase: - - """Base class for a Transport Factory""" - - def getTransport(self, trans): - return trans - -class TBufferedTransportFactory: - - """Factory transport that builds buffered transports""" - - def getTransport(self, trans): - buffered = TBufferedTransport(trans) - return buffered - - -class TBufferedTransport(TTransportBase,CReadableTransport): - - """Class that wraps another transport and buffers its I/O. - - The implementation uses a (configurable) fixed-size read buffer - but buffers all writes until a flush is performed. - """ - - DEFAULT_BUFFER = 4096 - - def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): - self.__trans = trans - self.__wbuf = StringIO() - self.__rbuf = StringIO("") - self.__rbuf_size = rbuf_size - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) - return self.__rbuf.read(sz) - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - out = self.__wbuf.getvalue() - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = StringIO() - self.__trans.write(out) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - retstring = partialread - if reqlen < self.__rbuf_size: - # try to make a read of as much as we can. - retstring += self.__trans.read(self.__rbuf_size) - - # but make sure we do read reqlen bytes. - if len(retstring) < reqlen: - retstring += self.__trans.readAll(reqlen - len(retstring)) - - self.__rbuf = StringIO(retstring) - return self.__rbuf - -class TMemoryBuffer(TTransportBase, CReadableTransport): - """Wraps a cStringIO object as a TTransport. - - NOTE: Unlike the C++ version of this class, you cannot write to it - then immediately read from it. If you want to read from a - TMemoryBuffer, you must either pass a string to the constructor. - TODO(dreiss): Make this work like the C++ version. - """ - - def __init__(self, value=None): - """value -- a value to read from for stringio - - If value is set, this will be a transport for reading, - otherwise, it is for writing""" - if value is not None: - self._buffer = StringIO(value) - else: - self._buffer = StringIO() - - def isOpen(self): - return not self._buffer.closed - - def open(self): - pass - - def close(self): - self._buffer.close() - - def read(self, sz): - return self._buffer.read(sz) - - def write(self, buf): - self._buffer.write(buf) - - def flush(self): - pass - - def getvalue(self): - return self._buffer.getvalue() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self._buffer - - def cstringio_refill(self, partialread, reqlen): - # only one shot at reading... - raise EOFError() - -class TFramedTransportFactory: - - """Factory transport that builds framed transports""" - - def getTransport(self, trans): - framed = TFramedTransport(trans) - return framed - - -class TFramedTransport(TTransportBase, CReadableTransport): - - """Class that wraps another transport and frames its I/O when writing.""" - - def __init__(self, trans,): - self.__trans = trans - self.__rbuf = StringIO() - self.__wbuf = StringIO() - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self.readFrame() - return self.__rbuf.read(sz) - - def readFrame(self): - buff = self.__trans.readAll(4) - sz, = unpack('!i', buff) - self.__rbuf = StringIO(self.__trans.readAll(sz)) - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - wout = self.__wbuf.getvalue() - wsz = len(wout) - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = StringIO() - # N.B.: Doing this string concatenation is WAY cheaper than making - # two separate calls to the underlying socket object. Socket writes in - # Python turn out to be REALLY expensive, but it seems to do a pretty - # good job of managing string buffer operations without excessive copies - buf = pack("!i", wsz) + wout - self.__trans.write(buf) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self.readFrame() - prefix += self.__rbuf.getvalue() - self.__rbuf = StringIO(prefix) - return self.__rbuf - - -class TFileObjectTransport(TTransportBase): - """Wraps a file-like object to make it work as a Thrift transport.""" - - def __init__(self, fileobj): - self.fileobj = fileobj - - def isOpen(self): - return True - - def close(self): - self.fileobj.close() - - def read(self, sz): - return self.fileobj.read(sz) - - def write(self, buf): - self.fileobj.write(buf) - - def flush(self): - self.fileobj.flush() diff --git a/anknotes/thrift/transport/TTransport.py~HEAD b/anknotes/thrift/transport/TTransport.py~HEAD deleted file mode 100644 index 12e51a9..0000000 --- a/anknotes/thrift/transport/TTransport.py~HEAD +++ /dev/null @@ -1,331 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from cStringIO import StringIO -from struct import pack,unpack -from thrift.Thrift import TException - -class TTransportException(TException): - - """Custom Transport Exception class""" - - UNKNOWN = 0 - NOT_OPEN = 1 - ALREADY_OPEN = 2 - TIMED_OUT = 3 - END_OF_FILE = 4 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - -class TTransportBase: - - """Base class for Thrift transport layer.""" - - def isOpen(self): - pass - - def open(self): - pass - - def close(self): - pass - - def read(self, sz): - pass - - def readAll(self, sz): - buff = '' - have = 0 - while (have < sz): - chunk = self.read(sz-have) - have += len(chunk) - buff += chunk - - if len(chunk) == 0: - raise EOFError() - - return buff - - def write(self, buf): - pass - - def flush(self): - pass - -# This class should be thought of as an interface. -class CReadableTransport: - """base class for transports that are readable from C""" - - # TODO(dreiss): Think about changing this interface to allow us to use - # a (Python, not c) StringIO instead, because it allows - # you to write after reading. - - # NOTE: This is a classic class, so properties will NOT work - # correctly for setting. - @property - def cstringio_buf(self): - """A cStringIO buffer that contains the current chunk we are reading.""" - pass - - def cstringio_refill(self, partialread, reqlen): - """Refills cstringio_buf. - - Returns the currently used buffer (which can but need not be the same as - the old cstringio_buf). partialread is what the C code has read from the - buffer, and should be inserted into the buffer before any more reads. The - return value must be a new, not borrowed reference. Something along the - lines of self._buf should be fine. - - If reqlen bytes can't be read, throw EOFError. - """ - pass - -class TServerTransportBase: - - """Base class for Thrift server transports.""" - - def listen(self): - pass - - def accept(self): - pass - - def close(self): - pass - -class TTransportFactoryBase: - - """Base class for a Transport Factory""" - - def getTransport(self, trans): - return trans - -class TBufferedTransportFactory: - - """Factory transport that builds buffered transports""" - - def getTransport(self, trans): - buffered = TBufferedTransport(trans) - return buffered - - -class TBufferedTransport(TTransportBase,CReadableTransport): - - """Class that wraps another transport and buffers its I/O. - - The implementation uses a (configurable) fixed-size read buffer - but buffers all writes until a flush is performed. - """ - - DEFAULT_BUFFER = 4096 - - def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): - self.__trans = trans - self.__wbuf = StringIO() - self.__rbuf = StringIO("") - self.__rbuf_size = rbuf_size - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) - return self.__rbuf.read(sz) - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - out = self.__wbuf.getvalue() - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = StringIO() - self.__trans.write(out) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - retstring = partialread - if reqlen < self.__rbuf_size: - # try to make a read of as much as we can. - retstring += self.__trans.read(self.__rbuf_size) - - # but make sure we do read reqlen bytes. - if len(retstring) < reqlen: - retstring += self.__trans.readAll(reqlen - len(retstring)) - - self.__rbuf = StringIO(retstring) - return self.__rbuf - -class TMemoryBuffer(TTransportBase, CReadableTransport): - """Wraps a cStringIO object as a TTransport. - - NOTE: Unlike the C++ version of this class, you cannot write to it - then immediately read from it. If you want to read from a - TMemoryBuffer, you must either pass a string to the constructor. - TODO(dreiss): Make this work like the C++ version. - """ - - def __init__(self, value=None): - """value -- a value to read from for stringio - - If value is set, this will be a transport for reading, - otherwise, it is for writing""" - if value is not None: - self._buffer = StringIO(value) - else: - self._buffer = StringIO() - - def isOpen(self): - return not self._buffer.closed - - def open(self): - pass - - def close(self): - self._buffer.close() - - def read(self, sz): - return self._buffer.read(sz) - - def write(self, buf): - self._buffer.write(buf) - - def flush(self): - pass - - def getvalue(self): - return self._buffer.getvalue() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self._buffer - - def cstringio_refill(self, partialread, reqlen): - # only one shot at reading... - raise EOFError() - -class TFramedTransportFactory: - - """Factory transport that builds framed transports""" - - def getTransport(self, trans): - framed = TFramedTransport(trans) - return framed - - -class TFramedTransport(TTransportBase, CReadableTransport): - - """Class that wraps another transport and frames its I/O when writing.""" - - def __init__(self, trans,): - self.__trans = trans - self.__rbuf = StringIO() - self.__wbuf = StringIO() - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self.readFrame() - return self.__rbuf.read(sz) - - def readFrame(self): - buff = self.__trans.readAll(4) - sz, = unpack('!i', buff) - self.__rbuf = StringIO(self.__trans.readAll(sz)) - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - wout = self.__wbuf.getvalue() - wsz = len(wout) - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = StringIO() - # N.B.: Doing this string concatenation is WAY cheaper than making - # two separate calls to the underlying socket object. Socket writes in - # Python turn out to be REALLY expensive, but it seems to do a pretty - # good job of managing string buffer operations without excessive copies - buf = pack("!i", wsz) + wout - self.__trans.write(buf) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self.readFrame() - prefix += self.__rbuf.getvalue() - self.__rbuf = StringIO(prefix) - return self.__rbuf - - -class TFileObjectTransport(TTransportBase): - """Wraps a file-like object to make it work as a Thrift transport.""" - - def __init__(self, fileobj): - self.fileobj = fileobj - - def isOpen(self): - return True - - def close(self): - self.fileobj.close() - - def read(self, sz): - return self.fileobj.read(sz) - - def write(self, buf): - self.fileobj.write(buf) - - def flush(self): - self.fileobj.flush() diff --git a/anknotes/thrift/transport/TTwisted.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/TTwisted.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index b6dcb4e..0000000 --- a/anknotes/thrift/transport/TTwisted.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,219 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from zope.interface import implements, Interface, Attribute -from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ - connectionDone -from twisted.internet import defer -from twisted.protocols import basic -from twisted.python import log -from twisted.web import server, resource, http - -from thrift.transport import TTransport -from cStringIO import StringIO - - -class TMessageSenderTransport(TTransport.TTransportBase): - - def __init__(self): - self.__wbuf = StringIO() - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - msg = self.__wbuf.getvalue() - self.__wbuf = StringIO() - self.sendMessage(msg) - - def sendMessage(self, message): - raise NotImplementedError - - -class TCallbackTransport(TMessageSenderTransport): - - def __init__(self, func): - TMessageSenderTransport.__init__(self) - self.func = func - - def sendMessage(self, message): - self.func(message) - - -class ThriftClientProtocol(basic.Int32StringReceiver): - - MAX_LENGTH = 2 ** 31 - 1 - - def __init__(self, client_class, iprot_factory, oprot_factory=None): - self._client_class = client_class - self._iprot_factory = iprot_factory - if oprot_factory is None: - self._oprot_factory = iprot_factory - else: - self._oprot_factory = oprot_factory - - self.recv_map = {} - self.started = defer.Deferred() - - def dispatch(self, msg): - self.sendString(msg) - - def connectionMade(self): - tmo = TCallbackTransport(self.dispatch) - self.client = self._client_class(tmo, self._oprot_factory) - self.started.callback(self.client) - - def connectionLost(self, reason=connectionDone): - for k,v in self.client._reqs.iteritems(): - tex = TTransport.TTransportException( - type=TTransport.TTransportException.END_OF_FILE, - message='Connection closed') - v.errback(tex) - - def stringReceived(self, frame): - tr = TTransport.TMemoryBuffer(frame) - iprot = self._iprot_factory.getProtocol(tr) - (fname, mtype, rseqid) = iprot.readMessageBegin() - - try: - method = self.recv_map[fname] - except KeyError: - method = getattr(self.client, 'recv_' + fname) - self.recv_map[fname] = method - - method(iprot, mtype, rseqid) - - -class ThriftServerProtocol(basic.Int32StringReceiver): - - MAX_LENGTH = 2 ** 31 - 1 - - def dispatch(self, msg): - self.sendString(msg) - - def processError(self, error): - self.transport.loseConnection() - - def processOk(self, _, tmo): - msg = tmo.getvalue() - - if len(msg) > 0: - self.dispatch(msg) - - def stringReceived(self, frame): - tmi = TTransport.TMemoryBuffer(frame) - tmo = TTransport.TMemoryBuffer() - - iprot = self.factory.iprot_factory.getProtocol(tmi) - oprot = self.factory.oprot_factory.getProtocol(tmo) - - d = self.factory.processor.process(iprot, oprot) - d.addCallbacks(self.processOk, self.processError, - callbackArgs=(tmo,)) - - -class IThriftServerFactory(Interface): - - processor = Attribute("Thrift processor") - - iprot_factory = Attribute("Input protocol factory") - - oprot_factory = Attribute("Output protocol factory") - - -class IThriftClientFactory(Interface): - - client_class = Attribute("Thrift client class") - - iprot_factory = Attribute("Input protocol factory") - - oprot_factory = Attribute("Output protocol factory") - - -class ThriftServerFactory(ServerFactory): - - implements(IThriftServerFactory) - - protocol = ThriftServerProtocol - - def __init__(self, processor, iprot_factory, oprot_factory=None): - self.processor = processor - self.iprot_factory = iprot_factory - if oprot_factory is None: - self.oprot_factory = iprot_factory - else: - self.oprot_factory = oprot_factory - - -class ThriftClientFactory(ClientFactory): - - implements(IThriftClientFactory) - - protocol = ThriftClientProtocol - - def __init__(self, client_class, iprot_factory, oprot_factory=None): - self.client_class = client_class - self.iprot_factory = iprot_factory - if oprot_factory is None: - self.oprot_factory = iprot_factory - else: - self.oprot_factory = oprot_factory - - def buildProtocol(self, addr): - p = self.protocol(self.client_class, self.iprot_factory, - self.oprot_factory) - p.factory = self - return p - - -class ThriftResource(resource.Resource): - - allowedMethods = ('POST',) - - def __init__(self, processor, inputProtocolFactory, - outputProtocolFactory=None): - resource.Resource.__init__(self) - self.inputProtocolFactory = inputProtocolFactory - if outputProtocolFactory is None: - self.outputProtocolFactory = inputProtocolFactory - else: - self.outputProtocolFactory = outputProtocolFactory - self.processor = processor - - def getChild(self, path, request): - return self - - def _cbProcess(self, _, request, tmo): - msg = tmo.getvalue() - request.setResponseCode(http.OK) - request.setHeader("content-type", "application/x-thrift") - request.write(msg) - request.finish() - - def render_POST(self, request): - request.content.seek(0, 0) - data = request.content.read() - tmi = TTransport.TMemoryBuffer(data) - tmo = TTransport.TMemoryBuffer() - - iprot = self.inputProtocolFactory.getProtocol(tmi) - oprot = self.outputProtocolFactory.getProtocol(tmo) - - d = self.processor.process(iprot, oprot) - d.addCallback(self._cbProcess, request, tmo) - return server.NOT_DONE_YET diff --git a/anknotes/thrift/transport/TTwisted.py~HEAD b/anknotes/thrift/transport/TTwisted.py~HEAD deleted file mode 100644 index b6dcb4e..0000000 --- a/anknotes/thrift/transport/TTwisted.py~HEAD +++ /dev/null @@ -1,219 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from zope.interface import implements, Interface, Attribute -from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ - connectionDone -from twisted.internet import defer -from twisted.protocols import basic -from twisted.python import log -from twisted.web import server, resource, http - -from thrift.transport import TTransport -from cStringIO import StringIO - - -class TMessageSenderTransport(TTransport.TTransportBase): - - def __init__(self): - self.__wbuf = StringIO() - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - msg = self.__wbuf.getvalue() - self.__wbuf = StringIO() - self.sendMessage(msg) - - def sendMessage(self, message): - raise NotImplementedError - - -class TCallbackTransport(TMessageSenderTransport): - - def __init__(self, func): - TMessageSenderTransport.__init__(self) - self.func = func - - def sendMessage(self, message): - self.func(message) - - -class ThriftClientProtocol(basic.Int32StringReceiver): - - MAX_LENGTH = 2 ** 31 - 1 - - def __init__(self, client_class, iprot_factory, oprot_factory=None): - self._client_class = client_class - self._iprot_factory = iprot_factory - if oprot_factory is None: - self._oprot_factory = iprot_factory - else: - self._oprot_factory = oprot_factory - - self.recv_map = {} - self.started = defer.Deferred() - - def dispatch(self, msg): - self.sendString(msg) - - def connectionMade(self): - tmo = TCallbackTransport(self.dispatch) - self.client = self._client_class(tmo, self._oprot_factory) - self.started.callback(self.client) - - def connectionLost(self, reason=connectionDone): - for k,v in self.client._reqs.iteritems(): - tex = TTransport.TTransportException( - type=TTransport.TTransportException.END_OF_FILE, - message='Connection closed') - v.errback(tex) - - def stringReceived(self, frame): - tr = TTransport.TMemoryBuffer(frame) - iprot = self._iprot_factory.getProtocol(tr) - (fname, mtype, rseqid) = iprot.readMessageBegin() - - try: - method = self.recv_map[fname] - except KeyError: - method = getattr(self.client, 'recv_' + fname) - self.recv_map[fname] = method - - method(iprot, mtype, rseqid) - - -class ThriftServerProtocol(basic.Int32StringReceiver): - - MAX_LENGTH = 2 ** 31 - 1 - - def dispatch(self, msg): - self.sendString(msg) - - def processError(self, error): - self.transport.loseConnection() - - def processOk(self, _, tmo): - msg = tmo.getvalue() - - if len(msg) > 0: - self.dispatch(msg) - - def stringReceived(self, frame): - tmi = TTransport.TMemoryBuffer(frame) - tmo = TTransport.TMemoryBuffer() - - iprot = self.factory.iprot_factory.getProtocol(tmi) - oprot = self.factory.oprot_factory.getProtocol(tmo) - - d = self.factory.processor.process(iprot, oprot) - d.addCallbacks(self.processOk, self.processError, - callbackArgs=(tmo,)) - - -class IThriftServerFactory(Interface): - - processor = Attribute("Thrift processor") - - iprot_factory = Attribute("Input protocol factory") - - oprot_factory = Attribute("Output protocol factory") - - -class IThriftClientFactory(Interface): - - client_class = Attribute("Thrift client class") - - iprot_factory = Attribute("Input protocol factory") - - oprot_factory = Attribute("Output protocol factory") - - -class ThriftServerFactory(ServerFactory): - - implements(IThriftServerFactory) - - protocol = ThriftServerProtocol - - def __init__(self, processor, iprot_factory, oprot_factory=None): - self.processor = processor - self.iprot_factory = iprot_factory - if oprot_factory is None: - self.oprot_factory = iprot_factory - else: - self.oprot_factory = oprot_factory - - -class ThriftClientFactory(ClientFactory): - - implements(IThriftClientFactory) - - protocol = ThriftClientProtocol - - def __init__(self, client_class, iprot_factory, oprot_factory=None): - self.client_class = client_class - self.iprot_factory = iprot_factory - if oprot_factory is None: - self.oprot_factory = iprot_factory - else: - self.oprot_factory = oprot_factory - - def buildProtocol(self, addr): - p = self.protocol(self.client_class, self.iprot_factory, - self.oprot_factory) - p.factory = self - return p - - -class ThriftResource(resource.Resource): - - allowedMethods = ('POST',) - - def __init__(self, processor, inputProtocolFactory, - outputProtocolFactory=None): - resource.Resource.__init__(self) - self.inputProtocolFactory = inputProtocolFactory - if outputProtocolFactory is None: - self.outputProtocolFactory = inputProtocolFactory - else: - self.outputProtocolFactory = outputProtocolFactory - self.processor = processor - - def getChild(self, path, request): - return self - - def _cbProcess(self, _, request, tmo): - msg = tmo.getvalue() - request.setResponseCode(http.OK) - request.setHeader("content-type", "application/x-thrift") - request.write(msg) - request.finish() - - def render_POST(self, request): - request.content.seek(0, 0) - data = request.content.read() - tmi = TTransport.TMemoryBuffer(data) - tmo = TTransport.TMemoryBuffer() - - iprot = self.inputProtocolFactory.getProtocol(tmi) - oprot = self.outputProtocolFactory.getProtocol(tmo) - - d = self.processor.process(iprot, oprot) - d.addCallback(self._cbProcess, request, tmo) - return server.NOT_DONE_YET diff --git a/anknotes/thrift/transport/TZlibTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/TZlibTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 784d4e1..0000000 --- a/anknotes/thrift/transport/TZlibTransport.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,261 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -''' -TZlibTransport provides a compressed transport and transport factory -class, using the python standard library zlib module to implement -data compression. -''' - -from __future__ import division -import zlib -from cStringIO import StringIO -from TTransport import TTransportBase, CReadableTransport - -class TZlibTransportFactory(object): - ''' - Factory transport that builds zlib compressed transports. - - This factory caches the last single client/transport that it was passed - and returns the same TZlibTransport object that was created. - - This caching means the TServer class will get the _same_ transport - object for both input and output transports from this factory. - (For non-threaded scenarios only, since the cache only holds one object) - - The purpose of this caching is to allocate only one TZlibTransport where - only one is really needed (since it must have separate read/write buffers), - and makes the statistics from getCompSavings() and getCompRatio() - easier to understand. - ''' - - # class scoped cache of last transport given and zlibtransport returned - _last_trans = None - _last_z = None - - def getTransport(self, trans, compresslevel=9): - '''Wrap a transport , trans, with the TZlibTransport - compressed transport class, returning a new - transport to the caller. - - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Defaults to 9. - @type compresslevel: int - - This method returns a TZlibTransport which wraps the - passed C{trans} TTransport derived instance. - ''' - if trans == self._last_trans: - return self._last_z - ztrans = TZlibTransport(trans, compresslevel) - self._last_trans = trans - self._last_z = ztrans - return ztrans - - -class TZlibTransport(TTransportBase, CReadableTransport): - ''' - Class that wraps a transport with zlib, compressing writes - and decompresses reads, using the python standard - library zlib module. - ''' - - # Read buffer size for the python fastbinary C extension, - # the TBinaryProtocolAccelerated class. - DEFAULT_BUFFSIZE = 4096 - - def __init__(self, trans, compresslevel=9): - ''' - Create a new TZlibTransport, wrapping C{trans}, another - TTransport derived object. - - @param trans: A thrift transport object, i.e. a TSocket() object. - @type trans: TTransport - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Default is 9. - @type compresslevel: int - ''' - self.__trans = trans - self.compresslevel = compresslevel - self.__rbuf = StringIO() - self.__wbuf = StringIO() - self._init_zlib() - self._init_stats() - - def _reinit_buffers(self): - ''' - Internal method to initialize/reset the internal StringIO objects - for read and write buffers. - ''' - self.__rbuf = StringIO() - self.__wbuf = StringIO() - - def _init_stats(self): - ''' - Internal method to reset the internal statistics counters - for compression ratios and bandwidth savings. - ''' - self.bytes_in = 0 - self.bytes_out = 0 - self.bytes_in_comp = 0 - self.bytes_out_comp = 0 - - def _init_zlib(self): - ''' - Internal method for setting up the zlib compression and - decompression objects. - ''' - self._zcomp_read = zlib.decompressobj() - self._zcomp_write = zlib.compressobj(self.compresslevel) - - def getCompRatio(self): - ''' - Get the current measured compression ratios (in,out) from - this transport. - - Returns a tuple of: - (inbound_compression_ratio, outbound_compression_ratio) - - The compression ratios are computed as: - compressed / uncompressed - - E.g., data that compresses by 10x will have a ratio of: 0.10 - and data that compresses to half of ts original size will - have a ratio of 0.5 - - None is returned if no bytes have yet been processed in - a particular direction. - ''' - r_percent, w_percent = (None, None) - if self.bytes_in > 0: - r_percent = self.bytes_in_comp / self.bytes_in - if self.bytes_out > 0: - w_percent = self.bytes_out_comp / self.bytes_out - return (r_percent, w_percent) - - def getCompSavings(self): - ''' - Get the current count of saved bytes due to data - compression. - - Returns a tuple of: - (inbound_saved_bytes, outbound_saved_bytes) - - Note: if compression is actually expanding your - data (only likely with very tiny thrift objects), then - the values returned will be negative. - ''' - r_saved = self.bytes_in - self.bytes_in_comp - w_saved = self.bytes_out - self.bytes_out_comp - return (r_saved, w_saved) - - def isOpen(self): - '''Return the underlying transport's open status''' - return self.__trans.isOpen() - - def open(self): - """Open the underlying transport""" - self._init_stats() - return self.__trans.open() - - def listen(self): - '''Invoke the underlying transport's listen() method''' - self.__trans.listen() - - def accept(self): - '''Accept connections on the underlying transport''' - return self.__trans.accept() - - def close(self): - '''Close the underlying transport,''' - self._reinit_buffers() - self._init_zlib() - return self.__trans.close() - - def read(self, sz): - ''' - Read up to sz bytes from the decompressed bytes buffer, and - read from the underlying transport if the decompression - buffer is empty. - ''' - ret = self.__rbuf.read(sz) - if len(ret) > 0: - return ret - # keep reading from transport until something comes back - while True: - if self.readComp(sz): - break - ret = self.__rbuf.read(sz) - return ret - - def readComp(self, sz): - ''' - Read compressed data from the underlying transport, then - decompress it and append it to the internal StringIO read buffer - ''' - zbuf = self.__trans.read(sz) - zbuf = self._zcomp_read.unconsumed_tail + zbuf - buf = self._zcomp_read.decompress(zbuf) - self.bytes_in += len(zbuf) - self.bytes_in_comp += len(buf) - old = self.__rbuf.read() - self.__rbuf = StringIO(old + buf) - if len(old) + len(buf) == 0: - return False - return True - - def write(self, buf): - ''' - Write some bytes, putting them into the internal write - buffer for eventual compression. - ''' - self.__wbuf.write(buf) - - def flush(self): - ''' - Flush any queued up data in the write buffer and ensure the - compression buffer is flushed out to the underlying transport - ''' - wout = self.__wbuf.getvalue() - if len(wout) > 0: - zbuf = self._zcomp_write.compress(wout) - self.bytes_out += len(wout) - self.bytes_out_comp += len(zbuf) - else: - zbuf = '' - ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) - self.bytes_out_comp += len(ztail) - if (len(zbuf) + len(ztail)) > 0: - self.__wbuf = StringIO() - self.__trans.write(zbuf + ztail) - self.__trans.flush() - - @property - def cstringio_buf(self): - '''Implement the CReadableTransport interface''' - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - '''Implement the CReadableTransport interface for refill''' - retstring = partialread - if reqlen < self.DEFAULT_BUFFSIZE: - retstring += self.read(self.DEFAULT_BUFFSIZE) - while len(retstring) < reqlen: - retstring += self.read(reqlen - len(retstring)) - self.__rbuf = StringIO(retstring) - return self.__rbuf diff --git a/anknotes/thrift/transport/TZlibTransport.py~HEAD b/anknotes/thrift/transport/TZlibTransport.py~HEAD deleted file mode 100644 index 784d4e1..0000000 --- a/anknotes/thrift/transport/TZlibTransport.py~HEAD +++ /dev/null @@ -1,261 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -''' -TZlibTransport provides a compressed transport and transport factory -class, using the python standard library zlib module to implement -data compression. -''' - -from __future__ import division -import zlib -from cStringIO import StringIO -from TTransport import TTransportBase, CReadableTransport - -class TZlibTransportFactory(object): - ''' - Factory transport that builds zlib compressed transports. - - This factory caches the last single client/transport that it was passed - and returns the same TZlibTransport object that was created. - - This caching means the TServer class will get the _same_ transport - object for both input and output transports from this factory. - (For non-threaded scenarios only, since the cache only holds one object) - - The purpose of this caching is to allocate only one TZlibTransport where - only one is really needed (since it must have separate read/write buffers), - and makes the statistics from getCompSavings() and getCompRatio() - easier to understand. - ''' - - # class scoped cache of last transport given and zlibtransport returned - _last_trans = None - _last_z = None - - def getTransport(self, trans, compresslevel=9): - '''Wrap a transport , trans, with the TZlibTransport - compressed transport class, returning a new - transport to the caller. - - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Defaults to 9. - @type compresslevel: int - - This method returns a TZlibTransport which wraps the - passed C{trans} TTransport derived instance. - ''' - if trans == self._last_trans: - return self._last_z - ztrans = TZlibTransport(trans, compresslevel) - self._last_trans = trans - self._last_z = ztrans - return ztrans - - -class TZlibTransport(TTransportBase, CReadableTransport): - ''' - Class that wraps a transport with zlib, compressing writes - and decompresses reads, using the python standard - library zlib module. - ''' - - # Read buffer size for the python fastbinary C extension, - # the TBinaryProtocolAccelerated class. - DEFAULT_BUFFSIZE = 4096 - - def __init__(self, trans, compresslevel=9): - ''' - Create a new TZlibTransport, wrapping C{trans}, another - TTransport derived object. - - @param trans: A thrift transport object, i.e. a TSocket() object. - @type trans: TTransport - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Default is 9. - @type compresslevel: int - ''' - self.__trans = trans - self.compresslevel = compresslevel - self.__rbuf = StringIO() - self.__wbuf = StringIO() - self._init_zlib() - self._init_stats() - - def _reinit_buffers(self): - ''' - Internal method to initialize/reset the internal StringIO objects - for read and write buffers. - ''' - self.__rbuf = StringIO() - self.__wbuf = StringIO() - - def _init_stats(self): - ''' - Internal method to reset the internal statistics counters - for compression ratios and bandwidth savings. - ''' - self.bytes_in = 0 - self.bytes_out = 0 - self.bytes_in_comp = 0 - self.bytes_out_comp = 0 - - def _init_zlib(self): - ''' - Internal method for setting up the zlib compression and - decompression objects. - ''' - self._zcomp_read = zlib.decompressobj() - self._zcomp_write = zlib.compressobj(self.compresslevel) - - def getCompRatio(self): - ''' - Get the current measured compression ratios (in,out) from - this transport. - - Returns a tuple of: - (inbound_compression_ratio, outbound_compression_ratio) - - The compression ratios are computed as: - compressed / uncompressed - - E.g., data that compresses by 10x will have a ratio of: 0.10 - and data that compresses to half of ts original size will - have a ratio of 0.5 - - None is returned if no bytes have yet been processed in - a particular direction. - ''' - r_percent, w_percent = (None, None) - if self.bytes_in > 0: - r_percent = self.bytes_in_comp / self.bytes_in - if self.bytes_out > 0: - w_percent = self.bytes_out_comp / self.bytes_out - return (r_percent, w_percent) - - def getCompSavings(self): - ''' - Get the current count of saved bytes due to data - compression. - - Returns a tuple of: - (inbound_saved_bytes, outbound_saved_bytes) - - Note: if compression is actually expanding your - data (only likely with very tiny thrift objects), then - the values returned will be negative. - ''' - r_saved = self.bytes_in - self.bytes_in_comp - w_saved = self.bytes_out - self.bytes_out_comp - return (r_saved, w_saved) - - def isOpen(self): - '''Return the underlying transport's open status''' - return self.__trans.isOpen() - - def open(self): - """Open the underlying transport""" - self._init_stats() - return self.__trans.open() - - def listen(self): - '''Invoke the underlying transport's listen() method''' - self.__trans.listen() - - def accept(self): - '''Accept connections on the underlying transport''' - return self.__trans.accept() - - def close(self): - '''Close the underlying transport,''' - self._reinit_buffers() - self._init_zlib() - return self.__trans.close() - - def read(self, sz): - ''' - Read up to sz bytes from the decompressed bytes buffer, and - read from the underlying transport if the decompression - buffer is empty. - ''' - ret = self.__rbuf.read(sz) - if len(ret) > 0: - return ret - # keep reading from transport until something comes back - while True: - if self.readComp(sz): - break - ret = self.__rbuf.read(sz) - return ret - - def readComp(self, sz): - ''' - Read compressed data from the underlying transport, then - decompress it and append it to the internal StringIO read buffer - ''' - zbuf = self.__trans.read(sz) - zbuf = self._zcomp_read.unconsumed_tail + zbuf - buf = self._zcomp_read.decompress(zbuf) - self.bytes_in += len(zbuf) - self.bytes_in_comp += len(buf) - old = self.__rbuf.read() - self.__rbuf = StringIO(old + buf) - if len(old) + len(buf) == 0: - return False - return True - - def write(self, buf): - ''' - Write some bytes, putting them into the internal write - buffer for eventual compression. - ''' - self.__wbuf.write(buf) - - def flush(self): - ''' - Flush any queued up data in the write buffer and ensure the - compression buffer is flushed out to the underlying transport - ''' - wout = self.__wbuf.getvalue() - if len(wout) > 0: - zbuf = self._zcomp_write.compress(wout) - self.bytes_out += len(wout) - self.bytes_out_comp += len(zbuf) - else: - zbuf = '' - ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) - self.bytes_out_comp += len(ztail) - if (len(zbuf) + len(ztail)) > 0: - self.__wbuf = StringIO() - self.__trans.write(zbuf + ztail) - self.__trans.flush() - - @property - def cstringio_buf(self): - '''Implement the CReadableTransport interface''' - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - '''Implement the CReadableTransport interface for refill''' - retstring = partialread - if reqlen < self.DEFAULT_BUFFSIZE: - retstring += self.read(self.DEFAULT_BUFFSIZE) - while len(retstring) < reqlen: - retstring += self.read(reqlen - len(retstring)) - self.__rbuf = StringIO(retstring) - return self.__rbuf diff --git a/anknotes/thrift/transport/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 b/anknotes/thrift/transport/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 deleted file mode 100644 index 46e54fe..0000000 --- a/anknotes/thrift/transport/__init__.py~155d40b1f21ee8336f1c8d81dbef09df4cb39236 +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] diff --git a/anknotes/thrift/transport/__init__.py~HEAD b/anknotes/thrift/transport/__init__.py~HEAD deleted file mode 100644 index 46e54fe..0000000 --- a/anknotes/thrift/transport/__init__.py~HEAD +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] diff --git a/anknotes/toc.py b/anknotes/toc.py new file mode 100644 index 0000000..1440078 --- /dev/null +++ b/anknotes/toc.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +try: + from pysqlite2 import dbapi2 as sqlite +except ImportError: + from sqlite3 import dbapi2 as sqlite +from anknotes.constants import * +from anknotes.html import generate_evernote_link, generate_evernote_span +from anknotes.logging import log_dump +from anknotes.base import matches_list +from anknotes.EvernoteNoteTitle import EvernoteNoteTitle, generateTOCTitle +from anknotes.EvernoteNotePrototype import EvernoteNotePrototype + + +def TOCNamePriority(title): + all_headings = [HEADINGS.TOP, [], HEADINGS.BOTTOM] + for heading_index, headings in enumerate(all_headings): + match = matches_list(title, headings) + if match: + return heading_index + float(match) / len(headings) + return all_headings.index([]) + + +def TOCNameSort(title1, title2): + priority1 = TOCNamePriority(title1) + priority2 = TOCNamePriority(title2) + # Lower value for item 1 = item 1 placed BEFORE item 2 + if priority1 != priority2: + return int((priority1 - priority2)*1000) + return cmp(title1, title2) + + +def TOCSort(hash1, hash2): + lvl1 = hash1.Level + lvl2 = hash2.Level + names1 = hash1.TitleParts + names2 = hash2.TitleParts + for i in range(0, min(lvl1, lvl2)): + name1 = names1[i] + name2 = names2[i] + if name1 != name2: + return TOCNameSort(name1, name2) + # Lower value for item 1 = item 1 placed BEFORE item 2 + return lvl1 - lvl2 + + +class TOCHierarchyClass: + Title = None + """:type : EvernoteNoteTitle""" + Note = None + """:type : EvernoteNotePrototype.EvernoteNotePrototype""" + Outline = None + """:type : TOCHierarchyClass""" + Number = 1 + Children = [] + """:type : list[TOCHierarchyClass]""" + Parent = None + """:type : TOCHierarchyClass""" + __isSorted = False + + @staticmethod + def TOCItemSort(tocHierarchy1, tocHierarchy2): + lvl1 = tocHierarchy1.Level + lvl2 = tocHierarchy2.Level + names1 = tocHierarchy1.TitleParts + names2 = tocHierarchy2.TitleParts + for i in range(0, min(lvl1, lvl2)): + name1 = names1[i] + name2 = names2[i] + if name1 != name2: + return TOCNameSort(name1, name2) + # Lower value for item 1 = item 1 placed BEFORE item 2 + return lvl1 - lvl2 + + @property + def IsOutline(self): + if not self.Note: + return False + return TAGS.OUTLINE in self.Note.Tags + + def sortIfNeeded(self): + if self.__isSorted: + return + self.sortChildren() + + @property + def FullTitle(self): return self.Title.FullTitle if self.Title else "" + + @property + def Level(self): + return self.Title.Level + + @property + def ChildrenCount(self): + return len(self.Children) + + @property + def TitleParts(self): + return self.Title.TitleParts + + def addNote(self, note): + tocHierarchy = TOCHierarchyClass(note=note) + self.addHierarchy(tocHierarchy) + + def getChildIndex(self, tocChildHierarchy): + if not tocChildHierarchy in self.Children: + return -1 + self.sortIfNeeded() + return self.Children.index(tocChildHierarchy) + + @property + def ListPrefix(self): + index = self.Index + isSingleItem = self.IsSingleItem + if isSingleItem is 0: + return "" + if isSingleItem is 1: + return "*" + return str(index) + "." + + @property + def IsSingleItem(self): + index = self.Index + if index is 0: + return 0 + if index is 1 and len(self.Parent.Children) is 1: + return 1 + return -1 + + @property + def Index(self): + if not self.Parent: + return 0 + return self.Parent.getChildIndex(self) + 1 + + def addTitle(self, title): + self.addHierarchy(TOCHierarchyClass(title)) + + def addHierarchy(self, tocHierarchy): + tocNewTitle = tocHierarchy.Title + tocNewLevel = tocNewTitle.Level + selfLevel = self.Title.Level + tocTestBase = tocHierarchy.FullTitle.replace(self.FullTitle, '') + if tocTestBase[: + 2] == ': ': + tocTestBase = tocTestBase[2:] + + print " \nAdd Hierarchy: %-70s --> %-40s\n-------------------------------------" % ( + self.FullTitle, tocTestBase) + + if selfLevel > tocHierarchy.Title.Level: + print "New Title Level is Below current level" + return False + + selfTOCTitle = self.Title.TOCTitle + tocSelfSibling = tocNewTitle.Parents(self.Title.Level) + + if tocSelfSibling.TOCTitle != selfTOCTitle: + print "New Title doesn't match current path" + return False + + if tocNewLevel is self.Title.Level: + if tocHierarchy.IsOutline: + tocHierarchy.Parent = self + self.Outline = tocHierarchy + print "SUCCESS: Outline added" + return True + print "New Title Level is current level, but New Title is not Outline" + return False + + tocNewSelfChild = tocNewTitle.Parents(self.Title.Level + 1) + tocNewSelfChildTOCName = tocNewSelfChild.TOCName + isDirectChild = (tocHierarchy.Level == self.Level + 1) + if isDirectChild: + tocNewChildNamesTitle = "N/A" + print "New Title is a direct child of the current title" + else: + tocNewChildNamesTitle = tocHierarchy.Title.Names(self.Title.Level + 1).FullTitle + print "New Title is a Grandchild or deeper of the current title " + + for tocChild in self.Children: + assert (isinstance(tocChild, TOCHierarchyClass)) + if tocChild.Title.TOCName == tocNewSelfChildTOCName: + print "%-60s Child %-20s Match Succeeded for %s." % ( + self.FullTitle + ':', tocChild.Title.Name + ':', tocNewChildNamesTitle) + success = tocChild.addHierarchy(tocHierarchy) + if success: + return True + print "%-60s Child %-20s Match Succeeded for %s: However, unable to add to matched child" % ( + self.FullTitle + ':', tocChild.Title.Name + ':', tocNewChildNamesTitle) + print "%-60s Child %-20s Search failed for %s" % ( + self.FullTitle + ':', tocNewSelfChild.Name, tocNewChildNamesTitle) + + newChild = tocHierarchy if isDirectChild else TOCHierarchyClass(tocNewSelfChild) + newChild.parent = self + if isDirectChild: + print "%-60s Child %-20s Created Direct Child for %s." % ( + self.FullTitle + ':', newChild.Title.Name, tocNewChildNamesTitle) + success = True + else: + print "%-60s Child %-20s Created Title-Only Child for %-40ss." % ( + self.FullTitle + ':', newChild.Title.Name, tocNewChildNamesTitle) + success = newChild.addHierarchy(tocHierarchy) + print "%-60s Child %-20s Created Title-Only Child for %-40s: Match %s." % ( + self.FullTitle + ':', newChild.Title.Name, tocNewChildNamesTitle, + "succeeded" if success else "failed") + self.__isSorted = False + self.Children.append(newChild) + + print "%-60s Child %-20s Appended Child for %s. Operation was an overall %s." % ( + self.FullTitle + ':', newChild.Title.Name + ':', tocNewChildNamesTitle, + "success" if success else "failure") + return success + + def sortChildren(self): + self.Children = sorted(self.Children, self.TOCItemSort) + self.__isSorted = True + + def __strsingle(self, fullTitle=False): + selfTitleStr = self.FullTitle + selfNameStr = self.Title.Name + selfLevel = self.Title.Level + selfDepth = self.Title.Depth + selfListPrefix = self.ListPrefix + str_ = '' + if selfLevel == 1: + str_ += ' [%d] ' % len(self.Children) + else: + if len(self.Children): + str_ += ' [%d:%2d] ' % (selfDepth, len(self.Children)) + else: + str_ += ' [%d] ' % selfDepth + str_ += ' ' * (selfDepth * 3) + str_ += ' %s ' % selfListPrefix + + str_ += '%-60s %s' % (selfTitleStr if fullTitle else selfNameStr, '' if self.Note else '(No Note)') + return str_ + + def __str__(self, fullTitle=True, fullChildrenTitles=False): + self.sortIfNeeded() + lst = [self.__strsingle(fullTitle)] + for child in self.Children: + lst.append(child.__str__(fullChildrenTitles, fullChildrenTitles)) + return '\n'.join(lst) + + def GetOrderedListItem(self, title=None): + if not title: + title = self.Title.Name + selfTitleStr = title + selfLevel = self.Title.Level + selfDepth = self.Title.Depth + if selfLevel == 1: + guid = 'guid-pending' + if self.Note: + guid = self.Note.Guid + link = generate_evernote_link(guid, generateTOCTitle(selfTitleStr), 'TOC') + if self.Outline: + link += ' ' + generate_evernote_link(self.Outline.Note.Guid, + '(<span style="color: rgb(255, 255, 255);">O</span>)', 'Outline', + escape=False) + return link + if self.Note: + return self.Note.generateLevelLink(selfDepth) + else: + return generate_evernote_span(selfTitleStr, 'Levels', selfDepth) + + def GetOrderedList(self, title=None): + self.sortIfNeeded() + lst = [] + header = (self.GetOrderedListItem(title)) + if self.ChildrenCount > 0: + for child in self.Children: + lst.append(child.GetOrderedList()) + childHTML = '\n'.join(lst) + else: + childHTML = '' + if childHTML: + tag = 'ol' if self.ChildrenCount > 1 else 'ul' + base = '<%s>\r\n%s\r\n</%s>\r\n' + childHTML = base % (tag, childHTML, tag) + + if self.Level is 1: + base = '<div> %s </div>\r\n %s \r\n' + base = base % (header, childHTML) + return base + base = '<li> %s \r\n %s \r\n</li> \r\n' + base = base % (header, childHTML) + return base + + def __reprsingle(self, fullTitle=True): + selfTitleStr = self.FullTitle + selfNameStr = self.Title.Name + selfListPrefix = self.ListPrefix + str_ = "<%s:%s[%d] %s%s>" % ( + self.__class__.__name__, selfListPrefix, len(self.Children), selfTitleStr if fullTitle else selfNameStr, + '' if self.Note else ' *') + return str_ + + def __repr__(self, fullTitle=True, fullChildrenTitles=False): + self.sortIfNeeded() + lst = [self.__reprsingle(fullTitle)] + for child in self.Children: + lst.append(child.__repr__(fullChildrenTitles, fullChildrenTitles)) + return '\n'.join(lst) + + def __init__(self, title=None, note=None, number=1): + """ + :type title: EvernoteNoteTitle + :type note: EvernoteNotePrototype.EvernoteNotePrototype + """ + assert note or title + self.Outline = None + if note: + if (isinstance(note, sqlite.Row)): + note = EvernoteNotePrototype(db_note=note) + + self.Note = note + self.Title = EvernoteNoteTitle(note) + else: + self.Title = EvernoteNoteTitle(title) + self.Note = None + self.Number = number + self.Children = [] + self.__isSorted = False + + # + # tocTest = TOCHierarchyClass("My Root Title") + # tocTest.addTitle("My Root Title: Somebody") + # tocTest.addTitle("My Root Title: Somebody: Else") + # tocTest.addTitle("My Root Title: Someone") + # tocTest.addTitle("My Root Title: Someone: Else") + # tocTest.addTitle("My Root Title: Someone: Else: Entirely") + # tocTest.addTitle("My Root Title: Z This: HasNo: Direct Parent") + # pass diff --git a/anknotes/version.py b/anknotes/version.py index 0fb5b6e..2bdceaa 100644 --- a/anknotes/version.py +++ b/anknotes/version.py @@ -26,9 +26,11 @@ of the same class, thus must follow the same rules) """ -import string, re +import string +import re from types import StringType + class Version: """Abstract base class for version numbering classes. Just provides constructor (__init__) and reproducer (__repr__), because those diff --git a/anknotes_remove_tags.py b/anknotes_remove_tags.py new file mode 100644 index 0000000..b6b33e0 --- /dev/null +++ b/anknotes_remove_tags.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +import sys + +if not 'anki' in sys.modules: + from anknotes.shared import * + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError: + from sqlite3 import dbapi2 as sqlite + + Error = sqlite.Error + ankDBSetLocal() + + tags = ',#Imported,#Anki_Import,#Anki_Import_High_Priority,' + # ankDB().setrowfactory() + db = ankDB(TABLES.EVERNOTE.TAGS) + dbRows = db.all("SELECT * FROM {t} WHERE ? LIKE '%%,' || name || ',%%' ", tags) + + for dbRow in dbRows: + db.execute(fmt("UPDATE {n} SET tagNames = REPLACE(tagNames, ',{row[name]},', ','), tagGuids = " + "REPLACE(tagGuids, ',{row[guid]},', ',') WHERE tagGuids LIKE '%,{row[guid]},%'", row=dbRow)) + db.commit() diff --git a/anknotes_start.py b/anknotes_start.py index d596053..268133d 100644 --- a/anknotes_start.py +++ b/anknotes_start.py @@ -1 +1,7 @@ -from anknotes import __main__ \ No newline at end of file +from anknotes.logging_base import reset_logs +reset_logs() +# from anknotes.logging_base import write_file_contents +# write_file_contents('Loading ' + __name__, 'load') + +from anknotes import __main__ +# write_file_contents('Loaded ' + __name__, 'load') diff --git a/anknotes_start_detect_see_also_changes.py b/anknotes_start_detect_see_also_changes.py new file mode 100644 index 0000000..538f733 --- /dev/null +++ b/anknotes_start_detect_see_also_changes.py @@ -0,0 +1,5 @@ +import sys + +if not 'anki' in sys.modules: + from anknotes import detect_see_also_changes + detect_see_also_changes.main() \ No newline at end of file diff --git a/anknotes_start_find_deleted_notes.py b/anknotes_start_find_deleted_notes.py new file mode 100644 index 0000000..20834d0 --- /dev/null +++ b/anknotes_start_find_deleted_notes.py @@ -0,0 +1,8 @@ +import sys + +if not 'anki' in sys.modules: + from anknotes import find_deleted_notes + from anknotes.db import ankDBSetLocal + + ankDBSetLocal() + find_deleted_notes.do_find_deleted_notes() diff --git a/anknotes_start_note_validation.py b/anknotes_start_note_validation.py new file mode 100644 index 0000000..998cea0 --- /dev/null +++ b/anknotes_start_note_validation.py @@ -0,0 +1,114 @@ +import os +from anknotes import stopwatch +from anknotes.imports import import_etree +import time + +if import_etree(): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError: + from sqlite3 import dbapi2 as sqlite + from anknotes.imports import lxml, etree + + ### Anknotes Module Imports for Stand Alone Scripts + from anknotes import evernote as evernote + ### Anknotes Shared Imports + from anknotes.shared import * + from anknotes.error import * + from anknotes.toc import TOCHierarchyClass + + ### Anknotes Class Imports + from anknotes.AnkiNotePrototype import AnkiNotePrototype + from anknotes.EvernoteNoteTitle import generateTOCTitle + + ### Anknotes Main Imports + from anknotes.Anki import Anki + from anknotes.ankEvernote import Evernote + # from anknotes.EvernoteNoteFetcher import EvernoteNoteFetcher + # from anknotes.EvernoteNotes import EvernoteNotes + # from anknotes.EvernoteNotePrototype import EvernoteNotePrototype + # from anknotes.EvernoteImporter import EvernoteImporter + # + # ### Evernote Imports + # from anknotes.evernote.edam.notestore.ttypes import NoteFilter, NotesMetadataResultSpec + # from anknotes.evernote.edam.type.ttypes import NoteSortOrder, Note as EvernoteNote + from anknotes.evernote.edam.error.ttypes import EDAMSystemException, EDAMUserException, EDAMNotFoundException + # from anknotes.evernote.api.client import EvernoteClient + + def mk_banner(fn, display_initial_info=False): + l.default_filename = fn + my_info_str = info_str % l.default_filename.upper() + myTmr = stopwatch.Timer(len(queued_items[fn]), 10, infoStr=my_info_str, do_print=True, label=l.base_path, + display_initial_info=display_initial_info) + l.go("------------------------------------------------", clear=True) + l.go(my_info_str) + l.go("------------------------------------------------") + return myTmr + + ankDBSetLocal() + db = ankDB(TABLES.NOTE_VALIDATION_QUEUE) + db.Init() + + queued_items = {'Failed': db.all("validation_status = -1"), + 'Pending': db.all("validation_status = 0"), + 'Successful': db.all("validation_status = 1")} + info_str = 'CHECKING {num} %s MAKE NOTE QUEUE ITEMS' + + l = Logger('Validation\\validate_notes\\', rm_path=True, do_print=True, timestamp=False) + + tmr = mk_banner('Successful') + for result in queued_items[l.default_filename]: + line = (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW [%-30s] " % '' + line += result['title'] + l.go(line) + + tmr = mk_banner('Failed') + for result in queued_items[l.default_filename]: + line = '%-60s ' % (result['title'] + ':') + line += (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW" + line += '\n' + result['validation_result'] + l.go(line) + l.go("------------------------------------------------\n") + l.go(result['contents']) + l.go("------------------------------------------------\n") + + EN = Evernote() + + tmr = mk_banner('Pending', display_initial_info=True) + timerFull = stopwatch.Timer() + for result in queued_items[l.default_filename]: + guid = result['guid'] + noteContents = result['contents'] + noteTitle = result['title'] + line = (" [%-30s] " % ((result['guid']) + ':')) if result['guid'] else "NEW [%-30s] " % '' + errors = tmr.autoStep(EN.validateNoteContent(noteContents, noteTitle), noteTitle) + validation_status = 1 if tmr.status.IsSuccess else -1 + + line = " SUCCESS! " if tmr.status.IsSuccess else " FAILURE: " + line += ' ' if result['guid'] else ' NEW ' + # line += ' %-60s ' % (result['title'] + ':') + l.dump(errors, 'LXML ERRORS', 'lxml_errors', wrap_filename=False, crosspost_to_default=False) + if not tmr.status.IsSuccess: + if not is_str_type(errors): + errors = '\n * ' + '\n * '.join(errors) + l.go(line + errors) + else: + if not is_str_type(errors): + errors = '\n'.join(errors) + + sql = "UPDATE {t} SET validation_status = ?, validation_result = ? WHERE " + data = [validation_status, errors] + if guid: + sql += "guid = ?" + data += [guid] + else: + sql += "title = ? AND contents = ?" + data += [noteTitle, noteContents] + + db.execute(sql, data) + + timerFull.stop() + l.go("Validation of %d results completed in %s" % (tmr.max, str(timerFull))) + + db.commit() + db.close() diff --git a/bs4/__init__.py b/bs4/__init__.py new file mode 100644 index 0000000..af8c718 --- /dev/null +++ b/bs4/__init__.py @@ -0,0 +1,355 @@ +"""Beautiful Soup +Elixir and Tonic +"The Screen-Scraper's Friend" +http://www.crummy.com/software/BeautifulSoup/ + +Beautiful Soup uses a pluggable XML or HTML parser to parse a +(possibly invalid) document into a tree representation. Beautiful Soup +provides provides methods and Pythonic idioms that make it easy to +navigate, search, and modify the parse tree. + +Beautiful Soup works with Python 2.6 and up. It works better if lxml +and/or html5lib is installed. + +For more than you ever wanted to know about Beautiful Soup, see the +documentation: +http://www.crummy.com/software/BeautifulSoup/bs4/doc/ +""" + +__author__ = "Leonard Richardson (leonardr@segfault.org)" +__version__ = "4.1.0" +__copyright__ = "Copyright (c) 2004-2012 Leonard Richardson" +__license__ = "MIT" + +__all__ = ['BeautifulSoup'] + +import re +import warnings + +from .builder import builder_registry +from .dammit import UnicodeDammit +from .element import ( + CData, + Comment, + DEFAULT_OUTPUT_ENCODING, + Declaration, + Doctype, + NavigableString, + PageElement, + ProcessingInstruction, + ResultSet, + SoupStrainer, + Tag, + ) + +# The very first thing we do is give a useful error if someone is +# running this code under Python 3 without converting it. +syntax_error = u'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work. You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).' + +class BeautifulSoup(Tag): + """ + This class defines the basic interface called by the tree builders. + + These methods will be called by the parser: + reset() + feed(markup) + + The tree builder may call these methods from its feed() implementation: + handle_starttag(name, attrs) # See note about return value + handle_endtag(name) + handle_data(data) # Appends to the current data node + endData(containerClass=NavigableString) # Ends the current data node + + No matter how complicated the underlying parser is, you should be + able to build a tree using 'start tag' events, 'end tag' events, + 'data' events, and "done with data" events. + + If you encounter an empty-element tag (aka a self-closing tag, + like HTML's <br> tag), call handle_starttag and then + handle_endtag. + """ + ROOT_TAG_NAME = u'[document]' + + # If the end-user gives no indication which tree builder they + # want, look for one with these features. + DEFAULT_BUILDER_FEATURES = ['html', 'fast'] + + # Used when determining whether a text node is all whitespace and + # can be replaced with a single space. A text node that contains + # fancy Unicode spaces (usually non-breaking) should be left + # alone. + STRIP_ASCII_SPACES = {9: None, 10: None, 12: None, 13: None, 32: None, } + + def __init__(self, markup="", features=None, builder=None, + parse_only=None, from_encoding=None, **kwargs): + """The Soup object is initialized as the 'root tag', and the + provided markup (which can be a string or a file-like object) + is fed into the underlying parser.""" + + if 'convertEntities' in kwargs: + warnings.warn( + "BS4 does not respect the convertEntities argument to the " + "BeautifulSoup constructor. Entities are always converted " + "to Unicode characters.") + + if 'markupMassage' in kwargs: + del kwargs['markupMassage'] + warnings.warn( + "BS4 does not respect the markupMassage argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for any necessary markup massage.") + + if 'smartQuotesTo' in kwargs: + del kwargs['smartQuotesTo'] + warnings.warn( + "BS4 does not respect the smartQuotesTo argument to the " + "BeautifulSoup constructor. Smart quotes are always converted " + "to Unicode characters.") + + if 'selfClosingTags' in kwargs: + del kwargs['selfClosingTags'] + warnings.warn( + "BS4 does not respect the selfClosingTags argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for understanding self-closing tags.") + + if 'isHTML' in kwargs: + del kwargs['isHTML'] + warnings.warn( + "BS4 does not respect the isHTML argument to the " + "BeautifulSoup constructor. You can pass in features='html' " + "or features='xml' to get a builder capable of handling " + "one or the other.") + + def deprecated_argument(old_name, new_name): + if old_name in kwargs: + warnings.warn( + 'The "%s" argument to the BeautifulSoup constructor ' + 'has been renamed to "%s."' % (old_name, new_name)) + value = kwargs[old_name] + del kwargs[old_name] + return value + return None + + parse_only = parse_only or deprecated_argument( + "parseOnlyThese", "parse_only") + + from_encoding = from_encoding or deprecated_argument( + "fromEncoding", "from_encoding") + + if len(kwargs) > 0: + arg = kwargs.keys().pop() + raise TypeError( + "__init__() got an unexpected keyword argument '%s'" % arg) + + if builder is None: + if isinstance(features, basestring): + features = [features] + if features is None or len(features) == 0: + features = self.DEFAULT_BUILDER_FEATURES + builder_class = builder_registry.lookup(*features) + if builder_class is None: + raise ValueError( + "Couldn't find a tree builder with the features you " + "requested: %s. Do you need to install a parser library?" + % ",".join(features)) + builder = builder_class() + self.builder = builder + self.is_xml = builder.is_xml + self.builder.soup = self + + self.parse_only = parse_only + + self.reset() + + if hasattr(markup, 'read'): # It's a file-type object. + markup = markup.read() + (self.markup, self.original_encoding, self.declared_html_encoding, + self.contains_replacement_characters) = ( + self.builder.prepare_markup(markup, from_encoding)) + + try: + self._feed() + except StopParsing: + pass + + # Clear out the markup and remove the builder's circular + # reference to this object. + self.markup = None + self.builder.soup = None + + def _feed(self): + # Convert the document to Unicode. + self.builder.reset() + + self.builder.feed(self.markup) + # Close out any unfinished strings and close all the open tags. + self.endData() + while self.currentTag.name != self.ROOT_TAG_NAME: + self.popTag() + + def reset(self): + Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME) + self.hidden = 1 + self.builder.reset() + self.currentData = [] + self.currentTag = None + self.tagStack = [] + self.pushTag(self) + + def new_tag(self, name, namespace=None, nsprefix=None, **attrs): + """Create a new tag associated with this soup.""" + return Tag(None, self.builder, name, namespace, nsprefix, attrs) + + def new_string(self, s): + """Create a new NavigableString associated with this soup.""" + navigable = NavigableString(s) + navigable.setup() + return navigable + + def insert_before(self, successor): + raise ValueError("BeautifulSoup objects don't support insert_before().") + + def insert_after(self, successor): + raise ValueError("BeautifulSoup objects don't support insert_after().") + + def popTag(self): + tag = self.tagStack.pop() + #print "Pop", tag.name + if self.tagStack: + self.currentTag = self.tagStack[-1] + return self.currentTag + + def pushTag(self, tag): + #print "Push", tag.name + if self.currentTag: + self.currentTag.contents.append(tag) + self.tagStack.append(tag) + self.currentTag = self.tagStack[-1] + + def endData(self, containerClass=NavigableString): + if self.currentData: + currentData = u''.join(self.currentData) + if (currentData.translate(self.STRIP_ASCII_SPACES) == '' and + not set([tag.name for tag in self.tagStack]).intersection( + self.builder.preserve_whitespace_tags)): + if '\n' in currentData: + currentData = '\n' + else: + currentData = ' ' + self.currentData = [] + if self.parse_only and len(self.tagStack) <= 1 and \ + (not self.parse_only.text or \ + not self.parse_only.search(currentData)): + return + o = containerClass(currentData) + self.object_was_parsed(o) + + def object_was_parsed(self, o): + """Add an object to the parse tree.""" + o.setup(self.currentTag, self.previous_element) + if self.previous_element: + self.previous_element.next_element = o + self.previous_element = o + self.currentTag.contents.append(o) + + def _popToTag(self, name, nsprefix=None, inclusivePop=True): + """Pops the tag stack up to and including the most recent + instance of the given tag. If inclusivePop is false, pops the tag + stack up to but *not* including the most recent instqance of + the given tag.""" + #print "Popping to %s" % name + if name == self.ROOT_TAG_NAME: + return + + numPops = 0 + mostRecentTag = None + + for i in range(len(self.tagStack) - 1, 0, -1): + if (name == self.tagStack[i].name + and nsprefix == self.tagStack[i].nsprefix == nsprefix): + numPops = len(self.tagStack) - i + break + if not inclusivePop: + numPops = numPops - 1 + + for i in range(0, numPops): + mostRecentTag = self.popTag() + return mostRecentTag + + def handle_starttag(self, name, namespace, nsprefix, attrs): + """Push a start tag on to the stack. + + If this method returns None, the tag was rejected by the + SoupStrainer. You should proceed as if the tag had not occured + in the document. For instance, if this was a self-closing tag, + don't call handle_endtag. + """ + + # print "Start tag %s: %s" % (name, attrs) + self.endData() + + if (self.parse_only and len(self.tagStack) <= 1 + and (self.parse_only.text + or not self.parse_only.search_tag(name, attrs))): + return None + + tag = Tag(self, self.builder, name, namespace, nsprefix, attrs, + self.currentTag, self.previous_element) + if tag is None: + return tag + if self.previous_element: + self.previous_element.next_element = tag + self.previous_element = tag + self.pushTag(tag) + return tag + + def handle_endtag(self, name, nsprefix=None): + #print "End tag: " + name + self.endData() + self._popToTag(name, nsprefix) + + def handle_data(self, data): + self.currentData.append(data) + + def decode(self, pretty_print=False, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Returns a string or Unicode representation of this document. + To get Unicode, pass None for encoding.""" + + if self.is_xml: + # Print the XML declaration + encoding_part = '' + if eventual_encoding != None: + encoding_part = ' encoding="%s"' % eventual_encoding + prefix = u'<?xml version="1.0"%s?>\n' % encoding_part + else: + prefix = u'' + if not pretty_print: + indent_level = None + else: + indent_level = 0 + return prefix + super(BeautifulSoup, self).decode( + indent_level, eventual_encoding, formatter) + +class BeautifulStoneSoup(BeautifulSoup): + """Deprecated interface to an XML parser.""" + + def __init__(self, *args, **kwargs): + kwargs['features'] = 'xml' + warnings.warn( + 'The BeautifulStoneSoup class is deprecated. Instead of using ' + 'it, pass features="xml" into the BeautifulSoup constructor.') + super(BeautifulStoneSoup, self).__init__(*args, **kwargs) + + +class StopParsing(Exception): + pass + + +#By default, act as an HTML pretty-printer. +if __name__ == '__main__': + import sys + soup = BeautifulSoup(sys.stdin) + print soup.prettify() diff --git a/bs4/builder/__init__.py b/bs4/builder/__init__.py new file mode 100644 index 0000000..4c22b86 --- /dev/null +++ b/bs4/builder/__init__.py @@ -0,0 +1,307 @@ +from collections import defaultdict +import itertools +import sys +from bs4.element import ( + CharsetMetaAttributeValue, + ContentMetaAttributeValue, + whitespace_re + ) + +__all__ = [ + 'HTMLTreeBuilder', + 'SAXTreeBuilder', + 'TreeBuilder', + 'TreeBuilderRegistry', + ] + +# Some useful features for a TreeBuilder to have. +FAST = 'fast' +PERMISSIVE = 'permissive' +STRICT = 'strict' +XML = 'xml' +HTML = 'html' +HTML_5 = 'html5' + + +class TreeBuilderRegistry(object): + + def __init__(self): + self.builders_for_feature = defaultdict(list) + self.builders = [] + + def register(self, treebuilder_class): + """Register a treebuilder based on its advertised features.""" + for feature in treebuilder_class.features: + self.builders_for_feature[feature].insert(0, treebuilder_class) + self.builders.insert(0, treebuilder_class) + + def lookup(self, *features): + if len(self.builders) == 0: + # There are no builders at all. + return None + + if len(features) == 0: + # They didn't ask for any features. Give them the most + # recently registered builder. + return self.builders[0] + + # Go down the list of features in order, and eliminate any builders + # that don't match every feature. + features = list(features) + features.reverse() + candidates = None + candidate_set = None + while len(features) > 0: + feature = features.pop() + we_have_the_feature = self.builders_for_feature.get(feature, []) + if len(we_have_the_feature) > 0: + if candidates is None: + candidates = we_have_the_feature + candidate_set = set(candidates) + else: + # Eliminate any candidates that don't have this feature. + candidate_set = candidate_set.intersection( + set(we_have_the_feature)) + + # The only valid candidates are the ones in candidate_set. + # Go through the original list of candidates and pick the first one + # that's in candidate_set. + if candidate_set is None: + return None + for candidate in candidates: + if candidate in candidate_set: + return candidate + return None + +# The BeautifulSoup class will take feature lists from developers and use them +# to look up builders in this registry. +builder_registry = TreeBuilderRegistry() + +class TreeBuilder(object): + """Turn a document into a Beautiful Soup object tree.""" + + features = [] + + is_xml = False + preserve_whitespace_tags = set() + empty_element_tags = None # A tag will be considered an empty-element + # tag when and only when it has no contents. + + # A value for these tag/attribute combinations is a space- or + # comma-separated list of CDATA, rather than a single CDATA. + cdata_list_attributes = {} + + + def __init__(self): + self.soup = None + + def reset(self): + pass + + def can_be_empty_element(self, tag_name): + """Might a tag with this name be an empty-element tag? + + The final markup may or may not actually present this tag as + self-closing. + + For instance: an HTMLBuilder does not consider a <p> tag to be + an empty-element tag (it's not in + HTMLBuilder.empty_element_tags). This means an empty <p> tag + will be presented as "<p></p>", not "<p />". + + The default implementation has no opinion about which tags are + empty-element tags, so a tag will be presented as an + empty-element tag if and only if it has no contents. + "<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will + be left alone. + """ + if self.empty_element_tags is None: + return True + return tag_name in self.empty_element_tags + + def feed(self, markup): + raise NotImplementedError() + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None): + return markup, None, None, False + + def test_fragment_to_document(self, fragment): + """Wrap an HTML fragment to make it look like a document. + + Different parsers do this differently. For instance, lxml + introduces an empty <head> tag, and html5lib + doesn't. Abstracting this away lets us write simple tests + which run HTML fragments through the parser and compare the + results against other HTML fragments. + + This method should not be used outside of tests. + """ + return fragment + + def set_up_substitutions(self, tag): + return False + + def _replace_cdata_list_attribute_values(self, tag_name, attrs): + """Replaces class="foo bar" with class=["foo", "bar"] + + Modifies its input in place. + """ + if self.cdata_list_attributes: + universal = self.cdata_list_attributes.get('*', []) + tag_specific = self.cdata_list_attributes.get( + tag_name.lower(), []) + for cdata_list_attr in itertools.chain(universal, tag_specific): + if cdata_list_attr in dict(attrs): + # Basically, we have a "class" attribute whose + # value is a whitespace-separated list of CSS + # classes. Split it into a list. + value = attrs[cdata_list_attr] + values = whitespace_re.split(value) + attrs[cdata_list_attr] = values + return attrs + +class SAXTreeBuilder(TreeBuilder): + """A Beautiful Soup treebuilder that listens for SAX events.""" + + def feed(self, markup): + raise NotImplementedError() + + def close(self): + pass + + def startElement(self, name, attrs): + attrs = dict((key[1], value) for key, value in list(attrs.items())) + #print "Start %s, %r" % (name, attrs) + self.soup.handle_starttag(name, attrs) + + def endElement(self, name): + #print "End %s" % name + self.soup.handle_endtag(name) + + def startElementNS(self, nsTuple, nodeName, attrs): + # Throw away (ns, nodeName) for now. + self.startElement(nodeName, attrs) + + def endElementNS(self, nsTuple, nodeName): + # Throw away (ns, nodeName) for now. + self.endElement(nodeName) + #handler.endElementNS((ns, node.nodeName), node.nodeName) + + def startPrefixMapping(self, prefix, nodeValue): + # Ignore the prefix for now. + pass + + def endPrefixMapping(self, prefix): + # Ignore the prefix for now. + # handler.endPrefixMapping(prefix) + pass + + def characters(self, content): + self.soup.handle_data(content) + + def startDocument(self): + pass + + def endDocument(self): + pass + + +class HTMLTreeBuilder(TreeBuilder): + """This TreeBuilder knows facts about HTML. + + Such as which tags are empty-element tags. + """ + + preserve_whitespace_tags = set(['pre', 'textarea']) + empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta', + 'spacer', 'link', 'frame', 'base']) + + # The HTML standard defines these attributes as containing a + # space-separated list of values, not a single value. That is, + # class="foo bar" means that the 'class' attribute has two values, + # 'foo' and 'bar', not the single value 'foo bar'. When we + # encounter one of these attributes, we will parse its value into + # a list of values if possible. Upon output, the list will be + # converted back into a string. + cdata_list_attributes = { + "*" : ['class', 'accesskey', 'dropzone'], + "a" : ['rel', 'rev'], + "link" : ['rel', 'rev'], + "td" : ["headers"], + "th" : ["headers"], + "td" : ["headers"], + "form" : ["accept-charset"], + "object" : ["archive"], + + # These are HTML5 specific, as are *.accesskey and *.dropzone above. + "area" : ["rel"], + "icon" : ["sizes"], + "iframe" : ["sandbox"], + "output" : ["for"], + } + + def set_up_substitutions(self, tag): + # We are only interested in <meta> tags + if tag.name != 'meta': + return False + + http_equiv = tag.get('http-equiv') + content = tag.get('content') + charset = tag.get('charset') + + # We are interested in <meta> tags that say what encoding the + # document was originally in. This means HTML 5-style <meta> + # tags that provide the "charset" attribute. It also means + # HTML 4-style <meta> tags that provide the "content" + # attribute and have "http-equiv" set to "content-type". + # + # In both cases we will replace the value of the appropriate + # attribute with a standin object that can take on any + # encoding. + meta_encoding = None + if charset is not None: + # HTML 5 style: + # <meta charset="utf8"> + meta_encoding = charset + tag['charset'] = CharsetMetaAttributeValue(charset) + + elif (content is not None and http_equiv is not None + and http_equiv.lower() == 'content-type'): + # HTML 4 style: + # <meta http-equiv="content-type" content="text/html; charset=utf8"> + tag['content'] = ContentMetaAttributeValue(content) + + return (meta_encoding is not None) + +def register_treebuilders_from(module): + """Copy TreeBuilders from the given module into this module.""" + # I'm fairly sure this is not the best way to do this. + this_module = sys.modules['bs4.builder'] + for name in module.__all__: + obj = getattr(module, name) + + if issubclass(obj, TreeBuilder): + setattr(this_module, name, obj) + this_module.__all__.append(name) + # Register the builder while we're at it. + this_module.builder_registry.register(obj) + +# Builders are registered in reverse order of priority, so that custom +# builder registrations will take precedence. In general, we want lxml +# to take precedence over html5lib, because it's faster. And we only +# want to use HTMLParser as a last result. +from . import _htmlparser +register_treebuilders_from(_htmlparser) +try: + from . import _html5lib + register_treebuilders_from(_html5lib) +except ImportError: + # They don't have html5lib installed. + pass +try: + from . import _lxml + register_treebuilders_from(_lxml) +except ImportError: + # They don't have lxml installed. + pass diff --git a/bs4/builder/_html5lib.py b/bs4/builder/_html5lib.py new file mode 100644 index 0000000..6001e38 --- /dev/null +++ b/bs4/builder/_html5lib.py @@ -0,0 +1,222 @@ +__all__ = [ + 'HTML5TreeBuilder', + ] + +import warnings +from bs4.builder import ( + PERMISSIVE, + HTML, + HTML_5, + HTMLTreeBuilder, + ) +from bs4.element import NamespacedAttribute +import html5lib +from html5lib.constants import namespaces +from bs4.element import ( + Comment, + Doctype, + NavigableString, + Tag, + ) + +class HTML5TreeBuilder(HTMLTreeBuilder): + """Use html5lib to build a tree.""" + + features = ['html5lib', PERMISSIVE, HTML_5, HTML] + + def prepare_markup(self, markup, user_specified_encoding): + # Store the user-specified encoding for use later on. + self.user_specified_encoding = user_specified_encoding + return markup, None, None, False + + # These methods are defined by Beautiful Soup. + def feed(self, markup): + if self.soup.parse_only is not None: + warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.") + parser = html5lib.HTMLParser(tree=self.create_treebuilder) + doc = parser.parse(markup, encoding=self.user_specified_encoding) + + # Set the character encoding detected by the tokenizer. + if isinstance(markup, unicode): + # We need to special-case this because html5lib sets + # charEncoding to UTF-8 if it gets Unicode input. + doc.original_encoding = None + else: + doc.original_encoding = parser.tokenizer.stream.charEncoding[0] + + def create_treebuilder(self, namespaceHTMLElements): + self.underlying_builder = TreeBuilderForHtml5lib( + self.soup, namespaceHTMLElements) + return self.underlying_builder + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'<html><head></head><body>%s</body></html>' % fragment + + +class TreeBuilderForHtml5lib(html5lib.treebuilders._base.TreeBuilder): + + def __init__(self, soup, namespaceHTMLElements): + self.soup = soup + super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements) + + def documentClass(self): + self.soup.reset() + return Element(self.soup, self.soup, None) + + def insertDoctype(self, token): + name = token["name"] + publicId = token["publicId"] + systemId = token["systemId"] + + doctype = Doctype.for_name_and_ids(name, publicId, systemId) + self.soup.object_was_parsed(doctype) + + def elementClass(self, name, namespace): + tag = self.soup.new_tag(name, namespace) + return Element(tag, self.soup, namespace) + + def commentClass(self, data): + return TextNode(Comment(data), self.soup) + + def fragmentClass(self): + self.soup = BeautifulSoup("") + self.soup.name = "[document_fragment]" + return Element(self.soup, self.soup, None) + + def appendChild(self, node): + # XXX This code is not covered by the BS4 tests. + self.soup.append(node.element) + + def getDocument(self): + return self.soup + + def getFragment(self): + return html5lib.treebuilders._base.TreeBuilder.getFragment(self).element + +class AttrList(object): + def __init__(self, element): + self.element = element + self.attrs = dict(self.element.attrs) + def __iter__(self): + return list(self.attrs.items()).__iter__() + def __setitem__(self, name, value): + "set attr", name, value + self.element[name] = value + def items(self): + return list(self.attrs.items()) + def keys(self): + return list(self.attrs.keys()) + def __len__(self): + return len(self.attrs) + def __getitem__(self, name): + return self.attrs[name] + def __contains__(self, name): + return name in list(self.attrs.keys()) + + +class Element(html5lib.treebuilders._base.Node): + def __init__(self, element, soup, namespace): + html5lib.treebuilders._base.Node.__init__(self, element.name) + self.element = element + self.soup = soup + self.namespace = namespace + + def appendChild(self, node): + if (node.element.__class__ == NavigableString and self.element.contents + and self.element.contents[-1].__class__ == NavigableString): + # Concatenate new text onto old text node + # XXX This has O(n^2) performance, for input like + # "a</a>a</a>a</a>..." + old_element = self.element.contents[-1] + new_element = self.soup.new_string(old_element + node.element) + old_element.replace_with(new_element) + else: + self.element.append(node.element) + node.parent = self + + def getAttributes(self): + return AttrList(self.element) + + def setAttributes(self, attributes): + if attributes is not None and len(attributes) > 0: + + converted_attributes = [] + for name, value in list(attributes.items()): + if isinstance(name, tuple): + new_name = NamespacedAttribute(*name) + del attributes[name] + attributes[new_name] = value + + self.soup.builder._replace_cdata_list_attribute_values( + self.name, attributes) + for name, value in attributes.items(): + self.element[name] = value + + # The attributes may contain variables that need substitution. + # Call set_up_substitutions manually. + # + # The Tag constructor called this method when the Tag was created, + # but we just set/changed the attributes, so call it again. + self.soup.builder.set_up_substitutions(self.element) + attributes = property(getAttributes, setAttributes) + + def insertText(self, data, insertBefore=None): + text = TextNode(self.soup.new_string(data), self.soup) + if insertBefore: + self.insertBefore(text, insertBefore) + else: + self.appendChild(text) + + def insertBefore(self, node, refNode): + index = self.element.index(refNode.element) + if (node.element.__class__ == NavigableString and self.element.contents + and self.element.contents[index-1].__class__ == NavigableString): + # (See comments in appendChild) + old_node = self.element.contents[index-1] + new_str = self.soup.new_string(old_node + node.element) + old_node.replace_with(new_str) + else: + self.element.insert(index, node.element) + node.parent = self + + def removeChild(self, node): + node.element.extract() + + def reparentChildren(self, newParent): + while self.element.contents: + child = self.element.contents[0] + child.extract() + if isinstance(child, Tag): + newParent.appendChild( + Element(child, self.soup, namespaces["html"])) + else: + newParent.appendChild( + TextNode(child, self.soup)) + + def cloneNode(self): + tag = self.soup.new_tag(self.element.name, self.namespace) + node = Element(tag, self.soup, self.namespace) + for key,value in self.attributes: + node.attributes[key] = value + return node + + def hasContent(self): + return self.element.contents + + def getNameTuple(self): + if self.namespace == None: + return namespaces["html"], self.name + else: + return self.namespace, self.name + + nameTuple = property(getNameTuple) + +class TextNode(Element): + def __init__(self, element, soup): + html5lib.treebuilders._base.Node.__init__(self, None) + self.element = element + self.soup = soup + + def cloneNode(self): + raise NotImplementedError diff --git a/bs4/builder/_htmlparser.py b/bs4/builder/_htmlparser.py new file mode 100644 index 0000000..ede5cec --- /dev/null +++ b/bs4/builder/_htmlparser.py @@ -0,0 +1,244 @@ +"""Use the HTMLParser library to parse HTML files that aren't too bad.""" + +__all__ = [ + 'HTMLParserTreeBuilder', + ] + +from HTMLParser import ( + HTMLParser, + HTMLParseError, + ) +import sys +import warnings + +# Starting in Python 3.2, the HTMLParser constructor takes a 'strict' +# argument, which we'd like to set to False. Unfortunately, +# http://bugs.python.org/issue13273 makes strict=True a better bet +# before Python 3.2.3. +# +# At the end of this file, we monkeypatch HTMLParser so that +# strict=True works well on Python 3.2.2. +major, minor, release = sys.version_info[:3] +CONSTRUCTOR_TAKES_STRICT = ( + major > 3 + or (major == 3 and minor > 2) + or (major == 3 and minor == 2 and release >= 3)) + +from bs4.element import ( + CData, + Comment, + Declaration, + Doctype, + ProcessingInstruction, + ) +from bs4.dammit import EntitySubstitution, UnicodeDammit + +from bs4.builder import ( + HTML, + HTMLTreeBuilder, + STRICT, + ) + + +HTMLPARSER = 'html.parser' + +class BeautifulSoupHTMLParser(HTMLParser): + def handle_starttag(self, name, attrs): + # XXX namespace + self.soup.handle_starttag(name, None, None, dict(attrs)) + + def handle_endtag(self, name): + self.soup.handle_endtag(name) + + def handle_data(self, data): + self.soup.handle_data(data) + + def handle_charref(self, name): + # XXX workaround for a bug in HTMLParser. Remove this once + # it's fixed. + if name.startswith('x'): + real_name = int(name.lstrip('x'), 16) + else: + real_name = int(name) + + try: + data = unichr(real_name) + except (ValueError, OverflowError), e: + data = u"\N{REPLACEMENT CHARACTER}" + + self.handle_data(data) + + def handle_entityref(self, name): + character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name) + if character is not None: + data = character + else: + data = "&%s;" % name + self.handle_data(data) + + def handle_comment(self, data): + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(Comment) + + def handle_decl(self, data): + self.soup.endData() + if data.startswith("DOCTYPE "): + data = data[len("DOCTYPE "):] + self.soup.handle_data(data) + self.soup.endData(Doctype) + + def unknown_decl(self, data): + if data.upper().startswith('CDATA['): + cls = CData + data = data[len('CDATA['):] + else: + cls = Declaration + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(cls) + + def handle_pi(self, data): + self.soup.endData() + if data.endswith("?") and data.lower().startswith("xml"): + # "An XHTML processing instruction using the trailing '?' + # will cause the '?' to be included in data." - HTMLParser + # docs. + # + # Strip the question mark so we don't end up with two + # question marks. + data = data[:-1] + self.soup.handle_data(data) + self.soup.endData(ProcessingInstruction) + + +class HTMLParserTreeBuilder(HTMLTreeBuilder): + + is_xml = False + features = [HTML, STRICT, HTMLPARSER] + + def __init__(self, *args, **kwargs): + if CONSTRUCTOR_TAKES_STRICT: + kwargs['strict'] = False + self.parser_args = (args, kwargs) + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None): + """ + :return: A 4-tuple (markup, original encoding, encoding + declared within markup, whether any characters had to be + replaced with REPLACEMENT CHARACTER). + """ + if isinstance(markup, unicode): + return markup, None, None, False + + try_encodings = [user_specified_encoding, document_declared_encoding] + dammit = UnicodeDammit(markup, try_encodings, is_html=True) + return (dammit.markup, dammit.original_encoding, + dammit.declared_html_encoding, + dammit.contains_replacement_characters) + + def feed(self, markup): + args, kwargs = self.parser_args + parser = BeautifulSoupHTMLParser(*args, **kwargs) + parser.soup = self.soup + try: + parser.feed(markup) + except HTMLParseError, e: + warnings.warn(RuntimeWarning( + "Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help.")) + raise e + +# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some +# 3.2.3 code. This ensures they don't treat markup like <p></p> as a +# string. +# +# XXX This code can be removed once most Python 3 users are on 3.2.3. +if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT: + import re + attrfind_tolerant = re.compile( + r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*' + r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?') + HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant + + locatestarttagend = re.compile(r""" + <[a-zA-Z][-.a-zA-Z0-9:_]* # tag name + (?:\s+ # whitespace before attribute name + (?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name + (?:\s*=\s* # value indicator + (?:'[^']*' # LITA-enclosed value + |\"[^\"]*\" # LIT-enclosed value + |[^'\">\s]+ # bare value + ) + )? + ) + )* + \s* # trailing whitespace +""", re.VERBOSE) + BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend + + from html.parser import tagfind, attrfind + + def parse_starttag(self, i): + self.__starttag_text = None + endpos = self.check_for_whole_start_tag(i) + if endpos < 0: + return endpos + rawdata = self.rawdata + self.__starttag_text = rawdata[i:endpos] + + # Now parse the data between i+1 and j into a tag and attrs + attrs = [] + match = tagfind.match(rawdata, i+1) + assert match, 'unexpected call to parse_starttag()' + k = match.end() + self.lasttag = tag = rawdata[i+1:k].lower() + while k < endpos: + if self.strict: + m = attrfind.match(rawdata, k) + else: + m = attrfind_tolerant.match(rawdata, k) + if not m: + break + attrname, rest, attrvalue = m.group(1, 2, 3) + if not rest: + attrvalue = None + elif attrvalue[:1] == '\'' == attrvalue[-1:] or \ + attrvalue[:1] == '"' == attrvalue[-1:]: + attrvalue = attrvalue[1:-1] + if attrvalue: + attrvalue = self.unescape(attrvalue) + attrs.append((attrname.lower(), attrvalue)) + k = m.end() + + end = rawdata[k:endpos].strip() + if end not in (">", "/>"): + lineno, offset = self.getpos() + if "\n" in self.__starttag_text: + lineno = lineno + self.__starttag_text.count("\n") + offset = len(self.__starttag_text) \ + - self.__starttag_text.rfind("\n") + else: + offset = offset + len(self.__starttag_text) + if self.strict: + self.error("junk characters in start tag: %r" + % (rawdata[k:endpos][:20],)) + self.handle_data(rawdata[i:endpos]) + return endpos + if end.endswith('/>'): + # XHTML-style empty tag: <span attr="value" /> + self.handle_startendtag(tag, attrs) + else: + self.handle_starttag(tag, attrs) + if tag in self.CDATA_CONTENT_ELEMENTS: + self.set_cdata_mode(tag) + return endpos + + def set_cdata_mode(self, elem): + self.cdata_elem = elem.lower() + self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I) + + BeautifulSoupHTMLParser.parse_starttag = parse_starttag + BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode + + CONSTRUCTOR_TAKES_STRICT = True diff --git a/bs4/builder/_lxml.py b/bs4/builder/_lxml.py new file mode 100644 index 0000000..c78fdff --- /dev/null +++ b/bs4/builder/_lxml.py @@ -0,0 +1,179 @@ +__all__ = [ + 'LXMLTreeBuilderForXML', + 'LXMLTreeBuilder', + ] + +from StringIO import StringIO +import collections +from lxml import etree +from bs4.element import Comment, Doctype, NamespacedAttribute +from bs4.builder import ( + FAST, + HTML, + HTMLTreeBuilder, + PERMISSIVE, + TreeBuilder, + XML) +from bs4.dammit import UnicodeDammit + +LXML = 'lxml' + +class LXMLTreeBuilderForXML(TreeBuilder): + DEFAULT_PARSER_CLASS = etree.XMLParser + + is_xml = True + + # Well, it's permissive by XML parser standards. + features = [LXML, XML, FAST, PERMISSIVE] + + CHUNK_SIZE = 512 + + @property + def default_parser(self): + # This can either return a parser object or a class, which + # will be instantiated with default arguments. + return etree.XMLParser(target=self, strip_cdata=False, recover=True) + + def __init__(self, parser=None, empty_element_tags=None): + if empty_element_tags is not None: + self.empty_element_tags = set(empty_element_tags) + if parser is None: + # Use the default parser. + parser = self.default_parser + if isinstance(parser, collections.Callable): + # Instantiate the parser with default arguments + parser = parser(target=self, strip_cdata=False) + self.parser = parser + self.soup = None + self.nsmaps = None + + def _getNsTag(self, tag): + # Split the namespace URL out of a fully-qualified lxml tag + # name. Copied from lxml's src/lxml/sax.py. + if tag[0] == '{': + return tuple(tag[1:].split('}', 1)) + else: + return (None, tag) + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None): + """ + :return: A 3-tuple (markup, original encoding, encoding + declared within markup). + """ + if isinstance(markup, unicode): + return markup, None, None, False + + try_encodings = [user_specified_encoding, document_declared_encoding] + dammit = UnicodeDammit(markup, try_encodings, is_html=True) + return (dammit.markup, dammit.original_encoding, + dammit.declared_html_encoding, + dammit.contains_replacement_characters) + + def feed(self, markup): + if isinstance(markup, basestring): + markup = StringIO(markup) + # Call feed() at least once, even if the markup is empty, + # or the parser won't be initialized. + data = markup.read(self.CHUNK_SIZE) + self.parser.feed(data) + while data != '': + # Now call feed() on the rest of the data, chunk by chunk. + data = markup.read(self.CHUNK_SIZE) + if data != '': + self.parser.feed(data) + self.parser.close() + + def close(self): + self.nsmaps = None + + def start(self, name, attrs, nsmap={}): + # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy. + attrs = dict(attrs) + + nsprefix = None + # Invert each namespace map as it comes in. + if len(nsmap) == 0 and self.nsmaps != None: + # There are no new namespaces for this tag, but namespaces + # are in play, so we need a separate tag stack to know + # when they end. + self.nsmaps.append(None) + elif len(nsmap) > 0: + # A new namespace mapping has come into play. + if self.nsmaps is None: + self.nsmaps = [] + inverted_nsmap = dict((value, key) for key, value in nsmap.items()) + self.nsmaps.append(inverted_nsmap) + # Also treat the namespace mapping as a set of attributes on the + # tag, so we can recreate it later. + attrs = attrs.copy() + for prefix, namespace in nsmap.items(): + attribute = NamespacedAttribute( + "xmlns", prefix, "http://www.w3.org/2000/xmlns/") + attrs[attribute] = namespace + namespace, name = self._getNsTag(name) + if namespace is not None: + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + nsprefix = inverted_nsmap[namespace] + break + self.soup.handle_starttag(name, namespace, nsprefix, attrs) + + def end(self, name): + self.soup.endData() + completed_tag = self.soup.tagStack[-1] + namespace, name = self._getNsTag(name) + nsprefix = None + if namespace is not None: + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + nsprefix = inverted_nsmap[namespace] + break + self.soup.handle_endtag(name, nsprefix) + if self.nsmaps != None: + # This tag, or one of its parents, introduced a namespace + # mapping, so pop it off the stack. + self.nsmaps.pop() + if len(self.nsmaps) == 0: + # Namespaces are no longer in play, so don't bother keeping + # track of the namespace stack. + self.nsmaps = None + + def pi(self, target, data): + pass + + def data(self, content): + self.soup.handle_data(content) + + def doctype(self, name, pubid, system): + self.soup.endData() + doctype = Doctype.for_name_and_ids(name, pubid, system) + self.soup.object_was_parsed(doctype) + + def comment(self, content): + "Handle comments as Comment objects." + self.soup.endData() + self.soup.handle_data(content) + self.soup.endData(Comment) + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment + + +class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML): + + features = [LXML, HTML, FAST, PERMISSIVE] + is_xml = False + + @property + def default_parser(self): + return etree.HTMLParser + + def feed(self, markup): + self.parser.feed(markup) + self.parser.close() + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'<html><body>%s</body></html>' % fragment diff --git a/bs4/dammit.py b/bs4/dammit.py new file mode 100644 index 0000000..58cad9b --- /dev/null +++ b/bs4/dammit.py @@ -0,0 +1,792 @@ +# -*- coding: utf-8 -*- +"""Beautiful Soup bonus library: Unicode, Dammit + +This class forces XML data into a standard format (usually to UTF-8 or +Unicode). It is heavily based on code from Mark Pilgrim's Universal +Feed Parser. It does not rewrite the XML or HTML to reflect a new +encoding; that's the tree builder's job. +""" + +import codecs +from htmlentitydefs import codepoint2name +import re +import warnings + +# Autodetects character encodings. Very useful. +# Download from http://chardet.feedparser.org/ +# or 'apt-get install python-chardet' +# or 'easy_install chardet' +try: + import chardet + #import chardet.constants + #chardet.constants._debug = 1 +except ImportError: + chardet = None + +# Available from http://cjkpython.i18n.org/. +try: + import iconv_codec +except ImportError: + pass + +xml_encoding_re = re.compile( + '^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I) +html_meta_re = re.compile( + '<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I) + +class EntitySubstitution(object): + + """Substitute XML or HTML entities for the corresponding characters.""" + + def _populate_class_variables(): + lookup = {} + reverse_lookup = {} + characters_for_re = [] + for codepoint, name in list(codepoint2name.items()): + character = unichr(codepoint) + if codepoint != 34: + # There's no point in turning the quotation mark into + # ", unless it happens within an attribute value, which + # is handled elsewhere. + characters_for_re.append(character) + lookup[character] = name + # But we do want to turn " into the quotation mark. + reverse_lookup[name] = character + re_definition = "[%s]" % "".join(characters_for_re) + return lookup, reverse_lookup, re.compile(re_definition) + (CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER, + CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables() + + CHARACTER_TO_XML_ENTITY = { + "'": "apos", + '"': "quot", + "&": "amp", + "<": "lt", + ">": "gt", + } + + BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|" + "&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)" + ")") + + @classmethod + def _substitute_html_entity(cls, matchobj): + entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0)) + return "&%s;" % entity + + @classmethod + def _substitute_xml_entity(cls, matchobj): + """Used with a regular expression to substitute the + appropriate XML entity for an XML special character.""" + entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)] + return "&%s;" % entity + + @classmethod + def quoted_attribute_value(self, value): + """Make a value into a quoted XML attribute, possibly escaping it. + + Most strings will be quoted using double quotes. + + Bob's Bar -> "Bob's Bar" + + If a string contains double quotes, it will be quoted using + single quotes. + + Welcome to "my bar" -> 'Welcome to "my bar"' + + If a string contains both single and double quotes, the + double quotes will be escaped, and the string will be quoted + using double quotes. + + Welcome to "Bob's Bar" -> "Welcome to "Bob's bar" + """ + quote_with = '"' + if '"' in value: + if "'" in value: + # The string contains both single and double + # quotes. Turn the double quotes into + # entities. We quote the double quotes rather than + # the single quotes because the entity name is + # """ whether this is HTML or XML. If we + # quoted the single quotes, we'd have to decide + # between ' and &squot;. + replace_with = """ + value = value.replace('"', replace_with) + else: + # There are double quotes but no single quotes. + # We can use single quotes to quote the attribute. + quote_with = "'" + return quote_with + value + quote_with + + @classmethod + def substitute_xml(cls, value, make_quoted_attribute=False): + """Substitute XML entities for special XML characters. + + :param value: A string to be substituted. The less-than sign will + become <, the greater-than sign will become >, and any + ampersands that are not part of an entity defition will + become &. + + :param make_quoted_attribute: If True, then the string will be + quoted, as befits an attribute value. + """ + # Escape angle brackets, and ampersands that aren't part of + # entities. + value = cls.BARE_AMPERSAND_OR_BRACKET.sub( + cls._substitute_xml_entity, value) + + if make_quoted_attribute: + value = cls.quoted_attribute_value(value) + return value + + @classmethod + def substitute_html(cls, s): + """Replace certain Unicode characters with named HTML entities. + + This differs from data.encode(encoding, 'xmlcharrefreplace') + in that the goal is to make the result more readable (to those + with ASCII displays) rather than to recover from + errors. There's absolutely nothing wrong with a UTF-8 string + containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that + character with "é" will make it more readable to some + people. + """ + return cls.CHARACTER_TO_HTML_ENTITY_RE.sub( + cls._substitute_html_entity, s) + + +class UnicodeDammit: + """A class for detecting the encoding of a *ML document and + converting it to a Unicode string. If the source encoding is + windows-1252, can replace MS smart quotes with their HTML or XML + equivalents.""" + + # This dictionary maps commonly seen values for "charset" in HTML + # meta tags to the corresponding Python codec names. It only covers + # values that aren't in Python's aliases and can't be determined + # by the heuristics in find_codec. + CHARSET_ALIASES = {"macintosh": "mac-roman", + "x-sjis": "shift-jis"} + + ENCODINGS_WITH_SMART_QUOTES = [ + "windows-1252", + "iso-8859-1", + "iso-8859-2", + ] + + def __init__(self, markup, override_encodings=[], + smart_quotes_to=None, is_html=False): + self.declared_html_encoding = None + self.smart_quotes_to = smart_quotes_to + self.tried_encodings = [] + self.contains_replacement_characters = False + + if markup == '' or isinstance(markup, unicode): + self.markup = markup + self.unicode_markup = unicode(markup) + self.original_encoding = None + return + + new_markup, document_encoding, sniffed_encoding = \ + self._detectEncoding(markup, is_html) + self.markup = new_markup + + u = None + if new_markup != markup: + # _detectEncoding modified the markup, then converted it to + # Unicode and then to UTF-8. So convert it from UTF-8. + u = self._convert_from("utf8") + self.original_encoding = sniffed_encoding + + if not u: + for proposed_encoding in ( + override_encodings + [document_encoding, sniffed_encoding]): + if proposed_encoding is not None: + u = self._convert_from(proposed_encoding) + if u: + break + + # If no luck and we have auto-detection library, try that: + if not u and chardet and not isinstance(self.markup, unicode): + u = self._convert_from(chardet.detect(self.markup)['encoding']) + + # As a last resort, try utf-8 and windows-1252: + if not u: + for proposed_encoding in ("utf-8", "windows-1252"): + u = self._convert_from(proposed_encoding) + if u: + break + + # As an absolute last resort, try the encodings again with + # character replacement. + if not u: + for proposed_encoding in ( + override_encodings + [ + document_encoding, sniffed_encoding, "utf-8", "windows-1252"]): + if proposed_encoding != "ascii": + u = self._convert_from(proposed_encoding, "replace") + if u is not None: + warnings.warn( + UnicodeWarning( + "Some characters could not be decoded, and were " + "replaced with REPLACEMENT CHARACTER.")) + self.contains_replacement_characters = True + break + + # We could at this point force it to ASCII, but that would + # destroy so much data that I think giving up is better + self.unicode_markup = u + if not u: + self.original_encoding = None + + def _sub_ms_char(self, match): + """Changes a MS smart quote character to an XML or HTML + entity, or an ASCII character.""" + orig = match.group(1) + if self.smart_quotes_to == 'ascii': + sub = self.MS_CHARS_TO_ASCII.get(orig).encode() + else: + sub = self.MS_CHARS.get(orig) + if type(sub) == tuple: + if self.smart_quotes_to == 'xml': + sub = '&#x'.encode() + sub[1].encode() + ';'.encode() + else: + sub = '&'.encode() + sub[0].encode() + ';'.encode() + else: + sub = sub.encode() + return sub + + def _convert_from(self, proposed, errors="strict"): + proposed = self.find_codec(proposed) + if not proposed or (proposed, errors) in self.tried_encodings: + return None + self.tried_encodings.append((proposed, errors)) + markup = self.markup + + # Convert smart quotes to HTML if coming from an encoding + # that might have them. + if (self.smart_quotes_to is not None + and proposed.lower() in self.ENCODINGS_WITH_SMART_QUOTES): + smart_quotes_re = b"([\x80-\x9f])" + smart_quotes_compiled = re.compile(smart_quotes_re) + markup = smart_quotes_compiled.sub(self._sub_ms_char, markup) + + try: + #print "Trying to convert document to %s (errors=%s)" % ( + # proposed, errors) + u = self._to_unicode(markup, proposed, errors) + self.markup = u + self.original_encoding = proposed + except Exception as e: + #print "That didn't work!" + #print e + return None + #print "Correct encoding: %s" % proposed + return self.markup + + def _to_unicode(self, data, encoding, errors="strict"): + '''Given a string and its encoding, decodes the string into Unicode. + %encoding is a string recognized by encodings.aliases''' + + # strip Byte Order Mark (if present) + if (len(data) >= 4) and (data[:2] == '\xfe\xff') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16be' + data = data[2:] + elif (len(data) >= 4) and (data[:2] == '\xff\xfe') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16le' + data = data[2:] + elif data[:3] == '\xef\xbb\xbf': + encoding = 'utf-8' + data = data[3:] + elif data[:4] == '\x00\x00\xfe\xff': + encoding = 'utf-32be' + data = data[4:] + elif data[:4] == '\xff\xfe\x00\x00': + encoding = 'utf-32le' + data = data[4:] + newdata = unicode(data, encoding, errors) + return newdata + + def _detectEncoding(self, xml_data, is_html=False): + """Given a document, tries to detect its XML encoding.""" + xml_encoding = sniffed_xml_encoding = None + try: + if xml_data[:4] == b'\x4c\x6f\xa7\x94': + # EBCDIC + xml_data = self._ebcdic_to_ascii(xml_data) + elif xml_data[:4] == b'\x00\x3c\x00\x3f': + # UTF-16BE + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == b'\xfe\xff') \ + and (xml_data[2:4] != b'\x00\x00'): + # UTF-16BE with BOM + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') + elif xml_data[:4] == b'\x3c\x00\x3f\x00': + # UTF-16LE + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == b'\xff\xfe') and \ + (xml_data[2:4] != b'\x00\x00'): + # UTF-16LE with BOM + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') + elif xml_data[:4] == b'\x00\x00\x00\x3c': + # UTF-32BE + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') + elif xml_data[:4] == b'\x3c\x00\x00\x00': + # UTF-32LE + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') + elif xml_data[:4] == b'\x00\x00\xfe\xff': + # UTF-32BE with BOM + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') + elif xml_data[:4] == b'\xff\xfe\x00\x00': + # UTF-32LE with BOM + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') + elif xml_data[:3] == b'\xef\xbb\xbf': + # UTF-8 with BOM + sniffed_xml_encoding = 'utf-8' + xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') + else: + sniffed_xml_encoding = 'ascii' + pass + except: + xml_encoding_match = None + xml_encoding_match = xml_encoding_re.match(xml_data) + if not xml_encoding_match and is_html: + xml_encoding_match = html_meta_re.search(xml_data) + if xml_encoding_match is not None: + xml_encoding = xml_encoding_match.groups()[0].decode( + 'ascii').lower() + if is_html: + self.declared_html_encoding = xml_encoding + if sniffed_xml_encoding and \ + (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', + 'iso-10646-ucs-4', 'ucs-4', 'csucs4', + 'utf-16', 'utf-32', 'utf_16', 'utf_32', + 'utf16', 'u16')): + xml_encoding = sniffed_xml_encoding + return xml_data, xml_encoding, sniffed_xml_encoding + + def find_codec(self, charset): + return self._codec(self.CHARSET_ALIASES.get(charset, charset)) \ + or (charset and self._codec(charset.replace("-", ""))) \ + or (charset and self._codec(charset.replace("-", "_"))) \ + or charset + + def _codec(self, charset): + if not charset: + return charset + codec = None + try: + codecs.lookup(charset) + codec = charset + except (LookupError, ValueError): + pass + return codec + + EBCDIC_TO_ASCII_MAP = None + + def _ebcdic_to_ascii(self, s): + c = self.__class__ + if not c.EBCDIC_TO_ASCII_MAP: + emap = (0,1,2,3,156,9,134,127,151,141,142,11,12,13,14,15, + 16,17,18,19,157,133,8,135,24,25,146,143,28,29,30,31, + 128,129,130,131,132,10,23,27,136,137,138,139,140,5,6,7, + 144,145,22,147,148,149,150,4,152,153,154,155,20,21,158,26, + 32,160,161,162,163,164,165,166,167,168,91,46,60,40,43,33, + 38,169,170,171,172,173,174,175,176,177,93,36,42,41,59,94, + 45,47,178,179,180,181,182,183,184,185,124,44,37,95,62,63, + 186,187,188,189,190,191,192,193,194,96,58,35,64,39,61,34, + 195,97,98,99,100,101,102,103,104,105,196,197,198,199,200, + 201,202,106,107,108,109,110,111,112,113,114,203,204,205, + 206,207,208,209,126,115,116,117,118,119,120,121,122,210, + 211,212,213,214,215,216,217,218,219,220,221,222,223,224, + 225,226,227,228,229,230,231,123,65,66,67,68,69,70,71,72, + 73,232,233,234,235,236,237,125,74,75,76,77,78,79,80,81, + 82,238,239,240,241,242,243,92,159,83,84,85,86,87,88,89, + 90,244,245,246,247,248,249,48,49,50,51,52,53,54,55,56,57, + 250,251,252,253,254,255) + import string + c.EBCDIC_TO_ASCII_MAP = string.maketrans( + ''.join(map(chr, list(range(256)))), ''.join(map(chr, emap))) + return s.translate(c.EBCDIC_TO_ASCII_MAP) + + # A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities. + MS_CHARS = {b'\x80': ('euro', '20AC'), + b'\x81': ' ', + b'\x82': ('sbquo', '201A'), + b'\x83': ('fnof', '192'), + b'\x84': ('bdquo', '201E'), + b'\x85': ('hellip', '2026'), + b'\x86': ('dagger', '2020'), + b'\x87': ('Dagger', '2021'), + b'\x88': ('circ', '2C6'), + b'\x89': ('permil', '2030'), + b'\x8A': ('Scaron', '160'), + b'\x8B': ('lsaquo', '2039'), + b'\x8C': ('OElig', '152'), + b'\x8D': '?', + b'\x8E': ('#x17D', '17D'), + b'\x8F': '?', + b'\x90': '?', + b'\x91': ('lsquo', '2018'), + b'\x92': ('rsquo', '2019'), + b'\x93': ('ldquo', '201C'), + b'\x94': ('rdquo', '201D'), + b'\x95': ('bull', '2022'), + b'\x96': ('ndash', '2013'), + b'\x97': ('mdash', '2014'), + b'\x98': ('tilde', '2DC'), + b'\x99': ('trade', '2122'), + b'\x9a': ('scaron', '161'), + b'\x9b': ('rsaquo', '203A'), + b'\x9c': ('oelig', '153'), + b'\x9d': '?', + b'\x9e': ('#x17E', '17E'), + b'\x9f': ('Yuml', ''),} + + # A parochial partial mapping of ISO-Latin-1 to ASCII. Contains + # horrors like stripping diacritical marks to turn á into a, but also + # contains non-horrors like turning “ into ". + MS_CHARS_TO_ASCII = { + b'\x80' : 'EUR', + b'\x81' : ' ', + b'\x82' : ',', + b'\x83' : 'f', + b'\x84' : ',,', + b'\x85' : '...', + b'\x86' : '+', + b'\x87' : '++', + b'\x88' : '^', + b'\x89' : '%', + b'\x8a' : 'S', + b'\x8b' : '<', + b'\x8c' : 'OE', + b'\x8d' : '?', + b'\x8e' : 'Z', + b'\x8f' : '?', + b'\x90' : '?', + b'\x91' : "'", + b'\x92' : "'", + b'\x93' : '"', + b'\x94' : '"', + b'\x95' : '*', + b'\x96' : '-', + b'\x97' : '--', + b'\x98' : '~', + b'\x99' : '(TM)', + b'\x9a' : 's', + b'\x9b' : '>', + b'\x9c' : 'oe', + b'\x9d' : '?', + b'\x9e' : 'z', + b'\x9f' : 'Y', + b'\xa0' : ' ', + b'\xa1' : '!', + b'\xa2' : 'c', + b'\xa3' : 'GBP', + b'\xa4' : '$', #This approximation is especially parochial--this is the + #generic currency symbol. + b'\xa5' : 'YEN', + b'\xa6' : '|', + b'\xa7' : 'S', + b'\xa8' : '..', + b'\xa9' : '', + b'\xaa' : '(th)', + b'\xab' : '<<', + b'\xac' : '!', + b'\xad' : ' ', + b'\xae' : '(R)', + b'\xaf' : '-', + b'\xb0' : 'o', + b'\xb1' : '+-', + b'\xb2' : '2', + b'\xb3' : '3', + b'\xb4' : ("'", 'acute'), + b'\xb5' : 'u', + b'\xb6' : 'P', + b'\xb7' : '*', + b'\xb8' : ',', + b'\xb9' : '1', + b'\xba' : '(th)', + b'\xbb' : '>>', + b'\xbc' : '1/4', + b'\xbd' : '1/2', + b'\xbe' : '3/4', + b'\xbf' : '?', + b'\xc0' : 'A', + b'\xc1' : 'A', + b'\xc2' : 'A', + b'\xc3' : 'A', + b'\xc4' : 'A', + b'\xc5' : 'A', + b'\xc6' : 'AE', + b'\xc7' : 'C', + b'\xc8' : 'E', + b'\xc9' : 'E', + b'\xca' : 'E', + b'\xcb' : 'E', + b'\xcc' : 'I', + b'\xcd' : 'I', + b'\xce' : 'I', + b'\xcf' : 'I', + b'\xd0' : 'D', + b'\xd1' : 'N', + b'\xd2' : 'O', + b'\xd3' : 'O', + b'\xd4' : 'O', + b'\xd5' : 'O', + b'\xd6' : 'O', + b'\xd7' : '*', + b'\xd8' : 'O', + b'\xd9' : 'U', + b'\xda' : 'U', + b'\xdb' : 'U', + b'\xdc' : 'U', + b'\xdd' : 'Y', + b'\xde' : 'b', + b'\xdf' : 'B', + b'\xe0' : 'a', + b'\xe1' : 'a', + b'\xe2' : 'a', + b'\xe3' : 'a', + b'\xe4' : 'a', + b'\xe5' : 'a', + b'\xe6' : 'ae', + b'\xe7' : 'c', + b'\xe8' : 'e', + b'\xe9' : 'e', + b'\xea' : 'e', + b'\xeb' : 'e', + b'\xec' : 'i', + b'\xed' : 'i', + b'\xee' : 'i', + b'\xef' : 'i', + b'\xf0' : 'o', + b'\xf1' : 'n', + b'\xf2' : 'o', + b'\xf3' : 'o', + b'\xf4' : 'o', + b'\xf5' : 'o', + b'\xf6' : 'o', + b'\xf7' : '/', + b'\xf8' : 'o', + b'\xf9' : 'u', + b'\xfa' : 'u', + b'\xfb' : 'u', + b'\xfc' : 'u', + b'\xfd' : 'y', + b'\xfe' : 'b', + b'\xff' : 'y', + } + + # A map used when removing rogue Windows-1252/ISO-8859-1 + # characters in otherwise UTF-8 documents. + # + # Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in + # Windows-1252. + WINDOWS_1252_TO_UTF8 = { + 0x80 : b'\xe2\x82\xac', # € + 0x82 : b'\xe2\x80\x9a', # ‚ + 0x83 : b'\xc6\x92', # ƒ + 0x84 : b'\xe2\x80\x9e', # „ + 0x85 : b'\xe2\x80\xa6', # … + 0x86 : b'\xe2\x80\xa0', # † + 0x87 : b'\xe2\x80\xa1', # ‡ + 0x88 : b'\xcb\x86', # ˆ + 0x89 : b'\xe2\x80\xb0', # ‰ + 0x8a : b'\xc5\xa0', # Š + 0x8b : b'\xe2\x80\xb9', # ‹ + 0x8c : b'\xc5\x92', # Œ + 0x8e : b'\xc5\xbd', # Ž + 0x91 : b'\xe2\x80\x98', # ‘ + 0x92 : b'\xe2\x80\x99', # ’ + 0x93 : b'\xe2\x80\x9c', # “ + 0x94 : b'\xe2\x80\x9d', # ” + 0x95 : b'\xe2\x80\xa2', # • + 0x96 : b'\xe2\x80\x93', # – + 0x97 : b'\xe2\x80\x94', # — + 0x98 : b'\xcb\x9c', # ˜ + 0x99 : b'\xe2\x84\xa2', # ™ + 0x9a : b'\xc5\xa1', # š + 0x9b : b'\xe2\x80\xba', # › + 0x9c : b'\xc5\x93', # œ + 0x9e : b'\xc5\xbe', # ž + 0x9f : b'\xc5\xb8', # Ÿ + 0xa0 : b'\xc2\xa0', #   + 0xa1 : b'\xc2\xa1', # ¡ + 0xa2 : b'\xc2\xa2', # ¢ + 0xa3 : b'\xc2\xa3', # £ + 0xa4 : b'\xc2\xa4', # ¤ + 0xa5 : b'\xc2\xa5', # ¥ + 0xa6 : b'\xc2\xa6', # ¦ + 0xa7 : b'\xc2\xa7', # § + 0xa8 : b'\xc2\xa8', # ¨ + 0xa9 : b'\xc2\xa9', # © + 0xaa : b'\xc2\xaa', # ª + 0xab : b'\xc2\xab', # « + 0xac : b'\xc2\xac', # ¬ + 0xad : b'\xc2\xad', # ­ + 0xae : b'\xc2\xae', # ® + 0xaf : b'\xc2\xaf', # ¯ + 0xb0 : b'\xc2\xb0', # ° + 0xb1 : b'\xc2\xb1', # ± + 0xb2 : b'\xc2\xb2', # ² + 0xb3 : b'\xc2\xb3', # ³ + 0xb4 : b'\xc2\xb4', # ´ + 0xb5 : b'\xc2\xb5', # µ + 0xb6 : b'\xc2\xb6', # ¶ + 0xb7 : b'\xc2\xb7', # · + 0xb8 : b'\xc2\xb8', # ¸ + 0xb9 : b'\xc2\xb9', # ¹ + 0xba : b'\xc2\xba', # º + 0xbb : b'\xc2\xbb', # » + 0xbc : b'\xc2\xbc', # ¼ + 0xbd : b'\xc2\xbd', # ½ + 0xbe : b'\xc2\xbe', # ¾ + 0xbf : b'\xc2\xbf', # ¿ + 0xc0 : b'\xc3\x80', # À + 0xc1 : b'\xc3\x81', # Á + 0xc2 : b'\xc3\x82', #  + 0xc3 : b'\xc3\x83', # à + 0xc4 : b'\xc3\x84', # Ä + 0xc5 : b'\xc3\x85', # Å + 0xc6 : b'\xc3\x86', # Æ + 0xc7 : b'\xc3\x87', # Ç + 0xc8 : b'\xc3\x88', # È + 0xc9 : b'\xc3\x89', # É + 0xca : b'\xc3\x8a', # Ê + 0xcb : b'\xc3\x8b', # Ë + 0xcc : b'\xc3\x8c', # Ì + 0xcd : b'\xc3\x8d', # Í + 0xce : b'\xc3\x8e', # Î + 0xcf : b'\xc3\x8f', # Ï + 0xd0 : b'\xc3\x90', # Ð + 0xd1 : b'\xc3\x91', # Ñ + 0xd2 : b'\xc3\x92', # Ò + 0xd3 : b'\xc3\x93', # Ó + 0xd4 : b'\xc3\x94', # Ô + 0xd5 : b'\xc3\x95', # Õ + 0xd6 : b'\xc3\x96', # Ö + 0xd7 : b'\xc3\x97', # × + 0xd8 : b'\xc3\x98', # Ø + 0xd9 : b'\xc3\x99', # Ù + 0xda : b'\xc3\x9a', # Ú + 0xdb : b'\xc3\x9b', # Û + 0xdc : b'\xc3\x9c', # Ü + 0xdd : b'\xc3\x9d', # Ý + 0xde : b'\xc3\x9e', # Þ + 0xdf : b'\xc3\x9f', # ß + 0xe0 : b'\xc3\xa0', # à + 0xe1 : b'\xa1', # á + 0xe2 : b'\xc3\xa2', # â + 0xe3 : b'\xc3\xa3', # ã + 0xe4 : b'\xc3\xa4', # ä + 0xe5 : b'\xc3\xa5', # å + 0xe6 : b'\xc3\xa6', # æ + 0xe7 : b'\xc3\xa7', # ç + 0xe8 : b'\xc3\xa8', # è + 0xe9 : b'\xc3\xa9', # é + 0xea : b'\xc3\xaa', # ê + 0xeb : b'\xc3\xab', # ë + 0xec : b'\xc3\xac', # ì + 0xed : b'\xc3\xad', # í + 0xee : b'\xc3\xae', # î + 0xef : b'\xc3\xaf', # ï + 0xf0 : b'\xc3\xb0', # ð + 0xf1 : b'\xc3\xb1', # ñ + 0xf2 : b'\xc3\xb2', # ò + 0xf3 : b'\xc3\xb3', # ó + 0xf4 : b'\xc3\xb4', # ô + 0xf5 : b'\xc3\xb5', # õ + 0xf6 : b'\xc3\xb6', # ö + 0xf7 : b'\xc3\xb7', # ÷ + 0xf8 : b'\xc3\xb8', # ø + 0xf9 : b'\xc3\xb9', # ù + 0xfa : b'\xc3\xba', # ú + 0xfb : b'\xc3\xbb', # û + 0xfc : b'\xc3\xbc', # ü + 0xfd : b'\xc3\xbd', # ý + 0xfe : b'\xc3\xbe', # þ + } + + MULTIBYTE_MARKERS_AND_SIZES = [ + (0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF + (0xe0, 0xef, 3), # 3-byte characters start with E0-EF + (0xf0, 0xf4, 4), # 4-byte characters start with F0-F4 + ] + + FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0] + LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1] + + @classmethod + def detwingle(cls, in_bytes, main_encoding="utf8", + embedded_encoding="windows-1252"): + """Fix characters from one encoding embedded in some other encoding. + + Currently the only situation supported is Windows-1252 (or its + subset ISO-8859-1), embedded in UTF-8. + + The input must be a bytestring. If you've already converted + the document to Unicode, you're too late. + + The output is a bytestring in which `embedded_encoding` + characters have been converted to their `main_encoding` + equivalents. + """ + if embedded_encoding.replace('_', '-').lower() not in ( + 'windows-1252', 'windows_1252'): + raise NotImplementedError( + "Windows-1252 and ISO-8859-1 are the only currently supported " + "embedded encodings.") + + if main_encoding.lower() not in ('utf8', 'utf-8'): + raise NotImplementedError( + "UTF-8 is the only currently supported main encoding.") + + byte_chunks = [] + + chunk_start = 0 + pos = 0 + while pos < len(in_bytes): + byte = in_bytes[pos] + if not isinstance(byte, int): + # Python 2.x + byte = ord(byte) + if (byte >= cls.FIRST_MULTIBYTE_MARKER + and byte <= cls.LAST_MULTIBYTE_MARKER): + # This is the start of a UTF-8 multibyte character. Skip + # to the end. + for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES: + if byte >= start and byte <= end: + pos += size + break + elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8: + # We found a Windows-1252 character! + # Save the string up to this point as a chunk. + byte_chunks.append(in_bytes[chunk_start:pos]) + + # Now translate the Windows-1252 character into UTF-8 + # and add it as another, one-byte chunk. + byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte]) + pos += 1 + chunk_start = pos + else: + # Go on to the next character. + pos += 1 + if chunk_start == 0: + # The string is unchanged. + return in_bytes + else: + # Store the final chunk. + byte_chunks.append(in_bytes[chunk_start:]) + return b''.join(byte_chunks) + diff --git a/bs4/element.py b/bs4/element.py new file mode 100644 index 0000000..91a4007 --- /dev/null +++ b/bs4/element.py @@ -0,0 +1,1347 @@ +import collections +import re +import sys +import warnings +from bs4.dammit import EntitySubstitution + +DEFAULT_OUTPUT_ENCODING = "utf-8" +PY3K = (sys.version_info[0] > 2) + +whitespace_re = re.compile("\s+") + +def _alias(attr): + """Alias one attribute name to another for backward compatibility""" + @property + def alias(self): + return getattr(self, attr) + + @alias.setter + def alias(self): + return setattr(self, attr) + return alias + + +class NamespacedAttribute(unicode): + + def __new__(cls, prefix, name, namespace=None): + if name is None: + obj = unicode.__new__(cls, prefix) + else: + obj = unicode.__new__(cls, prefix + ":" + name) + obj.prefix = prefix + obj.name = name + obj.namespace = namespace + return obj + +class AttributeValueWithCharsetSubstitution(unicode): + """A stand-in object for a character encoding specified in HTML.""" + +class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'charset' attribute. + + When Beautiful Soup parses the markup '<meta charset="utf8">', the + value of the 'charset' attribute will be one of these objects. + """ + + def __new__(cls, original_value): + obj = unicode.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + return encoding + + +class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'content' attribute. + + When Beautiful Soup parses the markup: + <meta http-equiv="content-type" content="text/html; charset=utf8"> + + The value of the 'content' attribute will be one of these objects. + """ + + CHARSET_RE = re.compile("((^|;)\s*charset=)([^;]*)", re.M) + + def __new__(cls, original_value): + match = cls.CHARSET_RE.search(original_value) + if match is None: + # No substitution necessary. + return unicode.__new__(unicode, original_value) + + obj = unicode.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + def rewrite(match): + return match.group(1) + encoding + return self.CHARSET_RE.sub(rewrite, self.original_value) + + +class PageElement(object): + """Contains the navigational information for some part of the page + (either a tag or a piece of text)""" + + # There are five possible values for the "formatter" argument passed in + # to methods like encode() and prettify(): + # + # "html" - All Unicode characters with corresponding HTML entities + # are converted to those entities on output. + # "minimal" - Bare ampersands and angle brackets are converted to + # XML entities: & < > + # None - The null formatter. Unicode characters are never + # converted to entities. This is not recommended, but it's + # faster than "minimal". + # A function - This function will be called on every string that + # needs to undergo entity substition + FORMATTERS = { + "html" : EntitySubstitution.substitute_html, + "minimal" : EntitySubstitution.substitute_xml, + None : None + } + + @classmethod + def format_string(self, s, formatter='minimal'): + """Format the given string using the given formatter.""" + if not callable(formatter): + formatter = self.FORMATTERS.get( + formatter, EntitySubstitution.substitute_xml) + if formatter is None: + output = s + else: + output = formatter(s) + return output + + def setup(self, parent=None, previous_element=None): + """Sets up the initial relations between this element and + other elements.""" + self.parent = parent + self.previous_element = previous_element + if previous_element is not None: + self.previous_element.next_element = self + self.next_element = None + self.previous_sibling = None + self.next_sibling = None + if self.parent is not None and self.parent.contents: + self.previous_sibling = self.parent.contents[-1] + self.previous_sibling.next_sibling = self + + nextSibling = _alias("next_sibling") # BS3 + previousSibling = _alias("previous_sibling") # BS3 + + def replace_with(self, replace_with): + if replace_with is self: + return + if replace_with is self.parent: + raise ValueError("Cannot replace a Tag with its parent.") + old_parent = self.parent + my_index = self.parent.index(self) + self.extract() + old_parent.insert(my_index, replace_with) + return self + replaceWith = replace_with # BS3 + + def unwrap(self): + my_parent = self.parent + my_index = self.parent.index(self) + self.extract() + for child in reversed(self.contents[:]): + my_parent.insert(my_index, child) + return self + replace_with_children = unwrap + replaceWithChildren = unwrap # BS3 + + def wrap(self, wrap_inside): + me = self.replace_with(wrap_inside) + wrap_inside.append(me) + return wrap_inside + + def extract(self): + """Destructively rips this element out of the tree.""" + if self.parent is not None: + del self.parent.contents[self.parent.index(self)] + + #Find the two elements that would be next to each other if + #this element (and any children) hadn't been parsed. Connect + #the two. + last_child = self._last_descendant() + next_element = last_child.next_element + + if self.previous_element is not None: + self.previous_element.next_element = next_element + if next_element is not None: + next_element.previous_element = self.previous_element + self.previous_element = None + last_child.next_element = None + + self.parent = None + if self.previous_sibling is not None: + self.previous_sibling.next_sibling = self.next_sibling + if self.next_sibling is not None: + self.next_sibling.previous_sibling = self.previous_sibling + self.previous_sibling = self.next_sibling = None + return self + + def _last_descendant(self): + "Finds the last element beneath this object to be parsed." + last_child = self + while hasattr(last_child, 'contents') and last_child.contents: + last_child = last_child.contents[-1] + return last_child + # BS3: Not part of the API! + _lastRecursiveChild = _last_descendant + + def insert(self, position, new_child): + if new_child is self: + raise ValueError("Cannot insert a tag into itself.") + if (isinstance(new_child, basestring) + and not isinstance(new_child, NavigableString)): + new_child = NavigableString(new_child) + + position = min(position, len(self.contents)) + if hasattr(new_child, 'parent') and new_child.parent is not None: + # We're 'inserting' an element that's already one + # of this object's children. + if new_child.parent is self: + current_index = self.index(new_child) + if current_index < position: + # We're moving this element further down the list + # of this object's children. That means that when + # we extract this element, our target index will + # jump down one. + position -= 1 + new_child.extract() + + new_child.parent = self + previous_child = None + if position == 0: + new_child.previous_sibling = None + new_child.previous_element = self + else: + previous_child = self.contents[position - 1] + new_child.previous_sibling = previous_child + new_child.previous_sibling.next_sibling = new_child + new_child.previous_element = previous_child._last_descendant() + if new_child.previous_element is not None: + new_child.previous_element.next_element = new_child + + new_childs_last_element = new_child._last_descendant() + + if position >= len(self.contents): + new_child.next_sibling = None + + parent = self + parents_next_sibling = None + while parents_next_sibling is None and parent is not None: + parents_next_sibling = parent.next_sibling + parent = parent.parent + if parents_next_sibling is not None: + # We found the element that comes next in the document. + break + if parents_next_sibling is not None: + new_childs_last_element.next_element = parents_next_sibling + else: + # The last element of this tag is the last element in + # the document. + new_childs_last_element.next_element = None + else: + next_child = self.contents[position] + new_child.next_sibling = next_child + if new_child.next_sibling is not None: + new_child.next_sibling.previous_sibling = new_child + new_childs_last_element.next_element = next_child + + if new_childs_last_element.next_element is not None: + new_childs_last_element.next_element.previous_element = new_childs_last_element + self.contents.insert(position, new_child) + + def append(self, tag): + """Appends the given tag to the contents of this tag.""" + self.insert(len(self.contents), tag) + + def insert_before(self, predecessor): + """Makes the given element the immediate predecessor of this one. + + The two elements will have the same parent, and the given element + will be immediately before this one. + """ + if self is predecessor: + raise ValueError("Can't insert an element before itself.") + parent = self.parent + if parent is None: + raise ValueError( + "Element has no parent, so 'before' has no meaning.") + # Extract first so that the index won't be screwed up if they + # are siblings. + if isinstance(predecessor, PageElement): + predecessor.extract() + index = parent.index(self) + parent.insert(index, predecessor) + + def insert_after(self, successor): + """Makes the given element the immediate successor of this one. + + The two elements will have the same parent, and the given element + will be immediately after this one. + """ + if self is successor: + raise ValueError("Can't insert an element after itself.") + parent = self.parent + if parent is None: + raise ValueError( + "Element has no parent, so 'after' has no meaning.") + # Extract first so that the index won't be screwed up if they + # are siblings. + if isinstance(successor, PageElement): + successor.extract() + index = parent.index(self) + parent.insert(index+1, successor) + + def find_next(self, name=None, attrs={}, text=None, **kwargs): + """Returns the first item that matches the given criteria and + appears after this Tag in the document.""" + return self._find_one(self.find_all_next, name, attrs, text, **kwargs) + findNext = find_next # BS3 + + def find_all_next(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns all items that match the given criteria and appear + after this Tag in the document.""" + return self._find_all(name, attrs, text, limit, self.next_elements, + **kwargs) + findAllNext = find_all_next # BS3 + + def find_next_sibling(self, name=None, attrs={}, text=None, **kwargs): + """Returns the closest sibling to this Tag that matches the + given criteria and appears after this Tag in the document.""" + return self._find_one(self.find_next_siblings, name, attrs, text, + **kwargs) + findNextSibling = find_next_sibling # BS3 + + def find_next_siblings(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns the siblings of this Tag that match the given + criteria and appear after this Tag in the document.""" + return self._find_all(name, attrs, text, limit, + self.next_siblings, **kwargs) + findNextSiblings = find_next_siblings # BS3 + fetchNextSiblings = find_next_siblings # BS2 + + def find_previous(self, name=None, attrs={}, text=None, **kwargs): + """Returns the first item that matches the given criteria and + appears before this Tag in the document.""" + return self._find_one( + self.find_all_previous, name, attrs, text, **kwargs) + findPrevious = find_previous # BS3 + + def find_all_previous(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns all items that match the given criteria and appear + before this Tag in the document.""" + return self._find_all(name, attrs, text, limit, self.previous_elements, + **kwargs) + findAllPrevious = find_all_previous # BS3 + fetchPrevious = find_all_previous # BS2 + + def find_previous_sibling(self, name=None, attrs={}, text=None, **kwargs): + """Returns the closest sibling to this Tag that matches the + given criteria and appears before this Tag in the document.""" + return self._find_one(self.find_previous_siblings, name, attrs, text, + **kwargs) + findPreviousSibling = find_previous_sibling # BS3 + + def find_previous_siblings(self, name=None, attrs={}, text=None, + limit=None, **kwargs): + """Returns the siblings of this Tag that match the given + criteria and appear before this Tag in the document.""" + return self._find_all(name, attrs, text, limit, + self.previous_siblings, **kwargs) + findPreviousSiblings = find_previous_siblings # BS3 + fetchPreviousSiblings = find_previous_siblings # BS2 + + def find_parent(self, name=None, attrs={}, **kwargs): + """Returns the closest parent of this Tag that matches the given + criteria.""" + # NOTE: We can't use _find_one because findParents takes a different + # set of arguments. + r = None + l = self.find_parents(name, attrs, 1) + if l: + r = l[0] + return r + findParent = find_parent # BS3 + + def find_parents(self, name=None, attrs={}, limit=None, **kwargs): + """Returns the parents of this Tag that match the given + criteria.""" + + return self._find_all(name, attrs, None, limit, self.parents, + **kwargs) + findParents = find_parents # BS3 + fetchParents = find_parents # BS2 + + @property + def next(self): + return self.next_element + + @property + def previous(self): + return self.previous_element + + #These methods do the real heavy lifting. + + def _find_one(self, method, name, attrs, text, **kwargs): + r = None + l = method(name, attrs, text, 1, **kwargs) + if l: + r = l[0] + return r + + def _find_all(self, name, attrs, text, limit, generator, **kwargs): + "Iterates over a generator looking for things that match." + + if isinstance(name, SoupStrainer): + strainer = name + elif text is None and not limit and not attrs and not kwargs: + # Optimization to find all tags. + if name is True or name is None: + return [element for element in generator + if isinstance(element, Tag)] + # Optimization to find all tags with a given name. + elif isinstance(name, basestring): + return [element for element in generator + if isinstance(element, Tag) and element.name == name] + else: + strainer = SoupStrainer(name, attrs, text, **kwargs) + else: + # Build a SoupStrainer + strainer = SoupStrainer(name, attrs, text, **kwargs) + results = ResultSet(strainer) + while True: + try: + i = next(generator) + except StopIteration: + break + if i: + found = strainer.search(i) + if found: + results.append(found) + if limit and len(results) >= limit: + break + return results + + #These generators can be used to navigate starting from both + #NavigableStrings and Tags. + @property + def next_elements(self): + i = self.next_element + while i is not None: + yield i + i = i.next_element + + @property + def next_siblings(self): + i = self.next_sibling + while i is not None: + yield i + i = i.next_sibling + + @property + def previous_elements(self): + i = self.previous_element + while i is not None: + yield i + i = i.previous_element + + @property + def previous_siblings(self): + i = self.previous_sibling + while i is not None: + yield i + i = i.previous_sibling + + @property + def parents(self): + i = self.parent + while i is not None: + yield i + i = i.parent + + # Methods for supporting CSS selectors. + + tag_name_re = re.compile('^[a-z0-9]+$') + + # /^(\w+)\[(\w+)([=~\|\^\$\*]?)=?"?([^\]"]*)"?\]$/ + # \---/ \---/\-------------/ \-------/ + # | | | | + # | | | The value + # | | ~,|,^,$,* or = + # | Attribute + # Tag + attribselect_re = re.compile( + r'^(?P<tag>\w+)?\[(?P<attribute>\w+)(?P<operator>[=~\|\^\$\*]?)' + + r'=?"?(?P<value>[^\]"]*)"?\]$' + ) + + def _attr_value_as_string(self, value, default=None): + """Force an attribute value into a string representation. + + A multi-valued attribute will be converted into a + space-separated stirng. + """ + value = self.get(value, default) + if isinstance(value, list) or isinstance(value, tuple): + value =" ".join(value) + return value + + def _attribute_checker(self, operator, attribute, value=''): + """Create a function that performs a CSS selector operation. + + Takes an operator, attribute and optional value. Returns a + function that will return True for elements that match that + combination. + """ + if operator == '=': + # string representation of `attribute` is equal to `value` + return lambda el: el._attr_value_as_string(attribute) == value + elif operator == '~': + # space-separated list representation of `attribute` + # contains `value` + def _includes_value(element): + attribute_value = element.get(attribute, []) + if not isinstance(attribute_value, list): + attribute_value = attribute_value.split() + return value in attribute_value + return _includes_value + elif operator == '^': + # string representation of `attribute` starts with `value` + return lambda el: el._attr_value_as_string( + attribute, '').startswith(value) + elif operator == '$': + # string represenation of `attribute` ends with `value` + return lambda el: el._attr_value_as_string( + attribute, '').endswith(value) + elif operator == '*': + # string representation of `attribute` contains `value` + return lambda el: value in el._attr_value_as_string(attribute, '') + elif operator == '|': + # string representation of `attribute` is either exactly + # `value` or starts with `value` and then a dash. + def _is_or_starts_with_dash(element): + attribute_value = element._attr_value_as_string(attribute, '') + return (attribute_value == value or attribute_value.startswith( + value + '-')) + return _is_or_starts_with_dash + else: + return lambda el: el.has_attr(attribute) + + def select(self, selector): + """Perform a CSS selection operation on the current element.""" + tokens = selector.split() + current_context = [self] + for index, token in enumerate(tokens): + if tokens[index - 1] == '>': + # already found direct descendants in last step. skip this + # step. + continue + m = self.attribselect_re.match(token) + if m is not None: + # Attribute selector + tag, attribute, operator, value = m.groups() + if not tag: + tag = True + checker = self._attribute_checker(operator, attribute, value) + found = [] + for context in current_context: + found.extend( + [el for el in context.find_all(tag) if checker(el)]) + current_context = found + continue + + if '#' in token: + # ID selector + tag, id = token.split('#', 1) + if tag == "": + tag = True + el = current_context[0].find(tag, {'id': id}) + if el is None: + return [] # No match + current_context = [el] + continue + + if '.' in token: + # Class selector + tag_name, klass = token.split('.', 1) + if not tag_name: + tag_name = True + classes = set(klass.split('.')) + found = [] + def classes_match(tag): + if tag_name is not True and tag.name != tag_name: + return False + if not tag.has_attr('class'): + return False + return classes.issubset(tag['class']) + for context in current_context: + found.extend(context.find_all(classes_match)) + current_context = found + continue + + if token == '*': + # Star selector + found = [] + for context in current_context: + found.extend(context.findAll(True)) + current_context = found + continue + + if token == '>': + # Child selector + tag = tokens[index + 1] + if not tag: + tag = True + + found = [] + for context in current_context: + found.extend(context.find_all(tag, recursive=False)) + current_context = found + continue + + # Here we should just have a regular tag + if not self.tag_name_re.match(token): + return [] + found = [] + for context in current_context: + found.extend(context.findAll(token)) + current_context = found + return current_context + + # Old non-property versions of the generators, for backwards + # compatibility with BS3. + def nextGenerator(self): + return self.next_elements + + def nextSiblingGenerator(self): + return self.next_siblings + + def previousGenerator(self): + return self.previous_elements + + def previousSiblingGenerator(self): + return self.previous_siblings + + def parentGenerator(self): + return self.parents + + +class NavigableString(unicode, PageElement): + + PREFIX = '' + SUFFIX = '' + + def __new__(cls, value): + """Create a new NavigableString. + + When unpickling a NavigableString, this method is called with + the string in DEFAULT_OUTPUT_ENCODING. That encoding needs to be + passed in to the superclass's __new__ or the superclass won't know + how to handle non-ASCII characters. + """ + if isinstance(value, unicode): + return unicode.__new__(cls, value) + return unicode.__new__(cls, value, DEFAULT_OUTPUT_ENCODING) + + def __getnewargs__(self): + return (unicode(self),) + + def __getattr__(self, attr): + """text.string gives you text. This is for backwards + compatibility for Navigable*String, but for CData* it lets you + get the string without the CData wrapper.""" + if attr == 'string': + return self + else: + raise AttributeError( + "'%s' object has no attribute '%s'" % ( + self.__class__.__name__, attr)) + + def output_ready(self, formatter="minimal"): + output = self.format_string(self, formatter) + return self.PREFIX + output + self.SUFFIX + + +class PreformattedString(NavigableString): + """A NavigableString not subject to the normal formatting rules. + + The string will be passed into the formatter (to trigger side effects), + but the return value will be ignored. + """ + + def output_ready(self, formatter="minimal"): + """CData strings are passed into the formatter. + But the return value is ignored.""" + self.format_string(self, formatter) + return self.PREFIX + self + self.SUFFIX + +class CData(PreformattedString): + + PREFIX = u'<![CDATA[' + SUFFIX = u']]>' + +class ProcessingInstruction(PreformattedString): + + PREFIX = u'<?' + SUFFIX = u'?>' + +class Comment(PreformattedString): + + PREFIX = u'<!--' + SUFFIX = u'-->' + + +class Declaration(PreformattedString): + PREFIX = u'<!' + SUFFIX = u'!>' + + +class Doctype(PreformattedString): + + @classmethod + def for_name_and_ids(cls, name, pub_id, system_id): + value = name + if pub_id is not None: + value += ' PUBLIC "%s"' % pub_id + if system_id is not None: + value += ' "%s"' % system_id + elif system_id is not None: + value += ' SYSTEM "%s"' % system_id + + return Doctype(value) + + PREFIX = u'<!DOCTYPE ' + SUFFIX = u'>\n' + + +class Tag(PageElement): + + """Represents a found HTML tag with its attributes and contents.""" + + def __init__(self, parser=None, builder=None, name=None, namespace=None, + prefix=None, attrs=None, parent=None, previous=None): + "Basic constructor." + + if parser is None: + self.parser_class = None + else: + # We don't actually store the parser object: that lets extracted + # chunks be garbage-collected. + self.parser_class = parser.__class__ + if name is None: + raise ValueError("No value provided for new tag's name.") + self.name = name + self.namespace = namespace + self.prefix = prefix + if attrs is None: + attrs = {} + elif builder.cdata_list_attributes: + attrs = builder._replace_cdata_list_attribute_values( + self.name, attrs) + else: + attrs = dict(attrs) + self.attrs = attrs + self.contents = [] + self.setup(parent, previous) + self.hidden = False + + # Set up any substitutions, such as the charset in a META tag. + if builder is not None: + builder.set_up_substitutions(self) + self.can_be_empty_element = builder.can_be_empty_element(name) + else: + self.can_be_empty_element = False + + parserClass = _alias("parser_class") # BS3 + + @property + def is_empty_element(self): + """Is this tag an empty-element tag? (aka a self-closing tag) + + A tag that has contents is never an empty-element tag. + + A tag that has no contents may or may not be an empty-element + tag. It depends on the builder used to create the tag. If the + builder has a designated list of empty-element tags, then only + a tag whose name shows up in that list is considered an + empty-element tag. + + If the builder has no designated list of empty-element tags, + then any tag with no contents is an empty-element tag. + """ + return len(self.contents) == 0 and self.can_be_empty_element + isSelfClosing = is_empty_element # BS3 + + @property + def string(self): + """Convenience property to get the single string within this tag. + + :Return: If this tag has a single string child, return value + is that string. If this tag has no children, or more than one + child, return value is None. If this tag has one child tag, + return value is the 'string' attribute of the child tag, + recursively. + """ + if len(self.contents) != 1: + return None + child = self.contents[0] + if isinstance(child, NavigableString): + return child + return child.string + + @string.setter + def string(self, string): + self.clear() + self.append(string.__class__(string)) + + def _all_strings(self, strip=False): + """Yield all child strings, possibly stripping them.""" + for descendant in self.descendants: + if not isinstance(descendant, NavigableString): + continue + if strip: + descendant = descendant.strip() + if len(descendant) == 0: + continue + yield descendant + strings = property(_all_strings) + + @property + def stripped_strings(self): + for string in self._all_strings(True): + yield string + + def get_text(self, separator="", strip=False): + """ + Get all child strings, concatenated using the given separator. + """ + return separator.join([s for s in self._all_strings(strip)]) + getText = get_text + text = property(get_text) + + def decompose(self): + """Recursively destroys the contents of this tree.""" + self.extract() + i = self + while i is not None: + next = i.next_element + i.__dict__.clear() + i = next + + def clear(self, decompose=False): + """ + Extract all children. If decompose is True, decompose instead. + """ + if decompose: + for element in self.contents[:]: + if isinstance(element, Tag): + element.decompose() + else: + element.extract() + else: + for element in self.contents[:]: + element.extract() + + def index(self, element): + """ + Find the index of a child by identity, not value. Avoids issues with + tag.contents.index(element) getting the index of equal elements. + """ + for i, child in enumerate(self.contents): + if child is element: + return i + raise ValueError("Tag.index: element not in tag") + + def get(self, key, default=None): + """Returns the value of the 'key' attribute for the tag, or + the value given for 'default' if it doesn't have that + attribute.""" + return self.attrs.get(key, default) + + def has_attr(self, key): + return key in self.attrs + + def __hash__(self): + return str(self).__hash__() + + def __getitem__(self, key): + """tag[key] returns the value of the 'key' attribute for the tag, + and throws an exception if it's not there.""" + return self.attrs[key] + + def __iter__(self): + "Iterating over a tag iterates over its contents." + return iter(self.contents) + + def __len__(self): + "The length of a tag is the length of its list of contents." + return len(self.contents) + + def __contains__(self, x): + return x in self.contents + + def __nonzero__(self): + "A tag is non-None even if it has no contents." + return True + + def __setitem__(self, key, value): + """Setting tag[key] sets the value of the 'key' attribute for the + tag.""" + self.attrs[key] = value + + def __delitem__(self, key): + "Deleting tag[key] deletes all 'key' attributes for the tag." + self.attrs.pop(key, None) + + def __call__(self, *args, **kwargs): + """Calling a tag like a function is the same as calling its + find_all() method. Eg. tag('a') returns a list of all the A tags + found within this tag.""" + return self.find_all(*args, **kwargs) + + def __getattr__(self, tag): + #print "Getattr %s.%s" % (self.__class__, tag) + if len(tag) > 3 and tag.endswith('Tag'): + # BS3: soup.aTag -> "soup.find("a") + tag_name = tag[:-3] + warnings.warn( + '.%sTag is deprecated, use .find("%s") instead.' % ( + tag_name, tag_name)) + return self.find(tag_name) + # We special case contents to avoid recursion. + elif not tag.startswith("__") and not tag=="contents": + return self.find(tag) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__, tag)) + + def __eq__(self, other): + """Returns true iff this tag has the same name, the same attributes, + and the same contents (recursively) as the given tag.""" + if self is other: + return True + if (not hasattr(other, 'name') or + not hasattr(other, 'attrs') or + not hasattr(other, 'contents') or + self.name != other.name or + self.attrs != other.attrs or + len(self) != len(other)): + return False + for i, my_child in enumerate(self.contents): + if my_child != other.contents[i]: + return False + return True + + def __ne__(self, other): + """Returns true iff this tag is not identical to the other tag, + as defined in __eq__.""" + return not self == other + + def __repr__(self, encoding=DEFAULT_OUTPUT_ENCODING): + """Renders this tag as a string.""" + return self.encode(encoding) + + def __unicode__(self): + return self.decode() + + def __str__(self): + return self.encode() + + if PY3K: + __str__ = __repr__ = __unicode__ + + def encode(self, encoding=DEFAULT_OUTPUT_ENCODING, + indent_level=None, formatter="minimal", + errors="xmlcharrefreplace"): + # Turn the data structure into Unicode, then encode the + # Unicode. + u = self.decode(indent_level, encoding, formatter) + return u.encode(encoding, errors) + + def decode(self, indent_level=None, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Returns a Unicode representation of this tag and its contents. + + :param eventual_encoding: The tag is destined to be + encoded into this encoding. This method is _not_ + responsible for performing that encoding. This information + is passed in so that it can be substituted in if the + document contains a <META> tag that mentions the document's + encoding. + """ + attrs = [] + if self.attrs: + for key, val in sorted(self.attrs.items()): + if val is None: + decoded = key + else: + if isinstance(val, list) or isinstance(val, tuple): + val = ' '.join(val) + elif not isinstance(val, basestring): + val = str(val) + elif ( + isinstance(val, AttributeValueWithCharsetSubstitution) + and eventual_encoding is not None): + val = val.encode(eventual_encoding) + + text = self.format_string(val, formatter) + decoded = ( + str(key) + '=' + + EntitySubstitution.quoted_attribute_value(text)) + attrs.append(decoded) + close = '' + closeTag = '' + if self.is_empty_element: + close = '/' + else: + closeTag = '</%s>' % self.name + + prefix = '' + if self.prefix: + prefix = self.prefix + ":" + + pretty_print = (indent_level is not None) + if pretty_print: + space = (' ' * (indent_level - 1)) + indent_contents = indent_level + 1 + else: + space = '' + indent_contents = None + contents = self.decode_contents( + indent_contents, eventual_encoding, formatter) + + if self.hidden: + # This is the 'document root' object. + s = contents + else: + s = [] + attribute_string = '' + if attrs: + attribute_string = ' ' + ' '.join(attrs) + if pretty_print: + s.append(space) + s.append('<%s%s%s%s>' % ( + prefix, self.name, attribute_string, close)) + if pretty_print: + s.append("\n") + s.append(contents) + if pretty_print and contents and contents[-1] != "\n": + s.append("\n") + if pretty_print and closeTag: + s.append(space) + s.append(closeTag) + if pretty_print and closeTag and self.next_sibling: + s.append("\n") + s = ''.join(s) + return s + + def prettify(self, encoding=None, formatter="minimal"): + if encoding is None: + return self.decode(True, formatter=formatter) + else: + return self.encode(encoding, True, formatter=formatter) + + def decode_contents(self, indent_level=None, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Renders the contents of this tag as a Unicode string. + + :param eventual_encoding: The tag is destined to be + encoded into this encoding. This method is _not_ + responsible for performing that encoding. This information + is passed in so that it can be substituted in if the + document contains a <META> tag that mentions the document's + encoding. + """ + pretty_print = (indent_level is not None) + s = [] + for c in self: + text = None + if isinstance(c, NavigableString): + text = c.output_ready(formatter) + elif isinstance(c, Tag): + s.append(c.decode(indent_level, eventual_encoding, + formatter)) + if text and indent_level: + text = text.strip() + if text: + if pretty_print: + s.append(" " * (indent_level - 1)) + s.append(text) + if pretty_print: + s.append("\n") + return ''.join(s) + + def encode_contents( + self, indent_level=None, encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Renders the contents of this tag as a bytestring.""" + contents = self.decode_contents(indent_level, encoding, formatter) + return contents.encode(encoding) + + # Old method for BS3 compatibility + def renderContents(self, encoding=DEFAULT_OUTPUT_ENCODING, + prettyPrint=False, indentLevel=0): + if not prettyPrint: + indentLevel = None + return self.encode_contents( + indent_level=indentLevel, encoding=encoding) + + #Soup methods + + def find(self, name=None, attrs={}, recursive=True, text=None, + **kwargs): + """Return only the first child of this Tag matching the given + criteria.""" + r = None + l = self.find_all(name, attrs, recursive, text, 1, **kwargs) + if l: + r = l[0] + return r + findChild = find + + def find_all(self, name=None, attrs={}, recursive=True, text=None, + limit=None, **kwargs): + """Extracts a list of Tag objects that match the given + criteria. You can specify the name of the Tag and any + attributes you want the Tag to have. + + The value of a key-value pair in the 'attrs' map can be a + string, a list of strings, a regular expression object, or a + callable that takes a string and returns whether or not the + string matches for some custom definition of 'matches'. The + same is true of the tag name.""" + generator = self.descendants + if not recursive: + generator = self.children + return self._find_all(name, attrs, text, limit, generator, **kwargs) + findAll = find_all # BS3 + findChildren = find_all # BS2 + + #Generator methods + @property + def children(self): + # return iter() to make the purpose of the method clear + return iter(self.contents) # XXX This seems to be untested. + + @property + def descendants(self): + if not len(self.contents): + return + stopNode = self._last_descendant().next_element + current = self.contents[0] + while current is not stopNode: + yield current + current = current.next_element + + # Old names for backwards compatibility + def childGenerator(self): + return self.children + + def recursiveChildGenerator(self): + return self.descendants + + # This was kind of misleading because has_key() (attributes) was + # different from __in__ (contents). has_key() is gone in Python 3, + # anyway. + has_key = has_attr + +# Next, a couple classes to represent queries and their results. +class SoupStrainer(object): + """Encapsulates a number of ways of matching a markup element (tag or + text).""" + + def __init__(self, name=None, attrs={}, text=None, **kwargs): + self.name = self._normalize_search_value(name) + if not isinstance(attrs, dict): + # Treat a non-dict value for attrs as a search for the 'class' + # attribute. + kwargs['class'] = attrs + attrs = None + + if kwargs: + if attrs: + attrs = attrs.copy() + attrs.update(kwargs) + else: + attrs = kwargs + normalized_attrs = {} + for key, value in attrs.items(): + normalized_attrs[key] = self._normalize_search_value(value) + + self.attrs = normalized_attrs + self.text = self._normalize_search_value(text) + + def _normalize_search_value(self, value): + # Leave it alone if it's a Unicode string, a callable, a + # regular expression, a boolean, or None. + if (isinstance(value, unicode) or callable(value) or hasattr(value, 'match') + or isinstance(value, bool) or value is None): + return value + + # If it's a bytestring, convert it to Unicode, treating it as UTF-8. + if isinstance(value, bytes): + return value.decode("utf8") + + # If it's listlike, convert it into a list of strings. + if hasattr(value, '__iter__'): + new_value = [] + for v in value: + if (hasattr(v, '__iter__') and not isinstance(v, bytes) + and not isinstance(v, unicode)): + # This is almost certainly the user's mistake. In the + # interests of avoiding infinite loops, we'll let + # it through as-is rather than doing a recursive call. + new_value.append(v) + else: + new_value.append(self._normalize_search_value(v)) + return new_value + + # Otherwise, convert it into a Unicode string. + # The unicode(str()) thing is so this will do the same thing on Python 2 + # and Python 3. + return unicode(str(value)) + + def __str__(self): + if self.text: + return self.text + else: + return "%s|%s" % (self.name, self.attrs) + + def search_tag(self, markup_name=None, markup_attrs={}): + found = None + markup = None + if isinstance(markup_name, Tag): + markup = markup_name + markup_attrs = markup + call_function_with_tag_data = ( + isinstance(self.name, collections.Callable) + and not isinstance(markup_name, Tag)) + + if ((not self.name) + or call_function_with_tag_data + or (markup and self._matches(markup, self.name)) + or (not markup and self._matches(markup_name, self.name))): + if call_function_with_tag_data: + match = self.name(markup_name, markup_attrs) + else: + match = True + markup_attr_map = None + for attr, match_against in list(self.attrs.items()): + if not markup_attr_map: + if hasattr(markup_attrs, 'get'): + markup_attr_map = markup_attrs + else: + markup_attr_map = {} + for k, v in markup_attrs: + markup_attr_map[k] = v + attr_value = markup_attr_map.get(attr) + if not self._matches(attr_value, match_against): + match = False + break + if match: + if markup: + found = markup + else: + found = markup_name + if found and self.text and not self._matches(found.string, self.text): + found = None + return found + searchTag = search_tag + + def search(self, markup): + # print 'looking for %s in %s' % (self, markup) + found = None + # If given a list of items, scan it for a text element that + # matches. + if hasattr(markup, '__iter__') and not isinstance(markup, (Tag, basestring)): + for element in markup: + if isinstance(element, NavigableString) \ + and self.search(element): + found = element + break + # If it's a Tag, make sure its name or attributes match. + # Don't bother with Tags if we're searching for text. + elif isinstance(markup, Tag): + if not self.text or self.name or self.attrs: + found = self.search_tag(markup) + # If it's text, make sure the text matches. + elif isinstance(markup, NavigableString) or \ + isinstance(markup, basestring): + if not self.name and not self.attrs and self._matches(markup, self.text): + found = markup + else: + raise Exception( + "I don't know how to match against a %s" % markup.__class__) + return found + + def _matches(self, markup, match_against): + # print u"Matching %s against %s" % (markup, match_against) + result = False + if isinstance(markup, list) or isinstance(markup, tuple): + # This should only happen when searching a multi-valued attribute + # like 'class'. + if (isinstance(match_against, unicode) + and ' ' in match_against): + # A bit of a special case. If they try to match "foo + # bar" on a multivalue attribute's value, only accept + # the literal value "foo bar" + # + # XXX This is going to be pretty slow because we keep + # splitting match_against. But it shouldn't come up + # too often. + return (whitespace_re.split(match_against) == markup) + else: + for item in markup: + if self._matches(item, match_against): + return True + return False + + if match_against is True: + # True matches any non-None value. + return markup is not None + + if isinstance(match_against, collections.Callable): + return match_against(markup) + + # Custom callables take the tag as an argument, but all + # other ways of matching match the tag name as a string. + if isinstance(markup, Tag): + markup = markup.name + + # Ensure that `markup` is either a Unicode string, or None. + markup = self._normalize_search_value(markup) + + if markup is None: + # None matches None, False, an empty string, an empty list, and so on. + return not match_against + + if isinstance(match_against, unicode): + # Exact string match + return markup == match_against + + if hasattr(match_against, 'match'): + # Regexp match + return match_against.search(markup) + + if hasattr(match_against, '__iter__'): + # The markup must be an exact match against something + # in the iterable. + return markup in match_against + + +class ResultSet(list): + """A ResultSet is just a list that keeps track of the SoupStrainer + that created it.""" + def __init__(self, source): + list.__init__([]) + self.source = source diff --git a/bs4/testing.py b/bs4/testing.py new file mode 100644 index 0000000..5a84b0b --- /dev/null +++ b/bs4/testing.py @@ -0,0 +1,515 @@ +"""Helper classes for tests.""" + +import copy +import functools +import unittest +from unittest import TestCase +from bs4 import BeautifulSoup +from bs4.element import ( + CharsetMetaAttributeValue, + Comment, + ContentMetaAttributeValue, + Doctype, + SoupStrainer, +) + +from bs4.builder import HTMLParserTreeBuilder +default_builder = HTMLParserTreeBuilder + + +class SoupTest(unittest.TestCase): + + @property + def default_builder(self): + return default_builder() + + def soup(self, markup, **kwargs): + """Build a Beautiful Soup object from markup.""" + builder = kwargs.pop('builder', self.default_builder) + return BeautifulSoup(markup, builder=builder, **kwargs) + + def document_for(self, markup): + """Turn an HTML fragment into a document. + + The details depend on the builder. + """ + return self.default_builder.test_fragment_to_document(markup) + + def assertSoupEquals(self, to_parse, compare_parsed_to=None): + builder = self.default_builder + obj = BeautifulSoup(to_parse, builder=builder) + if compare_parsed_to is None: + compare_parsed_to = to_parse + + self.assertEqual(obj.decode(), self.document_for(compare_parsed_to)) + + +class HTMLTreeBuilderSmokeTest(object): + + """A basic test of a treebuilder's competence. + + Any HTML treebuilder, present or future, should be able to pass + these tests. With invalid markup, there's room for interpretation, + and different parsers can handle it differently. But with the + markup in these tests, there's not much room for interpretation. + """ + + def assertDoctypeHandled(self, doctype_fragment): + """Assert that a given doctype string is handled correctly.""" + doctype_str, soup = self._document_with_doctype(doctype_fragment) + + # Make sure a Doctype object was created. + doctype = soup.contents[0] + self.assertEqual(doctype.__class__, Doctype) + self.assertEqual(doctype, doctype_fragment) + self.assertEqual(str(soup)[:len(doctype_str)], doctype_str) + + # Make sure that the doctype was correctly associated with the + # parse tree and that the rest of the document parsed. + self.assertEqual(soup.p.contents[0], 'foo') + + def _document_with_doctype(self, doctype_fragment): + """Generate and parse a document with the given doctype.""" + doctype = '<!DOCTYPE %s>' % doctype_fragment + markup = doctype + '\n<p>foo</p>' + soup = self.soup(markup) + return doctype, soup + + def test_normal_doctypes(self): + """Make sure normal, everyday HTML doctypes are handled correctly.""" + self.assertDoctypeHandled("html") + self.assertDoctypeHandled( + 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"') + + def test_public_doctype_with_url(self): + doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"' + self.assertDoctypeHandled(doctype) + + def test_system_doctype(self): + self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"') + + def test_namespaced_system_doctype(self): + # We can handle a namespaced doctype with a system ID. + self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"') + + def test_namespaced_public_doctype(self): + # Test a namespaced doctype with a public id. + self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"') + + def test_real_xhtml_document(self): + """A real XHTML document should come out more or less the same as it went in.""" + markup = b"""<?xml version="1.0" encoding="utf-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"> +<html xmlns="http://www.w3.org/1999/xhtml"> +<head><title>Hello. +Goodbye. +""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8").replace(b"\n", b""), + markup.replace(b"\n", b"")) + + def test_deepcopy(self): + """Make sure you can copy the tree builder. + + This is important because the builder is part of a + BeautifulSoup object, and we want to be able to copy that. + """ + copy.deepcopy(self.default_builder) + + def test_p_tag_is_never_empty_element(self): + """A

      tag is never designated as an empty-element tag. + + Even if the markup shows it as an empty-element tag, it + shouldn't be presented that way. + """ + soup = self.soup("

      ") + self.assertFalse(soup.p.is_empty_element) + self.assertEqual(str(soup.p), "

      ") + + def test_unclosed_tags_get_closed(self): + """A tag that's not closed by the end of the document should be closed. + + This applies to all tags except empty-element tags. + """ + self.assertSoupEquals("

      ", "

      ") + self.assertSoupEquals("", "") + + self.assertSoupEquals("
      ", "
      ") + + def test_br_is_always_empty_element_tag(self): + """A
      tag is designated as an empty-element tag. + + Some parsers treat

      as one
      tag, some parsers as + two tags, but it should always be an empty-element tag. + """ + soup = self.soup("

      ") + self.assertTrue(soup.br.is_empty_element) + self.assertEqual(str(soup.br), "
      ") + + def test_nested_formatting_elements(self): + self.assertSoupEquals("") + + def test_comment(self): + # Comments are represented as Comment objects. + markup = "

      foobaz

      " + self.assertSoupEquals(markup) + + soup = self.soup(markup) + comment = soup.find(text="foobar") + self.assertEqual(comment.__class__, Comment) + + def test_preserved_whitespace_in_pre_and_textarea(self): + """Whitespace must be preserved in
       and ")
      +
      +    def test_nested_inline_elements(self):
      +        """Inline elements can be nested indefinitely."""
      +        b_tag = "Inside a B tag"
      +        self.assertSoupEquals(b_tag)
      +
      +        nested_b_tag = "

      A nested tag

      " + self.assertSoupEquals(nested_b_tag) + + double_nested_b_tag = "

      A doubly nested tag

      " + self.assertSoupEquals(nested_b_tag) + + def test_nested_block_level_elements(self): + """Block elements can be nested.""" + soup = self.soup('

      Foo

      ') + blockquote = soup.blockquote + self.assertEqual(blockquote.p.b.string, 'Foo') + self.assertEqual(blockquote.b.string, 'Foo') + + def test_correctly_nested_tables(self): + """One table can go inside another one.""" + markup = ('' + '' + "') + + self.assertSoupEquals( + markup, + '
      Here's another table:" + '' + '' + '
      foo
      Here\'s another table:' + '
      foo
      ' + '
      ') + + self.assertSoupEquals( + "" + "" + "
      Foo
      Bar
      Baz
      ") + + def test_angle_brackets_in_attribute_values_are_escaped(self): + self.assertSoupEquals('', '') + + def test_entities_in_attributes_converted_to_unicode(self): + expect = u'

      ' + self.assertSoupEquals('

      ', expect) + self.assertSoupEquals('

      ', expect) + self.assertSoupEquals('

      ', expect) + + def test_entities_in_text_converted_to_unicode(self): + expect = u'

      pi\N{LATIN SMALL LETTER N WITH TILDE}ata

      ' + self.assertSoupEquals("

      piñata

      ", expect) + self.assertSoupEquals("

      piñata

      ", expect) + self.assertSoupEquals("

      piñata

      ", expect) + + def test_quot_entity_converted_to_quotation_mark(self): + self.assertSoupEquals("

      I said "good day!"

      ", + '

      I said "good day!"

      ') + + def test_out_of_range_entity(self): + expect = u"\N{REPLACEMENT CHARACTER}" + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + + def test_basic_namespaces(self): + """Parsers don't need to *understand* namespaces, but at the + very least they should not choke on namespaces or lose + data.""" + + markup = b'4' + soup = self.soup(markup) + self.assertEqual(markup, soup.encode()) + html = soup.html + self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns']) + self.assertEqual( + 'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml']) + self.assertEqual( + 'http://www.w3.org/2000/svg', soup.html['xmlns:svg']) + + def test_multivalued_attribute_value_becomes_list(self): + markup = b'' + soup = self.soup(markup) + self.assertEqual(['foo', 'bar'], soup.a['class']) + + # + # Generally speaking, tests below this point are more tests of + # Beautiful Soup than tests of the tree builders. But parsers are + # weird, so we run these tests separately for every tree builder + # to detect any differences between them. + # + + def test_soupstrainer(self): + """Parsers should be able to work with SoupStrainers.""" + strainer = SoupStrainer("b") + soup = self.soup("A bold statement", + parse_only=strainer) + self.assertEqual(soup.decode(), "bold") + + def test_single_quote_attribute_values_become_double_quotes(self): + self.assertSoupEquals("", + '') + + def test_attribute_values_with_nested_quotes_are_left_alone(self): + text = """a""" + self.assertSoupEquals(text) + + def test_attribute_values_with_double_nested_quotes_get_quoted(self): + text = """a""" + soup = self.soup(text) + soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"' + self.assertSoupEquals( + soup.foo.decode(), + """a""") + + def test_ampersand_in_attribute_value_gets_escaped(self): + self.assertSoupEquals('', + '') + + self.assertSoupEquals( + 'foo', + 'foo') + + def test_escaped_ampersand_in_attribute_value_is_left_alone(self): + self.assertSoupEquals('') + + def test_entities_in_strings_converted_during_parsing(self): + # Both XML and HTML entities are converted to Unicode characters + # during parsing. + text = "

      <<sacré bleu!>>

      " + expected = u"

      <<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>

      " + self.assertSoupEquals(text, expected) + + def test_smart_quotes_converted_on_the_way_in(self): + # Microsoft smart quotes are converted to Unicode characters during + # parsing. + quote = b"

      \x91Foo\x92

      " + soup = self.soup(quote) + self.assertEqual( + soup.p.string, + u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}") + + def test_non_breaking_spaces_converted_on_the_way_in(self): + soup = self.soup("  ") + self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2) + + def test_entities_converted_on_the_way_out(self): + text = "

      <<sacré bleu!>>

      " + expected = u"

      <<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>

      ".encode("utf-8") + soup = self.soup(text) + self.assertEqual(soup.p.encode("utf-8"), expected) + + def test_real_iso_latin_document(self): + # Smoke test of interrelated functionality, using an + # easy-to-understand document. + + # Here it is in Unicode. Note that it claims to be in ISO-Latin-1. + unicode_html = u'

      Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!

      ' + + # That's because we're going to encode it into ISO-Latin-1, and use + # that to test. + iso_latin_html = unicode_html.encode("iso-8859-1") + + # Parse the ISO-Latin-1 HTML. + soup = self.soup(iso_latin_html) + # Encode it to UTF-8. + result = soup.encode("utf-8") + + # What do we expect the result to look like? Well, it would + # look like unicode_html, except that the META tag would say + # UTF-8 instead of ISO-Latin-1. + expected = unicode_html.replace("ISO-Latin-1", "utf-8") + + # And, of course, it would be in UTF-8, not Unicode. + expected = expected.encode("utf-8") + + # Ta-da! + self.assertEqual(result, expected) + + def test_real_shift_jis_document(self): + # Smoke test to make sure the parser can handle a document in + # Shift-JIS encoding, without choking. + shift_jis_html = ( + b'
      '
      +            b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
      +            b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
      +            b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
      +            b'
      ') + unicode_html = shift_jis_html.decode("shift-jis") + soup = self.soup(unicode_html) + + # Make sure the parse tree is correctly encoded to various + # encodings. + self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8")) + self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp")) + + def test_real_hebrew_document(self): + # A real-world test to make sure we can convert ISO-8859-9 (a + # Hebrew encoding) to UTF-8. + hebrew_document = b'Hebrew (ISO 8859-8) in Visual Directionality

      Hebrew (ISO 8859-8) in Visual Directionality

      \xed\xe5\xec\xf9' + soup = self.soup( + hebrew_document, from_encoding="iso8859-8") + self.assertEqual(soup.original_encoding, 'iso8859-8') + self.assertEqual( + soup.encode('utf-8'), + hebrew_document.decode("iso8859-8").encode("utf-8")) + + def test_meta_tag_reflects_current_encoding(self): + # Here's the tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '\n%s\n' + '' + 'Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'}) + content = parsed_meta['content'] + self.assertEqual('text/html; charset=x-sjis', content) + + # But that value is actually a ContentMetaAttributeValue object. + self.assertTrue(isinstance(content, ContentMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('text/html; charset=utf8', content.encode("utf8")) + + # For the rest of the story, see TestSubstitutions in + # test_tree.py. + + def test_html5_style_meta_tag_reflects_current_encoding(self): + # Here's the tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '\n%s\n' + '' + 'Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', id="encoding") + charset = parsed_meta['charset'] + self.assertEqual('x-sjis', charset) + + # But that value is actually a CharsetMetaAttributeValue object. + self.assertTrue(isinstance(charset, CharsetMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('utf8', charset.encode("utf8")) + + def test_tag_with_no_attributes_can_have_attributes_added(self): + data = self.soup("text") + data.a['foo'] = 'bar' + self.assertEqual('text', data.a.decode()) + +class XMLTreeBuilderSmokeTest(object): + + def test_docstring_generated(self): + soup = self.soup("") + self.assertEqual( + soup.encode(), b'\n') + + def test_real_xhtml_document(self): + """A real XHTML document should come out *exactly* the same as it went in.""" + markup = b""" + + +Hello. +Goodbye. +""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8"), markup) + + + def test_docstring_includes_correct_encoding(self): + soup = self.soup("") + self.assertEqual( + soup.encode("latin1"), + b'\n') + + def test_large_xml_document(self): + """A large XML document should come out the same as it went in.""" + markup = (b'\n' + + b'0' * (2**12) + + b'') + soup = self.soup(markup) + self.assertEqual(soup.encode("utf-8"), markup) + + + def test_tags_are_empty_element_if_and_only_if_they_are_empty(self): + self.assertSoupEquals("

      ", "

      ") + self.assertSoupEquals("

      foo

      ") + + def test_namespaces_are_preserved(self): + markup = 'This tag is in the a namespaceThis tag is in the b namespace' + soup = self.soup(markup) + root = soup.root + self.assertEqual("http://example.com/", root['xmlns:a']) + self.assertEqual("http://example.net/", root['xmlns:b']) + + +class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest): + """Smoke test for a tree builder that supports HTML5.""" + + def test_real_xhtml_document(self): + # Since XHTML is not HTML5, HTML5 parsers are not tested to handle + # XHTML documents in any particular way. + pass + + def test_html_tags_have_namespace(self): + markup = "" + soup = self.soup(markup) + self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace) + + def test_svg_tags_have_namespace(self): + markup = '' + soup = self.soup(markup) + namespace = "http://www.w3.org/2000/svg" + self.assertEqual(namespace, soup.svg.namespace) + self.assertEqual(namespace, soup.circle.namespace) + + + def test_mathml_tags_have_namespace(self): + markup = '5' + soup = self.soup(markup) + namespace = 'http://www.w3.org/1998/Math/MathML' + self.assertEqual(namespace, soup.math.namespace) + self.assertEqual(namespace, soup.msqrt.namespace) + + +def skipIf(condition, reason): + def nothing(test, *args, **kwargs): + return None + + def decorator(test_item): + if condition: + return nothing + else: + return test_item + + return decorator diff --git a/bs4/tests/__init__.py b/bs4/tests/__init__.py new file mode 100644 index 0000000..142c8cc --- /dev/null +++ b/bs4/tests/__init__.py @@ -0,0 +1 @@ +"The beautifulsoup tests." diff --git a/bs4/tests/test_builder_registry.py b/bs4/tests/test_builder_registry.py new file mode 100644 index 0000000..92ad10f --- /dev/null +++ b/bs4/tests/test_builder_registry.py @@ -0,0 +1,141 @@ +"""Tests of the builder registry.""" + +import unittest + +from bs4 import BeautifulSoup +from bs4.builder import ( + builder_registry as registry, + HTMLParserTreeBuilder, + TreeBuilderRegistry, +) + +try: + from bs4.builder import HTML5TreeBuilder + HTML5LIB_PRESENT = True +except ImportError: + HTML5LIB_PRESENT = False + +try: + from bs4.builder import ( + LXMLTreeBuilderForXML, + LXMLTreeBuilder, + ) + LXML_PRESENT = True +except ImportError: + LXML_PRESENT = False + + +class BuiltInRegistryTest(unittest.TestCase): + """Test the built-in registry with the default builders registered.""" + + def test_combination(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('fast', 'html'), + LXMLTreeBuilder) + + if LXML_PRESENT: + self.assertEqual(registry.lookup('permissive', 'xml'), + LXMLTreeBuilderForXML) + self.assertEqual(registry.lookup('strict', 'html'), + HTMLParserTreeBuilder) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html5lib', 'html'), + HTML5TreeBuilder) + + def test_lookup_by_markup_type(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('html'), LXMLTreeBuilder) + self.assertEqual(registry.lookup('xml'), LXMLTreeBuilderForXML) + else: + self.assertEqual(registry.lookup('xml'), None) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html'), HTML5TreeBuilder) + else: + self.assertEqual(registry.lookup('html'), HTMLParserTreeBuilder) + + def test_named_library(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('lxml', 'xml'), + LXMLTreeBuilderForXML) + self.assertEqual(registry.lookup('lxml', 'html'), + LXMLTreeBuilder) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html5lib'), + HTML5TreeBuilder) + + self.assertEqual(registry.lookup('html.parser'), + HTMLParserTreeBuilder) + + def test_beautifulsoup_constructor_does_lookup(self): + # You can pass in a string. + BeautifulSoup("", features="html") + # Or a list of strings. + BeautifulSoup("", features=["html", "fast"]) + + # You'll get an exception if BS can't find an appropriate + # builder. + self.assertRaises(ValueError, BeautifulSoup, + "", features="no-such-feature") + +class RegistryTest(unittest.TestCase): + """Test the TreeBuilderRegistry class in general.""" + + def setUp(self): + self.registry = TreeBuilderRegistry() + + def builder_for_features(self, *feature_list): + cls = type('Builder_' + '_'.join(feature_list), + (object,), {'features' : feature_list}) + + self.registry.register(cls) + return cls + + def test_register_with_no_features(self): + builder = self.builder_for_features() + + # Since the builder advertises no features, you can't find it + # by looking up features. + self.assertEqual(self.registry.lookup('foo'), None) + + # But you can find it by doing a lookup with no features, if + # this happens to be the only registered builder. + self.assertEqual(self.registry.lookup(), builder) + + def test_register_with_features_makes_lookup_succeed(self): + builder = self.builder_for_features('foo', 'bar') + self.assertEqual(self.registry.lookup('foo'), builder) + self.assertEqual(self.registry.lookup('bar'), builder) + + def test_lookup_fails_when_no_builder_implements_feature(self): + builder = self.builder_for_features('foo', 'bar') + self.assertEqual(self.registry.lookup('baz'), None) + + def test_lookup_gets_most_recent_registration_when_no_feature_specified(self): + builder1 = self.builder_for_features('foo') + builder2 = self.builder_for_features('bar') + self.assertEqual(self.registry.lookup(), builder2) + + def test_lookup_fails_when_no_tree_builders_registered(self): + self.assertEqual(self.registry.lookup(), None) + + def test_lookup_gets_most_recent_builder_supporting_all_features(self): + has_one = self.builder_for_features('foo') + has_the_other = self.builder_for_features('bar') + has_both_early = self.builder_for_features('foo', 'bar', 'baz') + has_both_late = self.builder_for_features('foo', 'bar', 'quux') + lacks_one = self.builder_for_features('bar') + has_the_other = self.builder_for_features('foo') + + # There are two builders featuring 'foo' and 'bar', but + # the one that also features 'quux' was registered later. + self.assertEqual(self.registry.lookup('foo', 'bar'), + has_both_late) + + # There is only one builder featuring 'foo', 'bar', and 'baz'. + self.assertEqual(self.registry.lookup('foo', 'bar', 'baz'), + has_both_early) + + def test_lookup_fails_when_cannot_reconcile_requested_features(self): + builder1 = self.builder_for_features('foo', 'bar') + builder2 = self.builder_for_features('foo', 'baz') + self.assertEqual(self.registry.lookup('bar', 'baz'), None) diff --git a/bs4/tests/test_docs.py b/bs4/tests/test_docs.py new file mode 100644 index 0000000..5b9f677 --- /dev/null +++ b/bs4/tests/test_docs.py @@ -0,0 +1,36 @@ +"Test harness for doctests." + +# pylint: disable-msg=E0611,W0142 + +__metaclass__ = type +__all__ = [ + 'additional_tests', + ] + +import atexit +import doctest +import os +#from pkg_resources import ( +# resource_filename, resource_exists, resource_listdir, cleanup_resources) +import unittest + +DOCTEST_FLAGS = ( + doctest.ELLIPSIS | + doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) + + +# def additional_tests(): +# "Run the doc tests (README.txt and docs/*, if any exist)" +# doctest_files = [ +# os.path.abspath(resource_filename('bs4', 'README.txt'))] +# if resource_exists('bs4', 'docs'): +# for name in resource_listdir('bs4', 'docs'): +# if name.endswith('.txt'): +# doctest_files.append( +# os.path.abspath( +# resource_filename('bs4', 'docs/%s' % name))) +# kwargs = dict(module_relative=False, optionflags=DOCTEST_FLAGS) +# atexit.register(cleanup_resources) +# return unittest.TestSuite(( +# doctest.DocFileSuite(*doctest_files, **kwargs))) diff --git a/bs4/tests/test_html5lib.py b/bs4/tests/test_html5lib.py new file mode 100644 index 0000000..f195f7d --- /dev/null +++ b/bs4/tests/test_html5lib.py @@ -0,0 +1,58 @@ +"""Tests to ensure that the html5lib tree builder generates good trees.""" + +import warnings + +try: + from bs4.builder import HTML5TreeBuilder + HTML5LIB_PRESENT = True +except ImportError, e: + HTML5LIB_PRESENT = False +from bs4.element import SoupStrainer +from bs4.testing import ( + HTML5TreeBuilderSmokeTest, + SoupTest, + skipIf, +) + +@skipIf( + not HTML5LIB_PRESENT, + "html5lib seems not to be present, not testing its tree builder.") +class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest): + """See ``HTML5TreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return HTML5TreeBuilder() + + def test_soupstrainer(self): + # The html5lib tree builder does not support SoupStrainers. + strainer = SoupStrainer("b") + markup = "

      A bold statement.

      " + with warnings.catch_warnings(record=True) as w: + soup = self.soup(markup, parse_only=strainer) + self.assertEqual( + soup.decode(), self.document_for(markup)) + + self.assertTrue( + "the html5lib tree builder doesn't support parse_only" in + str(w[0].message)) + + def test_correctly_nested_tables(self): + """html5lib inserts tags where other parsers don't.""" + markup = ('' + '' + "') + + self.assertSoupEquals( + markup, + '
      Here's another table:" + '' + '' + '
      foo
      Here\'s another table:' + '
      foo
      ' + '
      ') + + self.assertSoupEquals( + "" + "" + "
      Foo
      Bar
      Baz
      ") diff --git a/bs4/tests/test_htmlparser.py b/bs4/tests/test_htmlparser.py new file mode 100644 index 0000000..bcb5ed2 --- /dev/null +++ b/bs4/tests/test_htmlparser.py @@ -0,0 +1,19 @@ +"""Tests to ensure that the html.parser tree builder generates good +trees.""" + +from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest +from bs4.builder import HTMLParserTreeBuilder + +class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest): + + @property + def default_builder(self): + return HTMLParserTreeBuilder() + + def test_namespaced_system_doctype(self): + # html.parser can't handle namespaced doctypes, so skip this one. + pass + + def test_namespaced_public_doctype(self): + # html.parser can't handle namespaced doctypes, so skip this one. + pass diff --git a/bs4/tests/test_lxml.py b/bs4/tests/test_lxml.py new file mode 100644 index 0000000..39e26bf --- /dev/null +++ b/bs4/tests/test_lxml.py @@ -0,0 +1,75 @@ +"""Tests to ensure that the lxml tree builder generates good trees.""" + +import re +import warnings + +try: + from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML + LXML_PRESENT = True +except ImportError, e: + LXML_PRESENT = False + +from bs4 import ( + BeautifulSoup, + BeautifulStoneSoup, + ) +from bs4.element import Comment, Doctype, SoupStrainer +from bs4.testing import skipIf +from bs4.tests import test_htmlparser +from bs4.testing import ( + HTMLTreeBuilderSmokeTest, + XMLTreeBuilderSmokeTest, + SoupTest, + skipIf, +) + +@skipIf( + not LXML_PRESENT, + "lxml seems not to be present, not testing its tree builder.") +class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest): + """See ``HTMLTreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return LXMLTreeBuilder() + + def test_out_of_range_entity(self): + self.assertSoupEquals( + "

      foo�bar

      ", "

      foobar

      ") + self.assertSoupEquals( + "

      foo�bar

      ", "

      foobar

      ") + self.assertSoupEquals( + "

      foo�bar

      ", "

      foobar

      ") + + def test_beautifulstonesoup_is_xml_parser(self): + # Make sure that the deprecated BSS class uses an xml builder + # if one is installed. + with warnings.catch_warnings(record=False) as w: + soup = BeautifulStoneSoup("") + self.assertEqual(u"", unicode(soup.b)) + + def test_real_xhtml_document(self): + """lxml strips the XML definition from an XHTML doc, which is fine.""" + markup = b""" + + +Hello. +Goodbye. +""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8").replace(b"\n", b''), + markup.replace(b'\n', b'').replace( + b'', b'')) + + +@skipIf( + not LXML_PRESENT, + "lxml seems not to be present, not testing its XML tree builder.") +class LXMLXMLTreeBuilderSmokeTest(SoupTest, XMLTreeBuilderSmokeTest): + """See ``HTMLTreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return LXMLTreeBuilderForXML() + diff --git a/bs4/tests/test_soup.py b/bs4/tests/test_soup.py new file mode 100644 index 0000000..23a664e --- /dev/null +++ b/bs4/tests/test_soup.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- +"""Tests of Beautiful Soup as a whole.""" + +import unittest +from bs4 import ( + BeautifulSoup, + BeautifulStoneSoup, +) +from bs4.element import ( + CharsetMetaAttributeValue, + ContentMetaAttributeValue, + SoupStrainer, + NamespacedAttribute, + ) +import bs4.dammit +from bs4.dammit import EntitySubstitution, UnicodeDammit +from bs4.testing import ( + SoupTest, + skipIf, +) +import warnings + +try: + from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML + LXML_PRESENT = True +except ImportError, e: + LXML_PRESENT = False + +class TestDeprecatedConstructorArguments(SoupTest): + + def test_parseOnlyThese_renamed_to_parse_only(self): + with warnings.catch_warnings(record=True) as w: + soup = self.soup("
      ", parseOnlyThese=SoupStrainer("b")) + msg = str(w[0].message) + self.assertTrue("parseOnlyThese" in msg) + self.assertTrue("parse_only" in msg) + self.assertEqual(b"", soup.encode()) + + def test_fromEncoding_renamed_to_from_encoding(self): + with warnings.catch_warnings(record=True) as w: + utf8 = b"\xc3\xa9" + soup = self.soup(utf8, fromEncoding="utf8") + msg = str(w[0].message) + self.assertTrue("fromEncoding" in msg) + self.assertTrue("from_encoding" in msg) + self.assertEqual("utf8", soup.original_encoding) + + def test_unrecognized_keyword_argument(self): + self.assertRaises( + TypeError, self.soup, "", no_such_argument=True) + + @skipIf( + not LXML_PRESENT, + "lxml not present, not testing BeautifulStoneSoup.") + def test_beautifulstonesoup(self): + with warnings.catch_warnings(record=True) as w: + soup = BeautifulStoneSoup("") + self.assertTrue(isinstance(soup, BeautifulSoup)) + self.assertTrue("BeautifulStoneSoup class is deprecated") + +class TestSelectiveParsing(SoupTest): + + def test_parse_with_soupstrainer(self): + markup = "NoYesNoYes Yes" + strainer = SoupStrainer("b") + soup = self.soup(markup, parse_only=strainer) + self.assertEqual(soup.encode(), b"YesYes Yes") + + +class TestEntitySubstitution(unittest.TestCase): + """Standalone tests of the EntitySubstitution class.""" + def setUp(self): + self.sub = EntitySubstitution + + def test_simple_html_substitution(self): + # Unicode characters corresponding to named HTML entites + # are substituted, and no others. + s = u"foo\u2200\N{SNOWMAN}\u00f5bar" + self.assertEqual(self.sub.substitute_html(s), + u"foo∀\N{SNOWMAN}õbar") + + def test_smart_quote_substitution(self): + # MS smart quotes are a common source of frustration, so we + # give them a special test. + quotes = b"\x91\x92foo\x93\x94" + dammit = UnicodeDammit(quotes) + self.assertEqual(self.sub.substitute_html(dammit.markup), + "‘’foo“”") + + def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self): + s = 'Welcome to "my bar"' + self.assertEqual(self.sub.substitute_xml(s, False), s) + + def test_xml_attribute_quoting_normally_uses_double_quotes(self): + self.assertEqual(self.sub.substitute_xml("Welcome", True), + '"Welcome"') + self.assertEqual(self.sub.substitute_xml("Bob's Bar", True), + '"Bob\'s Bar"') + + def test_xml_attribute_quoting_uses_single_quotes_when_value_contains_double_quotes(self): + s = 'Welcome to "my bar"' + self.assertEqual(self.sub.substitute_xml(s, True), + "'Welcome to \"my bar\"'") + + def test_xml_attribute_quoting_escapes_single_quotes_when_value_contains_both_single_and_double_quotes(self): + s = 'Welcome to "Bob\'s Bar"' + self.assertEqual( + self.sub.substitute_xml(s, True), + '"Welcome to "Bob\'s Bar""') + + def test_xml_quotes_arent_escaped_when_value_is_not_being_quoted(self): + quoted = 'Welcome to "Bob\'s Bar"' + self.assertEqual(self.sub.substitute_xml(quoted), quoted) + + def test_xml_quoting_handles_angle_brackets(self): + self.assertEqual( + self.sub.substitute_xml("foo"), + "foo<bar>") + + def test_xml_quoting_handles_ampersands(self): + self.assertEqual(self.sub.substitute_xml("AT&T"), "AT&T") + + def test_xml_quoting_ignores_ampersands_when_they_are_part_of_an_entity(self): + self.assertEqual( + self.sub.substitute_xml("ÁT&T"), + "ÁT&T") + + def test_quotes_not_html_substituted(self): + """There's no need to do this except inside attribute values.""" + text = 'Bob\'s "bar"' + self.assertEqual(self.sub.substitute_html(text), text) + + +class TestEncodingConversion(SoupTest): + # Test Beautiful Soup's ability to decode and encode from various + # encodings. + + def setUp(self): + super(TestEncodingConversion, self).setUp() + self.unicode_data = u"Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!" + self.utf8_data = self.unicode_data.encode("utf-8") + # Just so you know what it looks like. + self.assertEqual( + self.utf8_data, + b"Sacr\xc3\xa9 bleu!") + + def test_ascii_in_unicode_out(self): + # ASCII input is converted to Unicode. The original_encoding + # attribute is set. + ascii = b"a" + soup_from_ascii = self.soup(ascii) + unicode_output = soup_from_ascii.decode() + self.assertTrue(isinstance(unicode_output, unicode)) + self.assertEqual(unicode_output, self.document_for(ascii.decode())) + self.assertEqual(soup_from_ascii.original_encoding, "ascii") + + def test_unicode_in_unicode_out(self): + # Unicode input is left alone. The original_encoding attribute + # is not set. + soup_from_unicode = self.soup(self.unicode_data) + self.assertEqual(soup_from_unicode.decode(), self.unicode_data) + self.assertEqual(soup_from_unicode.foo.string, u'Sacr\xe9 bleu!') + self.assertEqual(soup_from_unicode.original_encoding, None) + + def test_utf8_in_unicode_out(self): + # UTF-8 input is converted to Unicode. The original_encoding + # attribute is set. + soup_from_utf8 = self.soup(self.utf8_data) + self.assertEqual(soup_from_utf8.decode(), self.unicode_data) + self.assertEqual(soup_from_utf8.foo.string, u'Sacr\xe9 bleu!') + + def test_utf8_out(self): + # The internal data structures can be encoded as UTF-8. + soup_from_unicode = self.soup(self.unicode_data) + self.assertEqual(soup_from_unicode.encode('utf-8'), self.utf8_data) + + +class TestUnicodeDammit(unittest.TestCase): + """Standalone tests of Unicode, Dammit.""" + + def test_smart_quotes_to_unicode(self): + markup = b"\x91\x92\x93\x94" + dammit = UnicodeDammit(markup) + self.assertEqual( + dammit.unicode_markup, u"\u2018\u2019\u201c\u201d") + + def test_smart_quotes_to_xml_entities(self): + markup = b"\x91\x92\x93\x94" + dammit = UnicodeDammit(markup, smart_quotes_to="xml") + self.assertEqual( + dammit.unicode_markup, "‘’“”") + + def test_smart_quotes_to_html_entities(self): + markup = b"\x91\x92\x93\x94" + dammit = UnicodeDammit(markup, smart_quotes_to="html") + self.assertEqual( + dammit.unicode_markup, "‘’“”") + + def test_smart_quotes_to_ascii(self): + markup = b"\x91\x92\x93\x94" + dammit = UnicodeDammit(markup, smart_quotes_to="ascii") + self.assertEqual( + dammit.unicode_markup, """''""""") + + def test_detect_utf8(self): + utf8 = b"\xc3\xa9" + dammit = UnicodeDammit(utf8) + self.assertEqual(dammit.unicode_markup, u'\xe9') + self.assertEqual(dammit.original_encoding, 'utf-8') + + def test_convert_hebrew(self): + hebrew = b"\xed\xe5\xec\xf9" + dammit = UnicodeDammit(hebrew, ["iso-8859-8"]) + self.assertEqual(dammit.original_encoding, 'iso-8859-8') + self.assertEqual(dammit.unicode_markup, u'\u05dd\u05d5\u05dc\u05e9') + + def test_dont_see_smart_quotes_where_there_are_none(self): + utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch" + dammit = UnicodeDammit(utf_8) + self.assertEqual(dammit.original_encoding, 'utf-8') + self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8) + + def test_ignore_inappropriate_codecs(self): + utf8_data = u"Räksmörgås".encode("utf-8") + dammit = UnicodeDammit(utf8_data, ["iso-8859-8"]) + self.assertEqual(dammit.original_encoding, 'utf-8') + + def test_ignore_invalid_codecs(self): + utf8_data = u"Räksmörgås".encode("utf-8") + for bad_encoding in ['.utf8', '...', 'utF---16.!']: + dammit = UnicodeDammit(utf8_data, [bad_encoding]) + self.assertEqual(dammit.original_encoding, 'utf-8') + + def test_detect_html5_style_meta_tag(self): + + for data in ( + b'', + b"", + b"", + b""): + dammit = UnicodeDammit(data, is_html=True) + self.assertEqual( + "euc-jp", dammit.original_encoding) + + def test_last_ditch_entity_replacement(self): + # This is a UTF-8 document that contains bytestrings + # completely incompatible with UTF-8 (ie. encoded with some other + # encoding). + # + # Since there is no consistent encoding for the document, + # Unicode, Dammit will eventually encode the document as UTF-8 + # and encode the incompatible characters as REPLACEMENT + # CHARACTER. + # + # If chardet is installed, it will detect that the document + # can be converted into ISO-8859-1 without errors. This happens + # to be the wrong encoding, but it is a consistent encoding, so the + # code we're testing here won't run. + # + # So we temporarily disable chardet if it's present. + doc = b"""\357\273\277 +\330\250\330\252\330\261 +\310\322\321\220\312\321\355\344""" + chardet = bs4.dammit.chardet + try: + bs4.dammit.chardet = None + with warnings.catch_warnings(record=True) as w: + dammit = UnicodeDammit(doc) + self.assertEqual(True, dammit.contains_replacement_characters) + self.assertTrue(u"\ufffd" in dammit.unicode_markup) + + soup = BeautifulSoup(doc, "html.parser") + self.assertTrue(soup.contains_replacement_characters) + + msg = w[0].message + self.assertTrue(isinstance(msg, UnicodeWarning)) + self.assertTrue("Some characters could not be decoded" in str(msg)) + finally: + bs4.dammit.chardet = chardet + + def test_sniffed_xml_encoding(self): + # A document written in UTF-16LE will be converted by a different + # code path that sniffs the byte order markers. + data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00' + dammit = UnicodeDammit(data) + self.assertEqual(u"áé", dammit.unicode_markup) + self.assertEqual("utf-16le", dammit.original_encoding) + + def test_detwingle(self): + # Here's a UTF8 document. + utf8 = (u"\N{SNOWMAN}" * 3).encode("utf8") + + # Here's a Windows-1252 document. + windows_1252 = ( + u"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!" + u"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252") + + # Through some unholy alchemy, they've been stuck together. + doc = utf8 + windows_1252 + utf8 + + # The document can't be turned into UTF-8: + self.assertRaises(UnicodeDecodeError, doc.decode, "utf8") + + # Unicode, Dammit thinks the whole document is Windows-1252, + # and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃" + + # But if we run it through fix_embedded_windows_1252, it's fixed: + + fixed = UnicodeDammit.detwingle(doc) + self.assertEqual( + u"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8")) + + def test_detwingle_ignores_multibyte_characters(self): + # Each of these characters has a UTF-8 representation ending + # in \x93. \x93 is a smart quote if interpreted as + # Windows-1252. But our code knows to skip over multibyte + # UTF-8 characters, so they'll survive the process unscathed. + for tricky_unicode_char in ( + u"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93' + u"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93' + u"\xf0\x90\x90\x93", # This is a CJK character, not sure which one. + ): + input = tricky_unicode_char.encode("utf8") + self.assertTrue(input.endswith(b'\x93')) + output = UnicodeDammit.detwingle(input) + self.assertEqual(output, input) + +class TestNamedspacedAttribute(SoupTest): + + def test_name_may_be_none(self): + a = NamespacedAttribute("xmlns", None) + self.assertEqual(a, "xmlns") + + def test_attribute_is_equivalent_to_colon_separated_string(self): + a = NamespacedAttribute("a", "b") + self.assertEqual("a:b", a) + + def test_attributes_are_equivalent_if_prefix_and_name_identical(self): + a = NamespacedAttribute("a", "b", "c") + b = NamespacedAttribute("a", "b", "c") + self.assertEqual(a, b) + + # The actual namespace is not considered. + c = NamespacedAttribute("a", "b", None) + self.assertEqual(a, c) + + # But name and prefix are important. + d = NamespacedAttribute("a", "z", "c") + self.assertNotEqual(a, d) + + e = NamespacedAttribute("z", "b", "c") + self.assertNotEqual(a, e) + + +class TestAttributeValueWithCharsetSubstitution(unittest.TestCase): + + def test_content_meta_attribute_value(self): + value = CharsetMetaAttributeValue("euc-jp") + self.assertEqual("euc-jp", value) + self.assertEqual("euc-jp", value.original_value) + self.assertEqual("utf8", value.encode("utf8")) + + + def test_content_meta_attribute_value(self): + value = ContentMetaAttributeValue("text/html; charset=euc-jp") + self.assertEqual("text/html; charset=euc-jp", value) + self.assertEqual("text/html; charset=euc-jp", value.original_value) + self.assertEqual("text/html; charset=utf8", value.encode("utf8")) diff --git a/bs4/tests/test_tree.py b/bs4/tests/test_tree.py new file mode 100644 index 0000000..cc573ed --- /dev/null +++ b/bs4/tests/test_tree.py @@ -0,0 +1,1695 @@ +# -*- coding: utf-8 -*- +"""Tests for Beautiful Soup's tree traversal methods. + +The tree traversal methods are the main advantage of using Beautiful +Soup over just using a parser. + +Different parsers will build different Beautiful Soup trees given the +same markup, but all Beautiful Soup trees can be traversed with the +methods tested here. +""" + +import copy +import pickle +import re +import warnings +from bs4 import BeautifulSoup +from bs4.builder import ( + builder_registry, + HTMLParserTreeBuilder, +) +from bs4.element import ( + CData, + Doctype, + NavigableString, + SoupStrainer, + Tag, +) +from bs4.testing import ( + SoupTest, + skipIf, +) + +XML_BUILDER_PRESENT = (builder_registry.lookup("xml") is not None) +LXML_PRESENT = (builder_registry.lookup("lxml") is not None) + +class TreeTest(SoupTest): + + def assertSelects(self, tags, should_match): + """Make sure that the given tags have the correct text. + + This is used in tests that define a bunch of tags, each + containing a single string, and then select certain strings by + some mechanism. + """ + self.assertEqual([tag.string for tag in tags], should_match) + + def assertSelectsIDs(self, tags, should_match): + """Make sure that the given tags have the correct IDs. + + This is used in tests that define a bunch of tags, each + containing a single string, and then select certain strings by + some mechanism. + """ + self.assertEqual([tag['id'] for tag in tags], should_match) + + +class TestFind(TreeTest): + """Basic tests of the find() method. + + find() just calls find_all() with limit=1, so it's not tested all + that thouroughly here. + """ + + def test_find_tag(self): + soup = self.soup("1234") + self.assertEqual(soup.find("b").string, "2") + + def test_unicode_text_find(self): + soup = self.soup(u'

      Räksmörgås

      ') + self.assertEqual(soup.find(text=u'Räksmörgås'), u'Räksmörgås') + +class TestFindAll(TreeTest): + """Basic tests of the find_all() method.""" + + def test_find_all_text_nodes(self): + """You can search the tree for text nodes.""" + soup = self.soup("Foobar\xbb") + # Exact match. + self.assertEqual(soup.find_all(text="bar"), [u"bar"]) + # Match any of a number of strings. + self.assertEqual( + soup.find_all(text=["Foo", "bar"]), [u"Foo", u"bar"]) + # Match a regular expression. + self.assertEqual(soup.find_all(text=re.compile('.*')), + [u"Foo", u"bar", u'\xbb']) + # Match anything. + self.assertEqual(soup.find_all(text=True), + [u"Foo", u"bar", u'\xbb']) + + def test_find_all_limit(self): + """You can limit the number of items returned by find_all.""" + soup = self.soup("12345") + self.assertSelects(soup.find_all('a', limit=3), ["1", "2", "3"]) + self.assertSelects(soup.find_all('a', limit=1), ["1"]) + self.assertSelects( + soup.find_all('a', limit=10), ["1", "2", "3", "4", "5"]) + + # A limit of 0 means no limit. + self.assertSelects( + soup.find_all('a', limit=0), ["1", "2", "3", "4", "5"]) + + def test_calling_a_tag_is_calling_findall(self): + soup = self.soup("123") + self.assertSelects(soup('a', limit=1), ["1"]) + self.assertSelects(soup.b(id="foo"), ["3"]) + + def test_find_all_with_self_referential_data_structure_does_not_cause_infinite_recursion(self): + soup = self.soup("") + # Create a self-referential list. + l = [] + l.append(l) + + # Without special code in _normalize_search_value, this would cause infinite + # recursion. + self.assertEqual([], soup.find_all(l)) + +class TestFindAllBasicNamespaces(TreeTest): + + def test_find_by_namespaced_name(self): + soup = self.soup('4') + self.assertEqual("4", soup.find("mathml:msqrt").string) + self.assertEqual("a", soup.find(attrs= { "svg:fill" : "red" }).name) + + +class TestFindAllByName(TreeTest): + """Test ways of finding tags by tag name.""" + + def setUp(self): + super(TreeTest, self).setUp() + self.tree = self.soup("""First tag. + Second tag. + Third Nested tag. tag.""") + + def test_find_all_by_tag_name(self): + # Find all the tags. + self.assertSelects( + self.tree.find_all('a'), ['First tag.', 'Nested tag.']) + + def test_find_all_by_name_and_text(self): + self.assertSelects( + self.tree.find_all('a', text='First tag.'), ['First tag.']) + + self.assertSelects( + self.tree.find_all('a', text=True), ['First tag.', 'Nested tag.']) + + self.assertSelects( + self.tree.find_all('a', text=re.compile("tag")), + ['First tag.', 'Nested tag.']) + + + def test_find_all_on_non_root_element(self): + # You can call find_all on any node, not just the root. + self.assertSelects(self.tree.c.find_all('a'), ['Nested tag.']) + + def test_calling_element_invokes_find_all(self): + self.assertSelects(self.tree('a'), ['First tag.', 'Nested tag.']) + + def test_find_all_by_tag_strainer(self): + self.assertSelects( + self.tree.find_all(SoupStrainer('a')), + ['First tag.', 'Nested tag.']) + + def test_find_all_by_tag_names(self): + self.assertSelects( + self.tree.find_all(['a', 'b']), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_by_tag_dict(self): + self.assertSelects( + self.tree.find_all({'a' : True, 'b' : True}), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_by_tag_re(self): + self.assertSelects( + self.tree.find_all(re.compile('^[ab]$')), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_with_tags_matching_method(self): + # You can define an oracle method that determines whether + # a tag matches the search. + def id_matches_name(tag): + return tag.name == tag.get('id') + + tree = self.soup("""Match 1. + Does not match. + Match 2.""") + + self.assertSelects( + tree.find_all(id_matches_name), ["Match 1.", "Match 2."]) + + +class TestFindAllByAttribute(TreeTest): + + def test_find_all_by_attribute_name(self): + # You can pass in keyword arguments to find_all to search by + # attribute. + tree = self.soup(""" + Matching a. + + Non-matching Matching b.a. + """) + self.assertSelects(tree.find_all(id='first'), + ["Matching a.", "Matching b."]) + + def test_find_all_by_utf8_attribute_value(self): + peace = u"םולש".encode("utf8") + data = u''.encode("utf8") + soup = self.soup(data) + self.assertEqual([soup.a], soup.find_all(title=peace)) + self.assertEqual([soup.a], soup.find_all(title=peace.decode("utf8"))) + self.assertEqual([soup.a], soup.find_all(title=[peace, "something else"])) + + def test_find_all_by_attribute_dict(self): + # You can pass in a dictionary as the argument 'attrs'. This + # lets you search for attributes like 'name' (a fixed argument + # to find_all) and 'class' (a reserved word in Python.) + tree = self.soup(""" + Name match. + Class match. + Non-match. + A tag called 'name1'. + """) + + # This doesn't do what you want. + self.assertSelects(tree.find_all(name='name1'), + ["A tag called 'name1'."]) + # This does what you want. + self.assertSelects(tree.find_all(attrs={'name' : 'name1'}), + ["Name match."]) + + # Passing class='class2' would cause a syntax error. + self.assertSelects(tree.find_all(attrs={'class' : 'class2'}), + ["Class match."]) + + def test_find_all_by_class(self): + # Passing in a string to 'attrs' will search the CSS class. + tree = self.soup(""" + Class 1. + Class 2. + Class 1. + Class 3 and 4. + """) + self.assertSelects(tree.find_all('a', '1'), ['Class 1.']) + self.assertSelects(tree.find_all(attrs='1'), ['Class 1.', 'Class 1.']) + self.assertSelects(tree.find_all('c', '3'), ['Class 3 and 4.']) + self.assertSelects(tree.find_all('c', '4'), ['Class 3 and 4.']) + + def test_find_by_class_when_multiple_classes_present(self): + tree = self.soup("Found it") + + attrs = { 'class' : re.compile("o") } + f = tree.find_all("gar", attrs=attrs) + self.assertSelects(f, ["Found it"]) + + f = tree.find_all("gar", re.compile("a")) + self.assertSelects(f, ["Found it"]) + + # Since the class is not the string "foo bar", but the two + # strings "foo" and "bar", this will not find anything. + attrs = { 'class' : re.compile("o b") } + f = tree.find_all("gar", attrs=attrs) + self.assertSelects(f, []) + + def test_find_all_with_non_dictionary_for_attrs_finds_by_class(self): + soup = self.soup("Found it") + + self.assertSelects(soup.find_all("a", re.compile("ba")), ["Found it"]) + + def big_attribute_value(value): + return len(value) > 3 + + self.assertSelects(soup.find_all("a", big_attribute_value), []) + + def small_attribute_value(value): + return len(value) <= 3 + + self.assertSelects( + soup.find_all("a", small_attribute_value), ["Found it"]) + + def test_find_all_with_string_for_attrs_finds_multiple_classes(self): + soup = self.soup('') + a, a2 = soup.find_all("a") + self.assertEqual([a, a2], soup.find_all("a", "foo")) + self.assertEqual([a], soup.find_all("a", "bar")) + + # If you specify the attribute as a string that contains a + # space, only that specific value will be found. + self.assertEqual([a], soup.find_all("a", "foo bar")) + self.assertEqual([], soup.find_all("a", "bar foo")) + + def test_find_all_by_attribute_soupstrainer(self): + tree = self.soup(""" + Match. + Non-match.""") + + strainer = SoupStrainer(attrs={'id' : 'first'}) + self.assertSelects(tree.find_all(strainer), ['Match.']) + + def test_find_all_with_missing_atribute(self): + # You can pass in None as the value of an attribute to find_all. + # This will match tags that do not have that attribute set. + tree = self.soup("""ID present. + No ID present. + ID is empty.""") + self.assertSelects(tree.find_all('a', id=None), ["No ID present."]) + + def test_find_all_with_defined_attribute(self): + # You can pass in None as the value of an attribute to find_all. + # This will match tags that have that attribute set to any value. + tree = self.soup("""ID present. + No ID present. + ID is empty.""") + self.assertSelects( + tree.find_all(id=True), ["ID present.", "ID is empty."]) + + def test_find_all_with_numeric_attribute(self): + # If you search for a number, it's treated as a string. + tree = self.soup("""Unquoted attribute. + Quoted attribute.""") + + expected = ["Unquoted attribute.", "Quoted attribute."] + self.assertSelects(tree.find_all(id=1), expected) + self.assertSelects(tree.find_all(id="1"), expected) + + def test_find_all_with_list_attribute_values(self): + # You can pass a list of attribute values instead of just one, + # and you'll get tags that match any of the values. + tree = self.soup("""1 + 2 + 3 + No ID.""") + self.assertSelects(tree.find_all(id=["1", "3", "4"]), + ["1", "3"]) + + def test_find_all_with_regular_expression_attribute_value(self): + # You can pass a regular expression as an attribute value, and + # you'll get tags whose values for that attribute match the + # regular expression. + tree = self.soup("""One a. + Two as. + Mixed as and bs. + One b. + No ID.""") + + self.assertSelects(tree.find_all(id=re.compile("^a+$")), + ["One a.", "Two as."]) + + def test_find_by_name_and_containing_string(self): + soup = self.soup("foobarfoo") + a = soup.a + + self.assertEqual([a], soup.find_all("a", text="foo")) + self.assertEqual([], soup.find_all("a", text="bar")) + self.assertEqual([], soup.find_all("a", text="bar")) + + def test_find_by_name_and_containing_string_when_string_is_buried(self): + soup = self.soup("foofoo") + self.assertEqual(soup.find_all("a"), soup.find_all("a", text="foo")) + + def test_find_by_attribute_and_containing_string(self): + soup = self.soup('foofoo') + a = soup.a + + self.assertEqual([a], soup.find_all(id=2, text="foo")) + self.assertEqual([], soup.find_all(id=1, text="bar")) + + + + +class TestIndex(TreeTest): + """Test Tag.index""" + def test_index(self): + tree = self.soup("""
      + Identical + Not identical + Identical + + Identical with child + Also not identical + Identical with child +
      """) + div = tree.div + for i, element in enumerate(div.contents): + self.assertEqual(i, div.index(element)) + self.assertRaises(ValueError, tree.index, 1) + + +class TestParentOperations(TreeTest): + """Test navigation and searching through an element's parents.""" + + def setUp(self): + super(TestParentOperations, self).setUp() + self.tree = self.soup('''
        +
          +
            +
              + Start here +
            +
          ''') + self.start = self.tree.b + + + def test_parent(self): + self.assertEqual(self.start.parent['id'], 'bottom') + self.assertEqual(self.start.parent.parent['id'], 'middle') + self.assertEqual(self.start.parent.parent.parent['id'], 'top') + + def test_parent_of_top_tag_is_soup_object(self): + top_tag = self.tree.contents[0] + self.assertEqual(top_tag.parent, self.tree) + + def test_soup_object_has_no_parent(self): + self.assertEqual(None, self.tree.parent) + + def test_find_parents(self): + self.assertSelectsIDs( + self.start.find_parents('ul'), ['bottom', 'middle', 'top']) + self.assertSelectsIDs( + self.start.find_parents('ul', id="middle"), ['middle']) + + def test_find_parent(self): + self.assertEqual(self.start.find_parent('ul')['id'], 'bottom') + + def test_parent_of_text_element(self): + text = self.tree.find(text="Start here") + self.assertEqual(text.parent.name, 'b') + + def test_text_element_find_parent(self): + text = self.tree.find(text="Start here") + self.assertEqual(text.find_parent('ul')['id'], 'bottom') + + def test_parent_generator(self): + parents = [parent['id'] for parent in self.start.parents + if parent is not None and 'id' in parent.attrs] + self.assertEqual(parents, ['bottom', 'middle', 'top']) + + +class ProximityTest(TreeTest): + + def setUp(self): + super(TreeTest, self).setUp() + self.tree = self.soup( + 'OneTwoThree') + + +class TestNextOperations(ProximityTest): + + def setUp(self): + super(TestNextOperations, self).setUp() + self.start = self.tree.b + + def test_next(self): + self.assertEqual(self.start.next_element, "One") + self.assertEqual(self.start.next_element.next_element['id'], "2") + + def test_next_of_last_item_is_none(self): + last = self.tree.find(text="Three") + self.assertEqual(last.next_element, None) + + def test_next_of_root_is_none(self): + # The document root is outside the next/previous chain. + self.assertEqual(self.tree.next_element, None) + + def test_find_all_next(self): + self.assertSelects(self.start.find_all_next('b'), ["Two", "Three"]) + self.start.find_all_next(id=3) + self.assertSelects(self.start.find_all_next(id=3), ["Three"]) + + def test_find_next(self): + self.assertEqual(self.start.find_next('b')['id'], '2') + self.assertEqual(self.start.find_next(text="Three"), "Three") + + def test_find_next_for_text_element(self): + text = self.tree.find(text="One") + self.assertEqual(text.find_next("b").string, "Two") + self.assertSelects(text.find_all_next("b"), ["Two", "Three"]) + + def test_next_generator(self): + start = self.tree.find(text="Two") + successors = [node for node in start.next_elements] + # There are two successors: the final tag and its text contents. + tag, contents = successors + self.assertEqual(tag['id'], '3') + self.assertEqual(contents, "Three") + +class TestPreviousOperations(ProximityTest): + + def setUp(self): + super(TestPreviousOperations, self).setUp() + self.end = self.tree.find(text="Three") + + def test_previous(self): + self.assertEqual(self.end.previous_element['id'], "3") + self.assertEqual(self.end.previous_element.previous_element, "Two") + + def test_previous_of_first_item_is_none(self): + first = self.tree.find('html') + self.assertEqual(first.previous_element, None) + + def test_previous_of_root_is_none(self): + # The document root is outside the next/previous chain. + # XXX This is broken! + #self.assertEqual(self.tree.previous_element, None) + pass + + def test_find_all_previous(self): + # The tag containing the "Three" node is the predecessor + # of the "Three" node itself, which is why "Three" shows up + # here. + self.assertSelects( + self.end.find_all_previous('b'), ["Three", "Two", "One"]) + self.assertSelects(self.end.find_all_previous(id=1), ["One"]) + + def test_find_previous(self): + self.assertEqual(self.end.find_previous('b')['id'], '3') + self.assertEqual(self.end.find_previous(text="One"), "One") + + def test_find_previous_for_text_element(self): + text = self.tree.find(text="Three") + self.assertEqual(text.find_previous("b").string, "Three") + self.assertSelects( + text.find_all_previous("b"), ["Three", "Two", "One"]) + + def test_previous_generator(self): + start = self.tree.find(text="One") + predecessors = [node for node in start.previous_elements] + + # There are four predecessors: the tag containing "One" + # the tag, the tag, and the tag. + b, body, head, html = predecessors + self.assertEqual(b['id'], '1') + self.assertEqual(body.name, "body") + self.assertEqual(head.name, "head") + self.assertEqual(html.name, "html") + + +class SiblingTest(TreeTest): + + def setUp(self): + super(SiblingTest, self).setUp() + markup = ''' + + + + + + + + + + + ''' + # All that whitespace looks good but makes the tests more + # difficult. Get rid of it. + markup = re.compile("\n\s*").sub("", markup) + self.tree = self.soup(markup) + + +class TestNextSibling(SiblingTest): + + def setUp(self): + super(TestNextSibling, self).setUp() + self.start = self.tree.find(id="1") + + def test_next_sibling_of_root_is_none(self): + self.assertEqual(self.tree.next_sibling, None) + + def test_next_sibling(self): + self.assertEqual(self.start.next_sibling['id'], '2') + self.assertEqual(self.start.next_sibling.next_sibling['id'], '3') + + # Note the difference between next_sibling and next_element. + self.assertEqual(self.start.next_element['id'], '1.1') + + def test_next_sibling_may_not_exist(self): + self.assertEqual(self.tree.html.next_sibling, None) + + nested_span = self.tree.find(id="1.1") + self.assertEqual(nested_span.next_sibling, None) + + last_span = self.tree.find(id="4") + self.assertEqual(last_span.next_sibling, None) + + def test_find_next_sibling(self): + self.assertEqual(self.start.find_next_sibling('span')['id'], '2') + + def test_next_siblings(self): + self.assertSelectsIDs(self.start.find_next_siblings("span"), + ['2', '3', '4']) + + self.assertSelectsIDs(self.start.find_next_siblings(id='3'), ['3']) + + def test_next_sibling_for_text_element(self): + soup = self.soup("Foobarbaz") + start = soup.find(text="Foo") + self.assertEqual(start.next_sibling.name, 'b') + self.assertEqual(start.next_sibling.next_sibling, 'baz') + + self.assertSelects(start.find_next_siblings('b'), ['bar']) + self.assertEqual(start.find_next_sibling(text="baz"), "baz") + self.assertEqual(start.find_next_sibling(text="nonesuch"), None) + + +class TestPreviousSibling(SiblingTest): + + def setUp(self): + super(TestPreviousSibling, self).setUp() + self.end = self.tree.find(id="4") + + def test_previous_sibling_of_root_is_none(self): + self.assertEqual(self.tree.previous_sibling, None) + + def test_previous_sibling(self): + self.assertEqual(self.end.previous_sibling['id'], '3') + self.assertEqual(self.end.previous_sibling.previous_sibling['id'], '2') + + # Note the difference between previous_sibling and previous_element. + self.assertEqual(self.end.previous_element['id'], '3.1') + + def test_previous_sibling_may_not_exist(self): + self.assertEqual(self.tree.html.previous_sibling, None) + + nested_span = self.tree.find(id="1.1") + self.assertEqual(nested_span.previous_sibling, None) + + first_span = self.tree.find(id="1") + self.assertEqual(first_span.previous_sibling, None) + + def test_find_previous_sibling(self): + self.assertEqual(self.end.find_previous_sibling('span')['id'], '3') + + def test_previous_siblings(self): + self.assertSelectsIDs(self.end.find_previous_siblings("span"), + ['3', '2', '1']) + + self.assertSelectsIDs(self.end.find_previous_siblings(id='1'), ['1']) + + def test_previous_sibling_for_text_element(self): + soup = self.soup("Foobarbaz") + start = soup.find(text="baz") + self.assertEqual(start.previous_sibling.name, 'b') + self.assertEqual(start.previous_sibling.previous_sibling, 'Foo') + + self.assertSelects(start.find_previous_siblings('b'), ['bar']) + self.assertEqual(start.find_previous_sibling(text="Foo"), "Foo") + self.assertEqual(start.find_previous_sibling(text="nonesuch"), None) + + +class TestTagCreation(SoupTest): + """Test the ability to create new tags.""" + def test_new_tag(self): + soup = self.soup("") + new_tag = soup.new_tag("foo", bar="baz") + self.assertTrue(isinstance(new_tag, Tag)) + self.assertEqual("foo", new_tag.name) + self.assertEqual(dict(bar="baz"), new_tag.attrs) + self.assertEqual(None, new_tag.parent) + + def test_tag_inherits_self_closing_rules_from_builder(self): + if XML_BUILDER_PRESENT: + xml_soup = BeautifulSoup("", "xml") + xml_br = xml_soup.new_tag("br") + xml_p = xml_soup.new_tag("p") + + # Both the
          and

          tag are empty-element, just because + # they have no contents. + self.assertEqual(b"
          ", xml_br.encode()) + self.assertEqual(b"

          ", xml_p.encode()) + + html_soup = BeautifulSoup("", "html") + html_br = html_soup.new_tag("br") + html_p = html_soup.new_tag("p") + + # The HTML builder users HTML's rules about which tags are + # empty-element tags, and the new tags reflect these rules. + self.assertEqual(b"
          ", html_br.encode()) + self.assertEqual(b"

          ", html_p.encode()) + + def test_new_string_creates_navigablestring(self): + soup = self.soup("") + s = soup.new_string("foo") + self.assertEqual("foo", s) + self.assertTrue(isinstance(s, NavigableString)) + +class TestTreeModification(SoupTest): + + def test_attribute_modification(self): + soup = self.soup('') + soup.a['id'] = 2 + self.assertEqual(soup.decode(), self.document_for('')) + del(soup.a['id']) + self.assertEqual(soup.decode(), self.document_for('')) + soup.a['id2'] = 'foo' + self.assertEqual(soup.decode(), self.document_for('')) + + def test_new_tag_creation(self): + builder = builder_registry.lookup('html')() + soup = self.soup("", builder=builder) + a = Tag(soup, builder, 'a') + ol = Tag(soup, builder, 'ol') + a['href'] = 'http://foo.com/' + soup.body.insert(0, a) + soup.body.insert(1, ol) + self.assertEqual( + soup.body.encode(), + b'
            ') + + def test_append_to_contents_moves_tag(self): + doc = """

            Don't leave me here.

            +

            Don\'t leave!

            """ + soup = self.soup(doc) + second_para = soup.find(id='2') + bold = soup.b + + # Move the tag to the end of the second paragraph. + soup.find(id='2').append(soup.b) + + # The tag is now a child of the second paragraph. + self.assertEqual(bold.parent, second_para) + + self.assertEqual( + soup.decode(), self.document_for( + '

            Don\'t leave me .

            \n' + '

            Don\'t leave!here

            ')) + + def test_replace_with_returns_thing_that_was_replaced(self): + text = "" + soup = self.soup(text) + a = soup.a + new_a = a.replace_with(soup.c) + self.assertEqual(a, new_a) + + def test_unwrap_returns_thing_that_was_replaced(self): + text = "" + soup = self.soup(text) + a = soup.a + new_a = a.unwrap() + self.assertEqual(a, new_a) + + def test_replace_tag_with_itself(self): + text = "Foo" + soup = self.soup(text) + c = soup.c + soup.c.replace_with(c) + self.assertEqual(soup.decode(), self.document_for(text)) + + def test_replace_tag_with_its_parent_raises_exception(self): + text = "" + soup = self.soup(text) + self.assertRaises(ValueError, soup.b.replace_with, soup.a) + + def test_insert_tag_into_itself_raises_exception(self): + text = "" + soup = self.soup(text) + self.assertRaises(ValueError, soup.a.insert, 0, soup.a) + + def test_replace_with_maintains_next_element_throughout(self): + soup = self.soup('

            onethree

            ') + a = soup.a + b = a.contents[0] + # Make it so the tag has two text children. + a.insert(1, "two") + + # Now replace each one with the empty string. + left, right = a.contents + left.replaceWith('') + right.replaceWith('') + + # The tag is still connected to the tree. + self.assertEqual("three", soup.b.string) + + def test_replace_final_node(self): + soup = self.soup("Argh!") + soup.find(text="Argh!").replace_with("Hooray!") + new_text = soup.find(text="Hooray!") + b = soup.b + self.assertEqual(new_text.previous_element, b) + self.assertEqual(new_text.parent, b) + self.assertEqual(new_text.previous_element.next_element, new_text) + self.assertEqual(new_text.next_element, None) + + def test_consecutive_text_nodes(self): + # A builder should never create two consecutive text nodes, + # but if you insert one next to another, Beautiful Soup will + # handle it correctly. + soup = self.soup("Argh!") + soup.b.insert(1, "Hooray!") + + self.assertEqual( + soup.decode(), self.document_for( + "Argh!Hooray!")) + + new_text = soup.find(text="Hooray!") + self.assertEqual(new_text.previous_element, "Argh!") + self.assertEqual(new_text.previous_element.next_element, new_text) + + self.assertEqual(new_text.previous_sibling, "Argh!") + self.assertEqual(new_text.previous_sibling.next_sibling, new_text) + + self.assertEqual(new_text.next_sibling, None) + self.assertEqual(new_text.next_element, soup.c) + + def test_insert_string(self): + soup = self.soup("") + soup.a.insert(0, "bar") + soup.a.insert(0, "foo") + # The string were added to the tag. + self.assertEqual(["foo", "bar"], soup.a.contents) + # And they were converted to NavigableStrings. + self.assertEqual(soup.a.contents[0].next_element, "bar") + + def test_insert_tag(self): + builder = self.default_builder + soup = self.soup( + "Findlady!", builder=builder) + magic_tag = Tag(soup, builder, 'magictag') + magic_tag.insert(0, "the") + soup.a.insert(1, magic_tag) + + self.assertEqual( + soup.decode(), self.document_for( + "Findthelady!")) + + # Make sure all the relationships are hooked up correctly. + b_tag = soup.b + self.assertEqual(b_tag.next_sibling, magic_tag) + self.assertEqual(magic_tag.previous_sibling, b_tag) + + find = b_tag.find(text="Find") + self.assertEqual(find.next_element, magic_tag) + self.assertEqual(magic_tag.previous_element, find) + + c_tag = soup.c + self.assertEqual(magic_tag.next_sibling, c_tag) + self.assertEqual(c_tag.previous_sibling, magic_tag) + + the = magic_tag.find(text="the") + self.assertEqual(the.parent, magic_tag) + self.assertEqual(the.next_element, c_tag) + self.assertEqual(c_tag.previous_element, the) + + def test_append_child_thats_already_at_the_end(self): + data = "" + soup = self.soup(data) + soup.a.append(soup.b) + self.assertEqual(data, soup.decode()) + + def test_move_tag_to_beginning_of_parent(self): + data = "" + soup = self.soup(data) + soup.a.insert(0, soup.d) + self.assertEqual("", soup.decode()) + + def test_insert_works_on_empty_element_tag(self): + # This is a little strange, since most HTML parsers don't allow + # markup like this to come through. But in general, we don't + # know what the parser would or wouldn't have allowed, so + # I'm letting this succeed for now. + soup = self.soup("
            ") + soup.br.insert(1, "Contents") + self.assertEqual(str(soup.br), "
            Contents
            ") + + def test_insert_before(self): + soup = self.soup("foobar") + soup.b.insert_before("BAZ") + soup.a.insert_before("QUUX") + self.assertEqual( + soup.decode(), self.document_for("QUUXfooBAZbar")) + + soup.a.insert_before(soup.b) + self.assertEqual( + soup.decode(), self.document_for("QUUXbarfooBAZ")) + + def test_insert_after(self): + soup = self.soup("foobar") + soup.b.insert_after("BAZ") + soup.a.insert_after("QUUX") + self.assertEqual( + soup.decode(), self.document_for("fooQUUXbarBAZ")) + soup.b.insert_after(soup.a) + self.assertEqual( + soup.decode(), self.document_for("QUUXbarfooBAZ")) + + def test_insert_after_raises_valueerror_if_after_has_no_meaning(self): + soup = self.soup("") + tag = soup.new_tag("a") + string = soup.new_string("") + self.assertRaises(ValueError, string.insert_after, tag) + self.assertRaises(ValueError, soup.insert_after, tag) + self.assertRaises(ValueError, tag.insert_after, tag) + + def test_insert_before_raises_valueerror_if_before_has_no_meaning(self): + soup = self.soup("") + tag = soup.new_tag("a") + string = soup.new_string("") + self.assertRaises(ValueError, string.insert_before, tag) + self.assertRaises(ValueError, soup.insert_before, tag) + self.assertRaises(ValueError, tag.insert_before, tag) + + def test_replace_with(self): + soup = self.soup( + "

            There's no business like show business

            ") + no, show = soup.find_all('b') + show.replace_with(no) + self.assertEqual( + soup.decode(), + self.document_for( + "

            There's business like no business

            ")) + + self.assertEqual(show.parent, None) + self.assertEqual(no.parent, soup.p) + self.assertEqual(no.next_element, "no") + self.assertEqual(no.next_sibling, " business") + + def test_replace_first_child(self): + data = "" + soup = self.soup(data) + soup.b.replace_with(soup.c) + self.assertEqual("", soup.decode()) + + def test_replace_last_child(self): + data = "" + soup = self.soup(data) + soup.c.replace_with(soup.b) + self.assertEqual("", soup.decode()) + + def test_nested_tag_replace_with(self): + soup = self.soup( + """Wereservetherighttorefuseservice""") + + # Replace the entire tag and its contents ("reserve the + # right") with the tag ("refuse"). + remove_tag = soup.b + move_tag = soup.f + remove_tag.replace_with(move_tag) + + self.assertEqual( + soup.decode(), self.document_for( + "Werefusetoservice")) + + # The tag is now an orphan. + self.assertEqual(remove_tag.parent, None) + self.assertEqual(remove_tag.find(text="right").next_element, None) + self.assertEqual(remove_tag.previous_element, None) + self.assertEqual(remove_tag.next_sibling, None) + self.assertEqual(remove_tag.previous_sibling, None) + + # The tag is now connected to the tag. + self.assertEqual(move_tag.parent, soup.a) + self.assertEqual(move_tag.previous_element, "We") + self.assertEqual(move_tag.next_element.next_element, soup.e) + self.assertEqual(move_tag.next_sibling, None) + + # The gap where the tag used to be has been mended, and + # the word "to" is now connected to the tag. + to_text = soup.find(text="to") + g_tag = soup.g + self.assertEqual(to_text.next_element, g_tag) + self.assertEqual(to_text.next_sibling, g_tag) + self.assertEqual(g_tag.previous_element, to_text) + self.assertEqual(g_tag.previous_sibling, to_text) + + def test_unwrap(self): + tree = self.soup(""" +

            Unneeded formatting is unneeded

            + """) + tree.em.unwrap() + self.assertEqual(tree.em, None) + self.assertEqual(tree.p.text, "Unneeded formatting is unneeded") + + def test_wrap(self): + soup = self.soup("I wish I was bold.") + value = soup.string.wrap(soup.new_tag("b")) + self.assertEqual(value.decode(), "I wish I was bold.") + self.assertEqual( + soup.decode(), self.document_for("I wish I was bold.")) + + def test_wrap_extracts_tag_from_elsewhere(self): + soup = self.soup("I wish I was bold.") + soup.b.next_sibling.wrap(soup.b) + self.assertEqual( + soup.decode(), self.document_for("I wish I was bold.")) + + def test_wrap_puts_new_contents_at_the_end(self): + soup = self.soup("I like being bold.I wish I was bold.") + soup.b.next_sibling.wrap(soup.b) + self.assertEqual(2, len(soup.b.contents)) + self.assertEqual( + soup.decode(), self.document_for( + "I like being bold.I wish I was bold.")) + + def test_extract(self): + soup = self.soup( + 'Some content. More content.') + + self.assertEqual(len(soup.body.contents), 3) + extracted = soup.find(id="nav").extract() + + self.assertEqual( + soup.decode(), "Some content. More content.") + self.assertEqual(extracted.decode(), '') + + # The extracted tag is now an orphan. + self.assertEqual(len(soup.body.contents), 2) + self.assertEqual(extracted.parent, None) + self.assertEqual(extracted.previous_element, None) + self.assertEqual(extracted.next_element.next_element, None) + + # The gap where the extracted tag used to be has been mended. + content_1 = soup.find(text="Some content. ") + content_2 = soup.find(text=" More content.") + self.assertEqual(content_1.next_element, content_2) + self.assertEqual(content_1.next_sibling, content_2) + self.assertEqual(content_2.previous_element, content_1) + self.assertEqual(content_2.previous_sibling, content_1) + + def test_extract_distinguishes_between_identical_strings(self): + soup = self.soup("
            foobar") + foo_1 = soup.a.string + bar_1 = soup.b.string + foo_2 = soup.new_string("foo") + bar_2 = soup.new_string("bar") + soup.a.append(foo_2) + soup.b.append(bar_2) + + # Now there are two identical strings in the tag, and two + # in the tag. Let's remove the first "foo" and the second + # "bar". + foo_1.extract() + bar_2.extract() + self.assertEqual(foo_2, soup.a.string) + self.assertEqual(bar_2, soup.b.string) + + def test_clear(self): + """Tag.clear()""" + soup = self.soup("

            String Italicized and another

            ") + # clear using extract() + a = soup.a + soup.p.clear() + self.assertEqual(len(soup.p.contents), 0) + self.assertTrue(hasattr(a, "contents")) + + # clear using decompose() + em = a.em + a.clear(decompose=True) + self.assertFalse(hasattr(em, "contents")) + + def test_string_set(self): + """Tag.string = 'string'""" + soup = self.soup(" ") + soup.a.string = "foo" + self.assertEqual(soup.a.contents, ["foo"]) + soup.b.string = "bar" + self.assertEqual(soup.b.contents, ["bar"]) + + def test_string_set_does_not_affect_original_string(self): + soup = self.soup("foobar") + soup.b.string = soup.c.string + self.assertEqual(soup.a.encode(), b"barbar") + + def test_set_string_preserves_class_of_string(self): + soup = self.soup("") + cdata = CData("foo") + soup.a.string = cdata + self.assertTrue(isinstance(soup.a.string, CData)) + +class TestElementObjects(SoupTest): + """Test various features of element objects.""" + + def test_len(self): + """The length of an element is its number of children.""" + soup = self.soup("123") + + # The BeautifulSoup object itself contains one element: the + # tag. + self.assertEqual(len(soup.contents), 1) + self.assertEqual(len(soup), 1) + + # The tag contains three elements: the text node "1", the + # tag, and the text node "3". + self.assertEqual(len(soup.top), 3) + self.assertEqual(len(soup.top.contents), 3) + + def test_member_access_invokes_find(self): + """Accessing a Python member .foo invokes find('foo')""" + soup = self.soup('') + self.assertEqual(soup.b, soup.find('b')) + self.assertEqual(soup.b.i, soup.find('b').find('i')) + self.assertEqual(soup.a, None) + + def test_deprecated_member_access(self): + soup = self.soup('') + with warnings.catch_warnings(record=True) as w: + tag = soup.bTag + self.assertEqual(soup.b, tag) + self.assertEqual( + '.bTag is deprecated, use .find("b") instead.', + str(w[0].message)) + + def test_has_attr(self): + """has_attr() checks for the presence of an attribute. + + Please note note: has_attr() is different from + __in__. has_attr() checks the tag's attributes and __in__ + checks the tag's chidlren. + """ + soup = self.soup("") + self.assertTrue(soup.foo.has_attr('attr')) + self.assertFalse(soup.foo.has_attr('attr2')) + + + def test_attributes_come_out_in_alphabetical_order(self): + markup = '' + self.assertSoupEquals(markup, '') + + def test_string(self): + # A tag that contains only a text node makes that node + # available as .string. + soup = self.soup("foo") + self.assertEqual(soup.b.string, 'foo') + + def test_empty_tag_has_no_string(self): + # A tag with no children has no .stirng. + soup = self.soup("") + self.assertEqual(soup.b.string, None) + + def test_tag_with_multiple_children_has_no_string(self): + # A tag with no children has no .string. + soup = self.soup("foo") + self.assertEqual(soup.b.string, None) + + soup = self.soup("foobar
            ") + self.assertEqual(soup.b.string, None) + + # Even if all the children are strings, due to trickery, + # it won't work--but this would be a good optimization. + soup = self.soup("foo
            ") + soup.a.insert(1, "bar") + self.assertEqual(soup.a.string, None) + + def test_tag_with_recursive_string_has_string(self): + # A tag with a single child which has a .string inherits that + # .string. + soup = self.soup("foo") + self.assertEqual(soup.a.string, "foo") + self.assertEqual(soup.string, "foo") + + def test_lack_of_string(self): + """Only a tag containing a single text node has a .string.""" + soup = self.soup("feo") + self.assertFalse(soup.b.string) + + soup = self.soup("") + self.assertFalse(soup.b.string) + + def test_all_text(self): + """Tag.text and Tag.get_text(sep=u"") -> all child text, concatenated""" + soup = self.soup("ar t ") + self.assertEqual(soup.a.text, "ar t ") + self.assertEqual(soup.a.get_text(strip=True), "art") + self.assertEqual(soup.a.get_text(","), "a,r, , t ") + self.assertEqual(soup.a.get_text(",", strip=True), "a,r,t") + +class TestCDAtaListAttributes(SoupTest): + + """Testing cdata-list attributes like 'class'. + """ + def test_single_value_becomes_list(self): + soup = self.soup("") + self.assertEqual(["foo"],soup.a['class']) + + def test_multiple_values_becomes_list(self): + soup = self.soup("") + self.assertEqual(["foo", "bar"], soup.a['class']) + + def test_multiple_values_separated_by_weird_whitespace(self): + soup = self.soup("") + self.assertEqual(["foo", "bar", "baz"],soup.a['class']) + + def test_attributes_joined_into_string_on_output(self): + soup = self.soup("") + self.assertEqual(b'', soup.a.encode()) + + def test_accept_charset(self): + soup = self.soup('
            ') + self.assertEqual(['ISO-8859-1', 'UTF-8'], soup.form['accept-charset']) + + def test_cdata_attribute_applying_only_to_one_tag(self): + data = '' + soup = self.soup(data) + # We saw in another test that accept-charset is a cdata-list + # attribute for the tag. But it's not a cdata-list + # attribute for any other tag. + self.assertEqual('ISO-8859-1 UTF-8', soup.a['accept-charset']) + + +class TestPersistence(SoupTest): + "Testing features like pickle and deepcopy." + + def setUp(self): + super(TestPersistence, self).setUp() + self.page = """ + + + +Beautiful Soup: We called him Tortoise because he taught us. + + + + + + +foo +bar + +""" + self.tree = self.soup(self.page) + + def test_pickle_and_unpickle_identity(self): + # Pickling a tree, then unpickling it, yields a tree identical + # to the original. + dumped = pickle.dumps(self.tree, 2) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.__class__, BeautifulSoup) + self.assertEqual(loaded.decode(), self.tree.decode()) + + def test_deepcopy_identity(self): + # Making a deepcopy of a tree yields an identical tree. + copied = copy.deepcopy(self.tree) + self.assertEqual(copied.decode(), self.tree.decode()) + + def test_unicode_pickle(self): + # A tree containing Unicode characters can be pickled. + html = u"\N{SNOWMAN}" + soup = self.soup(html) + dumped = pickle.dumps(soup, pickle.HIGHEST_PROTOCOL) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.decode(), soup.decode()) + + +class TestSubstitutions(SoupTest): + + def test_default_formatter_is_minimal(self): + markup = u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>" + soup = self.soup(markup) + decoded = soup.decode(formatter="minimal") + # The < is converted back into < but the e-with-acute is left alone. + self.assertEqual( + decoded, + self.document_for( + u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>")) + + def test_formatter_html(self): + markup = u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>" + soup = self.soup(markup) + decoded = soup.decode(formatter="html") + self.assertEqual( + decoded, + self.document_for("<<Sacré bleu!>>")) + + def test_formatter_minimal(self): + markup = u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>" + soup = self.soup(markup) + decoded = soup.decode(formatter="minimal") + # The < is converted back into < but the e-with-acute is left alone. + self.assertEqual( + decoded, + self.document_for( + u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>")) + + def test_formatter_null(self): + markup = u"<<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>" + soup = self.soup(markup) + decoded = soup.decode(formatter=None) + # Neither the angle brackets nor the e-with-acute are converted. + # This is not valid HTML, but it's what the user wanted. + self.assertEqual(decoded, + self.document_for(u"<>")) + + def test_formatter_custom(self): + markup = u"<foo>bar" + soup = self.soup(markup) + decoded = soup.decode(formatter = lambda x: x.upper()) + # Instead of normal entity conversion code, the custom + # callable is called on every string. + self.assertEqual( + decoded, + self.document_for(u"BAR")) + + def test_formatter_is_run_on_attribute_values(self): + markup = u'e' + soup = self.soup(markup) + a = soup.a + + expect_minimal = u'e' + + self.assertEqual(expect_minimal, a.decode()) + self.assertEqual(expect_minimal, a.decode(formatter="minimal")) + + expect_html = u'e' + self.assertEqual(expect_html, a.decode(formatter="html")) + + self.assertEqual(markup, a.decode(formatter=None)) + expect_upper = u'E' + self.assertEqual(expect_upper, a.decode(formatter=lambda x: x.upper())) + + def test_prettify_accepts_formatter(self): + soup = BeautifulSoup("foo") + pretty = soup.prettify(formatter = lambda x: x.upper()) + self.assertTrue("FOO" in pretty) + + def test_prettify_outputs_unicode_by_default(self): + soup = self.soup("") + self.assertEqual(unicode, type(soup.prettify())) + + def test_prettify_can_encode_data(self): + soup = self.soup("") + self.assertEqual(bytes, type(soup.prettify("utf-8"))) + + def test_html_entity_substitution_off_by_default(self): + markup = u"Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!" + soup = self.soup(markup) + encoded = soup.b.encode("utf-8") + self.assertEqual(encoded, markup.encode('utf-8')) + + def test_encoding_substitution(self): + # Here's the tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('') + soup = self.soup(meta_tag) + + # Parse the document, and the charset apprears unchanged. + self.assertEqual(soup.meta['content'], 'text/html; charset=x-sjis') + + # Encode the document into some encoding, and the encoding is + # substituted into the meta tag. + utf_8 = soup.encode("utf-8") + self.assertTrue(b"charset=utf-8" in utf_8) + + euc_jp = soup.encode("euc_jp") + self.assertTrue(b"charset=euc_jp" in euc_jp) + + shift_jis = soup.encode("shift-jis") + self.assertTrue(b"charset=shift-jis" in shift_jis) + + utf_16_u = soup.encode("utf-16").decode("utf-16") + self.assertTrue("charset=utf-16" in utf_16_u) + + def test_encoding_substitution_doesnt_happen_if_tag_is_strained(self): + markup = ('
            foo
            ') + + # Beautiful Soup used to try to rewrite the meta tag even if the + # meta tag got filtered out by the strainer. This test makes + # sure that doesn't happen. + strainer = SoupStrainer('pre') + soup = self.soup(markup, parse_only=strainer) + self.assertEqual(soup.contents[0].name, 'pre') + +class TestEncoding(SoupTest): + """Test the ability to encode objects into strings.""" + + def test_unicode_string_can_be_encoded(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual(soup.b.string.encode("utf-8"), + u"\N{SNOWMAN}".encode("utf-8")) + + def test_tag_containing_unicode_string_can_be_encoded(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual( + soup.b.encode("utf-8"), html.encode("utf-8")) + + def test_encoding_substitutes_unrecognized_characters_by_default(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual(soup.b.encode("ascii"), b"") + + def test_encoding_can_be_made_strict(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertRaises( + UnicodeEncodeError, soup.encode, "ascii", errors="strict") + + def test_decode_contents(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual(u"\N{SNOWMAN}", soup.b.decode_contents()) + + def test_encode_contents(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual( + u"\N{SNOWMAN}".encode("utf8"), soup.b.encode_contents( + encoding="utf8")) + + def test_deprecated_renderContents(self): + html = u"\N{SNOWMAN}" + soup = self.soup(html) + self.assertEqual( + u"\N{SNOWMAN}".encode("utf8"), soup.b.renderContents()) + +class TestNavigableStringSubclasses(SoupTest): + + def test_cdata(self): + # None of the current builders turn CDATA sections into CData + # objects, but you can create them manually. + soup = self.soup("") + cdata = CData("foo") + soup.insert(1, cdata) + self.assertEqual(str(soup), "") + self.assertEqual(soup.find(text="foo"), "foo") + self.assertEqual(soup.contents[0], "foo") + + def test_cdata_is_never_formatted(self): + """Text inside a CData object is passed into the formatter. + + But the return value is ignored. + """ + + self.count = 0 + def increment(*args): + self.count += 1 + return "BITTER FAILURE" + + soup = self.soup("") + cdata = CData("<><><>") + soup.insert(1, cdata) + self.assertEqual( + b"<><>]]>", soup.encode(formatter=increment)) + self.assertEqual(1, self.count) + + def test_doctype_ends_in_newline(self): + # Unlike other NavigableString subclasses, a DOCTYPE always ends + # in a newline. + doctype = Doctype("foo") + soup = self.soup("") + soup.insert(1, doctype) + self.assertEqual(soup.encode(), b"\n") + + +class TestSoupSelector(TreeTest): + + HTML = """ + + + +The title + + + + +
            +
            +

            An H1

            +

            Some text

            +

            Some more text

            +

            An H2

            +

            Another

            +Bob +

            Another H2

            +me + +span1a1 +span1a2 test + +span2a1 + + + +
            +

            English

            +

            English UK

            +

            English US

            +

            French

            +
            + + +""" + + def setUp(self): + self.soup = BeautifulSoup(self.HTML) + + def assertSelects(self, selector, expected_ids): + el_ids = [el['id'] for el in self.soup.select(selector)] + el_ids.sort() + expected_ids.sort() + self.assertEqual(expected_ids, el_ids, + "Selector %s, expected [%s], got [%s]" % ( + selector, ', '.join(expected_ids), ', '.join(el_ids) + ) + ) + + assertSelect = assertSelects + + def assertSelectMultiple(self, *tests): + for selector, expected_ids in tests: + self.assertSelect(selector, expected_ids) + + def test_one_tag_one(self): + els = self.soup.select('title') + self.assertEqual(len(els), 1) + self.assertEqual(els[0].name, 'title') + self.assertEqual(els[0].contents, [u'The title']) + + def test_one_tag_many(self): + els = self.soup.select('div') + self.assertEqual(len(els), 3) + for div in els: + self.assertEqual(div.name, 'div') + + def test_tag_in_tag_one(self): + els = self.soup.select('div div') + self.assertSelects('div div', ['inner']) + + def test_tag_in_tag_many(self): + for selector in ('html div', 'html body div', 'body div'): + self.assertSelects(selector, ['main', 'inner', 'footer']) + + def test_tag_no_match(self): + self.assertEqual(len(self.soup.select('del')), 0) + + def test_invalid_tag(self): + self.assertEqual(len(self.soup.select('tag%t')), 0) + + def test_header_tags(self): + self.assertSelectMultiple( + ('h1', ['header1']), + ('h2', ['header2', 'header3']), + ) + + def test_class_one(self): + for selector in ('.onep', 'p.onep', 'html p.onep'): + els = self.soup.select(selector) + self.assertEqual(len(els), 1) + self.assertEqual(els[0].name, 'p') + self.assertEqual(els[0]['class'], ['onep']) + + def test_class_mismatched_tag(self): + els = self.soup.select('div.onep') + self.assertEqual(len(els), 0) + + def test_one_id(self): + for selector in ('div#inner', '#inner', 'div div#inner'): + self.assertSelects(selector, ['inner']) + + def test_bad_id(self): + els = self.soup.select('#doesnotexist') + self.assertEqual(len(els), 0) + + def test_items_in_id(self): + els = self.soup.select('div#inner p') + self.assertEqual(len(els), 3) + for el in els: + self.assertEqual(el.name, 'p') + self.assertEqual(els[1]['class'], ['onep']) + self.assertFalse(els[0].has_key('class')) + + def test_a_bunch_of_emptys(self): + for selector in ('div#main del', 'div#main div.oops', 'div div#main'): + self.assertEqual(len(self.soup.select(selector)), 0) + + def test_multi_class_support(self): + for selector in ('.class1', 'p.class1', '.class2', 'p.class2', + '.class3', 'p.class3', 'html p.class2', 'div#inner .class2'): + self.assertSelects(selector, ['pmulti']) + + def test_multi_class_selection(self): + for selector in ('.class1.class3', '.class3.class2', + '.class1.class2.class3'): + self.assertSelects(selector, ['pmulti']) + + def test_child_selector(self): + self.assertSelects('.s1 > a', ['s1a1', 's1a2']) + self.assertSelects('.s1 > a span', ['s1a2s1']) + + def test_attribute_equals(self): + self.assertSelectMultiple( + ('p[class="onep"]', ['p1']), + ('p[id="p1"]', ['p1']), + ('[class="onep"]', ['p1']), + ('[id="p1"]', ['p1']), + ('link[rel="stylesheet"]', ['l1']), + ('link[type="text/css"]', ['l1']), + ('link[href="blah.css"]', ['l1']), + ('link[href="no-blah.css"]', []), + ('[rel="stylesheet"]', ['l1']), + ('[type="text/css"]', ['l1']), + ('[href="blah.css"]', ['l1']), + ('[href="no-blah.css"]', []), + ('p[href="no-blah.css"]', []), + ('[href="no-blah.css"]', []), + ) + + def test_attribute_tilde(self): + self.assertSelectMultiple( + ('p[class~="class1"]', ['pmulti']), + ('p[class~="class2"]', ['pmulti']), + ('p[class~="class3"]', ['pmulti']), + ('[class~="class1"]', ['pmulti']), + ('[class~="class2"]', ['pmulti']), + ('[class~="class3"]', ['pmulti']), + ('a[rel~="friend"]', ['bob']), + ('a[rel~="met"]', ['bob']), + ('[rel~="friend"]', ['bob']), + ('[rel~="met"]', ['bob']), + ) + + def test_attribute_startswith(self): + self.assertSelectMultiple( + ('[rel^="style"]', ['l1']), + ('link[rel^="style"]', ['l1']), + ('notlink[rel^="notstyle"]', []), + ('[rel^="notstyle"]', []), + ('link[rel^="notstyle"]', []), + ('link[href^="bla"]', ['l1']), + ('a[href^="http://"]', ['bob', 'me']), + ('[href^="http://"]', ['bob', 'me']), + ('[id^="p"]', ['pmulti', 'p1']), + ('[id^="m"]', ['me', 'main']), + ('div[id^="m"]', ['main']), + ('a[id^="m"]', ['me']), + ) + + def test_attribute_endswith(self): + self.assertSelectMultiple( + ('[href$=".css"]', ['l1']), + ('link[href$=".css"]', ['l1']), + ('link[id$="1"]', ['l1']), + ('[id$="1"]', ['l1', 'p1', 'header1', 's1a1', 's2a1', 's1a2s1']), + ('div[id$="1"]', []), + ('[id$="noending"]', []), + ) + + def test_attribute_contains(self): + self.assertSelectMultiple( + # From test_attribute_startswith + ('[rel*="style"]', ['l1']), + ('link[rel*="style"]', ['l1']), + ('notlink[rel*="notstyle"]', []), + ('[rel*="notstyle"]', []), + ('link[rel*="notstyle"]', []), + ('link[href*="bla"]', ['l1']), + ('a[href*="http://"]', ['bob', 'me']), + ('[href*="http://"]', ['bob', 'me']), + ('[id*="p"]', ['pmulti', 'p1']), + ('div[id*="m"]', ['main']), + ('a[id*="m"]', ['me']), + # From test_attribute_endswith + ('[href*=".css"]', ['l1']), + ('link[href*=".css"]', ['l1']), + ('link[id*="1"]', ['l1']), + ('[id*="1"]', ['l1', 'p1', 'header1', 's1a1', 's1a2', 's2a1', 's1a2s1']), + ('div[id*="1"]', []), + ('[id*="noending"]', []), + # New for this test + ('[href*="."]', ['bob', 'me', 'l1']), + ('a[href*="."]', ['bob', 'me']), + ('link[href*="."]', ['l1']), + ('div[id*="n"]', ['main', 'inner']), + ('div[id*="nn"]', ['inner']), + ) + + def test_attribute_exact_or_hypen(self): + self.assertSelectMultiple( + ('p[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']), + ('[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']), + ('p[lang|="fr"]', ['lang-fr']), + ('p[lang|="gb"]', []), + ) + + def test_attribute_exists(self): + self.assertSelectMultiple( + ('[rel]', ['l1', 'bob', 'me']), + ('link[rel]', ['l1']), + ('a[rel]', ['bob', 'me']), + ('[lang]', ['lang-en', 'lang-en-gb', 'lang-en-us', 'lang-fr']), + ('p[class]', ['p1', 'pmulti']), + ('[blah]', []), + ('p[blah]', []), + ) + + def test_select_on_element(self): + # Other tests operate on the tree; this operates on an element + # within the tree. + inner = self.soup.find("div", id="main") + selected = inner.select("div") + # The
            tag was selected. The