MLIR 23.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVParsingUtils.h"
16
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
28#include "llvm/ADT/Sequence.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31
32using namespace mlir;
33using namespace mlir::spirv;
34
35#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36
37//===----------------------------------------------------------------------===//
38// InlinerInterface
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42/// ops.
43static inline bool containsReturn(Region &region) {
44 return llvm::any_of(region, [](Block &block) {
45 Operation *terminator = block.getTerminator();
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47 });
48}
49
50namespace {
51/// This class defines the interface for inlining within the SPIR-V dialect.
52struct SPIRVInlinerInterface : public DialectInlinerInterface {
53 using DialectInlinerInterface::DialectInlinerInterface;
54
55 /// All call operations within SPIRV can be inlined.
56 bool isLegalToInline(Operation *call, Operation *callable,
57 bool wouldBeCloned) const final {
58 return true;
59 }
60
61 /// Returns true if the given region 'src' can be inlined into the region
62 /// 'dest' that is attached to an operation registered to the current dialect.
63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64 IRMapping &) const final {
65 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66 // spirv.mlir.loop operations.
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69 }
70
71 /// Returns true if the given operation 'op', that is registered to this
72 /// dialect, can be inlined into the region 'dest' that is attached to an
73 /// operation registered to the current dialect.
74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75 IRMapping &) const final {
76 // TODO: Enable inlining structured control flows with return.
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78 containsReturn(op->getRegion(0)))
79 return false;
80 // TODO: we need to filter OpKill here to avoid inlining it to
81 // a loop continue construct:
82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83 // For now, we just disallow inlining OpKill anywhere in the code,
84 // but this restriction should be relaxed, as pointed above.
85 if (isa<spirv::KillOp>(op))
86 return false;
87
88 return true;
89 }
90
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 auto builder = OpBuilder(op);
96 spirv::BranchOp::create(builder, op->getLoc(), newDest);
97 op->erase();
98 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99 auto builder = OpBuilder(op);
100 spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101 retValOp->getOperands());
102 op->erase();
103 }
104 }
105
106 /// Handle the given inlined terminator by replacing it with a new operation
107 /// as necessary.
108 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
109 // Only spirv.ReturnValue needs to be handled here.
110 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
111 if (!retValOp)
112 return;
113
114 // Replace the values directly with the return operands.
115 assert(valuesToRepl.size() == 1 &&
116 "spirv.ReturnValue expected to only handle one result");
117 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
118 }
119};
120} // namespace
121
122//===----------------------------------------------------------------------===//
123// SPIR-V Dialect
124//===----------------------------------------------------------------------===//
125
126void SPIRVDialect::initialize() {
127 registerAttributes();
128 registerTypes();
129
130 // Add SPIR-V ops.
131 addOperations<
132#define GET_OP_LIST
133#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134 >();
135
136 addInterfaces<SPIRVInlinerInterface>();
137
138 // Allow unknown operations because SPIR-V is extensible.
139 allowUnknownOperations();
140 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141}
142
143std::string SPIRVDialect::getAttributeName(Decoration decoration) {
144 return getDecorationString(decoration);
145}
146
147//===----------------------------------------------------------------------===//
148// Type Parsing
149//===----------------------------------------------------------------------===//
150
151// Forward declarations.
152template <typename ValTy>
153static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
154 DialectAsmParser &parser);
155template <>
156std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
157 DialectAsmParser &parser);
158
159template <>
160std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
161 DialectAsmParser &parser);
162
163static Type parseAndVerifyType(SPIRVDialect const &dialect,
164 DialectAsmParser &parser) {
165 Type type;
166 SMLoc typeLoc = parser.getCurrentLocation();
167 if (parser.parseType(type))
168 return Type();
169
170 // Allow SPIR-V dialect types.
171 if (&type.getDialect() == &dialect)
172 return type;
173
174 // Check other allowed types.
175 if (auto t = dyn_cast<FloatType>(type)) {
176 // TODO: All float types are allowed for now, but this should be fixed.
177 } else if (auto t = dyn_cast<IntegerType>(type)) {
178 if (!ScalarType::isValid(t)) {
179 parser.emitError(typeLoc,
180 "only 1/8/16/32/64-bit integer type allowed but found ")
181 << type;
182 return Type();
183 }
184 } else if (auto t = dyn_cast<VectorType>(type)) {
185 if (t.getRank() != 1) {
186 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
187 return Type();
188 }
189 if (t.getNumElements() < 2) {
190 parser.emitError(typeLoc, "SPIR-V does not allow one-element vectors");
191 return Type();
192 }
193 if (t.getNumElements() > 4) {
194 parser.emitError(
195 typeLoc, "vector length has to be less than or equal to 4 but found ")
196 << t.getNumElements();
197 return Type();
198 }
199 if (!isa<ScalarType>(t.getElementType())) {
200 parser.emitError(
201 typeLoc,
202 "vector element type must be a SPIR-V scalar type but found ")
203 << t.getElementType();
204 return Type();
205 }
206 } else if (auto t = dyn_cast<TensorArmType>(type)) {
207 if (!isa<ScalarType>(t.getElementType())) {
208 parser.emitError(
209 typeLoc, "only scalar element type allowed in tensor type but found ")
210 << t.getElementType();
211 return Type();
212 }
213 } else {
214 parser.emitError(typeLoc, "cannot use ")
215 << type << " to compose SPIR-V types";
216 return Type();
217 }
218
219 return type;
220}
221
222static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
223 DialectAsmParser &parser) {
224 Type type;
225 SMLoc typeLoc = parser.getCurrentLocation();
226 if (parser.parseType(type))
227 return Type();
228
229 if (auto t = dyn_cast<VectorType>(type)) {
230 if (t.getRank() != 1) {
231 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
232 return Type();
233 }
234 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
235 parser.emitError(typeLoc,
236 "matrix columns size has to be less than or equal "
237 "to 4 and greater than or equal 2, but found ")
238 << t.getNumElements();
239 return Type();
240 }
241
242 if (!isa<FloatType>(t.getElementType())) {
243 parser.emitError(typeLoc, "matrix columns' elements must be of "
244 "Float type, got ")
245 << t.getElementType();
246 return Type();
247 }
248 } else {
249 parser.emitError(typeLoc, "matrix must be composed using vector "
250 "type, got ")
251 << type;
252 return Type();
253 }
254
255 return type;
256}
257
258static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
259 DialectAsmParser &parser) {
260 Type type;
261 SMLoc typeLoc = parser.getCurrentLocation();
262 if (parser.parseType(type))
263 return Type();
264
265 auto imageType = dyn_cast<ImageType>(type);
266 if (!imageType) {
267 parser.emitError(typeLoc,
268 "sampled image must be composed using image type, got ")
269 << type;
270 return Type();
271 }
272
273 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, imageType.getDim())) {
274 parser.emitError(
275 typeLoc, "sampled image Dim must not be SubpassData or Buffer, got ")
276 << stringifyDim(imageType.getDim());
277 return Type();
278 }
279
280 return type;
281}
282
283/// Parses an optional `, stride = N` assembly segment. If no parsing failure
284/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
285/// missing.
286static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
287 DialectAsmParser &parser,
288 unsigned &stride) {
289 if (failed(parser.parseOptionalComma())) {
290 stride = 0;
291 return success();
292 }
293
294 if (parser.parseKeyword("stride") || parser.parseEqual())
295 return failure();
296
297 SMLoc strideLoc = parser.getCurrentLocation();
298 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
299 if (!optStride)
300 return failure();
301
302 if (!(stride = *optStride)) {
303 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
304 return failure();
305 }
306 return success();
307}
308
309// element-type ::= integer-type
310// | floating-point-type
311// | vector-type
312// | spirv-type
313//
314// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
315// (`,` `stride` `=` integer-literal)? `>`
316static Type parseArrayType(SPIRVDialect const &dialect,
317 DialectAsmParser &parser) {
318 if (parser.parseLess())
319 return Type();
320
321 SmallVector<int64_t, 1> countDims;
322 SMLoc countLoc = parser.getCurrentLocation();
323 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
324 return Type();
325 if (countDims.size() != 1) {
326 parser.emitError(countLoc,
327 "expected single integer for array element count");
328 return Type();
329 }
330
331 // According to the SPIR-V spec:
332 // "Length is the number of elements in the array. It must be at least 1."
333 int64_t count = countDims[0];
334 if (count == 0) {
335 parser.emitError(countLoc, "expected array length greater than 0");
336 return Type();
337 }
338
339 Type elementType = parseAndVerifyType(dialect, parser);
340 if (!elementType)
341 return Type();
342
343 unsigned stride = 0;
344 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
345 return Type();
346
347 if (parser.parseGreater())
348 return Type();
349 return ArrayType::get(elementType, count, stride);
350}
351
352// cooperative-matrix-type ::=
353// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
354// scope `,` use `>`
355static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
356 DialectAsmParser &parser) {
357 if (parser.parseLess())
358 return {};
359
361 SMLoc countLoc = parser.getCurrentLocation();
362 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
363 return {};
364
365 if (dims.size() != 2) {
366 parser.emitError(countLoc, "expected row and column count");
367 return {};
368 }
369
370 auto elementTy = parseAndVerifyType(dialect, parser);
371 if (!elementTy)
372 return {};
373
374 Scope scope;
375 if (parser.parseComma() ||
376 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
377 return {};
378
379 CooperativeMatrixUseKHR use;
380 if (parser.parseComma() ||
381 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
382 return {};
383
384 if (parser.parseGreater())
385 return {};
386
387 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
388}
389
390// tensor-arm-type ::=
391// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
392static Type parseTensorArmType(SPIRVDialect const &dialect,
393 DialectAsmParser &parser) {
394 if (parser.parseLess())
395 return {};
396
397 bool unranked = false;
399 SMLoc countLoc = parser.getCurrentLocation();
400
401 if (parser.parseOptionalStar().succeeded()) {
402 unranked = true;
403 if (parser.parseXInDimensionList())
404 return {};
405 } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
406 return {};
407 }
408
409 if (!unranked && dims.empty()) {
410 parser.emitError(countLoc, "arm.tensors do not support rank zero");
411 return {};
412 }
413
414 if (llvm::is_contained(dims, 0)) {
415 parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
416 return {};
417 }
418
419 if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
420 llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
421 parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
422 "fully dynamic or completed shaped");
423 return {};
424 }
425
426 auto elementTy = parseAndVerifyType(dialect, parser);
427 if (!elementTy)
428 return {};
429
430 if (parser.parseGreater())
431 return {};
432
433 return TensorArmType::get(dims, elementTy);
434}
435
436// TODO: Reorder methods to be utilities first and parse*Type
437// methods in alphabetical order
438//
439// storage-class ::= `UniformConstant`
440// | `Uniform`
441// | `Workgroup`
442// | <and other storage classes...>
443//
444// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
445static Type parsePointerType(SPIRVDialect const &dialect,
446 DialectAsmParser &parser) {
447 if (parser.parseLess())
448 return Type();
449
450 auto pointeeType = parseAndVerifyType(dialect, parser);
451 if (!pointeeType)
452 return Type();
453
454 StringRef storageClassSpec;
455 SMLoc storageClassLoc = parser.getCurrentLocation();
456 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
457 return Type();
458
459 auto storageClass = symbolizeStorageClass(storageClassSpec);
460 if (!storageClass) {
461 parser.emitError(storageClassLoc, "unknown storage class: ")
462 << storageClassSpec;
463 return Type();
464 }
465 if (parser.parseGreater())
466 return Type();
467 return PointerType::get(pointeeType, *storageClass);
468}
469
470// runtime-array-type ::= `!spirv.rtarray` `<` element-type
471// (`,` `stride` `=` integer-literal)? `>`
472static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
473 DialectAsmParser &parser) {
474 if (parser.parseLess())
475 return Type();
476
477 Type elementType = parseAndVerifyType(dialect, parser);
478 if (!elementType)
479 return Type();
480
481 unsigned stride = 0;
482 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
483 return Type();
484
485 if (parser.parseGreater())
486 return Type();
487 return RuntimeArrayType::get(elementType, stride);
488}
489
490// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
491static Type parseMatrixType(SPIRVDialect const &dialect,
492 DialectAsmParser &parser) {
493 if (parser.parseLess())
494 return Type();
495
496 SmallVector<int64_t, 1> countDims;
497 SMLoc countLoc = parser.getCurrentLocation();
498 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
499 return Type();
500 if (countDims.size() != 1) {
501 parser.emitError(countLoc, "expected single unsigned "
502 "integer for number of columns");
503 return Type();
504 }
505
506 int64_t columnCount = countDims[0];
507 // According to the specification, Matrices can have 2, 3, or 4 columns
508 if (columnCount < 2 || columnCount > 4) {
509 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
510 "columns");
511 return Type();
512 }
513
514 Type columnType = parseAndVerifyMatrixType(dialect, parser);
515 if (!columnType)
516 return Type();
517
518 if (parser.parseGreater())
519 return Type();
520
521 return MatrixType::get(columnType, columnCount);
522}
523
524// Specialize this function to parse each of the parameters that define an
525// ImageType. By default it assumes this is an enum type.
526template <typename ValTy>
527static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
528 DialectAsmParser &parser) {
529 StringRef enumSpec;
530 SMLoc enumLoc = parser.getCurrentLocation();
531 if (parser.parseKeyword(&enumSpec)) {
532 return std::nullopt;
533 }
534
535 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
536 if (!val)
537 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
538 return val;
539}
540
541template <>
542std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
543 DialectAsmParser &parser) {
544 // TODO: Further verify that the element type can be sampled
545 auto ty = parseAndVerifyType(dialect, parser);
546 if (!ty)
547 return std::nullopt;
548 return ty;
549}
550
551template <typename IntTy>
552static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
553 DialectAsmParser &parser) {
554 IntTy offsetVal = std::numeric_limits<IntTy>::max();
555 if (parser.parseInteger(offsetVal))
556 return std::nullopt;
557 return offsetVal;
558}
559
560template <>
561std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
562 DialectAsmParser &parser) {
563 return parseAndVerifyInteger<unsigned>(dialect, parser);
564}
565
566namespace {
567// Functor object to parse a comma separated list of specs. The function
568// parseAndVerify does the actual parsing and verification of individual
569// elements. This is a functor since parsing the last element of the list
570// (termination condition) needs partial specialization.
571template <typename ParseType, typename... Args>
572struct ParseCommaSeparatedList {
573 std::optional<std::tuple<ParseType, Args...>>
574 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
575 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
576 if (!parseVal)
577 return std::nullopt;
578
579 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
580 if (numArgs != 0 && failed(parser.parseComma()))
581 return std::nullopt;
582 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
583 if (!remainingValues)
584 return std::nullopt;
585 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
586 remainingValues.value());
587 }
588};
589
590// Partial specialization of the function to parse a comma separated list of
591// specs to parse the last element of the list.
592template <typename ParseType>
593struct ParseCommaSeparatedList<ParseType> {
594 std::optional<std::tuple<ParseType>>
595 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
596 if (auto value = parseAndVerify<ParseType>(dialect, parser))
597 return std::tuple<ParseType>(*value);
598 return std::nullopt;
599 }
600};
601} // namespace
602
603// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
604//
605// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
606//
607// arrayed-info ::= `NonArrayed` | `Arrayed`
608//
609// sampling-info ::= `SingleSampled` | `MultiSampled`
610//
611// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
612//
613// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
614//
615// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
616// arrayed-info `,` sampling-info `,`
617// sampler-use-info `,` format `>`
618static Type parseImageType(SPIRVDialect const &dialect,
619 DialectAsmParser &parser) {
620 if (parser.parseLess())
621 return Type();
622
623 auto value =
624 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
625 ImageSamplingInfo, ImageSamplerUseInfo,
626 ImageFormat>{}(dialect, parser);
627 if (!value)
628 return Type();
629
630 if (parser.parseGreater())
631 return Type();
632 return ImageType::get(*value);
633}
634
635// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
636static Type parseSampledImageType(SPIRVDialect const &dialect,
637 DialectAsmParser &parser) {
638 if (parser.parseLess())
639 return Type();
640
641 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
642 if (!parsedType)
643 return Type();
644
645 if (parser.parseGreater())
646 return Type();
647 return SampledImageType::get(parsedType);
648}
649
650// Parse decorations associated with a member.
652 SPIRVDialect const &dialect, DialectAsmParser &parser,
653 ArrayRef<Type> memberTypes,
656
657 // Check if the first element is offset.
658 SMLoc offsetLoc = parser.getCurrentLocation();
659 StructType::OffsetInfo offset = 0;
660 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
661 if (offsetParseResult.has_value()) {
662 if (failed(*offsetParseResult))
663 return failure();
664
665 if (offsetInfo.size() != memberTypes.size() - 1) {
666 return parser.emitError(offsetLoc,
667 "offset specification must be given for "
668 "all members");
669 }
670 offsetInfo.push_back(offset);
671 }
672
673 // Check for no spirv::Decorations.
674 if (succeeded(parser.parseOptionalRSquare()))
675 return success();
676
677 // If there was an offset, make sure to parse the comma.
678 if (offsetParseResult.has_value() && parser.parseComma())
679 return failure();
680
681 // Check for spirv::Decorations.
682 auto parseDecorations = [&]() {
683 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
684 if (!memberDecoration)
685 return failure();
686
687 // Parse member decoration value if it exists.
688 if (succeeded(parser.parseOptionalEqual())) {
689 Attribute memberDecorationValue;
690 if (failed(parser.parseAttribute(memberDecorationValue)))
691 return failure();
692
693 memberDecorationInfo.emplace_back(
694 static_cast<uint32_t>(memberTypes.size() - 1),
695 memberDecoration.value(), memberDecorationValue);
696 } else {
697 memberDecorationInfo.emplace_back(
698 static_cast<uint32_t>(memberTypes.size() - 1),
699 memberDecoration.value(), UnitAttr::get(dialect.getContext()));
700 }
701 return success();
702 };
703 if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
704 failed(parser.parseRSquare()))
705 return failure();
706
707 return success();
708}
709
710// struct-member-decoration ::= integer-literal? spirv-decoration*
711// struct-type ::=
712// `!spirv.struct<` (id `,`)?
713// `(`
714// (spirv-type (`[` struct-member-decoration `]`)?)*
715// `)`
716// (`,` struct-decoration)?
717// `>`
718static Type parseStructType(SPIRVDialect const &dialect,
719 DialectAsmParser &parser) {
720 // TODO: This function is quite lengthy. Break it down into smaller chunks.
721
722 if (parser.parseLess())
723 return Type();
724
725 StringRef identifier;
726 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
727
728 // Check if this is an identified struct type.
729 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
730 // Check if this is a possible recursive reference.
731 auto structType =
732 StructType::getIdentified(dialect.getContext(), identifier);
733 cyclicParse = parser.tryStartCyclicParse(structType);
734 if (succeeded(parser.parseOptionalGreater())) {
735 if (succeeded(cyclicParse)) {
736 parser.emitError(
737 parser.getNameLoc(),
738 "recursive struct reference not nested in struct definition");
739
740 return Type();
741 }
742
743 return structType;
744 }
745
746 if (failed(parser.parseComma()))
747 return Type();
748
749 if (failed(cyclicParse)) {
750 parser.emitError(parser.getNameLoc(),
751 "identifier already used for an enclosing struct");
752 return Type();
753 }
754 }
755
756 if (failed(parser.parseLParen()))
757 return Type();
758
759 if (succeeded(parser.parseOptionalRParen()) &&
760 succeeded(parser.parseOptionalGreater())) {
761 return StructType::getEmpty(dialect.getContext(), identifier);
762 }
763
764 StructType idStructTy;
765
766 if (!identifier.empty())
767 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
768
769 SmallVector<Type, 4> memberTypes;
772
773 do {
774 Type memberType;
775 if (parser.parseType(memberType))
776 return Type();
777 memberTypes.push_back(memberType);
778
779 if (succeeded(parser.parseOptionalLSquare()))
780 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
781 memberDecorationInfo))
782 return Type();
783 } while (succeeded(parser.parseOptionalComma()));
784
785 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
786 parser.emitError(parser.getNameLoc(),
787 "offset specification must be given for all members");
788 return Type();
789 }
790
791 if (failed(parser.parseRParen()))
792 return Type();
793
795
796 auto parseStructDecoration = [&]() {
797 std::optional<spirv::Decoration> decoration =
798 parseAndVerify<spirv::Decoration>(dialect, parser);
799 if (!decoration)
800 return failure();
801
802 // Parse decoration value if it exists.
803 if (succeeded(parser.parseOptionalEqual())) {
804 Attribute decorationValue;
805 if (failed(parser.parseAttribute(decorationValue)))
806 return failure();
807
808 structDecorationInfo.emplace_back(decoration.value(), decorationValue);
809 } else {
810 structDecorationInfo.emplace_back(decoration.value(),
811 UnitAttr::get(dialect.getContext()));
812 }
813 return success();
814 };
815
816 while (succeeded(parser.parseOptionalComma()))
817 if (failed(parseStructDecoration()))
818 return Type();
819
820 if (failed(parser.parseGreater()))
821 return Type();
822
823 if (!identifier.empty()) {
824 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
825 memberDecorationInfo,
826 structDecorationInfo)))
827 return Type();
828 return idStructTy;
829 }
830
831 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
832 structDecorationInfo);
833}
834
835// spirv-type ::= array-type
836// | element-type
837// | image-type
838// | pointer-type
839// | runtime-array-type
840// | sampled-image-type
841// | struct-type
842Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
843 StringRef keyword;
844 if (parser.parseKeyword(&keyword))
845 return Type();
846
847 if (keyword == "array")
848 return parseArrayType(*this, parser);
849 if (keyword == "coopmatrix")
850 return parseCooperativeMatrixType(*this, parser);
851 if (keyword == "image")
852 return parseImageType(*this, parser);
853 if (keyword == "ptr")
854 return parsePointerType(*this, parser);
855 if (keyword == "rtarray")
856 return parseRuntimeArrayType(*this, parser);
857 if (keyword == "sampled_image")
858 return parseSampledImageType(*this, parser);
859 if (keyword == "struct")
860 return parseStructType(*this, parser);
861 if (keyword == "matrix")
862 return parseMatrixType(*this, parser);
863 if (keyword == "arm.tensor")
864 return parseTensorArmType(*this, parser);
865 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
866 return Type();
867}
868
869//===----------------------------------------------------------------------===//
870// Type Printing
871//===----------------------------------------------------------------------===//
872
873static void print(ArrayType type, DialectAsmPrinter &os) {
874 os << "array<" << type.getNumElements() << " x " << type.getElementType();
875 if (unsigned stride = type.getArrayStride())
876 os << ", stride=" << stride;
877 os << ">";
878}
879
881 os << "rtarray<" << type.getElementType();
882 if (unsigned stride = type.getArrayStride())
883 os << ", stride=" << stride;
884 os << ">";
885}
886
887static void print(PointerType type, DialectAsmPrinter &os) {
888 os << "ptr<" << type.getPointeeType() << ", "
889 << stringifyStorageClass(type.getStorageClass()) << ">";
890}
891
892static void print(ImageType type, DialectAsmPrinter &os) {
893 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
894 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
895 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
896 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
897 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
898 << stringifyImageFormat(type.getImageFormat()) << ">";
899}
900
902 os << "sampled_image<" << type.getImageType() << ">";
903}
904
905static void print(StructType type, DialectAsmPrinter &os) {
906 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
907
908 os << "struct<";
909
910 if (type.isIdentified()) {
911 os << type.getIdentifier();
912
913 cyclicPrint = os.tryStartCyclicPrint(type);
914 if (failed(cyclicPrint)) {
915 os << ">";
916 return;
917 }
918
919 os << ", ";
920 }
921
922 os << "(";
923
924 auto printMember = [&](unsigned i) {
925 os << type.getElementType(i);
927 type.getMemberDecorations(i, decorations);
928 if (type.hasOffset() || !decorations.empty()) {
929 os << " [";
930 if (type.hasOffset()) {
931 os << type.getMemberOffset(i);
932 if (!decorations.empty())
933 os << ", ";
934 }
935 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
936 os << stringifyDecoration(decoration.decoration);
937 if (decoration.hasValue()) {
938 os << "=";
939 os.printAttributeWithoutType(decoration.decorationValue);
940 }
941 };
942 llvm::interleaveComma(decorations, os, eachFn);
943 os << "]";
944 }
945 };
946 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
947 printMember);
948 os << ")";
949
951 type.getStructDecorations(decorations);
952 if (!decorations.empty()) {
953 os << ", ";
954 auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
955 os << stringifyDecoration(decoration.decoration);
956 if (decoration.hasValue()) {
957 os << "=";
958 os.printAttributeWithoutType(decoration.decorationValue);
959 }
960 };
961 llvm::interleaveComma(decorations, os, eachFn);
962 }
963
964 os << ">";
965}
966
968 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
969 << type.getElementType() << ", " << type.getScope() << ", "
970 << type.getUse() << ">";
971}
972
973static void print(MatrixType type, DialectAsmPrinter &os) {
974 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
975 os << ">";
976}
977
978static void print(TensorArmType type, DialectAsmPrinter &os) {
979 os << "arm.tensor<";
980
981 llvm::interleave(
982 type.getShape(), os,
983 [&](int64_t dim) {
984 if (ShapedType::isDynamic(dim))
985 os << '?';
986 else
987 os << dim;
988 },
989 "x");
990 if (!type.hasRank()) {
991 os << "*";
992 }
993 os << "x" << type.getElementType() << ">";
994}
995
996void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
997 TypeSwitch<Type>(type)
1000 [&](auto type) { print(type, os); })
1001 .DefaultUnreachable("Unhandled SPIR-V type");
1002}
1003
1004//===----------------------------------------------------------------------===//
1005// Constant
1006//===----------------------------------------------------------------------===//
1007
1008Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
1009 Attribute value, Type type,
1010 Location loc) {
1011 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
1012 return ub::PoisonOp::create(builder, loc, type, poison);
1013
1014 if (!spirv::ConstantOp::isBuildableWith(type))
1015 return nullptr;
1016
1017 return spirv::ConstantOp::create(builder, loc, type, value);
1018}
1019
1020//===----------------------------------------------------------------------===//
1021// Shader Interface ABI
1022//===----------------------------------------------------------------------===//
1023
1024LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1025 NamedAttribute attribute) {
1026 StringRef symbol = attribute.getName().strref();
1027 Attribute attr = attribute.getValue();
1028
1029 if (symbol == spirv::getEntryPointABIAttrName()) {
1030 if (!isa<spirv::EntryPointABIAttr>(attr)) {
1031 return op->emitError("'")
1032 << symbol << "' attribute must be an entry point ABI attribute";
1033 }
1034 } else if (symbol == spirv::getTargetEnvAttrName()) {
1035 if (!isa<spirv::TargetEnvAttr>(attr))
1036 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1037 } else {
1038 return op->emitError("found unsupported '")
1039 << symbol << "' attribute on operation";
1040 }
1041
1042 return success();
1043}
1044
1045/// Verifies the given SPIR-V `attribute` attached to a value of the given
1046/// `valueType` is valid.
1047static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1048 NamedAttribute attribute) {
1049 StringRef symbol = attribute.getName().strref();
1050 Attribute attr = attribute.getValue();
1051
1052 if (symbol == spirv::getInterfaceVarABIAttrName()) {
1053 auto varABIAttr = dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1054 if (!varABIAttr)
1055 return emitError(loc, "'")
1056 << symbol << "' must be a spirv::InterfaceVarABIAttr";
1057
1058 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1059 return emitError(loc, "'") << symbol
1060 << "' attribute cannot specify storage class "
1061 "when attaching to a non-scalar value";
1062 return success();
1063 }
1064 if (symbol == spirv::DecorationAttr::name) {
1065 if (!isa<spirv::DecorationAttr>(attr))
1066 return emitError(loc, "'")
1067 << symbol << "' must be a spirv::DecorationAttr";
1068 return success();
1069 }
1070
1071 return emitError(loc, "found unsupported '")
1072 << symbol << "' attribute on region argument";
1073}
1074
1075LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1076 unsigned regionIndex,
1077 unsigned argIndex,
1078 NamedAttribute attribute) {
1079 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1080 if (!funcOp)
1081 return success();
1082 Type argType = funcOp.getArgumentTypes()[argIndex];
1083
1084 return verifyRegionAttribute(op->getLoc(), argType, attribute);
1085}
1086
1087LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1088 Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
1089 NamedAttribute attribute) {
1090 if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1091 return verifyRegionAttribute(
1092 op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1093 return op->emitError(
1094 "cannot attach SPIR-V attributes to region result which is "
1095 "not part of a spirv::GraphARMOp type");
1096}
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType get(Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136