MLIR 22.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir-c/AffineMap.h"
11#include "mlir-c/IR.h"
12#include "mlir-c/Support.h"
13#include "mlir/CAPI/AffineMap.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/Types.h"
19
20#include <algorithm>
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Integer types.
26//===----------------------------------------------------------------------===//
27
28MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
29
30bool mlirTypeIsAInteger(MlirType type) {
31 return llvm::isa<IntegerType>(unwrap(type));
32}
33
34MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
35 return wrap(IntegerType::get(unwrap(ctx), bitwidth));
36}
37
38MlirStringRef mlirIntegerTypeGetName(void) { return wrap(IntegerType::name); }
39
40MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
41 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
42}
43
44MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
45 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
46}
47
48unsigned mlirIntegerTypeGetWidth(MlirType type) {
49 return llvm::cast<IntegerType>(unwrap(type)).getWidth();
50}
51
52bool mlirIntegerTypeIsSignless(MlirType type) {
53 return llvm::cast<IntegerType>(unwrap(type)).isSignless();
54}
55
56bool mlirIntegerTypeIsSigned(MlirType type) {
57 return llvm::cast<IntegerType>(unwrap(type)).isSigned();
58}
59
60bool mlirIntegerTypeIsUnsigned(MlirType type) {
61 return llvm::cast<IntegerType>(unwrap(type)).isUnsigned();
62}
63
64//===----------------------------------------------------------------------===//
65// Index type.
66//===----------------------------------------------------------------------===//
67
68MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
69
70bool mlirTypeIsAIndex(MlirType type) {
71 return llvm::isa<IndexType>(unwrap(type));
72}
73
74MlirType mlirIndexTypeGet(MlirContext ctx) {
75 return wrap(IndexType::get(unwrap(ctx)));
76}
77
78MlirStringRef mlirIndexTypeGetName(void) { return wrap(IndexType::name); }
79
80//===----------------------------------------------------------------------===//
81// Floating-point types.
82//===----------------------------------------------------------------------===//
83
84bool mlirTypeIsAFloat(MlirType type) {
85 return llvm::isa<FloatType>(unwrap(type));
86}
87
88unsigned mlirFloatTypeGetWidth(MlirType type) {
89 return llvm::cast<FloatType>(unwrap(type)).getWidth();
90}
91
93 return wrap(Float4E2M1FNType::getTypeID());
94}
95
96bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
97 return llvm::isa<Float4E2M1FNType>(unwrap(type));
98}
99
100MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
101 return wrap(Float4E2M1FNType::get(unwrap(ctx)));
102}
103
105 return wrap(Float4E2M1FNType::name);
106}
107
109 return wrap(Float6E2M3FNType::getTypeID());
110}
111
112bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
113 return llvm::isa<Float6E2M3FNType>(unwrap(type));
114}
115
116MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
117 return wrap(Float6E2M3FNType::get(unwrap(ctx)));
118}
119
121 return wrap(Float6E2M3FNType::name);
122}
123
125 return wrap(Float6E3M2FNType::getTypeID());
126}
127
128bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
129 return llvm::isa<Float6E3M2FNType>(unwrap(type));
130}
131
132MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
133 return wrap(Float6E3M2FNType::get(unwrap(ctx)));
134}
135
137 return wrap(Float6E3M2FNType::name);
138}
139
141 return wrap(Float8E5M2Type::getTypeID());
142}
143
144bool mlirTypeIsAFloat8E5M2(MlirType type) {
145 return llvm::isa<Float8E5M2Type>(unwrap(type));
146}
147
148MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
149 return wrap(Float8E5M2Type::get(unwrap(ctx)));
150}
151
153 return wrap(Float8E5M2Type::name);
154}
155
157 return wrap(Float8E4M3Type::getTypeID());
158}
159
160bool mlirTypeIsAFloat8E4M3(MlirType type) {
161 return llvm::isa<Float8E4M3Type>(unwrap(type));
162}
163
164MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
165 return wrap(Float8E4M3Type::get(unwrap(ctx)));
166}
167
169 return wrap(Float8E4M3Type::name);
170}
171
173 return wrap(Float8E4M3FNType::getTypeID());
174}
175
176bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
177 return llvm::isa<Float8E4M3FNType>(unwrap(type));
178}
179
180MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
181 return wrap(Float8E4M3FNType::get(unwrap(ctx)));
182}
183
185 return wrap(Float8E4M3FNType::name);
186}
187
189 return wrap(Float8E5M2FNUZType::getTypeID());
190}
191
192bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
193 return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
194}
195
196MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
197 return wrap(Float8E5M2FNUZType::get(unwrap(ctx)));
198}
199
201 return wrap(Float8E5M2FNUZType::name);
202}
203
205 return wrap(Float8E4M3FNUZType::getTypeID());
206}
207
208bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
209 return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
210}
211
212MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
213 return wrap(Float8E4M3FNUZType::get(unwrap(ctx)));
214}
215
217 return wrap(Float8E4M3FNUZType::name);
218}
219
221 return wrap(Float8E4M3B11FNUZType::getTypeID());
222}
223
224bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
225 return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
226}
227
228MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
229 return wrap(Float8E4M3B11FNUZType::get(unwrap(ctx)));
230}
231
233 return wrap(Float8E4M3B11FNUZType::name);
234}
235
237 return wrap(Float8E3M4Type::getTypeID());
238}
239
240bool mlirTypeIsAFloat8E3M4(MlirType type) {
241 return llvm::isa<Float8E3M4Type>(unwrap(type));
242}
243
244MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
245 return wrap(Float8E3M4Type::get(unwrap(ctx)));
246}
247
249 return wrap(Float8E3M4Type::name);
250}
251
253 return wrap(Float8E8M0FNUType::getTypeID());
254}
255
256bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
257 return llvm::isa<Float8E8M0FNUType>(unwrap(type));
258}
259
260MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
261 return wrap(Float8E8M0FNUType::get(unwrap(ctx)));
262}
263
265 return wrap(Float8E8M0FNUType::name);
266}
267
269 return wrap(BFloat16Type::getTypeID());
270}
271
272bool mlirTypeIsABF16(MlirType type) {
273 return llvm::isa<BFloat16Type>(unwrap(type));
274}
275
276MlirType mlirBF16TypeGet(MlirContext ctx) {
277 return wrap(BFloat16Type::get(unwrap(ctx)));
278}
279
280MlirStringRef mlirBF16TypeGetName(void) { return wrap(BFloat16Type::name); }
281
282MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
283
284bool mlirTypeIsAF16(MlirType type) {
285 return llvm::isa<Float16Type>(unwrap(type));
286}
287
288MlirType mlirF16TypeGet(MlirContext ctx) {
289 return wrap(Float16Type::get(unwrap(ctx)));
290}
291
292MlirStringRef mlirF16TypeGetName(void) { return wrap(Float16Type::name); }
293
295 return wrap(FloatTF32Type::getTypeID());
296}
297
298bool mlirTypeIsATF32(MlirType type) {
299 return llvm::isa<FloatTF32Type>(unwrap(type));
300}
301
302MlirType mlirTF32TypeGet(MlirContext ctx) {
303 return wrap(FloatTF32Type::get(unwrap(ctx)));
304}
305
306MlirStringRef mlirTF32TypeGetName(void) { return wrap(FloatTF32Type::name); }
307
308MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
309
310bool mlirTypeIsAF32(MlirType type) {
311 return llvm::isa<Float32Type>(unwrap(type));
312}
313
314MlirType mlirF32TypeGet(MlirContext ctx) {
315 return wrap(Float32Type::get(unwrap(ctx)));
316}
317
318MlirStringRef mlirF32TypeGetName(void) { return wrap(Float32Type::name); }
319
320MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
321
322bool mlirTypeIsAF64(MlirType type) {
323 return llvm::isa<Float64Type>(unwrap(type));
324}
325
326MlirType mlirF64TypeGet(MlirContext ctx) {
327 return wrap(Float64Type::get(unwrap(ctx)));
328}
329
330MlirStringRef mlirF64TypeGetName(void) { return wrap(Float64Type::name); }
331
332//===----------------------------------------------------------------------===//
333// None type.
334//===----------------------------------------------------------------------===//
335
336MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
337
338bool mlirTypeIsANone(MlirType type) {
339 return llvm::isa<NoneType>(unwrap(type));
340}
341
342MlirType mlirNoneTypeGet(MlirContext ctx) {
343 return wrap(NoneType::get(unwrap(ctx)));
344}
345
346MlirStringRef mlirNoneTypeGetName(void) { return wrap(NoneType::name); }
347
348//===----------------------------------------------------------------------===//
349// Complex type.
350//===----------------------------------------------------------------------===//
351
352MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
353
354bool mlirTypeIsAComplex(MlirType type) {
355 return llvm::isa<ComplexType>(unwrap(type));
356}
357
358MlirType mlirComplexTypeGet(MlirType elementType) {
359 return wrap(ComplexType::get(unwrap(elementType)));
360}
361
362MlirStringRef mlirComplexTypeGetName(void) { return wrap(ComplexType::name); }
363
364MlirType mlirComplexTypeGetElementType(MlirType type) {
365 return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
366}
367
368//===----------------------------------------------------------------------===//
369// Shaped type.
370//===----------------------------------------------------------------------===//
371
372bool mlirTypeIsAShaped(MlirType type) {
373 return llvm::isa<ShapedType>(unwrap(type));
374}
375
376MlirType mlirShapedTypeGetElementType(MlirType type) {
377 return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType());
378}
379
380bool mlirShapedTypeHasRank(MlirType type) {
381 return llvm::cast<ShapedType>(unwrap(type)).hasRank();
382}
383
385 return llvm::cast<ShapedType>(unwrap(type)).getRank();
386}
387
388bool mlirShapedTypeHasStaticShape(MlirType type) {
389 return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape();
390}
391
392bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
393 return llvm::cast<ShapedType>(unwrap(type))
394 .isDynamicDim(static_cast<unsigned>(dim));
395}
396
397bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
398 return llvm::cast<ShapedType>(unwrap(type))
399 .isStaticDim(static_cast<unsigned>(dim));
400}
401
403 return llvm::cast<ShapedType>(unwrap(type))
404 .getDimSize(static_cast<unsigned>(dim));
405}
406
407int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
408
410 return ShapedType::isDynamic(size);
411}
412
414 return ShapedType::isStatic(size);
415}
416
418 return ShapedType::isDynamic(val);
419}
420
422 return ShapedType::isStatic(val);
423}
424
426 return ShapedType::kDynamic;
427}
428
429//===----------------------------------------------------------------------===//
430// Vector type.
431//===----------------------------------------------------------------------===//
432
433MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
434
435bool mlirTypeIsAVector(MlirType type) {
436 return llvm::isa<VectorType>(unwrap(type));
437}
438
440 MlirType elementType) {
441 return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
442 unwrap(elementType)));
443}
444
445MlirStringRef mlirVectorTypeGetName(void) { return wrap(VectorType::name); }
446
447MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
448 const int64_t *shape, MlirType elementType) {
449 return wrap(VectorType::getChecked(
450 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
451 unwrap(elementType)));
452}
453
455 const bool *scalable, MlirType elementType) {
456 return wrap(VectorType::get(
457 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
458 llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
459}
460
461MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
462 const int64_t *shape,
463 const bool *scalable,
464 MlirType elementType) {
465 return wrap(VectorType::getChecked(
466 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
467 unwrap(elementType),
468 llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
469}
470
471bool mlirVectorTypeIsScalable(MlirType type) {
472 return cast<VectorType>(unwrap(type)).isScalable();
473}
474
475bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
476 return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
477}
478
479//===----------------------------------------------------------------------===//
480// Ranked / Unranked tensor type.
481//===----------------------------------------------------------------------===//
482
483bool mlirTypeIsATensor(MlirType type) {
484 return llvm::isa<TensorType>(unwrap(type));
485}
486
488 return wrap(RankedTensorType::getTypeID());
489}
490
491bool mlirTypeIsARankedTensor(MlirType type) {
492 return llvm::isa<RankedTensorType>(unwrap(type));
493}
494
496 return wrap(UnrankedTensorType::getTypeID());
497}
498
499bool mlirTypeIsAUnrankedTensor(MlirType type) {
500 return llvm::isa<UnrankedTensorType>(unwrap(type));
501}
502
504 MlirType elementType, MlirAttribute encoding) {
505 return wrap(
506 RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
507 unwrap(elementType), unwrap(encoding)));
508}
509
511 return wrap(RankedTensorType::name);
512}
513
514MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
515 const int64_t *shape,
516 MlirType elementType,
517 MlirAttribute encoding) {
518 return wrap(RankedTensorType::getChecked(
519 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
520 unwrap(elementType), unwrap(encoding)));
521}
522
523MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
524 return wrap(llvm::cast<RankedTensorType>(unwrap(type)).getEncoding());
525}
526
527MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
528 return wrap(UnrankedTensorType::get(unwrap(elementType)));
529}
530
532 return wrap(UnrankedTensorType::name);
533}
534
535MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
536 MlirType elementType) {
537 return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
538}
539
540//===----------------------------------------------------------------------===//
541// Ranked / Unranked MemRef type.
542//===----------------------------------------------------------------------===//
543
544MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
545
546bool mlirTypeIsAMemRef(MlirType type) {
547 return llvm::isa<MemRefType>(unwrap(type));
548}
549
550MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
551 const int64_t *shape, MlirAttribute layout,
552 MlirAttribute memorySpace) {
553 return wrap(MemRefType::get(
554 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
555 mlirAttributeIsNull(layout)
556 ? MemRefLayoutAttrInterface()
557 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
558 unwrap(memorySpace)));
559}
560
561MlirStringRef mlirMemRefTypeGetName(void) { return wrap(MemRefType::name); }
562
563MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
564 intptr_t rank, const int64_t *shape,
565 MlirAttribute layout,
566 MlirAttribute memorySpace) {
567 return wrap(MemRefType::getChecked(
568 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
569 unwrap(elementType),
570 mlirAttributeIsNull(layout)
571 ? MemRefLayoutAttrInterface()
572 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
573 unwrap(memorySpace)));
574}
575
576MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
577 const int64_t *shape,
578 MlirAttribute memorySpace) {
579 return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
580 unwrap(elementType), MemRefLayoutAttrInterface(),
581 unwrap(memorySpace)));
582}
583
584MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
585 MlirType elementType, intptr_t rank,
586 const int64_t *shape,
587 MlirAttribute memorySpace) {
588 return wrap(MemRefType::getChecked(
589 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
590 unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
591}
592
593MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
594 return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout());
595}
596
597MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
598 return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout().getAffineMap());
599}
600
601MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
602 return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
603}
604
606 int64_t *strides,
607 int64_t *offset) {
608 MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
609 SmallVector<int64_t> strides_;
610 if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
612
613 (void)llvm::copy(strides_, strides);
615}
616
618 return wrap(UnrankedMemRefType::getTypeID());
619}
620
621bool mlirTypeIsAUnrankedMemRef(MlirType type) {
622 return llvm::isa<UnrankedMemRefType>(unwrap(type));
623}
624
625MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
626 MlirAttribute memorySpace) {
627 return wrap(
628 UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
629}
630
632 return wrap(UnrankedMemRefType::name);
633}
634
635MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
636 MlirType elementType,
637 MlirAttribute memorySpace) {
638 return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
639 unwrap(memorySpace)));
640}
641
642MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
643 return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace());
644}
645
646//===----------------------------------------------------------------------===//
647// Tuple type.
648//===----------------------------------------------------------------------===//
649
650MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
651
652bool mlirTypeIsATuple(MlirType type) {
653 return llvm::isa<TupleType>(unwrap(type));
654}
655
656MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
657 MlirType const *elements) {
659 ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
660 return wrap(TupleType::get(unwrap(ctx), typeRef));
661}
662
663MlirStringRef mlirTupleTypeGetName(void) { return wrap(TupleType::name); }
664
666 return llvm::cast<TupleType>(unwrap(type)).size();
667}
668
669MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
670 return wrap(
671 llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos)));
672}
673
674//===----------------------------------------------------------------------===//
675// Function type.
676//===----------------------------------------------------------------------===//
677
679 return wrap(FunctionType::getTypeID());
680}
681
682bool mlirTypeIsAFunction(MlirType type) {
683 return llvm::isa<FunctionType>(unwrap(type));
684}
685
686MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
687 MlirType const *inputs, intptr_t numResults,
688 MlirType const *results) {
689 SmallVector<Type, 4> inputsList;
690 SmallVector<Type, 4> resultsList;
691 (void)unwrapList(numInputs, inputs, inputsList);
692 (void)unwrapList(numResults, results, resultsList);
693 return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
694}
695
696MlirStringRef mlirFunctionTypeGetName(void) { return wrap(FunctionType::name); }
697
699 return llvm::cast<FunctionType>(unwrap(type)).getNumInputs();
700}
701
703 return llvm::cast<FunctionType>(unwrap(type)).getNumResults();
704}
705
706MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
707 assert(pos >= 0 && "pos in array must be positive");
708 return wrap(llvm::cast<FunctionType>(unwrap(type))
709 .getInput(static_cast<unsigned>(pos)));
710}
711
712MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
713 assert(pos >= 0 && "pos in array must be positive");
714 return wrap(llvm::cast<FunctionType>(unwrap(type))
715 .getResult(static_cast<unsigned>(pos)));
716}
717
718//===----------------------------------------------------------------------===//
719// Opaque type.
720//===----------------------------------------------------------------------===//
721
722MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
723
724bool mlirTypeIsAOpaque(MlirType type) {
725 return llvm::isa<OpaqueType>(unwrap(type));
726}
727
728MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
729 MlirStringRef typeData) {
730 return wrap(
731 OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
732 unwrap(typeData)));
733}
734
735MlirStringRef mlirOpaqueTypeGetName(void) { return wrap(OpaqueType::name); }
736
738 return wrap(
739 llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());
740}
741
743 return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData());
744}
bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MlirTypeID mlirFloat6E2M3FNTypeGetTypeID()
Returns the typeID of an Float6E2M3FN type.
bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MlirStringRef mlirF16TypeGetName(void)
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MlirTypeID mlirIntegerTypeGetTypeID()
Returns the typeID of an Integer type.
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirTypeID mlirBFloat16TypeGetTypeID()
Returns the typeID of an BFloat16 type.
MlirStringRef mlirVectorTypeGetName(void)
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MlirStringRef mlirUnrankedTensorTypeGetName(void)
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
int64_t mlirShapedTypeGetDynamicStrideOrOffset()
Returns the value indicating a dynamic stride or offset in a shaped type.
MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MlirTypeID mlirFloatTF32TypeGetTypeID()
Returns the typeID of a TF32 type.
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MlirTypeID mlirComplexTypeGetTypeID()
Returns the typeID of an Complex type.
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MlirType mlirUnrankedTensorTypeGet(MlirType elementType)
Creates an unranked tensor type with the given element type in the same context as the element type.
MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MlirTypeID mlirFloat8E3M4TypeGetTypeID()
Returns the typeID of an Float8E3M4 type.
MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID()
Returns the typeID of an Float8E4M3B11FNUZ type.
MlirTypeID mlirFunctionTypeGetTypeID()
Returns the typeID of an Function type.
bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MlirTypeID mlirNoneTypeGetTypeID()
Returns the typeID of an None type.
MlirTypeID mlirUnrankedTensorTypeGetTypeID()
Returns the typeID of an UnrankedTensor type.
MlirTypeID mlirMemRefTypeGetTypeID()
Returns the typeID of an MemRef type.
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirStringRef mlirFloat8E8M0FNUTypeGetName(void)
bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MlirStringRef mlirFloat8E3M4TypeGetName(void)
bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MlirStringRef mlirUnrankedMemRefTypeGetName(void)
MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MlirStringRef mlirBF16TypeGetName(void)
bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MlirTypeID mlirIndexTypeGetTypeID()
Returns the typeID of an Index type.
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirStringRef mlirFloat8E5M2FNUZTypeGetName(void)
bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MlirTypeID mlirFloat6E3M2FNTypeGetTypeID()
Returns the typeID of an Float6E3M2FN type.
bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is dynamic.
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MlirStringRef mlirFloat6E2M3FNTypeGetName(void)
MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID()
Returns the typeID of an Float8E4M3FNUZ type.
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
int64_t mlirShapedTypeGetDynamicSize()
Returns the value indicating a dynamic size in a shaped type.
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID()
Returns the typeID of an Float8E4M3FN type.
bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MlirTypeID mlirFloat64TypeGetTypeID()
Returns the typeID of an Float64 type.
MlirStringRef mlirIndexTypeGetName(void)
MlirStringRef mlirFloat4E2M1FNTypeGetName(void)
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirStringRef mlirF32TypeGetName(void)
MlirStringRef mlirFloat8E4M3TypeGetName(void)
MlirStringRef mlirFloat8E5M2TypeGetName(void)
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps,...
bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MlirStringRef mlirF64TypeGetName(void)
intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace)
Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping MlirType on illegal arguments,...
bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is static.
bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MlirStringRef mlirTupleTypeGetName(void)
MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MlirTypeID mlirFloat8E4M3TypeGetTypeID()
Returns the typeID of an Float8E4M3 type.
MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID()
Returns the typeID of an Float8E5M2FNUZ type.
MlirStringRef mlirFunctionTypeGetName(void)
intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MlirTypeID mlirTupleTypeGetTypeID()
Returns the typeID of an Tuple type.
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Creates a scalable vector type with the shape identified by its rank and dimensions.
bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MlirTypeID mlirFloat16TypeGetTypeID()
Returns the typeID of an Float16 type.
MlirStringRef mlirFloat8E4M3FNUZTypeGetName(void)
MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MlirTypeID mlirFloat8E5M2TypeGetTypeID()
Returns the typeID of an Float8E5M2 type.
MlirTypeID mlirFloat32TypeGetTypeID()
Returns the typeID of an Float32 type.
bool mlirShapedTypeIsStaticSize(int64_t size)
Checks whether the given shaped type dimension value is statically-sized.
bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MlirStringRef mlirFloat6E3M2FNTypeGetName(void)
bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
bool mlirTypeIsATensor(MlirType type)
Checks whether the given type is a Tensor type.
MlirStringRef mlirMemRefTypeGetName(void)
bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MlirTypeID mlirRankedTensorTypeGetTypeID()
Returns the typeID of an RankedTensor type.
MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val)
Checks whether the given dimension value of a stride or an offset is statically-sized.
MlirStringRef mlirRankedTensorTypeGetName(void)
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID()
Returns the typeID of an Float8E8M0FNU type.
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType)
Creates a vector type of the shape identified by its rank and dimensions, with the given element type...
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace)
Creates a MemRef type with the given rank, shape, memory space and element type in the same context a...
MlirStringRef mlirTF32TypeGetName(void)
MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
MlirStringRef mlirComplexTypeGetName(void)
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MlirStringRef mlirFloat8E4M3B11FNUZTypeGetName(void)
MlirStringRef mlirNoneTypeGetName(void)
MlirStringRef mlirFloat8E4M3FNTypeGetName(void)
MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MlirTypeID mlirVectorTypeGetTypeID()
Returns the typeID of an Vector type.
bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MlirTypeID mlirOpaqueTypeGetTypeID()
Returns the typeID of an Opaque type.
MlirTypeID mlirUnrankedMemRefTypeGetTypeID()
Returns the typeID of an UnrankedMemRef type.
MlirStringRef mlirIntegerTypeGetName(void)
bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MlirStringRef mlirOpaqueTypeGetName(void)
MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MlirTypeID mlirFloat4E2M1FNTypeGetTypeID()
Returns the typeID of an Float4E2M1FN type.
unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace)
Creates an Unranked MemRef type with the given element type and in the given memory space.
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition Wrap.h:40
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:140
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:134
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
A logical result value, essentially a boolean with named states.
Definition Support.h:118
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:75