MLIR 23.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVParsingUtils.h"
16
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
28#include "llvm/ADT/Sequence.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31
32using namespace mlir;
33using namespace mlir::spirv;
34
35#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36
37//===----------------------------------------------------------------------===//
38// InlinerInterface
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42/// ops.
43static inline bool containsReturn(Region &region) {
44 return llvm::any_of(region, [](Block &block) {
45 Operation *terminator = block.getTerminator();
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47 });
48}
49
50namespace {
51/// This class defines the interface for inlining within the SPIR-V dialect.
52struct SPIRVInlinerInterface : public DialectInlinerInterface {
53 using DialectInlinerInterface::DialectInlinerInterface;
54
55 /// All call operations within SPIRV can be inlined.
56 bool isLegalToInline(Operation *call, Operation *callable,
57 bool wouldBeCloned) const final {
58 return true;
59 }
60
61 /// Returns true if the given region 'src' can be inlined into the region
62 /// 'dest' that is attached to an operation registered to the current dialect.
63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64 IRMapping &) const final {
65 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66 // spirv.mlir.loop operations.
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69 }
70
71 /// Returns true if the given operation 'op', that is registered to this
72 /// dialect, can be inlined into the region 'dest' that is attached to an
73 /// operation registered to the current dialect.
74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75 IRMapping &) const final {
76 // TODO: Enable inlining structured control flows with return.
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78 containsReturn(op->getRegion(0)))
79 return false;
80 // TODO: we need to filter OpKill here to avoid inlining it to
81 // a loop continue construct:
82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83 // For now, we just disallow inlining OpKill anywhere in the code,
84 // but this restriction should be relaxed, as pointed above.
85 if (isa<spirv::KillOp>(op))
86 return false;
87
88 return true;
89 }
90
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 auto builder = OpBuilder(op);
96 spirv::BranchOp::create(builder, op->getLoc(), newDest);
97 op->erase();
98 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99 auto builder = OpBuilder(op);
100 spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101 retValOp->getOperands());
102 op->erase();
103 }
104 }
105
106 /// Handle the given inlined terminator by replacing it with a new operation
107 /// as necessary.
108 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
109 // Only spirv.ReturnValue needs to be handled here.
110 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
111 if (!retValOp)
112 return;
113
114 // Replace the values directly with the return operands.
115 assert(valuesToRepl.size() == 1 &&
116 "spirv.ReturnValue expected to only handle one result");
117 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
118 }
119};
120} // namespace
121
122//===----------------------------------------------------------------------===//
123// SPIR-V Dialect
124//===----------------------------------------------------------------------===//
125
126void SPIRVDialect::initialize() {
127 registerAttributes();
128 registerTypes();
129
130 // Add SPIR-V ops.
131 addOperations<
132#define GET_OP_LIST
133#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134 >();
135
136 addInterfaces<SPIRVInlinerInterface>();
137
138 // Allow unknown operations because SPIR-V is extensible.
139 allowUnknownOperations();
140 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141}
142
143std::string SPIRVDialect::getAttributeName(Decoration decoration) {
144 return getDecorationString(decoration);
145}
146
147//===----------------------------------------------------------------------===//
148// Type Parsing
149//===----------------------------------------------------------------------===//
150
151// Forward declarations.
152template <typename ValTy>
153static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
154 DialectAsmParser &parser);
155template <>
156std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
157 DialectAsmParser &parser);
158
159template <>
160std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
161 DialectAsmParser &parser);
162
163static Type parseAndVerifyType(SPIRVDialect const &dialect,
164 DialectAsmParser &parser) {
165 Type type;
166 SMLoc typeLoc = parser.getCurrentLocation();
167 if (parser.parseType(type))
168 return Type();
169
170 // Allow SPIR-V dialect types.
171 if (&type.getDialect() == &dialect)
172 return type;
173
174 // Check other allowed types.
175 if (auto t = dyn_cast<FloatType>(type)) {
176 // TODO: All float types are allowed for now, but this should be fixed.
177 } else if (auto t = dyn_cast<IntegerType>(type)) {
178 if (!ScalarType::isValid(t)) {
179 parser.emitError(typeLoc,
180 "only 1/8/16/32/64-bit integer type allowed but found ")
181 << type;
182 return Type();
183 }
184 } else if (auto t = dyn_cast<VectorType>(type)) {
185 if (t.getRank() != 1) {
186 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
187 return Type();
188 }
189 if (t.getNumElements() < 2) {
190 parser.emitError(typeLoc, "SPIR-V does not allow one-element vectors");
191 return Type();
192 }
193 if (t.getNumElements() > 4) {
194 parser.emitError(
195 typeLoc, "vector length has to be less than or equal to 4 but found ")
196 << t.getNumElements();
197 return Type();
198 }
199 if (!isa<ScalarType>(t.getElementType())) {
200 parser.emitError(
201 typeLoc,
202 "vector element type must be a SPIR-V scalar type but found ")
203 << t.getElementType();
204 return Type();
205 }
206 } else if (auto t = dyn_cast<TensorArmType>(type)) {
207 if (!isa<ScalarType>(t.getElementType())) {
208 parser.emitError(
209 typeLoc, "only scalar element type allowed in tensor type but found ")
210 << t.getElementType();
211 return Type();
212 }
213 } else {
214 parser.emitError(typeLoc, "cannot use ")
215 << type << " to compose SPIR-V types";
216 return Type();
217 }
218
219 return type;
220}
221
222static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
223 DialectAsmParser &parser) {
224 Type type;
225 SMLoc typeLoc = parser.getCurrentLocation();
226 if (parser.parseType(type))
227 return Type();
228
229 if (auto t = dyn_cast<VectorType>(type)) {
230 if (t.getRank() != 1) {
231 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
232 return Type();
233 }
234 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
235 parser.emitError(typeLoc,
236 "matrix columns size has to be less than or equal "
237 "to 4 and greater than or equal 2, but found ")
238 << t.getNumElements();
239 return Type();
240 }
241
242 if (!isa<FloatType>(t.getElementType())) {
243 parser.emitError(typeLoc, "matrix columns' elements must be of "
244 "Float type, got ")
245 << t.getElementType();
246 return Type();
247 }
248 } else {
249 parser.emitError(typeLoc, "matrix must be composed using vector "
250 "type, got ")
251 << type;
252 return Type();
253 }
254
255 return type;
256}
257
258static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
259 DialectAsmParser &parser) {
260 Type type;
261 SMLoc typeLoc = parser.getCurrentLocation();
262 if (parser.parseType(type))
263 return Type();
264
265 auto imageType = dyn_cast<ImageType>(type);
266 if (!imageType) {
267 parser.emitError(typeLoc,
268 "sampled image must be composed using image type, got ")
269 << type;
270 return Type();
271 }
272
273 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, imageType.getDim())) {
274 parser.emitError(
275 typeLoc, "sampled image Dim must not be SubpassData or Buffer, got ")
276 << stringifyDim(imageType.getDim());
277 return Type();
278 }
279
280 return type;
281}
282
283/// Parses an optional `, stride = N` assembly segment. If no parsing failure
284/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
285/// missing.
286static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
287 DialectAsmParser &parser,
288 unsigned &stride) {
289 if (failed(parser.parseOptionalComma())) {
290 stride = 0;
291 return success();
292 }
293
294 if (parser.parseKeyword("stride") || parser.parseEqual())
295 return failure();
296
297 SMLoc strideLoc = parser.getCurrentLocation();
298 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
299 if (!optStride)
300 return failure();
301
302 if (!(stride = *optStride)) {
303 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
304 return failure();
305 }
306 return success();
307}
308
309// element-type ::= integer-type
310// | floating-point-type
311// | vector-type
312// | spirv-type
313//
314// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
315// (`,` `stride` `=` integer-literal)? `>`
316static Type parseArrayType(SPIRVDialect const &dialect,
317 DialectAsmParser &parser) {
318 if (parser.parseLess())
319 return Type();
320
321 SmallVector<int64_t, 1> countDims;
322 SMLoc countLoc = parser.getCurrentLocation();
323 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
324 return Type();
325 if (countDims.size() != 1) {
326 parser.emitError(countLoc,
327 "expected single integer for array element count");
328 return Type();
329 }
330
331 // According to the SPIR-V spec:
332 // "Length is the number of elements in the array. It must be at least 1."
333 int64_t count = countDims[0];
334 if (count == 0) {
335 parser.emitError(countLoc, "expected array length greater than 0");
336 return Type();
337 }
338
339 Type elementType = parseAndVerifyType(dialect, parser);
340 if (!elementType)
341 return Type();
342
343 unsigned stride = 0;
344 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
345 return Type();
346
347 if (parser.parseGreater())
348 return Type();
349 return ArrayType::get(elementType, count, stride);
350}
351
352// cooperative-matrix-type ::=
353// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
354// scope `,` use `>`
355static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
356 DialectAsmParser &parser) {
357 if (parser.parseLess())
358 return {};
359
361 SMLoc countLoc = parser.getCurrentLocation();
362 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
363 return {};
364
365 if (dims.size() != 2) {
366 parser.emitError(countLoc, "expected row and column count");
367 return {};
368 }
369
370 auto elementTy = parseAndVerifyType(dialect, parser);
371 if (!elementTy)
372 return {};
373
374 Scope scope;
375 if (parser.parseComma() ||
376 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
377 return {};
378
379 CooperativeMatrixUseKHR use;
380 if (parser.parseComma() ||
381 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
382 return {};
383
384 if (parser.parseGreater())
385 return {};
386
387 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
388}
389
390// tensor-arm-type ::=
391// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
392static Type parseTensorArmType(SPIRVDialect const &dialect,
393 DialectAsmParser &parser) {
394 if (parser.parseLess())
395 return {};
396
397 bool unranked = false;
399 SMLoc countLoc = parser.getCurrentLocation();
400
401 if (parser.parseOptionalStar().succeeded()) {
402 unranked = true;
403 if (parser.parseXInDimensionList())
404 return {};
405 } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
406 return {};
407 }
408
409 if (!unranked && dims.empty()) {
410 parser.emitError(countLoc, "arm.tensors do not support rank zero");
411 return {};
412 }
413
414 if (llvm::is_contained(dims, 0)) {
415 parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
416 return {};
417 }
418
419 if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
420 llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
421 parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
422 "fully dynamic or completed shaped");
423 return {};
424 }
425
426 auto elementTy = parseAndVerifyType(dialect, parser);
427 if (!elementTy)
428 return {};
429
430 if (parser.parseGreater())
431 return {};
432
433 return TensorArmType::get(dims, elementTy);
434}
435
436// TODO: Reorder methods to be utilities first and parse*Type
437// methods in alphabetical order
438//
439// storage-class ::= `UniformConstant`
440// | `Uniform`
441// | `Workgroup`
442// | <and other storage classes...>
443//
444// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
445static Type parsePointerType(SPIRVDialect const &dialect,
446 DialectAsmParser &parser) {
447 if (parser.parseLess())
448 return Type();
449
450 auto pointeeType = parseAndVerifyType(dialect, parser);
451 if (!pointeeType)
452 return Type();
453
454 StringRef storageClassSpec;
455 SMLoc storageClassLoc = parser.getCurrentLocation();
456 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
457 return Type();
458
459 auto storageClass = symbolizeStorageClass(storageClassSpec);
460 if (!storageClass) {
461 parser.emitError(storageClassLoc, "unknown storage class: ")
462 << storageClassSpec;
463 return Type();
464 }
465 if (parser.parseGreater())
466 return Type();
467 return PointerType::get(pointeeType, *storageClass);
468}
469
470// runtime-array-type ::= `!spirv.rtarray` `<` element-type
471// (`,` `stride` `=` integer-literal)? `>`
472static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
473 DialectAsmParser &parser) {
474 if (parser.parseLess())
475 return Type();
476
477 Type elementType = parseAndVerifyType(dialect, parser);
478 if (!elementType)
479 return Type();
480
481 unsigned stride = 0;
482 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
483 return Type();
484
485 if (parser.parseGreater())
486 return Type();
487 return RuntimeArrayType::get(elementType, stride);
488}
489
490// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
491static Type parseMatrixType(SPIRVDialect const &dialect,
492 DialectAsmParser &parser) {
493 if (parser.parseLess())
494 return Type();
495
496 SmallVector<int64_t, 1> countDims;
497 SMLoc countLoc = parser.getCurrentLocation();
498 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
499 return Type();
500 if (countDims.size() != 1) {
501 parser.emitError(countLoc, "expected single unsigned "
502 "integer for number of columns");
503 return Type();
504 }
505
506 int64_t columnCount = countDims[0];
507 // According to the specification, Matrices can have 2, 3, or 4 columns
508 if (columnCount < 2 || columnCount > 4) {
509 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
510 "columns");
511 return Type();
512 }
513
514 Type columnType = parseAndVerifyMatrixType(dialect, parser);
515 if (!columnType)
516 return Type();
517
518 if (parser.parseGreater())
519 return Type();
520
521 return MatrixType::get(columnType, columnCount);
522}
523
524// Specialize this function to parse each of the parameters that define an
525// ImageType. By default it assumes this is an enum type.
526template <typename ValTy>
527static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
528 DialectAsmParser &parser) {
529 StringRef enumSpec;
530 SMLoc enumLoc = parser.getCurrentLocation();
531 if (parser.parseKeyword(&enumSpec)) {
532 return std::nullopt;
533 }
534
535 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
536 if (!val)
537 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
538 return val;
539}
540
541template <>
542std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
543 DialectAsmParser &parser) {
544 // TODO: Further verify that the element type can be sampled
545 auto ty = parseAndVerifyType(dialect, parser);
546 if (!ty)
547 return std::nullopt;
548 return ty;
549}
550
551template <typename IntTy>
552static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
553 DialectAsmParser &parser) {
554 IntTy offsetVal = std::numeric_limits<IntTy>::max();
555 if (parser.parseInteger(offsetVal))
556 return std::nullopt;
557 return offsetVal;
558}
559
560template <>
561std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
562 DialectAsmParser &parser) {
563 return parseAndVerifyInteger<unsigned>(dialect, parser);
564}
565
566namespace {
567// Functor object to parse a comma separated list of specs. The function
568// parseAndVerify does the actual parsing and verification of individual
569// elements. This is a functor since parsing the last element of the list
570// (termination condition) needs partial specialization.
571template <typename ParseType, typename... Args>
572struct ParseCommaSeparatedList {
573 std::optional<std::tuple<ParseType, Args...>>
574 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
575 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
576 if (!parseVal)
577 return std::nullopt;
578
579 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
580 if (numArgs != 0 && failed(parser.parseComma()))
581 return std::nullopt;
582 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
583 if (!remainingValues)
584 return std::nullopt;
585 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
586 remainingValues.value());
587 }
588};
589
590// Partial specialization of the function to parse a comma separated list of
591// specs to parse the last element of the list.
592template <typename ParseType>
593struct ParseCommaSeparatedList<ParseType> {
594 std::optional<std::tuple<ParseType>>
595 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
596 if (auto value = parseAndVerify<ParseType>(dialect, parser))
597 return std::tuple<ParseType>(*value);
598 return std::nullopt;
599 }
600};
601} // namespace
602
603// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
604//
605// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
606//
607// arrayed-info ::= `NonArrayed` | `Arrayed`
608//
609// sampling-info ::= `SingleSampled` | `MultiSampled`
610//
611// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
612//
613// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
614//
615// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
616// arrayed-info `,` sampling-info `,`
617// sampler-use-info `,` format `>`
618static Type parseImageType(SPIRVDialect const &dialect,
619 DialectAsmParser &parser) {
620 if (parser.parseLess())
621 return Type();
622
623 auto value =
624 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
625 ImageSamplingInfo, ImageSamplerUseInfo,
626 ImageFormat>{}(dialect, parser);
627 if (!value)
628 return Type();
629
630 if (parser.parseGreater())
631 return Type();
632 return ImageType::get(*value);
633}
634
635// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
636static Type parseSampledImageType(SPIRVDialect const &dialect,
637 DialectAsmParser &parser) {
638 if (parser.parseLess())
639 return Type();
640
641 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
642 if (!parsedType)
643 return Type();
644
645 if (parser.parseGreater())
646 return Type();
647 return SampledImageType::get(parsedType);
648}
649
650// Parse decorations associated with a member.
652 SPIRVDialect const &dialect, DialectAsmParser &parser,
653 ArrayRef<Type> memberTypes,
656
657 // Check if the first element is offset.
658 SMLoc offsetLoc = parser.getCurrentLocation();
659 StructType::OffsetInfo offset = 0;
660 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
661 if (offsetParseResult.has_value()) {
662 if (failed(*offsetParseResult))
663 return failure();
664
665 if (offsetInfo.size() != memberTypes.size() - 1) {
666 return parser.emitError(offsetLoc,
667 "offset specification must be given for "
668 "all members");
669 }
670 offsetInfo.push_back(offset);
671 }
672
673 // Check for no spirv::Decorations.
674 if (succeeded(parser.parseOptionalRSquare()))
675 return success();
676
677 // If there was an offset, make sure to parse the comma.
678 if (offsetParseResult.has_value() && parser.parseComma())
679 return failure();
680
681 // Check for spirv::Decorations.
682 auto parseDecorations = [&]() {
683 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
684 if (!memberDecoration)
685 return failure();
686
687 // Parse member decoration value if it exists.
688 if (succeeded(parser.parseOptionalEqual())) {
689 Attribute memberDecorationValue;
690 if (failed(parser.parseAttribute(memberDecorationValue)))
691 return failure();
692
693 memberDecorationInfo.emplace_back(
694 static_cast<uint32_t>(memberTypes.size() - 1),
695 memberDecoration.value(), memberDecorationValue);
696 } else {
697 memberDecorationInfo.emplace_back(
698 static_cast<uint32_t>(memberTypes.size() - 1),
699 memberDecoration.value(), UnitAttr::get(dialect.getContext()));
700 }
701 return success();
702 };
703 if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
704 failed(parser.parseRSquare()))
705 return failure();
706
707 return success();
708}
709
710// struct-member-decoration ::= integer-literal? spirv-decoration*
711// struct-type ::=
712// `!spirv.struct<` (id `,`)?
713// `(`
714// (spirv-type (`[` struct-member-decoration `]`)?)*
715// `)`
716// (`,` struct-decoration)?
717// `>`
718static Type parseStructType(SPIRVDialect const &dialect,
719 DialectAsmParser &parser) {
720 // TODO: This function is quite lengthy. Break it down into smaller chunks.
721
722 if (parser.parseLess())
723 return Type();
724
725 StringRef identifier;
726 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
727
728 // Check if this is an identified struct type.
729 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
730 // Check if this is a possible recursive reference.
731 auto structType =
732 StructType::getIdentified(dialect.getContext(), identifier);
733 cyclicParse = parser.tryStartCyclicParse(structType);
734 if (succeeded(parser.parseOptionalGreater())) {
735 if (succeeded(cyclicParse)) {
736 parser.emitError(
737 parser.getNameLoc(),
738 "recursive struct reference not nested in struct definition");
739
740 return Type();
741 }
742
743 return structType;
744 }
745
746 if (failed(parser.parseComma()))
747 return Type();
748
749 if (failed(cyclicParse)) {
750 parser.emitError(parser.getNameLoc(),
751 "identifier already used for an enclosing struct");
752 return Type();
753 }
754 }
755
756 if (failed(parser.parseLParen()))
757 return Type();
758
759 if (succeeded(parser.parseOptionalRParen()) &&
760 succeeded(parser.parseOptionalGreater())) {
761 return StructType::getEmpty(dialect.getContext(), identifier);
762 }
763
764 StructType idStructTy;
765
766 if (!identifier.empty())
767 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
768
769 SmallVector<Type, 4> memberTypes;
772
773 do {
774 Type memberType;
775 if (parser.parseType(memberType))
776 return Type();
777 if (!isa<SPIRVType>(memberType)) {
778 parser.emitError(parser.getNameLoc(),
779 "member type must be a valid SPIR-V type");
780 return Type();
781 }
782 memberTypes.push_back(memberType);
783
784 if (succeeded(parser.parseOptionalLSquare()))
785 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
786 memberDecorationInfo))
787 return Type();
788 } while (succeeded(parser.parseOptionalComma()));
789
790 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
791 parser.emitError(parser.getNameLoc(),
792 "offset specification must be given for all members");
793 return Type();
794 }
795
796 if (failed(parser.parseRParen()))
797 return Type();
798
800
801 auto parseStructDecoration = [&]() {
802 std::optional<spirv::Decoration> decoration =
803 parseAndVerify<spirv::Decoration>(dialect, parser);
804 if (!decoration)
805 return failure();
806
807 // Parse decoration value if it exists.
808 if (succeeded(parser.parseOptionalEqual())) {
809 Attribute decorationValue;
810 if (failed(parser.parseAttribute(decorationValue)))
811 return failure();
812
813 structDecorationInfo.emplace_back(decoration.value(), decorationValue);
814 } else {
815 structDecorationInfo.emplace_back(decoration.value(),
816 UnitAttr::get(dialect.getContext()));
817 }
818 return success();
819 };
820
821 while (succeeded(parser.parseOptionalComma()))
822 if (failed(parseStructDecoration()))
823 return Type();
824
825 if (failed(parser.parseGreater()))
826 return Type();
827
828 if (!identifier.empty()) {
829 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
830 memberDecorationInfo,
831 structDecorationInfo)))
832 return Type();
833 return idStructTy;
834 }
835
836 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
837 structDecorationInfo);
838}
839
840// spirv-type ::= array-type
841// | element-type
842// | image-type
843// | pointer-type
844// | runtime-array-type
845// | sampled-image-type
846// | struct-type
847Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
848 StringRef keyword;
849 if (parser.parseKeyword(&keyword))
850 return Type();
851
852 if (keyword == "array")
853 return parseArrayType(*this, parser);
854 if (keyword == "coopmatrix")
855 return parseCooperativeMatrixType(*this, parser);
856 if (keyword == "image")
857 return parseImageType(*this, parser);
858 if (keyword == "ptr")
859 return parsePointerType(*this, parser);
860 if (keyword == "rtarray")
861 return parseRuntimeArrayType(*this, parser);
862 if (keyword == "sampled_image")
863 return parseSampledImageType(*this, parser);
864 if (keyword == "struct")
865 return parseStructType(*this, parser);
866 if (keyword == "matrix")
867 return parseMatrixType(*this, parser);
868 if (keyword == "arm.tensor")
869 return parseTensorArmType(*this, parser);
870 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
871 return Type();
872}
873
874//===----------------------------------------------------------------------===//
875// Type Printing
876//===----------------------------------------------------------------------===//
877
878static void print(ArrayType type, DialectAsmPrinter &os) {
879 os << "array<" << type.getNumElements() << " x " << type.getElementType();
880 if (unsigned stride = type.getArrayStride())
881 os << ", stride=" << stride;
882 os << ">";
883}
884
886 os << "rtarray<" << type.getElementType();
887 if (unsigned stride = type.getArrayStride())
888 os << ", stride=" << stride;
889 os << ">";
890}
891
892static void print(PointerType type, DialectAsmPrinter &os) {
893 os << "ptr<" << type.getPointeeType() << ", "
894 << stringifyStorageClass(type.getStorageClass()) << ">";
895}
896
897static void print(ImageType type, DialectAsmPrinter &os) {
898 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
899 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
900 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
901 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
902 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
903 << stringifyImageFormat(type.getImageFormat()) << ">";
904}
905
907 os << "sampled_image<" << type.getImageType() << ">";
908}
909
910static void print(StructType type, DialectAsmPrinter &os) {
911 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
912
913 os << "struct<";
914
915 if (type.isIdentified()) {
916 os << type.getIdentifier();
917
918 cyclicPrint = os.tryStartCyclicPrint(type);
919 if (failed(cyclicPrint)) {
920 os << ">";
921 return;
922 }
923
924 os << ", ";
925 }
926
927 os << "(";
928
929 auto printMember = [&](unsigned i) {
930 os << type.getElementType(i);
932 type.getMemberDecorations(i, decorations);
933 if (type.hasOffset() || !decorations.empty()) {
934 os << " [";
935 if (type.hasOffset()) {
936 os << type.getMemberOffset(i);
937 if (!decorations.empty())
938 os << ", ";
939 }
940 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
941 os << stringifyDecoration(decoration.decoration);
942 if (decoration.hasValue()) {
943 os << "=";
944 os.printAttributeWithoutType(decoration.decorationValue);
945 }
946 };
947 llvm::interleaveComma(decorations, os, eachFn);
948 os << "]";
949 }
950 };
951 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
952 printMember);
953 os << ")";
954
956 type.getStructDecorations(decorations);
957 if (!decorations.empty()) {
958 os << ", ";
959 auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
960 os << stringifyDecoration(decoration.decoration);
961 if (decoration.hasValue()) {
962 os << "=";
963 os.printAttributeWithoutType(decoration.decorationValue);
964 }
965 };
966 llvm::interleaveComma(decorations, os, eachFn);
967 }
968
969 os << ">";
970}
971
973 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
974 << type.getElementType() << ", " << type.getScope() << ", "
975 << type.getUse() << ">";
976}
977
978static void print(MatrixType type, DialectAsmPrinter &os) {
979 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
980 os << ">";
981}
982
983static void print(TensorArmType type, DialectAsmPrinter &os) {
984 os << "arm.tensor<";
985
986 llvm::interleave(
987 type.getShape(), os,
988 [&](int64_t dim) {
989 if (ShapedType::isDynamic(dim))
990 os << '?';
991 else
992 os << dim;
993 },
994 "x");
995 if (!type.hasRank()) {
996 os << "*";
997 }
998 os << "x" << type.getElementType() << ">";
999}
1000
1001void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
1002 TypeSwitch<Type>(type)
1005 [&](auto type) { print(type, os); })
1006 .DefaultUnreachable("Unhandled SPIR-V type");
1007}
1008
1009//===----------------------------------------------------------------------===//
1010// Constant
1011//===----------------------------------------------------------------------===//
1012
1013Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
1014 Attribute value, Type type,
1015 Location loc) {
1016 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
1017 return ub::PoisonOp::create(builder, loc, type, poison);
1018
1019 if (!spirv::ConstantOp::isBuildableWith(type))
1020 return nullptr;
1021
1022 return spirv::ConstantOp::create(builder, loc, type, value);
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// Shader Interface ABI
1027//===----------------------------------------------------------------------===//
1028
1029LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1030 NamedAttribute attribute) {
1031 StringRef symbol = attribute.getName().strref();
1032 Attribute attr = attribute.getValue();
1033
1034 if (symbol == spirv::getEntryPointABIAttrName()) {
1035 if (!isa<spirv::EntryPointABIAttr>(attr)) {
1036 return op->emitError("'")
1037 << symbol << "' attribute must be an entry point ABI attribute";
1038 }
1039 } else if (symbol == spirv::getTargetEnvAttrName()) {
1040 if (!isa<spirv::TargetEnvAttr>(attr))
1041 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1042 } else {
1043 return op->emitError("found unsupported '")
1044 << symbol << "' attribute on operation";
1045 }
1046
1047 return success();
1048}
1049
1050/// Verifies the given SPIR-V `attribute` attached to a value of the given
1051/// `valueType` is valid.
1052static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1053 NamedAttribute attribute) {
1054 StringRef symbol = attribute.getName().strref();
1055 Attribute attr = attribute.getValue();
1056
1057 if (symbol == spirv::getInterfaceVarABIAttrName()) {
1058 auto varABIAttr = dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1059 if (!varABIAttr)
1060 return emitError(loc, "'")
1061 << symbol << "' must be a spirv::InterfaceVarABIAttr";
1062
1063 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1064 return emitError(loc, "'") << symbol
1065 << "' attribute cannot specify storage class "
1066 "when attaching to a non-scalar value";
1067 return success();
1068 }
1069 if (symbol == spirv::DecorationAttr::name) {
1070 if (!isa<spirv::DecorationAttr>(attr))
1071 return emitError(loc, "'")
1072 << symbol << "' must be a spirv::DecorationAttr";
1073 return success();
1074 }
1075
1076 return emitError(loc, "found unsupported '")
1077 << symbol << "' attribute on region argument";
1078}
1079
1080LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1081 unsigned regionIndex,
1082 unsigned argIndex,
1083 NamedAttribute attribute) {
1084 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1085 if (!funcOp)
1086 return success();
1087 Type argType = funcOp.getArgumentTypes()[argIndex];
1088
1089 return verifyRegionAttribute(op->getLoc(), argType, attribute);
1090}
1091
1092LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1093 Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
1094 NamedAttribute attribute) {
1095 if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1096 return verifyRegionAttribute(
1097 op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1098 return op->emitError(
1099 "cannot attach SPIR-V attributes to region result which is "
1100 "not part of a spirv::GraphARMOp type");
1101}
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType get(Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:486
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136