From a878d7dc564a69fe31ea51d0b064151893a17989 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 08:40:45 -0700 Subject: [PATCH 01/45] Working on extension with pybind11 --- src/tfc/utils/BF/BF.cxx | 8 ++- src/tfc/utils/BF/BF.h | 19 ++++--- src/tfc/utils/BF/BF_Py.cc | 93 +++++++++++++++++++++++++++++++++ src/tfc/utils/BF/CMakeLists.txt | 45 ++++++++++++++++ 4 files changed, 157 insertions(+), 8 deletions(-) create mode 100644 src/tfc/utils/BF/BF_Py.cc create mode 100644 src/tfc/utils/BF/CMakeLists.txt diff --git a/src/tfc/utils/BF/BF.cxx b/src/tfc/utils/BF/BF.cxx index e9b6f0a..ac16678 100644 --- a/src/tfc/utils/BF/BF.cxx +++ b/src/tfc/utils/BF/BF.cxx @@ -22,7 +22,7 @@ void xlaWrapper(void* out, void** in){ #endif // Parent basis function class: ********************************************************************** -BasisFunc::BasisFunc(double x0in, double xf, int* nCin, int ncDim0, int min, double z0in, double zf){ +BasisFunc::BasisFunc(double x0in, double xf, const int* nCin, int ncDim0, int min, double z0in, double zf){ // Initialize internal variables based on user givens nC = new int[ncDim0]; @@ -55,7 +55,7 @@ BasisFunc::~BasisFunc(){ delete[] nC; }; -void BasisFunc::H(double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full){ +void BasisFunc::H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full){ *nOut = n; *mOut = full ? m : m-numC; @@ -1005,6 +1005,10 @@ void nBasisFunc::H(double* x, int in, int xDim1, int* d, int dDim0, int* nOut, i nHint(x,xDim1,d,dDim0,numBasis,*F,full); }; +void nBasisFunc::H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full) { + throw std::runtime_error("This version of \"H\" should never be called from an n-dimensional basis class."); +} + void nBasisFunc::xla(void* out, void** in){ double* out_buf = reinterpret_cast(out); double* x = reinterpret_cast(in[1]); diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index 92ce708..21a678e 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -62,7 +62,7 @@ class BasisFunc{ * - Stores variables based on user supplied givens * - Stores a pointer to itself using static variables * - Creates PyCapsule for xla function. */ - BasisFunc(double x0in, double xf, int* nCin, int ncDim0, int min, double z0in=0., double zf=DBL_MAX); + BasisFunc(double x0in, double xf, const int* nCin, int ncDim0, int min, double z0in=0., double zf=DBL_MAX); /** Dummy empty constructor allows derived classes without calling constructor explicitly. */ BasisFunc(){}; @@ -81,7 +81,7 @@ class BasisFunc{ * - If true, uses the x values given * - If false, uses the z values from the class * Note that this function is used to hook into Python, thus the extra arguments. */ - virtual void H(double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full); + virtual void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full); /** This function is an XLA version of the basis function. */ virtual void xla(void* out, void** in); @@ -125,7 +125,7 @@ typedef void(*xlaFnType)(void*,void**); class CP: virtual public BasisFunc { public: /** CP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - CP(double x0, double xf, int* nCin, int ncDim0, int min): + CP(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min,-1.,1.){}; /** Dummy CP class constructor. Used only in n-dimensions. */ @@ -415,11 +415,18 @@ class nBasisFunc: virtual public BasisFunc{ /** n-D basis function class destructor. */ virtual ~nBasisFunc(); + /** + * Including override of BasisFunc so we don't have issues with hidden virtual overloads. + * However, this should never be called from nBasisFunc. + * If it is, it will throw an error. + */ + void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full) override; + /** This function is used to create a basis function matrix and its derivatives. */ void H(double* x, int in, int xDim1, int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full); /** This function is an XLA version of the basis function. */ - void xla(void* out, void** in); + void xla(void* out, void** in) override; /** Python hook to return domain mapping constants. */ void getC(double** arrOut, int* nOut); @@ -435,10 +442,10 @@ class nBasisFunc: virtual public BasisFunc{ virtual void nHint(double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full); /** Function used internally to create the basis function matrices. */ - virtual void Hint(const int d, const double* x, const int nOut, double* dark) = 0; + virtual void Hint(const int d, const double* x, const int nOut, double* dark) override = 0; /** Function used internally to create derivatives of the basis function matrices. */ - virtual void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut) = 0; + virtual void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut) override = 0; }; diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc new file mode 100644 index 0000000..ad13f61 --- /dev/null +++ b/src/tfc/utils/BF/BF_Py.cc @@ -0,0 +1,93 @@ +#include +#include +#include "BF.h" + +namespace py = pybind11; + +template +void add1DInit(auto& c) { + c.def(py::init([](double x0, double xf, py::array_t nC, int min){ + return T(x0, xf, nC.data(), nC.size(), min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( + BasisFunc constructor. + + Parameters: + x0: Start of domain + xf: End of domain + nC: Array of indices to remove (1D numpy array) + min: Number of basis functions to use + )" + ); +} + +PYBIND11_MODULE(BF, m) { + + py::class_(m, "BasisFunc") + .def_readwrite("z0", &BasisFunc::z0) + .def_readwrite("x0", &BasisFunc::x0) + .def_readwrite("c", &BasisFunc::c) + .def_readwrite("m", &BasisFunc::m) + .def_readwrite("numC", &BasisFunc::numC) + .def_readwrite("identifier", &BasisFunc::identifier) + .def_property_readonly("xlaCapsule", [](BasisFunc& self) { + py::object capsule = py::reinterpret_borrow(self.xlaCapsule); + return capsule; + }) + // GPU Capsule (only if available) + #ifdef HAS_CUDA + .def_property_readonly("xlaGpuCapsule", [](BasisFunc& self) { + return py::reinterpret_borrow(self.xlaGpuCapsule); + }) + #else + .def_property_readonly("xlaGpuCapsule", [](BasisFunc&) { + return "CUDA NOT FOUND, GPU NOT IMPLEMENTED."; + }) + #endif + // Static members + .def_readonly_static("nIdentifier", &BasisFunc::nIdentifier) + .def_readonly_static("BasisFuncContainer", &BasisFunc::BasisFuncContainer) + // Methods + .def("H", + [](BasisFunc& self, + py::array_t x, + int d, + bool full) { + if (x.ndim() != 1) { + throw py::value_error("The \"x\" input array must be 1-dimensional."); + } + int n = x.size(); + int nOut = 0; + int mOut = 0; + double* F = nullptr; + self.H(x.data(), n, d, &nOut, &mOut, &F, full); + + // Wrap data in a py::capsule to ensure it gets deleted + auto capsule = py::capsule(F, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + + return py::array_t({mOut, nOut}, F, capsule); + }, + py::arg("x"), py::arg("d"), py::arg("full"), + R"( + Compute basis function matrix. + + Parameters: + x: Points (1D numpy array) + d: Derivative order + full: Whether to return full matrix (not removing nC columns) + + Returns: + mOut x nOut NumPy array. + )" + ); + + auto PyCP = py::class_ (m, "CP"); + add1DInit(PyCP); +} diff --git a/src/tfc/utils/BF/CMakeLists.txt b/src/tfc/utils/BF/CMakeLists.txt new file mode 100644 index 0000000..e212925 --- /dev/null +++ b/src/tfc/utils/BF/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.25) + +project(tfc) + +# TODO: Change for release +add_compile_options(-Wall -Werror) + +# Contorl whether we build with shared libraries or static libraries +option(BUILD_SHARED_LIBS "Build using shared libraries" ON) + +# If not building with shared libs, set POSITION_INDEPENDENT_CODE +# so that -fPIC gets used. This will +# allow linking the static libraries into the shared pybind11 libraries. +if(NOT BUILD_SHARED_LIBS) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +endif() + + +# Use C++ 20 +set(CMAKE_CXX_STANDARD 20) + +# Turn on generation of compile_commands.json for LSP +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Turn on colored diagnostics +set(CMAKE_COLOR_DIAGNOSTICS ON) + +# Find Python in the system +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) + +# Find pybind11 in the system +# This is needed for CMake < 3.27. +# After Cmake 3.27+, can remove setting PYBIND11_FINDPYTHON. +set(PYBIND11_FINDPYTHON ON) +find_package(pybind11 3.0 REQUIRED CONFIG) + +# Create the bf library +add_library(bf BF.cxx) +target_link_libraries(bf PUBLIC Python3::Python) + +# Create the BF.py Python file +pybind11_add_module(BF BF_Py.cc) +target_link_libraries(BF PRIVATE bf) + +# TODO: Stubgen From 650c0d2850e90a8576d3d432f3b5d9e2ee048163 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 09:15:20 -0700 Subject: [PATCH 02/45] Can now install the extension module and stubs. --- setup.py | 67 +++++++++++++++++++-------------- src/tfc/utils/BF/BF_Py.cc | 1 + src/tfc/utils/BF/CMakeLists.txt | 5 ++- utils/Makefile | 11 ++++++ 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/setup.py b/setup.py index 9ee3819..5f43977 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,15 @@ import sys -from os import path, name +import os +from pathlib import Path import numpy from setuptools import setup, Extension, find_packages from setuptools.command.build_py import build_py as _build_py +from setuptools.command.build_ext import build_ext +from subprocess import check_call # Get long description -this_directory = path.abspath(path.dirname(__file__)) -with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: +this_directory = os.path.abspath(os.path.dirname(__file__)) +with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: long_description = f.read() long_description = long_description.replace( '', @@ -14,12 +17,6 @@ 1, ) -# Get numpy directory -try: - numpy_include = numpy.get_include() -except AttributeError: - numpy_include = numpy.get_numpy_include() - # Get version info version_dict = {} with open("src/tfc/version.py") as f: @@ -27,7 +24,7 @@ version = version_dict["__version__"] # In the future, can add -DHAS_CUDA to this to enable GPU support -if name == "nt": +if os.name == "nt": # Windows compile flags cxxFlags = ["/O2", "/std:c++17", "/Wall", "/DWINDOWS_MSVC"] else: @@ -38,22 +35,38 @@ else: numpy_version = "numpy>=1.21.0" -# Create basis function c++ extension -BF = Extension( - "tfc.utils.BF._BF", - sources=["src/tfc/utils/BF/BF.i", "src/tfc/utils/BF/BF.cxx"], - include_dirs=["src/tfc/utils/BF", numpy_include], - swig_opts=["-c++", "-doxygen", "-O", "-olddefs"], - extra_compile_args=cxxFlags, - extra_link_args=cxxFlags, -) +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = str((Path(sourcedir) / "src" / "tfc" / "utils" / "BF").absolute()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext): + extdir = Path(self.get_ext_fullpath(ext.name)).parents[0].absolute() + bf_dir = extdir / "tfc" / "utils" / "BF" + + cfg = "Debug" if self.debug else "Release" + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-DCMAKE_INSTALL_PREFIX={bf_dir}", + ] + + # Optional: use Ninja if available + if "CMAKE_GENERATOR" not in os.environ: + cmake_args += ["-G", "Ninja"] + + build_temp = Path(self.build_temp) + build_temp.mkdir(parents=True, exist_ok=True) + + # Run CMake configuration + check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) + # Run CMake build + check_call(["cmake", "--build", ".", "--config", cfg], cwd=build_temp) -# Custom build options to include swig Python files -class build_py(_build_py): - def run(self): - self.run_command("build_ext") - super(build_py, self).run() + # Run CMake install + check_call(["cmake", "--install", "."], cwd=build_temp) # Setup @@ -72,7 +85,8 @@ def run(self): package_data={"": ["src/tfc/py.typed"]}, python_requires=">=3.10", include_package_data=True, - ext_modules=[BF], + ext_modules=[CMakeExtension("BF")], + cmdclass={"build_ext": CMakeBuild}, install_requires=[ numpy_version, "jax ~= 0.6.0", @@ -94,7 +108,4 @@ def run(self): "Topic :: Scientific/Engineering", "Topic :: Education", ], - cmdclass={ - "build_py": build_py, - }, ) diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index ad13f61..1a8be41 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -1,5 +1,6 @@ #include #include +#include #include "BF.h" namespace py = pybind11; diff --git a/src/tfc/utils/BF/CMakeLists.txt b/src/tfc/utils/BF/CMakeLists.txt index e212925..72ab0f0 100644 --- a/src/tfc/utils/BF/CMakeLists.txt +++ b/src/tfc/utils/BF/CMakeLists.txt @@ -6,7 +6,7 @@ project(tfc) add_compile_options(-Wall -Werror) # Contorl whether we build with shared libraries or static libraries -option(BUILD_SHARED_LIBS "Build using shared libraries" ON) +option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) # If not building with shared libs, set POSITION_INDEPENDENT_CODE # so that -fPIC gets used. This will @@ -42,4 +42,5 @@ target_link_libraries(bf PUBLIC Python3::Python) pybind11_add_module(BF BF_Py.cc) target_link_libraries(BF PRIVATE bf) -# TODO: Stubgen +install(TARGETS bf BF DESTINATION .) +install(CODE [=[execute_process(COMMAND stubgen -m BF -o . --include-docstring WORKING_DIRECTORY $ENV{DESTDIR}${CMAKE_INSTALL_PREFIX})]=]) diff --git a/utils/Makefile b/utils/Makefile index 4ba93bc..975dbe6 100644 --- a/utils/Makefile +++ b/utils/Makefile @@ -7,12 +7,23 @@ PYTHON_PKG_FILES=$(shell find $(SRC_DIR)) PYTHON_WHEEL=tfc-$(VERSION)-*.whl PYTHON_WHEEL_DIST=../dist/$(PYTHON_WHEEL) +ifdef VIRTUAL_ENV + USE_VENV := true +else + USE_VENV := false +endif ../dist: mkdir -p ../dist +ifeq ($(USE_VENV),true) +install: ../dist + cd ../; python setup.py bdist_wheel + uv pip uninstall tfc; uv pip install ../dist/*.whl +else install: ../dist cd ../; python setup.py bdist_wheel pip uninstall -y tfc; pip install ../dist/*.whl +endif clean-python: rm -f ../dist/*.whl From 6fd3e55fa178221aeb628d8ea993dbd5a64bebff Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 11:37:44 -0700 Subject: [PATCH 03/45] Got this working. Can at least pass the basic test. --- src/tfc/utils/BF/BF.h | 27 ++++++++++++++++++++++----- src/tfc/utils/BF/BF_Py.cc | 9 +++------ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index 21a678e..03baa3b 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -64,12 +64,17 @@ class BasisFunc{ * - Creates PyCapsule for xla function. */ BasisFunc(double x0in, double xf, const int* nCin, int ncDim0, int min, double z0in=0., double zf=DBL_MAX); - /** Dummy empty constructor allows derived classes without calling constructor explicitly. */ - BasisFunc(){}; - /** Basis function class destructor. Removes memory used by the basis function class. */ virtual ~BasisFunc(); + // Prevent copying + BasisFunc(const BasisFunc&) = delete; + BasisFunc& operator=(const BasisFunc&) = delete; + + // Prevent moving + BasisFunc(BasisFunc&&) = delete; + BasisFunc& operator=(BasisFunc&&) = delete; + /** Function is used to create a basis function matrix and its derivatives. This matrix is is an m x N matrix where: * - m is the number of basis functions * - N = in is the number of points in x @@ -92,6 +97,9 @@ class BasisFunc{ #endif protected: + /** Dummy empty constructor allows derived classes without calling constructor explicitly. */ + BasisFunc(){}; + /** This function creates a PyCapsule object that wraps the XLA verison of the basis function. */ PyObject* GetXlaCapsule(); @@ -128,13 +136,22 @@ class CP: virtual public BasisFunc { CP(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min,-1.,1.){}; - /** Dummy CP class constructor. Used only in n-dimensions. */ - CP(){}; /** CP class destructor.*/ virtual ~CP(){}; + // Prevent copying + CP(const CP&) = delete; + CP& operator=(const CP&) = delete; + + // Prevent moving + CP(CP&&) = delete; + CP& operator=(CP&&) = delete; + protected: + /** Dummy CP class constructor. Used only in n-dimensions. */ + CP(){}; + /** Function used internally to create the basis function matrices. */ void Hint(const int d, const double* x, const int nOut, double* dark); diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index 1a8be41..beadbcb 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -8,7 +8,7 @@ namespace py = pybind11; template void add1DInit(auto& c) { c.def(py::init([](double x0, double xf, py::array_t nC, int min){ - return T(x0, xf, nC.data(), nC.size(), min); + return std::make_unique(x0, xf, nC.data(), nC.size(), min); }), py::arg("x0"), py::arg("xf"), @@ -49,9 +49,6 @@ PYBIND11_MODULE(BF, m) { return "CUDA NOT FOUND, GPU NOT IMPLEMENTED."; }) #endif - // Static members - .def_readonly_static("nIdentifier", &BasisFunc::nIdentifier) - .def_readonly_static("BasisFuncContainer", &BasisFunc::BasisFuncContainer) // Methods .def("H", [](BasisFunc& self, @@ -73,7 +70,7 @@ PYBIND11_MODULE(BF, m) { free(d); }); - return py::array_t({mOut, nOut}, F, capsule); + return py::array_t({nOut, mOut}, F, capsule); }, py::arg("x"), py::arg("d"), py::arg("full"), R"( @@ -89,6 +86,6 @@ PYBIND11_MODULE(BF, m) { )" ); - auto PyCP = py::class_ (m, "CP"); + auto PyCP = py::class_ (m, "CP", py::multiple_inheritance()); add1DInit(PyCP); } From 90a7ff4baedae0e67b5de29487e4dadbd2c84312 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 11:38:53 -0700 Subject: [PATCH 04/45] Removing files we don't need anymore. --- src/tfc/utils/BF/BF.i | 79 - src/tfc/utils/BF/numpy.i | 2970 -------------------------------------- 2 files changed, 3049 deletions(-) delete mode 100644 src/tfc/utils/BF/BF.i delete mode 100644 src/tfc/utils/BF/numpy.i diff --git a/src/tfc/utils/BF/BF.i b/src/tfc/utils/BF/BF.i deleted file mode 100644 index 018bade..0000000 --- a/src/tfc/utils/BF/BF.i +++ /dev/null @@ -1,79 +0,0 @@ -// BF.i - -%module BF -%{ -#define SWIG_FILE_WITH_INIT -#include -#include -#include -#include -#ifdef HAS_CUDA - #include - #include - #include -#endif -#include "BF.h" -%} - -%feature("python:annotations", "c"); - -%include "numpy.i" -%include -%include -%apply bool* INPUT {bool* useVal}; - -%init %{ - import_array(); -%} -%ignore xlaGpuWrapper(CUstream stream, void** buffers, const char* opaque, size_t opaque_len); - -// Apply typemaps to allow hooks into Python -%apply (int* IN_ARRAY1, int DIM1){(int* d, int dDim0),(int* nCin, int ncDim0),(int* useVal, int useValDim0)}; -%apply (double* IN_ARRAY1, int DIM1){(double* x, int n),(double* cin, int cDim0),(double* arrIn, int nIn),(double* zin, int zDim0),(double* x0in, int x0Dim0),(double* xf, int xfDim0)}; -%apply (int* IN_ARRAY2, int DIM1, int DIM2){(int* nCin, int ncDim0, int ncDim1)}; -%apply (double* IN_ARRAY2, int DIM1, int DIM2){(double* zin, int zDim0, int zDim1),(double* x, int in, int xDim1),(double* arrIn, int dimIn, int nIn)}; - -%apply (int* DIM1, int* DIM2, double** ARGOUTVIEWM_ARRAY2){(int* nOut, int* mOut, double** F),(int* dimOut, int* nOut, double** arrOut)}; // Switch to ARGOUTVIEWM when you can to avoid memory leaks -%apply (double** ARGOUTVIEWM_ARRAY1, int* DIM1){(double** arrOut, int* nOut)}; - -// Add getter and setter methods - -%extend nBasisFunc{ - %rename(_getC) getC(double** arrOut, int* nOut); - - %pythoncode %{ - c = property(_getC) - %} -}; - -%extend ELM{ - %rename(_getW) getW(double** arrOut, int* nOut); - %rename(_setW) setW(double* arrIn, int nIn); - %rename(_getB) getB(double** arrOut, int* nOut); - %rename(_setB) setB(double* arrIn, int nIn); - %ignore w; - %ignore b; - - %pythoncode %{ - w = property(_getW,_setW) - b = property(_getB,_setB) - %} - -}; - -%extend nELM{ - %rename(_getW) getW(int* dimOut, int* nOut, double** arrOut); - %rename(_setW) setW(double* arrIn, int dimIn, int nIn); - %rename(_getB) getB(double** arrOut, int* nOut); - %rename(_setB) setB(double* arrIn, int nIn); - %ignore w; - %ignore b; - - %pythoncode %{ - w = property(_getW,_setW) - b = property(_getB,_setB) - %} - -}; - -%include "BF.h" diff --git a/src/tfc/utils/BF/numpy.i b/src/tfc/utils/BF/numpy.i deleted file mode 100644 index 7474466..0000000 --- a/src/tfc/utils/BF/numpy.i +++ /dev/null @@ -1,2970 +0,0 @@ -/* -*- C -*- (not really, but good for syntax highlighting) */ - -/* - * Copyright (c) 2005-2015, NumPy Developers. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following - * disclaimer in the documentation and/or other materials provided - * with the distribution. - * - * * Neither the name of the NumPy Developers nor the names of any - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifdef SWIGPYTHON - -%{ -#ifndef SWIG_FILE_WITH_INIT -#define NO_IMPORT_ARRAY -#endif -#include "stdio.h" -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include -%} - -/**********************************************************************/ - -%fragment("NumPy_Backward_Compatibility", "header") -{ -%#if NPY_API_VERSION < NPY_1_7_API_VERSION -%#define NPY_ARRAY_DEFAULT NPY_DEFAULT -%#define NPY_ARRAY_FARRAY NPY_FARRAY -%#define NPY_FORTRANORDER NPY_FORTRAN -%#endif -} - -/**********************************************************************/ - -/* The following code originally appeared in - * enthought/kiva/agg/src/numeric.i written by Eric Jones. It was - * translated from C++ to C by John Hunter. Bill Spotz has modified - * it to fix some minor bugs, upgrade from Numeric to numpy (all - * versions), add some comments and functionality, and convert from - * direct code insertion to SWIG fragments. - */ - -%fragment("NumPy_Macros", "header") -{ -/* Macros to extract array attributes. - */ -%#if NPY_API_VERSION < NPY_1_7_API_VERSION -%#define is_array(a) ((a) && PyArray_Check((PyArrayObject*)a)) -%#define array_type(a) (int)(PyArray_TYPE((PyArrayObject*)a)) -%#define array_numdims(a) (((PyArrayObject*)a)->nd) -%#define array_dimensions(a) (((PyArrayObject*)a)->dimensions) -%#define array_size(a,i) (((PyArrayObject*)a)->dimensions[i]) -%#define array_strides(a) (((PyArrayObject*)a)->strides) -%#define array_stride(a,i) (((PyArrayObject*)a)->strides[i]) -%#define array_data(a) (((PyArrayObject*)a)->data) -%#define array_descr(a) (((PyArrayObject*)a)->descr) -%#define array_flags(a) (((PyArrayObject*)a)->flags) -%#define array_clearflags(a,f) (((PyArrayObject*)a)->flags) &= ~f -%#define array_enableflags(a,f) (((PyArrayObject*)a)->flags) = f -%#define array_is_fortran(a) (PyArray_ISFORTRAN((PyArrayObject*)a)) -%#else -%#define is_array(a) ((a) && PyArray_Check(a)) -%#define array_type(a) PyArray_TYPE((PyArrayObject*)a) -%#define array_numdims(a) PyArray_NDIM((PyArrayObject*)a) -%#define array_dimensions(a) PyArray_DIMS((PyArrayObject*)a) -%#define array_strides(a) PyArray_STRIDES((PyArrayObject*)a) -%#define array_stride(a,i) PyArray_STRIDE((PyArrayObject*)a,i) -%#define array_size(a,i) PyArray_DIM((PyArrayObject*)a,i) -%#define array_data(a) PyArray_DATA((PyArrayObject*)a) -%#define array_descr(a) PyArray_DESCR((PyArrayObject*)a) -%#define array_flags(a) PyArray_FLAGS((PyArrayObject*)a) -%#define array_enableflags(a,f) PyArray_ENABLEFLAGS((PyArrayObject*)a,f) -%#define array_clearflags(a,f) PyArray_CLEARFLAGS((PyArrayObject*)a,f) -%#define array_is_fortran(a) (PyArray_IS_F_CONTIGUOUS((PyArrayObject*)a)) -%#endif -%#define array_is_contiguous(a) (PyArray_ISCONTIGUOUS((PyArrayObject*)a)) -%#define array_is_native(a) (PyArray_ISNOTSWAPPED((PyArrayObject*)a)) -} - -/**********************************************************************/ - -%fragment("NumPy_Utilities", - "header") -{ - /* Given a PyObject, return a string describing its type. - */ - const char* pytype_string(PyObject* py_obj) - { - if (py_obj == NULL ) return "C NULL value"; - if (py_obj == Py_None ) return "Python None" ; - if (PyCallable_Check(py_obj)) return "callable" ; - if (PyBytes_Check( py_obj)) return "string" ; - if (PyLong_Check( py_obj)) return "int" ; - if (PyFloat_Check( py_obj)) return "float" ; - if (PyDict_Check( py_obj)) return "dict" ; - if (PyList_Check( py_obj)) return "list" ; - if (PyTuple_Check( py_obj)) return "tuple" ; - - return "unknown type"; - } - - /* Given a NumPy typecode, return a string describing the type. - */ - const char* typecode_string(int typecode) - { - static const char* type_names[25] = {"bool", - "byte", - "unsigned byte", - "short", - "unsigned short", - "int", - "unsigned int", - "long", - "unsigned long", - "long long", - "unsigned long long", - "float", - "double", - "long double", - "complex float", - "complex double", - "complex long double", - "object", - "string", - "unicode", - "void", - "ntypes", - "notype", - "char", - "unknown"}; - return typecode < 24 ? type_names[typecode] : type_names[24]; - } - - /* Make sure input has correct numpy type. This now just calls - PyArray_EquivTypenums(). - */ - int type_match(int actual_type, - int desired_type) - { - return PyArray_EquivTypenums(actual_type, desired_type); - } - -void free_cap(PyObject * cap) - { - void* array = (void*) PyCapsule_GetPointer(cap,SWIGPY_CAPSULE_NAME); - if (array != NULL) free(array); - } - - -} - -/**********************************************************************/ - -%fragment("NumPy_Object_to_Array", - "header", - fragment="NumPy_Backward_Compatibility", - fragment="NumPy_Macros", - fragment="NumPy_Utilities") -{ - /* Given a PyObject pointer, cast it to a PyArrayObject pointer if - * legal. If not, set the python error string appropriately and - * return NULL. - */ - PyArrayObject* obj_to_array_no_conversion(PyObject* input, - int typecode) - { - PyArrayObject* ary = NULL; - if (is_array(input) && (typecode == NPY_NOTYPE || - PyArray_EquivTypenums(array_type(input), typecode))) - { - ary = (PyArrayObject*) input; - } - else if is_array(input) - { - const char* desired_type = typecode_string(typecode); - const char* actual_type = typecode_string(array_type(input)); - PyErr_Format(PyExc_TypeError, - "Array of type '%s' required. Array of type '%s' given", - desired_type, actual_type); - ary = NULL; - } - else - { - const char* desired_type = typecode_string(typecode); - const char* actual_type = pytype_string(input); - PyErr_Format(PyExc_TypeError, - "Array of type '%s' required. A '%s' was given", - desired_type, - actual_type); - ary = NULL; - } - return ary; - } - - /* Convert the given PyObject to a NumPy array with the given - * typecode. On success, return a valid PyArrayObject* with the - * correct type. On failure, the python error string will be set and - * the routine returns NULL. - */ - PyArrayObject* obj_to_array_allow_conversion(PyObject* input, - int typecode, - int* is_new_object) - { - PyArrayObject* ary = NULL; - PyObject* py_obj; - if (is_array(input) && (typecode == NPY_NOTYPE || - PyArray_EquivTypenums(array_type(input),typecode))) - { - ary = (PyArrayObject*) input; - *is_new_object = 0; - } - else - { - py_obj = PyArray_FROMANY(input, typecode, 0, 0, NPY_ARRAY_DEFAULT); - /* If NULL, PyArray_FromObject will have set python error value.*/ - ary = (PyArrayObject*) py_obj; - *is_new_object = 1; - } - return ary; - } - - /* Given a PyArrayObject, check to see if it is contiguous. If so, - * return the input pointer and flag it as not a new object. If it is - * not contiguous, create a new PyArrayObject using the original data, - * flag it as a new object and return the pointer. - */ - PyArrayObject* make_contiguous(PyArrayObject* ary, - int* is_new_object, - int min_dims, - int max_dims) - { - PyArrayObject* result; - if (array_is_contiguous(ary)) - { - result = ary; - *is_new_object = 0; - } - else - { - result = (PyArrayObject*) PyArray_ContiguousFromObject((PyObject*)ary, - array_type(ary), - min_dims, - max_dims); - *is_new_object = 1; - } - return result; - } - - /* Given a PyArrayObject, check to see if it is Fortran-contiguous. - * If so, return the input pointer, but do not flag it as not a new - * object. If it is not Fortran-contiguous, create a new - * PyArrayObject using the original data, flag it as a new object - * and return the pointer. - */ - PyArrayObject* make_fortran(PyArrayObject* ary, - int* is_new_object) - { - PyArrayObject* result; - if (array_is_fortran(ary)) - { - result = ary; - *is_new_object = 0; - } - else - { - Py_INCREF(array_descr(ary)); - result = (PyArrayObject*) PyArray_FromArray(ary, - array_descr(ary), -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - NPY_FORTRANORDER); -%#else - NPY_ARRAY_F_CONTIGUOUS); -%#endif - *is_new_object = 1; - } - return result; - } - - /* Convert a given PyObject to a contiguous PyArrayObject of the - * specified type. If the input object is not a contiguous - * PyArrayObject, a new one will be created and the new object flag - * will be set. - */ - PyArrayObject* obj_to_array_contiguous_allow_conversion(PyObject* input, - int typecode, - int* is_new_object) - { - int is_new1 = 0; - int is_new2 = 0; - PyArrayObject* ary2; - PyArrayObject* ary1 = obj_to_array_allow_conversion(input, - typecode, - &is_new1); - if (ary1) - { - ary2 = make_contiguous(ary1, &is_new2, 0, 0); - if ( is_new1 && is_new2) - { - Py_DECREF(ary1); - } - ary1 = ary2; - } - *is_new_object = is_new1 || is_new2; - return ary1; - } - - /* Convert a given PyObject to a Fortran-ordered PyArrayObject of the - * specified type. If the input object is not a Fortran-ordered - * PyArrayObject, a new one will be created and the new object flag - * will be set. - */ - PyArrayObject* obj_to_array_fortran_allow_conversion(PyObject* input, - int typecode, - int* is_new_object) - { - int is_new1 = 0; - int is_new2 = 0; - PyArrayObject* ary2; - PyArrayObject* ary1 = obj_to_array_allow_conversion(input, - typecode, - &is_new1); - if (ary1) - { - ary2 = make_fortran(ary1, &is_new2); - if (is_new1 && is_new2) - { - Py_DECREF(ary1); - } - ary1 = ary2; - } - *is_new_object = is_new1 || is_new2; - return ary1; - } -} /* end fragment */ - -/**********************************************************************/ - -%fragment("NumPy_Array_Requirements", - "header", - fragment="NumPy_Backward_Compatibility", - fragment="NumPy_Macros") -{ - /* Test whether a python object is contiguous. If array is - * contiguous, return 1. Otherwise, set the python error string and - * return 0. - */ - int require_contiguous(PyArrayObject* ary) - { - int contiguous = 1; - if (!array_is_contiguous(ary)) - { - PyErr_SetString(PyExc_TypeError, - "Array must be contiguous. A non-contiguous array was given"); - contiguous = 0; - } - return contiguous; - } - - /* Test whether a python object is (C_ or F_) contiguous. If array is - * contiguous, return 1. Otherwise, set the python error string and - * return 0. - */ - int require_c_or_f_contiguous(PyArrayObject* ary) - { - int contiguous = 1; - if (!(array_is_contiguous(ary) || array_is_fortran(ary))) - { - PyErr_SetString(PyExc_TypeError, - "Array must be contiguous (C_ or F_). A non-contiguous array was given"); - contiguous = 0; - } - return contiguous; - } - - /* Require that a numpy array is not byte-swapped. If the array is - * not byte-swapped, return 1. Otherwise, set the python error string - * and return 0. - */ - int require_native(PyArrayObject* ary) - { - int native = 1; - if (!array_is_native(ary)) - { - PyErr_SetString(PyExc_TypeError, - "Array must have native byteorder. " - "A byte-swapped array was given"); - native = 0; - } - return native; - } - - /* Require the given PyArrayObject to have a specified number of - * dimensions. If the array has the specified number of dimensions, - * return 1. Otherwise, set the python error string and return 0. - */ - int require_dimensions(PyArrayObject* ary, - int exact_dimensions) - { - int success = 1; - if (array_numdims(ary) != exact_dimensions) - { - PyErr_Format(PyExc_TypeError, - "Array must have %d dimensions. Given array has %d dimensions", - exact_dimensions, - array_numdims(ary)); - success = 0; - } - return success; - } - - /* Require the given PyArrayObject to have one of a list of specified - * number of dimensions. If the array has one of the specified number - * of dimensions, return 1. Otherwise, set the python error string - * and return 0. - */ - int require_dimensions_n(PyArrayObject* ary, - int* exact_dimensions, - int n) - { - int success = 0; - int i; - char dims_str[255] = ""; - char s[255]; - for (i = 0; i < n && !success; i++) - { - if (array_numdims(ary) == exact_dimensions[i]) - { - success = 1; - } - } - if (!success) - { - for (i = 0; i < n-1; i++) - { - sprintf(s, "%d, ", exact_dimensions[i]); - strcat(dims_str,s); - } - sprintf(s, " or %d", exact_dimensions[n-1]); - strcat(dims_str,s); - PyErr_Format(PyExc_TypeError, - "Array must have %s dimensions. Given array has %d dimensions", - dims_str, - array_numdims(ary)); - } - return success; - } - - /* Require the given PyArrayObject to have a specified shape. If the - * array has the specified shape, return 1. Otherwise, set the python - * error string and return 0. - */ - int require_size(PyArrayObject* ary, - npy_intp* size, - int n) - { - int i; - int success = 1; - size_t len; - char desired_dims[255] = "["; - char s[255]; - char actual_dims[255] = "["; - for(i=0; i < n;i++) - { - if (size[i] != -1 && size[i] != array_size(ary,i)) - { - success = 0; - } - } - if (!success) - { - for (i = 0; i < n; i++) - { - if (size[i] == -1) - { - sprintf(s, "*,"); - } - else - { - sprintf(s, "%ld,", (long int)size[i]); - } - strcat(desired_dims,s); - } - len = strlen(desired_dims); - desired_dims[len-1] = ']'; - for (i = 0; i < n; i++) - { - sprintf(s, "%ld,", (long int)array_size(ary,i)); - strcat(actual_dims,s); - } - len = strlen(actual_dims); - actual_dims[len-1] = ']'; - PyErr_Format(PyExc_TypeError, - "Array must have shape of %s. Given array has shape of %s", - desired_dims, - actual_dims); - } - return success; - } - - /* Require the given PyArrayObject to be Fortran ordered. If the - * the PyArrayObject is already Fortran ordered, do nothing. Else, - * set the Fortran ordering flag and recompute the strides. - */ - int require_fortran(PyArrayObject* ary) - { - int success = 1; - int nd = array_numdims(ary); - int i; - npy_intp * strides = array_strides(ary); - if (array_is_fortran(ary)) return success; - int n_non_one = 0; - /* Set the Fortran ordered flag */ - const npy_intp *dims = array_dimensions(ary); - for (i=0; i < nd; ++i) - n_non_one += (dims[i] != 1) ? 1 : 0; - if (n_non_one > 1) - array_clearflags(ary,NPY_ARRAY_CARRAY); - array_enableflags(ary,NPY_ARRAY_FARRAY); - /* Recompute the strides */ - strides[0] = strides[nd-1]; - for (i=1; i < nd; ++i) - strides[i] = strides[i-1] * array_size(ary,i-1); - return success; - } -} - -/* Combine all NumPy fragments into one for convenience */ -%fragment("NumPy_Fragments", - "header", - fragment="NumPy_Backward_Compatibility", - fragment="NumPy_Macros", - fragment="NumPy_Utilities", - fragment="NumPy_Object_to_Array", - fragment="NumPy_Array_Requirements") -{ -} - -/* End John Hunter translation (with modifications by Bill Spotz) - */ - -/* %numpy_typemaps() macro - * - * This macro defines a family of 75 typemaps that allow C arguments - * of the form - * - * 1. (DATA_TYPE IN_ARRAY1[ANY]) - * 2. (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1) - * 3. (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1) - * - * 4. (DATA_TYPE IN_ARRAY2[ANY][ANY]) - * 5. (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - * 6. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2) - * 7. (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - * 8. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2) - * - * 9. (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY]) - * 10. (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 11. (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 12. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3) - * 13. (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 14. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3) - * - * 15. (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY]) - * 16. (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 17. (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 18. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, , DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4) - * 19. (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 20. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4) - * - * 21. (DATA_TYPE INPLACE_ARRAY1[ANY]) - * 22. (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1) - * 23. (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1) - * - * 24. (DATA_TYPE INPLACE_ARRAY2[ANY][ANY]) - * 25. (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - * 26. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2) - * 27. (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - * 28. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2) - * - * 29. (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY]) - * 30. (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 31. (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 32. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3) - * 33. (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - * 34. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3) - * - * 35. (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY]) - * 36. (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 37. (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 38. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4) - * 39. (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - * 40. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4) - * - * 41. (DATA_TYPE ARGOUT_ARRAY1[ANY]) - * 42. (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1) - * 43. (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1) - * - * 44. (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY]) - * - * 45. (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY]) - * - * 46. (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY]) - * - * 47. (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1) - * 48. (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1) - * - * 49. (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - * 50. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2) - * 51. (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - * 52. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2) - * - * 53. (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) - * 54. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3) - * 55. (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) - * 56. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_FARRAY3) - * - * 57. (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) - * 58. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_ARRAY4) - * 59. (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) - * 60. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_FARRAY4) - * - * 61. (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1) - * 62. (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1) - * - * 63. (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - * 64. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2) - * 65. (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - * 66. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2) - * - * 67. (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) - * 68. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_ARRAY3) - * 69. (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) - * 70. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_FARRAY3) - * - * 71. (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) - * 72. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_ARRAY4) - * 73. (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) - * 74. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_FARRAY4) - * - * 75. (DATA_TYPE* INPLACE_ARRAY_FLAT, DIM_TYPE DIM_FLAT) - * - * where "DATA_TYPE" is any type supported by the NumPy module, and - * "DIM_TYPE" is any int-like type suitable for specifying dimensions. - * The difference between "ARRAY" typemaps and "FARRAY" typemaps is - * that the "FARRAY" typemaps expect Fortran ordering of - * multidimensional arrays. In python, the dimensions will not need - * to be specified (except for the "DATA_TYPE* ARGOUT_ARRAY1" - * typemaps). The IN_ARRAYs can be a numpy array or any sequence that - * can be converted to a numpy array of the specified type. The - * INPLACE_ARRAYs must be numpy arrays of the appropriate type. The - * ARGOUT_ARRAYs will be returned as new numpy arrays of the - * appropriate type. - * - * These typemaps can be applied to existing functions using the - * %apply directive. For example: - * - * %apply (double* IN_ARRAY1, int DIM1) {(double* series, int length)}; - * double prod(double* series, int length); - * - * %apply (int DIM1, int DIM2, double* INPLACE_ARRAY2) - * {(int rows, int cols, double* matrix )}; - * void floor(int rows, int cols, double* matrix, double f); - * - * %apply (double IN_ARRAY3[ANY][ANY][ANY]) - * {(double tensor[2][2][2] )}; - * %apply (double ARGOUT_ARRAY3[ANY][ANY][ANY]) - * {(double low[2][2][2] )}; - * %apply (double ARGOUT_ARRAY3[ANY][ANY][ANY]) - * {(double upp[2][2][2] )}; - * void luSplit(double tensor[2][2][2], - * double low[2][2][2], - * double upp[2][2][2] ); - * - * or directly with - * - * double prod(double* IN_ARRAY1, int DIM1); - * - * void floor(int DIM1, int DIM2, double* INPLACE_ARRAY2, double f); - * - * void luSplit(double IN_ARRAY3[ANY][ANY][ANY], - * double ARGOUT_ARRAY3[ANY][ANY][ANY], - * double ARGOUT_ARRAY3[ANY][ANY][ANY]); - */ - -%define %numpy_typemaps(DATA_TYPE, DATA_TYPECODE, DIM_TYPE) - -/************************/ -/* Input Array Typemaps */ -/************************/ - -/* Typemap suite for (DATA_TYPE IN_ARRAY1[ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE IN_ARRAY1[ANY]) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE IN_ARRAY1[ANY]) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[1] = { $1_dim0 }; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 1) || - !require_size(array, size, 1)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(freearg) - (DATA_TYPE IN_ARRAY1[ANY]) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[1] = { -1 }; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 1) || - !require_size(array, size, 1)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); -} -%typemap(freearg) - (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[1] = {-1}; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 1) || - !require_size(array, size, 1)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE IN_ARRAY2[ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE IN_ARRAY2[ANY][ANY]) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE IN_ARRAY2[ANY][ANY]) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[2] = { $1_dim0, $1_dim1 }; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 2) || - !require_size(array, size, 2)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(freearg) - (DATA_TYPE IN_ARRAY2[ANY][ANY]) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[2] = { -1, -1 }; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 2) || - !require_size(array, size, 2)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); -} -%typemap(freearg) - (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[2] = { -1, -1 }; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 2) || - !require_size(array, size, 2)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[2] = { -1, -1 }; - array = obj_to_array_fortran_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 2) || - !require_size(array, size, 2) || !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); -} -%typemap(freearg) - (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[2] = { -1, -1 }; - array = obj_to_array_fortran_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 2) || - !require_size(array, size, 2) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY]) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY]) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[3] = { $1_dim0, $1_dim1, $1_dim2 }; - array = obj_to_array_contiguous_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 3) || - !require_size(array, size, 3)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(freearg) - (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY]) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[3] = { -1, -1, -1 }; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 3) || - !require_size(array, size, 3)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); -} -%typemap(freearg) - (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - /* for now, only concerned with lists */ - $1 = PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL, int* is_new_object_array=NULL) -{ - npy_intp size[2] = { -1, -1 }; - PyArrayObject* temp_array; - Py_ssize_t i; - int is_new_object; - - /* length of the list */ - $2 = PyList_Size($input); - - /* the arrays */ - array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *)); - object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *)); - is_new_object_array = (int *)calloc($2,sizeof(int)); - - if (array == NULL || object_array == NULL || is_new_object_array == NULL) - { - SWIG_fail; - } - - for (i=0; i<$2; i++) - { - temp_array = obj_to_array_contiguous_allow_conversion(PySequence_GetItem($input,i), DATA_TYPECODE, &is_new_object); - - /* the new array must be stored so that it can be destroyed in freearg */ - object_array[i] = temp_array; - is_new_object_array[i] = is_new_object; - - if (!temp_array || !require_dimensions(temp_array, 2)) SWIG_fail; - - /* store the size of the first array in the list, then use that for comparison. */ - if (i == 0) - { - size[0] = array_size(temp_array,0); - size[1] = array_size(temp_array,1); - } - - if (!require_size(temp_array, size, 2)) SWIG_fail; - - array[i] = (DATA_TYPE*) array_data(temp_array); - } - - $1 = (DATA_TYPE**) array; - $3 = (DIM_TYPE) size[0]; - $4 = (DIM_TYPE) size[1]; -} -%typemap(freearg) - (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - Py_ssize_t i; - - if (array$argnum!=NULL) free(array$argnum); - - /*freeing the individual arrays if needed */ - if (object_array$argnum!=NULL) - { - if (is_new_object_array$argnum!=NULL) - { - for (i=0; i<$2; i++) - { - if (object_array$argnum[i] != NULL && is_new_object_array$argnum[i]) - { Py_DECREF(object_array$argnum[i]); } - } - free(is_new_object_array$argnum); - } - free(object_array$argnum); - } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, - * DATA_TYPE* IN_ARRAY3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[3] = { -1, -1, -1 }; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 3) || - !require_size(array, size, 3)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[3] = { -1, -1, -1 }; - array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 3) || - !require_size(array, size, 3) | !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); -} -%typemap(freearg) - (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, - * DATA_TYPE* IN_FARRAY3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[3] = { -1, -1, -1 }; - array = obj_to_array_fortran_allow_conversion($input, - DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 3) || - !require_size(array, size, 3) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY]) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY]) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[4] = { $1_dim0, $1_dim1, $1_dim2 , $1_dim3}; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 4) || - !require_size(array, size, 4)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(freearg) - (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY]) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[4] = { -1, -1, -1, -1 }; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 4) || - !require_size(array, size, 4)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); - $5 = (DIM_TYPE) array_size(array,3); -} -%typemap(freearg) - (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - /* for now, only concerned with lists */ - $1 = PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL, int* is_new_object_array=NULL) -{ - npy_intp size[3] = { -1, -1, -1 }; - PyArrayObject* temp_array; - Py_ssize_t i; - int is_new_object; - - /* length of the list */ - $2 = PyList_Size($input); - - /* the arrays */ - array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *)); - object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *)); - is_new_object_array = (int *)calloc($2,sizeof(int)); - - if (array == NULL || object_array == NULL || is_new_object_array == NULL) - { - SWIG_fail; - } - - for (i=0; i<$2; i++) - { - temp_array = obj_to_array_contiguous_allow_conversion(PySequence_GetItem($input,i), DATA_TYPECODE, &is_new_object); - - /* the new array must be stored so that it can be destroyed in freearg */ - object_array[i] = temp_array; - is_new_object_array[i] = is_new_object; - - if (!temp_array || !require_dimensions(temp_array, 3)) SWIG_fail; - - /* store the size of the first array in the list, then use that for comparison. */ - if (i == 0) - { - size[0] = array_size(temp_array,0); - size[1] = array_size(temp_array,1); - size[2] = array_size(temp_array,2); - } - - if (!require_size(temp_array, size, 3)) SWIG_fail; - - array[i] = (DATA_TYPE*) array_data(temp_array); - } - - $1 = (DATA_TYPE**) array; - $3 = (DIM_TYPE) size[0]; - $4 = (DIM_TYPE) size[1]; - $5 = (DIM_TYPE) size[2]; -} -%typemap(freearg) - (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - Py_ssize_t i; - - if (array$argnum!=NULL) free(array$argnum); - - /*freeing the individual arrays if needed */ - if (object_array$argnum!=NULL) - { - if (is_new_object_array$argnum!=NULL) - { - for (i=0; i<$2; i++) - { - if (object_array$argnum[i] != NULL && is_new_object_array$argnum[i]) - { Py_DECREF(object_array$argnum[i]); } - } - free(is_new_object_array$argnum); - } - free(object_array$argnum); - } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, - * DATA_TYPE* IN_ARRAY4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[4] = { -1, -1, -1 , -1}; - array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 4) || - !require_size(array, size, 4)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DIM_TYPE) array_size(array,3); - $5 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[4] = { -1, -1, -1, -1 }; - array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 4) || - !require_size(array, size, 4) | !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); - $5 = (DIM_TYPE) array_size(array,3); -} -%typemap(freearg) - (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, - * DATA_TYPE* IN_FARRAY4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4) -{ - $1 = is_array($input) || PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4) - (PyArrayObject* array=NULL, int is_new_object=0) -{ - npy_intp size[4] = { -1, -1, -1 , -1 }; - array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE, - &is_new_object); - if (!array || !require_dimensions(array, 4) || - !require_size(array, size, 4) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DIM_TYPE) array_size(array,3); - $5 = (DATA_TYPE*) array_data(array); -} -%typemap(freearg) - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4) -{ - if (is_new_object$argnum && array$argnum) - { Py_DECREF(array$argnum); } -} - -/***************************/ -/* In-Place Array Typemaps */ -/***************************/ - -/* Typemap suite for (DATA_TYPE INPLACE_ARRAY1[ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE INPLACE_ARRAY1[ANY]) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE INPLACE_ARRAY1[ANY]) - (PyArrayObject* array=NULL) -{ - npy_intp size[1] = { $1_dim0 }; - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,1) || !require_size(array, size, 1) || - !require_contiguous(array) || !require_native(array)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1) - (PyArrayObject* array=NULL, int i=1) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,1) || !require_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = 1; - for (i=0; i < array_numdims(array); ++i) $2 *= array_size(array,i); -} - -/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1) - (PyArrayObject* array=NULL, int i=0) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,1) || !require_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = 1; - for (i=0; i < array_numdims(array); ++i) $1 *= array_size(array,i); - $2 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE INPLACE_ARRAY2[ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE INPLACE_ARRAY2[ANY][ANY]) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE INPLACE_ARRAY2[ANY][ANY]) - (PyArrayObject* array=NULL) -{ - npy_intp size[2] = { $1_dim0, $1_dim1 }; - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,2) || !require_size(array, size, 2) || - !require_contiguous(array) || !require_native(array)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,2) || !require_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,2) || !require_contiguous(array) || - !require_native(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,2) || !require_contiguous(array) - || !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,2) || !require_contiguous(array) || - !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY]) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY]) - (PyArrayObject* array=NULL) -{ - npy_intp size[3] = { $1_dim0, $1_dim1, $1_dim2 }; - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,3) || !require_size(array, size, 3) || - !require_contiguous(array) || !require_native(array)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,3) || !require_contiguous(array) || - !require_native(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); -} - -/* Typemap suite for (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - $1 = PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL) -{ - npy_intp size[2] = { -1, -1 }; - PyArrayObject* temp_array; - Py_ssize_t i; - - /* length of the list */ - $2 = PyList_Size($input); - - /* the arrays */ - array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *)); - object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *)); - - if (array == NULL || object_array == NULL) - { - SWIG_fail; - } - - for (i=0; i<$2; i++) - { - temp_array = obj_to_array_no_conversion(PySequence_GetItem($input,i), DATA_TYPECODE); - - /* the new array must be stored so that it can be destroyed in freearg */ - object_array[i] = temp_array; - - if ( !temp_array || !require_dimensions(temp_array, 2) || - !require_contiguous(temp_array) || - !require_native(temp_array) || - !PyArray_EquivTypenums(array_type(temp_array), DATA_TYPECODE) - ) SWIG_fail; - - /* store the size of the first array in the list, then use that for comparison. */ - if (i == 0) - { - size[0] = array_size(temp_array,0); - size[1] = array_size(temp_array,1); - } - - if (!require_size(temp_array, size, 2)) SWIG_fail; - - array[i] = (DATA_TYPE*) array_data(temp_array); - } - - $1 = (DATA_TYPE**) array; - $3 = (DIM_TYPE) size[0]; - $4 = (DIM_TYPE) size[1]; -} -%typemap(freearg) - (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - if (array$argnum!=NULL) free(array$argnum); - if (object_array$argnum!=NULL) free(object_array$argnum); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, - * DATA_TYPE* INPLACE_ARRAY3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,3) || !require_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,3) || !require_contiguous(array) || - !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, - * DATA_TYPE* INPLACE_FARRAY3) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,3) || !require_contiguous(array) - || !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY]) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY]) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY]) - (PyArrayObject* array=NULL) -{ - npy_intp size[4] = { $1_dim0, $1_dim1, $1_dim2 , $1_dim3 }; - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,4) || !require_size(array, size, 4) || - !require_contiguous(array) || !require_native(array)) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,4) || !require_contiguous(array) || - !require_native(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); - $5 = (DIM_TYPE) array_size(array,3); -} - -/* Typemap suite for (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - $1 = PySequence_Check($input); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL) -{ - npy_intp size[3] = { -1, -1, -1 }; - PyArrayObject* temp_array; - Py_ssize_t i; - - /* length of the list */ - $2 = PyList_Size($input); - - /* the arrays */ - array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *)); - object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *)); - - if (array == NULL || object_array == NULL) - { - SWIG_fail; - } - - for (i=0; i<$2; i++) - { - temp_array = obj_to_array_no_conversion(PySequence_GetItem($input,i), DATA_TYPECODE); - - /* the new array must be stored so that it can be destroyed in freearg */ - object_array[i] = temp_array; - - if ( !temp_array || !require_dimensions(temp_array, 3) || - !require_contiguous(temp_array) || - !require_native(temp_array) || - !PyArray_EquivTypenums(array_type(temp_array), DATA_TYPECODE) - ) SWIG_fail; - - /* store the size of the first array in the list, then use that for comparison. */ - if (i == 0) - { - size[0] = array_size(temp_array,0); - size[1] = array_size(temp_array,1); - size[2] = array_size(temp_array,2); - } - - if (!require_size(temp_array, size, 3)) SWIG_fail; - - array[i] = (DATA_TYPE*) array_data(temp_array); - } - - $1 = (DATA_TYPE**) array; - $3 = (DIM_TYPE) size[0]; - $4 = (DIM_TYPE) size[1]; - $5 = (DIM_TYPE) size[2]; -} -%typemap(freearg) - (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - if (array$argnum!=NULL) free(array$argnum); - if (object_array$argnum!=NULL) free(object_array$argnum); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, - * DATA_TYPE* INPLACE_ARRAY4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,4) || !require_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DIM_TYPE) array_size(array,3); - $5 = (DATA_TYPE*) array_data(array); -} - -/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, - * DIM_TYPE DIM3, DIM_TYPE DIM4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,4) || !require_contiguous(array) || - !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = (DIM_TYPE) array_size(array,0); - $3 = (DIM_TYPE) array_size(array,1); - $4 = (DIM_TYPE) array_size(array,2); - $5 = (DIM_TYPE) array_size(array,3); -} - -/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, - * DATA_TYPE* INPLACE_FARRAY4) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4) - (PyArrayObject* array=NULL) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_dimensions(array,4) || !require_contiguous(array) - || !require_native(array) || !require_fortran(array)) SWIG_fail; - $1 = (DIM_TYPE) array_size(array,0); - $2 = (DIM_TYPE) array_size(array,1); - $3 = (DIM_TYPE) array_size(array,2); - $4 = (DIM_TYPE) array_size(array,3); - $5 = (DATA_TYPE*) array_data(array); -} - -/*************************/ -/* Argout Array Typemaps */ -/*************************/ - -/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY1[ANY]) - */ -%typemap(in,numinputs=0, - fragment="NumPy_Backward_Compatibility,NumPy_Macros") - (DATA_TYPE ARGOUT_ARRAY1[ANY]) - (PyObject* array = NULL) -{ - npy_intp dims[1] = { $1_dim0 }; - array = PyArray_SimpleNew(1, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(argout) - (DATA_TYPE ARGOUT_ARRAY1[ANY]) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/* Typemap suite for (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1) - */ -%typemap(in,numinputs=1, - fragment="NumPy_Fragments") - (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1) - (PyObject* array = NULL) -{ - npy_intp dims[1]; - if (!PyLong_Check($input)) - { - const char* typestring = pytype_string($input); - PyErr_Format(PyExc_TypeError, - "Int dimension expected. '%s' given.", - typestring); - SWIG_fail; - } - $2 = (DIM_TYPE) PyLong_AsSsize_t($input); - if ($2 == -1 && PyErr_Occurred()) SWIG_fail; - dims[0] = (npy_intp) $2; - array = PyArray_SimpleNew(1, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); -} -%typemap(argout) - (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1) - */ -%typemap(in,numinputs=1, - fragment="NumPy_Fragments") - (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1) - (PyObject* array = NULL) -{ - npy_intp dims[1]; - if (!PyLong_Check($input)) - { - const char* typestring = pytype_string($input); - PyErr_Format(PyExc_TypeError, - "Int dimension expected. '%s' given.", - typestring); - SWIG_fail; - } - $1 = (DIM_TYPE) PyLong_AsSsize_t($input); - if ($1 == -1 && PyErr_Occurred()) SWIG_fail; - dims[0] = (npy_intp) $1; - array = PyArray_SimpleNew(1, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $2 = (DATA_TYPE*) array_data(array); -} -%typemap(argout) - (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY]) - */ -%typemap(in,numinputs=0, - fragment="NumPy_Backward_Compatibility,NumPy_Macros") - (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY]) - (PyObject* array = NULL) -{ - npy_intp dims[2] = { $1_dim0, $1_dim1 }; - array = PyArray_SimpleNew(2, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(argout) - (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY]) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY]) - */ -%typemap(in,numinputs=0, - fragment="NumPy_Backward_Compatibility,NumPy_Macros") - (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY]) - (PyObject* array = NULL) -{ - npy_intp dims[3] = { $1_dim0, $1_dim1, $1_dim2 }; - array = PyArray_SimpleNew(3, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(argout) - (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY]) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY]) - */ -%typemap(in,numinputs=0, - fragment="NumPy_Backward_Compatibility,NumPy_Macros") - (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY]) - (PyObject* array = NULL) -{ - npy_intp dims[4] = { $1_dim0, $1_dim1, $1_dim2, $1_dim3 }; - array = PyArray_SimpleNew(4, dims, DATA_TYPECODE); - if (!array) SWIG_fail; - $1 = ($1_ltype) array_data(array); -} -%typemap(argout) - (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY]) -{ - $result = SWIG_AppendOutput($result,(PyObject*)array$argnum); -} - -/*****************************/ -/* Argoutview Array Typemaps */ -/*****************************/ - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim_temp) -{ - $1 = &data_temp; - $2 = &dim_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1) -{ - npy_intp dims[1] = { *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DATA_TYPE** ARGOUTVIEW_ARRAY1) - (DIM_TYPE dim_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim_temp; - $2 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1) -{ - npy_intp dims[1] = { *$1 }; - PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$2)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) -{ - npy_intp dims[2] = { *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEW_ARRAY2) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2) -{ - npy_intp dims[2] = { *$1, *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) -{ - npy_intp dims[2] = { *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEW_FARRAY2) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2) -{ - npy_intp dims[2] = { *$1, *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) -{ - npy_intp dims[3] = { *$2, *$3, *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, - DATA_TYPE** ARGOUTVIEW_ARRAY3) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3) -{ - npy_intp dims[3] = { *$1, *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) -{ - npy_intp dims[3] = { *$2, *$3, *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, - DATA_TYPE** ARGOUTVIEW_FARRAY3) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEW_FARRAY3) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_FARRAY3) -{ - npy_intp dims[3] = { *$1, *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3, DIM_TYPE* DIM4) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; - $5 = &dim4_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) -{ - npy_intp dims[4] = { *$2, *$3, *$4 , *$5 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, - DATA_TYPE** ARGOUTVIEW_ARRAY4) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEW_ARRAY4) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &dim4_temp; - $5 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_ARRAY4) -{ - npy_intp dims[4] = { *$1, *$2, *$3 , *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3, DIM_TYPE* DIM4) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; - $5 = &dim4_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) -{ - npy_intp dims[4] = { *$2, *$3, *$4 , *$5 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, - DATA_TYPE** ARGOUTVIEW_FARRAY4) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEW_FARRAY4) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &dim4_temp; - $5 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_FARRAY4) -{ - npy_intp dims[4] = { *$1, *$2, *$3 , *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - $result = SWIG_AppendOutput($result,obj); -} - -/*************************************/ -/* Managed Argoutview Array Typemaps */ -/*************************************/ - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim_temp) -{ - $1 = &data_temp; - $2 = &dim_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1) -{ - npy_intp dims[1] = { *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DATA_TYPE** ARGOUTVIEWM_ARRAY1) - (DIM_TYPE dim_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim_temp; - $2 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1) -{ - npy_intp dims[1] = { *$1 }; - PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$2)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$2), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) -{ - npy_intp dims[2] = { *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEWM_ARRAY2) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2) -{ - npy_intp dims[2] = { *$1, *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$3), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2) -{ - npy_intp dims[2] = { *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEWM_FARRAY2) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2) -{ - npy_intp dims[2] = { *$1, *$2 }; - PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$3), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) -{ - npy_intp dims[3] = { *$2, *$3, *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, - DATA_TYPE** ARGOUTVIEWM_ARRAY3) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEWM_ARRAY3) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_ARRAY3) -{ - npy_intp dims[3] = { *$1, *$2, *$3 }; - PyObject* obj= PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$4), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3) -{ - npy_intp dims[3] = { *$2, *$3, *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, - DATA_TYPE** ARGOUTVIEWM_FARRAY3) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEWM_FARRAY3) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_FARRAY3) -{ - npy_intp dims[3] = { *$1, *$2, *$3 }; - PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$4), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3, DIM_TYPE* DIM4) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; - $5 = &dim4_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) -{ - npy_intp dims[4] = { *$2, *$3, *$4 , *$5 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, - DATA_TYPE** ARGOUTVIEWM_ARRAY4) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_ARRAY4) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &dim4_temp; - $5 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_ARRAY4) -{ - npy_intp dims[4] = { *$1, *$2, *$3 , *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$5), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, - DIM_TYPE* DIM3, DIM_TYPE* DIM4) - */ -%typemap(in,numinputs=0) - (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 ) - (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp) -{ - $1 = &data_temp; - $2 = &dim1_temp; - $3 = &dim2_temp; - $4 = &dim3_temp; - $5 = &dim4_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4) -{ - npy_intp dims[4] = { *$2, *$3, *$4 , *$5 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, - DATA_TYPE** ARGOUTVIEWM_FARRAY4) - */ -%typemap(in,numinputs=0) - (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_FARRAY4) - (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL ) -{ - $1 = &dim1_temp; - $2 = &dim2_temp; - $3 = &dim3_temp; - $4 = &dim4_temp; - $5 = &data_temp; -} -%typemap(argout, - fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities") - (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_FARRAY4) -{ - npy_intp dims[4] = { *$1, *$2, *$3 , *$4 }; - PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5)); - PyArrayObject* array = (PyArrayObject*) obj; - - if (!array || !require_fortran(array)) SWIG_fail; - -PyObject* cap = PyCapsule_New((void*)(*$5), SWIGPY_CAPSULE_NAME, free_cap); - -%#if NPY_API_VERSION < NPY_1_7_API_VERSION - PyArray_BASE(array) = cap; -%#else - PyArray_SetBaseObject(array,cap); -%#endif - - $result = SWIG_AppendOutput($result,obj); -} - -/**************************************/ -/* In-Place Array Typemap - flattened */ -/**************************************/ - -/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY_FLAT, DIM_TYPE DIM_FLAT) - */ -%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY, - fragment="NumPy_Macros") - (DATA_TYPE* INPLACE_ARRAY_FLAT, DIM_TYPE DIM_FLAT) -{ - $1 = is_array($input) && PyArray_EquivTypenums(array_type($input), - DATA_TYPECODE); -} -%typemap(in, - fragment="NumPy_Fragments") - (DATA_TYPE* INPLACE_ARRAY_FLAT, DIM_TYPE DIM_FLAT) - (PyArrayObject* array=NULL, int i=1) -{ - array = obj_to_array_no_conversion($input, DATA_TYPECODE); - if (!array || !require_c_or_f_contiguous(array) - || !require_native(array)) SWIG_fail; - $1 = (DATA_TYPE*) array_data(array); - $2 = 1; - for (i=0; i < array_numdims(array); ++i) $2 *= array_size(array,i); -} - -%enddef /* %numpy_typemaps() macro */ -/* *************************************************************** */ - -/* Concrete instances of the %numpy_typemaps() macro: Each invocation - * below applies all of the typemaps above to the specified data type. - */ -%numpy_typemaps(signed char , NPY_BYTE , int) -%numpy_typemaps(unsigned char , NPY_UBYTE , int) -%numpy_typemaps(short , NPY_SHORT , int) -%numpy_typemaps(unsigned short , NPY_USHORT , int) -%numpy_typemaps(int , NPY_INT , int) -%numpy_typemaps(unsigned int , NPY_UINT , int) -%numpy_typemaps(long , NPY_LONG , int) -%numpy_typemaps(unsigned long , NPY_ULONG , int) -%numpy_typemaps(long long , NPY_LONGLONG , int) -%numpy_typemaps(unsigned long long, NPY_ULONGLONG, int) -%numpy_typemaps(float , NPY_FLOAT , int) -%numpy_typemaps(double , NPY_DOUBLE , int) -%numpy_typemaps(int8_t , NPY_INT8 , int) -%numpy_typemaps(int16_t , NPY_INT16 , int) -%numpy_typemaps(int32_t , NPY_INT32 , int) -%numpy_typemaps(int64_t , NPY_INT64 , int) -%numpy_typemaps(uint8_t , NPY_UINT8 , int) -%numpy_typemaps(uint16_t , NPY_UINT16 , int) -%numpy_typemaps(uint32_t , NPY_UINT32 , int) -%numpy_typemaps(uint64_t , NPY_UINT64 , int) - - -/* *************************************************************** - * The follow macro expansion does not work, because C++ bool is 4 - * bytes and NPY_BOOL is 1 byte - * - * %numpy_typemaps(bool, NPY_BOOL, int) - */ - -/* *************************************************************** - * On my Mac, I get the following warning for this macro expansion: - * 'swig/python detected a memory leak of type 'long double *', no destructor found.' - * - * %numpy_typemaps(long double, NPY_LONGDOUBLE, int) - */ - -#ifdef __cplusplus - -%include - -%numpy_typemaps(std::complex, NPY_CFLOAT , int) -%numpy_typemaps(std::complex, NPY_CDOUBLE, int) - -#endif - -#endif /* SWIGPYTHON */ From 35c0c244eda2abe0ac03eb9465ec3e2ec05d93c2 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 11:47:33 -0700 Subject: [PATCH 05/45] Adding other non-ELM 1-D basis functions. --- src/tfc/utils/BF/BF.h | 10 +++++----- src/tfc/utils/BF/BF_Py.cc | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index 03baa3b..6cbfe27 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -164,7 +164,7 @@ class CP: virtual public BasisFunc { class LeP: virtual public BasisFunc { public: /** LeP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - LeP(double x0, double xf, int* nCin, int ncDim0, int min): + LeP(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min,-1.,1.){}; /** Dummy LeP class constructor. Used only in n-dimensions. */ @@ -186,7 +186,7 @@ class LeP: virtual public BasisFunc { class LaP: public BasisFunc { public: /** LaP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - LaP(double x0, double xf, int* nCin, int ncDim0, int min): + LaP(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min){}; /** LaP class destructor.*/ ~LaP(){}; @@ -204,7 +204,7 @@ class LaP: public BasisFunc { class HoPpro: public BasisFunc { public: /** HoPpro class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - HoPpro(double x0, double xf, int* nCin, int ncDim0, int min): + HoPpro(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min){}; /** HoPpro class destructor.*/ ~HoPpro(){}; @@ -222,7 +222,7 @@ class HoPpro: public BasisFunc { class HoPphy: public BasisFunc { public: /** HoPphy class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - HoPphy(double x0, double xf, int* nCin, int ncDim0, int min): + HoPphy(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min){}; /** HoPphy class destructor.*/ ~HoPphy(){}; @@ -240,7 +240,7 @@ class HoPphy: public BasisFunc { class FS: virtual public BasisFunc { public: /** FS class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - FS(double x0, double xf, int* nCin, int ncDim0, int min): + FS(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min,-M_PI,M_PI){}; /** Dummy FS class constructor. Used only in n-dimensions. */ diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index beadbcb..141e24f 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -88,4 +88,19 @@ PYBIND11_MODULE(BF, m) { auto PyCP = py::class_ (m, "CP", py::multiple_inheritance()); add1DInit(PyCP); + + auto PyLeP = py::class_ (m, "LeP", py::multiple_inheritance()); + add1DInit(PyLeP); + + auto PyLaP = py::class_ (m, "LaP", py::multiple_inheritance()); + add1DInit(PyLaP); + + auto PyHoPpro = py::class_ (m, "HoPpro", py::multiple_inheritance()); + add1DInit(PyHoPpro); + + auto PyHoPphy = py::class_ (m, "HoPphy", py::multiple_inheritance()); + add1DInit(PyHoPphy); + + auto PyFS = py::class_ (m, "FS", py::multiple_inheritance()); + add1DInit(PyFS); } From f09c06fd15bc2d932ed24434956b6f2d7a2d3ddb Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 11:47:47 -0700 Subject: [PATCH 06/45] Updating so we can find pybind11 from anywhere on different machines. --- setup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.py b/setup.py index 5f43977..91d3f4d 100644 --- a/setup.py +++ b/setup.py @@ -46,10 +46,16 @@ def build_extension(self, ext): extdir = Path(self.get_ext_fullpath(ext.name)).parents[0].absolute() bf_dir = extdir / "tfc" / "utils" / "BF" + import pybind11 + dark = Path(pybind11.__file__).parents[0] + pybind11_dir = dark / "share" / "cmake" / "pybind11" + + cfg = "Debug" if self.debug else "Release" cmake_args = [ f"-DCMAKE_BUILD_TYPE={cfg}", f"-DCMAKE_INSTALL_PREFIX={bf_dir}", + f"-Dpybind11_DIR={pybind11_dir}" ] # Optional: use Ninja if available From 42f490457aec1cba385ff220020dadfbcd53a613 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 26 Jul 2025 11:59:29 -0700 Subject: [PATCH 07/45] Adding pybind11 and removing unecessary imports. --- pyproject.toml | 1 + requirements.txt | 1 + setup.py | 2 -- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3d1f7f..bb9fa2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,7 @@ requires = ["setuptools>=42", "wheel", "numpy>=2.1", + "pybind11~=3.0.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 1a8cb36..6d33da3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ tqdm pandas openpyxl sympy +pybind11 # Optional # pdfCropMargins # Used to crop pdfs in PlotlyMakePlot diff --git a/setup.py b/setup.py index 91d3f4d..b8492bf 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,7 @@ import sys import os from pathlib import Path -import numpy from setuptools import setup, Extension, find_packages -from setuptools.command.build_py import build_py as _build_py from setuptools.command.build_ext import build_ext from subprocess import check_call From d9956df8c053137dc287ab6d331032ad6a1c59c6 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 27 Jul 2025 07:27:42 -0700 Subject: [PATCH 08/45] Adding in 1-D ELM functions. --- src/tfc/utils/BF/BF.cxx | 10 ++++---- src/tfc/utils/BF/BF.h | 20 ++++++++-------- src/tfc/utils/BF/BF_Py.cc | 49 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/src/tfc/utils/BF/BF.cxx b/src/tfc/utils/BF/BF.cxx index ac16678..93ea82d 100644 --- a/src/tfc/utils/BF/BF.cxx +++ b/src/tfc/utils/BF/BF.cxx @@ -578,7 +578,7 @@ void FS::Hint(const int d, const double* x, const int nOut, double* dark){ // ELM: ********************************************************************** // ELM base class -ELM::ELM(double x0, double xf, int* nCin, int ncDim0, int min): +ELM::ELM(double x0, double xf, const int* nCin, int ncDim0, int min): BasisFunc(x0,xf,nCin,ncDim0,min,0.,1.){ int k; @@ -596,7 +596,7 @@ ELM::~ELM(){ delete[] w; }; -void ELM::setW(double* arrIn, int nIn){ +void ELM::setW(const double* arrIn, int nIn){ if (nIn != m){ printf("Failure in setW function. Weight vector is the wrong size. Exiting program.\n"); exit(EXIT_FAILURE); @@ -605,7 +605,7 @@ void ELM::setW(double* arrIn, int nIn){ w[k] = arrIn[k]; }; -void ELM::setB(double* arrIn, int nIn){ +void ELM::setB(const double* arrIn, int nIn){ if (nIn != m){ printf("Failure in setB function. Bias vector is the wrong size. Exiting program.\n"); exit(EXIT_FAILURE); @@ -1238,7 +1238,7 @@ nELM::~nELM(){ delete[] w; }; -void nELM::setW(double* arrIn, int dimIn, int nIn){ +void nELM::setW(const double* arrIn, int dimIn, int nIn){ if ((nIn != m)||(dimIn != dim)){ printf("Failure in setW function. Weight vector is the wrong size. Exiting program.\n"); exit(EXIT_FAILURE); @@ -1247,7 +1247,7 @@ void nELM::setW(double* arrIn, int dimIn, int nIn){ w[k] = arrIn[k]; }; -void nELM::setB(double* arrIn, int nIn){ +void nELM::setB(const double* arrIn, int nIn){ if (nIn != m){ printf("Failure in setB function. Bias vector is the wrong size. Exiting program.\n"); exit(EXIT_FAILURE); diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index 6cbfe27..1bda52d 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -272,7 +272,7 @@ class ELM: public BasisFunc { double *b; /** ELM class constructor. Calls BasisFunc class constructor and sets up weights and biases for the ELM. See BasisFunc class for more details. */ - ELM(double x0, double xf, int* nCin, int ncDim0, int min); + ELM(double x0, double xf, const int* nCin, int ncDim0, int min); /** ELM class destructor.*/ virtual ~ELM(); @@ -281,13 +281,13 @@ class ELM: public BasisFunc { void getW(double** arrOut, int* nOut); /** Python hook to set ELM weights. */ - void setW(double* arrIn, int nIn); + void setW(const double* arrIn, int nIn); /** Python hook to return ELM biases. */ void getB(double** arrOut, int* nOut); /** Python hook to set ELM biases. */ - void setB(double* arrIn, int nIn); + void setB(const double* arrIn, int nIn); protected: /** Function used internally to create the basis function matrices. */ @@ -307,7 +307,7 @@ class ELMSigmoid: public ELM { public: /** ELMSigmoid class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSigmoid(double x0, double xf, int* nCin, int ncDim0, int min): + ELMSigmoid(double x0, double xf, const int* nCin, int ncDim0, int min): ELM(x0,xf,nCin,ncDim0,min){}; /** ELMSigmoid class destructor.*/ @@ -325,7 +325,7 @@ class ELMReLU: public ELM { public: /** ELMReLU class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMReLU(double x0, double xf, int* nCin, int ncDim0, int min): + ELMReLU(double x0, double xf, const int* nCin, int ncDim0, int min): ELM(x0,xf,nCin,ncDim0,min){}; /** ELMReLU class destructor.*/ @@ -343,7 +343,7 @@ class ELMTanh: public ELM { public: /** ELMTanh class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMTanh(double x0, double xf, int* nCin, int ncDim0, int min): + ELMTanh(double x0, double xf, const int* nCin, int ncDim0, int min): ELM(x0,xf,nCin,ncDim0,min){}; /** ELMTanh class destructor.*/ @@ -361,7 +361,7 @@ class ELMSin: public ELM { public: /** ELMSin class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSin(double x0, double xf, int* nCin, int ncDim0, int min): + ELMSin(double x0, double xf, const int* nCin, int ncDim0, int min): ELM(x0,xf,nCin,ncDim0,min){}; /** ELMSin class destructor.*/ @@ -379,7 +379,7 @@ class ELMSwish: public ELM { public: /** ELMSwish class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSwish(double x0, double xf, int* nCin, int ncDim0, int min): + ELMSwish(double x0, double xf, const int* nCin, int ncDim0, int min): ELM(x0,xf,nCin,ncDim0,min){}; /** ELMSwish class destructor.*/ @@ -552,7 +552,7 @@ class nELM: public nBasisFunc { virtual ~nELM(); /** Python hook to return nELM weights. */ - void setW(double* arrIn, int dimIn, int nIn); + void setW(const double* arrIn, int dimIn, int nIn); /** Python hook to set nELM weights. */ void getW(int* dimOut, int* nOut, double** arrOut); @@ -561,7 +561,7 @@ class nELM: public nBasisFunc { void getB(double** arrOut, int* nOut); /** Python hook to set nELM biases. */ - void setB(double* arrIn, int nIn); + void setB(const double* arrIn, int nIn); private: diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index 141e24f..af0e8d3 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -15,7 +15,7 @@ void add1DInit(auto& c) { py::arg("nC"), py::arg("min"), R"( - BasisFunc constructor. + Constructor. Parameters: x0: Start of domain @@ -103,4 +103,51 @@ PYBIND11_MODULE(BF, m) { auto PyFS = py::class_ (m, "FS", py::multiple_inheritance()); add1DInit(PyFS); + + py::class_ (m, "ELM") + .def_property("b", + [](ELM& self) { + double* data = nullptr; + int nOut; + self.getB(&data, &nOut); + + auto capsule = py::capsule(data, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](ELM& self, py::array_t b) { + self.setB(b.data(), b.size()); + }) + .def_property("w", + [](ELM& self) { + double* data = nullptr; + int nOut; + self.getW(&data, &nOut); + + auto capsule = py::capsule(data, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](ELM& self, py::array_t w) { + self.setW(w.data(), w.size()); + }); + + auto PyELMSigmoid = py::class_ (m, "ELMSigmoid"); + add1DInit(PyELMSigmoid); + + auto PyELMReLU = py::class_ (m, "ELMReLU"); + add1DInit(PyELMReLU); + + auto PyELMTanh = py::class_ (m, "ELMTanh"); + add1DInit(PyELMTanh); + + auto PyELMSin = py::class_ (m, "ELMSin"); + add1DInit(PyELMSin); + + auto PyELMSwish = py::class_ (m, "ELMSwish"); + add1DInit(PyELMSwish); } From fc928ecd69b159c2db3b7b42e82588bdfa226526 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 27 Jul 2025 07:29:32 -0700 Subject: [PATCH 09/45] Fixing classifiers warning. --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index b8492bf..6fac8a3 100644 --- a/setup.py +++ b/setup.py @@ -105,7 +105,6 @@ def build_extension(self, ext): ], classifiers=[ "Development Status :: 4 - Beta", - "License :: OSI Approved :: MIT License", "Natural Language :: English", "Programming Language :: C++", "Programming Language :: Python :: 3 :: Only", From 3e77389841d1674b1a037013e5a29b6207fef410 Mon Sep 17 00:00:00 2001 From: leakec Date: Tue, 29 Jul 2025 16:35:28 -0700 Subject: [PATCH 10/45] Working on updating ci to build. --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 26872c7..cf957a1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,8 +21,8 @@ jobs: - name: lint run: "python -m black --line-length 100 --check ./src/tfc" - - run: "sudo apt-get update && sudo apt-get install -y swig gcc g++" - - run: python -m pip install wheel setuptools numpy pytest + - run: "sudo apt-get update && sudo apt-get install -y gcc g++" + - run: python -m pip install wheel setuptools numpy pytest pybind11 - run: python setup.py bdist_wheel - run: pip install ./dist/*.whl From 4d66266454e28ecb971c6e93363c35af4e8d947e Mon Sep 17 00:00:00 2001 From: leakec Date: Tue, 29 Jul 2025 16:35:46 -0700 Subject: [PATCH 11/45] Got nBasisFuncs to compile and pass reg tests. Now, only nELMs remain. --- src/tfc/utils/BF/BF.cxx | 8 ++--- src/tfc/utils/BF/BF.h | 35 +++++++++--------- src/tfc/utils/BF/BF_Py.cc | 75 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 21 deletions(-) diff --git a/src/tfc/utils/BF/BF.cxx b/src/tfc/utils/BF/BF.cxx index 93ea82d..227a0c1 100644 --- a/src/tfc/utils/BF/BF.cxx +++ b/src/tfc/utils/BF/BF.cxx @@ -937,7 +937,7 @@ void ELMSwish::Hint(const int d, const double* x, const int nOut, double* dark){ }; // Parent n-dimensional basis function class: ********************************************************************** -nBasisFunc::nBasisFunc(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int ncDim1, int min, double z0in, double zfin){ +nBasisFunc::nBasisFunc(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min, double z0in, double zfin){ // Initialize internal variables based on user givens dim = x0Dim0; @@ -997,7 +997,7 @@ void nBasisFunc::getC(double** arrOut, int* nOut){ return; }; -void nBasisFunc::H(double* x, int in, int xDim1, int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full){ +void nBasisFunc::H(const double* x, int in, int xDim1, const int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full){ int numBasis = full ? numBasisFuncFull : numBasisFunc; *mOut = numBasis; *nOut = xDim1; @@ -1022,7 +1022,7 @@ void nBasisFunc::xla(void* out, void** in){ }; -void nBasisFunc::nHint(double* x, int n, const int* d, int dDim0, int numBasis, double*& F, const bool full){ +void nBasisFunc::nHint(const double* x, int n, const int* d, int dDim0, int numBasis, double*& F, const bool full){ int j,k; double* dark = new double[n*m]; @@ -1273,7 +1273,7 @@ void nELM::getB(double** arrOut, int* nOut){ return; }; -void nELM::nHint(double* x, int n, const int* d, int dDim0, int numBasis, double*& F, const bool full){ +void nELM::nHint(const double* x, int n, const int* d, int dDim0, int numBasis, double*& F, const bool full){ int j,k; double* z = new double[n*dim]; diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index 1bda52d..d90d3e4 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -424,23 +424,13 @@ class nBasisFunc: virtual public BasisFunc{ public: /** n-D basis function class constructor. */ - nBasisFunc(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int ncDim1, int min, double z0in=0., double zfin=0.); - - /** Dummy nBasisFunc constructor used by nELM only. */ - nBasisFunc(){}; + nBasisFunc(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min, double z0in=0., double zfin=0.); /** n-D basis function class destructor. */ virtual ~nBasisFunc(); - /** - * Including override of BasisFunc so we don't have issues with hidden virtual overloads. - * However, this should never be called from nBasisFunc. - * If it is, it will throw an error. - */ - void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full) override; - /** This function is used to create a basis function matrix and its derivatives. */ - void H(double* x, int in, int xDim1, int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full); + void H(const double* x, int in, int xDim1, const int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full); /** This function is an XLA version of the basis function. */ void xla(void* out, void** in) override; @@ -448,7 +438,18 @@ class nBasisFunc: virtual public BasisFunc{ /** Python hook to return domain mapping constants. */ void getC(double** arrOut, int* nOut); + protected: + /** Dummy nBasisFunc constructor used by nELM only. */ + nBasisFunc(){}; + private: + /** + * Including override of BasisFunc so we don't have issues with hidden virtual overloads. + * However, this should never be called from nBasisFunc. + * If it is, it will throw an error. + */ + void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full) override; + /** Recursive function used to perform the tensor product of univarite basis functions to form multivariate basis functions. */ void RecurseBasis(int dimCurr, int* vec, int &count, const bool full, const int in, const int numBasis, const double* T, double* out); @@ -456,7 +457,7 @@ class nBasisFunc: virtual public BasisFunc{ void NumBasisFunc(int dimCurr, int* vec, int &count, const bool full); /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, that if dDim0 < dim, then 0's will be used for the tail end.*/ - virtual void nHint(double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full); + virtual void nHint(const double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full); /** Function used internally to create the basis function matrices. */ virtual void Hint(const int d, const double* x, const int nOut, double* dark) override = 0; @@ -473,7 +474,7 @@ class nCP: public nBasisFunc, public CP { public: /** nCP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nCP(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; + nCP(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; /** nCP class destructor.*/ ~nCP(){}; @@ -494,7 +495,7 @@ class nLeP: public nBasisFunc, public LeP { public: /** nLeP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nLeP(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; + nLeP(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; /** nLeP class destructor.*/ ~nLeP(){}; @@ -514,7 +515,7 @@ class nFS: public nBasisFunc, public FS { public: /** nFS class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nFS(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-M_PI,M_PI){}; + nFS(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-M_PI,M_PI){}; /** nFS class destructor.*/ ~nFS(){}; @@ -566,7 +567,7 @@ class nELM: public nBasisFunc { private: /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, that if dDim0 < dim, then 0's will be used for the tail end.*/ - void nHint(double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full) override; + void nHint(const double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full) override; /** This function handles creating a full matrix of nELM basis functions. */ virtual void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) = 0; diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index af0e8d3..3a80324 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -26,6 +26,27 @@ void add1DInit(auto& c) { ); } +template +void addNdInit(auto& c) { + c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min){ + return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.shape()[0], nC.shape()[1], min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( + Constructor. + + Parameters: + x0: Start of domain + xf: End of domain + nC: Array of indices to remove (2D numpy array) + min: Number of basis functions to use + )" + ); +} + PYBIND11_MODULE(BF, m) { py::class_(m, "BasisFunc") @@ -150,4 +171,58 @@ PYBIND11_MODULE(BF, m) { auto PyELMSwish = py::class_ (m, "ELMSwish"); add1DInit(PyELMSwish); + + // TODO: Finish members and add methods. + py::class_(m, "nBasisFunc", py::multiple_inheritance()) + .def_readwrite("z0", &nBasisFunc::z0) + .def_readwrite("zf", &nBasisFunc::zf) + .def_readwrite("dim", &nBasisFunc::dim) + .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) + .def_readwrite("numBasisFuncFull", &nBasisFunc::numBasisFuncFull) + .def("H", + [](nBasisFunc& self, + py::array_t x, + py::array_t d, + bool full) { + if (x.ndim() != 2) { + throw py::value_error("The \"x\" input array must be 1-dimensional."); + } + if (d.ndim() != 1) { + throw py::value_error("The \"d\" input array must be 1-dimensional."); + } + int nOut = 0; + int mOut = 0; + double* F = nullptr; + self.H(x.data(), x.shape()[0], x.shape()[1], d.data(), d.shape()[0], &nOut, &mOut, &F, full); + + // Wrap data in a py::capsule to ensure it gets deleted + auto capsule = py::capsule(F, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + + return py::array_t({nOut, mOut}, F, capsule); + }, + py::arg("x"), py::arg("d"), py::arg("full"), + R"( + Compute basis function matrix. + + Parameters: + x: Points (1D numpy array) + d: Derivative order + full: Whether to return full matrix (not removing nC columns) + + Returns: + mOut x nOut NumPy array. + )" + ); + + auto PynCP = py::class_ (m, "nCP"); + addNdInit(PynCP); + + auto PynLeP = py::class_ (m, "nLeP"); + addNdInit(PynLeP); + + auto PynFS = py::class_ (m, "nFS"); + addNdInit(PynFS); } From 464ca6fb5874bd8ca86706ff1f618375bf3d854f Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:07:55 -0700 Subject: [PATCH 12/45] Adding nELM functions. --- src/tfc/utils/BF/BF.cxx | 2 +- src/tfc/utils/BF/BF.h | 12 ++--- src/tfc/utils/BF/BF_Py.cc | 95 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/src/tfc/utils/BF/BF.cxx b/src/tfc/utils/BF/BF.cxx index 227a0c1..a0fee14 100644 --- a/src/tfc/utils/BF/BF.cxx +++ b/src/tfc/utils/BF/BF.cxx @@ -1178,7 +1178,7 @@ void nBasisFunc::RecurseBasis(int dimCurr, int* vec, int &count, const bool full }; // nELM base class: *********************************************************************************** -nELM::nELM(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min, double z0in, double zfin){ +nELM::nELM(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min, double z0in, double zfin){ int k; bool flag = true; diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF/BF.h index d90d3e4..50ee20e 100644 --- a/src/tfc/utils/BF/BF.h +++ b/src/tfc/utils/BF/BF.h @@ -547,7 +547,7 @@ class nELM: public nBasisFunc { double *b; /** n-D ELM class constructor. */ - nELM(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min, double z0in=0., double zfin=1.); + nELM(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min, double z0in=0., double zfin=1.); /** n-D ELM class destructor. */ virtual ~nELM(); @@ -592,7 +592,7 @@ class nELMSigmoid: public nELM { public: /** nELMSigmoid class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSigmoid(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0,int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; + nELMSigmoid(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0,int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; /** nELMSigmoid class destructor.*/ ~nELMSigmoid(){}; @@ -608,7 +608,7 @@ class nELMTanh: public nELM { public: /** nELMTanh class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMTanh(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; + nELMTanh(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; /** nELMTanh class destructor.*/ ~nELMTanh(){}; @@ -624,7 +624,7 @@ class nELMSin: public nELM { public: /** nELMSin class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSin(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; + nELMSin(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; /** nELMSin class destructor.*/ ~nELMSin(){}; @@ -640,7 +640,7 @@ class nELMSwish: public nELM { public: /** nELMSwish class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSwish(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; + nELMSwish(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; /** nELMSwish class destructor.*/ ~nELMSwish(){}; @@ -656,7 +656,7 @@ class nELMReLU: public nELM { public: /** nELMReLU class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMReLU(double* x0in, int x0Dim0, double* xf, int xfDim0, int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; + nELMReLU(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; /** nELMReLU class destructor.*/ ~nELMReLU(){}; diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index 3a80324..9d2bdb8 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -8,6 +8,9 @@ namespace py = pybind11; template void add1DInit(auto& c) { c.def(py::init([](double x0, double xf, py::array_t nC, int min){ + if (nC.ndim() != 1) { + throw py::value_error("The \"nC\" input array must be 1-dimensional."); + } return std::make_unique(x0, xf, nC.data(), nC.size(), min); }), py::arg("x0"), @@ -29,6 +32,15 @@ void add1DInit(auto& c) { template void addNdInit(auto& c) { c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min){ + if (x0.ndim() != 1) { + throw py::value_error("The \"x0\" input array must be 1-dimensional."); + } + if (xf.ndim() != 1) { + throw py::value_error("The \"xf\" input array must be 1-dimensional."); + } + if (nC.ndim() != 2) { + throw py::value_error("The \"nC\" input array must be 2-dimensional."); + } return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.shape()[0], nC.shape()[1], min); }), py::arg("x0"), @@ -47,6 +59,36 @@ void addNdInit(auto& c) { ); } +template +void addNdElmInit(auto& c) { + c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min){ + if (x0.ndim() != 1) { + throw py::value_error("The \"x0\" input array must be 1-dimensional."); + } + if (xf.ndim() != 1) { + throw py::value_error("The \"xf\" input array must be 1-dimensional."); + } + if (nC.ndim() != 1) { + throw py::value_error("The \"nC\" input array must be 1-dimensional."); + } + return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.size(), min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( + Constructor. + + Parameters: + x0: Start of domain (1D numpy array) + xf: End of domain (1D numpy array) + nC: Array of indices to remove (1D numpy array) + min: Number of basis functions to use + )" + ); +} + PYBIND11_MODULE(BF, m) { py::class_(m, "BasisFunc") @@ -185,7 +227,7 @@ PYBIND11_MODULE(BF, m) { py::array_t d, bool full) { if (x.ndim() != 2) { - throw py::value_error("The \"x\" input array must be 1-dimensional."); + throw py::value_error("The \"x\" input array must be 2-dimensional."); } if (d.ndim() != 1) { throw py::value_error("The \"d\" input array must be 1-dimensional."); @@ -225,4 +267,55 @@ PYBIND11_MODULE(BF, m) { auto PynFS = py::class_ (m, "nFS"); addNdInit(PynFS); + + py::class_ (m, "nELM") + .def_property("b", + [](nELM& self) { + double* data = nullptr; + int nOut; + self.getB(&data, &nOut); + + auto capsule = py::capsule(data, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](nELM& self, py::array_t b) { + self.setB(b.data(), b.size()); + }) + .def_property("w", + [](nELM& self) { + double* data = nullptr; + int nOut; + int dimOut; + self.getW(&dimOut, &nOut, &data); + + auto capsule = py::capsule(data, [](void* f) { + double* d = reinterpret_cast(f); + free(d); + }); + return py::array_t({dimOut, nOut}, data, capsule); + }, + [](nELM& self, py::array_t w) { + if (w.ndim() != 2) { + throw py::value_error("The \"w\" input array must be 2-dimensional."); + } + self.setW(w.data(), w.shape()[0], w.shape()[1]); + }); + + auto PynELMSigmoid = py::class_ (m, "nELMSigmoid"); + addNdElmInit(PynELMSigmoid); + + auto PynELMTanh = py::class_ (m, "nELMTanh"); + addNdElmInit(PynELMTanh); + + auto PynELMSin = py::class_ (m, "nELMSin"); + addNdElmInit(PynELMSin); + + auto PynELMSwish = py::class_ (m, "nELMSwish"); + addNdElmInit(PynELMSwish); + + auto PynELMReLU = py::class_ (m, "nELMReLU"); + addNdElmInit(PynELMReLU); } From a34e9410c19598bdb1aa6485a00ece5a92fa8f95 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:43:35 -0700 Subject: [PATCH 13/45] Updating so unit tests pass. --- requirements.txt | 4 ++-- setup.py | 4 ++-- src/tfc/mtfc.py | 40 +++++++++++++++++---------------------- src/tfc/utils/BF/BF_Py.cc | 16 ++++++++++++++++ 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6d33da3..e42d76d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ numpy >= 1.25 scipy >= 1.11 -jax ~= 0.6 -jaxlib ~= 0.6.0 +jax >= 0.6 +jaxlib >= 0.6 jaxtyping annotated-types matplotlib diff --git a/setup.py b/setup.py index 6fac8a3..7d9e89c 100644 --- a/setup.py +++ b/setup.py @@ -93,8 +93,8 @@ def build_extension(self, ext): cmdclass={"build_ext": CMakeBuild}, install_requires=[ numpy_version, - "jax ~= 0.6.0", - "jaxlib ~= 0.6.0", + "jax >= 0.6.0", + "jaxlib >= 0.6.0", "jaxtyping", "annotated-types", "matplotlib", diff --git a/src/tfc/mtfc.py b/src/tfc/mtfc.py index 60df6b0..add2d5b 100644 --- a/src/tfc/mtfc.py +++ b/src/tfc/mtfc.py @@ -377,7 +377,7 @@ def H(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: H : NDArray Basis function matrix. """ - d = onp.zeros(self.dim, dtype=np.int32) + d = tuple(0 for _ in range(self.dim)) return self._Hjax(*x, d=d, full=full) def Hx(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: @@ -397,8 +397,7 @@ def Hx(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hx : NDArray Derivative of the basis function matrix with respect to the first variable. """ - d = onp.zeros(self.dim, dtype=np.int32) - d[0] = 1 + d = tuple(1 if k == 0 else 0 for k in range(self.dim)) return self._Hjax(*x, d=d, full=full) def Hx2(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: @@ -418,8 +417,7 @@ def Hx2(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hx2 : NDArray Second derivative of the basis function matrix with respect to the first variable. """ - d = onp.zeros(self.dim, dtype=np.int32) - d[0] = 2 + d = tuple(2 if k == 0 else 0 for k in range(self.dim)) return self._Hjax(*x, d=d, full=full) def Hy2(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: @@ -439,8 +437,7 @@ def Hy2(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hy2 : NDArray Second derivative of the basis function matrix with respect to the second variable. """ - d = onp.zeros(self.dim, dtype=np.int32) - d[1] = 2 + d = tuple(2 if k == 1 else 0 for k in range(self.dim)) return self._Hjax(*x, d=d, full=full) def Hx2y(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: @@ -461,10 +458,10 @@ def Hx2y(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hx2y : NDArray Mixed derivative of the basis function matrix with respect to the first variable. """ - d = onp.zeros(self.dim, dtype=np.int32) + d = [0 for _ in range(self.dim)] d[0] = 2 d[1] = 1 - return self._Hjax(*x, d=d, full=full) + return self._Hjax(*x, d=tuple(d), full=full) def Hy(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ @@ -483,8 +480,7 @@ def Hy(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hy : NDArray Derivative of the basis function matrix with respect to the second variable. """ - d = onp.zeros(self.dim, dtype=np.int32) - d[1] = 1 + d = tuple(1 if k == 1 else 0 for k in range(self.dim)) return self._Hjax(*x, d=d, full=full) def Hxy(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: @@ -505,10 +501,10 @@ def Hxy(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hxy : NDArray Mixed derivative of the basis function matrix with respect to the first variable. """ - d = onp.zeros(self.dim, dtype=np.int32) + d = [0 for _ in range(self.dim)] d[0] = 1 d[1] = 1 - return self._Hjax(*x, d=d, full=full) + return self._Hjax(*x, d=tuple(d), full=full) def Hz(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ @@ -527,15 +523,14 @@ def Hz(self, *x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: Hz : NDArray Derivative of the basis function matrix with respect to the third variable. """ - d = onp.zeros(self.dim, dtype=np.int32) - d[2] = 1 + d = tuple(1 if k == 2 else 0 for k in range(self.dim)) return self._Hjax(*x, d=d, full=full) def SetupJAX(self): """This function is used internally by TFC to setup autograd primatives and create desired behavior when taking derivatives of TFC constrained expressions.""" # Helper variables - d0 = onp.zeros(self.dim, dtype=np.int32) + d0 = tuple(0 for _ in range(self.dim)) # Regiser XLA function if self._backend == "C++": @@ -546,18 +541,18 @@ def SetupJAX(self): # Create Primitives H_p = Primitive("H") - def Hjax(*x: JaxOrNumpyArray, d: npt.NDArray[onp.int32] = d0, full: bool = False): + def Hjax(*x: JaxOrNumpyArray, d: tuple[int, ...] = d0, full: bool = False): return cast(npt.NDArray, H_p.bind(*x, d=d, full=full)) # Implicit translations - def H_impl(*x: npt.NDArray, d: npt.NDArray[onp.int32] = d0, full: bool = False): + def H_impl(*x: npt.NDArray, d: tuple[int, ...] = d0, full: bool = False): return self.basisClass.H(np.array(x), d, full) H_p.def_impl(H_impl) # Define abstract evaluation def H_abstract_eval( - *x, d: npt.NDArray[onp.int32] = d0, full: bool = False + *x, d: tuple[int, ...] = d0, full: bool = False ) -> core.ShapedArray: if full: dim1 = self.basisClass.numBasisFuncFull @@ -614,13 +609,13 @@ def H_xla(ctx, *x, d: uint = 0, full: bool = False): mlir.register_lowering(H_p, H_xla, platform="cpu") # Batching translation - def H_batch(vec, batch, d: npt.NDArray[onp.int32] = d0, full: bool = False): + def H_batch(vec, batch, d: tuple[int, ...] = d0, full: bool = False): return Hjax(*vec, d=d, full=full), batch[0] batching.primitive_batchers[H_p] = H_batch # Jacobian vector translation - def H_jvp(arg_vals, arg_tans, d: npt.NDArray[onp.int32] = d0, full: bool = False): + def H_jvp(arg_vals, arg_tans, d: tuple[int, ...] = d0, full: bool = False): n = len(arg_vals) flat = len(arg_vals[0].shape) == 1 dim0 = arg_vals[0].shape[0] @@ -636,8 +631,7 @@ def H_jvp(arg_vals, arg_tans, d: npt.NDArray[onp.int32] = d0, full: bool = False else: flag = onp.any(arg_tans[k] != 0) if flag: - dark = copy(d) - dark[k] += 1 + dark = tuple(d[j]+1 if k == j else d[j] for j in range(len(d))) if flat: out_tans += Hjax(*arg_vals, d=dark, full=full) * np.expand_dims( arg_tans[k], 1 diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF/BF_Py.cc index 9d2bdb8..d06a1e7 100644 --- a/src/tfc/utils/BF/BF_Py.cc +++ b/src/tfc/utils/BF/BF_Py.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -219,6 +220,21 @@ PYBIND11_MODULE(BF, m) { .def_readwrite("z0", &nBasisFunc::z0) .def_readwrite("zf", &nBasisFunc::zf) .def_readwrite("dim", &nBasisFunc::dim) + .def_property("c", + [](nBasisFunc& self){ + // Return c, and ensure the nBasisFunc stays around as long as c does. + return py::array_t(self.dim, self.c, py::cast(self)); + }, + [](nBasisFunc& self, py::array_t c) + { + if (c.ndim() != 1) { + throw py::value_error("The \"c\" input array must be 1-dimensional."); + } + if (c.size() != self.dim) { + throw py::value_error(std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, c.size())); + } + } + ) .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) .def_readwrite("numBasisFuncFull", &nBasisFunc::numBasisFuncFull) .def("H", From 76efc4a1f6af2767441059a38768917b11d00cf5 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:47:23 -0700 Subject: [PATCH 14/45] Fixing API version to avoid warnings. --- src/tfc/mtfc.py | 1 + src/tfc/utfc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tfc/mtfc.py b/src/tfc/mtfc.py index add2d5b..3a9a544 100644 --- a/src/tfc/mtfc.py +++ b/src/tfc/mtfc.py @@ -604,6 +604,7 @@ def H_xla(ctx, *x, d: uint = 0, full: bool = False): result_layouts=[ default_layout((dim0, dim1)), ], + api_version=3, ).results mlir.register_lowering(H_p, H_xla, platform="cpu") diff --git a/src/tfc/utfc.py b/src/tfc/utfc.py index b767510..e2eb711 100644 --- a/src/tfc/utfc.py +++ b/src/tfc/utfc.py @@ -344,7 +344,7 @@ def H_xla(ctx, x, d: uint = 0, full: bool = False): mlir.ir_constant(np.int32(dim1)), ], has_side_effect=False, - api_version=2, + api_version=3, ) return custom_call_op.results From f2838f780cc934936696459f740f9a14b931225d Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:47:39 -0700 Subject: [PATCH 15/45] Formatting with black. --- src/tfc/mtfc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tfc/mtfc.py b/src/tfc/mtfc.py index 3a9a544..ff21304 100644 --- a/src/tfc/mtfc.py +++ b/src/tfc/mtfc.py @@ -551,9 +551,7 @@ def H_impl(*x: npt.NDArray, d: tuple[int, ...] = d0, full: bool = False): H_p.def_impl(H_impl) # Define abstract evaluation - def H_abstract_eval( - *x, d: tuple[int, ...] = d0, full: bool = False - ) -> core.ShapedArray: + def H_abstract_eval(*x, d: tuple[int, ...] = d0, full: bool = False) -> core.ShapedArray: if full: dim1 = self.basisClass.numBasisFuncFull else: @@ -632,7 +630,7 @@ def H_jvp(arg_vals, arg_tans, d: tuple[int, ...] = d0, full: bool = False): else: flag = onp.any(arg_tans[k] != 0) if flag: - dark = tuple(d[j]+1 if k == j else d[j] for j in range(len(d))) + dark = tuple(d[j] + 1 if k == j else d[j] for j in range(len(d))) if flat: out_tans += Hjax(*arg_vals, d=dark, full=full) * np.expand_dims( arg_tans[k], 1 From fbeb5df29543d6e346b4780ea48983335ea1f014 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:48:49 -0700 Subject: [PATCH 16/45] Adding pybind11 to requirements for docs. --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 765fab3..9187886 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,6 +9,7 @@ breathe exhale nbsphinx ipykernel +pybind11 jax jaxlib jaxtyping From a04413f3c22f30f3e77d6952ae6adcc6175feaba Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:48:58 -0700 Subject: [PATCH 17/45] Removing this file as I don't think we need it anymore. --- src/tfc/utils/BF/BF.py | 77 ------------------------------------------ 1 file changed, 77 deletions(-) delete mode 100644 src/tfc/utils/BF/BF.py diff --git a/src/tfc/utils/BF/BF.py b/src/tfc/utils/BF/BF.py deleted file mode 100644 index f70670b..0000000 --- a/src/tfc/utils/BF/BF.py +++ /dev/null @@ -1,77 +0,0 @@ -""" This is a dummy file used only to avoid errors in ReadTheDocs. The real BF.py is created during the setup once swig is run. """ - - -def CP(): - pass - - -def LeP(): - pass - - -def LaP(): - pass - - -def HoPpro(): - pass - - -def HoPphy(): - pass - - -def FS(): - pass - - -def ELMReLU(): - pass - - -def ELMSigmoid(): - pass - - -def ELMTanh(): - pass - - -def ELMSin(): - pass - - -def ELMSwish(): - pass - - -def nCP(): - pass - - -def nLeP(): - pass - - -def nFS(): - pass - - -def nELMReLU(): - pass - - -def nELMSigmoid(): - pass - - -def nELMTanh(): - pass - - -def nELMSin(): - pass - - -def nELMSwish(): - pass From 7ac2f9050b82b54004678b731e814cc6d189fa43 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 07:55:36 -0700 Subject: [PATCH 18/45] Adding pybind11 and mypy. --- .github/workflows/ci.yml | 2 +- .github/workflows/publish_wheels.yml | 4 ++-- .github/workflows/publish_wheels_test_pypi.yml | 4 ++-- .github/workflows/run_most_examples.yml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cf957a1..378c827 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: run: "python -m black --line-length 100 --check ./src/tfc" - run: "sudo apt-get update && sudo apt-get install -y gcc g++" - - run: python -m pip install wheel setuptools numpy pytest pybind11 + - run: python -m pip install wheel setuptools numpy pytest pybind11 mypy - run: python setup.py bdist_wheel - run: pip install ./dist/*.whl diff --git a/.github/workflows/publish_wheels.yml b/.github/workflows/publish_wheels.yml index fdb5009..e44dce9 100644 --- a/.github/workflows/publish_wheels.yml +++ b/.github/workflows/publish_wheels.yml @@ -12,13 +12,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install libraries - run: "sudo apt-get update && sudo apt-get install -y swig gcc g++" + run: "sudo apt-get update && sudo apt-get install -y gcc g++" - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - name: Checkout dependencies - run: python -m pip install wheel setuptools numpy + run: python -m pip install wheel setuptools numpy pybind11 mypy - name: Create source distribution run: python setup.py sdist diff --git a/.github/workflows/publish_wheels_test_pypi.yml b/.github/workflows/publish_wheels_test_pypi.yml index a057cbd..8b75b35 100644 --- a/.github/workflows/publish_wheels_test_pypi.yml +++ b/.github/workflows/publish_wheels_test_pypi.yml @@ -9,13 +9,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install libraries - run: "sudo apt-get update && sudo apt-get install -y swig gcc g++" + run: "sudo apt-get update && sudo apt-get install -y gcc g++" - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - name: Checkout dependencies - run: python -m pip install wheel setuptools numpy + run: python -m pip install wheel setuptools numpy pybind11 mypy - name: Create source distribution run: python setup.py sdist diff --git a/.github/workflows/run_most_examples.yml b/.github/workflows/run_most_examples.yml index cbb9178..ef812c7 100644 --- a/.github/workflows/run_most_examples.yml +++ b/.github/workflows/run_most_examples.yml @@ -14,8 +14,8 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - - run: "sudo apt-get update && sudo apt-get install -y swig gcc g++ graphviz" - - run: python -m pip install wheel setuptools numpy pytest + - run: "sudo apt-get update && sudo apt-get install -y gcc g++ graphviz" + - run: python -m pip install wheel setuptools numpy pytest pybind11 mypy - run: python setup.py bdist_wheel - run: pip install ./dist/*.whl - run: pip install -r ./requirements.txt From deb9f4993729ba1cba53a85eb5434a5baf59aad4 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:06:47 -0700 Subject: [PATCH 19/45] Fixing types now that we are on 3.10 or greater. --- src/tfc/utils/types.py | 45 ++++++++++++------------------------------ 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/src/tfc/utils/types.py b/src/tfc/utils/types.py index 5cfeff8..1fc0842 100644 --- a/src/tfc/utils/types.py +++ b/src/tfc/utils/types.py @@ -6,21 +6,7 @@ from sympy.core.function import AppliedUndef from sympy import Expr -if sys.version_info >= (3, 8): - from typing import Literal, Protocol, TypedDict -else: - from typing_extensions import Literal, Protocol, TypedDict - -if sys.version_info >= (3, 9): - from typing import Annotated - - List = list - Tuple = tuple - Dict = dict - Tuple = tuple -else: - from typing_extensions import Annotated - from typing import List, Tuple, Dict, Tuple +from typing import Literal, Protocol, TypedDict, Annotated from annotated_types import Gt, Ge, Lt, Le @@ -37,26 +23,21 @@ # General number type Number = Union[int, float, complex] -if sys.version_info >= (3, 8): - from numpy._typing._array_like import _ArrayLikeStr_co, _ArrayLikeInt_co +from numpy._typing._array_like import _ArrayLikeStr_co, _ArrayLikeInt_co - # Array-like of strings - StrArrayLike = _ArrayLikeStr_co +# Array-like of strings +StrArrayLike = _ArrayLikeStr_co - # Array-like of integers - IntArrayLike = _ArrayLikeInt_co -else: - # Hacks to keep things working for Python 3.7 - StrArrayLike = Any - IntArrayLike = Any +# Array-like of integers +IntArrayLike = _ArrayLikeInt_co # List or array like -NumberListOrArray = Union[Tuple[Number, ...], List[Number], npt.NDArray[Any], Array] +NumberListOrArray = Union[tuple[Number, ...], list[Number], npt.NDArray[Any], Array] # List or array of integers IntListOrArray = Union[ - Tuple[int, ...], - List[int], + tuple[int, ...], + list[int], npt.NDArray[np.int32], npt.NDArray[np.int64], npt.NDArray[np.int16], @@ -67,14 +48,14 @@ JaxOrNumpyArray = Union[npt.NDArray, Array] # Tuple or list of array -TupleOrListOfArray = Union[Tuple[JaxOrNumpyArray, ...], List[JaxOrNumpyArray]] -TupleOrListOfNumpyArray = Union[Tuple[npt.NDArray, ...], List[npt.NDArray]] +TupleOrListOfArray = Union[tuple[JaxOrNumpyArray, ...], list[JaxOrNumpyArray]] +TupleOrListOfNumpyArray = Union[tuple[npt.NDArray, ...], list[npt.NDArray]] # Sympy constraint operator # Adding in Any here since sympy types are a bit funky at the moment ConstraintOperator = Callable[[Union[AppliedUndef, Expr, Any]], Union[AppliedUndef, Any]] -ConstraintOperators = Union[List[ConstraintOperator], Tuple[ConstraintOperator, ...]] +ConstraintOperators = Union[list[ConstraintOperator], tuple[ConstraintOperator, ...]] # List or tuple of sympy expressions # Adding in Any here since sympy types are a bit funky at the moment -Exprs = Union[List[Union[Expr, Any]], Tuple[Union[Expr, Any], ...]] +Exprs = Union[list[Union[Expr, Any]], tuple[Union[Expr, Any], ...]] From f1fd8286bd4dccd2cf83bd712c9a5805dc127952 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:19:03 -0700 Subject: [PATCH 20/45] Reworkign types and moving BF into utils. No need for separate folder. --- setup.py | 4 +- src/tfc/mtfc.py | 8 ++-- src/tfc/utfc.py | 4 +- src/tfc/utils/{BF => }/BF.cxx | 0 src/tfc/utils/{BF => }/BF.h | 0 src/tfc/utils/BF/__init__.py | 2 - src/tfc/utils/{BF => }/BF_Py.cc | 0 src/tfc/utils/{BF => }/BF_Py.py | 2 +- src/tfc/utils/{BF => }/CMakeLists.txt | 0 src/tfc/utils/CeSolver.py | 5 +-- src/tfc/utils/Html.py | 7 ++- src/tfc/utils/Latex.py | 17 ++++---- src/tfc/utils/MakePlot.py | 6 +-- src/tfc/utils/MayaviMakePlot.py | 16 +++---- src/tfc/utils/PlotlyMakePlot.py | 6 +-- src/tfc/utils/TFCUtils.py | 54 ++++++++++++------------ src/tfc/utils/__init__.py | 1 - src/tfc/utils/{types.py => tfc_types.py} | 0 tests/test_BF.py | 22 +++++----- tests/test_nBF.py | 16 +++---- 20 files changed, 81 insertions(+), 89 deletions(-) rename src/tfc/utils/{BF => }/BF.cxx (100%) rename src/tfc/utils/{BF => }/BF.h (100%) delete mode 100644 src/tfc/utils/BF/__init__.py rename src/tfc/utils/{BF => }/BF_Py.cc (100%) rename src/tfc/utils/{BF => }/BF_Py.py (99%) rename src/tfc/utils/{BF => }/CMakeLists.txt (100%) rename src/tfc/utils/{types.py => tfc_types.py} (100%) diff --git a/setup.py b/setup.py index 7d9e89c..6b3e8a7 100644 --- a/setup.py +++ b/setup.py @@ -36,13 +36,13 @@ class CMakeExtension(Extension): def __init__(self, name, sourcedir=""): super().__init__(name, sources=[]) - self.sourcedir = str((Path(sourcedir) / "src" / "tfc" / "utils" / "BF").absolute()) + self.sourcedir = str((Path(sourcedir) / "src" / "tfc" / "utils").absolute()) class CMakeBuild(build_ext): def build_extension(self, ext): extdir = Path(self.get_ext_fullpath(ext.name)).parents[0].absolute() - bf_dir = extdir / "tfc" / "utils" / "BF" + bf_dir = extdir / "tfc" / "utils" import pybind11 dark = Path(pybind11.__file__).parents[0] diff --git a/src/tfc/mtfc.py b/src/tfc/mtfc.py index ff21304..8a0649d 100644 --- a/src/tfc/mtfc.py +++ b/src/tfc/mtfc.py @@ -2,12 +2,11 @@ config.update("jax_enable_x64", True) -from copy import copy import numpy as onp import jax.numpy as np from typing import cast import numpy.typing as npt -from .utils.types import ( +from .utils.tfc_types import ( Literal, uint, IntListOrArray, @@ -16,7 +15,6 @@ JaxOrNumpyArray, IntArrayLike, Array, - Tuple, ) from jax import core from jax.extend.core import Primitive @@ -280,7 +278,7 @@ def __init__( if backend == "C++": from .utils import BF elif backend == "Python": - from .utils.BF import BF_Py as BF + from .utils import BF_Py as BF else: TFCPrint.Error( f'The backend {backend} was specified, but can only be one of "C++" or "Python".' @@ -354,7 +352,7 @@ def __init__( x[k][:] = (self.z[k, :] - z0) / self.c[k] + self.x0[k] self.z: Array = cast(Array, np.array(self.z.tolist())) - self.x: Tuple[Array, ...] = tuple( + self.x: tuple[Array, ...] = tuple( [cast(Array, np.array(x[k].tolist())) for k in range(self.dim)] ) diff --git a/src/tfc/utfc.py b/src/tfc/utfc.py index e2eb711..67920e7 100644 --- a/src/tfc/utfc.py +++ b/src/tfc/utfc.py @@ -6,7 +6,7 @@ import jax.numpy as np import numpy.typing as npt from typing import Optional, cast -from .utils.types import Literal, uint, IntArrayLike, JaxOrNumpyArray +from .utils.tfc_types import Literal, uint, IntArrayLike, JaxOrNumpyArray from jax import core from jax.extend.core import Primitive from jax.interpreters import ad, batching, mlir @@ -132,7 +132,7 @@ def __init__( if backend == "C++": from .utils import BF elif backend == "Python": - from .utils.BF import BF_Py as BF + from .utils import BF_Py as BF else: TFCPrint.Error( f'The backend {backend} was specified, but can only be one of "C++" or "Python".' diff --git a/src/tfc/utils/BF/BF.cxx b/src/tfc/utils/BF.cxx similarity index 100% rename from src/tfc/utils/BF/BF.cxx rename to src/tfc/utils/BF.cxx diff --git a/src/tfc/utils/BF/BF.h b/src/tfc/utils/BF.h similarity index 100% rename from src/tfc/utils/BF/BF.h rename to src/tfc/utils/BF.h diff --git a/src/tfc/utils/BF/__init__.py b/src/tfc/utils/BF/__init__.py deleted file mode 100644 index d747c6d..0000000 --- a/src/tfc/utils/BF/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .BF import CP, LeP, LaP, HoPpro, HoPphy, FS, ELMReLU, ELMSigmoid, ELMTanh, ELMSin, ELMSwish -from .BF import nCP, nLeP, nFS, nELMSigmoid, nELMTanh, nELMSin, nELMSwish, nELMReLU diff --git a/src/tfc/utils/BF/BF_Py.cc b/src/tfc/utils/BF_Py.cc similarity index 100% rename from src/tfc/utils/BF/BF_Py.cc rename to src/tfc/utils/BF_Py.cc diff --git a/src/tfc/utils/BF/BF_Py.py b/src/tfc/utils/BF_Py.py similarity index 99% rename from src/tfc/utils/BF/BF_Py.py rename to src/tfc/utils/BF_Py.py index ff40fdb..b1372aa 100644 --- a/src/tfc/utils/BF/BF_Py.py +++ b/src/tfc/utils/BF_Py.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from abc import ABC, abstractmethod from numpy import typing as npt -from tfc.utils.types import uint, Number +from tfc.utils.tfc_types import uint, Number from typing import Callable, Tuple diff --git a/src/tfc/utils/BF/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt similarity index 100% rename from src/tfc/utils/BF/CMakeLists.txt rename to src/tfc/utils/CMakeLists.txt diff --git a/src/tfc/utils/CeSolver.py b/src/tfc/utils/CeSolver.py index be188fd..f48fe59 100644 --- a/src/tfc/utils/CeSolver.py +++ b/src/tfc/utils/CeSolver.py @@ -3,9 +3,8 @@ from sympy.core.function import AppliedUndef from sympy.printing.pycode import PythonCodePrinter from sympy.simplify.simplify import nc_simplify -from .types import ConstraintOperators, Exprs, Union, Any, Literal, ConstraintOperator +from .tfc_types import ConstraintOperators, Exprs, Union, Any, Literal, ConstraintOperator from .TFCUtils import TFCPrint -from sympy import latex class CeSolver: @@ -182,7 +181,7 @@ def S(self) -> sp.Matrix: the constrained expression is defined. """ - def _applyC(c, s) -> Any: + def _applyC(c: ConstraintOperator, s) -> Any: """ Apply the constraint operator to the switching function. diff --git a/src/tfc/utils/Html.py b/src/tfc/utils/Html.py index 0d229d6..18015c5 100644 --- a/src/tfc/utils/Html.py +++ b/src/tfc/utils/Html.py @@ -1,8 +1,7 @@ import os from graphviz import Digraph from yattag import Doc, indent -from .types import List -from .types import Path +from .tfc_types import Path class HTML: @@ -93,13 +92,13 @@ def __init__(self, outFile: Path, name: str): self._name = name self.dot = Digraph(name=self._name) - def Render(self, formats: List[str] = ["cmapx", "svg"]): + def Render(self, formats: list[str] = ["cmapx", "svg"]): """ This function renders the dot graph as a .svg and as a .cmapx. Parameters ---------- - formats : List[str], optional + formats : list[str], optional List whose elementts dictate which formats to render the dot graph in. (Default value = ["cmapx", "svg"]) """ for f in formats: diff --git a/src/tfc/utils/Latex.py b/src/tfc/utils/Latex.py index afa6041..819cac1 100644 --- a/src/tfc/utils/Latex.py +++ b/src/tfc/utils/Latex.py @@ -1,6 +1,5 @@ import numpy as np from numpy import typing as npt -from .types import List from typing import Optional @@ -24,13 +23,13 @@ def _Header(numCols: int) -> str: return "\\begin{center}\n\\begin{tabular}{" + "|c" * numCols + "|}\n" @staticmethod - def _colHeader(strIn: List[str]) -> str: + def _colHeader(strIn: list[str]) -> str: """This function creates the column header based on the list of strings that are passed in via the input strIn. Parameters ---------- - strIn : List[str] + strIn : list[str] List of strings that form the column headers. Returns @@ -41,7 +40,7 @@ def _colHeader(strIn: List[str]) -> str: return " & ".join(strIn) + "\\\\\n" @staticmethod - def _Arr2Tab(arrIn: npt.NDArray, form: str = "%.4E", rowHeader: Optional[List[str]] = None): + def _Arr2Tab(arrIn: npt.NDArray, form: str = "%.4E", rowHeader: Optional[list[str]] = None): """ This function transforms the 2-D numpy array (arrIn) into latex tabular format. The "form" argument specifies the number format to be used in the tabular environment. @@ -57,7 +56,7 @@ def _Arr2Tab(arrIn: npt.NDArray, form: str = "%.4E", rowHeader: Optional[List[st form : str, optional Format string for the table numbers. (Default value = "%.4E") - rowHeader : Optional[List[str]] + rowHeader : Optional[list[str]] List of strings to use as the row headers. (Default value = None) Returns @@ -106,8 +105,8 @@ def _Footer() -> str: def SimpleTable( arrIn: npt.NDArray, form: str = "%.4E", - colHeader: Optional[List[str]] = None, - rowHeader: Optional[List[str]] = None, + colHeader: Optional[list[str]] = None, + rowHeader: Optional[list[str]] = None, ) -> str: """This function creates a simple latex table for the 2D numpy array arrIn. The "form" argument specifies the number format to be used in the tabular environment. @@ -124,10 +123,10 @@ def SimpleTable( form : str, optional Format string for the table numbers. (Default value = "%.4E") - colHeader : Optional[List[str]] + colHeader : Optional[list[str]] List of strings that form the column headers. (Default value = None) - rowHeader : Optional[List[str]] + rowHeader : Optional[list[str]] List of strings to use as the row headers. (Default value = None) Returns diff --git a/src/tfc/utils/MakePlot.py b/src/tfc/utils/MakePlot.py index c5f7922..5f3e9ab 100644 --- a/src/tfc/utils/MakePlot.py +++ b/src/tfc/utils/MakePlot.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from .TFCUtils import TFCPrint -from .types import StrArrayLike, Path, List, Dict, Literal, pint +from .tfc_types import StrArrayLike, Path, Literal, pint from typing import Optional, Union, Generator, Callable TFCPrint() @@ -24,7 +24,7 @@ def __init__( titles: Optional[StrArrayLike] = None, twinYlabs: Optional[StrArrayLike] = None, zlabs: Optional[StrArrayLike] = None, - style: Optional[Union[str, Dict, Path, List[str], List[Dict], List[Path]]] = None, + style: Optional[Union[str, dict, Path, list[str], list[dict], list[Path]]] = None, ): """ This function initializes the plot/subplots based on the inputs provided. @@ -41,7 +41,7 @@ def __init__( The twin y-axes labels for the plots. Setting this forces twin axis y-axes. (Default value = None) zlabs: StrArrayLike, optional The z-axes labels of for the plots. Setting this forces subplots to be 3D. (Default value = None) - style : Union[str, Dict, Path, List[str], List[Dict], List[Path]] + style : Union[str, dict, Path, list[str], list[dict], list[Path]] Matplotlib style. (Default value = None) """ diff --git a/src/tfc/utils/MayaviMakePlot.py b/src/tfc/utils/MayaviMakePlot.py index aecdd82..54e2976 100644 --- a/src/tfc/utils/MayaviMakePlot.py +++ b/src/tfc/utils/MayaviMakePlot.py @@ -3,11 +3,11 @@ import mayavi from mayavi import mlab from matplotlib import colors as mcolors -from .types import Dict, Tuple, Path, Ge, Le, Annotated, Literal +from .tfc_types import Path, Ge, Le, Annotated, Literal from typing import Optional, Any, Union, Generator, Callable from .TFCUtils import TFCPrint -Color = Union[str, Tuple[float, float, float, float], npt.NDArray[np.float64]] +Color = Union[str, tuple[float, float, float, float], npt.NDArray[np.float64]] TFCPrint() @@ -15,7 +15,7 @@ class MakePlot: """MakePlot class for Mayavi.""" @staticmethod - def _str_to_rgb(color: str) -> Tuple[float, float, float]: + def _str_to_rgb(color: str) -> tuple[float, float, float]: """Call matplotlib's colorConverter.to_rgb on input string. Parameters @@ -25,13 +25,13 @@ def _str_to_rgb(color: str) -> Tuple[float, float, float]: Returns ------- - color_rgb : Tuple[float, float, float] + color_rgb : tuple[float, float, float] 3-tuple of the RGB for the color """ return mcolors.colorConverter.to_rgb(color) @staticmethod - def _str_to_rgba(color, alpha: Optional[float] = None) -> Tuple[float, float, float, float]: + def _str_to_rgba(color, alpha: Optional[float] = None) -> tuple[float, float, float, float]: """Call matplotlib's colorConverter.to_rgba on input string. Parameters @@ -44,13 +44,13 @@ def _str_to_rgba(color, alpha: Optional[float] = None) -> Tuple[float, float, fl Returns ------- - color_rgba : Tuple[float, float, float, float] + color_rgba : tuple[float, float, float, float] 4-tuple of the RGB for the color """ return mcolors.colorConverter.to_rgba(color, alpha=alpha) @staticmethod - def _ProcessKwargs(**kwargs: Any) -> Dict[str, Any]: + def _ProcessKwargs(**kwargs: Any) -> dict[str, Any]: """This function effectively extends common mlab keywords. Parameters @@ -60,7 +60,7 @@ def _ProcessKwargs(**kwargs: Any) -> Dict[str, Any]: Returns ------- - kwargs : Dict[str, any] + kwargs : dict[str, any] Same as input keyword arguments but color has been transformed to an RGB if it was a string. """ # Process color argument if it exists diff --git a/src/tfc/utils/PlotlyMakePlot.py b/src/tfc/utils/PlotlyMakePlot.py index b07039a..b820a81 100644 --- a/src/tfc/utils/PlotlyMakePlot.py +++ b/src/tfc/utils/PlotlyMakePlot.py @@ -3,7 +3,7 @@ import plotly.graph_objects as go from .TFCUtils import TFCPrint -from .types import StrArrayLike, uint, Path, Literal, List +from .tfc_types import StrArrayLike, uint, Path, Literal from typing import Optional TFCPrint() @@ -663,7 +663,7 @@ def PartScreen(self, width: float, height: float, units: Literal["in", "mm", "px def NormalizeColorScale( self, - types: List[str] = [], + types: list[str] = [], data: Optional[str] = None, cmax: Optional[float] = None, cmin: Optional[float] = None, @@ -680,7 +680,7 @@ def NormalizeColorScale( Parameters ---------- - types: List[str] + types: list[str] Plot types to set cmax and cmin for. data: Optional[str] Data type to use to calculate cmax and cmin if not already specified. (Default value = None) diff --git a/src/tfc/utils/TFCUtils.py b/src/tfc/utils/TFCUtils.py index a4a6e0d..71ff568 100644 --- a/src/tfc/utils/TFCUtils.py +++ b/src/tfc/utils/TFCUtils.py @@ -22,15 +22,15 @@ from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal from jax.experimental import io_callback from typing import Any, Callable, Optional, Union, cast -from .types import uint, List, Literal, Tuple, TypedDict, Path, Dict +from .tfc_types import uint, Literal, TypedDict, Path from jaxtyping import PyTree from typing import cast # Types that can be added to a TFCDict -TFCDictAddable = Union[np.ndarray, Dict[Any, Any], "TFCDict"] +TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"] # Types that can be added to a TFCDictRobust -TFCDictRobustAddable = Union[np.ndarray, Dict[Any, Any], "TFCDictRobust"] +TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"] class TFCPrint: @@ -204,7 +204,7 @@ def wrapped(*args: Any) -> Any: return wrapped -def pe(*args: Any, constant_arg_nums: List[int] = []) -> Any: +def pe(*args: Any, constant_arg_nums: list[int] = []) -> Any: """ Decorator that returns a function evaluated such that the arg numbers specified in constant_arg_nums and all functions that utilizes only those arguments are treated as compile time constants. @@ -213,7 +213,7 @@ def pe(*args: Any, constant_arg_nums: List[int] = []) -> Any: ---------- *args : Any Arguments for the function that pe is applied to. - constant_arg_nums : List[int], optional + constant_arg_nums : list[int], optional The arguments whose values and functions that depend only on these values should be treated as cached constants. @@ -298,7 +298,7 @@ def get_arg(a, unknown): return wrapper -def pejit(*args: Any, constant_arg_nums: List[int] = [], **kwargs) -> Any: +def pejit(*args: Any, constant_arg_nums: list[int] = [], **kwargs) -> Any: """ Works like :func:`pe `, but also JITs the returned function. See :func:`pe ` for more details. @@ -306,7 +306,7 @@ def pejit(*args: Any, constant_arg_nums: List[int] = [], **kwargs) -> Any: ----------- *args: Any Arguments for the function that pe is applied to. - constant_arg_nums: List[int], optional + constant_arg_nums: list[int], optional The arguments whose values and functions that depend only on these values should be treated as cached constants. **kwargs: Any @@ -746,13 +746,13 @@ def LS( zXi: PyTree, res: Callable, *args: Any, - constant_arg_nums: List[int] = [], + constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, timerType: str = "process_time", holomorphic: bool = False, -) -> Union[PyTree, Tuple[PyTree, float]]: +) -> Union[PyTree, tuple[PyTree, float]]: """ JITed least squares. This function takes in an initial guess of zeros, zXi, and a residual function, res, and @@ -771,7 +771,7 @@ def LS( *args : Any Any additional arguments taken by res other than the first PyTree argument. - constant_arg_nums: List[int], optional + constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit ` for more details. J : Optional[Callable[...,np.ndarray]] @@ -839,7 +839,7 @@ def J(xi, *args): # Make arguments constant if desired ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls) - args: List[Any] = list(args) + args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: @@ -877,7 +877,7 @@ def __init__( zXi: PyTree, res: Callable, *args: Any, - constant_arg_nums: List[int] = [], + constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, @@ -902,7 +902,7 @@ def __init__( J : Optional[Callable[...,np.ndarray]] User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) - constant_arg_nums: List[int], optional + constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit ` for more details. method : Literal["pinv","lstsq"], optional @@ -966,7 +966,7 @@ def J(xi, *args): # Make arguments constant if desired ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls) - args: List[Any] = list(args) + args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: @@ -976,7 +976,7 @@ def J(xi, *args): self._compiled = False - def run(self, zXi: PyTree, *args: Any) -> Union[PyTree, Tuple[PyTree, float]]: + def run(self, zXi: PyTree, *args: Any) -> Union[PyTree, tuple[PyTree, float]]: """ Runs the JIT-ed least-squares function and times it if desired. @@ -1030,7 +1030,7 @@ def NLLS( xiInit: PyTree, res: Callable, *args: Any, - constant_arg_nums: List[int] = [], + constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, @@ -1042,7 +1042,7 @@ def NLLS( printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, -) -> Union[Tuple[PyTree, int], Tuple[PyTree, int, float]]: +) -> Union[tuple[PyTree, int], tuple[PyTree, int, float]]: """ JIT-ed non-linear least squares. This function takes in an initial guess, xiInit (initial values of xi), and a residual function, res, and @@ -1063,7 +1063,7 @@ def NLLS( *args : iterable Any additional arguments taken by res other than xi. - constant_arg_nums: List[int], optional + constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit ` for more details. J : function @@ -1171,7 +1171,7 @@ def J(xi, *args): LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS) res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res) - args: List[Any] = list(args) + args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: @@ -1239,7 +1239,7 @@ def __init__( xiInit: PyTree, res: Callable, *args: Any, - constant_arg_nums: List[int] = [], + constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, @@ -1266,7 +1266,7 @@ def __init__( *args : iterable Any additional arguments taken by res other than xi. - constant_arg_nums: List[int], optional + constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit ` for more details. J : function @@ -1370,7 +1370,7 @@ def J(xi, *args): LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS) res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res) - args: List[Any] = list(args) + args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: @@ -1404,7 +1404,7 @@ def body(val): def run( self, xiInit: PyTree, *args: Any - ) -> Union[Tuple[PyTree, int], Tuple[PyTree, int, float]]: + ) -> Union[tuple[PyTree, int], tuple[PyTree, int, float]]: """Runs the JIT-ed nonlinear least-squares function and times it if desired. Parameters @@ -1472,16 +1472,16 @@ class ComponentConstraintGraph: Creates a graph of all valid ways in which component constraints can be embedded. """ - def __init__(self, N: List[str], E: List[ComponentConstraintDict]) -> None: + def __init__(self, N: list[str], E: list[ComponentConstraintDict]) -> None: """ Class constructor. Parameters ---------- - N : List[str] + N : list[str] A list of strings that specify the node names. These node names typically coincide with the names of the dependent variables. - E : List[ComponentConstraintDict] + E : list[ComponentConstraintDict] The ComponentConstraintDict is a dictionary with the following fields: * name - Name of the component constraint. * node0 - The name of one of the nodes that makes up the component constraint. Must correspond with an element of the list given in N. @@ -1629,7 +1629,7 @@ def SaveGraphs(self, outputDir: Path, allGraphs: bool = False, savePDFs: bool = treeHtml.WriteFile() -def ScaledQrLs(A: np.ndarray, B: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: +def ScaledQrLs(A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """This function performs least-squares using a scaled QR method. Parameters diff --git a/src/tfc/utils/__init__.py b/src/tfc/utils/__init__.py index c9ac1f6..7a13254 100644 --- a/src/tfc/utils/__init__.py +++ b/src/tfc/utils/__init__.py @@ -16,5 +16,4 @@ ) from .MakePlot import MakePlot from . import Latex -from . import BF from .CeSolver import CeSolver diff --git a/src/tfc/utils/types.py b/src/tfc/utils/tfc_types.py similarity index 100% rename from src/tfc/utils/types.py rename to src/tfc/utils/tfc_types.py diff --git a/tests/test_BF.py b/tests/test_BF.py index 155c371..41a4c87 100644 --- a/tests/test_BF.py +++ b/tests/test_BF.py @@ -17,7 +17,7 @@ def test_CP(): - from tfc.utils.BF.BF_Py import CP as pCP + from tfc.utils.BF_Py import CP as pCP x = np.linspace(0, 2, num=10) @@ -36,7 +36,7 @@ def test_CP(): def test_LeP(): - from tfc.utils.BF.BF_Py import LeP as pLeP + from tfc.utils.BF_Py import LeP as pLeP x = np.linspace(0, 2, num=10) lep1 = LeP(0.0, 2.0, np.array([], dtype=np.int32), 5) @@ -54,7 +54,7 @@ def test_LeP(): def test_LaP(): - from tfc.utils.BF.BF_Py import LaP as pLaP + from tfc.utils.BF_Py import LaP as pLaP x = np.linspace(0, 5, num=10) lap1 = LaP(0.0, 5.0, np.array([], dtype=np.int32), 5) @@ -72,7 +72,7 @@ def test_LaP(): def test_HoPpro(): - from tfc.utils.BF.BF_Py import HoPpro as pHoPpro + from tfc.utils.BF_Py import HoPpro as pHoPpro x = np.linspace(0, 5, num=10) hoppro1 = HoPpro(0.0, 5.0, np.array([], dtype=np.int32), 5) @@ -90,7 +90,7 @@ def test_HoPpro(): def test_HoPphy(): - from tfc.utils.BF.BF_Py import HoPphy as pHoPphy + from tfc.utils.BF_Py import HoPphy as pHoPphy x = np.linspace(0, 5, num=10) hopphy1 = HoPphy(0.0, 5.0, np.array([], dtype=np.int32), 5) @@ -108,7 +108,7 @@ def test_HoPphy(): def test_FS(): - from tfc.utils.BF.BF_Py import FS as pFS + from tfc.utils.BF_Py import FS as pFS x = np.linspace(0, 2 * np.pi, num=10) fs1 = FS(0.0, 2.0 * np.pi, np.array([], dtype=np.int32), 5) @@ -135,7 +135,7 @@ def test_FS(): def test_ELMReLU(): - from tfc.utils.BF.BF_Py import ELMReLU as pELMReLU + from tfc.utils.BF_Py import ELMReLU as pELMReLU x = np.linspace(0, 1, num=10) elm = ELMReLU(0.0, 1.0, np.array([], dtype=np.int32), 10) @@ -157,7 +157,7 @@ def test_ELMReLU(): def test_ELMSigmoid(): - from tfc.utils.BF.BF_Py import ELMSigmoid as pELMSigmoid + from tfc.utils.BF_Py import ELMSigmoid as pELMSigmoid x = np.linspace(0, 1, num=10) elm = ELMSigmoid(0.0, 1.0, np.array([], dtype=np.int32), 10) @@ -182,7 +182,7 @@ def test_ELMSigmoid(): def test_ELMTanh(): - from tfc.utils.BF.BF_Py import ELMTanh as pELMTanh + from tfc.utils.BF_Py import ELMTanh as pELMTanh x = np.linspace(0, 1, num=10) elm = ELMTanh(0.0, 1.0, np.array([], dtype=np.int32), 10) @@ -206,7 +206,7 @@ def test_ELMTanh(): def test_ELMSin(): - from tfc.utils.BF.BF_Py import ELMSin as pELMSin + from tfc.utils.BF_Py import ELMSin as pELMSin x = np.linspace(0, 1, num=10) elm = ELMSin(0.0, 1.0, np.array([], dtype=np.int32), 10) @@ -233,7 +233,7 @@ def test_ELMSin(): def test_ELMSwish(): - from tfc.utils.BF.BF_Py import ELMSwish as pELMSwish + from tfc.utils.BF_Py import ELMSwish as pELMSwish x = np.linspace(0, 1, num=10) elm = ELMSwish(0.0, 1.0, np.array([], dtype=np.int32), 10) diff --git a/tests/test_nBF.py b/tests/test_nBF.py index 32511d5..3bb1f40 100644 --- a/tests/test_nBF.py +++ b/tests/test_nBF.py @@ -5,7 +5,7 @@ def test_nCP(): - from tfc.utils.BF.BF_Py import nCP as pnCP + from tfc.utils.BF_Py import nCP as pnCP dim = 2 nC = -1 * np.ones((dim, 1), dtype=np.int32) @@ -40,7 +40,7 @@ def test_nCP(): def test_nLeP(): - from tfc.utils.BF.BF_Py import nLeP as pnLeP + from tfc.utils.BF_Py import nLeP as pnLeP dim = 2 nC = -1 * np.ones((dim, 1), dtype=np.int32) @@ -74,7 +74,7 @@ def test_nLeP(): def test_nFS(): - from tfc.utils.BF.BF_Py import nFS as pnFS + from tfc.utils.BF_Py import nFS as pnFS dim = 2 nC = -1 * np.ones((dim, 1), dtype=np.int32) @@ -109,7 +109,7 @@ def test_nFS(): def test_nELMSigmoid(): - from tfc.utils.BF.BF_Py import nELMSigmoid as pnELMSigmoid + from tfc.utils.BF_Py import nELMSigmoid as pnELMSigmoid dim = 2 nC = -1 * np.ones(1, dtype=np.int32) @@ -151,7 +151,7 @@ def test_nELMSigmoid(): def test_nELMTanh(): - from tfc.utils.BF.BF_Py import nELMTanh as pnELMTanh + from tfc.utils.BF_Py import nELMTanh as pnELMTanh dim = 2 nC = -1 * np.ones(1, dtype=np.int32) @@ -194,7 +194,7 @@ def test_nELMTanh(): def test_nELMSin(): - from tfc.utils.BF.BF_Py import nELMSin as pnELMSin + from tfc.utils.BF_Py import nELMSin as pnELMSin dim = 2 nC = -1 * np.ones(1, dtype=np.int32) @@ -238,7 +238,7 @@ def test_nELMSin(): def test_nELMSwish(): - from tfc.utils.BF.BF_Py import nELMSwish as pnELMSwish + from tfc.utils.BF_Py import nELMSwish as pnELMSwish dim = 2 nC = -1 * np.ones(1, dtype=np.int32) @@ -281,7 +281,7 @@ def test_nELMSwish(): def test_nELMReLU(): - from tfc.utils.BF.BF_Py import nELMReLU as pnELMReLU + from tfc.utils.BF_Py import nELMReLU as pnELMReLU dim = 2 nC = -1 * np.ones(1, dtype=np.int32) From 8f433af271b6959504b0eb62ab7f7170ecf793a6 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:36:55 -0700 Subject: [PATCH 21/45] Trying to install package now in readthedocs since we use cmake. --- .readthedocs.yml | 2 ++ docs/requirements.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/.readthedocs.yml b/.readthedocs.yml index f8759e7..1adae1a 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,3 +20,5 @@ formats: all python: install: - requirements: docs/requirements.txt + - method: pip + path: . diff --git a/docs/requirements.txt b/docs/requirements.txt index 9187886..5ac8ad5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,6 +10,7 @@ exhale nbsphinx ipykernel pybind11 +mypy jax jaxlib jaxtyping From c7a3a05bb0408ba1d4c275f091e58db155c965dc Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:39:12 -0700 Subject: [PATCH 22/45] Switching runner to ubuntu 24, using python 3.12, and adding cmake for readthedocs. Adding ninja. Adding ninja. --- .readthedocs.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 1adae1a..1afb725 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,11 +3,13 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "3.10" + python: "3.12" apt_packages: - graphviz + - cmake + - ninja-build # Build documentation in the docs/ directory with Sphinx sphinx: From d172687847e005a49a88ae48eb5de8cc879a5c6b Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:43:00 -0700 Subject: [PATCH 23/45] Changing minimum version to 3.11. --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb9fa2f..7930ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "tfc" version = "1.2.1" -requires-python = ">=3.10" +requires-python = ">=3.11" readme = "README.md" dynamic = ["dependencies", "classifiers", "authors", "license", "description"] diff --git a/setup.py b/setup.py index 6b3e8a7..45dd87b 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def build_extension(self, ext): packages=find_packages("src"), package_dir={"": "src"}, package_data={"": ["src/tfc/py.typed"]}, - python_requires=">=3.10", + python_requires=">=3.11", include_package_data=True, ext_modules=[CMakeExtension("BF")], cmdclass={"build_ext": CMakeBuild}, From 40ba76e1b37392f241584d71312045c1d7c7419f Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:47:50 -0700 Subject: [PATCH 24/45] Changing to BF.cc --- src/tfc/utils/{BF.cxx => BF.cc} | 0 src/tfc/utils/CMakeLists.txt | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/tfc/utils/{BF.cxx => BF.cc} (100%) diff --git a/src/tfc/utils/BF.cxx b/src/tfc/utils/BF.cc similarity index 100% rename from src/tfc/utils/BF.cxx rename to src/tfc/utils/BF.cc diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index 72ab0f0..ed268b6 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -35,7 +35,7 @@ set(PYBIND11_FINDPYTHON ON) find_package(pybind11 3.0 REQUIRED CONFIG) # Create the bf library -add_library(bf BF.cxx) +add_library(bf BF.cc) target_link_libraries(bf PUBLIC Python3::Python) # Create the BF.py Python file From 4456a02855d3393efd325b42b9b3eec36888b6ab Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 08:47:53 -0700 Subject: [PATCH 25/45] Fixing input path. --- docs/Doxyfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Doxyfile b/docs/Doxyfile index 6333c8a..9c400fa 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -794,7 +794,7 @@ WARN_LOGFILE = # Note: If this tag is empty the current directory is searched. #INPUT = "../../../src/cxx" "../../../src/tfc" "../../../src/tfc/utils" -INPUT = "../../../src/tfc/utils/BF" +INPUT = "../../../src/tfc/utils" # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses From f417eba09bd5f8d2dc1e238bbba6327ead4b449f Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 09:35:31 -0700 Subject: [PATCH 26/45] Formatting C++ code using clang. --- clang-format.yaml | 14 + src/tfc/utils/BF.cc | 3403 +++++++++++++++++++++------------------- src/tfc/utils/BF.h | 1150 +++++++------- src/tfc/utils/BF_Py.cc | 352 ++--- utils/Makefile | 4 +- 5 files changed, 2593 insertions(+), 2330 deletions(-) create mode 100644 clang-format.yaml diff --git a/clang-format.yaml b/clang-format.yaml new file mode 100644 index 0000000..d0ba256 --- /dev/null +++ b/clang-format.yaml @@ -0,0 +1,14 @@ +BasedOnStyle: LLVM +IndentWidth: 4 +ColumnLimit: 120 +AllowShortIfStatementsOnASingleLine: false + +# Make method/function arguments line up if they are broken onto a new line +BinPackArguments: false +BinPackParameters: false + +# If we break up constructor initializers, do so before the comma. +BreakConstructorInitializers: BeforeComma + +# Indent all namespaces +NamespaceIndentation: All diff --git a/src/tfc/utils/BF.cc b/src/tfc/utils/BF.cc index a0fee14..5566a9a 100644 --- a/src/tfc/utils/BF.cc +++ b/src/tfc/utils/BF.cc @@ -2,1776 +2,1959 @@ // Initialize static BasisFunc variables int BasisFunc::nIdentifier = 0; -std::vector BasisFunc::BasisFuncContainer; +std::vector BasisFunc::BasisFuncContainer; // xlaWrapper function -void xlaWrapper(void* out, void** in){ - int N = (reinterpret_cast(in[0]))[0]; - BasisFunc::BasisFuncContainer[N]->xla(out,in); +void xlaWrapper(void *out, void **in) { + int N = (reinterpret_cast(in[0]))[0]; + BasisFunc::BasisFuncContainer[N]->xla(out, in); }; #ifdef HAS_CUDA - // xlaGpuWrapper function - void xlaGpuWrapper(CUstream stream, void** buffers, const char* opaque, size_t opaque_len){ - int* N = new int[1]; - N[0] = 0; - //cudaMemcpy(N,reinterpret_cast(buffers[6]),1*sizeof(int),cudaMemcpyDeviceToHost); - BasisFunc::BasisFuncContainer[*N]->xlaGpu(stream,buffers,opaque,opaque_len); - delete[] N; - }; +// xlaGpuWrapper function +void xlaGpuWrapper(CUstream stream, void **buffers, const char *opaque, size_t opaque_len) { + int *N = new int[1]; + N[0] = 0; + // cudaMemcpy(N,reinterpret_cast(buffers[6]),1*sizeof(int),cudaMemcpyDeviceToHost); + BasisFunc::BasisFuncContainer[*N]->xlaGpu(stream, buffers, opaque, opaque_len); + delete[] N; +}; #endif // Parent basis function class: ********************************************************************** -BasisFunc::BasisFunc(double x0in, double xf, const int* nCin, int ncDim0, int min, double z0in, double zf){ - - // Initialize internal variables based on user givens - nC = new int[ncDim0]; - memcpy(nC,nCin,ncDim0*sizeof(int)); - numC = ncDim0; - - z0 = z0in; - - m = min; - if (zf == DBL_MAX){ - c = 1.; x0 = 0.; - } else { - x0 = x0in; - c = (zf-z0)/(xf-x0); - } - - // Track this instance of BasisFunc - BasisFuncContainer.push_back(this); - identifier = nIdentifier; - nIdentifier++; - - // Create a PyCapsule with xla function for XLA compilation - xlaCapsule = GetXlaCapsule(); - #ifdef HAS_CUDA - xlaGpuCapsule = GetXlaCapsuleGpu(); - #endif +BasisFunc::BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int min, double z0in, double zf) { + + // Initialize internal variables based on user givens + nC = new int[ncDim0]; + memcpy(nC, nCin, ncDim0 * sizeof(int)); + numC = ncDim0; + + z0 = z0in; + + m = min; + if (zf == DBL_MAX) { + c = 1.; + x0 = 0.; + } else { + x0 = x0in; + c = (zf - z0) / (xf - x0); + } + + // Track this instance of BasisFunc + BasisFuncContainer.push_back(this); + identifier = nIdentifier; + nIdentifier++; + + // Create a PyCapsule with xla function for XLA compilation + xlaCapsule = GetXlaCapsule(); +#ifdef HAS_CUDA + xlaGpuCapsule = GetXlaCapsuleGpu(); +#endif }; -BasisFunc::~BasisFunc(){ - delete[] nC; +BasisFunc::~BasisFunc() { delete[] nC; }; + +void BasisFunc::H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) { + *nOut = n; + *mOut = full ? m : m - numC; + + int j, k; + double dMult = pow(c, d); + double *dark = new double[n * m]; + double *z = new double[n]; + + for (k = 0; k < n; k++) + z[k] = (x[k] - x0) * c + z0; + + *F = (double *)malloc((*mOut) * n * sizeof(double)); + Hint(d, z, n, dark); + + if (!full) { + int i = -1; + bool flag; + for (j = 0; j < m; j++) { + flag = false; + for (k = 0; k < numC; k++) { + if (j == nC[k]) { + flag = true; + break; + } + } + if (flag) + continue; + else + i++; + for (k = 0; k < n; k++) + (*F)[(*mOut) * k + i] = dark[m * k + j] * dMult; + } + } else { + for (j = 0; j < m; j++) { + for (k = 0; k < n; k++) + (*F)[m * k + j] = dark[m * k + j] * dMult; + } + } + delete[] dark; + delete[] z; }; -void BasisFunc::H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full){ - *nOut = n; - *mOut = full ? m : m-numC; - - int j,k; - double dMult = pow(c,d); - double* dark = new double[n*m]; - double* z = new double[n]; - - for (k=0;k(out); + double *x = reinterpret_cast(in[1]); + int d = (reinterpret_cast(in[2]))[0]; + bool full = (reinterpret_cast(in[3]))[0]; + int n = (reinterpret_cast(in[4]))[0]; + int mOut = (reinterpret_cast(in[5]))[0]; + + int j, k; + double dMult = pow(c, d); + double *dark = new double[n * m]; + double *z = new double[n]; + + for (k = 0; k < n; k++) + z[k] = (x[k] - x0) * c + z0; + + Hint(d, z, n, dark); + + if (!full) { + int i = -1; + bool flag; + for (j = 0; j < m; j++) { + flag = false; + for (k = 0; k < numC; k++) { + if (j == nC[k]) { + flag = true; + break; + } + } + if (flag) + continue; + else + i++; + for (k = 0; k < n; k++) + out_buf[mOut * k + i] = dark[m * k + j] * dMult; + } + } else { + for (j = 0; j < m; j++) { + for (k = 0; k < n; k++) + out_buf[m * k + j] = dark[m * k + j] * dMult; + } + } + delete[] dark; + delete[] z; }; -void BasisFunc::xla(void* out, void** in){ - double* out_buf = reinterpret_cast(out); - double* x = reinterpret_cast(in[1]); - int d = (reinterpret_cast(in[2]))[0]; - bool full = (reinterpret_cast(in[3]))[0]; - int n = (reinterpret_cast(in[4]))[0]; - int mOut = (reinterpret_cast(in[5]))[0]; - - int j,k; - double dMult = pow(c,d); - double* dark = new double[n*m]; - double* z = new double[n]; - - for (k=0;k(xlaFnPtr), name, NULL); - return capsule; +PyObject *BasisFunc::GetXlaCapsule() { + xlaFnType xlaFnPtr = xlaWrapper; + const char *name = "xla._CUSTOM_CALL_TARGET"; + PyObject *capsule; + capsule = PyCapsule_New(reinterpret_cast(xlaFnPtr), name, NULL); + return capsule; }; #ifdef HAS_CUDA - PyObject* BasisFunc::GetXlaCapsuleGpu(){ - xlaGpuFnType xlaFnPtr = xlaGpuWrapper; - const char* name = "xla._CUSTOM_CALL_TARGET"; - PyObject* capsule; - capsule = PyCapsule_New(reinterpret_cast(xlaFnPtr), name, NULL); - return capsule; - }; +PyObject *BasisFunc::GetXlaCapsuleGpu() { + xlaGpuFnType xlaFnPtr = xlaGpuWrapper; + const char *name = "xla._CUSTOM_CALL_TARGET"; + PyObject *capsule; + capsule = PyCapsule_New(reinterpret_cast(xlaFnPtr), name, NULL); + return capsule; +}; #endif // COP: ********************************************************************** -void CP::Hint(const int d, const double* x, const int nOut, double* dark){ - - int j,k; - int deg = m-1; - if (deg == 0){ - if (d >0){ - for (k=0;k 1){ - for (k=0;k 0){ - for (k=0;k 0) { + for (k = 0; k < nOut; k++) + dark[k] = 0.; + } else { + for (k = 0; k < nOut; k++) + dark[k] = 1.; + } + } else if (deg == 1) { + if (d > 1) { + for (k = 0; k < m * nOut; k++) + dark[k] = 0.; + } else if (d > 0) { + for (k = 0; k < nOut; k++) { + dark[m * k] = 0.; + dark[m * k + 1] = 1.; + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1; + dark[m * k + 1] = x[k]; + } + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1; + dark[m * k + 1] = x[k]; + } + for (k = 2; k < m; k++) { + for (j = 0; j < nOut; j++) + dark[m * j + k] = 2. * x[j] * dark[m * j + k - 1] - dark[m * j + k - 2]; + } + RecurseDeriv(d, 0, x, nOut, dark, m); + } + return; }; -void CP::RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - if (dCurr != d){ - int j, k; - double* dark = new double[mOut*nOut]; - memcpy(&dark[0],F,mOut*nOut*sizeof(double)); - if (dCurr == 0){ - for (k=0;k0){ - for (k=0;k 1){ - for (k=0;k 0){ - for (k=0;k 0) { + for (k = 0; k < nOut; k++) + dark[k] = 0.; + } else { + for (k = 0; k < nOut; k++) + dark[k] = 1.; + } + } else if (deg == 1) { + if (d > 1) { + for (k = 0; k < m * nOut; k++) + dark[k] = 0.; + } else if (d > 0) { + for (k = 0; k < nOut; k++) { + dark[m * k] = 0.; + dark[m * k + 1] = 1.; + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1; + dark[m * k + 1] = x[k]; + } + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1; + dark[m * k + 1] = x[k]; + } + for (k = 1; k < m - 1; k++) { + for (j = 0; j < nOut; j++) + dark[m * j + k + 1] = ((2. * k + 1.) * x[j] * dark[m * j + k] - k * dark[m * j + k - 1]) / (k + 1.); + } + RecurseDeriv(d, 0, x, nOut, dark, m); + } + return; }; -void LeP::RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - - if (dCurr != d){ - int j, k; - double* dark = new double[mOut*nOut]; - memcpy(&dark[0],F,mOut*nOut*sizeof(double)); - if (dCurr == 0){ - for (k=0;k0){ - for (k=0;k 1){ - for (k=0;k 0){ - for (k=0;k 0) { + for (k = 0; k < nOut; k++) + dark[k] = 0.; + } else { + for (k = 0; k < nOut; k++) + dark[k] = 1.; + } + } else if (deg == 1) { + if (d > 1) { + for (k = 0; k < m * nOut; k++) + dark[k] = 0.; + } else if (d > 0) { + for (k = 0; k < nOut; k++) { + dark[m * k] = 0.; + dark[m * k + 1] = -1.; + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = 1. - x[k]; + } + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = 1. - x[k]; + } + for (k = 1; k < m - 1; k++) { + for (j = 0; j < nOut; j++) + dark[m * j + k + 1] = ((2. * k + 1. - x[j]) * dark[m * j + k] - k * dark[m * j + k - 1]) / (k + 1.); + } + RecurseDeriv(d, 0, x, nOut, dark, m); + } + return; }; -void LaP::RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - - if (dCurr != d){ - int j, k; - double* dark = new double[mOut*nOut]; - memcpy(&dark[0],F,mOut*nOut*sizeof(double)); - if (dCurr == 0){ - for (k=0;k0){ - for (k=0;k 1){ - for (k=0;k 0){ - for (k=0;k 0) { + for (k = 0; k < nOut; k++) + dark[k] = 0.; + } else { + for (k = 0; k < nOut; k++) + dark[k] = 1.; + } + } else if (deg == 1) { + if (d > 1) { + for (k = 0; k < m * nOut; k++) + dark[k] = 0.; + } else if (d > 0) { + for (k = 0; k < nOut; k++) { + dark[m * k] = 0.; + dark[m * k + 1] = 1.; + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = x[k]; + } + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = x[k]; + } + for (k = 1; k < m - 1; k++) { + for (j = 0; j < nOut; j++) + dark[m * j + k + 1] = x[j] * dark[m * j + k] - k * dark[m * j + k - 1]; + } + RecurseDeriv(d, 0, x, nOut, dark, m); + } + return; }; -void HoPpro::RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - - if (dCurr != d){ - int j, k; - double* dark = new double[mOut*nOut]; - memcpy(&dark[0],F,mOut*nOut*sizeof(double)); - if (dCurr == 0){ - for (k=0;k0){ - for (k=0;k 1){ - for (k=0;k 0){ - for (k=0;k 0) { + for (k = 0; k < nOut; k++) + dark[k] = 0.; + } else { + for (k = 0; k < nOut; k++) + dark[k] = 1.; + } + } else if (deg == 1) { + if (d > 1) { + for (k = 0; k < m * nOut; k++) + dark[k] = 0.; + } else if (d > 0) { + for (k = 0; k < nOut; k++) { + dark[m * k] = 0.; + dark[m * k + 1] = 2.; + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = 2. * x[k]; + } + } + } else { + for (k = 0; k < nOut; k++) { + dark[m * k] = 1.; + dark[m * k + 1] = 2. * x[k]; + } + for (k = 1; k < m - 1; k++) { + for (j = 0; j < nOut; j++) + dark[m * j + k + 1] = 2. * x[j] * dark[m * j + k] - 2. * k * dark[m * j + k - 1]; + } + RecurseDeriv(d, 0, x, nOut, dark, m); + } + return; }; -void HoPphy::RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - - if (dCurr != d){ - int j, k; - double* dark = new double[mOut*nOut]; - memcpy(&dark[0],F,mOut*nOut*sizeof(double)); - if (dCurr == 0){ - for (k=0;k(out); - double* x = reinterpret_cast(in[1]); - int* d = reinterpret_cast(in[2]); - int dDim0 = (reinterpret_cast(in[3]))[0]; - bool full = (reinterpret_cast(in[4]))[0]; - int nOut = (reinterpret_cast(in[5]))[0]; - int mOut = (reinterpret_cast(in[6]))[0]; - - nHint(x,nOut,d,dDim0,mOut,out_buf,full); +void nBasisFunc::xla(void *out, void **in) { + double *out_buf = reinterpret_cast(out); + double *x = reinterpret_cast(in[1]); + int *d = reinterpret_cast(in[2]); + int dDim0 = (reinterpret_cast(in[3]))[0]; + bool full = (reinterpret_cast(in[4]))[0]; + int nOut = (reinterpret_cast(in[5]))[0]; + int mOut = (reinterpret_cast(in[6]))[0]; + nHint(x, nOut, d, dDim0, mOut, out_buf, full); }; -void nBasisFunc::nHint(const double* x, int n, const int* d, int dDim0, int numBasis, double*& F, const bool full){ - - int j,k; - double* dark = new double[n*m]; - double* T = new double[n*m*dim]; - double* z = new double[n*dim]; - double dMult; - - // Calculate the basis function domain points - for (j=0;j= dDim0){ - Hint(0,z+k*n,n,dark); - dMult = 1.; - } else { - Hint(d[k],z+k*n,n,dark); - dMult = pow(c[k],d[k]); - } - for (j=0;j= dDim0) { + Hint(0, z + k * n, n, dark); + dMult = 1.; + } else { + Hint(d[k], z + k * n, n, dark); + dMult = pow(c[k], d[k]); + } + for (j = 0; j < n * m; j++) + T[j + k * m * n] = dark[j] * dMult; + } + + for (k = 0; k < n * numBasis; k++) + F[k] = 1.; + + int count = 0; #ifdef WINDOWS_MSVC - int* vec = new int[dim]; + int *vec = new int[dim]; #else - int vec[dim]; + int vec[dim]; #endif - RecurseBasis(dim-1, vec, count, full, n, numBasis, &T[0], F); + RecurseBasis(dim - 1, vec, count, full, n, numBasis, &T[0], F); #ifdef WINDOWS_MSVC - delete[] vec; + delete[] vec; #endif - delete[] dark; delete[] T; delete[] z; + delete[] dark; + delete[] T; + delete[] z; }; -void nBasisFunc::NumBasisFunc(int dimCurr, int* vec, int &count, const bool full){ - int k; - if (dimCurr > 0){ - for (k=0;k 0) { + for (k = 0; k < m; k++) { + vec[dimCurr] = k; + NumBasisFunc(dimCurr - 1, vec, count, full); + } + } else { + int j, g; + int sum; + bool flag, flag1; + for (k = 0; k < m; k++) { + vec[dimCurr] = k; + flag = false; + sum = 0; + if (full) { + for (j = 0; j < dim; j++) + sum += vec[j]; + if (sum <= m - 1) + count++; + } else { + + // If at least one of the dimensions' basis functions is not a constraint, then + // set flag = true + for (j = 0; j < dim; j++) { + flag1 = true; + for (g = 0; g < numC; g++) { + if (vec[j] == nC[j * numC + g]) { + flag1 = false; + break; + } + } + if (flag1) + flag = true; + } + + // If flag is true and the degree of the product of univariate basis + // functions is less than the degree specified, add one to count + if (flag) { + for (j = 0; j < dim; j++) + sum += vec[j]; + if (sum <= m - 1) + count++; + } + } + } + } + return; }; -void nBasisFunc::RecurseBasis(int dimCurr, int* vec, int &count, const bool full, const int in, const int numBasis, const double* T, double* out){ - int k; - if (dimCurr > 0){ - for (k=0;k 0) { + for (k = 0; k < m; k++) { + vec[dimCurr] = k; + RecurseBasis(dimCurr - 1, vec, count, full, in, numBasis, T, out); + } + } else { + int j, g, h, l; + int sum; + bool flag, flag1; + for (k = 0; k < m; k++) { + vec[dimCurr] = k; + flag = false; + sum = 0; + if (full) { + for (j = 0; j < dim; j++) + sum += vec[j]; + if (sum <= m - 1) { + for (h = 0; h < in; h++) { + for (l = 0; l < dim; l++) + out[h * numBasis + count] *= T[m * in * l + vec[l] + h * m]; + } + count++; + } + } else { + + // If at least one of the dimensions' basis functions is not a constraint, then + // set flag = true + for (j = 0; j < dim; j++) { + flag1 = true; + for (g = 0; g < numC; g++) { + if (vec[j] == nC[j * numC + g]) { + flag1 = false; + break; + } + } + if (flag1) + flag = true; + } + + // If flag is true and the degree of the product of univariate basis + // functions is less than the degree specified, add one to count + if (flag) { + for (j = 0; j < dim; j++) + sum += vec[j]; + if (sum <= m - 1) { + for (h = 0; h < in; h++) { + for (l = 0; l < dim; l++) + out[h * numBasis + count] *= T[m * in * l + vec[l] + h * m]; + } + count++; + } + } + } + } + } + return; }; // nELM base class: *********************************************************************************** -nELM::nELM(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min, double z0in, double zfin){ - - int k; - bool flag = true; - - // Initialize internal variables based on user givens - dim = x0Dim0; - m = min; - numC = ncDim0; - z0 = z0in; - zf = zfin; - - x0 = new double[dim]; - memcpy(x0,x0in,dim*sizeof(double)); - - nC = new int[dim*numC]; - memcpy(nC,nCin,ncDim0*sizeof(int)); - - c = new double[dim]; - for (k=0; k 1 || ((d[i] == 1) && (h != -1))){ - zeroFlag = true; - break; - } else if (d[i] == 1) { - h = i; - dark1 = c[i]; - } - } - - if (zeroFlag) { - for (j=0;j 1 || ((d[i] == 1) && (h != -1))) { + zeroFlag = true; + break; + } else if (d[i] == 1) { + h = i; + dark1 = c[i]; + } + } + + if (zeroFlag) { + for (j = 0; j < in; j++) { + for (k = 0; k < m; k++) + F[m * j + k] = 0.; + } + } else { + for (j = 0; j < in; j++) { + for (k = 0; k < m; k++) { + dark = 0.; + for (i = 0; i < dim; i++) + dark += w[i * m + k] * x[i * in + j]; + F[m * j + k] = std::max(0., dark + b[k]); + } + } + } + + if (h != -1) { + for (j = 0; j < in; j++) { + for (k = 0; k < m; k++) { + if (F[m * j + k] != 0.) { + F[m * j + k] = dark1 * w[h * m + k]; + } + } + } + } + return; }; diff --git a/src/tfc/utils/BF.h b/src/tfc/utils/BF.h index 50ee20e..00b24dd 100644 --- a/src/tfc/utils/BF.h +++ b/src/tfc/utils/BF.h @@ -1,669 +1,739 @@ #define _USE_MATH_DEFINES // Needed by Windows +#include +#include #include -#include #include -#include -#include +#include #ifdef HAS_CUDA - #include - #include - #include +#include +#include +#include #endif - #ifndef BF_H #define BF_H -// BasisFunc ************************************************************************************************************************** -/** This class is an abstract class used to create all other basis function classes. It defines standard methods to call the basis function and its - * derivatives, as well as provides wrappers for XLA computation. */ -class BasisFunc{ - - public: - /** Beginning of the basis function domain. */ - double z0; - - /** Start of the problem domain. */ - double x0; - - /** Multiplier in the linear domain map. */ - double c; - - /** Array that specifies which basis functions to remove. */ - int* nC; - - /** Number of basis functions to be removed. */ - int numC; - - /** Number of basis functions to use. */ - int m; - - /** Unique identifier for this instance of BasisFunc. */ - int identifier; - - /** PyObject that contains the XLA version of the basis function. */ - PyObject* xlaCapsule; - - #ifdef HAS_CUDA - /** PyObject that contains the XLA version of the basis function that uses a CUDA GPU kernel. */ - PyObject* xlaGpuCapsule; - #else - const char* xlaGpuCapsule = "CUDA NOT FOUND, GPU NOT IMPLEMENTED."; - #endif - - /** Counter that increments each time a new instance of BasisFunc is created. */ - static int nIdentifier; - - /** Vector that contains pointers to all BasisFunc classes. */ - static std::vector BasisFuncContainer; - - public: - /** Basis function class constructor. - * - Stores variables based on user supplied givens - * - Stores a pointer to itself using static variables - * - Creates PyCapsule for xla function. */ - BasisFunc(double x0in, double xf, const int* nCin, int ncDim0, int min, double z0in=0., double zf=DBL_MAX); - - /** Basis function class destructor. Removes memory used by the basis function class. */ - virtual ~BasisFunc(); - - // Prevent copying - BasisFunc(const BasisFunc&) = delete; - BasisFunc& operator=(const BasisFunc&) = delete; - - // Prevent moving - BasisFunc(BasisFunc&&) = delete; - BasisFunc& operator=(BasisFunc&&) = delete; - - /** Function is used to create a basis function matrix and its derivatives. This matrix is is an m x N matrix where: - * - m is the number of basis functions - * - N = in is the number of points in x - * - d is used to specify the derivative - * - full is a bool that specifies: - * - If true, full matrix with no basis functions removed is returned - * - If false, matrix columns corresponding to the values in nC are removed - * - useVal is a bool that specifies: - * - If true, uses the x values given - * - If false, uses the z values from the class - * Note that this function is used to hook into Python, thus the extra arguments. */ - virtual void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full); - - /** This function is an XLA version of the basis function. */ - virtual void xla(void* out, void** in); - - #ifdef HAS_CUDA - /** This function is an XLA version of the basis function that uses a CUDA GPU kernel. */ - void xlaGpu(CUstream stream, void** buffers, const char* opaque, size_t opaque_len); - #endif - - protected: - /** Dummy empty constructor allows derived classes without calling constructor explicitly. */ - BasisFunc(){}; - - /** This function creates a PyCapsule object that wraps the XLA verison of the basis function. */ - PyObject* GetXlaCapsule(); - - #ifdef HAS_CUDA - /** This function creates a PyCapsule object that wraps the XLA verison of the basis function that uses a CUDA GPU kernel. */ - PyObject* GetXlaCapsuleGpu(); - #endif - - private: - /** Function used internally to create the basis function matrices. */ - virtual void Hint(const int d, const double* x, const int nOut, double* dark) = 0; - - /** Function used internally to create derivatives of the basis function matrices. */ - virtual void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut) = 0; -}; +// BasisFunc +// ************************************************************************************************************************** +/** This class is an abstract class used to create all other basis function classes. It defines standard methods to call + * the basis function and its derivatives, as well as provides wrappers for XLA computation. */ +class BasisFunc { -// XLA related declarations: ********************************************************************************************************** -/** Pointer for XLA-type function that can be cast to void* and put in a PyCapsule. */ -typedef void(*xlaFnType)(void*,void**); + public: + /** Beginning of the basis function domain. */ + double z0; -#ifdef HAS_CUDA - /** Pointer for GPU compatible XLA-type function that can be cast to void* and put in a PyCapsule. */ - typedef void(*xlaGpuFnType)(CUstream,void**,const char*,size_t); - - /** Function used to wrap BasisFunc->xlaGpu in C-style function that can be cast to void*. */ - void xlaGpuWrapper(CUstream stream, void** buffers, const char* opaque, size_t opaque_len); -#endif + /** Start of the problem domain. */ + double x0; -// CP: ******************************************************************************************************************************** -/** Class for Chebyshev orthogonal polynomials. */ -class CP: virtual public BasisFunc { - public: - /** CP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - CP(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min,-1.,1.){}; + /** Multiplier in the linear domain map. */ + double c; + /** Array that specifies which basis functions to remove. */ + int *nC; - /** CP class destructor.*/ - virtual ~CP(){}; + /** Number of basis functions to be removed. */ + int numC; - // Prevent copying - CP(const CP&) = delete; - CP& operator=(const CP&) = delete; + /** Number of basis functions to use. */ + int m; - // Prevent moving - CP(CP&&) = delete; - CP& operator=(CP&&) = delete; + /** Unique identifier for this instance of BasisFunc. */ + int identifier; - protected: - /** Dummy CP class constructor. Used only in n-dimensions. */ - CP(){}; + /** PyObject that contains the XLA version of the basis function. */ + PyObject *xlaCapsule; - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark); +#ifdef HAS_CUDA + /** PyObject that contains the XLA version of the basis function that uses a CUDA GPU kernel. */ + PyObject *xlaGpuCapsule; +#else + const char *xlaGpuCapsule = "CUDA NOT FOUND, GPU NOT IMPLEMENTED."; +#endif - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut); -}; + /** Counter that increments each time a new instance of BasisFunc is created. */ + static int nIdentifier; + + /** Vector that contains pointers to all BasisFunc classes. */ + static std::vector BasisFuncContainer; + + public: + /** Basis function class constructor. + * - Stores variables based on user supplied givens + * - Stores a pointer to itself using static variables + * - Creates PyCapsule for xla function. */ + BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int min, double z0in = 0., double zf = DBL_MAX); + + /** Basis function class destructor. Removes memory used by the basis function class. */ + virtual ~BasisFunc(); + + // Prevent copying + BasisFunc(const BasisFunc &) = delete; + BasisFunc &operator=(const BasisFunc &) = delete; + + // Prevent moving + BasisFunc(BasisFunc &&) = delete; + BasisFunc &operator=(BasisFunc &&) = delete; + + /** Function is used to create a basis function matrix and its derivatives. This matrix is is an m x N matrix where: + * - m is the number of basis functions + * - N = in is the number of points in x + * - d is used to specify the derivative + * - full is a bool that specifies: + * - If true, full matrix with no basis functions removed is returned + * - If false, matrix columns corresponding to the values in nC are removed + * - useVal is a bool that specifies: + * - If true, uses the x values given + * - If false, uses the z values from the class + * Note that this function is used to hook into Python, thus the extra arguments. */ + virtual void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full); + + /** This function is an XLA version of the basis function. */ + virtual void xla(void *out, void **in); -// LeP: ******************************************************************************************************************************** -/** Class for Legendre orthogonal polynomials. */ -class LeP: virtual public BasisFunc { - public: - /** LeP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - LeP(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min,-1.,1.){}; +#ifdef HAS_CUDA + /** This function is an XLA version of the basis function that uses a CUDA GPU kernel. */ + void xlaGpu(CUstream stream, void **buffers, const char *opaque, size_t opaque_len); +#endif - /** Dummy LeP class constructor. Used only in n-dimensions. */ - LeP(){}; + protected: + /** Dummy empty constructor allows derived classes without calling constructor explicitly. */ + BasisFunc() {}; - /** LeP class destructor.*/ - ~LeP(){}; + /** This function creates a PyCapsule object that wraps the XLA verison of the basis function. */ + PyObject *GetXlaCapsule(); - protected: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark); +#ifdef HAS_CUDA + /** This function creates a PyCapsule object that wraps the XLA verison of the basis function that uses a CUDA GPU + * kernel. */ + PyObject *GetXlaCapsuleGpu(); +#endif - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut); -}; + private: + /** Function used internally to create the basis function matrices. */ + virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0; -// LaP: ******************************************************************************************************************************** -/** Class for Laguerre orthogonal polynomials. */ -class LaP: public BasisFunc { - public: - /** LaP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - LaP(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min){}; - /** LaP class destructor.*/ - ~LaP(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark); - - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut); + /** Function used internally to create derivatives of the basis function matrices. */ + virtual void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) = 0; }; -// HoPpro: ******************************************************************************************************************************** -/** Class for Hermite probablist orthogonal polynomials. */ -class HoPpro: public BasisFunc { - public: - /** HoPpro class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - HoPpro(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min){}; - /** HoPpro class destructor.*/ - ~HoPpro(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark); - - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut); -}; +// XLA related declarations: +// ********************************************************************************************************** +/** Pointer for XLA-type function that can be cast to void* and put in a PyCapsule. */ +typedef void (*xlaFnType)(void *, void **); -// HoPphy: ******************************************************************************************************************************** -/** Class for Hermite physicist orthogonal polynomials. */ -class HoPphy: public BasisFunc { - public: - /** HoPphy class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - HoPphy(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min){}; - /** HoPphy class destructor.*/ - ~HoPphy(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark); - - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut); -}; +#ifdef HAS_CUDA +/** Pointer for GPU compatible XLA-type function that can be cast to void* and put in a PyCapsule. */ +typedef void (*xlaGpuFnType)(CUstream, void **, const char *, size_t); -// FS: ******************************************************************************************************************************** -/** Class for Fourier Series basis. */ -class FS: virtual public BasisFunc { - public: - /** FS class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ - FS(double x0, double xf, const int* nCin, int ncDim0, int min): - BasisFunc(x0,xf,nCin,ncDim0,min,-M_PI,M_PI){}; +/** Function used to wrap BasisFunc->xlaGpu in C-style function that can be cast to void*. */ +void xlaGpuWrapper(CUstream stream, void **buffers, const char *opaque, size_t opaque_len); +#endif - /** Dummy FS class constructor. Used only in n-dimensions. */ - FS(){}; +// CP: +// ******************************************************************************************************************************** +/** Class for Chebyshev orthogonal polynomials. */ +class CP : virtual public BasisFunc { + public: + /** CP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + CP(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {}; - /** FS class destructor.*/ - ~FS(){}; + /** CP class destructor.*/ + virtual ~CP() {}; - protected: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + // Prevent copying + CP(const CP &) = delete; + CP &operator=(const CP &) = delete; - /** This function is unecessary for FS as it is all handled in Hint. Therefore, this is just an empty function that returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - fprintf(stderr, "Warning, this function from FS should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - printf("Warning, this function from FS should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - }; + // Prevent moving + CP(CP &&) = delete; + CP &operator=(CP &&) = delete; -}; + protected: + /** Dummy CP class constructor. Used only in n-dimensions. */ + CP() {}; -// ELM base class: ******************************************************************************************************************************** -/** ELM base class. */ -class ELM: public BasisFunc { - public: - /** ELM weights. */ - double *w; + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark); - /** ELM biases. */ - double *b; + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut); +}; - /** ELM class constructor. Calls BasisFunc class constructor and sets up weights and biases for the ELM. See BasisFunc class for more details. */ - ELM(double x0, double xf, const int* nCin, int ncDim0, int min); +// LeP: +// ******************************************************************************************************************************** +/** Class for Legendre orthogonal polynomials. */ +class LeP : virtual public BasisFunc { + public: + /** LeP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + LeP(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {}; - /** ELM class destructor.*/ - virtual ~ELM(); + /** Dummy LeP class constructor. Used only in n-dimensions. */ + LeP() {}; - /** Python hook to return ELM weights. */ - void getW(double** arrOut, int* nOut); + /** LeP class destructor.*/ + ~LeP() {}; - /** Python hook to set ELM weights. */ - void setW(const double* arrIn, int nIn); + protected: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark); - /** Python hook to return ELM biases. */ - void getB(double** arrOut, int* nOut); + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut); +}; - /** Python hook to set ELM biases. */ - void setB(const double* arrIn, int nIn); +// LaP: +// ******************************************************************************************************************************** +/** Class for Laguerre orthogonal polynomials. */ +class LaP : public BasisFunc { + public: + /** LaP class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + LaP(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min) {}; + /** LaP class destructor.*/ + ~LaP() {}; + + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark); + + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut); +}; - protected: - /** Function used internally to create the basis function matrices. */ - virtual void Hint(const int d, const double* x, const int nOut, double* dark) = 0; +// HoPpro: +// ******************************************************************************************************************************** +/** Class for Hermite probablist orthogonal polynomials. */ +class HoPpro : public BasisFunc { + public: + /** HoPpro class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + HoPpro(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min) {}; + /** HoPpro class destructor.*/ + ~HoPpro() {}; + + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark); + + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut); +}; - /** This function is unecessary for ELM as it is all handled in Hint. Therefore, this is just an empty function that returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){ - fprintf(stderr, "Warning, this function from ELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - printf("Warning, this function from ELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - }; +// HoPphy: +// ******************************************************************************************************************************** +/** Class for Hermite physicist orthogonal polynomials. */ +class HoPphy : public BasisFunc { + public: + /** HoPphy class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + HoPphy(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min) {}; + /** HoPphy class destructor.*/ + ~HoPphy() {}; + + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark); + + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut); +}; +// FS: +// ******************************************************************************************************************************** +/** Class for Fourier Series basis. */ +class FS : virtual public BasisFunc { + public: + /** FS class constructor. Calls BasisFunc class constructor. See BasisFunc class for more details. */ + FS(double x0, double xf, const int *nCin, int ncDim0, int min) + : BasisFunc(x0, xf, nCin, ncDim0, min, -M_PI, M_PI) {}; + + /** Dummy FS class constructor. Used only in n-dimensions. */ + FS() {}; + + /** FS class destructor.*/ + ~FS() {}; + + protected: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); + + /** This function is unecessary for FS as it is all handled in Hint. Therefore, this is just an empty function that + * returns a warning. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + fprintf(stderr, + "Warning, this function from FS should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + printf("Warning, this function from FS should never be called. It seems it has been called by accident. Please " + "check that this function was intended to be called.\n"); + }; }; -// ELM sigmoid: ******************************************************************************************************************************** -/** ELM that uses the sigmoid activation function. */ -class ELMSigmoid: public ELM { +// ELM base class: +// ******************************************************************************************************************************** +/** ELM base class. */ +class ELM : public BasisFunc { + public: + /** ELM weights. */ + double *w; + + /** ELM biases. */ + double *b; + + /** ELM class constructor. Calls BasisFunc class constructor and sets up weights and biases for the ELM. See + * BasisFunc class for more details. */ + ELM(double x0, double xf, const int *nCin, int ncDim0, int min); + + /** ELM class destructor.*/ + virtual ~ELM(); + + /** Python hook to return ELM weights. */ + void getW(double **arrOut, int *nOut); + + /** Python hook to set ELM weights. */ + void setW(const double *arrIn, int nIn); + + /** Python hook to return ELM biases. */ + void getB(double **arrOut, int *nOut); + + /** Python hook to set ELM biases. */ + void setB(const double *arrIn, int nIn); + + protected: + /** Function used internally to create the basis function matrices. */ + virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0; + + /** This function is unecessary for ELM as it is all handled in Hint. Therefore, this is just an empty function that + * returns a warning. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + fprintf(stderr, + "Warning, this function from ELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + printf("Warning, this function from ELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + }; +}; - public: - /** ELMSigmoid class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSigmoid(double x0, double xf, const int* nCin, int ncDim0, int min): - ELM(x0,xf,nCin,ncDim0,min){}; +// ELM sigmoid: +// ******************************************************************************************************************************** +/** ELM that uses the sigmoid activation function. */ +class ELMSigmoid : public ELM { - /** ELMSigmoid class destructor.*/ - ~ELMSigmoid(){}; + public: + /** ELMSigmoid class constructor. Calls ELM class constructor. See ELM class for more details. */ + ELMSigmoid(double x0, double xf, const int *nCin, int ncDim0, int min) + : ELM(x0, xf, nCin, ncDim0, min) {}; - protected: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + /** ELMSigmoid class destructor.*/ + ~ELMSigmoid() {}; + protected: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); }; -// ELM ReLU: ******************************************************************************************************************************** +// ELM ReLU: +// ******************************************************************************************************************************** /** ELM that uses the recitified linear unit activation function. */ -class ELMReLU: public ELM { - - public: - /** ELMReLU class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMReLU(double x0, double xf, const int* nCin, int ncDim0, int min): - ELM(x0,xf,nCin,ncDim0,min){}; +class ELMReLU : public ELM { - /** ELMReLU class destructor.*/ - ~ELMReLU(){}; + public: + /** ELMReLU class constructor. Calls ELM class constructor. See ELM class for more details. */ + ELMReLU(double x0, double xf, const int *nCin, int ncDim0, int min) + : ELM(x0, xf, nCin, ncDim0, min) {}; - protected: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + /** ELMReLU class destructor.*/ + ~ELMReLU() {}; + protected: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); }; -// ELM Tanh: ******************************************************************************************************************************** +// ELM Tanh: +// ******************************************************************************************************************************** /** ELM that uses the tanh activation function. */ -class ELMTanh: public ELM { - - public: - /** ELMTanh class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMTanh(double x0, double xf, const int* nCin, int ncDim0, int min): - ELM(x0,xf,nCin,ncDim0,min){}; +class ELMTanh : public ELM { - /** ELMTanh class destructor.*/ - ~ELMTanh(){}; + public: + /** ELMTanh class constructor. Calls ELM class constructor. See ELM class for more details. */ + ELMTanh(double x0, double xf, const int *nCin, int ncDim0, int min) + : ELM(x0, xf, nCin, ncDim0, min) {}; - private: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + /** ELMTanh class destructor.*/ + ~ELMTanh() {}; + private: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); }; -// ELM Sin: ******************************************************************************************************************************** +// ELM Sin: +// ******************************************************************************************************************************** /** ELM that uses the sin activation function. */ -class ELMSin: public ELM { +class ELMSin : public ELM { - public: - /** ELMSin class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSin(double x0, double xf, const int* nCin, int ncDim0, int min): - ELM(x0,xf,nCin,ncDim0,min){}; + public: + /** ELMSin class constructor. Calls ELM class constructor. See ELM class for more details. */ + ELMSin(double x0, double xf, const int *nCin, int ncDim0, int min) + : ELM(x0, xf, nCin, ncDim0, min) {}; - /** ELMSin class destructor.*/ - ~ELMSin(){}; - - private: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + /** ELMSin class destructor.*/ + ~ELMSin() {}; + private: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); }; -// ELM Swish: ******************************************************************************************************************************** +// ELM Swish: +// ******************************************************************************************************************************** /** ELM that uses the swish activation function. */ -class ELMSwish: public ELM { - - public: - /** ELMSwish class constructor. Calls ELM class constructor. See ELM class for more details. */ - ELMSwish(double x0, double xf, const int* nCin, int ncDim0, int min): - ELM(x0,xf,nCin,ncDim0,min){}; +class ELMSwish : public ELM { - /** ELMSwish class destructor.*/ - ~ELMSwish(){}; + public: + /** ELMSwish class constructor. Calls ELM class constructor. See ELM class for more details. */ + ELMSwish(double x0, double xf, const int *nCin, int ncDim0, int min) + : ELM(x0, xf, nCin, ncDim0, min) {}; - private: - /** Function used internally to create the basis function matrices and derivatives. */ - void Hint(const int d, const double* x, const int nOut, double* dark); + /** ELMSwish class destructor.*/ + ~ELMSwish() {}; + private: + /** Function used internally to create the basis function matrices and derivatives. */ + void Hint(const int d, const double *x, const int nOut, double *dark); }; -// n-D Basis function base class: *************************************************************************************************** +// n-D Basis function base class: +// *************************************************************************************************** /** Base class for n-dimensional basis functions. This class inherits from BasisFunc, and contains - * methods that are used for all n-dimensional basis fuctions. This is an abstract class. + * methods that are used for all n-dimensional basis fuctions. This is an abstract class. * Concrete n-dimensional basis functions will inherit from this class and one of the concrete * 1-dimensional basis function classes. */ -class nBasisFunc: virtual public BasisFunc{ - - public: - - /** Beginning of the basis function domain. */ - double z0; - - /** Beginning of the basis function domain. */ - double zf; - - /** Multipliers for the linear domain map. */ - double* c; - - /** Initial value of the domain */ - double* x0; - - /** Number of dimensions. */ - int dim; - - /** Number of basis functions in H matrix. */ - int numBasisFunc; - - /** Number of basis functions in full H matrix. */ - int numBasisFuncFull; - - public: - /** n-D basis function class constructor. */ - nBasisFunc(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min, double z0in=0., double zfin=0.); - - /** n-D basis function class destructor. */ - virtual ~nBasisFunc(); - - /** This function is used to create a basis function matrix and its derivatives. */ - void H(const double* x, int in, int xDim1, const int* d, int dDim0, int* nOut, int* mOut, double** F, const bool full); - - /** This function is an XLA version of the basis function. */ - void xla(void* out, void** in) override; - - /** Python hook to return domain mapping constants. */ - void getC(double** arrOut, int* nOut); - - protected: - /** Dummy nBasisFunc constructor used by nELM only. */ - nBasisFunc(){}; - - private: - /** - * Including override of BasisFunc so we don't have issues with hidden virtual overloads. - * However, this should never be called from nBasisFunc. - * If it is, it will throw an error. - */ - void H(const double* x, int n, const int d, int* nOut, int* mOut, double** F, bool full) override; - - /** Recursive function used to perform the tensor product of univarite basis functions to form multivariate basis functions. */ - void RecurseBasis(int dimCurr, int* vec, int &count, const bool full, const int in, const int numBasis, const double* T, double* out); - - /** Recursive function used to calculate the size of the multivariate basis function matrix. */ - void NumBasisFunc(int dimCurr, int* vec, int &count, const bool full); - - /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, that if dDim0 < dim, then 0's will be used for the tail end.*/ - virtual void nHint(const double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full); - - /** Function used internally to create the basis function matrices. */ - virtual void Hint(const int d, const double* x, const int nOut, double* dark) override = 0; - - /** Function used internally to create derivatives of the basis function matrices. */ - virtual void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut) override = 0; - +class nBasisFunc : virtual public BasisFunc { + + public: + /** Beginning of the basis function domain. */ + double z0; + + /** Beginning of the basis function domain. */ + double zf; + + /** Multipliers for the linear domain map. */ + double *c; + + /** Initial value of the domain */ + double *x0; + + /** Number of dimensions. */ + int dim; + + /** Number of basis functions in H matrix. */ + int numBasisFunc; + + /** Number of basis functions in full H matrix. */ + int numBasisFuncFull; + + public: + /** n-D basis function class constructor. */ + nBasisFunc(const double *x0in, + int x0Dim0, + const double *xf, + int xfDim0, + const int *nCin, + int ncDim0, + int ncDim1, + int min, + double z0in = 0., + double zfin = 0.); + + /** n-D basis function class destructor. */ + virtual ~nBasisFunc(); + + /** This function is used to create a basis function matrix and its derivatives. */ + void + H(const double *x, int in, int xDim1, const int *d, int dDim0, int *nOut, int *mOut, double **F, const bool full); + + /** This function is an XLA version of the basis function. */ + void xla(void *out, void **in) override; + + /** Python hook to return domain mapping constants. */ + void getC(double **arrOut, int *nOut); + + protected: + /** Dummy nBasisFunc constructor used by nELM only. */ + nBasisFunc() {}; + + private: + /** + * Including override of BasisFunc so we don't have issues with hidden virtual overloads. + * However, this should never be called from nBasisFunc. + * If it is, it will throw an error. + */ + void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) override; + + /** Recursive function used to perform the tensor product of univarite basis functions to form multivariate basis + * functions. */ + void RecurseBasis(int dimCurr, + int *vec, + int &count, + const bool full, + const int in, + const int numBasis, + const double *T, + double *out); + + /** Recursive function used to calculate the size of the multivariate basis function matrix. */ + void NumBasisFunc(int dimCurr, int *vec, int &count, const bool full); + + /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, + * that if dDim0 < dim, then 0's will be used for the tail end.*/ + virtual void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full); + + /** Function used internally to create the basis function matrices. */ + virtual void Hint(const int d, const double *x, const int nOut, double *dark) override = 0; + + /** Function used internally to create derivatives of the basis function matrices. */ + virtual void + RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override = 0; }; -// n-D CP class: ****************************************************************************************************************** +// n-D CP class: +// ****************************************************************************************************************** /** Class for n-dimensional Chebyshev orthogonal polynomials. */ -class nCP: public nBasisFunc, public CP { - - public: +class nCP : public nBasisFunc, public CP { - /** nCP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nCP(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; + public: + /** nCP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See + * nBasisFunc class for more details. */ + nCP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min) + : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {}; - /** nCP class destructor.*/ - ~nCP(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark){CP::Hint(d,x,nOut,dark);}; - - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){CP::RecurseDeriv(d,dCurr,x,nOut,F,mOut);}; + /** nCP class destructor.*/ + ~nCP() {}; + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark) { CP::Hint(d, x, nOut, dark); }; + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + CP::RecurseDeriv(d, dCurr, x, nOut, F, mOut); + }; }; -// n-D LeP class: ****************************************************************************************************************** +// n-D LeP class: +// ****************************************************************************************************************** /** Class for n-dimensional Legendre orthogonal polynomials. */ -class nLeP: public nBasisFunc, public LeP { - - public: - /** nLeP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nLeP(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-1.,1.){}; +class nLeP : public nBasisFunc, public LeP { - /** nLeP class destructor.*/ - ~nLeP(){}; + public: + /** nLeP class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See + * nBasisFunc class for more details. */ + nLeP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min) + : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {}; - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark){LeP::Hint(d,x,nOut,dark);}; + /** nLeP class destructor.*/ + ~nLeP() {}; - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){LeP::RecurseDeriv(d,dCurr,x,nOut,F,mOut);}; + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark) { LeP::Hint(d, x, nOut, dark); }; + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + LeP::RecurseDeriv(d, dCurr, x, nOut, F, mOut); + }; }; -// n-D FS class: ****************************************************************************************************************** +// n-D FS class: +// ****************************************************************************************************************** /** Class for n-dimensional Fourier Series. */ -class nFS: public nBasisFunc, public FS { - - public: - /** nFS class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See nBasisFunc class for more details. */ - nFS(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int ncDim1, int min):nBasisFunc(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,ncDim1,min,-M_PI,M_PI){}; +class nFS : public nBasisFunc, public FS { - /** nFS class destructor.*/ - ~nFS(){}; + public: + /** nFS class constructor. Calls nBasisFunc class constructor and dummy constructors of remaining parents. See + * nBasisFunc class for more details. */ + nFS(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min) + : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -M_PI, M_PI) {}; - private: - /** Function used internally to create the basis function matrices. */ - void Hint(const int d, const double* x, const int nOut, double* dark){FS::Hint(d,x,nOut,dark);}; + /** nFS class destructor.*/ + ~nFS() {}; - /** Function used internally to create derivatives of the basis function matrices. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut){FS::RecurseDeriv(d,dCurr,x,nOut,F,mOut);}; + private: + /** Function used internally to create the basis function matrices. */ + void Hint(const int d, const double *x, const int nOut, double *dark) { FS::Hint(d, x, nOut, dark); }; + /** Function used internally to create derivatives of the basis function matrices. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + FS::RecurseDeriv(d, dCurr, x, nOut, F, mOut); + }; }; -// n-D ELM base class: ******************************************************************************************************************************************************* +// n-D ELM base class: +// ******************************************************************************************************************************************************* /** n-D ELM base class. */ -class nELM: public nBasisFunc { +class nELM : public nBasisFunc { + + public: + /** Beginning of the basis function domain. */ + double z0; + + /** Beginning of the basis function domain. */ + double zf; + + /** nELM weights. */ + double *w; + + /** nELM biases. */ + double *b; + + /** n-D ELM class constructor. */ + nELM(const double *x0in, + int x0Dim0, + const double *xf, + int xfDim0, + const int *nCin, + int ncDim0, + int min, + double z0in = 0., + double zfin = 1.); + + /** n-D ELM class destructor. */ + virtual ~nELM(); + + /** Python hook to return nELM weights. */ + void setW(const double *arrIn, int dimIn, int nIn); + + /** Python hook to set nELM weights. */ + void getW(int *dimOut, int *nOut, double **arrOut); + + /** Python hook to return nELM biases. */ + void getB(double **arrOut, int *nOut); + + /** Python hook to set nELM biases. */ + void setB(const double *arrIn, int nIn); + + private: + /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, + * that if dDim0 < dim, then 0's will be used for the tail end.*/ + void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full) override; + + /** This function handles creating a full matrix of nELM basis functions. */ + virtual void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) = 0; + + /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function + * that returns a warning. */ + void Hint(const int d, const double *x, const int nOut, double *dark) override { + fprintf(stderr, + "Warning, this function from nELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + printf("Warning, this function from nELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + }; + + /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function + * that returns a warning. */ + void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override { + fprintf(stderr, + "Warning, this function from nELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + printf("Warning, this function from nELM should never be called. It seems it has been called by accident. " + "Please check that this function was intended to be called.\n"); + }; +}; - public: - /** Beginning of the basis function domain. */ - double z0; +// n-D ELM sigmoid class: +// ******************************************************************************************************************************************************* +/** n-D ELM that uses the sigmoid activation function. */ +class nELMSigmoid : public nELM { - /** Beginning of the basis function domain. */ - double zf; + public: + /** nELMSigmoid class constructor. Calls nELM class constructor. See nELM class for more details. */ + nELMSigmoid(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min) + : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {}; - /** nELM weights. */ - double *w; + /** nELMSigmoid class destructor.*/ + ~nELMSigmoid() {}; - /** nELM biases. */ - double *b; + private: + /** Function used internally to create the basis function matrices. */ + void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override; +}; - /** n-D ELM class constructor. */ - nELM(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min, double z0in=0., double zfin=1.); +// n-D ELM Tanh class: +// ******************************************************************************************************************************************************* +/** n-D ELM that uses the tanh activation function. */ +class nELMTanh : public nELM { - /** n-D ELM class destructor. */ - virtual ~nELM(); + public: + /** nELMTanh class constructor. Calls nELM class constructor. See nELM class for more details. */ + nELMTanh(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min) + : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {}; - /** Python hook to return nELM weights. */ - void setW(const double* arrIn, int dimIn, int nIn); + /** nELMTanh class destructor.*/ + ~nELMTanh() {}; - /** Python hook to set nELM weights. */ - void getW(int* dimOut, int* nOut, double** arrOut); + private: + /** Function used internally to create the basis function matrices. */ + void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override; +}; - /** Python hook to return nELM biases. */ - void getB(double** arrOut, int* nOut); +// n-D ELM Sin class: +// ******************************************************************************************************************************************************* +/** n-D ELM that uses the sine activation function. */ +class nELMSin : public nELM { - /** Python hook to set nELM biases. */ - void setB(const double* arrIn, int nIn); + public: + /** nELMSin class constructor. Calls nELM class constructor. See nELM class for more details. */ + nELMSin(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min) + : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {}; - private: + /** nELMSin class destructor.*/ + ~nELMSin() {}; - /** Internal function used to calculate dim sets of univariate basis functions with specified derivatives. Note, that if dDim0 < dim, then 0's will be used for the tail end.*/ - void nHint(const double* x, int in, const int* d, int dDim0, int numBasis, double*& F, const bool full) override; + private: + /** Function used internally to create the basis function matrices. */ + void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override; +}; - /** This function handles creating a full matrix of nELM basis functions. */ - virtual void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) = 0; +// n-D ELM Swish class: +// ******************************************************************************************************************************************************* +/** n-D ELM that uses the swish activation function. */ +class nELMSwish : public nELM { - /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function that returns a warning. */ - void Hint(const int d, const double* x, const int nOut, double* dark) override { - fprintf(stderr, "Warning, this function from nELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - printf("Warning, this function from nELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - }; + public: + /** nELMSwish class constructor. Calls nELM class constructor. See nELM class for more details. */ + nELMSwish(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min) + : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {}; - /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function that returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double* x, const int nOut, double*& F, const int mOut) override { - fprintf(stderr, "Warning, this function from nELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - printf("Warning, this function from nELM should never be called. It seems it has been called by accident. Please check that this function was intended to be called.\n"); - }; + /** nELMSwish class destructor.*/ + ~nELMSwish() {}; + private: + /** Function used internally to create the basis function matrices. */ + void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override; }; -// n-D ELM sigmoid class: ******************************************************************************************************************************************************* -/** n-D ELM that uses the sigmoid activation function. */ -class nELMSigmoid: public nELM { - - public: - /** nELMSigmoid class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSigmoid(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0,int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; - - /** nELMSigmoid class destructor.*/ - ~nELMSigmoid(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) override ; -}; +// n-D ELM ReLU class: +// ******************************************************************************************************************************************************* +/** n-D ELM that uses the rectified linear activation function. */ +class nELMReLU : public nELM { -// n-D ELM Tanh class: ******************************************************************************************************************************************************* -/** n-D ELM that uses the tanh activation function. */ -class nELMTanh: public nELM { - - public: - /** nELMTanh class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMTanh(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; - - /** nELMTanh class destructor.*/ - ~nELMTanh(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) override ; -}; + public: + /** nELMReLU class constructor. Calls nELM class constructor. See nELM class for more details. */ + nELMReLU(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min) + : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {}; -// n-D ELM Sin class: ******************************************************************************************************************************************************* -/** n-D ELM that uses the sine activation function. */ -class nELMSin: public nELM { - - public: - /** nELMSin class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSin(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; - - /** nELMSin class destructor.*/ - ~nELMSin(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) override ; -}; + /** nELMReLU class destructor.*/ + ~nELMReLU() {}; -// n-D ELM Swish class: ******************************************************************************************************************************************************* -/** n-D ELM that uses the swish activation function. */ -class nELMSwish: public nELM { - - public: - /** nELMSwish class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMSwish(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; - - /** nELMSwish class destructor.*/ - ~nELMSwish(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) override ; + private: + /** Function used internally to create the basis function matrices. */ + void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override; }; -// n-D ELM ReLU class: ******************************************************************************************************************************************************* -/** n-D ELM that uses the rectified linear activation function. */ -class nELMReLU: public nELM { - - public: - /** nELMReLU class constructor. Calls nELM class constructor. See nELM class for more details. */ - nELMReLU(const double* x0in, int x0Dim0, const double* xf, int xfDim0, const int* nCin, int ncDim0, int min):nELM(x0in,x0Dim0,xf,xfDim0,nCin,ncDim0,min){}; - - /** nELMReLU class destructor.*/ - ~nELMReLU(){}; - - private: - /** Function used internally to create the basis function matrices. */ - void nElmHint(const int* d, int dDim0, const double* x, const int in, double* F) override ; -}; - #endif diff --git a/src/tfc/utils/BF_Py.cc b/src/tfc/utils/BF_Py.cc index d06a1e7..ce0c326 100644 --- a/src/tfc/utils/BF_Py.cc +++ b/src/tfc/utils/BF_Py.cc @@ -1,24 +1,23 @@ +#include "BF.h" #include -#include #include +#include #include -#include "BF.h" namespace py = pybind11; -template -void add1DInit(auto& c) { - c.def(py::init([](double x0, double xf, py::array_t nC, int min){ - if (nC.ndim() != 1) { - throw py::value_error("The \"nC\" input array must be 1-dimensional."); - } - return std::make_unique(x0, xf, nC.data(), nC.size(), min); - }), - py::arg("x0"), - py::arg("xf"), - py::arg("nC"), - py::arg("min"), - R"( +template void add1DInit(auto &c) { + c.def(py::init([](double x0, double xf, py::array_t nC, int min) { + if (nC.ndim() != 1) { + throw py::value_error("The \"nC\" input array must be 1-dimensional."); + } + return std::make_unique(x0, xf, nC.data(), nC.size(), min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( Constructor. Parameters: @@ -26,29 +25,28 @@ void add1DInit(auto& c) { xf: End of domain nC: Array of indices to remove (1D numpy array) min: Number of basis functions to use - )" - ); + )"); } -template -void addNdInit(auto& c) { - c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min){ - if (x0.ndim() != 1) { - throw py::value_error("The \"x0\" input array must be 1-dimensional."); - } - if (xf.ndim() != 1) { - throw py::value_error("The \"xf\" input array must be 1-dimensional."); - } - if (nC.ndim() != 2) { - throw py::value_error("The \"nC\" input array must be 2-dimensional."); - } - return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.shape()[0], nC.shape()[1], min); - }), - py::arg("x0"), - py::arg("xf"), - py::arg("nC"), - py::arg("min"), - R"( +template void addNdInit(auto &c) { + c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min) { + if (x0.ndim() != 1) { + throw py::value_error("The \"x0\" input array must be 1-dimensional."); + } + if (xf.ndim() != 1) { + throw py::value_error("The \"xf\" input array must be 1-dimensional."); + } + if (nC.ndim() != 2) { + throw py::value_error("The \"nC\" input array must be 2-dimensional."); + } + return std::make_unique( + x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.shape()[0], nC.shape()[1], min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( Constructor. Parameters: @@ -56,29 +54,27 @@ void addNdInit(auto& c) { xf: End of domain nC: Array of indices to remove (2D numpy array) min: Number of basis functions to use - )" - ); + )"); } -template -void addNdElmInit(auto& c) { - c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min){ - if (x0.ndim() != 1) { - throw py::value_error("The \"x0\" input array must be 1-dimensional."); - } - if (xf.ndim() != 1) { - throw py::value_error("The \"xf\" input array must be 1-dimensional."); - } - if (nC.ndim() != 1) { - throw py::value_error("The \"nC\" input array must be 1-dimensional."); - } - return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.size(), min); - }), - py::arg("x0"), - py::arg("xf"), - py::arg("nC"), - py::arg("min"), - R"( +template void addNdElmInit(auto &c) { + c.def(py::init([](py::array_t x0, py::array_t xf, py::array_t nC, int min) { + if (x0.ndim() != 1) { + throw py::value_error("The \"x0\" input array must be 1-dimensional."); + } + if (xf.ndim() != 1) { + throw py::value_error("The \"xf\" input array must be 1-dimensional."); + } + if (nC.ndim() != 1) { + throw py::value_error("The \"nC\" input array must be 1-dimensional."); + } + return std::make_unique(x0.data(), x0.size(), xf.data(), xf.size(), nC.data(), nC.size(), min); + }), + py::arg("x0"), + py::arg("xf"), + py::arg("nC"), + py::arg("min"), + R"( Constructor. Parameters: @@ -86,8 +82,7 @@ void addNdElmInit(auto& c) { xf: End of domain (1D numpy array) nC: Array of indices to remove (1D numpy array) min: Number of basis functions to use - )" - ); + )"); } PYBIND11_MODULE(BF, m) { @@ -99,44 +94,44 @@ PYBIND11_MODULE(BF, m) { .def_readwrite("m", &BasisFunc::m) .def_readwrite("numC", &BasisFunc::numC) .def_readwrite("identifier", &BasisFunc::identifier) - .def_property_readonly("xlaCapsule", [](BasisFunc& self) { - py::object capsule = py::reinterpret_borrow(self.xlaCapsule); - return capsule; - }) - // GPU Capsule (only if available) - #ifdef HAS_CUDA - .def_property_readonly("xlaGpuCapsule", [](BasisFunc& self) { - return py::reinterpret_borrow(self.xlaGpuCapsule); - }) - #else + .def_property_readonly("xlaCapsule", + [](BasisFunc &self) { + py::object capsule = py::reinterpret_borrow(self.xlaCapsule); + return capsule; + }) +// GPU Capsule (only if available) +#ifdef HAS_CUDA + .def_property_readonly("xlaGpuCapsule", + [](BasisFunc &self) { return py::reinterpret_borrow(self.xlaGpuCapsule); }) +#else .def_property_readonly("xlaGpuCapsule", [](BasisFunc&) { return "CUDA NOT FOUND, GPU NOT IMPLEMENTED."; }) - #endif +#endif // Methods - .def("H", - [](BasisFunc& self, - py::array_t x, - int d, - bool full) { + .def( + "H", + [](BasisFunc &self, py::array_t x, int d, bool full) { if (x.ndim() != 1) { throw py::value_error("The \"x\" input array must be 1-dimensional."); } int n = x.size(); int nOut = 0; int mOut = 0; - double* F = nullptr; + double *F = nullptr; self.H(x.data(), n, d, &nOut, &mOut, &F, full); // Wrap data in a py::capsule to ensure it gets deleted - auto capsule = py::capsule(F, [](void* f) { - double* d = reinterpret_cast(f); + auto capsule = py::capsule(F, [](void *f) { + double *d = reinterpret_cast(f); free(d); }); return py::array_t({nOut, mOut}, F, capsule); }, - py::arg("x"), py::arg("d"), py::arg("full"), + py::arg("x"), + py::arg("d"), + py::arg("full"), R"( Compute basis function matrix. @@ -147,72 +142,69 @@ PYBIND11_MODULE(BF, m) { Returns: mOut x nOut NumPy array. - )" - ); + )"); - auto PyCP = py::class_ (m, "CP", py::multiple_inheritance()); + auto PyCP = py::class_(m, "CP", py::multiple_inheritance()); add1DInit(PyCP); - auto PyLeP = py::class_ (m, "LeP", py::multiple_inheritance()); + auto PyLeP = py::class_(m, "LeP", py::multiple_inheritance()); add1DInit(PyLeP); - auto PyLaP = py::class_ (m, "LaP", py::multiple_inheritance()); + auto PyLaP = py::class_(m, "LaP", py::multiple_inheritance()); add1DInit(PyLaP); - auto PyHoPpro = py::class_ (m, "HoPpro", py::multiple_inheritance()); + auto PyHoPpro = py::class_(m, "HoPpro", py::multiple_inheritance()); add1DInit(PyHoPpro); - auto PyHoPphy = py::class_ (m, "HoPphy", py::multiple_inheritance()); + auto PyHoPphy = py::class_(m, "HoPphy", py::multiple_inheritance()); add1DInit(PyHoPphy); - auto PyFS = py::class_ (m, "FS", py::multiple_inheritance()); + auto PyFS = py::class_(m, "FS", py::multiple_inheritance()); add1DInit(PyFS); - py::class_ (m, "ELM") - .def_property("b", - [](ELM& self) { - double* data = nullptr; - int nOut; - self.getB(&data, &nOut); + py::class_(m, "ELM") + .def_property( + "b", + [](ELM &self) { + double *data = nullptr; + int nOut; + self.getB(&data, &nOut); - auto capsule = py::capsule(data, [](void* f) { - double* d = reinterpret_cast(f); - free(d); - }); - return py::array_t(self.m, data, capsule); - }, - [](ELM& self, py::array_t b) { - self.setB(b.data(), b.size()); - }) - .def_property("w", - [](ELM& self) { - double* data = nullptr; - int nOut; - self.getW(&data, &nOut); - - auto capsule = py::capsule(data, [](void* f) { - double* d = reinterpret_cast(f); - free(d); - }); - return py::array_t(self.m, data, capsule); - }, - [](ELM& self, py::array_t w) { - self.setW(w.data(), w.size()); - }); + auto capsule = py::capsule(data, [](void *f) { + double *d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](ELM &self, py::array_t b) { self.setB(b.data(), b.size()); }) + .def_property( + "w", + [](ELM &self) { + double *data = nullptr; + int nOut; + self.getW(&data, &nOut); + + auto capsule = py::capsule(data, [](void *f) { + double *d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](ELM &self, py::array_t w) { self.setW(w.data(), w.size()); }); - auto PyELMSigmoid = py::class_ (m, "ELMSigmoid"); + auto PyELMSigmoid = py::class_(m, "ELMSigmoid"); add1DInit(PyELMSigmoid); - auto PyELMReLU = py::class_ (m, "ELMReLU"); + auto PyELMReLU = py::class_(m, "ELMReLU"); add1DInit(PyELMReLU); - auto PyELMTanh = py::class_ (m, "ELMTanh"); + auto PyELMTanh = py::class_(m, "ELMTanh"); add1DInit(PyELMTanh); - auto PyELMSin = py::class_ (m, "ELMSin"); + auto PyELMSin = py::class_(m, "ELMSin"); add1DInit(PyELMSin); - auto PyELMSwish = py::class_ (m, "ELMSwish"); + auto PyELMSwish = py::class_(m, "ELMSwish"); add1DInit(PyELMSwish); // TODO: Finish members and add methods. @@ -220,25 +212,26 @@ PYBIND11_MODULE(BF, m) { .def_readwrite("z0", &nBasisFunc::z0) .def_readwrite("zf", &nBasisFunc::zf) .def_readwrite("dim", &nBasisFunc::dim) - .def_property("c", - [](nBasisFunc& self){ + .def_property( + "c", + [](nBasisFunc &self) { // Return c, and ensure the nBasisFunc stays around as long as c does. return py::array_t(self.dim, self.c, py::cast(self)); }, - [](nBasisFunc& self, py::array_t c) - { + [](nBasisFunc &self, py::array_t c) { if (c.ndim() != 1) { throw py::value_error("The \"c\" input array must be 1-dimensional."); } if (c.size() != self.dim) { - throw py::value_error(std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, c.size())); + throw py::value_error( + std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, c.size())); } - } - ) + }) .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) .def_readwrite("numBasisFuncFull", &nBasisFunc::numBasisFuncFull) - .def("H", - [](nBasisFunc& self, + .def( + "H", + [](nBasisFunc &self, py::array_t x, py::array_t d, bool full) { @@ -250,18 +243,20 @@ PYBIND11_MODULE(BF, m) { } int nOut = 0; int mOut = 0; - double* F = nullptr; + double *F = nullptr; self.H(x.data(), x.shape()[0], x.shape()[1], d.data(), d.shape()[0], &nOut, &mOut, &F, full); // Wrap data in a py::capsule to ensure it gets deleted - auto capsule = py::capsule(F, [](void* f) { - double* d = reinterpret_cast(f); + auto capsule = py::capsule(F, [](void *f) { + double *d = reinterpret_cast(f); free(d); }); return py::array_t({nOut, mOut}, F, capsule); }, - py::arg("x"), py::arg("d"), py::arg("full"), + py::arg("x"), + py::arg("d"), + py::arg("full"), R"( Compute basis function matrix. @@ -272,66 +267,65 @@ PYBIND11_MODULE(BF, m) { Returns: mOut x nOut NumPy array. - )" - ); + )"); - auto PynCP = py::class_ (m, "nCP"); + auto PynCP = py::class_(m, "nCP"); addNdInit(PynCP); - auto PynLeP = py::class_ (m, "nLeP"); + auto PynLeP = py::class_(m, "nLeP"); addNdInit(PynLeP); - auto PynFS = py::class_ (m, "nFS"); + auto PynFS = py::class_(m, "nFS"); addNdInit(PynFS); - py::class_ (m, "nELM") - .def_property("b", - [](nELM& self) { - double* data = nullptr; - int nOut; - self.getB(&data, &nOut); + py::class_(m, "nELM") + .def_property( + "b", + [](nELM &self) { + double *data = nullptr; + int nOut; + self.getB(&data, &nOut); - auto capsule = py::capsule(data, [](void* f) { - double* d = reinterpret_cast(f); - free(d); - }); - return py::array_t(self.m, data, capsule); - }, - [](nELM& self, py::array_t b) { - self.setB(b.data(), b.size()); - }) - .def_property("w", - [](nELM& self) { - double* data = nullptr; - int nOut; - int dimOut; - self.getW(&dimOut, &nOut, &data); - - auto capsule = py::capsule(data, [](void* f) { - double* d = reinterpret_cast(f); - free(d); + auto capsule = py::capsule(data, [](void *f) { + double *d = reinterpret_cast(f); + free(d); + }); + return py::array_t(self.m, data, capsule); + }, + [](nELM &self, py::array_t b) { self.setB(b.data(), b.size()); }) + .def_property( + "w", + [](nELM &self) { + double *data = nullptr; + int nOut; + int dimOut; + self.getW(&dimOut, &nOut, &data); + + auto capsule = py::capsule(data, [](void *f) { + double *d = reinterpret_cast(f); + free(d); + }); + return py::array_t({dimOut, nOut}, data, capsule); + }, + [](nELM &self, py::array_t w) { + if (w.ndim() != 2) { + throw py::value_error("The \"w\" input array must be 2-dimensional."); + } + self.setW(w.data(), w.shape()[0], w.shape()[1]); }); - return py::array_t({dimOut, nOut}, data, capsule); - }, - [](nELM& self, py::array_t w) { - if (w.ndim() != 2) { - throw py::value_error("The \"w\" input array must be 2-dimensional."); - } - self.setW(w.data(), w.shape()[0], w.shape()[1]); - }); - - auto PynELMSigmoid = py::class_ (m, "nELMSigmoid"); + + auto PynELMSigmoid = py::class_(m, "nELMSigmoid"); addNdElmInit(PynELMSigmoid); - auto PynELMTanh = py::class_ (m, "nELMTanh"); + auto PynELMTanh = py::class_(m, "nELMTanh"); addNdElmInit(PynELMTanh); - auto PynELMSin = py::class_ (m, "nELMSin"); + auto PynELMSin = py::class_(m, "nELMSin"); addNdElmInit(PynELMSin); - auto PynELMSwish = py::class_ (m, "nELMSwish"); + auto PynELMSwish = py::class_(m, "nELMSwish"); addNdElmInit(PynELMSwish); - auto PynELMReLU = py::class_ (m, "nELMReLU"); + auto PynELMReLU = py::class_(m, "nELMReLU"); addNdElmInit(PynELMReLU); } diff --git a/utils/Makefile b/utils/Makefile index 975dbe6..cf0b4f0 100644 --- a/utils/Makefile +++ b/utils/Makefile @@ -31,4 +31,6 @@ clean-python: clean: clean-python - +format: + @cd ../src/tfc && fd -e cc -e h | xargs clang-format -i --verbose --style="file:../../clang-format.yaml" + @cd ../src/tfc && black . From 224f1181f8f08b6c8d5c9f7a6699c077146d0447 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 15:29:54 -0700 Subject: [PATCH 27/45] Trying cast to int to make Mac OS happy. --- src/tfc/utils/BF_Py.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tfc/utils/BF_Py.cc b/src/tfc/utils/BF_Py.cc index ce0c326..5dd5940 100644 --- a/src/tfc/utils/BF_Py.cc +++ b/src/tfc/utils/BF_Py.cc @@ -224,7 +224,7 @@ PYBIND11_MODULE(BF, m) { } if (c.size() != self.dim) { throw py::value_error( - std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, c.size())); + std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, int(c.size()))); } }) .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) From df33beb30cc6de3884f42268f54d620ce3e3f2b0 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 15:31:47 -0700 Subject: [PATCH 28/45] Changes so this builds with pedantic. --- src/tfc/utils/BF.cc | 19 ------------------- src/tfc/utils/BF_Py.cc | 4 ++-- src/tfc/utils/CMakeLists.txt | 2 +- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/tfc/utils/BF.cc b/src/tfc/utils/BF.cc index 5566a9a..cfa88b8 100644 --- a/src/tfc/utils/BF.cc +++ b/src/tfc/utils/BF.cc @@ -882,12 +882,8 @@ void ELMSin::Hint(const int d, const double *x, const int nOut, double *dark) { void ELMSwish::Hint(const int d, const double *x, const int nOut, double *dark) { int j, k; -#ifdef WINDOWS_MSVC double *sig = new double[nOut * m]; double *zint = new double[nOut * m]; -#else - double sig[nOut * m], zint[nOut * m]; -#endif if (d == 0) { for (j = 0; j < nOut; j++) { @@ -1007,10 +1003,8 @@ void ELMSwish::Hint(const int d, const double *x, const int nOut, double *dark) } } -#ifdef WINDOWS_MSVC delete[] sig; delete[] zint; -#endif return; }; @@ -1048,18 +1042,12 @@ nBasisFunc::nBasisFunc(const double *x0in, numBasisFunc = 0; numBasisFuncFull = 0; -#ifdef WINDOWS_MSVC int *vec = new int[dim]; -#else - int vec[dim]; -#endif NumBasisFunc(dim - 1, &vec[0], numBasisFunc, false); NumBasisFunc(dim - 1, &vec[0], numBasisFuncFull, true); -#ifdef WINDOWS_MSVC delete[] vec; -#endif // Track this instance of BasisFunc BasisFuncContainer.push_back(this); @@ -1140,18 +1128,11 @@ void nBasisFunc::nHint(const double *x, int n, const int *d, int dDim0, int numB int count = 0; -#ifdef WINDOWS_MSVC int *vec = new int[dim]; -#else - int vec[dim]; -#endif RecurseBasis(dim - 1, vec, count, full, n, numBasis, &T[0], F); -#ifdef WINDOWS_MSVC delete[] vec; -#endif - delete[] dark; delete[] T; delete[] z; diff --git a/src/tfc/utils/BF_Py.cc b/src/tfc/utils/BF_Py.cc index 5dd5940..88d69b7 100644 --- a/src/tfc/utils/BF_Py.cc +++ b/src/tfc/utils/BF_Py.cc @@ -223,8 +223,8 @@ PYBIND11_MODULE(BF, m) { throw py::value_error("The \"c\" input array must be 1-dimensional."); } if (c.size() != self.dim) { - throw py::value_error( - std::format("The \"c\" input array must be size {}, but got size {}.", self.dim, int(c.size()))); + throw py::value_error(std::format( + "The \"c\" input array must be size {}, but got size {}.", self.dim, int(c.size()))); } }) .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index ed268b6..24009a0 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.25) project(tfc) # TODO: Change for release -add_compile_options(-Wall -Werror) +add_compile_options(-Wall -Werror -pedantic) # Contorl whether we build with shared libraries or static libraries option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) From 685883341077c3a2339666aca56321d35c34ee82 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 15:35:21 -0700 Subject: [PATCH 29/45] Adding files to manifest for source distribution. --- MANIFEST.in | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 190ed53..2862af2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,5 @@ -include src/tfc/utils/BF/BF.h -include src/tfc/utils/BF/numpy.i +include src/tfc/utils/BF.h +include src/tfc/utils/BF.cc +include src/tfc/utils/BF_Py.cc +include src/tfc/utils/CMakeLists.txt include src/tfc/py.typed From 07ff93a18694398fa9a93af91ca1dab5fe25303a Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 15:43:55 -0700 Subject: [PATCH 30/45] Switching to stringstream, as Mac OS seems to be having issues with std::format. --- pyproject.toml | 2 +- src/tfc/utils/BF_Py.cc | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7930ae7..d07e3b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dynamic = ["dependencies", "classifiers", "authors", "license", "description"] [tool.black] line-length = 100 -target-version = ['py310','py311'] +target-version = ['py311'] [tool.cibuildwheel] before-build = "pip install setuptools wheel numpy" diff --git a/src/tfc/utils/BF_Py.cc b/src/tfc/utils/BF_Py.cc index 88d69b7..8e294a8 100644 --- a/src/tfc/utils/BF_Py.cc +++ b/src/tfc/utils/BF_Py.cc @@ -1,5 +1,4 @@ #include "BF.h" -#include #include #include #include @@ -223,8 +222,9 @@ PYBIND11_MODULE(BF, m) { throw py::value_error("The \"c\" input array must be 1-dimensional."); } if (c.size() != self.dim) { - throw py::value_error(std::format( - "The \"c\" input array must be size {}, but got size {}.", self.dim, int(c.size()))); + std::stringstream ss; + ss << "The \"c\" input array must be size " << self.dim << ", but got size " << c.size() << "." << std::endl; + throw py::value_error(ss.str()); } }) .def_readwrite("numBasisFunc", &nBasisFunc::numBasisFunc) From 55c04b60456d24877971d7c45d419fee4ebc447f Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 16:10:54 -0700 Subject: [PATCH 31/45] Adding other dependencies and removing swig. --- pyproject.toml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d07e3b8..9d1fe53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ line-length = 100 target-version = ['py311'] [tool.cibuildwheel] -before-build = "pip install setuptools wheel numpy" +before-build = "pip install setuptools wheel numpy pybind11 mypy" skip = "pp* *-musllinux*" manylinux-x86_64-image = "manylinux2014" test-requires = ["pytest"] @@ -26,11 +26,9 @@ test-command = "pytest {package}/tests" test-skip = "*-macosx_arm64" [tool.cibuildwheel.linux] -before-all = "yum install -y swig" archs = ["x86_64"] [tool.cibuildwheel.macos] -before-all = "brew install swig" archs = ["arm64"] [tool.cibuildwheel.windows] @@ -38,5 +36,3 @@ archs = ["AMD64"] [[tool.cibuildwheel.overrides]] select = "*-musllinux*" -before-all = "apk add swig" - From 9611ec300eb5b0b64d0153c029f8c77326cb5989 Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 16:25:51 -0700 Subject: [PATCH 32/45] Removing 3.14 for now. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9d1fe53..a753e6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "tfc" version = "1.2.1" -requires-python = ">=3.11" +requires-python = ">=3.11,<3.14" readme = "README.md" dynamic = ["dependencies", "classifiers", "authors", "license", "description"] From 04b1840c48fc0e6d53b76e465780365687ffabbe Mon Sep 17 00:00:00 2001 From: leakec Date: Sat, 2 Aug 2025 16:28:56 -0700 Subject: [PATCH 33/45] Adding Python version to setup Python to get rid of warnings. --- .github/workflows/ci.yml | 2 ++ .github/workflows/publish_wheels.yml | 4 ++++ .github/workflows/publish_wheels_test_pypi.yml | 4 ++++ .github/workflows/run_most_examples.yml | 2 ++ 4 files changed, 12 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 378c827..e0db1cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,8 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - run: "python -m pip install black" - name: lint diff --git a/.github/workflows/publish_wheels.yml b/.github/workflows/publish_wheels.yml index e44dce9..8ffc33e 100644 --- a/.github/workflows/publish_wheels.yml +++ b/.github/workflows/publish_wheels.yml @@ -16,6 +16,8 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - name: Checkout dependencies run: python -m pip install wheel setuptools numpy pybind11 mypy @@ -48,6 +50,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - name: Install cibuildwheel run: python -m pip install cibuildwheel diff --git a/.github/workflows/publish_wheels_test_pypi.yml b/.github/workflows/publish_wheels_test_pypi.yml index 8b75b35..de321b0 100644 --- a/.github/workflows/publish_wheels_test_pypi.yml +++ b/.github/workflows/publish_wheels_test_pypi.yml @@ -13,6 +13,8 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - name: Checkout dependencies run: python -m pip install wheel setuptools numpy pybind11 mypy @@ -45,6 +47,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - name: Install cibuildwheel run: python -m pip install cibuildwheel diff --git a/.github/workflows/run_most_examples.yml b/.github/workflows/run_most_examples.yml index ef812c7..055cad7 100644 --- a/.github/workflows/run_most_examples.yml +++ b/.github/workflows/run_most_examples.yml @@ -13,6 +13,8 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: + python-version: "3.12" - run: "sudo apt-get update && sudo apt-get install -y gcc g++ graphviz" - run: python -m pip install wheel setuptools numpy pytest pybind11 mypy From bb53026f9bfad9c624ebc58afd0348a70162e6ed Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 06:54:03 -0700 Subject: [PATCH 34/45] Updating to build with Windows. Thank you to https://stackoverflow.com/questions/78124769/pybind11-how-do-i-truly-fix-missing-dll-error for the tip on statically linking lpthread. --- src/tfc/utils/CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index 24009a0..f1dded0 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -40,7 +40,11 @@ target_link_libraries(bf PUBLIC Python3::Python) # Create the BF.py Python file pybind11_add_module(BF BF_Py.cc) -target_link_libraries(BF PRIVATE bf) +if (MINGW) + target_link_libraries(BF PRIVATE bf -static -lpthread -static-libgcc -static-libstdc++) +else() + target_link_libraries(BF PRIVATE bf) +endif() -install(TARGETS bf BF DESTINATION .) +install(TARGETS BF DESTINATION .) install(CODE [=[execute_process(COMMAND stubgen -m BF -o . --include-docstring WORKING_DIRECTORY $ENV{DESTDIR}${CMAKE_INSTALL_PREFIX})]=]) From c59b09bd3cbee1ae76621fe116303913867c8755 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 07:14:05 -0700 Subject: [PATCH 35/45] Trying to fix finding the right executable with cibuildwheel. --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 45dd87b..824589e 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,8 @@ def build_extension(self, ext): cmake_args = [ f"-DCMAKE_BUILD_TYPE={cfg}", f"-DCMAKE_INSTALL_PREFIX={bf_dir}", - f"-Dpybind11_DIR={pybind11_dir}" + f"-Dpybind11_DIR={pybind11_dir}", + f"-DPython_EXECUTABLE={sys.executable}" ] # Optional: use Ninja if available From 16b2ad7fc4f7138b9c8004e3c8591ef8e45c853b Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 07:22:49 -0700 Subject: [PATCH 36/45] Using Development.Module instread. See notes in pybind11 here: https://pybind11.readthedocs.io/en/stable/compiling.html#findpython-mode --- src/tfc/utils/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index f1dded0..0533c2e 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -26,7 +26,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_COLOR_DIAGNOSTICS ON) # Find Python in the system -find_package(Python3 REQUIRED COMPONENTS Interpreter Development) +find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) # Find pybind11 in the system # This is needed for CMake < 3.27. From c50a8165a3eb96ff84a06e6636740f16513cef39 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 07:25:10 -0700 Subject: [PATCH 37/45] Undoing this, as this breaks source builds. --- src/tfc/utils/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index 0533c2e..f1dded0 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -26,7 +26,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_COLOR_DIAGNOSTICS ON) # Find Python in the system -find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) # Find pybind11 in the system # This is needed for CMake < 3.27. From 025f96d17ef34bcf83820edde1b5060ce38565de Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 17:23:23 -0700 Subject: [PATCH 38/45] Letting pybind11 find Python. Hoping this will fix cibuildwheel finding the wrong version. --- src/tfc/utils/CMakeLists.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index f1dded0..d8d8b16 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -25,9 +25,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Turn on colored diagnostics set(CMAKE_COLOR_DIAGNOSTICS ON) -# Find Python in the system -find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - # Find pybind11 in the system # This is needed for CMake < 3.27. # After Cmake 3.27+, can remove setting PYBIND11_FINDPYTHON. @@ -36,7 +33,7 @@ find_package(pybind11 3.0 REQUIRED CONFIG) # Create the bf library add_library(bf BF.cc) -target_link_libraries(bf PUBLIC Python3::Python) +target_link_libraries(bf PUBLIC pybind11::pybind11) # Create the BF.py Python file pybind11_add_module(BF BF_Py.cc) From 586ed2cc37315943b3354f9dc67c4d5a4dc417ab Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 17:33:21 -0700 Subject: [PATCH 39/45] Adding ninja to cibuildwheel runners. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a753e6d..a24d51e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ test-command = "pytest {package}/tests" test-skip = "*-macosx_arm64" [tool.cibuildwheel.linux] +before-all = "yum install -y ninja-build" archs = ["x86_64"] [tool.cibuildwheel.macos] From 1e4d696db21e34b29c33157a23684eb5705636d9 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 17:59:19 -0700 Subject: [PATCH 40/45] Adding Wextra and fixing the errors. Still need to fix pedantic, but need an older version of gcc. --- src/tfc/utils/BF.cc | 17 ++++++++++++----- src/tfc/utils/BF.h | 8 ++++---- src/tfc/utils/BF_Py.cc | 3 ++- src/tfc/utils/CMakeLists.txt | 2 +- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/tfc/utils/BF.cc b/src/tfc/utils/BF.cc index cfa88b8..5de1fa7 100644 --- a/src/tfc/utils/BF.cc +++ b/src/tfc/utils/BF.cc @@ -1013,7 +1013,7 @@ void ELMSwish::Hint(const int d, const double *x, const int nOut, double *dark) nBasisFunc::nBasisFunc(const double *x0in, int x0Dim0, const double *xf, - int xfDim0, + int /*xfDim0*/, const int *nCin, int ncDim0, int ncDim1, @@ -1071,8 +1071,15 @@ void nBasisFunc::getC(double **arrOut, int *nOut) { return; }; -void nBasisFunc::H( - const double *x, int in, int xDim1, const int *d, int dDim0, int *nOut, int *mOut, double **F, const bool full) { +void nBasisFunc::H(const double *x, + int /*in*/, + int xDim1, + const int *d, + int dDim0, + int *nOut, + int *mOut, + double **F, + const bool full) { int numBasis = full ? numBasisFuncFull : numBasisFunc; *mOut = numBasis; *nOut = xDim1; @@ -1080,7 +1087,7 @@ void nBasisFunc::H( nHint(x, xDim1, d, dDim0, numBasis, *F, full); }; -void nBasisFunc::H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) { +void nBasisFunc::H(const double *, int, const int, int *, int *, double **, bool) { throw std::runtime_error("This version of \"H\" should never be called from an n-dimensional basis class."); } @@ -1259,7 +1266,7 @@ void nBasisFunc::RecurseBasis(int dimCurr, nELM::nELM(const double *x0in, int x0Dim0, const double *xf, - int xfDim0, + int /*xfDim0*/, const int *nCin, int ncDim0, int min, diff --git a/src/tfc/utils/BF.h b/src/tfc/utils/BF.h index 00b24dd..519970f 100644 --- a/src/tfc/utils/BF.h +++ b/src/tfc/utils/BF.h @@ -262,7 +262,7 @@ class FS : virtual public BasisFunc { /** This function is unecessary for FS as it is all handled in Hint. Therefore, this is just an empty function that * returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + void RecurseDeriv(const int, int, const double *, const int, double *&, const int) { fprintf(stderr, "Warning, this function from FS should never be called. It seems it has been called by accident. " "Please check that this function was intended to be called.\n"); @@ -307,7 +307,7 @@ class ELM : public BasisFunc { /** This function is unecessary for ELM as it is all handled in Hint. Therefore, this is just an empty function that * returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { + void RecurseDeriv(const int, int, const double *, const int, double *&, const int) { fprintf(stderr, "Warning, this function from ELM should never be called. It seems it has been called by accident. " "Please check that this function was intended to be called.\n"); @@ -627,7 +627,7 @@ class nELM : public nBasisFunc { /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function * that returns a warning. */ - void Hint(const int d, const double *x, const int nOut, double *dark) override { + void Hint(const int, const double *, const int, double *) override { fprintf(stderr, "Warning, this function from nELM should never be called. It seems it has been called by accident. " "Please check that this function was intended to be called.\n"); @@ -637,7 +637,7 @@ class nELM : public nBasisFunc { /** This function is unecessary for nELM as it is all handled in nElmHint. Therefore, this is just an empty function * that returns a warning. */ - void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override { + void RecurseDeriv(const int, int, const double *, const int, double *&, const int) override { fprintf(stderr, "Warning, this function from nELM should never be called. It seems it has been called by accident. " "Please check that this function was intended to be called.\n"); diff --git a/src/tfc/utils/BF_Py.cc b/src/tfc/utils/BF_Py.cc index 8e294a8..de1a403 100644 --- a/src/tfc/utils/BF_Py.cc +++ b/src/tfc/utils/BF_Py.cc @@ -223,7 +223,8 @@ PYBIND11_MODULE(BF, m) { } if (c.size() != self.dim) { std::stringstream ss; - ss << "The \"c\" input array must be size " << self.dim << ", but got size " << c.size() << "." << std::endl; + ss << "The \"c\" input array must be size " << self.dim << ", but got size " << c.size() << "." + << std::endl; throw py::value_error(ss.str()); } }) diff --git a/src/tfc/utils/CMakeLists.txt b/src/tfc/utils/CMakeLists.txt index d8d8b16..2c03dfc 100644 --- a/src/tfc/utils/CMakeLists.txt +++ b/src/tfc/utils/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.25) project(tfc) # TODO: Change for release -add_compile_options(-Wall -Werror -pedantic) +add_compile_options(-Wall -Wextra -Werror -pedantic) # Contorl whether we build with shared libraries or static libraries option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) From 6eaf591d18f18a13273525614185434469fcfc46 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 18:07:37 -0700 Subject: [PATCH 41/45] Fixing pedantic errors. --- src/tfc/utils/BF.cc | 94 ++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/src/tfc/utils/BF.cc b/src/tfc/utils/BF.cc index 5de1fa7..82504d0 100644 --- a/src/tfc/utils/BF.cc +++ b/src/tfc/utils/BF.cc @@ -8,7 +8,7 @@ std::vector BasisFunc::BasisFuncContainer; void xlaWrapper(void *out, void **in) { int N = (reinterpret_cast(in[0]))[0]; BasisFunc::BasisFuncContainer[N]->xla(out, in); -}; +} #ifdef HAS_CUDA // xlaGpuWrapper function @@ -50,9 +50,9 @@ BasisFunc::BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int mi #ifdef HAS_CUDA xlaGpuCapsule = GetXlaCapsuleGpu(); #endif -}; +} -BasisFunc::~BasisFunc() { delete[] nC; }; +BasisFunc::~BasisFunc() { delete[] nC; } void BasisFunc::H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) { *nOut = n; @@ -95,7 +95,7 @@ void BasisFunc::H(const double *x, int n, const int d, int *nOut, int *mOut, dou } delete[] dark; delete[] z; -}; +} void BasisFunc::xla(void *out, void **in) { double *out_buf = reinterpret_cast(out); @@ -141,7 +141,7 @@ void BasisFunc::xla(void *out, void **in) { } delete[] dark; delete[] z; -}; +} #ifdef HAS_CUDA void BasisFunc::xlaGpu(CUstream stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -155,7 +155,7 @@ PyObject *BasisFunc::GetXlaCapsule() { PyObject *capsule; capsule = PyCapsule_New(reinterpret_cast(xlaFnPtr), name, NULL); return capsule; -}; +} #ifdef HAS_CUDA PyObject *BasisFunc::GetXlaCapsuleGpu() { @@ -207,7 +207,7 @@ void CP::Hint(const int d, const double *x, const int nOut, double *dark) { RecurseDeriv(d, 0, x, nOut, dark, m); } return; -}; +} void CP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { if (dCurr != d) { @@ -234,7 +234,7 @@ void CP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, d RecurseDeriv(d, dCurr, x, nOut, F, mOut); } return; -}; +} // LeP: ********************************************************************** void LeP::Hint(const int d, const double *x, const int nOut, double *dark) { @@ -276,7 +276,7 @@ void LeP::Hint(const int d, const double *x, const int nOut, double *dark) { RecurseDeriv(d, 0, x, nOut, dark, m); } return; -}; +} void LeP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { @@ -305,7 +305,7 @@ void LeP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, RecurseDeriv(d, dCurr, x, nOut, F, mOut); } return; -}; +} // LaP: ********************************************************************** void LaP::Hint(const int d, const double *x, const int nOut, double *dark) { @@ -347,7 +347,7 @@ void LaP::Hint(const int d, const double *x, const int nOut, double *dark) { RecurseDeriv(d, 0, x, nOut, dark, m); } return; -}; +} void LaP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { @@ -376,7 +376,7 @@ void LaP::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, RecurseDeriv(d, dCurr, x, nOut, F, mOut); } return; -}; +} // HoPpro: ********************************************************************** // Hermite polynomials, probablists @@ -419,7 +419,7 @@ void HoPpro::Hint(const int d, const double *x, const int nOut, double *dark) { RecurseDeriv(d, 0, x, nOut, dark, m); } return; -}; +} void HoPpro::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { @@ -446,7 +446,7 @@ void HoPpro::RecurseDeriv(const int d, int dCurr, const double *x, const int nOu RecurseDeriv(d, dCurr, x, nOut, F, mOut); } return; -}; +} // HoPphy: ********************************************************************** // Hermite polynomials, physicists @@ -489,7 +489,7 @@ void HoPphy::Hint(const int d, const double *x, const int nOut, double *dark) { RecurseDeriv(d, 0, x, nOut, dark, m); } return; -}; +} void HoPphy::RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) { @@ -517,7 +517,7 @@ void HoPphy::RecurseDeriv(const int d, int dCurr, const double *x, const int nOu RecurseDeriv(d, dCurr, x, nOut, F, mOut); } return; -}; +} // FS: ********************************************************************** // Fourier Series @@ -587,7 +587,7 @@ void FS::Hint(const int d, const double *x, const int nOut, double *dark) { } } return; -}; +} // ELM: ********************************************************************** // ELM base class @@ -602,12 +602,12 @@ ELM::ELM(double x0, double xf, const int *nCin, int ncDim0, int min) w[k] = 20. * ((double)rand() / (double)RAND_MAX) - 10.; b[k] = 20. * ((double)rand() / (double)RAND_MAX) - 10.; } -}; +} ELM::~ELM() { delete[] b; delete[] w; -}; +} void ELM::setW(const double *arrIn, int nIn) { if (nIn != m) { @@ -616,7 +616,7 @@ void ELM::setW(const double *arrIn, int nIn) { } for (int k = 0; k < m; k++) w[k] = arrIn[k]; -}; +} void ELM::setB(const double *arrIn, int nIn) { if (nIn != m) { @@ -625,7 +625,7 @@ void ELM::setB(const double *arrIn, int nIn) { } for (int k = 0; k < m; k++) b[k] = arrIn[k]; -}; +} void ELM::getW(double **arrOut, int *nOut) { *nOut = m; @@ -633,7 +633,7 @@ void ELM::getW(double **arrOut, int *nOut) { for (int k = 0; k < m; k++) (*arrOut)[k] = w[k]; return; -}; +} void ELM::getB(double **arrOut, int *nOut) { *nOut = m; @@ -641,7 +641,7 @@ void ELM::getB(double **arrOut, int *nOut) { for (int k = 0; k < m; k++) (*arrOut)[k] = b[k]; return; -}; +} // ELM ReLU: ********************************************************************** void ELMReLU::Hint(const int d, const double *x, const int nOut, double *dark) { @@ -673,7 +673,7 @@ void ELMReLU::Hint(const int d, const double *x, const int nOut, double *dark) { } } return; -}; +} // ELM Sigmoid: ********************************************************************** void ELMSigmoid::Hint(const int d, const double *x, const int nOut, double *dark) { @@ -762,7 +762,7 @@ void ELMSigmoid::Hint(const int d, const double *x, const int nOut, double *dark } } return; -}; +} // ELM Tanh: ********************************************************************** @@ -843,7 +843,7 @@ void ELMTanh::Hint(const int d, const double *x, const int nOut, double *dark) { } } return; -}; +} // ELM Sin: ********************************************************************** @@ -876,7 +876,7 @@ void ELMSin::Hint(const int d, const double *x, const int nOut, double *dark) { } } return; -}; +} // ELM Swish: ********************************************************************** @@ -1007,7 +1007,7 @@ void ELMSwish::Hint(const int d, const double *x, const int nOut, double *dark) delete[] zint; return; -}; +} // Parent n-dimensional basis function class: ********************************************************************** nBasisFunc::nBasisFunc(const double *x0in, @@ -1061,7 +1061,7 @@ nBasisFunc::nBasisFunc(const double *x0in, #endif } -nBasisFunc::~nBasisFunc() { delete[] c; }; +nBasisFunc::~nBasisFunc() { delete[] c; } void nBasisFunc::getC(double **arrOut, int *nOut) { *nOut = dim; @@ -1069,7 +1069,7 @@ void nBasisFunc::getC(double **arrOut, int *nOut) { for (int k = 0; k < dim; k++) (*arrOut)[k] = c[k]; return; -}; +} void nBasisFunc::H(const double *x, int /*in*/, @@ -1085,7 +1085,7 @@ void nBasisFunc::H(const double *x, *nOut = xDim1; *F = (double *)malloc(numBasis * xDim1 * sizeof(double)); nHint(x, xDim1, d, dDim0, numBasis, *F, full); -}; +} void nBasisFunc::H(const double *, int, const int, int *, int *, double **, bool) { throw std::runtime_error("This version of \"H\" should never be called from an n-dimensional basis class."); @@ -1101,7 +1101,7 @@ void nBasisFunc::xla(void *out, void **in) { int mOut = (reinterpret_cast(in[6]))[0]; nHint(x, nOut, d, dDim0, mOut, out_buf, full); -}; +} void nBasisFunc::nHint(const double *x, int n, const int *d, int dDim0, int numBasis, double *&F, const bool full) { @@ -1143,7 +1143,7 @@ void nBasisFunc::nHint(const double *x, int n, const int *d, int dDim0, int numB delete[] dark; delete[] T; delete[] z; -}; +} void nBasisFunc::NumBasisFunc(int dimCurr, int *vec, int &count, const bool full) { int k; @@ -1193,7 +1193,7 @@ void nBasisFunc::NumBasisFunc(int dimCurr, int *vec, int &count, const bool full } } return; -}; +} void nBasisFunc::RecurseBasis(int dimCurr, int *vec, @@ -1260,7 +1260,7 @@ void nBasisFunc::RecurseBasis(int dimCurr, } } return; -}; +} // nELM base class: *********************************************************************************** nELM::nELM(const double *x0in, @@ -1324,12 +1324,12 @@ nELM::nELM(const double *x0in, w[k] = 2. * ((double)rand() / (double)RAND_MAX) - 1.; for (k = 0; k < m; k++) b[k] = 2. * ((double)rand() / (double)RAND_MAX) - 1.; -}; +} nELM::~nELM() { delete[] b; delete[] w; -}; +} void nELM::setW(const double *arrIn, int dimIn, int nIn) { if ((nIn != m) || (dimIn != dim)) { @@ -1338,7 +1338,7 @@ void nELM::setW(const double *arrIn, int dimIn, int nIn) { } for (int k = 0; k < m * dim; k++) w[k] = arrIn[k]; -}; +} void nELM::setB(const double *arrIn, int nIn) { if (nIn != m) { @@ -1347,7 +1347,7 @@ void nELM::setB(const double *arrIn, int nIn) { } for (int k = 0; k < m; k++) b[k] = arrIn[k]; -}; +} void nELM::getW(int *dimOut, int *nOut, double **arrOut) { *dimOut = dim; @@ -1356,7 +1356,7 @@ void nELM::getW(int *dimOut, int *nOut, double **arrOut) { for (int k = 0; k < m * dim; k++) (*arrOut)[k] = w[k]; return; -}; +} void nELM::getB(double **arrOut, int *nOut) { *nOut = m; @@ -1364,7 +1364,7 @@ void nELM::getB(double **arrOut, int *nOut) { for (int k = 0; k < m; k++) (*arrOut)[k] = b[k]; return; -}; +} void nELM::nHint(const double *x, int n, const int *d, int dDim0, int numBasis, double *&F, const bool full) { @@ -1405,7 +1405,7 @@ void nELM::nHint(const double *x, int n, const int *d, int dDim0, int numBasis, delete[] dark; } delete[] z; -}; +} // nELM Sigmoid ******************************************************************************************* void nELMSigmoid::nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) { @@ -1539,7 +1539,7 @@ void nELMSigmoid::nElmHint(const int *d, int dDim0, const double *x, const int i } } return; -}; +} // nELM Tanh ******************************************************************************************* void nELMTanh::nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) { @@ -1663,7 +1663,7 @@ void nELMTanh::nElmHint(const int *d, int dDim0, const double *x, const int in, } } return; -}; +} // nELM Sin ******************************************************************************************* void nELMSin::nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) { @@ -1727,7 +1727,7 @@ void nELMSin::nElmHint(const int *d, int dDim0, const double *x, const int in, d } } return; -}; +} // nELM Swish ******************************************************************************************* void nELMSwish::nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) { @@ -1901,7 +1901,7 @@ void nELMSwish::nElmHint(const int *d, int dDim0, const double *x, const int in, delete[] sig; delete[] zint; return; -}; +} // nELM Swish ******************************************************************************************* void nELMReLU::nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) { @@ -1945,4 +1945,4 @@ void nELMReLU::nElmHint(const int *d, int dDim0, const double *x, const int in, } } return; -}; +} From 244ce6a0e00fe94e5598d9fc6bb9e0bc5f0b24e8 Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 18:12:38 -0700 Subject: [PATCH 42/45] No need to skip pp*, as these have been deprecated. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a24d51e..19cbf2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ target-version = ['py311'] [tool.cibuildwheel] before-build = "pip install setuptools wheel numpy pybind11 mypy" -skip = "pp* *-musllinux*" +skip = "*-musllinux*" manylinux-x86_64-image = "manylinux2014" test-requires = ["pytest"] test-command = "pytest {package}/tests" From 3fac0a8a81898bdb6b9d8609fa077b160404e7ac Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 18:40:20 -0700 Subject: [PATCH 43/45] Some quick typing improvements. More typing fixes. --- src/tfc/utils/BF_Py.py | 167 ++++++++++++++++---------------- src/tfc/utils/CeSolver.py | 14 +-- src/tfc/utils/MakePlot.py | 6 +- src/tfc/utils/MayaviMakePlot.py | 4 +- src/tfc/utils/TFCUtils.py | 79 +++++++++++++-- src/tfc/utils/tfc_types.py | 33 +++---- 6 files changed, 177 insertions(+), 126 deletions(-) diff --git a/src/tfc/utils/BF_Py.py b/src/tfc/utils/BF_Py.py index b1372aa..8556a15 100644 --- a/src/tfc/utils/BF_Py.py +++ b/src/tfc/utils/BF_Py.py @@ -1,8 +1,7 @@ import numpy as np import jax.numpy as jnp from abc import ABC, abstractmethod -from numpy import typing as npt -from tfc.utils.tfc_types import uint, Number +from tfc.utils.tfc_types import uint, Number, JaxOrNumpyArray from typing import Callable, Tuple @@ -21,7 +20,7 @@ def __init__( self, x0: Number, xf: Number, - nC: npt.NDArray, + nC: JaxOrNumpyArray, m: uint, z0: Number = 0, zf: Number = float("inf"), @@ -35,7 +34,7 @@ def __init__( Start of the problem domain. xf : Number End of the problem domain. - nC : npt.NDArray + nC : JaxOrNumpyArray Basis functions to be removed m : uint Number of basis functions. @@ -57,7 +56,7 @@ def __init__( self._x0 = x0 self._c = (zf - z0) / (xf - x0) - def H(self, x: npt.NDArray, d: uint = 0, full: bool = False) -> npt.NDArray: + def H(self, x: JaxOrNumpyArray, d: uint = 0, full: bool = False) -> JaxOrNumpyArray: """ Returns the basis function matrix for the x with a derivative of order d. @@ -87,7 +86,7 @@ def H(self, x: npt.NDArray, d: uint = 0, full: bool = False) -> npt.NDArray: return F @abstractmethod - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -130,7 +129,7 @@ def __init__( self, x0: Number, xf: Number, - nC: npt.NDArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -142,14 +141,14 @@ def __init__( Start of the problem domain. xf : Number End of the problem domain. - nC : npt.NDArray + nC : JaxOrNumpyArray Basis functions to be removed m: uint Number of basis functions. """ super().__init__(x0, xf, nC, m, -1.0, 1.0) - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the CP basis function values. @@ -187,7 +186,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: for k in range(2, self._m): F[:, k : k + 1] = 2 * z * F[:, k - 1 : k] - F[:, k - 2 : k - 1] - def Recurse(dark: npt.NDArray, d: uint, dCurr: uint = 0) -> npt.NDArray: + def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray: """ Take derivative recursively. """ @@ -219,7 +218,7 @@ def __init__( self, x0: Number, xf: Number, - nC: npt.NDArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -231,14 +230,14 @@ def __init__( Start of the problem domain. xf : Number End of the problem domain. - nC : npt.NDArray + nC : JaxOrNumpyArray Basis functions to be removed m : uint Number of basis functions. """ super().__init__(x0, xf, nC, m, -1.0, 1.0) - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the LeP basis function values. @@ -278,7 +277,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: (2.0 * k + 1.0) * z * F[:, k : k + 1] - k * F[:, k - 1 : k] ) / (k + 1.0) - def Recurse(dark: npt.NDArray, d: uint, dCurr: uint = 0) -> npt.NDArray: + def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray: """ Take derivative recursively. """ @@ -306,7 +305,7 @@ class LaP(BasisFunc): Laguerre polynomial basis functions. """ - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the LaP basis function values. @@ -346,7 +345,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: (2.0 * k + 1.0 - z) * F[:, k : k + 1] - k * F[:, k - 1 : k] ) / (k + 1.0) - def Recurse(dark: npt.NDArray, d: uint, dCurr: uint = 0) -> npt.NDArray: + def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray: """ Take derivative recursively. """ @@ -374,7 +373,7 @@ class HoPpro(BasisFunc): Hermite probablist polynomial basis functions. """ - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the HoPpro basis function values. @@ -412,7 +411,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: for k in range(1, self._m - 1): F[:, k + 1 : k + 2] = z * F[:, k : k + 1] - k * F[:, k - 1 : k] - def Recurse(dark: npt.NDArray, d: uint, dCurr: uint = 0) -> npt.NDArray: + def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray: """ Take derivative recursively. """ @@ -440,7 +439,7 @@ class HoPphy(BasisFunc): Hermite physicist polynomial basis functions. """ - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the HoPpro basis function values. @@ -478,7 +477,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: for k in range(1, self._m - 1): F[:, k + 1 : k + 2] = 2.0 * z * F[:, k : k + 1] - 2.0 * k * F[:, k - 1 : k] - def Recurse(dark: npt.NDArray, d: uint, dCurr: uint = 0) -> npt.NDArray: + def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray: """ Take derivative recursively. """ @@ -512,7 +511,7 @@ def __init__( self, x0: Number, xf: Number, - nC: npt.NDArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -524,14 +523,14 @@ def __init__( Start of the problem domain. xf : Number End of the problem domain. - nC : npt.NDArray + nC : JaxOrNumpyArray Basis functions to be removed m : uint Number of basis functions. """ super().__init__(x0, xf, nC, m, -np.pi, np.pi) - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the CP basis function values. @@ -599,7 +598,7 @@ def __init__( self, x0: Number, xf: Number, - nC: npt.NDArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -611,7 +610,7 @@ def __init__( Start of the problem domain. xf : Number End of the problem domain. - nC : npt.NDArray + nC : JaxOrNumpyArray Basis functions to be removed m : uint Number of basis functions. @@ -627,7 +626,7 @@ def __init__( self._b = self._b.reshape((1, self._m)) @property - def w(self) -> npt.NDArray: + def w(self) -> JaxOrNumpyArray: """ Weights of the ELM @@ -639,7 +638,7 @@ def w(self) -> npt.NDArray: return self._w @property - def b(self) -> npt.NDArray: + def b(self) -> JaxOrNumpyArray: """ Biases of the ELM @@ -651,7 +650,7 @@ def b(self) -> npt.NDArray: return self._b @w.setter - def w(self, val: npt.NDArray) -> None: + def w(self, val: JaxOrNumpyArray) -> None: """ Weights of the ELM. @@ -670,7 +669,7 @@ def w(self, val: npt.NDArray) -> None: ) @b.setter - def b(self, val: npt.NDArray) -> None: + def b(self, val: JaxOrNumpyArray) -> None: """ Biases of the ELM. @@ -690,7 +689,7 @@ def b(self, val: npt.NDArray) -> None: class ELMReLU(ELM): - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the ELMRelu basis function values. @@ -716,7 +715,7 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: class ELMSigmoid(ELM): - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the ELMSigmoid basis function values. @@ -738,8 +737,8 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: f = lambda x: 1.0 / (1.0 + jnp.exp(-self._w * x - self._b)) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: """ Take derivative recursively. """ @@ -754,7 +753,7 @@ def Recurse( class ELMTanh(ELM): - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the ELMTanh basis function values. @@ -776,8 +775,8 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: f = lambda x: jnp.tanh(self._w * x + self._b) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: """ Take derivative recursively. """ @@ -792,7 +791,7 @@ def Recurse( class ELMSin(ELM): - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the ELMSin basis function values. @@ -814,8 +813,8 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: f = lambda x: jnp.sin(self._w * x + self._b) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: """ Take derivative recursively. """ @@ -830,7 +829,7 @@ def Recurse( class ELMSwish(ELM): - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Internal method used to calcualte the ELMSwish basis function values. @@ -852,8 +851,8 @@ def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: f = lambda x: (self._w * x + self._b) / (1.0 + jnp.exp(-self._w * x - self._b)) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: """ Take derivative recursively. """ @@ -875,9 +874,9 @@ class nBasisFunc(BasisFunc): def __init__( self, - x0: npt.NDArray, - xf: npt.NDArray, - nC: npt.NDArray, + x0: JaxOrNumpyArray, + xf: JaxOrNumpyArray, + nC: JaxOrNumpyArray, m: uint, z0: Number = 0.0, zf: Number = 0.0, @@ -919,7 +918,7 @@ def __init__( self._numBasisFunc = self._NumBasisFunc(self._dim - 1, vec, full=False) self._numBasisFuncFull = self._NumBasisFunc(self._dim - 1, vec, full=True) - def _NumBasisFunc(self, dim: int, vec: npt.NDArray, n: int = 0, full: bool = False) -> int: + def _NumBasisFunc(self, dim: int, vec: JaxOrNumpyArray, n: int = 0, full: bool = False) -> int: """ Calculate the number of basis functions. @@ -960,14 +959,14 @@ def _NumBasisFunc(self, dim: int, vec: npt.NDArray, n: int = 0, full: bool = Fal return n @property - def c(self) -> npt.NDArray: + def c(self) -> JaxOrNumpyArray: """ Return the constants that map the problem domain to the basis function domain. Returns ------- - npt.NDArray + JaxOrNumpyArray The constants that map the problem domain to the basis function domain. """ @@ -1004,7 +1003,7 @@ def numBasisFuncFull(self) -> float: return self._numBasisFuncFull - def H(self, x: npt.NDArray, d: npt.NDArray, full: bool = False) -> npt.NDArray: + def H(self, x: JaxOrNumpyArray, d: JaxOrNumpyArray, full: bool = False) -> JaxOrNumpyArray: """ Returns the basis function matrix for the x with a derivative of order d. @@ -1041,7 +1040,7 @@ def H(self, x: npt.NDArray, d: npt.NDArray, full: bool = False) -> npt.NDArray: T[:, :, k] = self._Hint(z[k : k + 1, :].T, d[k]) * self._c[k] ** d[k] # Define functions for use in generating the CP sheet - def MultT(vec: npt.NDArray) -> npt.NDArray: + def MultT(vec: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Creates basis functions for the multidimensional case by mulitplying the basis functions for the single dimensional cases together. @@ -1062,8 +1061,8 @@ def MultT(vec: npt.NDArray) -> npt.NDArray: return tout def Recurse( - dim: int, out: npt.NDArray, vec: npt.NDArray, n: int = 0, full: bool = False - ) -> Tuple[npt.NDArray, int]: + dim: int, out: JaxOrNumpyArray, vec: JaxOrNumpyArray, n: int = 0, full: bool = False + ) -> Tuple[JaxOrNumpyArray, int]: """ Creates basis functions for the multidimensional case given the basis functions for the single dimensional cases. @@ -1129,9 +1128,9 @@ class nCP(nBasisFunc, CP): def __init__( self, - x0: npt.NDArray, - xf: npt.NDArray, - nC: npt.NDArray, + x0: JaxOrNumpyArray, + xf: JaxOrNumpyArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -1159,9 +1158,9 @@ class nLeP(nBasisFunc, LeP): def __init__( self, - x0: npt.NDArray, - xf: npt.NDArray, - nC: npt.NDArray, + x0: JaxOrNumpyArray, + xf: JaxOrNumpyArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -1189,9 +1188,9 @@ class nFS(nBasisFunc, FS): def __init__( self, - x0: npt.NDArray, - xf: npt.NDArray, - nC: npt.NDArray, + x0: JaxOrNumpyArray, + xf: JaxOrNumpyArray, + nC: JaxOrNumpyArray, m: uint, ) -> None: """ @@ -1219,9 +1218,9 @@ class nELM(nBasisFunc): def __init__( self, - x0: npt.NDArray, - xf: npt.NDArray, - nC: npt.NDArray, + x0: JaxOrNumpyArray, + xf: JaxOrNumpyArray, + nC: JaxOrNumpyArray, m: uint, z0: Number = 0.0, zf: Number = 1.0, @@ -1273,7 +1272,7 @@ def __init__( self._b = self._b.reshape((1, self._m)) @property - def w(self) -> npt.NDArray: + def w(self) -> JaxOrNumpyArray: """ Weights of the nELM @@ -1285,7 +1284,7 @@ def w(self) -> npt.NDArray: return self._w @property - def b(self) -> npt.NDArray: + def b(self) -> JaxOrNumpyArray: """ Biases of the nELM @@ -1297,7 +1296,7 @@ def b(self) -> npt.NDArray: return self._b @w.setter - def w(self, val: npt.NDArray) -> None: + def w(self, val: JaxOrNumpyArray) -> None: """ Weights of the nELM. @@ -1316,7 +1315,7 @@ def w(self, val: npt.NDArray) -> None: ) @b.setter - def b(self, val: npt.NDArray) -> None: + def b(self, val: JaxOrNumpyArray) -> None: """ Biases of the nELM. @@ -1334,7 +1333,7 @@ def b(self, val: npt.NDArray) -> None: f"Input array of size {val.size} was received, but size {self._m} was expected." ) - def H(self, x: npt.NDArray, d: npt.NDArray, full: bool = False) -> npt.NDArray: + def H(self, x: JaxOrNumpyArray, d: JaxOrNumpyArray, full: bool = False) -> JaxOrNumpyArray: """ Returns the basis function matrix for the x with a derivative of order d. @@ -1370,7 +1369,7 @@ def H(self, x: npt.NDArray, d: npt.NDArray, full: bool = False) -> npt.NDArray: return F @abstractmethod - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1388,7 +1387,7 @@ def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: """ pass - def _Hint(self, z: npt.NDArray, d: uint) -> npt.NDArray: + def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray: """ Dummy function, this should never be called! """ @@ -1400,7 +1399,7 @@ class nELMReLU(nELM): n-dimensional ELM ReLU basis functions. """ - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1443,7 +1442,7 @@ class nELMSin(nELM): n-dimensional ELM sin basis functions. """ - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1467,8 +1466,8 @@ def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: z = jnp.split(z, z.shape[1], axis=1) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: if dCurr == d: return dark else: @@ -1490,7 +1489,7 @@ class nELMTanh(nELM): n-dimensional ELM tanh basis functions. """ - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1514,8 +1513,8 @@ def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: z = jnp.split(z, z.shape[1], axis=1) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: if dCurr == d: return dark else: @@ -1537,7 +1536,7 @@ class nELMSigmoid(nELM): n-dimensional ELM sigmoid basis functions. """ - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1561,8 +1560,8 @@ def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: z = jnp.split(z, z.shape[1], axis=1) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: if dCurr == d: return dark else: @@ -1584,7 +1583,7 @@ class nELMSwish(nELM): n-dimensional ELM swish basis functions. """ - def _nHint(self, z: npt.NDArray, d: npt.NDArray) -> npt.NDArray: + def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray: """ Internal method used to calcualte the basis function value. @@ -1610,8 +1609,8 @@ def f(*x): z = jnp.split(z, z.shape[1], axis=1) def Recurse( - dark: Callable[[npt.NDArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 - ) -> Callable[[npt.NDArray], jnp.ndarray]: + dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0 + ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]: if dCurr == d: return dark else: diff --git a/src/tfc/utils/CeSolver.py b/src/tfc/utils/CeSolver.py index f48fe59..6b4c2ea 100644 --- a/src/tfc/utils/CeSolver.py +++ b/src/tfc/utils/CeSolver.py @@ -3,7 +3,7 @@ from sympy.core.function import AppliedUndef from sympy.printing.pycode import PythonCodePrinter from sympy.simplify.simplify import nc_simplify -from .tfc_types import ConstraintOperators, Exprs, Union, Any, Literal, ConstraintOperator +from .tfc_types import ConstraintOperators, Exprs, Any, Literal, ConstraintOperator from .TFCUtils import TFCPrint @@ -31,7 +31,7 @@ class CeSolver: in terms of sympy symbols and constants. For example, if we wanted to use the constant function x = 1 as a support function, then we would use sympy.re(1) in this iterable. - g : Union[AppliedUndef, Any] + g : AppliedUndef| Any This is the free function used in the constrained expression. For example, `g(x)`. @@ -62,7 +62,7 @@ class CeSolver: In the above code example, `ce` is the constrained expression that satisfies these constraints. """ - def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: Union[AppliedUndef, Any]): + def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: AppliedUndef| Any): self._C = C self._K = kappa self._s = s @@ -317,25 +317,25 @@ def C(self, C: ConstraintOperators) -> None: self._ce_stale = True @property - def g(self) -> Union[AppliedUndef, Any]: + def g(self) -> AppliedUndef| Any: """ Free function. Returns ------- - Union[AppliedUndef, Any] + AppliedUndef | Any Free function. """ return self._g @g.setter - def g(self, g: Union[AppliedUndef, Any]) -> None: + def g(self, g: AppliedUndef| Any) -> None: """ Set the free function. Parameters ---------- - g : Union[AppliedUndef, Any] + g : AppliedUndef | Any This is the free function used in the constrained expression. For example, `g(x)`. """ diff --git a/src/tfc/utils/MakePlot.py b/src/tfc/utils/MakePlot.py index 5f3e9ab..776298f 100644 --- a/src/tfc/utils/MakePlot.py +++ b/src/tfc/utils/MakePlot.py @@ -6,7 +6,7 @@ from .TFCUtils import TFCPrint from .tfc_types import StrArrayLike, Path, Literal, pint -from typing import Optional, Union, Generator, Callable +from typing import Optional, Generator, Callable TFCPrint() @@ -24,7 +24,7 @@ def __init__( titles: Optional[StrArrayLike] = None, twinYlabs: Optional[StrArrayLike] = None, zlabs: Optional[StrArrayLike] = None, - style: Optional[Union[str, dict, Path, list[str], list[dict], list[Path]]] = None, + style: Optional[str| dict| Path| list[str]| list[dict]| list[Path]] = None, ): """ This function initializes the plot/subplots based on the inputs provided. @@ -41,7 +41,7 @@ def __init__( The twin y-axes labels for the plots. Setting this forces twin axis y-axes. (Default value = None) zlabs: StrArrayLike, optional The z-axes labels of for the plots. Setting this forces subplots to be 3D. (Default value = None) - style : Union[str, dict, Path, list[str], list[dict], list[Path]] + style : str| dict| Path| list[str]| list[dict]| list[Path] Matplotlib style. (Default value = None) """ diff --git a/src/tfc/utils/MayaviMakePlot.py b/src/tfc/utils/MayaviMakePlot.py index 54e2976..ada2154 100644 --- a/src/tfc/utils/MayaviMakePlot.py +++ b/src/tfc/utils/MayaviMakePlot.py @@ -4,10 +4,10 @@ from mayavi import mlab from matplotlib import colors as mcolors from .tfc_types import Path, Ge, Le, Annotated, Literal -from typing import Optional, Any, Union, Generator, Callable +from typing import Optional, Any, Generator, Callable from .TFCUtils import TFCPrint -Color = Union[str, tuple[float, float, float, float], npt.NDArray[np.float64]] +Color = str| tuple[float| float| float| float]| npt.NDArray[np.float64] TFCPrint() diff --git a/src/tfc/utils/TFCUtils.py b/src/tfc/utils/TFCUtils.py index 71ff568..8b56a61 100644 --- a/src/tfc/utils/TFCUtils.py +++ b/src/tfc/utils/TFCUtils.py @@ -21,16 +21,15 @@ from jax.core import get_aval, eval_jaxpr from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal from jax.experimental import io_callback -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Optional, cast, Union, overload from .tfc_types import uint, Literal, TypedDict, Path from jaxtyping import PyTree -from typing import cast # Types that can be added to a TFCDict -TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"] +TFCDictAddable = Union[np.ndarray , dict[Any, Any] , "TFCDict"] # Types that can be added to a TFCDictRobust -TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"] +TFCDictRobustAddable = Union[np.ndarray , dict[Any, Any] , "TFCDictRobust"] class TFCPrint: @@ -742,6 +741,30 @@ def __sub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust": ) +@overload +def LS( + zXi: PyTree, + res: Callable, + *args: Any, + constant_arg_nums: list[int] = [], + J: Optional[Callable[..., np.ndarray]] = None, + method: Literal["pinv", "lstsq"] = "pinv", + timer: Literal[False] = False, + timerType: str = "process_time", + holomorphic: bool = False, +) -> PyTree: ... +@overload +def LS( + zXi: PyTree, + res: Callable, + *args: Any, + constant_arg_nums: list[int] = [], + J: Optional[Callable[..., np.ndarray]] = None, + method: Literal["pinv", "lstsq"] = "pinv", + timer: Literal[True] = True, + timerType: str = "process_time", + holomorphic: bool = False, + ) -> tuple[PyTree, float]: ... def LS( zXi: PyTree, res: Callable, @@ -752,7 +775,7 @@ def LS( timer: bool = False, timerType: str = "process_time", holomorphic: bool = False, -) -> Union[PyTree, tuple[PyTree, float]]: +) -> PyTree | tuple[PyTree, float]: """ JITed least squares. This function takes in an initial guess of zeros, zXi, and a residual function, res, and @@ -976,7 +999,7 @@ def J(xi, *args): self._compiled = False - def run(self, zXi: PyTree, *args: Any) -> Union[PyTree, tuple[PyTree, float]]: + def run(self, zXi: PyTree, *args: Any) -> PyTree | tuple[PyTree, float]: """ Runs the JIT-ed least-squares function and times it if desired. @@ -1026,6 +1049,42 @@ def nlls_id_print(it: int, x, end: str = "\n"): print("Iteration: {0}\tmax(abs(res)): {1}".format(it, x), end=end) +@overload +def NLLS( + xiInit: PyTree, + res: Callable, + *args: Any, + constant_arg_nums: list[int] = [], + J: Optional[Callable[..., np.ndarray]] = None, + cond: Optional[Callable[[PyTree], bool]] = None, + body: Optional[Callable[[PyTree], PyTree]] = None, + tol: float = 1e-13, + maxIter: uint = 50, + method: Literal["pinv", "lstsq"] = "pinv", + timer: Literal[False] = False, + printOut: bool = False, + printOutEnd: str = "\n", + timerType: str = "process_time", + holomorphic: bool = False, + ) -> tuple[PyTree, int]: ... +@overload +def NLLS( + xiInit: PyTree, + res: Callable, + *args: Any, + constant_arg_nums: list[int] = [], + J: Optional[Callable[..., np.ndarray]] = None, + cond: Optional[Callable[[PyTree], bool]] = None, + body: Optional[Callable[[PyTree], PyTree]] = None, + tol: float = 1e-13, + maxIter: uint = 50, + method: Literal["pinv", "lstsq"] = "pinv", + timer: Literal[True] = True, + printOut: bool = False, + printOutEnd: str = "\n", + timerType: str = "process_time", + holomorphic: bool = False, + ) -> tuple[PyTree, int, float]: ... def NLLS( xiInit: PyTree, res: Callable, @@ -1042,7 +1101,7 @@ def NLLS( printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, -) -> Union[tuple[PyTree, int], tuple[PyTree, int, float]]: +) -> tuple[PyTree, int] | tuple[PyTree, int, float]: """ JIT-ed non-linear least squares. This function takes in an initial guess, xiInit (initial values of xi), and a residual function, res, and @@ -1203,7 +1262,7 @@ def body(val): nlls = jit(lambda val: lax.while_loop(cond, body, val)) if dictFlag: - dxi = np.ones_like(cast(Union[TFCDict, TFCDictRobust], xiInit).toArray()) + dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray()) else: dxi = np.ones_like(xiInit) @@ -1404,7 +1463,7 @@ def body(val): def run( self, xiInit: PyTree, *args: Any - ) -> Union[tuple[PyTree, int], tuple[PyTree, int, float]]: + ) -> tuple[PyTree, int] | tuple[PyTree, int, float]: """Runs the JIT-ed nonlinear least-squares function and times it if desired. Parameters @@ -1429,7 +1488,7 @@ def run( """ if self._dictFlag: - dxi = np.ones_like(cast(Union[TFCDict, TFCDictRobust], xiInit).toArray()) + dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray()) else: dxi = np.ones_like(xiInit) diff --git a/src/tfc/utils/tfc_types.py b/src/tfc/utils/tfc_types.py index 1fc0842..0d2d0ec 100644 --- a/src/tfc/utils/tfc_types.py +++ b/src/tfc/utils/tfc_types.py @@ -1,5 +1,5 @@ import sys -from typing import Union, Any, Callable +from typing import Any, Callable import numpy as np import numpy.typing as npt from jax import Array @@ -11,7 +11,7 @@ from annotated_types import Gt, Ge, Lt, Le # Path -# Path = Union[str, os.PathLike] +# Path = str | os.PathLike Path = str # Integer > 0 @@ -21,41 +21,34 @@ uint = Annotated[int, Ge(0)] # General number type -Number = Union[int, float, complex] +Number = int| float| complex -from numpy._typing._array_like import _ArrayLikeStr_co, _ArrayLikeInt_co +from numpy._typing._array_like import _ArrayLikeStr_co # Array-like of strings StrArrayLike = _ArrayLikeStr_co # Array-like of integers -IntArrayLike = _ArrayLikeInt_co +IntArrayLike = Annotated[npt.ArrayLike, np.int32] # List or array like -NumberListOrArray = Union[tuple[Number, ...], list[Number], npt.NDArray[Any], Array] +NumberListOrArray = tuple[Number, ...]| list[Number]| npt.NDArray[Any]| Array # List or array of integers -IntListOrArray = Union[ - tuple[int, ...], - list[int], - npt.NDArray[np.int32], - npt.NDArray[np.int64], - npt.NDArray[np.int16], - npt.NDArray[np.int8], -] +IntListOrArray = IntArrayLike # JAX array or numpy array -JaxOrNumpyArray = Union[npt.NDArray, Array] +JaxOrNumpyArray = npt.NDArray | Array # Tuple or list of array -TupleOrListOfArray = Union[tuple[JaxOrNumpyArray, ...], list[JaxOrNumpyArray]] -TupleOrListOfNumpyArray = Union[tuple[npt.NDArray, ...], list[npt.NDArray]] +TupleOrListOfArray = tuple[JaxOrNumpyArray, ...] | list[JaxOrNumpyArray] +TupleOrListOfNumpyArray = tuple[npt.NDArray, ...] | list[npt.NDArray] # Sympy constraint operator # Adding in Any here since sympy types are a bit funky at the moment -ConstraintOperator = Callable[[Union[AppliedUndef, Expr, Any]], Union[AppliedUndef, Any]] -ConstraintOperators = Union[list[ConstraintOperator], tuple[ConstraintOperator, ...]] +ConstraintOperator = Callable[[AppliedUndef| Expr| Any], AppliedUndef| Any] +ConstraintOperators = list[ConstraintOperator]| tuple[ConstraintOperator, ...] # List or tuple of sympy expressions # Adding in Any here since sympy types are a bit funky at the moment -Exprs = Union[list[Union[Expr, Any]], tuple[Union[Expr, Any], ...]] +Exprs = list[Expr| Any] | tuple[Expr| Any, ...] From 501a39cfa860c53f894a65e34835ca28fb56b19b Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 18:57:03 -0700 Subject: [PATCH 44/45] Formatting --- src/tfc/utils/CeSolver.py | 6 +++--- src/tfc/utils/MakePlot.py | 2 +- src/tfc/utils/MayaviMakePlot.py | 2 +- src/tfc/utils/TFCUtils.py | 14 ++++++-------- src/tfc/utils/tfc_types.py | 10 +++++----- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/tfc/utils/CeSolver.py b/src/tfc/utils/CeSolver.py index 6b4c2ea..015c523 100644 --- a/src/tfc/utils/CeSolver.py +++ b/src/tfc/utils/CeSolver.py @@ -62,7 +62,7 @@ class CeSolver: In the above code example, `ce` is the constrained expression that satisfies these constraints. """ - def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: AppliedUndef| Any): + def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: AppliedUndef | Any): self._C = C self._K = kappa self._s = s @@ -317,7 +317,7 @@ def C(self, C: ConstraintOperators) -> None: self._ce_stale = True @property - def g(self) -> AppliedUndef| Any: + def g(self) -> AppliedUndef | Any: """ Free function. @@ -329,7 +329,7 @@ def g(self) -> AppliedUndef| Any: return self._g @g.setter - def g(self, g: AppliedUndef| Any) -> None: + def g(self, g: AppliedUndef | Any) -> None: """ Set the free function. diff --git a/src/tfc/utils/MakePlot.py b/src/tfc/utils/MakePlot.py index 776298f..2b646d5 100644 --- a/src/tfc/utils/MakePlot.py +++ b/src/tfc/utils/MakePlot.py @@ -24,7 +24,7 @@ def __init__( titles: Optional[StrArrayLike] = None, twinYlabs: Optional[StrArrayLike] = None, zlabs: Optional[StrArrayLike] = None, - style: Optional[str| dict| Path| list[str]| list[dict]| list[Path]] = None, + style: Optional[str | dict | Path | list[str] | list[dict] | list[Path]] = None, ): """ This function initializes the plot/subplots based on the inputs provided. diff --git a/src/tfc/utils/MayaviMakePlot.py b/src/tfc/utils/MayaviMakePlot.py index ada2154..e965606 100644 --- a/src/tfc/utils/MayaviMakePlot.py +++ b/src/tfc/utils/MayaviMakePlot.py @@ -7,7 +7,7 @@ from typing import Optional, Any, Generator, Callable from .TFCUtils import TFCPrint -Color = str| tuple[float| float| float| float]| npt.NDArray[np.float64] +Color = str | tuple[float | float | float | float] | npt.NDArray[np.float64] TFCPrint() diff --git a/src/tfc/utils/TFCUtils.py b/src/tfc/utils/TFCUtils.py index 8b56a61..81eb855 100644 --- a/src/tfc/utils/TFCUtils.py +++ b/src/tfc/utils/TFCUtils.py @@ -26,10 +26,10 @@ from jaxtyping import PyTree # Types that can be added to a TFCDict -TFCDictAddable = Union[np.ndarray , dict[Any, Any] , "TFCDict"] +TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"] # Types that can be added to a TFCDictRobust -TFCDictRobustAddable = Union[np.ndarray , dict[Any, Any] , "TFCDictRobust"] +TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"] class TFCPrint: @@ -764,7 +764,7 @@ def LS( timer: Literal[True] = True, timerType: str = "process_time", holomorphic: bool = False, - ) -> tuple[PyTree, float]: ... +) -> tuple[PyTree, float]: ... def LS( zXi: PyTree, res: Callable, @@ -1066,7 +1066,7 @@ def NLLS( printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, - ) -> tuple[PyTree, int]: ... +) -> tuple[PyTree, int]: ... @overload def NLLS( xiInit: PyTree, @@ -1084,7 +1084,7 @@ def NLLS( printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, - ) -> tuple[PyTree, int, float]: ... +) -> tuple[PyTree, int, float]: ... def NLLS( xiInit: PyTree, res: Callable, @@ -1461,9 +1461,7 @@ def body(val): self._nlls = jit(lambda val: lax.while_loop(cond, body, val)) self._compiled = False - def run( - self, xiInit: PyTree, *args: Any - ) -> tuple[PyTree, int] | tuple[PyTree, int, float]: + def run(self, xiInit: PyTree, *args: Any) -> tuple[PyTree, int] | tuple[PyTree, int, float]: """Runs the JIT-ed nonlinear least-squares function and times it if desired. Parameters diff --git a/src/tfc/utils/tfc_types.py b/src/tfc/utils/tfc_types.py index 0d2d0ec..d86a4a7 100644 --- a/src/tfc/utils/tfc_types.py +++ b/src/tfc/utils/tfc_types.py @@ -21,7 +21,7 @@ uint = Annotated[int, Ge(0)] # General number type -Number = int| float| complex +Number = int | float | complex from numpy._typing._array_like import _ArrayLikeStr_co @@ -32,7 +32,7 @@ IntArrayLike = Annotated[npt.ArrayLike, np.int32] # List or array like -NumberListOrArray = tuple[Number, ...]| list[Number]| npt.NDArray[Any]| Array +NumberListOrArray = tuple[Number, ...] | list[Number] | npt.NDArray[Any] | Array # List or array of integers IntListOrArray = IntArrayLike @@ -46,9 +46,9 @@ # Sympy constraint operator # Adding in Any here since sympy types are a bit funky at the moment -ConstraintOperator = Callable[[AppliedUndef| Expr| Any], AppliedUndef| Any] -ConstraintOperators = list[ConstraintOperator]| tuple[ConstraintOperator, ...] +ConstraintOperator = Callable[[AppliedUndef | Expr | Any], AppliedUndef | Any] +ConstraintOperators = list[ConstraintOperator] | tuple[ConstraintOperator, ...] # List or tuple of sympy expressions # Adding in Any here since sympy types are a bit funky at the moment -Exprs = list[Expr| Any] | tuple[Expr| Any, ...] +Exprs = list[Expr | Any] | tuple[Expr | Any, ...] From b1a726fac32faea7ae54d6979295a7f4d71f68fa Mon Sep 17 00:00:00 2001 From: leakec Date: Sun, 3 Aug 2025 18:57:08 -0700 Subject: [PATCH 45/45] Getting rid of unecessary imports. --- tests/test_BF.py | 1 - tests/test_step.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/test_BF.py b/tests/test_BF.py index 41a4c87..2e4f985 100644 --- a/tests/test_BF.py +++ b/tests/test_BF.py @@ -13,7 +13,6 @@ ELMSin, ELMSwish, ) -from tfc.utils import egrad def test_CP(): diff --git a/tests/test_step.py b/tests/test_step.py index 1578681..42be0ca 100644 --- a/tests/test_step.py +++ b/tests/test_step.py @@ -1,7 +1,5 @@ import jax.numpy as np -from tfc import utfc as TFC -from tfc import mtfc as nTFC from tfc.utils import egrad, step def test_step():