MLIR 22.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVParsingUtils.h"
16
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
28#include "llvm/ADT/Sequence.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31
32using namespace mlir;
33using namespace mlir::spirv;
34
35#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36
37//===----------------------------------------------------------------------===//
38// InlinerInterface
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42/// ops.
43static inline bool containsReturn(Region &region) {
44 return llvm::any_of(region, [](Block &block) {
45 Operation *terminator = block.getTerminator();
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47 });
48}
49
50namespace {
51/// This class defines the interface for inlining within the SPIR-V dialect.
52struct SPIRVInlinerInterface : public DialectInlinerInterface {
54
55 /// All call operations within SPIRV can be inlined.
56 bool isLegalToInline(Operation *call, Operation *callable,
57 bool wouldBeCloned) const final {
58 return true;
59 }
60
61 /// Returns true if the given region 'src' can be inlined into the region
62 /// 'dest' that is attached to an operation registered to the current dialect.
63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64 IRMapping &) const final {
65 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66 // spirv.mlir.loop operations.
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69 }
70
71 /// Returns true if the given operation 'op', that is registered to this
72 /// dialect, can be inlined into the region 'dest' that is attached to an
73 /// operation registered to the current dialect.
74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75 IRMapping &) const final {
76 // TODO: Enable inlining structured control flows with return.
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78 containsReturn(op->getRegion(0)))
79 return false;
80 // TODO: we need to filter OpKill here to avoid inlining it to
81 // a loop continue construct:
82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83 // For now, we just disallow inlining OpKill anywhere in the code,
84 // but this restriction should be relaxed, as pointed above.
85 if (isa<spirv::KillOp>(op))
86 return false;
87
88 return true;
89 }
90
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 auto builder = OpBuilder(op);
96 spirv::BranchOp::create(builder, op->getLoc(), newDest);
97 op->erase();
98 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99 auto builder = OpBuilder(op);
100 spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101 retValOp->getOperands());
102 op->erase();
103 }
104 }
105
106 /// Handle the given inlined terminator by replacing it with a new operation
107 /// as necessary.
108 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
109 // Only spirv.ReturnValue needs to be handled here.
110 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
111 if (!retValOp)
112 return;
113
114 // Replace the values directly with the return operands.
115 assert(valuesToRepl.size() == 1 &&
116 "spirv.ReturnValue expected to only handle one result");
117 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
118 }
119};
120} // namespace
121
122//===----------------------------------------------------------------------===//
123// SPIR-V Dialect
124//===----------------------------------------------------------------------===//
125
126void SPIRVDialect::initialize() {
127 registerAttributes();
128 registerTypes();
129
130 // Add SPIR-V ops.
131 addOperations<
132#define GET_OP_LIST
133#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134 >();
135
136 addInterfaces<SPIRVInlinerInterface>();
137
138 // Allow unknown operations because SPIR-V is extensible.
139 allowUnknownOperations();
140 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141}
142
143std::string SPIRVDialect::getAttributeName(Decoration decoration) {
144 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
145}
146
147//===----------------------------------------------------------------------===//
148// Type Parsing
149//===----------------------------------------------------------------------===//
150
151// Forward declarations.
152template <typename ValTy>
153static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
154 DialectAsmParser &parser);
155template <>
156std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
157 DialectAsmParser &parser);
158
159template <>
160std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
161 DialectAsmParser &parser);
162
163static Type parseAndVerifyType(SPIRVDialect const &dialect,
164 DialectAsmParser &parser) {
165 Type type;
166 SMLoc typeLoc = parser.getCurrentLocation();
167 if (parser.parseType(type))
168 return Type();
169
170 // Allow SPIR-V dialect types
171 if (&type.getDialect() == &dialect)
172 return type;
173
174 // Check other allowed types
175 if (auto t = llvm::dyn_cast<FloatType>(type)) {
176 // TODO: All float types are allowed for now, but this should be fixed.
177 } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
178 if (!ScalarType::isValid(t)) {
179 parser.emitError(typeLoc,
180 "only 1/8/16/32/64-bit integer type allowed but found ")
181 << type;
182 return Type();
183 }
184 } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
185 if (t.getRank() != 1) {
186 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
187 return Type();
188 }
189 if (t.getNumElements() > 4) {
190 parser.emitError(
191 typeLoc, "vector length has to be less than or equal to 4 but found ")
192 << t.getNumElements();
193 return Type();
194 }
195 } else if (auto t = dyn_cast<TensorArmType>(type)) {
196 if (!isa<ScalarType>(t.getElementType())) {
197 parser.emitError(
198 typeLoc, "only scalar element type allowed in tensor type but found ")
199 << t.getElementType();
200 return Type();
201 }
202 } else {
203 parser.emitError(typeLoc, "cannot use ")
204 << type << " to compose SPIR-V types";
205 return Type();
206 }
207
208 return type;
209}
210
211static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
212 DialectAsmParser &parser) {
213 Type type;
214 SMLoc typeLoc = parser.getCurrentLocation();
215 if (parser.parseType(type))
216 return Type();
217
218 if (auto t = llvm::dyn_cast<VectorType>(type)) {
219 if (t.getRank() != 1) {
220 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
221 return Type();
222 }
223 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
224 parser.emitError(typeLoc,
225 "matrix columns size has to be less than or equal "
226 "to 4 and greater than or equal 2, but found ")
227 << t.getNumElements();
228 return Type();
229 }
230
231 if (!llvm::isa<FloatType>(t.getElementType())) {
232 parser.emitError(typeLoc, "matrix columns' elements must be of "
233 "Float type, got ")
234 << t.getElementType();
235 return Type();
236 }
237 } else {
238 parser.emitError(typeLoc, "matrix must be composed using vector "
239 "type, got ")
240 << type;
241 return Type();
242 }
243
244 return type;
245}
246
247static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
248 DialectAsmParser &parser) {
249 Type type;
250 SMLoc typeLoc = parser.getCurrentLocation();
251 if (parser.parseType(type))
252 return Type();
253
254 auto imageType = dyn_cast<ImageType>(type);
255 if (!imageType) {
256 parser.emitError(typeLoc,
257 "sampled image must be composed using image type, got ")
258 << type;
259 return Type();
260 }
261
262 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, imageType.getDim())) {
263 parser.emitError(
264 typeLoc, "sampled image Dim must not be SubpassData or Buffer, got ")
265 << stringifyDim(imageType.getDim());
266 return Type();
267 }
268
269 return type;
270}
271
272/// Parses an optional `, stride = N` assembly segment. If no parsing failure
273/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
274/// missing.
275static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
276 DialectAsmParser &parser,
277 unsigned &stride) {
278 if (failed(parser.parseOptionalComma())) {
279 stride = 0;
280 return success();
281 }
282
283 if (parser.parseKeyword("stride") || parser.parseEqual())
284 return failure();
285
286 SMLoc strideLoc = parser.getCurrentLocation();
287 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
288 if (!optStride)
289 return failure();
290
291 if (!(stride = *optStride)) {
292 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
293 return failure();
294 }
295 return success();
296}
297
298// element-type ::= integer-type
299// | floating-point-type
300// | vector-type
301// | spirv-type
302//
303// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
304// (`,` `stride` `=` integer-literal)? `>`
305static Type parseArrayType(SPIRVDialect const &dialect,
306 DialectAsmParser &parser) {
307 if (parser.parseLess())
308 return Type();
309
310 SmallVector<int64_t, 1> countDims;
311 SMLoc countLoc = parser.getCurrentLocation();
312 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
313 return Type();
314 if (countDims.size() != 1) {
315 parser.emitError(countLoc,
316 "expected single integer for array element count");
317 return Type();
318 }
319
320 // According to the SPIR-V spec:
321 // "Length is the number of elements in the array. It must be at least 1."
322 int64_t count = countDims[0];
323 if (count == 0) {
324 parser.emitError(countLoc, "expected array length greater than 0");
325 return Type();
326 }
327
328 Type elementType = parseAndVerifyType(dialect, parser);
329 if (!elementType)
330 return Type();
331
332 unsigned stride = 0;
333 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
334 return Type();
335
336 if (parser.parseGreater())
337 return Type();
338 return ArrayType::get(elementType, count, stride);
339}
340
341// cooperative-matrix-type ::=
342// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
343// scope `,` use `>`
344static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
345 DialectAsmParser &parser) {
346 if (parser.parseLess())
347 return {};
348
350 SMLoc countLoc = parser.getCurrentLocation();
351 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
352 return {};
353
354 if (dims.size() != 2) {
355 parser.emitError(countLoc, "expected row and column count");
356 return {};
357 }
358
359 auto elementTy = parseAndVerifyType(dialect, parser);
360 if (!elementTy)
361 return {};
362
363 Scope scope;
364 if (parser.parseComma() ||
365 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
366 return {};
367
368 CooperativeMatrixUseKHR use;
369 if (parser.parseComma() ||
370 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
371 return {};
372
373 if (parser.parseGreater())
374 return {};
375
376 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
377}
378
379// tensor-arm-type ::=
380// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
381static Type parseTensorArmType(SPIRVDialect const &dialect,
382 DialectAsmParser &parser) {
383 if (parser.parseLess())
384 return {};
385
386 bool unranked = false;
388 SMLoc countLoc = parser.getCurrentLocation();
389
390 if (parser.parseOptionalStar().succeeded()) {
391 unranked = true;
392 if (parser.parseXInDimensionList())
393 return {};
394 } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
395 return {};
396 }
397
398 if (!unranked && dims.empty()) {
399 parser.emitError(countLoc, "arm.tensors do not support rank zero");
400 return {};
401 }
402
403 if (llvm::is_contained(dims, 0)) {
404 parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
405 return {};
406 }
407
408 if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
409 llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
410 parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
411 "fully dynamic or completed shaped");
412 return {};
413 }
414
415 auto elementTy = parseAndVerifyType(dialect, parser);
416 if (!elementTy)
417 return {};
418
419 if (parser.parseGreater())
420 return {};
421
422 return TensorArmType::get(dims, elementTy);
423}
424
425// TODO: Reorder methods to be utilities first and parse*Type
426// methods in alphabetical order
427//
428// storage-class ::= `UniformConstant`
429// | `Uniform`
430// | `Workgroup`
431// | <and other storage classes...>
432//
433// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
434static Type parsePointerType(SPIRVDialect const &dialect,
435 DialectAsmParser &parser) {
436 if (parser.parseLess())
437 return Type();
438
439 auto pointeeType = parseAndVerifyType(dialect, parser);
440 if (!pointeeType)
441 return Type();
442
443 StringRef storageClassSpec;
444 SMLoc storageClassLoc = parser.getCurrentLocation();
445 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
446 return Type();
447
448 auto storageClass = symbolizeStorageClass(storageClassSpec);
449 if (!storageClass) {
450 parser.emitError(storageClassLoc, "unknown storage class: ")
451 << storageClassSpec;
452 return Type();
453 }
454 if (parser.parseGreater())
455 return Type();
456 return PointerType::get(pointeeType, *storageClass);
457}
458
459// runtime-array-type ::= `!spirv.rtarray` `<` element-type
460// (`,` `stride` `=` integer-literal)? `>`
461static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
462 DialectAsmParser &parser) {
463 if (parser.parseLess())
464 return Type();
465
466 Type elementType = parseAndVerifyType(dialect, parser);
467 if (!elementType)
468 return Type();
469
470 unsigned stride = 0;
471 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
472 return Type();
473
474 if (parser.parseGreater())
475 return Type();
476 return RuntimeArrayType::get(elementType, stride);
477}
478
479// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
480static Type parseMatrixType(SPIRVDialect const &dialect,
481 DialectAsmParser &parser) {
482 if (parser.parseLess())
483 return Type();
484
485 SmallVector<int64_t, 1> countDims;
486 SMLoc countLoc = parser.getCurrentLocation();
487 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
488 return Type();
489 if (countDims.size() != 1) {
490 parser.emitError(countLoc, "expected single unsigned "
491 "integer for number of columns");
492 return Type();
493 }
494
495 int64_t columnCount = countDims[0];
496 // According to the specification, Matrices can have 2, 3, or 4 columns
497 if (columnCount < 2 || columnCount > 4) {
498 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
499 "columns");
500 return Type();
501 }
502
503 Type columnType = parseAndVerifyMatrixType(dialect, parser);
504 if (!columnType)
505 return Type();
506
507 if (parser.parseGreater())
508 return Type();
509
510 return MatrixType::get(columnType, columnCount);
511}
512
513// Specialize this function to parse each of the parameters that define an
514// ImageType. By default it assumes this is an enum type.
515template <typename ValTy>
516static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
517 DialectAsmParser &parser) {
518 StringRef enumSpec;
519 SMLoc enumLoc = parser.getCurrentLocation();
520 if (parser.parseKeyword(&enumSpec)) {
521 return std::nullopt;
522 }
523
524 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
525 if (!val)
526 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
527 return val;
528}
529
530template <>
531std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
532 DialectAsmParser &parser) {
533 // TODO: Further verify that the element type can be sampled
534 auto ty = parseAndVerifyType(dialect, parser);
535 if (!ty)
536 return std::nullopt;
537 return ty;
538}
539
540template <typename IntTy>
541static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
542 DialectAsmParser &parser) {
543 IntTy offsetVal = std::numeric_limits<IntTy>::max();
544 if (parser.parseInteger(offsetVal))
545 return std::nullopt;
546 return offsetVal;
547}
548
549template <>
550std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
551 DialectAsmParser &parser) {
552 return parseAndVerifyInteger<unsigned>(dialect, parser);
553}
554
555namespace {
556// Functor object to parse a comma separated list of specs. The function
557// parseAndVerify does the actual parsing and verification of individual
558// elements. This is a functor since parsing the last element of the list
559// (termination condition) needs partial specialization.
560template <typename ParseType, typename... Args>
561struct ParseCommaSeparatedList {
562 std::optional<std::tuple<ParseType, Args...>>
563 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
564 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
565 if (!parseVal)
566 return std::nullopt;
567
568 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
569 if (numArgs != 0 && failed(parser.parseComma()))
570 return std::nullopt;
571 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
572 if (!remainingValues)
573 return std::nullopt;
574 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
575 remainingValues.value());
576 }
577};
578
579// Partial specialization of the function to parse a comma separated list of
580// specs to parse the last element of the list.
581template <typename ParseType>
582struct ParseCommaSeparatedList<ParseType> {
583 std::optional<std::tuple<ParseType>>
584 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
585 if (auto value = parseAndVerify<ParseType>(dialect, parser))
586 return std::tuple<ParseType>(*value);
587 return std::nullopt;
588 }
589};
590} // namespace
591
592// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
593//
594// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
595//
596// arrayed-info ::= `NonArrayed` | `Arrayed`
597//
598// sampling-info ::= `SingleSampled` | `MultiSampled`
599//
600// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
601//
602// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
603//
604// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
605// arrayed-info `,` sampling-info `,`
606// sampler-use-info `,` format `>`
607static Type parseImageType(SPIRVDialect const &dialect,
608 DialectAsmParser &parser) {
609 if (parser.parseLess())
610 return Type();
611
612 auto value =
613 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
614 ImageSamplingInfo, ImageSamplerUseInfo,
615 ImageFormat>{}(dialect, parser);
616 if (!value)
617 return Type();
618
619 if (parser.parseGreater())
620 return Type();
621 return ImageType::get(*value);
622}
623
624// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
625static Type parseSampledImageType(SPIRVDialect const &dialect,
626 DialectAsmParser &parser) {
627 if (parser.parseLess())
628 return Type();
629
630 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
631 if (!parsedType)
632 return Type();
633
634 if (parser.parseGreater())
635 return Type();
636 return SampledImageType::get(parsedType);
637}
638
639// Parse decorations associated with a member.
641 SPIRVDialect const &dialect, DialectAsmParser &parser,
642 ArrayRef<Type> memberTypes,
645
646 // Check if the first element is offset.
647 SMLoc offsetLoc = parser.getCurrentLocation();
648 StructType::OffsetInfo offset = 0;
649 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
650 if (offsetParseResult.has_value()) {
651 if (failed(*offsetParseResult))
652 return failure();
653
654 if (offsetInfo.size() != memberTypes.size() - 1) {
655 return parser.emitError(offsetLoc,
656 "offset specification must be given for "
657 "all members");
658 }
659 offsetInfo.push_back(offset);
660 }
661
662 // Check for no spirv::Decorations.
663 if (succeeded(parser.parseOptionalRSquare()))
664 return success();
665
666 // If there was an offset, make sure to parse the comma.
667 if (offsetParseResult.has_value() && parser.parseComma())
668 return failure();
669
670 // Check for spirv::Decorations.
671 auto parseDecorations = [&]() {
672 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
673 if (!memberDecoration)
674 return failure();
675
676 // Parse member decoration value if it exists.
677 if (succeeded(parser.parseOptionalEqual())) {
678 Attribute memberDecorationValue;
679 if (failed(parser.parseAttribute(memberDecorationValue)))
680 return failure();
681
682 memberDecorationInfo.emplace_back(
683 static_cast<uint32_t>(memberTypes.size() - 1),
684 memberDecoration.value(), memberDecorationValue);
685 } else {
686 memberDecorationInfo.emplace_back(
687 static_cast<uint32_t>(memberTypes.size() - 1),
688 memberDecoration.value(), UnitAttr::get(dialect.getContext()));
689 }
690 return success();
691 };
692 if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
693 failed(parser.parseRSquare()))
694 return failure();
695
696 return success();
697}
698
699// struct-member-decoration ::= integer-literal? spirv-decoration*
700// struct-type ::=
701// `!spirv.struct<` (id `,`)?
702// `(`
703// (spirv-type (`[` struct-member-decoration `]`)?)*
704// `)`
705// (`,` struct-decoration)?
706// `>`
707static Type parseStructType(SPIRVDialect const &dialect,
708 DialectAsmParser &parser) {
709 // TODO: This function is quite lengthy. Break it down into smaller chunks.
710
711 if (parser.parseLess())
712 return Type();
713
714 StringRef identifier;
715 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
716
717 // Check if this is an identified struct type.
718 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
719 // Check if this is a possible recursive reference.
720 auto structType =
721 StructType::getIdentified(dialect.getContext(), identifier);
722 cyclicParse = parser.tryStartCyclicParse(structType);
723 if (succeeded(parser.parseOptionalGreater())) {
724 if (succeeded(cyclicParse)) {
725 parser.emitError(
726 parser.getNameLoc(),
727 "recursive struct reference not nested in struct definition");
728
729 return Type();
730 }
731
732 return structType;
733 }
734
735 if (failed(parser.parseComma()))
736 return Type();
737
738 if (failed(cyclicParse)) {
739 parser.emitError(parser.getNameLoc(),
740 "identifier already used for an enclosing struct");
741 return Type();
742 }
743 }
744
745 if (failed(parser.parseLParen()))
746 return Type();
747
748 if (succeeded(parser.parseOptionalRParen()) &&
749 succeeded(parser.parseOptionalGreater())) {
750 return StructType::getEmpty(dialect.getContext(), identifier);
751 }
752
753 StructType idStructTy;
754
755 if (!identifier.empty())
756 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
757
758 SmallVector<Type, 4> memberTypes;
761
762 do {
763 Type memberType;
764 if (parser.parseType(memberType))
765 return Type();
766 memberTypes.push_back(memberType);
767
768 if (succeeded(parser.parseOptionalLSquare()))
769 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
770 memberDecorationInfo))
771 return Type();
772 } while (succeeded(parser.parseOptionalComma()));
773
774 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
775 parser.emitError(parser.getNameLoc(),
776 "offset specification must be given for all members");
777 return Type();
778 }
779
780 if (failed(parser.parseRParen()))
781 return Type();
782
784
785 auto parseStructDecoration = [&]() {
786 std::optional<spirv::Decoration> decoration =
787 parseAndVerify<spirv::Decoration>(dialect, parser);
788 if (!decoration)
789 return failure();
790
791 // Parse decoration value if it exists.
792 if (succeeded(parser.parseOptionalEqual())) {
793 Attribute decorationValue;
794 if (failed(parser.parseAttribute(decorationValue)))
795 return failure();
796
797 structDecorationInfo.emplace_back(decoration.value(), decorationValue);
798 } else {
799 structDecorationInfo.emplace_back(decoration.value(),
800 UnitAttr::get(dialect.getContext()));
801 }
802 return success();
803 };
804
805 while (succeeded(parser.parseOptionalComma()))
806 if (failed(parseStructDecoration()))
807 return Type();
808
809 if (failed(parser.parseGreater()))
810 return Type();
811
812 if (!identifier.empty()) {
813 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
814 memberDecorationInfo,
815 structDecorationInfo)))
816 return Type();
817 return idStructTy;
818 }
819
820 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
821 structDecorationInfo);
822}
823
824// spirv-type ::= array-type
825// | element-type
826// | image-type
827// | pointer-type
828// | runtime-array-type
829// | sampled-image-type
830// | struct-type
831Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
832 StringRef keyword;
833 if (parser.parseKeyword(&keyword))
834 return Type();
835
836 if (keyword == "array")
837 return parseArrayType(*this, parser);
838 if (keyword == "coopmatrix")
839 return parseCooperativeMatrixType(*this, parser);
840 if (keyword == "image")
841 return parseImageType(*this, parser);
842 if (keyword == "ptr")
843 return parsePointerType(*this, parser);
844 if (keyword == "rtarray")
845 return parseRuntimeArrayType(*this, parser);
846 if (keyword == "sampled_image")
847 return parseSampledImageType(*this, parser);
848 if (keyword == "struct")
849 return parseStructType(*this, parser);
850 if (keyword == "matrix")
851 return parseMatrixType(*this, parser);
852 if (keyword == "arm.tensor")
853 return parseTensorArmType(*this, parser);
854 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
855 return Type();
856}
857
858//===----------------------------------------------------------------------===//
859// Type Printing
860//===----------------------------------------------------------------------===//
861
862static void print(ArrayType type, DialectAsmPrinter &os) {
863 os << "array<" << type.getNumElements() << " x " << type.getElementType();
864 if (unsigned stride = type.getArrayStride())
865 os << ", stride=" << stride;
866 os << ">";
867}
868
870 os << "rtarray<" << type.getElementType();
871 if (unsigned stride = type.getArrayStride())
872 os << ", stride=" << stride;
873 os << ">";
874}
875
876static void print(PointerType type, DialectAsmPrinter &os) {
877 os << "ptr<" << type.getPointeeType() << ", "
878 << stringifyStorageClass(type.getStorageClass()) << ">";
879}
880
881static void print(ImageType type, DialectAsmPrinter &os) {
882 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
883 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
884 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
885 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
886 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
887 << stringifyImageFormat(type.getImageFormat()) << ">";
888}
889
891 os << "sampled_image<" << type.getImageType() << ">";
892}
893
894static void print(StructType type, DialectAsmPrinter &os) {
895 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
896
897 os << "struct<";
898
899 if (type.isIdentified()) {
900 os << type.getIdentifier();
901
902 cyclicPrint = os.tryStartCyclicPrint(type);
903 if (failed(cyclicPrint)) {
904 os << ">";
905 return;
906 }
907
908 os << ", ";
909 }
910
911 os << "(";
912
913 auto printMember = [&](unsigned i) {
914 os << type.getElementType(i);
916 type.getMemberDecorations(i, decorations);
917 if (type.hasOffset() || !decorations.empty()) {
918 os << " [";
919 if (type.hasOffset()) {
920 os << type.getMemberOffset(i);
921 if (!decorations.empty())
922 os << ", ";
923 }
924 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
925 os << stringifyDecoration(decoration.decoration);
926 if (decoration.hasValue()) {
927 os << "=";
928 os.printAttributeWithoutType(decoration.decorationValue);
929 }
930 };
931 llvm::interleaveComma(decorations, os, eachFn);
932 os << "]";
933 }
934 };
935 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
936 printMember);
937 os << ")";
938
940 type.getStructDecorations(decorations);
941 if (!decorations.empty()) {
942 os << ", ";
943 auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
944 os << stringifyDecoration(decoration.decoration);
945 if (decoration.hasValue()) {
946 os << "=";
947 os.printAttributeWithoutType(decoration.decorationValue);
948 }
949 };
950 llvm::interleaveComma(decorations, os, eachFn);
951 }
952
953 os << ">";
954}
955
957 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
958 << type.getElementType() << ", " << type.getScope() << ", "
959 << type.getUse() << ">";
960}
961
962static void print(MatrixType type, DialectAsmPrinter &os) {
963 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
964 os << ">";
965}
966
967static void print(TensorArmType type, DialectAsmPrinter &os) {
968 os << "arm.tensor<";
969
970 llvm::interleave(
971 type.getShape(), os,
972 [&](int64_t dim) {
973 if (ShapedType::isDynamic(dim))
974 os << '?';
975 else
976 os << dim;
977 },
978 "x");
979 if (!type.hasRank()) {
980 os << "*";
981 }
982 os << "x" << type.getElementType() << ">";
983}
984
985void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
986 TypeSwitch<Type>(type)
989 [&](auto type) { print(type, os); })
990 .DefaultUnreachable("Unhandled SPIR-V type");
991}
992
993//===----------------------------------------------------------------------===//
994// Constant
995//===----------------------------------------------------------------------===//
996
997Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
998 Attribute value, Type type,
999 Location loc) {
1000 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
1001 return ub::PoisonOp::create(builder, loc, type, poison);
1002
1003 if (!spirv::ConstantOp::isBuildableWith(type))
1004 return nullptr;
1005
1006 return spirv::ConstantOp::create(builder, loc, type, value);
1007}
1008
1009//===----------------------------------------------------------------------===//
1010// Shader Interface ABI
1011//===----------------------------------------------------------------------===//
1012
1013LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1014 NamedAttribute attribute) {
1015 StringRef symbol = attribute.getName().strref();
1016 Attribute attr = attribute.getValue();
1017
1018 if (symbol == spirv::getEntryPointABIAttrName()) {
1019 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
1020 return op->emitError("'")
1021 << symbol << "' attribute must be an entry point ABI attribute";
1022 }
1023 } else if (symbol == spirv::getTargetEnvAttrName()) {
1024 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
1025 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1026 } else {
1027 return op->emitError("found unsupported '")
1028 << symbol << "' attribute on operation";
1029 }
1030
1031 return success();
1032}
1033
1034/// Verifies the given SPIR-V `attribute` attached to a value of the given
1035/// `valueType` is valid.
1036static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1037 NamedAttribute attribute) {
1038 StringRef symbol = attribute.getName().strref();
1039 Attribute attr = attribute.getValue();
1040
1041 if (symbol == spirv::getInterfaceVarABIAttrName()) {
1042 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1043 if (!varABIAttr)
1044 return emitError(loc, "'")
1045 << symbol << "' must be a spirv::InterfaceVarABIAttr";
1046
1047 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1048 return emitError(loc, "'") << symbol
1049 << "' attribute cannot specify storage class "
1050 "when attaching to a non-scalar value";
1051 return success();
1052 }
1053 if (symbol == spirv::DecorationAttr::name) {
1054 if (!isa<spirv::DecorationAttr>(attr))
1055 return emitError(loc, "'")
1056 << symbol << "' must be a spirv::DecorationAttr";
1057 return success();
1058 }
1059
1060 return emitError(loc, "found unsupported '")
1061 << symbol << "' attribute on region argument";
1062}
1063
1064LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1065 unsigned regionIndex,
1066 unsigned argIndex,
1067 NamedAttribute attribute) {
1068 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1069 if (!funcOp)
1070 return success();
1071 Type argType = funcOp.getArgumentTypes()[argIndex];
1072
1073 return verifyRegionAttribute(op->getLoc(), argType, attribute);
1074}
1075
1076LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1077 Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
1078 NamedAttribute attribute) {
1079 if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1080 return verifyRegionAttribute(
1081 op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1082 return op->emitError(
1083 "cannot attach SPIR-V attributes to region result which is "
1084 "not part of a spirv::GraphARMOp type");
1085}
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType get(Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144