OnnxRuntime
onnxruntime_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors.
8//
9// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
10// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
11//
12// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
13// methods for this purpose.
14
15#pragma once
16#include "onnxruntime_c_api.h"
17#include <cstddef>
18#include <array>
19#include <memory>
20#include <stdexcept>
21#include <string>
22#include <vector>
23#include <unordered_map>
24#include <utility>
25#include <type_traits>
26
27#ifdef ORT_NO_EXCEPTIONS
28#include <iostream>
29#endif
30
34namespace Ort {
35
40struct Exception : std::exception {
41 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
42
43 OrtErrorCode GetOrtErrorCode() const { return code_; }
44 const char* what() const noexcept override { return message_.c_str(); }
45
46 private:
47 std::string message_;
48 OrtErrorCode code_;
49};
50
51#ifdef ORT_NO_EXCEPTIONS
52// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
53// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
54#ifndef ORT_CXX_API_THROW
55#define ORT_CXX_API_THROW(string, code) \
56 do { \
57 std::cerr << Ort::Exception(string, code) \
58 .what() \
59 << std::endl; \
60 abort(); \
61 } while (false)
62#endif
63#else
64#define ORT_CXX_API_THROW(string, code) \
65 throw Ort::Exception(string, code)
66#endif
67
68// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make
69// it transparent to the users of the API.
70template <typename T>
71struct Global {
72 static const OrtApi* api_;
73};
74
75// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
76template <typename T>
77#ifdef ORT_API_MANUAL_INIT
78const OrtApi* Global<T>::api_{};
79inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
80#else
81#if defined(_MSC_VER) && !defined(__clang__)
82#pragma warning(push)
83// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
84// Please define ORT_API_MANUAL_INIT if it conerns you.
85#pragma warning(disable : 26426)
86#endif
88#if defined(_MSC_VER) && !defined(__clang__)
89#pragma warning(pop)
90#endif
91#endif
92
94inline const OrtApi& GetApi() { return *Global<void>::api_; }
95
97std::vector<std::string> GetAvailableProviders();
98
99// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
100// This can't be done in the C API since C doesn't have function overloading.
101#define ORT_DEFINE_RELEASE(NAME) \
102 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
103
104ORT_DEFINE_RELEASE(Allocator);
105ORT_DEFINE_RELEASE(MemoryInfo);
106ORT_DEFINE_RELEASE(CustomOpDomain);
107ORT_DEFINE_RELEASE(Env);
108ORT_DEFINE_RELEASE(RunOptions);
109ORT_DEFINE_RELEASE(Session);
110ORT_DEFINE_RELEASE(SessionOptions);
111ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
112ORT_DEFINE_RELEASE(SequenceTypeInfo);
113ORT_DEFINE_RELEASE(MapTypeInfo);
114ORT_DEFINE_RELEASE(TypeInfo);
115ORT_DEFINE_RELEASE(Value);
116ORT_DEFINE_RELEASE(ModelMetadata);
117ORT_DEFINE_RELEASE(ThreadingOptions);
118ORT_DEFINE_RELEASE(IoBinding);
119ORT_DEFINE_RELEASE(ArenaCfg);
120
121#undef ORT_DEFINE_RELEASE
122
162struct Float16_t {
163 uint16_t value;
164 constexpr Float16_t() noexcept : value(0) {}
165 constexpr Float16_t(uint16_t v) noexcept : value(v) {}
166 constexpr operator uint16_t() const noexcept { return value; }
167 constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
168 constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
169};
170
171static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
172
182 uint16_t value;
183 constexpr BFloat16_t() noexcept : value(0) {}
184 constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
185 constexpr operator uint16_t() const noexcept { return value; }
186 constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
187 constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
188};
189
190static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
191
199template <typename T>
200struct Base {
201 using contained_type = T;
202
203 Base() = default;
204 Base(T* p) : p_{p} {
205 if (!p)
206 ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
207 }
209
210 operator T*() { return p_; }
211 operator const T*() const { return p_; }
212
214 T* release() {
215 T* p = p_;
216 p_ = nullptr;
217 return p;
218 }
219
220 protected:
221 Base(const Base&) = delete;
222 Base& operator=(const Base&) = delete;
223 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
224 void operator=(Base&& v) noexcept {
225 OrtRelease(p_);
226 p_ = v.release();
227 }
228
229 T* p_{};
230
231 template <typename>
232 friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
233};
234
239template <typename T>
240struct Unowned : T {
241 Unowned(typename T::contained_type* p) : T{p} {}
242 Unowned(Unowned&& v) : T{v.p_} {}
243 ~Unowned() { this->release(); }
244};
245
246struct AllocatorWithDefaultOptions;
247struct MemoryInfo;
248struct Env;
249struct TypeInfo;
250struct Value;
251struct ModelMetadata;
252
253namespace detail {
254// Light functor to release memory with OrtAllocator
257 explicit AllocatedFree(OrtAllocator* allocator)
258 : allocator_(allocator) {}
259 void operator()(void* ptr) const {
260 if (ptr) allocator_->Free(allocator_, ptr);
261 }
262};
263} // namespace detail
264
269using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
270
276struct Env : Base<OrtEnv> {
277 explicit Env(std::nullptr_t) {}
278
280 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
281
283 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
284
286 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
287
289 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
290 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
291
293 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
294
297
298 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
299};
300
304struct CustomOpDomain : Base<OrtCustomOpDomain> {
305 explicit CustomOpDomain(std::nullptr_t) {}
306
308 explicit CustomOpDomain(const char* domain);
309
310 void Add(OrtCustomOp* op);
311};
312
313struct RunOptions : Base<OrtRunOptions> {
314 explicit RunOptions(std::nullptr_t) {}
316
319
322
323 RunOptions& SetRunTag(const char* run_tag);
324 const char* GetRunTag() const;
325
326 RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
327
334
340};
341
346struct SessionOptions : Base<OrtSessionOptions> {
347 explicit SessionOptions(std::nullptr_t) {}
350
352
353 SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
354 SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
356
359
360 SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
361
362 SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
364
366
369
371
372 SessionOptions& SetLogId(const char* logid);
374
376
378
379 SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
380 SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
381 SessionOptions& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);
382
391 SessionOptions& AppendExecutionProvider(const std::string& provider_name,
392 const std::unordered_map<std::string, std::string>& provider_options = {});
393
395 SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
397};
398
402struct ModelMetadata : Base<OrtModelMetadata> {
403 explicit ModelMetadata(std::nullptr_t) {}
405
411 char* GetProducerName(OrtAllocator* allocator) const;
412
420
426 char* GetGraphName(OrtAllocator* allocator) const;
427
435
441 char* GetDomain(OrtAllocator* allocator) const;
442
450
456 char* GetDescription(OrtAllocator* allocator) const;
457
465
471 char* GetGraphDescription(OrtAllocator* allocator) const;
472
480
486 char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
487
488 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;
489
495 char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
496
507
508 int64_t GetVersion() const;
509};
510
514struct Session : Base<OrtSession> {
515 explicit Session(std::nullptr_t) {}
516 Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
517 Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container);
518 Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
519 Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
520 OrtPrepackedWeightsContainer* prepacked_weights_container);
521
539 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
540 const char* const* output_names, size_t output_count);
541
545 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
546 const char* const* output_names, Value* output_values, size_t output_count);
547
548 void Run(const RunOptions& run_options, const struct IoBinding&);
549
550 size_t GetInputCount() const;
551 size_t GetOutputCount() const;
553
559 char* GetInputName(size_t index, OrtAllocator* allocator) const;
560
569
575 char* GetOutputName(size_t index, OrtAllocator* allocator) const;
576
585
591 char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
592
601
607 char* EndProfiling(OrtAllocator* allocator) const;
608
616 uint64_t GetProfilingStartTimeNs() const;
618
619 TypeInfo GetInputTypeInfo(size_t index) const;
620 TypeInfo GetOutputTypeInfo(size_t index) const;
622};
623
627struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
628 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
630
632 size_t GetElementCount() const;
633
634 size_t GetDimensionsCount() const;
635 void GetDimensions(int64_t* values, size_t values_count) const;
636 void GetSymbolicDimensions(const char** values, size_t values_count) const;
637
638 std::vector<int64_t> GetShape() const;
639};
640
644struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
645 explicit SequenceTypeInfo(std::nullptr_t) {}
647
649};
650
654struct MapTypeInfo : Base<OrtMapTypeInfo> {
655 explicit MapTypeInfo(std::nullptr_t) {}
657
660};
661
662struct TypeInfo : Base<OrtTypeInfo> {
663 explicit TypeInfo(std::nullptr_t) {}
664 explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
665
669
671};
672
673struct Value : Base<OrtValue> {
674 // This structure is used to feed sparse tensor values
675 // information for use with FillSparseTensor<Format>() API
676 // if the data type for the sparse tensor values is numeric
677 // use data.p_data, otherwise, use data.str pointer to feed
678 // values. data.str is an array of const char* that are zero terminated.
679 // number of strings in the array must match shape size.
680 // For fully sparse tensors use shape {0} and set p_data/str
681 // to nullptr.
683 const int64_t* values_shape;
685 union {
686 const void* p_data;
687 const char** str;
689 };
690
691 // Provides a way to pass shape in a single
692 // argument
693 struct Shape {
694 const int64_t* shape;
695 size_t shape_len;
696 };
697
706 template <typename T>
707 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
708
717 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
719
720#if !defined(DISABLE_SPARSE_TENSORS)
731 template <typename T>
732 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
733 const Shape& values_shape);
734
751 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
752 const Shape& values_shape, ONNXTensorElementDataType type);
753
762 void UseCooIndices(int64_t* indices_data, size_t indices_num);
763
774 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
775
784 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
785
786#endif // !defined(DISABLE_SPARSE_TENSORS)
787
794 template <typename T>
795 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
796
803 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
804
805#if !defined(DISABLE_SPARSE_TENSORS)
815 template <typename T>
816 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
817
829 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
830
840 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
841 const int64_t* indices_data, size_t indices_num);
842
854 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
855 const OrtSparseValuesParam& values,
856 const int64_t* inner_indices_data, size_t inner_indices_num,
857 const int64_t* outer_indices_data, size_t outer_indices_num);
858
869 const OrtSparseValuesParam& values,
870 const Shape& indices_shape,
871 const int32_t* indices_data);
872
880
887
896
906 template <typename T>
907 const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
908
909#endif // !defined(DISABLE_SPARSE_TENSORS)
910
911 static Value CreateMap(Value& keys, Value& values);
912 static Value CreateSequence(std::vector<Value>& values);
913
914 template <typename T>
915 static Value CreateOpaque(const char* domain, const char* type_name, const T&);
916
917 template <typename T>
918 void GetOpaqueData(const char* domain, const char* type_name, T&) const;
919
920 explicit Value(std::nullptr_t) {}
921 explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
922 Value(Value&&) = default;
923 Value& operator=(Value&&) = default;
924
925 bool IsTensor() const;
926 bool HasValue() const;
927
928#if !defined(DISABLE_SPARSE_TENSORS)
933 bool IsSparseTensor() const;
934#endif
935
936 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
937 Value GetValue(int index, OrtAllocator* allocator) const;
938
946
961 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
962
963 template <typename T>
965
966 template <typename T>
967 const T* GetTensorData() const;
968
969#if !defined(DISABLE_SPARSE_TENSORS)
978 template <typename T>
979 const T* GetSparseTensorValues() const;
980#endif
981
982 template <typename T>
983 T& At(const std::vector<int64_t>& location);
984
992
1000
1007 size_t GetStringTensorElementLength(size_t element_index) const;
1008
1017 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1018
1019 void FillStringTensor(const char* const* s, size_t s_len);
1020 void FillStringTensorElement(const char* s, size_t index);
1021};
1022
1023// Represents native memory allocation
1025 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1030 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1031
1032 void* get() { return p_; }
1033 size_t size() const { return size_; }
1034
1035 private:
1036 OrtAllocator* allocator_;
1037 void* p_;
1038 size_t size_;
1039};
1040
1043
1044 operator OrtAllocator*() { return p_; }
1045 operator const OrtAllocator*() const { return p_; }
1046
1047 void* Alloc(size_t size);
1048 // The return value will own the allocation
1050 void Free(void* p);
1051
1052 const OrtMemoryInfo* GetInfo() const;
1053
1054 private:
1055 OrtAllocator* p_{};
1056};
1057
1058struct MemoryInfo : Base<OrtMemoryInfo> {
1060
1061 explicit MemoryInfo(std::nullptr_t) {}
1063 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1064
1065 std::string GetAllocatorName() const;
1067 int GetDeviceId() const;
1069
1070 bool operator==(const MemoryInfo& o) const;
1071};
1072
1073struct Allocator : public Base<OrtAllocator> {
1074 Allocator(const Session& session, const MemoryInfo&);
1075
1076 void* Alloc(size_t size) const;
1077 // The return value will own the allocation
1079 void Free(void* p) const;
1081};
1082
1083struct IoBinding : public Base<OrtIoBinding> {
1084 explicit IoBinding(Session& session);
1085 void BindInput(const char* name, const Value&);
1086 void BindOutput(const char* name, const Value&);
1087 void BindOutput(const char* name, const MemoryInfo&);
1088 std::vector<std::string> GetOutputNames() const;
1089 std::vector<std::string> GetOutputNames(Allocator&) const;
1090 std::vector<Value> GetOutputValues() const;
1091 std::vector<Value> GetOutputValues(Allocator&) const;
1096
1097 private:
1098 std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
1099 std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
1100};
1101
1106struct ArenaCfg : Base<OrtArenaCfg> {
1107 explicit ArenaCfg(std::nullptr_t) {}
1116 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1117};
1118
1119//
1120// Custom OPs (only needed to implement custom OPs)
1121//
1122
1124 CustomOpApi(const OrtApi& api) : api_(api) {}
1125
1126 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1127 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1128
1133 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1134 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1135
1136 template <typename T>
1137 T* GetTensorMutableData(_Inout_ OrtValue* value);
1138 template <typename T>
1139 const T* GetTensorData(_Inout_ const OrtValue* value);
1140
1141 const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1142
1143 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1146 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1148 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1150
1152
1153 OrtOpAttr* CreateOpAttr(_In_ const char* name,
1154 _In_ const void* data,
1155 _In_ int len,
1156 _In_ OrtOpAttrType type);
1157
1158 void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1159
1160 OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1161 _In_ const char* op_name,
1162 _In_ const char* domain,
1163 _In_ int version,
1164 _In_opt_ const char** type_constraint_names,
1165 _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1166 _In_opt_ int type_constraint_count,
1167 _In_opt_ const OrtOpAttr* const* attr_values,
1168 _In_opt_ int attr_count,
1169 _In_ int input_count,
1170 _In_ int output_count);
1171
1172 void InvokeOp(_In_ const OrtKernelContext* context,
1173 _In_ const OrtOp* ort_op,
1174 _In_ const OrtValue* const* input_values,
1175 _In_ int input_count,
1176 _Inout_ OrtValue* const* output_values,
1177 _In_ int output_count);
1178
1179 void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1180
1182
1183 void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1184
1185 private:
1186 const OrtApi& api_;
1187};
1188
1189template <typename TOp, typename TKernel>
1193 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
1194 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
1195
1196 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
1197
1198 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
1199 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
1200
1201 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
1202 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
1203
1204 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
1205#if defined(_MSC_VER) && !defined(__clang__)
1206#pragma warning(push)
1207#pragma warning(disable : 26409)
1208#endif
1209 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
1210#if defined(_MSC_VER) && !defined(__clang__)
1211#pragma warning(pop)
1212#endif
1213 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
1214 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
1215 }
1216
1217 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
1218 const char* GetExecutionProviderType() const { return nullptr; }
1219
1220 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
1221 // (inputs and outputs are required by default)
1224 }
1225
1228 }
1229};
1230
1231} // namespace Ort
1232
1233#include "onnxruntime_cxx_inline.h"
struct OrtMemoryInfo OrtMemoryInfo
Definition: onnxruntime_c_api.h:252
struct OrtKernelInfo OrtKernelInfo
Definition: onnxruntime_c_api.h:327
OrtLoggingLevel
Logging severity levels.
Definition: onnxruntime_c_api.h:207
void(* OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
Definition: onnxruntime_c_api.h:292
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
Definition: onnxruntime_c_api.h:586
OrtCustomOpInputOutputCharacteristic
Definition: onnxruntime_c_api.h:3515
struct OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsV2
Definition: onnxruntime_c_api.h:268
struct OrtOpAttr OrtOpAttr
Definition: onnxruntime_c_api.h:271
struct OrtThreadingOptions OrtThreadingOptions
Definition: onnxruntime_c_api.h:265
struct OrtSequenceTypeInfo OrtSequenceTypeInfo
Definition: onnxruntime_c_api.h:262
OrtSparseIndicesFormat
Definition: onnxruntime_c_api.h:196
struct OrtPrepackedWeightsContainer OrtPrepackedWeightsContainer
Definition: onnxruntime_c_api.h:267
struct OrtCustomOpDomain OrtCustomOpDomain
Definition: onnxruntime_c_api.h:260
OrtAllocatorType
Definition: onnxruntime_c_api.h:333
struct OrtOp OrtOp
Definition: onnxruntime_c_api.h:270
struct OrtModelMetadata OrtModelMetadata
Definition: onnxruntime_c_api.h:263
struct OrtTypeInfo OrtTypeInfo
Definition: onnxruntime_c_api.h:257
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition: onnxruntime_c_api.h:258
struct OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2
Definition: onnxruntime_c_api.h:269
struct OrtKernelContext OrtKernelContext
Definition: onnxruntime_c_api.h:329
struct OrtSessionOptions OrtSessionOptions
Definition: onnxruntime_c_api.h:259
struct OrtValue OrtValue
Definition: onnxruntime_c_api.h:255
GraphOptimizationLevel
Graph optimization level.
Definition: onnxruntime_c_api.h:301
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
Definition: onnxruntime_c_api.h:342
OrtSparseFormat
Definition: onnxruntime_c_api.h:188
ONNXType
Definition: onnxruntime_c_api.h:176
struct OrtEnv OrtEnv
Definition: onnxruntime_c_api.h:250
OrtErrorCode
Definition: onnxruntime_c_api.h:215
struct OrtStatus OrtStatus
Definition: onnxruntime_c_api.h:251
#define ORT_API_VERSION
The API version defined in this header.
Definition: onnxruntime_c_api.h:33
struct OrtMapTypeInfo OrtMapTypeInfo
Definition: onnxruntime_c_api.h:261
struct OrtArenaCfg OrtArenaCfg
Definition: onnxruntime_c_api.h:266
ExecutionMode
Definition: onnxruntime_c_api.h:308
OrtOpAttrType
Definition: onnxruntime_c_api.h:230
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
Definition: onnxruntime_c_api.h:579
ONNXTensorElementDataType
Definition: onnxruntime_c_api.h:155
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
@ ORT_LOGGING_LEVEL_WARNING
Warning messages.
Definition: onnxruntime_c_api.h:210
@ INPUT_OUTPUT_REQUIRED
Definition: onnxruntime_c_api.h:3517
@ ORT_FAIL
Definition: onnxruntime_c_api.h:217
All C++ Onnxruntime APIs are defined inside this namespace.
Definition: onnxruntime_cxx_api.h:34
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Definition: onnxruntime_cxx_api.h:269
const OrtApi & GetApi()
This returns a reference to the OrtApi interface in use.
Definition: onnxruntime_cxx_api.h:94
void OrtRelease(OrtAllocator *ptr)
Definition: onnxruntime_cxx_api.h:104
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Definition: onnxruntime_cxx_api.h:1073
void Free(void *p) const
MemoryAllocation GetAllocation(size_t size)
void * Alloc(size_t size) const
Unowned< const MemoryInfo > GetInfo() const
Allocator(const Session &session, const MemoryInfo &)
Definition: onnxruntime_cxx_api.h:1041
const OrtMemoryInfo * GetInfo() const
MemoryAllocation GetAllocation(size_t size)
it is a structure that represents the configuration of an arena based allocator
Definition: onnxruntime_cxx_api.h:1106
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:1107
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk)
bfloat16 (Brain Floating Point) data type
Definition: onnxruntime_cxx_api.h:181
uint16_t value
Definition: onnxruntime_cxx_api.h:182
constexpr bool operator!=(const BFloat16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:187
constexpr BFloat16_t(uint16_t v) noexcept
Definition: onnxruntime_cxx_api.h:184
constexpr bool operator==(const BFloat16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:186
constexpr BFloat16_t() noexcept
Definition: onnxruntime_cxx_api.h:183
Used internally by the C++ API. C++ wrapper types inherit from this.
Definition: onnxruntime_cxx_api.h:200
Base & operator=(const Base &)=delete
T * release()
Releases ownership of the contained pointer.
Definition: onnxruntime_cxx_api.h:214
~Base()
Definition: onnxruntime_cxx_api.h:208
Base()=default
Base(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:223
T * p_
Definition: onnxruntime_cxx_api.h:229
Base(const Base &)=delete
Base(T *p)
Definition: onnxruntime_cxx_api.h:204
void operator=(Base &&v) noexcept
Definition: onnxruntime_cxx_api.h:224
T contained_type
Definition: onnxruntime_cxx_api.h:201
Definition: onnxruntime_cxx_api.h:1123
size_t KernelContext_GetOutputCount(const OrtKernelContext *context)
size_t GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info)
void * KernelContext_GetGPUComputeStream(const OrtKernelContext *context)
size_t KernelContext_GetInputCount(const OrtKernelContext *context)
void InvokeOp(const OrtKernelContext *context, const OrtOp *ort_op, const OrtValue *const *input_values, int input_count, OrtValue *const *output_values, int output_count)
OrtOpAttr * CreateOpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
void ReleaseOp(OrtOp *ort_op)
OrtValue * KernelContext_GetOutput(OrtKernelContext *context, size_t index, const int64_t *dim_values, size_t dim_count)
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input)
T KernelInfoGetAttribute(const OrtKernelInfo *info, const char *name)
OrtTensorTypeAndShapeInfo * GetTensorTypeAndShape(const OrtValue *value)
std::vector< int64_t > GetTensorShape(const OrtTensorTypeAndShapeInfo *info)
void GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length)
void ReleaseOpAttr(OrtOpAttr *op_attr)
void ThrowOnError(OrtStatus *result)
size_t GetTensorShapeElementCount(const OrtTensorTypeAndShapeInfo *info)
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo *info)
OrtOp * CreateOp(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, int type_constraint_count, const OrtOpAttr *const *attr_values, int attr_count, int input_count, int output_count)
CustomOpApi(const OrtApi &api)
Definition: onnxruntime_cxx_api.h:1124
void SetDimensions(OrtTensorTypeAndShapeInfo *info, const int64_t *dim_values, size_t dim_count)
OrtKernelInfo * CopyKernelInfo(const OrtKernelInfo *info)
void ReleaseKernelInfo(OrtKernelInfo *info_copy)
T * GetTensorMutableData(OrtValue *value)
const OrtValue * KernelContext_GetInput(const OrtKernelContext *context, size_t index)
const OrtMemoryInfo * GetTensorMemoryInfo(const OrtValue *value)
const T * GetTensorData(const OrtValue *value)
Definition: onnxruntime_cxx_api.h:1190
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Definition: onnxruntime_cxx_api.h:1226
CustomOpBase()
Definition: onnxruntime_cxx_api.h:1191
const char * GetExecutionProviderType() const
Definition: onnxruntime_cxx_api.h:1218
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
Definition: onnxruntime_cxx_api.h:1222
Custom Op Domain.
Definition: onnxruntime_cxx_api.h:304
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:305
CustomOpDomain(const char *domain)
Wraps OrtApi::CreateCustomOpDomain.
void Add(OrtCustomOp *op)
Wraps CustomOpDomain_Add.
The Env (Environment)
Definition: onnxruntime_cxx_api.h:276
Env & EnableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
Env(OrtEnv *p)
C Interop Helper.
Definition: onnxruntime_cxx_api.h:293
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:277
Env(OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnv.
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithGlobalThreadPools.
Env(const OrtThreadingOptions *tp_options, OrtLoggingFunction logging_function, void *logger_param, OrtLoggingLevel logging_level=ORT_LOGGING_LEVEL_WARNING, const char *logid="")
Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools.
Env(OrtLoggingLevel logging_level, const char *logid, OrtLoggingFunction logging_function, void *logger_param)
Wraps OrtApi::CreateEnvWithCustomLogger.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
Env & DisableTelemetryEvents()
Wraps OrtApi::DisableTelemetryEvents.
All C++ methods that can fail will throw an exception of this type.
Definition: onnxruntime_cxx_api.h:40
const char * what() const noexcept override
Definition: onnxruntime_cxx_api.h:44
OrtErrorCode GetOrtErrorCode() const
Definition: onnxruntime_cxx_api.h:43
Exception(std::string &&string, OrtErrorCode code)
Definition: onnxruntime_cxx_api.h:41
IEEE 754 half-precision floating point data type.
Definition: onnxruntime_cxx_api.h:162
constexpr bool operator!=(const Float16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:168
constexpr Float16_t(uint16_t v) noexcept
Definition: onnxruntime_cxx_api.h:165
uint16_t value
Definition: onnxruntime_cxx_api.h:163
constexpr bool operator==(const Float16_t &rhs) const noexcept
Definition: onnxruntime_cxx_api.h:167
constexpr Float16_t() noexcept
Definition: onnxruntime_cxx_api.h:164
Definition: onnxruntime_cxx_api.h:71
static const OrtApi * api_
Definition: onnxruntime_cxx_api.h:72
Definition: onnxruntime_cxx_api.h:1083
void BindInput(const char *name, const Value &)
std::vector< Value > GetOutputValues() const
void SynchronizeOutputs()
std::vector< std::string > GetOutputNames() const
std::vector< Value > GetOutputValues(Allocator &) const
std::vector< std::string > GetOutputNames(Allocator &) const
void BindOutput(const char *name, const MemoryInfo &)
void ClearBoundOutputs()
void SynchronizeInputs()
void ClearBoundInputs()
void BindOutput(const char *name, const Value &)
IoBinding(Session &session)
Wrapper around OrtMapTypeInfo.
Definition: onnxruntime_cxx_api.h:654
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
MapTypeInfo(OrtMapTypeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:656
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:655
Definition: onnxruntime_cxx_api.h:1024
MemoryAllocation(MemoryAllocation &&) noexcept
MemoryAllocation & operator=(const MemoryAllocation &)=delete
void * get()
Definition: onnxruntime_cxx_api.h:1032
MemoryAllocation(const MemoryAllocation &)=delete
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
size_t size() const
Definition: onnxruntime_cxx_api.h:1033
Definition: onnxruntime_cxx_api.h:1058
OrtAllocatorType GetAllocatorType() const
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type)
MemoryInfo(std::nullptr_t)
Definition: onnxruntime_cxx_api.h:1061
MemoryInfo(OrtMemoryInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:1062
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
std::string GetAllocatorName() const
bool operator==(const MemoryInfo &o) const
int GetDeviceId() const
OrtMemType GetMemoryType() const
Wrapper around OrtModelMetadata.
Definition: onnxruntime_cxx_api.h:402
char * GetDomain(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetDomain.
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys.
char * LookupCustomMetadataMap(const char *key, OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataLookupCustomMetadataMap.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:403
char * GetProducerName(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetProducerName.
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
char * GetGraphName(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetGraphName.
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator *allocator) const
Returns a copy of the producer name.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Returns a copy of the graph name.
char * GetGraphDescription(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetGraphDescription.
char * GetDescription(OrtAllocator *allocator) const
Wraps OrtApi::ModelMetadataGetDescription.
char ** GetCustomMetadataMapKeys(OrtAllocator *allocator, int64_t &num_keys) const
Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
ModelMetadata(OrtModelMetadata *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:404
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
Definition: onnxruntime_cxx_api.h:313
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance.
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:314
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
RunOptions()
Wraps OrtApi::CreateRunOptions.
Wrapper around OrtSequenceTypeInfo.
Definition: onnxruntime_cxx_api.h:644
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:645
SequenceTypeInfo(OrtSequenceTypeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:646
Wrapper around OrtSession.
Definition: onnxruntime_cxx_api.h:514
Session(Env &env, const char *model_path, const SessionOptions &options)
Wraps OrtApi::CreateSession.
char * GetInputName(size_t index, OrtAllocator *allocator) const
Wraps OrtApi::SessionGetInputName.
char * GetOutputName(size_t index, OrtAllocator *allocator) const
Wraps OrtApi::SessionGetOutputName.
size_t GetInputCount() const
Returns the number of model inputs.
Session(Env &env, const char *model_path, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer.
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:515
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count)
Run the model returning results in user provided outputs Same as Run(const RunOptions&,...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
size_t GetOutputCount() const
Returns the number of model outputs.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
Session(Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options)
Wraps OrtApi::CreateSessionFromArray.
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
char * EndProfiling(OrtAllocator *allocator) const
Wraps OrtApi::SessionEndProfiling.
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
char * GetOverridableInitializerName(size_t index, OrtAllocator *allocator) const
Wraps OrtApi::SessionGetOverridableInitializerName.
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator) const
Returns a copy of the profiling file name.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
void Run(const RunOptions &run_options, const struct IoBinding &)
Wraps OrtApi::RunWithBinding.
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
Session(Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options, OrtPrepackedWeightsContainer *prepacked_weights_container)
Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer.
Options object used when creating a new Session object.
Definition: onnxruntime_cxx_api.h:346
SessionOptions & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
SessionOptions & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
SessionOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
SessionOptions & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptions & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
SessionOptions & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
SessionOptions & DisableProfiling()
Wraps OrtApi::DisableProfiling.
SessionOptions & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
SessionOptions & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:347
SessionOptions & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
SessionOptions & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
SessionOptions & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
SessionOptions & EnableProfiling(const char *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
SessionOptions & SetOptimizedModelFilePath(const char *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
SessionOptions & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
SessionOptions & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO.
SessionOptions & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptions & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
SessionOptions & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
SessionOptions & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
SessionOptions & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
SessionOptions & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
SessionOptions & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
SessionOptions & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
SessionOptions & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptions & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
SessionOptions & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
SessionOptions & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
SessionOptions(OrtSessionOptions *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:349
SessionOptions & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
SessionOptions & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Wrapper around OrtTensorTypeAndShapeInfo.
Definition: onnxruntime_cxx_api.h:627
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:628
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:629
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Definition: onnxruntime_cxx_api.h:662
Unowned< MapTypeInfo > GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
Unowned< SequenceTypeInfo > GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
ONNXType GetONNXType() const
Unowned< TensorTypeAndShapeInfo > GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:663
TypeInfo(OrtTypeInfo *p)
C API Interop.
Definition: onnxruntime_cxx_api.h:664
Wraps an object that inherits from Ort::Base and stops it from deleting the contained pointer on dest...
Definition: onnxruntime_cxx_api.h:240
Unowned(Unowned &&v)
Definition: onnxruntime_cxx_api.h:242
~Unowned()
Definition: onnxruntime_cxx_api.h:243
Unowned(typename T::contained_type *p)
Definition: onnxruntime_cxx_api.h:241
Definition: onnxruntime_cxx_api.h:682
const char ** str
Definition: onnxruntime_cxx_api.h:687
size_t values_shape_len
Definition: onnxruntime_cxx_api.h:684
const int64_t * values_shape
Definition: onnxruntime_cxx_api.h:683
const void * p_data
Definition: onnxruntime_cxx_api.h:686
union Ort::Value::OrtSparseValuesParam::@0 data
Definition: onnxruntime_cxx_api.h:693
const int64_t * shape
Definition: onnxruntime_cxx_api.h:694
size_t shape_len
Definition: onnxruntime_cxx_api.h:695
Definition: onnxruntime_cxx_api.h:673
T * GetTensorMutableData()
Wraps OrtApi::GetTensorMutableData.
static Value CreateMap(Value &keys, Value &values)
Wraps OrtApi::CreateValue.
static Value CreateSparseTensor(const OrtMemoryInfo *info, void *p_data, const Shape &dense_shape, const Shape &values_shape, ONNXTensorElementDataType type)
Creates an OrtValue instance containing SparseTensor. This constructs a sparse tensor that makes use ...
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
const T * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
Value & operator=(Value &&)=default
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape, ONNXTensorElementDataType type)
Creates an instance of OrtValue containing sparse tensor. The created instance has no data....
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
Value(Value &&)=default
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Definition: onnxruntime_cxx_api.h:920
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
const T * GetTensorData() const
Wraps OrtApi::GetTensorMutableData.
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
Value(OrtValue *p)
Used for interop with the C API.
Definition: onnxruntime_cxx_api.h:921
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
static Value CreateSparseTensor(OrtAllocator *allocator, const Shape &dense_shape)
This is a simple forwarding method the below CreateSparseTensor. This helps to specify data type enum...
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
size_t GetCount() const
bool IsSparseTensor() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor....
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
void FillStringTensor(const char *const *s, size_t s_len)
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
T & At(const std::vector< int64_t > &location)
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape, size_t shape_len, ONNXTensorElementDataType type)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
static Value CreateOpaque(const char *domain, const char *type_name, const T &)
Wraps OrtApi::CreateOpaqueValue.
Value GetValue(int index, OrtAllocator *allocator) const
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len)
Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
bool HasValue() const
static Value CreateSequence(std::vector< Value > &values)
Wraps OrtApi::CreateValue.
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
void FillStringTensorElement(const char *s, size_t index)
const T * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values....
void GetOpaqueData(const char *domain, const char *type_name, T &) const
Wraps OrtApi::GetOpaqueValue.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor....
Definition: onnxruntime_cxx_api.h:255
AllocatedFree(OrtAllocator *allocator)
Definition: onnxruntime_cxx_api.h:257
OrtAllocator * allocator_
Definition: onnxruntime_cxx_api.h:256
void operator()(void *ptr) const
Definition: onnxruntime_cxx_api.h:259
Memory allocation interface.
Definition: onnxruntime_c_api.h:285
void(* Free)(struct OrtAllocator *this_, void *p)
Free a block of memory previously allocated with OrtAllocator::Alloc.
Definition: onnxruntime_c_api.h:288
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition: onnxruntime_c_api.h:554
The C API.
Definition: onnxruntime_c_api.h:595
CUDA Provider Options.
Definition: onnxruntime_c_api.h:361
Definition: onnxruntime_c_api.h:3525
OrtCustomOpInputOutputCharacteristic(* GetOutputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:3550
size_t(* GetInputTypeCount)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:3540
const char *(* GetName)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:3533
size_t(* GetOutputTypeCount)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:3542
void(* KernelDestroy)(void *op_kernel)
Definition: onnxruntime_c_api.h:3546
void *(* CreateKernel)(const struct OrtCustomOp *op, const OrtApi *api, const OrtKernelInfo *info)
Definition: onnxruntime_c_api.h:3529
uint32_t version
Definition: onnxruntime_c_api.h:3526
ONNXTensorElementDataType(* GetInputType)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:3539
OrtCustomOpInputOutputCharacteristic(* GetInputCharacteristic)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:3549
const char *(* GetExecutionProviderType)(const struct OrtCustomOp *op)
Definition: onnxruntime_c_api.h:3536
ONNXTensorElementDataType(* GetOutputType)(const struct OrtCustomOp *op, size_t index)
Definition: onnxruntime_c_api.h:3541
void(* KernelCompute)(void *op_kernel, OrtKernelContext *context)
Definition: onnxruntime_c_api.h:3545
MIGraphX Provider Options.
Definition: onnxruntime_c_api.h:506
OpenVINO Provider Options.
Definition: onnxruntime_c_api.h:516
ROCM Provider Options.
Definition: onnxruntime_c_api.h:420
TensorRT Provider Options.
Definition: onnxruntime_c_api.h:478