MLIR 23.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(descriptorSetName);
104 elidedAttrs.push_back(bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(builtInName);
115 }
116
117 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118}
119
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result.attributes) ||
130 parser.parseColon() || parser.parseType(type))
131 return failure();
132 auto fnType = dyn_cast<FunctionType>(type);
133 if (!fnType) {
134 parser.emitError(loc, "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138 return failure();
139 result.addTypes(fnType.getResults());
140 return success();
141 }
142 return failure(parser.parseOperandList(ops) ||
143 parser.parseOptionalAttrDict(result.attributes) ||
144 parser.parseColonType(type) ||
145 parser.resolveOperands(ops, type, result.operands) ||
146 parser.addTypeToList(type, result.types));
147}
148
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(0).getType();
155 if (llvm::any_of(op->getOperandTypes(),
156 [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(op->getOperands());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
174
175 if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
176 return op.emitOpError("mismatch in result type and pointer type");
177 }
178 return success();
179}
180
181/// Walks the given type hierarchy with the given indices, potentially down
182/// to component granularity, to select an element type. Returns null type and
183/// emits errors with the given loc on failure.
184static Type
186 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
187 if (indices.empty()) {
188 emitErrorFn("expected at least one index for spirv.CompositeExtract");
189 return nullptr;
190 }
191
192 for (auto index : indices) {
193 if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
195 (index < 0 ||
196 static_cast<uint64_t>(index) >= cType.getNumElements())) {
197 emitErrorFn("index ") << index << " out of bounds for " << type;
198 return nullptr;
199 }
200 type = cType.getElementType(index);
201 } else {
202 emitErrorFn("cannot extract from non-composite type ")
203 << type << " with index " << index;
204 return nullptr;
205 }
206 }
207 return type;
208}
209
210static Type
212 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
216 return nullptr;
217 }
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn("expected at least one index for spirv.CompositeExtract");
220 return nullptr;
221 }
222
223 SmallVector<int32_t, 2> indexVals;
224 for (auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
226 if (!indexIntAttr) {
227 emitErrorFn("expected an 32-bit integer for index, but found '")
228 << indexAttr << "'";
229 return nullptr;
230 }
231 indexVals.push_back(indexIntAttr.getInt());
232 }
233 return getElementType(type, indexVals, emitErrorFn);
234}
235
237 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
238 return ::mlir::emitError(loc, err);
239 };
240 return getElementType(type, indices, errorFn);
241}
242
244 SMLoc loc) {
245 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
246 return parser.emitError(loc, err);
247 };
248 return getElementType(type, indices, errorFn);
249}
250
251template <typename ExtendedBinaryOp>
252static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError("expected result struct type containing two members");
256
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
262
263 return success();
264}
265
269 if (parser.parseOptionalAttrDict(result.attributes) ||
270 parser.parseOperandList(operands) || parser.parseColon())
271 return failure();
272
273 Type resultType;
274 SMLoc loc = parser.getCurrentLocation();
275 if (parser.parseType(resultType))
276 return failure();
277
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.emitError(loc, "expected spirv.struct type with two members");
281
282 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
283 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
284 return failure();
285
286 result.addTypes(resultType);
287 return success();
288}
289
291 OpAsmPrinter &printer) {
292 printer << ' ';
293 printer.printOptionalAttrDict(op->getAttrs());
294 printer.printOperands(op->getOperands());
295 printer << " : " << op->getResultTypes().front();
296}
297
298static LogicalResult verifyShiftOp(Operation *op) {
299 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
300 return op->emitError("expected the same type for the first operand and "
301 "result, but provided ")
302 << op->getOperand(0).getType() << " and "
303 << op->getResult(0).getType();
304 }
305 return success();
306}
307
308//===----------------------------------------------------------------------===//
309// spirv.mlir.addressof
310//===----------------------------------------------------------------------===//
311
312void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
315}
316
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
319 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
320 getVariableAttr()));
321 if (!varOp) {
322 return emitOpError("expected spirv.GlobalVariable symbol");
323 }
324 if (getPointer().getType() != varOp.getType()) {
325 return emitOpError(
326 "result type mismatch with the referenced global variable's type");
327 }
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// spirv.CompositeConstruct
333//===----------------------------------------------------------------------===//
334
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
337
338 // There are 4 cases with varying verification rules:
339 // 1. Cooperative Matrices (1 constituent)
340 // 2. Structs (1 constituent for each member)
341 // 3. Arrays (1 constituent for each array element)
342 // 4. Vectors (1 constituent (sub-)element for each vector element)
343
344 auto coopElementType = llvm::TypeSwitch<Type, Type>(getType())
345 .Case([](spirv::CooperativeMatrixType coopType) {
346 return coopType.getElementType();
347 })
348 .Default(nullptr);
349
350 // Case 1. -- matrices.
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError("has incorrect number of operands: expected ")
354 << "1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError("operand type mismatch: expected operand type ")
357 << coopElementType << ", but provided "
358 << constituents.front().getType();
359 return success();
360 }
361
362 // Case 2./3./4. -- number of constituents matches the number of elements.
363 auto cType = cast<spirv::CompositeType>(getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
366 if (constituents[index].getType() != cType.getElementType(index)) {
367 return emitOpError("operand type mismatch: expected operand type ")
368 << cType.getElementType(index) << ", but provided "
369 << constituents[index].getType();
370 }
371 }
372 return success();
373 }
374
375 // Case 4. -- check that all constituents add up tp the expected vector type.
376 auto resultType = dyn_cast<VectorType>(cType);
377 if (!resultType)
378 return emitOpError(
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
381
383 for (Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError("operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
389
390 Type elementType = component.getType();
391 if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
394 } else {
395 sizes.push_back(1);
396 }
397
398 if (elementType != resultType.getElementType())
399 return emitOpError("operand element type mismatch: expected to be ")
400 << resultType.getElementType() << ", but provided " << elementType;
401 }
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError("has incorrect number of operands: expected ")
405 << cType.getNumElements() << ", but provided " << totalCount;
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// spirv.CompositeExtractOp
411//===----------------------------------------------------------------------===//
412
413void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
414 Value composite,
416 auto indexAttr = builder.getI32ArrayAttr(indices);
417 auto elementType =
418 getElementType(composite.getType(), indexAttr, state.location);
419 if (!elementType) {
420 return;
421 }
422 build(builder, state, elementType, composite, indexAttr);
423}
424
425ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
427 OpAsmParser::UnresolvedOperand compositeInfo;
428 Attribute indicesAttr;
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
431 Type compositeType;
432 SMLoc attrLocation;
433
434 if (parser.parseOperand(compositeInfo) ||
435 parser.getCurrentLocation(&attrLocation) ||
436 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
437 parser.parseColonType(compositeType) ||
438 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
439 return failure();
440 }
441
442 Type resultType =
443 getElementType(compositeType, indicesAttr, parser, attrLocation);
444 if (!resultType) {
445 return failure();
446 }
447 result.addTypes(resultType);
448 return success();
449}
450
451void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
452 printer << ' ' << getComposite() << getIndices() << " : "
453 << getComposite().getType();
454}
455
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
458 auto resultType =
459 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
460 if (!resultType)
461 return failure();
462
463 if (resultType != getType()) {
464 return emitOpError("invalid result type: expected ")
465 << resultType << " but provided " << getType();
466 }
467
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// spirv.CompositeInsert
473//===----------------------------------------------------------------------===//
474
475void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
476 Value object, Value composite,
478 auto indexAttr = builder.getI32ArrayAttr(indices);
479 build(builder, state, composite.getType(), object, composite, indexAttr);
480}
481
482ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485 Type objectType, compositeType;
486 Attribute indicesAttr;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
489 auto loc = parser.getCurrentLocation();
490
491 return failure(
492 parser.parseOperandList(operands, 2) ||
493 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
494 parser.parseColonType(objectType) ||
495 parser.parseKeywordType("into", compositeType) ||
496 parser.resolveOperands(operands, {objectType, compositeType}, loc,
497 result.operands) ||
498 parser.addTypesToList(compositeType, result.types));
499}
500
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
503 auto objectType =
504 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
505 if (!objectType)
506 return failure();
507
508 if (objectType != getObject().getType()) {
509 return emitOpError("object operand type should be ")
510 << objectType << ", but found " << getObject().getType();
511 }
512
513 if (getComposite().getType() != getType()) {
514 return emitOpError("result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() << " vs " << getType();
517 }
518
519 return success();
520}
521
522void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
523 printer << " " << getObject() << ", " << getComposite() << getIndices()
524 << " : " << getObject().getType() << " into "
525 << getComposite().getType();
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.Constant
530//===----------------------------------------------------------------------===//
531
532ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534 Attribute value;
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
536 if (parser.parseAttribute(value, valueAttrName, result.attributes))
537 return failure();
538
539 Type type = NoneType::get(parser.getContext());
540 if (auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
543 if (parser.parseColonType(type))
544 return failure();
545 }
546
547 if (isa<TensorArmType>(type)) {
548 if (parser.parseOptionalColon().succeeded())
549 if (parser.parseType(type))
550 return failure();
551 }
552
553 return parser.addTypeToList(type, result.types);
554}
555
556void spirv::ConstantOp::print(OpAsmPrinter &printer) {
557 printer << ' ' << getValue();
558 if (isa<spirv::ArrayType, spirv::StructType>(getType()))
559 printer << " : " << getType();
560}
561
562static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
563 Type opType) {
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError("expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
569 << denseAttr;
570 }
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError("result type (")
575 << opType << ") does not match value type (" << valueType << ")";
576 return success();
577 }
578 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
581 return success();
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
584 if (!arrayType)
585 return op.emitOpError("result or element type (")
586 << opType << ") does not match value type (" << valueType
587 << "), must be the same or spirv.array";
588
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
594 }
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError("only support nested array result type");
597
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError("result element type (")
601 << opElemType << ") does not match value element type ("
602 << valueElemType << ")";
603 }
604
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError("result number of elements (")
607 << numElements << ") does not match value number of elements ("
608 << shapedType.getNumElements() << ")";
609 }
610 return success();
611 }
612 if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 if (auto structType = dyn_cast<spirv::StructType>(opType)) {
614 // Identified (possibly recursive) structs are not supported as constants.
615 if (structType.isIdentified())
616 return op.emitOpError(
617 "cannot have an identified struct as a constant type");
618 if (arrayAttr.size() != structType.getNumElements())
619 return op.emitOpError("number of constituents (")
620 << arrayAttr.size()
621 << ") does not match number of struct members ("
622 << structType.getNumElements() << ")";
623 for (auto [idx, element] : llvm::enumerate(arrayAttr.getValue())) {
624 if (failed(verifyConstantType(op, element,
625 structType.getElementType(idx))))
626 return failure();
627 }
628 return success();
629 }
630 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
631 if (!arrayType)
632 return op.emitOpError(
633 "must have spirv.array or spirv.struct result type for array value");
634 Type elemType = arrayType.getElementType();
635 for (Attribute element : arrayAttr.getValue()) {
636 // Verify array elements recursively.
637 if (failed(verifyConstantType(op, element, elemType)))
638 return failure();
639 }
640 return success();
641 }
642 return op.emitOpError("cannot have attribute: ") << value;
643}
644
645LogicalResult spirv::ConstantOp::verify() {
646 // ODS already generates checks to make sure the result type is valid. We just
647 // need to additionally check that the value's attribute type is consistent
648 // with the result type.
649 return verifyConstantType(*this, getValueAttr(), getType());
650}
651
652bool spirv::ConstantOp::isBuildableWith(Type type) {
653 // Must be valid SPIR-V type first.
654 if (!isa<spirv::SPIRVType>(type))
655 return false;
656
657 if (isa<SPIRVDialect>(type.getDialect())) {
658 if (auto structType = dyn_cast<spirv::StructType>(type))
659 return !structType.isIdentified();
660 return isa<spirv::ArrayType>(type);
661 }
662
663 return true;
664}
665
666spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
667 OpBuilder &builder) {
668 if (auto intType = dyn_cast<IntegerType>(type)) {
669 unsigned width = intType.getWidth();
670 if (width == 1)
671 return spirv::ConstantOp::create(builder, loc, type,
672 builder.getBoolAttr(false));
673 return spirv::ConstantOp::create(
674 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
675 }
676 if (auto floatType = dyn_cast<FloatType>(type)) {
677 return spirv::ConstantOp::create(builder, loc, type,
678 builder.getFloatAttr(floatType, 0.0));
679 }
680 if (auto vectorType = dyn_cast<VectorType>(type)) {
681 Type elemType = vectorType.getElementType();
682 if (isa<IntegerType>(elemType)) {
683 return spirv::ConstantOp::create(
684 builder, loc, type,
685 DenseElementsAttr::get(vectorType,
686 IntegerAttr::get(elemType, 0).getValue()));
687 }
688 if (isa<FloatType>(elemType)) {
689 return spirv::ConstantOp::create(
690 builder, loc, type,
691 DenseFPElementsAttr::get(vectorType,
692 FloatAttr::get(elemType, 0.0).getValue()));
693 }
694 }
695
696 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
697}
698
699spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
700 OpBuilder &builder) {
701 if (auto intType = dyn_cast<IntegerType>(type)) {
702 unsigned width = intType.getWidth();
703 if (width == 1)
704 return spirv::ConstantOp::create(builder, loc, type,
705 builder.getBoolAttr(true));
706 return spirv::ConstantOp::create(
707 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
708 }
709 if (auto floatType = dyn_cast<FloatType>(type)) {
710 return spirv::ConstantOp::create(builder, loc, type,
711 builder.getFloatAttr(floatType, 1.0));
712 }
713 if (auto vectorType = dyn_cast<VectorType>(type)) {
714 Type elemType = vectorType.getElementType();
715 if (isa<IntegerType>(elemType)) {
716 return spirv::ConstantOp::create(
717 builder, loc, type,
718 DenseElementsAttr::get(vectorType,
719 IntegerAttr::get(elemType, 1).getValue()));
720 }
721 if (isa<FloatType>(elemType)) {
722 return spirv::ConstantOp::create(
723 builder, loc, type,
724 DenseFPElementsAttr::get(vectorType,
725 FloatAttr::get(elemType, 1.0).getValue()));
726 }
727 }
728
729 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
730}
731
732void mlir::spirv::ConstantOp::getAsmResultNames(
733 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
734 Type type = getType();
735
736 SmallString<32> specialNameBuffer;
737 llvm::raw_svector_ostream specialName(specialNameBuffer);
738 specialName << "cst";
739
740 IntegerType intTy = dyn_cast<IntegerType>(type);
741
742 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
743 assert(intTy);
744
745 if (intTy.getWidth() == 1) {
746 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
747 }
748
749 if (intTy.isSignless()) {
750 specialName << intCst.getInt();
751 } else if (intTy.isUnsigned()) {
752 specialName << intCst.getUInt();
753 } else {
754 specialName << intCst.getSInt();
755 }
756 }
757
758 if (intTy || isa<FloatType>(type)) {
759 specialName << '_' << type;
760 }
761
762 if (auto vecType = dyn_cast<VectorType>(type)) {
763 specialName << "_vec_";
764 specialName << vecType.getDimSize(0);
765
766 Type elementType = vecType.getElementType();
767
768 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
769 specialName << "x" << elementType;
770 }
771 }
772
773 setNameFn(getResult(), specialName.str());
774}
775
776void mlir::spirv::AddressOfOp::getAsmResultNames(
777 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
778 SmallString<32> specialNameBuffer;
779 llvm::raw_svector_ostream specialName(specialNameBuffer);
780 specialName << getVariable() << "_addr";
781 setNameFn(getResult(), specialName.str());
782}
783
784//===----------------------------------------------------------------------===//
785// spirv.EXTConstantCompositeReplicate
786//===----------------------------------------------------------------------===//
787
788// Returns type of attribute. In case of a TypedAttr this will simply return
789// the type. But for an ArrayAttr which is untyped and can be multidimensional
790// it creates the ArrayType recursively.
792 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
793 return typedAttr.getType();
794 }
795
796 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
797 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
798 }
799
800 return nullptr;
801}
802
803LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
804 Type valueType = getValueType(getValue());
805 if (!valueType)
806 return emitError("unknown value attribute type");
807
808 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
809 if (!compositeType)
810 return emitError("result type is not a composite type");
811
812 Type compositeElementType = compositeType.getElementType(0);
813
814 SmallVector<Type, 3> possibleTypes = {compositeElementType};
815 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
816 compositeElementType = type.getElementType(0);
817 possibleTypes.push_back(compositeElementType);
818 }
819
820 if (!is_contained(possibleTypes, valueType)) {
821 return emitError("expected value attribute type ")
822 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
823 }
824
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// spirv.ControlBarrierOp
830//===----------------------------------------------------------------------===//
831
832LogicalResult spirv::ControlBarrierOp::verify() {
833 return verifyMemorySemantics(getOperation(), getMemorySemantics());
834}
835
836//===----------------------------------------------------------------------===//
837// spirv.EntryPoint
838//===----------------------------------------------------------------------===//
839
840void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
841 spirv::ExecutionModel executionModel,
842 spirv::FuncOp function,
843 ArrayRef<Attribute> interfaceVars) {
844 build(builder, state,
845 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
846 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
847}
848
849ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
851 spirv::ExecutionModel execModel;
852 SmallVector<Attribute, 4> interfaceVars;
853
855 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
856 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
857 return failure();
858 }
859
860 if (!parser.parseOptionalComma()) {
861 // Parse the interface variables
862 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
863 // The name of the interface variable attribute isnt important
864 FlatSymbolRefAttr var;
865 NamedAttrList attrs;
866 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
867 return failure();
868 interfaceVars.push_back(var);
869 return success();
870 }))
871 return failure();
872 }
873 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
874 parser.getBuilder().getArrayAttr(interfaceVars));
875 return success();
876}
877
878void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
879 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
880 printer.printSymbolName(getFn());
881 auto interfaceVars = getInterface().getValue();
882 if (!interfaceVars.empty())
883 printer << ", " << llvm::interleaved(interfaceVars);
884}
885
886LogicalResult spirv::EntryPointOp::verify() {
887 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
888 // verification.
889 return success();
890}
891
892//===----------------------------------------------------------------------===//
893// spirv.ExecutionMode
894//===----------------------------------------------------------------------===//
895
896void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
897 spirv::FuncOp function,
898 spirv::ExecutionMode executionMode,
899 ArrayRef<int32_t> params) {
900 build(builder, state, SymbolRefAttr::get(function),
901 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
902 builder.getI32ArrayAttr(params));
903}
904
905ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
907 spirv::ExecutionMode execMode;
908 Attribute fn;
909 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
911 return failure();
912 }
913
915 Type i32Type = parser.getBuilder().getIntegerType(32);
916 while (!parser.parseOptionalComma()) {
917 NamedAttrList attr;
918 Attribute value;
919 if (parser.parseAttribute(value, i32Type, "value", attr)) {
920 return failure();
921 }
922 values.push_back(cast<IntegerAttr>(value).getInt());
923 }
924 StringRef valuesAttrName =
925 spirv::ExecutionModeOp::getValuesAttrName(result.name);
926 result.addAttribute(valuesAttrName,
927 parser.getBuilder().getI32ArrayAttr(values));
928 return success();
929}
930
931void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
932 printer << " ";
933 printer.printSymbolName(getFn());
934 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
935 ArrayAttr values = this->getValues();
936 if (!values.empty())
937 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
938}
939
940//===----------------------------------------------------------------------===//
941// spirv.ExecutionModeId
942//===----------------------------------------------------------------------===//
943
944ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
946 ExecutionMode execMode;
947 if (Attribute fn;
948 parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
949 parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
950 return failure();
951 }
952
954 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
955 FlatSymbolRefAttr attr;
956 if (parser.parseAttribute(attr))
957 return failure();
958 values.push_back(attr);
959 return success();
960 })) {
961 return failure();
962 }
963
964 StringRef valuesAttrName = getValuesAttrName(result.name);
965 ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
966 result.addAttribute(valuesAttrName, valuesAttr);
967 return success();
968}
969
970void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
971 printer << " ";
972 printer.printSymbolName(getFn());
973 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
974
975 llvm::interleaveComma(
976 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
977 [&](StringRef value) { printer.printSymbolName(value); });
978}
979
980LogicalResult spirv::ExecutionModeIdOp::verify() {
981 // Valid as of SPIRV 1.6
982 switch (getExecutionMode()) {
983 case ExecutionMode::SubgroupsPerWorkgroupId:
984 case ExecutionMode::LocalSizeId:
985 case ExecutionMode::LocalSizeHintId:
986 break;
987 default:
988 return emitOpError("expected ExecutionMode that takes extra operands that "
989 "are <id> operands, got: ")
990 << stringifyExecutionMode(getExecutionMode());
991 }
992
993 if (getValues().empty())
994 return emitOpError("expected at least one value operand");
995
996 for (Attribute value : getValues()) {
997 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
998 if (!valueSymbol)
999 return emitOpError("expected value operands to be symbol reference");
1001 (*this)->getParentOp(), valueSymbol);
1002 if (!valueOp)
1003 return emitOpError("cannot find symbol referenced by value operand: ")
1004 << valueSymbol.getValue();
1005 }
1006
1007 return success();
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// spirv.func
1012//===----------------------------------------------------------------------===//
1013
1014ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1016 SmallVector<DictionaryAttr> resultAttrs;
1017 SmallVector<Type> resultTypes;
1018 auto &builder = parser.getBuilder();
1019
1020 // Parse the name as a symbol.
1021 StringAttr nameAttr;
1022 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1023 result.attributes))
1024 return failure();
1025
1026 // Parse the function signature.
1027 bool isVariadic = false;
1029 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1030 resultAttrs))
1031 return failure();
1032
1033 SmallVector<Type> argTypes;
1034 for (auto &arg : entryArgs)
1035 argTypes.push_back(arg.type);
1036 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1037 result.addAttribute(getFunctionTypeAttrName(result.name),
1038 TypeAttr::get(fnType));
1039
1040 // Parse the optional function control keyword.
1041 spirv::FunctionControl fnControl;
1043 return failure();
1044
1045 // If additional attributes are present, parse them.
1046 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1047 return failure();
1048
1049 // Add the attributes to the function arguments.
1050 assert(resultAttrs.size() == resultTypes.size());
1052 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1053 getResAttrsAttrName(result.name));
1054
1055 // Parse the optional function body.
1056 auto *body = result.addRegion();
1057 OptionalParseResult parseResult =
1058 parser.parseOptionalRegion(*body, entryArgs);
1059 return failure(parseResult.has_value() && failed(*parseResult));
1060}
1061
1062void spirv::FuncOp::print(OpAsmPrinter &printer) {
1063 // Print function name, signature, and control.
1064 printer << " ";
1065 printer.printSymbolName(getSymName());
1066 auto fnType = getFunctionType();
1068 printer, *this, fnType.getInputs(),
1069 /*isVariadic=*/false, fnType.getResults());
1070 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
1071 << "\"";
1073 printer, *this,
1074 {spirv::attributeName<spirv::FunctionControl>(),
1075 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1076 getFunctionControlAttrName()});
1077
1078 // Print the body if this is not an external function.
1079 Region &body = this->getBody();
1080 if (!body.empty()) {
1081 printer << ' ';
1082 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1083 /*printBlockTerminators=*/true);
1084 }
1085}
1086
1087LogicalResult spirv::FuncOp::verifyType() {
1088 FunctionType fnType = getFunctionType();
1089 if (fnType.getNumResults() > 1)
1090 return emitOpError("cannot have more than one result");
1091
1092 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1093 unsigned argIndex) {
1094 auto func = cast<FunctionOpInterface>(getOperation());
1095 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1096 if (argAttr.getName() != spirv::DecorationAttr::name)
1097 continue;
1098 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1099 return decAttr.getValue() == decoration;
1100 }
1101 return false;
1102 };
1103
1104 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1105 Type param = fnType.getInputs()[i];
1106 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1107 if (!inputPtrType)
1108 continue;
1109
1110 auto pointeePtrType =
1111 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1112 if (pointeePtrType) {
1113 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1114 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1115 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1116 // > storage class, the function parameter must be decorated with exactly
1117 // > one of AliasedPointer or RestrictPointer.
1118 if (pointeePtrType.getStorageClass() !=
1119 spirv::StorageClass::PhysicalStorageBuffer)
1120 continue;
1121
1122 bool hasAliasedPtr =
1123 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1124 bool hasRestrictPtr =
1125 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1126 if (!hasAliasedPtr && !hasRestrictPtr)
1127 return emitOpError()
1128 << "with a pointer points to a physical buffer pointer must "
1129 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1130 continue;
1131 }
1132 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1133 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1134 // > the PhysicalStorageBuffer storage class, the function parameter must
1135 // > be decorated with exactly one of Aliased or Restrict.
1136 if (auto pointeeArrayType =
1137 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1138 pointeePtrType =
1139 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1140 } else {
1141 pointeePtrType = inputPtrType;
1142 }
1143
1144 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1145 spirv::StorageClass::PhysicalStorageBuffer)
1146 continue;
1147
1148 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1149 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1150 if (!hasAliased && !hasRestrict)
1151 return emitOpError() << "with physical buffer pointer must be decorated "
1152 "either 'Aliased' or 'Restrict'";
1153 }
1154
1155 return success();
1156}
1157
1158LogicalResult spirv::FuncOp::verifyBody() {
1159 FunctionType fnType = getFunctionType();
1160 if (!isExternal()) {
1161 Block &entryBlock = front();
1162
1163 unsigned numArguments = this->getNumArguments();
1164 if (entryBlock.getNumArguments() != numArguments)
1165 return emitOpError("entry block must have ")
1166 << numArguments << " arguments to match function signature";
1167
1168 for (auto [index, fnArgType, blockArgType] :
1169 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1170 if (blockArgType != fnArgType) {
1171 return emitOpError("type of entry block argument #")
1172 << index << '(' << blockArgType
1173 << ") must match the type of the corresponding argument in "
1174 << "function signature(" << fnArgType << ')';
1175 }
1176 }
1177 }
1178
1179 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1180 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1181 if (fnType.getNumResults() != 0)
1182 return retOp.emitOpError("cannot be used in functions returning value");
1183 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1184 if (fnType.getNumResults() != 1)
1185 return retOp.emitOpError(
1186 "returns 1 value but enclosing function requires ")
1187 << fnType.getNumResults() << " results";
1188
1189 auto retOperandType = retOp.getValue().getType();
1190 auto fnResultType = fnType.getResult(0);
1191 if (retOperandType != fnResultType)
1192 return retOp.emitOpError(" return value's type (")
1193 << retOperandType << ") mismatch with function's result type ("
1194 << fnResultType << ")";
1195 }
1196 return WalkResult::advance();
1197 });
1198
1199 // TODO: verify other bits like linkage type.
1200
1201 return failure(walkResult.wasInterrupted());
1202}
1203
1204void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1205 StringRef name, FunctionType type,
1206 spirv::FunctionControl control,
1209 builder.getStringAttr(name));
1210 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1211 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1212 builder.getAttr<spirv::FunctionControlAttr>(control));
1213 state.attributes.append(attrs.begin(), attrs.end());
1214 state.addRegion();
1215}
1216
1217//===----------------------------------------------------------------------===//
1218// spirv.GLFClampOp
1219//===----------------------------------------------------------------------===//
1220
1221ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1224}
1225void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1226
1227//===----------------------------------------------------------------------===//
1228// spirv.GLUClampOp
1229//===----------------------------------------------------------------------===//
1230
1231ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1234}
1235void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1236
1237//===----------------------------------------------------------------------===//
1238// spirv.GLSClampOp
1239//===----------------------------------------------------------------------===//
1240
1241ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1244}
1245void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1246
1247//===----------------------------------------------------------------------===//
1248// spirv.GLFmaOp
1249//===----------------------------------------------------------------------===//
1250
1251ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1253}
1254void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1255
1256//===----------------------------------------------------------------------===//
1257// spirv.GlobalVariable
1258//===----------------------------------------------------------------------===//
1259
1260void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1261 Type type, StringRef name,
1262 unsigned descriptorSet, unsigned binding) {
1263 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1264 state.addAttribute(
1265 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1266 builder.getI32IntegerAttr(descriptorSet));
1267 state.addAttribute(
1268 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1269 builder.getI32IntegerAttr(binding));
1270}
1271
1272void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1273 Type type, StringRef name,
1274 spirv::BuiltIn builtin) {
1275 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1276 state.addAttribute(
1277 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1278 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1279}
1280
1281ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1283 // Parse variable name.
1284 StringAttr nameAttr;
1285 StringRef initializerAttrName =
1286 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1287 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1288 result.attributes)) {
1289 return failure();
1290 }
1291
1292 // Parse optional initializer
1293 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1294 FlatSymbolRefAttr initSymbol;
1295 if (parser.parseLParen() ||
1296 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1297 result.attributes) ||
1298 parser.parseRParen())
1299 return failure();
1300 }
1301
1302 if (parseVariableDecorations(parser, result)) {
1303 return failure();
1304 }
1305
1306 Type type;
1307 StringRef typeAttrName =
1308 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1309 auto loc = parser.getCurrentLocation();
1310 if (parser.parseColonType(type)) {
1311 return failure();
1312 }
1313 if (!isa<spirv::PointerType>(type)) {
1314 return parser.emitError(loc, "expected spirv.ptr type");
1315 }
1316 result.addAttribute(typeAttrName, TypeAttr::get(type));
1317
1318 return success();
1319}
1320
1321void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1322 SmallVector<StringRef, 4> elidedAttrs{
1323 spirv::attributeName<spirv::StorageClass>()};
1324
1325 // Print variable name.
1326 printer << ' ';
1327 printer.printSymbolName(getSymName());
1328 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1329
1330 StringRef initializerAttrName = this->getInitializerAttrName();
1331 // Print optional initializer
1332 if (auto initializer = this->getInitializer()) {
1333 printer << " " << initializerAttrName << '(';
1334 printer.printSymbolName(*initializer);
1335 printer << ')';
1336 elidedAttrs.push_back(initializerAttrName);
1337 }
1338
1339 StringRef typeAttrName = this->getTypeAttrName();
1340 elidedAttrs.push_back(typeAttrName);
1341 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1342 printer << " : " << getType();
1343}
1344
1345LogicalResult spirv::GlobalVariableOp::verify() {
1346 if (!isa<spirv::PointerType>(getType()))
1347 return emitOpError("result must be of a !spv.ptr type");
1348
1349 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1350 // object. It cannot be Generic. It must be the same as the Storage Class
1351 // operand of the Result Type."
1352 // Also, Function storage class is reserved by spirv.Variable.
1353 auto storageClass = this->storageClass();
1354 if (storageClass == spirv::StorageClass::Generic ||
1355 storageClass == spirv::StorageClass::Function) {
1356 return emitOpError("storage class cannot be '")
1357 << stringifyStorageClass(storageClass) << "'";
1358 }
1359
1360 // SPIR-V spec: "A module-scope OpVariable with an Initializer operand must
1361 // not be decorated with the Import Linkage Type."
1362 if (std::optional<spirv::LinkageAttributesAttr> linkage =
1363 getLinkageAttributes()) {
1364 if (linkage->getLinkageType().getValue() == spirv::LinkageType::Import &&
1365 getInitializer()) {
1366 return emitOpError(
1367 "with Import linkage type must not have an initializer");
1368 }
1369 }
1370
1371 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1372 this->getInitializerAttrName())) {
1374 (*this)->getParentOp(), init.getAttr());
1375 // TODO: Currently only variable initialization with specialization
1376 // constants is supported. There could be normal constants in the module
1377 // scope as well.
1378 //
1379 // In the current setup we also cannot initialize one global variable with
1380 // another. The problem is that if we try to initialize pointer of type X
1381 // with another pointer type, the validator fails because it expects the
1382 // variable to be initialized to be type X, not pointer to X. Now
1383 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1384 // we cannot initialize one `spirv.GlobalVariable` with another.
1385 if (!initOp ||
1386 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1387 return emitOpError("initializer must be result of a "
1388 "spirv.SpecConstant or "
1389 "spirv.SpecConstantCompositeOp op");
1390 }
1391 }
1392
1393 return success();
1394}
1395
1396//===----------------------------------------------------------------------===//
1397// spirv.INTEL.SubgroupBlockRead
1398//===----------------------------------------------------------------------===//
1399
1400LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1401 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1402 return failure();
1403
1404 return success();
1405}
1406
1407//===----------------------------------------------------------------------===//
1408// spirv.INTEL.SubgroupBlockWrite
1409//===----------------------------------------------------------------------===//
1410
1411ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1413 // Parse the storage class specification
1414 spirv::StorageClass storageClass;
1416 auto loc = parser.getCurrentLocation();
1417 Type elementType;
1418 if (parseEnumStrAttr(storageClass, parser) ||
1419 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1420 parser.parseType(elementType)) {
1421 return failure();
1422 }
1423
1424 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1425 if (auto valVecTy = dyn_cast<VectorType>(elementType))
1426 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1427
1428 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1429 result.operands)) {
1430 return failure();
1431 }
1432 return success();
1433}
1434
1435void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1436 printer << " " << getPtr() << ", " << getValue() << " : "
1437 << getValue().getType();
1438}
1439
1440LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1441 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1442 return failure();
1443
1444 return success();
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// spirv.IAddCarryOp
1449//===----------------------------------------------------------------------===//
1450
1451LogicalResult spirv::IAddCarryOp::verify() {
1452 return ::verifyArithmeticExtendedBinaryOp(*this);
1453}
1454
1455ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1457 return ::parseArithmeticExtendedBinaryOp(parser, result);
1458}
1459
1460void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1461 ::printArithmeticExtendedBinaryOp(*this, printer);
1462}
1463
1464//===----------------------------------------------------------------------===//
1465// spirv.ISubBorrowOp
1466//===----------------------------------------------------------------------===//
1467
1468LogicalResult spirv::ISubBorrowOp::verify() {
1469 return ::verifyArithmeticExtendedBinaryOp(*this);
1470}
1471
1472ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1474 return ::parseArithmeticExtendedBinaryOp(parser, result);
1475}
1476
1477void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1478 ::printArithmeticExtendedBinaryOp(*this, printer);
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// spirv.SMulExtended
1483//===----------------------------------------------------------------------===//
1484
1485LogicalResult spirv::SMulExtendedOp::verify() {
1486 return ::verifyArithmeticExtendedBinaryOp(*this);
1487}
1488
1489ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1491 return ::parseArithmeticExtendedBinaryOp(parser, result);
1492}
1493
1494void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1495 ::printArithmeticExtendedBinaryOp(*this, printer);
1496}
1497
1498//===----------------------------------------------------------------------===//
1499// spirv.UMulExtended
1500//===----------------------------------------------------------------------===//
1501
1502LogicalResult spirv::UMulExtendedOp::verify() {
1503 return ::verifyArithmeticExtendedBinaryOp(*this);
1504}
1505
1506ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1508 return ::parseArithmeticExtendedBinaryOp(parser, result);
1509}
1510
1511void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1512 ::printArithmeticExtendedBinaryOp(*this, printer);
1513}
1514
1515//===----------------------------------------------------------------------===//
1516// spirv.MemoryBarrierOp
1517//===----------------------------------------------------------------------===//
1518
1519LogicalResult spirv::MemoryBarrierOp::verify() {
1520 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1521}
1522
1523//===----------------------------------------------------------------------===//
1524// spirv.MemoryNamedBarrierOp
1525//===----------------------------------------------------------------------===//
1526
1527LogicalResult spirv::MemoryNamedBarrierOp::verify() {
1528 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// spirv.module
1533//===----------------------------------------------------------------------===//
1534
1535void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1536 std::optional<StringRef> name) {
1537 OpBuilder::InsertionGuard guard(builder);
1538 builder.createBlock(state.addRegion());
1539 if (name) {
1541 builder.getStringAttr(*name));
1542 }
1543}
1544
1545void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1546 spirv::AddressingModel addressingModel,
1547 spirv::MemoryModel memoryModel,
1548 std::optional<VerCapExtAttr> vceTriple,
1549 std::optional<StringRef> name) {
1550 state.addAttribute(
1551 "addressing_model",
1552 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1553 state.addAttribute("memory_model",
1554 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1555 OpBuilder::InsertionGuard guard(builder);
1556 builder.createBlock(state.addRegion());
1557 if (vceTriple)
1558 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1559 if (name)
1561 builder.getStringAttr(*name));
1562}
1563
1564ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1566 Region *body = result.addRegion();
1567
1568 // If the name is present, parse it.
1569 StringAttr nameAttr;
1571 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1572
1573 // Parse attributes
1574 spirv::AddressingModel addrModel;
1575 spirv::MemoryModel memoryModel;
1577 result) ||
1579 result))
1580 return failure();
1581
1582 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1583 spirv::VerCapExtAttr vceTriple;
1584 if (parser.parseAttribute(vceTriple,
1585 spirv::ModuleOp::getVCETripleAttrName(),
1586 result.attributes))
1587 return failure();
1588 }
1589
1590 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1591 parser.parseRegion(*body, /*arguments=*/{}))
1592 return failure();
1593
1594 // Make sure we have at least one block.
1595 if (body->empty())
1596 body->push_back(new Block());
1597
1598 return success();
1599}
1600
1601void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1602 if (std::optional<StringRef> name = getName()) {
1603 printer << ' ';
1604 printer.printSymbolName(*name);
1605 }
1606
1607 SmallVector<StringRef, 2> elidedAttrs;
1608
1609 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1610 << spirv::stringifyMemoryModel(getMemoryModel());
1611 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1612 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1613 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1615
1616 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1617 printer << " requires " << *triple;
1618 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1619 }
1620
1621 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1622 printer << ' ';
1623 printer.printRegion(getRegion());
1624}
1625
1626LogicalResult spirv::ModuleOp::verifyRegions() {
1627 Dialect *dialect = (*this)->getDialect();
1629 entryPoints;
1630 mlir::SymbolTable table(*this);
1631
1632 for (auto &op : *getBody()) {
1633 if (op.getDialect() != dialect)
1634 return op.emitError("'spirv.module' can only contain spirv.* ops");
1635
1636 // For EntryPoint op, check that the function and execution model is not
1637 // duplicated in EntryPointOps. Also verify that the interface specified
1638 // comes from globalVariables here to make this check cheaper.
1639 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1640 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1641 if (!funcOp) {
1642 return entryPointOp.emitError("function '")
1643 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1644 }
1645 if (auto interface = entryPointOp.getInterface()) {
1646 for (Attribute varRef : interface) {
1647 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1648 if (!varSymRef) {
1649 return entryPointOp.emitError(
1650 "expected symbol reference for interface "
1651 "specification instead of '")
1652 << varRef;
1653 }
1654 auto variableOp =
1655 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1656 if (!variableOp) {
1657 return entryPointOp.emitError("expected spirv.GlobalVariable "
1658 "symbol reference instead of'")
1659 << varSymRef << "'";
1660 }
1661 }
1662 }
1663
1664 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1665 funcOp, entryPointOp.getExecutionModel());
1666 if (!entryPoints.try_emplace(key, entryPointOp).second)
1667 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1668 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1669 // If the function is external and does not have 'Import'
1670 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1671 // LinkageAttributes is used to import external functions.
1672 auto linkageAttr = funcOp.getLinkageAttributes();
1673 auto hasImportLinkage =
1674 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1675 spirv::LinkageType::Import);
1676 if (funcOp.isExternal() && !hasImportLinkage)
1677 return op.emitError(
1678 "'spirv.module' cannot contain external functions "
1679 "without 'Import' linkage_attributes (LinkageAttributes)");
1680
1681 // TODO: move this check to spirv.func.
1682 for (auto &block : funcOp)
1683 for (auto &op : block) {
1684 if (op.getDialect() != dialect)
1685 return op.emitError(
1686 "functions in 'spirv.module' can only contain spirv.* ops");
1687 }
1688 }
1689 }
1690
1691 return success();
1692}
1693
1694//===----------------------------------------------------------------------===//
1695// spirv.mlir.referenceof
1696//===----------------------------------------------------------------------===//
1697
1698LogicalResult spirv::ReferenceOfOp::verify() {
1699 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1700 (*this)->getParentOp(), getSpecConstAttr());
1701 Type constType;
1702
1703 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1704 if (specConstOp)
1705 constType = specConstOp.getDefaultValue().getType();
1706
1707 auto specConstCompositeOp =
1708 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1709 if (specConstCompositeOp)
1710 constType = specConstCompositeOp.getType();
1711
1712 if (!specConstOp && !specConstCompositeOp)
1713 return emitOpError(
1714 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1715
1716 if (getReference().getType() != constType)
1717 return emitOpError("result type mismatch with the referenced "
1718 "specialization constant's type");
1719
1720 return success();
1721}
1722
1723//===----------------------------------------------------------------------===//
1724// spirv.SpecConstant
1725//===----------------------------------------------------------------------===//
1726
1727ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1729 StringAttr nameAttr;
1730 Attribute valueAttr;
1731 StringRef defaultValueAttrName =
1732 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1733
1734 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1735 result.attributes))
1736 return failure();
1737
1738 // Parse optional spec_id.
1739 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1740 IntegerAttr specIdAttr;
1741 if (parser.parseLParen() ||
1742 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1743 parser.parseRParen())
1744 return failure();
1745 }
1746
1747 if (parser.parseEqual() ||
1748 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1749 return failure();
1750
1751 return success();
1752}
1753
1754void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1755 printer << ' ';
1756 printer.printSymbolName(getSymName());
1757 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1758 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1759 printer << " = " << getDefaultValue();
1760}
1761
1762LogicalResult spirv::SpecConstantOp::verify() {
1763 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1764 if (specID.getValue().isNegative())
1765 return emitOpError("SpecId cannot be negative");
1766
1767 auto value = getDefaultValue();
1768 if (isa<IntegerAttr, FloatAttr>(value)) {
1769 // Make sure bitwidth is allowed.
1770 if (!isa<spirv::SPIRVType>(value.getType()))
1771 return emitOpError("default value bitwidth disallowed");
1772 return success();
1773 }
1774 return emitOpError(
1775 "default value can only be a bool, integer, or float scalar");
1776}
1777
1778//===----------------------------------------------------------------------===//
1779// spirv.VectorShuffle
1780//===----------------------------------------------------------------------===//
1781
1782LogicalResult spirv::VectorShuffleOp::verify() {
1783 VectorType resultType = cast<VectorType>(getType());
1784
1785 size_t numResultElements = resultType.getNumElements();
1786 if (numResultElements != getComponents().size())
1787 return emitOpError("result type element count (")
1788 << numResultElements
1789 << ") mismatch with the number of component selectors ("
1790 << getComponents().size() << ")";
1791
1792 size_t totalSrcElements =
1793 cast<VectorType>(getVector1().getType()).getNumElements() +
1794 cast<VectorType>(getVector2().getType()).getNumElements();
1795
1796 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1797 uint32_t index = selector.getZExtValue();
1798 if (index >= totalSrcElements &&
1799 index != std::numeric_limits<uint32_t>().max())
1800 return emitOpError("component selector ")
1801 << index << " out of range: expected to be in [0, "
1802 << totalSrcElements << ") or 0xffffffff";
1803 }
1804 return success();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// spirv.SpecConstantComposite
1809//===----------------------------------------------------------------------===//
1810
1811ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1813
1814 StringAttr compositeName;
1815 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1816 result.attributes))
1817 return failure();
1818
1819 if (parser.parseLParen())
1820 return failure();
1821
1822 SmallVector<Attribute, 4> constituents;
1823
1824 do {
1825 // The name of the constituent attribute isn't important
1826 const char *attrName = "spec_const";
1827 FlatSymbolRefAttr specConstRef;
1828 NamedAttrList attrs;
1829
1830 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1831 return failure();
1832
1833 constituents.push_back(specConstRef);
1834 } while (!parser.parseOptionalComma());
1835
1836 if (parser.parseRParen())
1837 return failure();
1838
1839 StringAttr compositeSpecConstituentsName =
1840 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1841 result.addAttribute(compositeSpecConstituentsName,
1842 parser.getBuilder().getArrayAttr(constituents));
1843
1844 Type type;
1845 if (parser.parseColonType(type))
1846 return failure();
1847
1848 StringAttr typeAttrName =
1849 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1850 result.addAttribute(typeAttrName, TypeAttr::get(type));
1851
1852 return success();
1853}
1854
1855void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1856 printer << " ";
1857 printer.printSymbolName(getSymName());
1858 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1859 << ") : " << getType();
1860}
1861
1862LogicalResult spirv::SpecConstantCompositeOp::verify() {
1863 auto cType = dyn_cast<spirv::CompositeType>(getType());
1864 auto constituents = this->getConstituents().getValue();
1865
1866 if (!cType)
1867 return emitError("result type must be a composite type, but provided ")
1868 << getType();
1869
1870 if (isa<spirv::CooperativeMatrixType>(cType))
1871 return emitError("unsupported composite type ") << cType;
1872 if (constituents.size() != cType.getNumElements())
1873 return emitError("has incorrect number of operands: expected ")
1874 << cType.getNumElements() << ", but provided "
1875 << constituents.size();
1876
1877 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1878 auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
1879
1881 (*this)->getParentOp(), constituent.getAttr());
1882
1883 if (!constituentOp)
1884 return emitError("unknown constituent symbol ") << constituent.getAttr();
1885
1886 Type constituentType;
1887 if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp)) {
1888 constituentType = specConstOp.getDefaultValue().getType();
1889 } else if (auto specConstCompositeOp =
1890 dyn_cast<spirv::SpecConstantCompositeOp>(constituentOp)) {
1891 constituentType = specConstCompositeOp.getType();
1892 } else {
1893 return emitError("unsupported constituent ")
1894 << constituent.getAttr()
1895 << ": must reference a spirv.SpecConstant or "
1896 "spirv.SpecConstantComposite";
1897 }
1898
1899 if (constituentType != cType.getElementType(index))
1900 return emitError("has incorrect types of operands: expected ")
1901 << cType.getElementType(index) << ", but provided "
1902 << constituentType;
1903 }
1904
1905 return success();
1906}
1907
1908//===----------------------------------------------------------------------===//
1909// spirv.EXTSpecConstantCompositeReplicateOp
1910//===----------------------------------------------------------------------===//
1911
1912ParseResult
1913spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1915 StringAttr compositeName;
1916 FlatSymbolRefAttr specConstRef;
1917 const char *attrName = "spec_const";
1918 NamedAttrList attrs;
1919 Type type;
1920
1921 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1922 result.attributes) ||
1923 parser.parseLParen() ||
1924 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1925 parser.parseRParen() || parser.parseColonType(type))
1926 return failure();
1927
1928 StringAttr compositeSpecConstituentName =
1929 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1930 result.name);
1931 result.addAttribute(compositeSpecConstituentName, specConstRef);
1932
1933 StringAttr typeAttrName =
1934 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1935 result.addAttribute(typeAttrName, TypeAttr::get(type));
1936
1937 return success();
1938}
1939
1940void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1941 printer << " ";
1942 printer.printSymbolName(getSymName());
1943 printer << " (" << this->getConstituent() << ") : " << getType();
1944}
1945
1946LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1947 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1948 if (!compositeType)
1949 return emitError("result type must be a composite type, but provided ")
1950 << getType();
1951
1953 (*this)->getParentOp(), this->getConstituent());
1954 if (!constituentOp)
1955 return emitError(
1956 "splat spec constant reference defining constituent not found");
1957
1958 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1959 if (!constituentSpecConstOp)
1960 return emitError("constituent is not a spec constant");
1961
1962 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1963 Type compositeElementType = compositeType.getElementType(0);
1964 if (constituentType != compositeElementType)
1965 return emitError("constituent has incorrect type: expected ")
1966 << compositeElementType << ", but provided " << constituentType;
1967
1968 return success();
1969}
1970
1971//===----------------------------------------------------------------------===//
1972// spirv.SpecConstantOperation
1973//===----------------------------------------------------------------------===//
1974
1975ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1977 Region *body = result.addRegion();
1978
1979 if (parser.parseKeyword("wraps"))
1980 return failure();
1981
1982 body->push_back(new Block);
1983 Block &block = body->back();
1984 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1985
1986 if (!wrappedOp)
1987 return failure();
1988
1989 OpBuilder builder(parser.getContext());
1990 builder.setInsertionPointToEnd(&block);
1991 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
1992 result.location = wrappedOp->getLoc();
1993
1994 result.addTypes(wrappedOp->getResult(0).getType());
1995
1996 if (parser.parseOptionalAttrDict(result.attributes))
1997 return failure();
1998
1999 return success();
2000}
2001
2002void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
2003 printer << " wraps ";
2004 printer.printGenericOp(&getBody().front().front());
2005}
2006
2007LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2008 Block &block = getRegion().getBlocks().front();
2009
2010 if (block.getOperations().size() != 2)
2011 return emitOpError("expected exactly 2 nested ops");
2012
2013 Operation &enclosedOp = block.getOperations().front();
2014
2016 return emitOpError("invalid enclosed op");
2017
2018 for (auto operand : enclosedOp.getOperands())
2019 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2020 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2021 return emitOpError(
2022 "invalid operand, must be defined by a constant operation");
2023
2024 return success();
2025}
2026
2027//===----------------------------------------------------------------------===//
2028// spirv.GL.FrexpStruct
2029//===----------------------------------------------------------------------===//
2030
2031LogicalResult spirv::GLFrexpStructOp::verify() {
2032 spirv::StructType structTy =
2033 dyn_cast<spirv::StructType>(getResult().getType());
2034
2035 if (structTy.getNumElements() != 2)
2036 return emitError("result type must be a struct type with two memebers");
2037
2038 Type significandTy = structTy.getElementType(0);
2039 Type exponentTy = structTy.getElementType(1);
2040 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2041 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2042
2043 Type operandTy = getOperand().getType();
2044 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2045 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2046
2047 if (significandTy != operandTy)
2048 return emitError("member zero of the resulting struct type must be the "
2049 "same type as the operand");
2050
2051 if (exponentVecTy) {
2052 IntegerType componentIntTy =
2053 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2054 if (!componentIntTy || componentIntTy.getWidth() != 32)
2055 return emitError("member one of the resulting struct type must"
2056 "be a scalar or vector of 32 bit integer type");
2057 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2058 return emitError("member one of the resulting struct type "
2059 "must be a scalar or vector of 32 bit integer type");
2060 }
2061
2062 // Check that the two member types have the same number of components
2063 if (operandVecTy && exponentVecTy &&
2064 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2065 return success();
2066
2067 if (operandFTy && exponentIntTy)
2068 return success();
2069
2070 return emitError("member one of the resulting struct type must have the same "
2071 "number of components as the operand type");
2072}
2073
2074//===----------------------------------------------------------------------===//
2075// spirv.GL.Ldexp
2076//===----------------------------------------------------------------------===//
2077
2078static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
2079 Type integerType) {
2080 if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
2081 return op->emitOpError("operands must both be scalars or vectors");
2082
2083 auto getNumElements = [](Type type) -> unsigned {
2084 if (auto vectorType = dyn_cast<VectorType>(type))
2085 return vectorType.getNumElements();
2086 return 1;
2087 };
2088
2089 if (getNumElements(floatType) != getNumElements(integerType))
2090 return op->emitOpError("operands must have the same number of elements");
2091
2092 return success();
2093}
2094
2095LogicalResult spirv::GLLdexpOp::verify() {
2096 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2097 getExp().getType());
2098}
2099
2100//===----------------------------------------------------------------------===//
2101// spirv.CL.ldexp
2102//===----------------------------------------------------------------------===//
2103
2104LogicalResult spirv::CLLdexpOp::verify() {
2105 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2106 getExp().getType());
2107}
2108
2109//===----------------------------------------------------------------------===//
2110// spirv.CL.pown
2111//===----------------------------------------------------------------------===//
2112
2113LogicalResult spirv::CLPownOp::verify() {
2114 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2115 getY().getType());
2116}
2117
2118//===----------------------------------------------------------------------===//
2119// spirv.CL.rootn
2120//===----------------------------------------------------------------------===//
2121
2122LogicalResult spirv::CLRootnOp::verify() {
2123 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2124 getN().getType());
2125}
2126
2127//===----------------------------------------------------------------------===//
2128// spirv.ShiftLeftLogicalOp
2129//===----------------------------------------------------------------------===//
2130
2131LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2132 return verifyShiftOp(*this);
2133}
2134
2135//===----------------------------------------------------------------------===//
2136// spirv.ShiftRightArithmeticOp
2137//===----------------------------------------------------------------------===//
2138
2139LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2140 return verifyShiftOp(*this);
2141}
2142
2143//===----------------------------------------------------------------------===//
2144// spirv.ShiftRightLogicalOp
2145//===----------------------------------------------------------------------===//
2146
2147LogicalResult spirv::ShiftRightLogicalOp::verify() {
2148 return verifyShiftOp(*this);
2149}
2150
2151//===----------------------------------------------------------------------===//
2152// spirv.VectorTimesScalarOp
2153//===----------------------------------------------------------------------===//
2154
2155LogicalResult spirv::VectorTimesScalarOp::verify() {
2156 if (getVector().getType() != getType())
2157 return emitOpError("vector operand and result type mismatch");
2158 auto scalarType = cast<VectorType>(getType()).getElementType();
2159 if (getScalar().getType() != scalarType)
2160 return emitOpError("scalar operand and result element type match");
2161 return success();
2162}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:266
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:791
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:562
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:252
static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType, Type integerType)
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:298
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:185
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:290
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:281
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
Value getOperand(unsigned idx)
Definition Operation.h:375
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:274
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.