MLIR 22.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(descriptorSetName);
104 elidedAttrs.push_back(bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(builtInName);
115 }
116
117 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118}
119
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result.attributes) ||
130 parser.parseColon() || parser.parseType(type))
131 return failure();
132 auto fnType = llvm::dyn_cast<FunctionType>(type);
133 if (!fnType) {
134 parser.emitError(loc, "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138 return failure();
139 result.addTypes(fnType.getResults());
140 return success();
141 }
142 return failure(parser.parseOperandList(ops) ||
143 parser.parseOptionalAttrDict(result.attributes) ||
144 parser.parseColonType(type) ||
145 parser.resolveOperands(ops, type, result.operands) ||
146 parser.addTypeToList(type, result.types));
147}
148
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(0).getType();
155 if (llvm::any_of(op->getOperandTypes(),
156 [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(op->getOperands());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
174
175 if (valType !=
176 llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
177 return op.emitOpError("mismatch in result type and pointer type");
178 }
179 return success();
180}
181
182/// Walks the given type hierarchy with the given indices, potentially down
183/// to component granularity, to select an element type. Returns null type and
184/// emits errors with the given loc on failure.
185static Type
187 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
188 if (indices.empty()) {
189 emitErrorFn("expected at least one index for spirv.CompositeExtract");
190 return nullptr;
191 }
192
193 for (auto index : indices) {
194 if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
195 if (cType.hasCompileTimeKnownNumElements() &&
196 (index < 0 ||
197 static_cast<uint64_t>(index) >= cType.getNumElements())) {
198 emitErrorFn("index ") << index << " out of bounds for " << type;
199 return nullptr;
200 }
201 type = cType.getElementType(index);
202 } else {
203 emitErrorFn("cannot extract from non-composite type ")
204 << type << " with index " << index;
205 return nullptr;
206 }
207 }
208 return type;
209}
210
211static Type
213 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
214 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
215 if (!indicesArrayAttr) {
216 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
217 return nullptr;
218 }
219 if (indicesArrayAttr.empty()) {
220 emitErrorFn("expected at least one index for spirv.CompositeExtract");
221 return nullptr;
222 }
223
224 SmallVector<int32_t, 2> indexVals;
225 for (auto indexAttr : indicesArrayAttr) {
226 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
227 if (!indexIntAttr) {
228 emitErrorFn("expected an 32-bit integer for index, but found '")
229 << indexAttr << "'";
230 return nullptr;
231 }
232 indexVals.push_back(indexIntAttr.getInt());
233 }
234 return getElementType(type, indexVals, emitErrorFn);
235}
236
238 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
239 return ::mlir::emitError(loc, err);
240 };
241 return getElementType(type, indices, errorFn);
242}
243
245 SMLoc loc) {
246 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
247 return parser.emitError(loc, err);
248 };
249 return getElementType(type, indices, errorFn);
250}
251
252template <typename ExtendedBinaryOp>
253static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
254 auto resultType = llvm::cast<spirv::StructType>(op.getType());
255 if (resultType.getNumElements() != 2)
256 return op.emitOpError("expected result struct type containing two members");
257
258 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
259 resultType.getElementType(0),
260 resultType.getElementType(1)}))
261 return op.emitOpError(
262 "expected all operand types and struct member types are the same");
263
264 return success();
265}
266
270 if (parser.parseOptionalAttrDict(result.attributes) ||
271 parser.parseOperandList(operands) || parser.parseColon())
272 return failure();
273
274 Type resultType;
275 SMLoc loc = parser.getCurrentLocation();
276 if (parser.parseType(resultType))
277 return failure();
278
279 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
280 if (!structType || structType.getNumElements() != 2)
281 return parser.emitError(loc, "expected spirv.struct type with two members");
282
283 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
284 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
285 return failure();
286
287 result.addTypes(resultType);
288 return success();
289}
290
292 OpAsmPrinter &printer) {
293 printer << ' ';
294 printer.printOptionalAttrDict(op->getAttrs());
295 printer.printOperands(op->getOperands());
296 printer << " : " << op->getResultTypes().front();
297}
298
299static LogicalResult verifyShiftOp(Operation *op) {
300 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
301 return op->emitError("expected the same type for the first operand and "
302 "result, but provided ")
303 << op->getOperand(0).getType() << " and "
304 << op->getResult(0).getType();
305 }
306 return success();
307}
308
309//===----------------------------------------------------------------------===//
310// spirv.mlir.addressof
311//===----------------------------------------------------------------------===//
312
313void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
314 spirv::GlobalVariableOp var) {
315 build(builder, state, var.getType(), SymbolRefAttr::get(var));
316}
317
318LogicalResult spirv::AddressOfOp::verify() {
319 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
320 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
321 getVariableAttr()));
322 if (!varOp) {
323 return emitOpError("expected spirv.GlobalVariable symbol");
324 }
325 if (getPointer().getType() != varOp.getType()) {
326 return emitOpError(
327 "result type mismatch with the referenced global variable's type");
328 }
329 return success();
330}
331
332//===----------------------------------------------------------------------===//
333// spirv.CompositeConstruct
334//===----------------------------------------------------------------------===//
335
336LogicalResult spirv::CompositeConstructOp::verify() {
337 operand_range constituents = this->getConstituents();
338
339 // There are 4 cases with varying verification rules:
340 // 1. Cooperative Matrices (1 constituent)
341 // 2. Structs (1 constituent for each member)
342 // 3. Arrays (1 constituent for each array element)
343 // 4. Vectors (1 constituent (sub-)element for each vector element)
344
345 auto coopElementType =
348 [](auto coopType) { return coopType.getElementType(); })
349 .Default(nullptr);
350
351 // Case 1. -- matrices.
352 if (coopElementType) {
353 if (constituents.size() != 1)
354 return emitOpError("has incorrect number of operands: expected ")
355 << "1, but provided " << constituents.size();
356 if (coopElementType != constituents.front().getType())
357 return emitOpError("operand type mismatch: expected operand type ")
358 << coopElementType << ", but provided "
359 << constituents.front().getType();
360 return success();
361 }
362
363 // Case 2./3./4. -- number of constituents matches the number of elements.
364 auto cType = llvm::cast<spirv::CompositeType>(getType());
365 if (constituents.size() == cType.getNumElements()) {
366 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367 if (constituents[index].getType() != cType.getElementType(index)) {
368 return emitOpError("operand type mismatch: expected operand type ")
369 << cType.getElementType(index) << ", but provided "
370 << constituents[index].getType();
371 }
372 }
373 return success();
374 }
375
376 // Case 4. -- check that all constituents add up tp the expected vector type.
377 auto resultType = llvm::dyn_cast<VectorType>(cType);
378 if (!resultType)
379 return emitOpError(
380 "expected to return a vector or cooperative matrix when the number of "
381 "constituents is less than what the result needs");
382
384 for (Value component : constituents) {
385 if (!llvm::isa<VectorType>(component.getType()) &&
386 !component.getType().isIntOrFloat())
387 return emitOpError("operand type mismatch: expected operand to have "
388 "a scalar or vector type, but provided ")
389 << component.getType();
390
391 Type elementType = component.getType();
392 if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
393 sizes.push_back(vectorType.getNumElements());
394 elementType = vectorType.getElementType();
395 } else {
396 sizes.push_back(1);
397 }
398
399 if (elementType != resultType.getElementType())
400 return emitOpError("operand element type mismatch: expected to be ")
401 << resultType.getElementType() << ", but provided " << elementType;
402 }
403 unsigned totalCount = llvm::sum_of(sizes);
404 if (totalCount != cType.getNumElements())
405 return emitOpError("has incorrect number of operands: expected ")
406 << cType.getNumElements() << ", but provided " << totalCount;
407 return success();
408}
409
410//===----------------------------------------------------------------------===//
411// spirv.CompositeExtractOp
412//===----------------------------------------------------------------------===//
413
414void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
415 Value composite,
417 auto indexAttr = builder.getI32ArrayAttr(indices);
418 auto elementType =
419 getElementType(composite.getType(), indexAttr, state.location);
420 if (!elementType) {
421 return;
422 }
423 build(builder, state, elementType, composite, indexAttr);
424}
425
426ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
428 OpAsmParser::UnresolvedOperand compositeInfo;
429 Attribute indicesAttr;
430 StringRef indicesAttrName =
431 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
432 Type compositeType;
433 SMLoc attrLocation;
434
435 if (parser.parseOperand(compositeInfo) ||
436 parser.getCurrentLocation(&attrLocation) ||
437 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
438 parser.parseColonType(compositeType) ||
439 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
440 return failure();
441 }
442
443 Type resultType =
444 getElementType(compositeType, indicesAttr, parser, attrLocation);
445 if (!resultType) {
446 return failure();
447 }
448 result.addTypes(resultType);
449 return success();
450}
451
452void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
453 printer << ' ' << getComposite() << getIndices() << " : "
454 << getComposite().getType();
455}
456
457LogicalResult spirv::CompositeExtractOp::verify() {
458 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
459 auto resultType =
460 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
461 if (!resultType)
462 return failure();
463
464 if (resultType != getType()) {
465 return emitOpError("invalid result type: expected ")
466 << resultType << " but provided " << getType();
467 }
468
469 return success();
470}
471
472//===----------------------------------------------------------------------===//
473// spirv.CompositeInsert
474//===----------------------------------------------------------------------===//
475
476void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
477 Value object, Value composite,
479 auto indexAttr = builder.getI32ArrayAttr(indices);
480 build(builder, state, composite.getType(), object, composite, indexAttr);
481}
482
483ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
486 Type objectType, compositeType;
487 Attribute indicesAttr;
488 StringRef indicesAttrName =
489 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
490 auto loc = parser.getCurrentLocation();
491
492 return failure(
493 parser.parseOperandList(operands, 2) ||
494 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
495 parser.parseColonType(objectType) ||
496 parser.parseKeywordType("into", compositeType) ||
497 parser.resolveOperands(operands, {objectType, compositeType}, loc,
498 result.operands) ||
499 parser.addTypesToList(compositeType, result.types));
500}
501
502LogicalResult spirv::CompositeInsertOp::verify() {
503 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
504 auto objectType =
505 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
506 if (!objectType)
507 return failure();
508
509 if (objectType != getObject().getType()) {
510 return emitOpError("object operand type should be ")
511 << objectType << ", but found " << getObject().getType();
512 }
513
514 if (getComposite().getType() != getType()) {
515 return emitOpError("result type should be the same as "
516 "the composite type, but found ")
517 << getComposite().getType() << " vs " << getType();
518 }
519
520 return success();
521}
522
523void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
524 printer << " " << getObject() << ", " << getComposite() << getIndices()
525 << " : " << getObject().getType() << " into "
526 << getComposite().getType();
527}
528
529//===----------------------------------------------------------------------===//
530// spirv.Constant
531//===----------------------------------------------------------------------===//
532
533ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
535 Attribute value;
536 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
537 if (parser.parseAttribute(value, valueAttrName, result.attributes))
538 return failure();
539
540 Type type = NoneType::get(parser.getContext());
541 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
542 type = typedAttr.getType();
543 if (llvm::isa<NoneType, TensorType>(type)) {
544 if (parser.parseColonType(type))
545 return failure();
546 }
547
548 if (llvm::isa<TensorArmType>(type)) {
549 if (parser.parseOptionalColon().succeeded())
550 if (parser.parseType(type))
551 return failure();
552 }
553
554 return parser.addTypeToList(type, result.types);
555}
556
557void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558 printer << ' ' << getValue();
559 if (llvm::isa<spirv::ArrayType>(getType()))
560 printer << " : " << getType();
561}
562
563static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
564 Type opType) {
565 if (isa<spirv::CooperativeMatrixType>(opType)) {
566 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
567 if (!denseAttr || !denseAttr.isSplat())
568 return op.emitOpError("expected a splat dense attribute for cooperative "
569 "matrix constant, but found ")
570 << denseAttr;
571 }
572 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
573 auto valueType = llvm::cast<TypedAttr>(value).getType();
574 if (valueType != opType)
575 return op.emitOpError("result type (")
576 << opType << ") does not match value type (" << valueType << ")";
577 return success();
578 }
579 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
580 auto valueType = llvm::cast<TypedAttr>(value).getType();
581 if (valueType == opType)
582 return success();
583 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
584 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
585 if (!arrayType)
586 return op.emitOpError("result or element type (")
587 << opType << ") does not match value type (" << valueType
588 << "), must be the same or spirv.array";
589
590 int numElements = arrayType.getNumElements();
591 auto opElemType = arrayType.getElementType();
592 while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
593 numElements *= t.getNumElements();
594 opElemType = t.getElementType();
595 }
596 if (!opElemType.isIntOrFloat())
597 return op.emitOpError("only support nested array result type");
598
599 auto valueElemType = shapedType.getElementType();
600 if (valueElemType != opElemType) {
601 return op.emitOpError("result element type (")
602 << opElemType << ") does not match value element type ("
603 << valueElemType << ")";
604 }
605
606 if (numElements != shapedType.getNumElements()) {
607 return op.emitOpError("result number of elements (")
608 << numElements << ") does not match value number of elements ("
609 << shapedType.getNumElements() << ")";
610 }
611 return success();
612 }
613 if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
614 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
615 if (!arrayType)
616 return op.emitOpError(
617 "must have spirv.array result type for array value");
618 Type elemType = arrayType.getElementType();
619 for (Attribute element : arrayAttr.getValue()) {
620 // Verify array elements recursively.
621 if (failed(verifyConstantType(op, element, elemType)))
622 return failure();
623 }
624 return success();
625 }
626 return op.emitOpError("cannot have attribute: ") << value;
627}
628
629LogicalResult spirv::ConstantOp::verify() {
630 // ODS already generates checks to make sure the result type is valid. We just
631 // need to additionally check that the value's attribute type is consistent
632 // with the result type.
633 return verifyConstantType(*this, getValueAttr(), getType());
634}
635
636bool spirv::ConstantOp::isBuildableWith(Type type) {
637 // Must be valid SPIR-V type first.
638 if (!llvm::isa<spirv::SPIRVType>(type))
639 return false;
640
641 if (isa<SPIRVDialect>(type.getDialect())) {
642 // TODO: support constant struct
643 return llvm::isa<spirv::ArrayType>(type);
644 }
645
646 return true;
647}
648
649spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
650 OpBuilder &builder) {
651 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
652 unsigned width = intType.getWidth();
653 if (width == 1)
654 return spirv::ConstantOp::create(builder, loc, type,
655 builder.getBoolAttr(false));
656 return spirv::ConstantOp::create(
657 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
658 }
659 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
660 return spirv::ConstantOp::create(builder, loc, type,
661 builder.getFloatAttr(floatType, 0.0));
662 }
663 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664 Type elemType = vectorType.getElementType();
665 if (llvm::isa<IntegerType>(elemType)) {
666 return spirv::ConstantOp::create(
667 builder, loc, type,
668 DenseElementsAttr::get(vectorType,
669 IntegerAttr::get(elemType, 0).getValue()));
670 }
671 if (llvm::isa<FloatType>(elemType)) {
672 return spirv::ConstantOp::create(
673 builder, loc, type,
674 DenseFPElementsAttr::get(vectorType,
675 FloatAttr::get(elemType, 0.0).getValue()));
676 }
677 }
678
679 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
680}
681
682spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
683 OpBuilder &builder) {
684 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
685 unsigned width = intType.getWidth();
686 if (width == 1)
687 return spirv::ConstantOp::create(builder, loc, type,
688 builder.getBoolAttr(true));
689 return spirv::ConstantOp::create(
690 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
691 }
692 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
693 return spirv::ConstantOp::create(builder, loc, type,
694 builder.getFloatAttr(floatType, 1.0));
695 }
696 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697 Type elemType = vectorType.getElementType();
698 if (llvm::isa<IntegerType>(elemType)) {
699 return spirv::ConstantOp::create(
700 builder, loc, type,
701 DenseElementsAttr::get(vectorType,
702 IntegerAttr::get(elemType, 1).getValue()));
703 }
704 if (llvm::isa<FloatType>(elemType)) {
705 return spirv::ConstantOp::create(
706 builder, loc, type,
707 DenseFPElementsAttr::get(vectorType,
708 FloatAttr::get(elemType, 1.0).getValue()));
709 }
710 }
711
712 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
713}
714
715void mlir::spirv::ConstantOp::getAsmResultNames(
716 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
717 Type type = getType();
718
719 SmallString<32> specialNameBuffer;
720 llvm::raw_svector_ostream specialName(specialNameBuffer);
721 specialName << "cst";
722
723 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
724
725 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
726 assert(intTy);
727
728 if (intTy.getWidth() == 1) {
729 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
730 }
731
732 if (intTy.isSignless()) {
733 specialName << intCst.getInt();
734 } else if (intTy.isUnsigned()) {
735 specialName << intCst.getUInt();
736 } else {
737 specialName << intCst.getSInt();
738 }
739 }
740
741 if (intTy || llvm::isa<FloatType>(type)) {
742 specialName << '_' << type;
743 }
744
745 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
746 specialName << "_vec_";
747 specialName << vecType.getDimSize(0);
748
749 Type elementType = vecType.getElementType();
750
751 if (llvm::isa<IntegerType>(elementType) ||
752 llvm::isa<FloatType>(elementType)) {
753 specialName << "x" << elementType;
754 }
755 }
756
757 setNameFn(getResult(), specialName.str());
758}
759
760void mlir::spirv::AddressOfOp::getAsmResultNames(
761 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
762 SmallString<32> specialNameBuffer;
763 llvm::raw_svector_ostream specialName(specialNameBuffer);
764 specialName << getVariable() << "_addr";
765 setNameFn(getResult(), specialName.str());
766}
767
768//===----------------------------------------------------------------------===//
769// spirv.EXTConstantCompositeReplicate
770//===----------------------------------------------------------------------===//
771
772// Returns type of attribute. In case of a TypedAttr this will simply return
773// the type. But for an ArrayAttr which is untyped and can be multidimensional
774// it creates the ArrayType recursively.
776 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
777 return typedAttr.getType();
778 }
779
780 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
781 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
782 }
783
784 return nullptr;
785}
786
787LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
788 Type valueType = getValueType(getValue());
789 if (!valueType)
790 return emitError("unknown value attribute type");
791
792 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
793 if (!compositeType)
794 return emitError("result type is not a composite type");
795
796 Type compositeElementType = compositeType.getElementType(0);
797
798 SmallVector<Type, 3> possibleTypes = {compositeElementType};
799 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
800 compositeElementType = type.getElementType(0);
801 possibleTypes.push_back(compositeElementType);
802 }
803
804 if (!is_contained(possibleTypes, valueType)) {
805 return emitError("expected value attribute type ")
806 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
807 }
808
809 return success();
810}
811
812//===----------------------------------------------------------------------===//
813// spirv.ControlBarrierOp
814//===----------------------------------------------------------------------===//
815
816LogicalResult spirv::ControlBarrierOp::verify() {
817 return verifyMemorySemantics(getOperation(), getMemorySemantics());
818}
819
820//===----------------------------------------------------------------------===//
821// spirv.EntryPoint
822//===----------------------------------------------------------------------===//
823
824void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
825 spirv::ExecutionModel executionModel,
826 spirv::FuncOp function,
827 ArrayRef<Attribute> interfaceVars) {
828 build(builder, state,
829 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
830 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
831}
832
833ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
835 spirv::ExecutionModel execModel;
836 SmallVector<Attribute, 4> interfaceVars;
837
839 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
840 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
841 return failure();
842 }
843
844 if (!parser.parseOptionalComma()) {
845 // Parse the interface variables
846 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
847 // The name of the interface variable attribute isnt important
848 FlatSymbolRefAttr var;
849 NamedAttrList attrs;
850 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
851 return failure();
852 interfaceVars.push_back(var);
853 return success();
854 }))
855 return failure();
856 }
857 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
858 parser.getBuilder().getArrayAttr(interfaceVars));
859 return success();
860}
861
862void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
863 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
864 printer.printSymbolName(getFn());
865 auto interfaceVars = getInterface().getValue();
866 if (!interfaceVars.empty())
867 printer << ", " << llvm::interleaved(interfaceVars);
868}
869
870LogicalResult spirv::EntryPointOp::verify() {
871 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
872 // verification.
873 return success();
874}
875
876//===----------------------------------------------------------------------===//
877// spirv.ExecutionMode
878//===----------------------------------------------------------------------===//
879
880void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
881 spirv::FuncOp function,
882 spirv::ExecutionMode executionMode,
883 ArrayRef<int32_t> params) {
884 build(builder, state, SymbolRefAttr::get(function),
885 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
886 builder.getI32ArrayAttr(params));
887}
888
889ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
891 spirv::ExecutionMode execMode;
892 Attribute fn;
893 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
895 return failure();
896 }
897
899 Type i32Type = parser.getBuilder().getIntegerType(32);
900 while (!parser.parseOptionalComma()) {
901 NamedAttrList attr;
902 Attribute value;
903 if (parser.parseAttribute(value, i32Type, "value", attr)) {
904 return failure();
905 }
906 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
907 }
908 StringRef valuesAttrName =
909 spirv::ExecutionModeOp::getValuesAttrName(result.name);
910 result.addAttribute(valuesAttrName,
911 parser.getBuilder().getI32ArrayAttr(values));
912 return success();
913}
914
915void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
916 printer << " ";
917 printer.printSymbolName(getFn());
918 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
919 ArrayAttr values = this->getValues();
920 if (!values.empty())
921 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
922}
923
924//===----------------------------------------------------------------------===//
925// spirv.func
926//===----------------------------------------------------------------------===//
927
928ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
930 SmallVector<DictionaryAttr> resultAttrs;
931 SmallVector<Type> resultTypes;
932 auto &builder = parser.getBuilder();
933
934 // Parse the name as a symbol.
935 StringAttr nameAttr;
936 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
937 result.attributes))
938 return failure();
939
940 // Parse the function signature.
941 bool isVariadic = false;
943 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
944 resultAttrs))
945 return failure();
946
947 SmallVector<Type> argTypes;
948 for (auto &arg : entryArgs)
949 argTypes.push_back(arg.type);
950 auto fnType = builder.getFunctionType(argTypes, resultTypes);
951 result.addAttribute(getFunctionTypeAttrName(result.name),
952 TypeAttr::get(fnType));
953
954 // Parse the optional function control keyword.
955 spirv::FunctionControl fnControl;
957 return failure();
958
959 // If additional attributes are present, parse them.
960 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
961 return failure();
962
963 // Add the attributes to the function arguments.
964 assert(resultAttrs.size() == resultTypes.size());
966 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
967 getResAttrsAttrName(result.name));
968
969 // Parse the optional function body.
970 auto *body = result.addRegion();
971 OptionalParseResult parseResult =
972 parser.parseOptionalRegion(*body, entryArgs);
973 return failure(parseResult.has_value() && failed(*parseResult));
974}
975
976void spirv::FuncOp::print(OpAsmPrinter &printer) {
977 // Print function name, signature, and control.
978 printer << " ";
979 printer.printSymbolName(getSymName());
980 auto fnType = getFunctionType();
982 printer, *this, fnType.getInputs(),
983 /*isVariadic=*/false, fnType.getResults());
984 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
985 << "\"";
987 printer, *this,
989 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
990 getFunctionControlAttrName()});
991
992 // Print the body if this is not an external function.
993 Region &body = this->getBody();
994 if (!body.empty()) {
995 printer << ' ';
996 printer.printRegion(body, /*printEntryBlockArgs=*/false,
997 /*printBlockTerminators=*/true);
998 }
999}
1000
1001LogicalResult spirv::FuncOp::verifyType() {
1002 FunctionType fnType = getFunctionType();
1003 if (fnType.getNumResults() > 1)
1004 return emitOpError("cannot have more than one result");
1005
1006 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1007 unsigned argIndex) {
1008 auto func = llvm::cast<FunctionOpInterface>(getOperation());
1009 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1010 if (argAttr.getName() != spirv::DecorationAttr::name)
1011 continue;
1012 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1013 return decAttr.getValue() == decoration;
1014 }
1015 return false;
1016 };
1017
1018 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1019 Type param = fnType.getInputs()[i];
1020 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1021 if (!inputPtrType)
1022 continue;
1023
1024 auto pointeePtrType =
1025 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1026 if (pointeePtrType) {
1027 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1028 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1029 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1030 // > storage class, the function parameter must be decorated with exactly
1031 // > one of AliasedPointer or RestrictPointer.
1032 if (pointeePtrType.getStorageClass() !=
1033 spirv::StorageClass::PhysicalStorageBuffer)
1034 continue;
1035
1036 bool hasAliasedPtr =
1037 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1038 bool hasRestrictPtr =
1039 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1040 if (!hasAliasedPtr && !hasRestrictPtr)
1041 return emitOpError()
1042 << "with a pointer points to a physical buffer pointer must "
1043 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1044 continue;
1045 }
1046 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1047 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1048 // > the PhysicalStorageBuffer storage class, the function parameter must
1049 // > be decorated with exactly one of Aliased or Restrict.
1050 if (auto pointeeArrayType =
1051 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1052 pointeePtrType =
1053 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1054 } else {
1055 pointeePtrType = inputPtrType;
1056 }
1057
1058 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1059 spirv::StorageClass::PhysicalStorageBuffer)
1060 continue;
1061
1062 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1063 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1064 if (!hasAliased && !hasRestrict)
1065 return emitOpError() << "with physical buffer pointer must be decorated "
1066 "either 'Aliased' or 'Restrict'";
1067 }
1068
1069 return success();
1070}
1071
1072LogicalResult spirv::FuncOp::verifyBody() {
1073 FunctionType fnType = getFunctionType();
1074 if (!isExternal()) {
1075 Block &entryBlock = front();
1076
1077 unsigned numArguments = this->getNumArguments();
1078 if (entryBlock.getNumArguments() != numArguments)
1079 return emitOpError("entry block must have ")
1080 << numArguments << " arguments to match function signature";
1081
1082 for (auto [index, fnArgType, blockArgType] :
1083 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1084 if (blockArgType != fnArgType) {
1085 return emitOpError("type of entry block argument #")
1086 << index << '(' << blockArgType
1087 << ") must match the type of the corresponding argument in "
1088 << "function signature(" << fnArgType << ')';
1089 }
1090 }
1091 }
1092
1093 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1094 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1095 if (fnType.getNumResults() != 0)
1096 return retOp.emitOpError("cannot be used in functions returning value");
1097 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1098 if (fnType.getNumResults() != 1)
1099 return retOp.emitOpError(
1100 "returns 1 value but enclosing function requires ")
1101 << fnType.getNumResults() << " results";
1102
1103 auto retOperandType = retOp.getValue().getType();
1104 auto fnResultType = fnType.getResult(0);
1105 if (retOperandType != fnResultType)
1106 return retOp.emitOpError(" return value's type (")
1107 << retOperandType << ") mismatch with function's result type ("
1108 << fnResultType << ")";
1109 }
1110 return WalkResult::advance();
1111 });
1112
1113 // TODO: verify other bits like linkage type.
1114
1115 return failure(walkResult.wasInterrupted());
1116}
1117
1118void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1119 StringRef name, FunctionType type,
1120 spirv::FunctionControl control,
1123 builder.getStringAttr(name));
1124 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1126 builder.getAttr<spirv::FunctionControlAttr>(control));
1127 state.attributes.append(attrs.begin(), attrs.end());
1128 state.addRegion();
1129}
1130
1131//===----------------------------------------------------------------------===//
1132// spirv.GLFClampOp
1133//===----------------------------------------------------------------------===//
1134
1135ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1138}
1139void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1140
1141//===----------------------------------------------------------------------===//
1142// spirv.GLUClampOp
1143//===----------------------------------------------------------------------===//
1144
1145ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1148}
1149void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1150
1151//===----------------------------------------------------------------------===//
1152// spirv.GLSClampOp
1153//===----------------------------------------------------------------------===//
1154
1155ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1158}
1159void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1160
1161//===----------------------------------------------------------------------===//
1162// spirv.GLFmaOp
1163//===----------------------------------------------------------------------===//
1164
1165ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1167}
1168void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1169
1170//===----------------------------------------------------------------------===//
1171// spirv.GlobalVariable
1172//===----------------------------------------------------------------------===//
1173
1174void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1175 Type type, StringRef name,
1176 unsigned descriptorSet, unsigned binding) {
1177 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1178 state.addAttribute(
1179 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1180 builder.getI32IntegerAttr(descriptorSet));
1181 state.addAttribute(
1182 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1183 builder.getI32IntegerAttr(binding));
1184}
1185
1186void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1187 Type type, StringRef name,
1188 spirv::BuiltIn builtin) {
1189 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1190 state.addAttribute(
1191 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1192 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1193}
1194
1195ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1197 // Parse variable name.
1198 StringAttr nameAttr;
1199 StringRef initializerAttrName =
1200 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1201 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1202 result.attributes)) {
1203 return failure();
1204 }
1205
1206 // Parse optional initializer
1207 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1208 FlatSymbolRefAttr initSymbol;
1209 if (parser.parseLParen() ||
1210 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1211 result.attributes) ||
1212 parser.parseRParen())
1213 return failure();
1214 }
1215
1216 if (parseVariableDecorations(parser, result)) {
1217 return failure();
1218 }
1219
1220 Type type;
1221 StringRef typeAttrName =
1222 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1223 auto loc = parser.getCurrentLocation();
1224 if (parser.parseColonType(type)) {
1225 return failure();
1226 }
1227 if (!llvm::isa<spirv::PointerType>(type)) {
1228 return parser.emitError(loc, "expected spirv.ptr type");
1229 }
1230 result.addAttribute(typeAttrName, TypeAttr::get(type));
1231
1232 return success();
1233}
1234
1235void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1236 SmallVector<StringRef, 4> elidedAttrs{
1238
1239 // Print variable name.
1240 printer << ' ';
1241 printer.printSymbolName(getSymName());
1242 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1243
1244 StringRef initializerAttrName = this->getInitializerAttrName();
1245 // Print optional initializer
1246 if (auto initializer = this->getInitializer()) {
1247 printer << " " << initializerAttrName << '(';
1248 printer.printSymbolName(*initializer);
1249 printer << ')';
1250 elidedAttrs.push_back(initializerAttrName);
1251 }
1252
1253 StringRef typeAttrName = this->getTypeAttrName();
1254 elidedAttrs.push_back(typeAttrName);
1255 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1256 printer << " : " << getType();
1257}
1258
1259LogicalResult spirv::GlobalVariableOp::verify() {
1260 if (!llvm::isa<spirv::PointerType>(getType()))
1261 return emitOpError("result must be of a !spv.ptr type");
1262
1263 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1264 // object. It cannot be Generic. It must be the same as the Storage Class
1265 // operand of the Result Type."
1266 // Also, Function storage class is reserved by spirv.Variable.
1267 auto storageClass = this->storageClass();
1268 if (storageClass == spirv::StorageClass::Generic ||
1269 storageClass == spirv::StorageClass::Function) {
1270 return emitOpError("storage class cannot be '")
1271 << stringifyStorageClass(storageClass) << "'";
1272 }
1273
1274 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1275 this->getInitializerAttrName())) {
1277 (*this)->getParentOp(), init.getAttr());
1278 // TODO: Currently only variable initialization with specialization
1279 // constants is supported. There could be normal constants in the module
1280 // scope as well.
1281 //
1282 // In the current setup we also cannot initialize one global variable with
1283 // another. The problem is that if we try to initialize pointer of type X
1284 // with another pointer type, the validator fails because it expects the
1285 // variable to be initialized to be type X, not pointer to X. Now
1286 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1287 // we cannot initialize one `spirv.GlobalVariable` with another.
1288 if (!initOp ||
1289 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1290 return emitOpError("initializer must be result of a "
1291 "spirv.SpecConstant or "
1292 "spirv.SpecConstantCompositeOp op");
1293 }
1294 }
1295
1296 return success();
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// spirv.INTEL.SubgroupBlockRead
1301//===----------------------------------------------------------------------===//
1302
1303LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1304 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1305 return failure();
1306
1307 return success();
1308}
1309
1310//===----------------------------------------------------------------------===//
1311// spirv.INTEL.SubgroupBlockWrite
1312//===----------------------------------------------------------------------===//
1313
1314ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1316 // Parse the storage class specification
1317 spirv::StorageClass storageClass;
1319 auto loc = parser.getCurrentLocation();
1320 Type elementType;
1321 if (parseEnumStrAttr(storageClass, parser) ||
1322 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1323 parser.parseType(elementType)) {
1324 return failure();
1325 }
1326
1327 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1328 if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1329 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1330
1331 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1332 result.operands)) {
1333 return failure();
1334 }
1335 return success();
1336}
1337
1338void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1339 printer << " " << getPtr() << ", " << getValue() << " : "
1340 << getValue().getType();
1341}
1342
1343LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1344 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1345 return failure();
1346
1347 return success();
1348}
1349
1350//===----------------------------------------------------------------------===//
1351// spirv.IAddCarryOp
1352//===----------------------------------------------------------------------===//
1353
1354LogicalResult spirv::IAddCarryOp::verify() {
1355 return ::verifyArithmeticExtendedBinaryOp(*this);
1356}
1357
1358ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1360 return ::parseArithmeticExtendedBinaryOp(parser, result);
1361}
1362
1363void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1364 ::printArithmeticExtendedBinaryOp(*this, printer);
1365}
1366
1367//===----------------------------------------------------------------------===//
1368// spirv.ISubBorrowOp
1369//===----------------------------------------------------------------------===//
1370
1371LogicalResult spirv::ISubBorrowOp::verify() {
1372 return ::verifyArithmeticExtendedBinaryOp(*this);
1373}
1374
1375ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1377 return ::parseArithmeticExtendedBinaryOp(parser, result);
1378}
1379
1380void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1381 ::printArithmeticExtendedBinaryOp(*this, printer);
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// spirv.SMulExtended
1386//===----------------------------------------------------------------------===//
1387
1388LogicalResult spirv::SMulExtendedOp::verify() {
1389 return ::verifyArithmeticExtendedBinaryOp(*this);
1390}
1391
1392ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1394 return ::parseArithmeticExtendedBinaryOp(parser, result);
1395}
1396
1397void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1398 ::printArithmeticExtendedBinaryOp(*this, printer);
1399}
1400
1401//===----------------------------------------------------------------------===//
1402// spirv.UMulExtended
1403//===----------------------------------------------------------------------===//
1404
1405LogicalResult spirv::UMulExtendedOp::verify() {
1406 return ::verifyArithmeticExtendedBinaryOp(*this);
1407}
1408
1409ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1411 return ::parseArithmeticExtendedBinaryOp(parser, result);
1412}
1413
1414void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1415 ::printArithmeticExtendedBinaryOp(*this, printer);
1416}
1417
1418//===----------------------------------------------------------------------===//
1419// spirv.MemoryBarrierOp
1420//===----------------------------------------------------------------------===//
1421
1422LogicalResult spirv::MemoryBarrierOp::verify() {
1423 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1424}
1425
1426//===----------------------------------------------------------------------===//
1427// spirv.module
1428//===----------------------------------------------------------------------===//
1429
1430void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1431 std::optional<StringRef> name) {
1432 OpBuilder::InsertionGuard guard(builder);
1433 builder.createBlock(state.addRegion());
1434 if (name) {
1436 builder.getStringAttr(*name));
1437 }
1438}
1439
1440void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1441 spirv::AddressingModel addressingModel,
1442 spirv::MemoryModel memoryModel,
1443 std::optional<VerCapExtAttr> vceTriple,
1444 std::optional<StringRef> name) {
1445 state.addAttribute(
1446 "addressing_model",
1447 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1448 state.addAttribute("memory_model",
1449 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1450 OpBuilder::InsertionGuard guard(builder);
1451 builder.createBlock(state.addRegion());
1452 if (vceTriple)
1453 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1454 if (name)
1456 builder.getStringAttr(*name));
1457}
1458
1459ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1461 Region *body = result.addRegion();
1462
1463 // If the name is present, parse it.
1464 StringAttr nameAttr;
1466 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1467
1468 // Parse attributes
1469 spirv::AddressingModel addrModel;
1470 spirv::MemoryModel memoryModel;
1472 result) ||
1474 result))
1475 return failure();
1476
1477 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1478 spirv::VerCapExtAttr vceTriple;
1479 if (parser.parseAttribute(vceTriple,
1480 spirv::ModuleOp::getVCETripleAttrName(),
1481 result.attributes))
1482 return failure();
1483 }
1484
1485 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1486 parser.parseRegion(*body, /*arguments=*/{}))
1487 return failure();
1488
1489 // Make sure we have at least one block.
1490 if (body->empty())
1491 body->push_back(new Block());
1492
1493 return success();
1494}
1495
1496void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1497 if (std::optional<StringRef> name = getName()) {
1498 printer << ' ';
1499 printer.printSymbolName(*name);
1500 }
1501
1502 SmallVector<StringRef, 2> elidedAttrs;
1503
1504 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1505 << spirv::stringifyMemoryModel(getMemoryModel());
1506 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1507 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1508 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1510
1511 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1512 printer << " requires " << *triple;
1513 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1514 }
1515
1516 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1517 printer << ' ';
1518 printer.printRegion(getRegion());
1519}
1520
1521LogicalResult spirv::ModuleOp::verifyRegions() {
1522 Dialect *dialect = (*this)->getDialect();
1524 entryPoints;
1525 mlir::SymbolTable table(*this);
1526
1527 for (auto &op : *getBody()) {
1528 if (op.getDialect() != dialect)
1529 return op.emitError("'spirv.module' can only contain spirv.* ops");
1530
1531 // For EntryPoint op, check that the function and execution model is not
1532 // duplicated in EntryPointOps. Also verify that the interface specified
1533 // comes from globalVariables here to make this check cheaper.
1534 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1535 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1536 if (!funcOp) {
1537 return entryPointOp.emitError("function '")
1538 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1539 }
1540 if (auto interface = entryPointOp.getInterface()) {
1541 for (Attribute varRef : interface) {
1542 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1543 if (!varSymRef) {
1544 return entryPointOp.emitError(
1545 "expected symbol reference for interface "
1546 "specification instead of '")
1547 << varRef;
1548 }
1549 auto variableOp =
1550 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1551 if (!variableOp) {
1552 return entryPointOp.emitError("expected spirv.GlobalVariable "
1553 "symbol reference instead of'")
1554 << varSymRef << "'";
1555 }
1556 }
1557 }
1558
1559 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1560 funcOp, entryPointOp.getExecutionModel());
1561 if (!entryPoints.try_emplace(key, entryPointOp).second)
1562 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1563 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1564 // If the function is external and does not have 'Import'
1565 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1566 // LinkageAttributes is used to import external functions.
1567 auto linkageAttr = funcOp.getLinkageAttributes();
1568 auto hasImportLinkage =
1569 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1570 spirv::LinkageType::Import);
1571 if (funcOp.isExternal() && !hasImportLinkage)
1572 return op.emitError(
1573 "'spirv.module' cannot contain external functions "
1574 "without 'Import' linkage_attributes (LinkageAttributes)");
1575
1576 // TODO: move this check to spirv.func.
1577 for (auto &block : funcOp)
1578 for (auto &op : block) {
1579 if (op.getDialect() != dialect)
1580 return op.emitError(
1581 "functions in 'spirv.module' can only contain spirv.* ops");
1582 }
1583 }
1584 }
1585
1586 return success();
1587}
1588
1589//===----------------------------------------------------------------------===//
1590// spirv.mlir.referenceof
1591//===----------------------------------------------------------------------===//
1592
1593LogicalResult spirv::ReferenceOfOp::verify() {
1594 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1595 (*this)->getParentOp(), getSpecConstAttr());
1596 Type constType;
1597
1598 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1599 if (specConstOp)
1600 constType = specConstOp.getDefaultValue().getType();
1601
1602 auto specConstCompositeOp =
1603 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1604 if (specConstCompositeOp)
1605 constType = specConstCompositeOp.getType();
1606
1607 if (!specConstOp && !specConstCompositeOp)
1608 return emitOpError(
1609 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1610
1611 if (getReference().getType() != constType)
1612 return emitOpError("result type mismatch with the referenced "
1613 "specialization constant's type");
1614
1615 return success();
1616}
1617
1618//===----------------------------------------------------------------------===//
1619// spirv.SpecConstant
1620//===----------------------------------------------------------------------===//
1621
1622ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1624 StringAttr nameAttr;
1625 Attribute valueAttr;
1626 StringRef defaultValueAttrName =
1627 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1628
1629 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1630 result.attributes))
1631 return failure();
1632
1633 // Parse optional spec_id.
1634 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1635 IntegerAttr specIdAttr;
1636 if (parser.parseLParen() ||
1637 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1638 parser.parseRParen())
1639 return failure();
1640 }
1641
1642 if (parser.parseEqual() ||
1643 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1644 return failure();
1645
1646 return success();
1647}
1648
1649void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1650 printer << ' ';
1651 printer.printSymbolName(getSymName());
1652 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1653 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1654 printer << " = " << getDefaultValue();
1655}
1656
1657LogicalResult spirv::SpecConstantOp::verify() {
1658 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1659 if (specID.getValue().isNegative())
1660 return emitOpError("SpecId cannot be negative");
1661
1662 auto value = getDefaultValue();
1663 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1664 // Make sure bitwidth is allowed.
1665 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1666 return emitOpError("default value bitwidth disallowed");
1667 return success();
1668 }
1669 return emitOpError(
1670 "default value can only be a bool, integer, or float scalar");
1671}
1672
1673//===----------------------------------------------------------------------===//
1674// spirv.VectorShuffle
1675//===----------------------------------------------------------------------===//
1676
1677LogicalResult spirv::VectorShuffleOp::verify() {
1678 VectorType resultType = llvm::cast<VectorType>(getType());
1679
1680 size_t numResultElements = resultType.getNumElements();
1681 if (numResultElements != getComponents().size())
1682 return emitOpError("result type element count (")
1683 << numResultElements
1684 << ") mismatch with the number of component selectors ("
1685 << getComponents().size() << ")";
1686
1687 size_t totalSrcElements =
1688 llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1689 llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1690
1691 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1692 uint32_t index = selector.getZExtValue();
1693 if (index >= totalSrcElements &&
1694 index != std::numeric_limits<uint32_t>().max())
1695 return emitOpError("component selector ")
1696 << index << " out of range: expected to be in [0, "
1697 << totalSrcElements << ") or 0xffffffff";
1698 }
1699 return success();
1700}
1701
1702//===----------------------------------------------------------------------===//
1703// spirv.MatrixTimesScalar
1704//===----------------------------------------------------------------------===//
1705
1706LogicalResult spirv::MatrixTimesScalarOp::verify() {
1707 Type elementType =
1710 [](auto matrixType) { return matrixType.getElementType(); })
1711 .Default(nullptr);
1712
1713 assert(elementType && "Unhandled type");
1714
1715 // Check that the scalar type is the same as the matrix element type.
1716 if (getScalar().getType() != elementType)
1717 return emitOpError("input matrix components' type and scaling value must "
1718 "have the same type");
1719
1720 return success();
1721}
1722
1723//===----------------------------------------------------------------------===//
1724// spirv.Transpose
1725//===----------------------------------------------------------------------===//
1726
1727LogicalResult spirv::TransposeOp::verify() {
1728 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1729 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1730
1731 // Verify that the input and output matrices have correct shapes.
1732 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1733 return emitError("input matrix rows count must be equal to "
1734 "output matrix columns count");
1735
1736 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1737 return emitError("input matrix columns count must be equal to "
1738 "output matrix rows count");
1739
1740 // Verify that the input and output matrices have the same component type
1741 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1742 return emitError("input and output matrices must have the same "
1743 "component type");
1744
1745 return success();
1746}
1747
1748//===----------------------------------------------------------------------===//
1749// spirv.MatrixTimesVector
1750//===----------------------------------------------------------------------===//
1751
1752LogicalResult spirv::MatrixTimesVectorOp::verify() {
1753 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1754 auto vectorType = llvm::cast<VectorType>(getVector().getType());
1755 auto resultType = llvm::cast<VectorType>(getType());
1756
1757 if (matrixType.getNumColumns() != vectorType.getNumElements())
1758 return emitOpError("matrix columns (")
1759 << matrixType.getNumColumns() << ") must match vector operand size ("
1760 << vectorType.getNumElements() << ")";
1761
1762 if (resultType.getNumElements() != matrixType.getNumRows())
1763 return emitOpError("result size (")
1764 << resultType.getNumElements() << ") must match the matrix rows ("
1765 << matrixType.getNumRows() << ")";
1766
1767 if (matrixType.getElementType() != resultType.getElementType())
1768 return emitOpError("matrix and result element types must match");
1769
1770 return success();
1771}
1772
1773//===----------------------------------------------------------------------===//
1774// spirv.VectorTimesMatrix
1775//===----------------------------------------------------------------------===//
1776
1777LogicalResult spirv::VectorTimesMatrixOp::verify() {
1778 auto vectorType = llvm::cast<VectorType>(getVector().getType());
1779 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1780 auto resultType = llvm::cast<VectorType>(getType());
1781
1782 if (matrixType.getNumRows() != vectorType.getNumElements())
1783 return emitOpError("number of components in vector must equal the number "
1784 "of components in each column in matrix");
1785
1786 if (resultType.getNumElements() != matrixType.getNumColumns())
1787 return emitOpError("number of columns in matrix must equal the number of "
1788 "components in result");
1789
1790 if (matrixType.getElementType() != resultType.getElementType())
1791 return emitOpError("matrix must be a matrix with the same component type "
1792 "as the component type in result");
1793
1794 return success();
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// spirv.MatrixTimesMatrix
1799//===----------------------------------------------------------------------===//
1800
1801LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1802 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1803 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1804 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1805
1806 // left matrix columns' count and right matrix rows' count must be equal
1807 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1808 return emitError("left matrix columns' count must be equal to "
1809 "the right matrix rows' count");
1810
1811 // right and result matrices columns' count must be the same
1812 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1813 return emitError(
1814 "right and result matrices must have equal columns' count");
1815
1816 // right and result matrices component type must be the same
1817 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1818 return emitError("right and result matrices' component type must"
1819 " be the same");
1820
1821 // left and result matrices component type must be the same
1822 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1823 return emitError("left and result matrices' component type"
1824 " must be the same");
1825
1826 // left and result matrices rows count must be the same
1827 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1828 return emitError("left and result matrices must have equal rows' count");
1829
1830 return success();
1831}
1832
1833//===----------------------------------------------------------------------===//
1834// spirv.SpecConstantComposite
1835//===----------------------------------------------------------------------===//
1836
1837ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1839
1840 StringAttr compositeName;
1841 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1842 result.attributes))
1843 return failure();
1844
1845 if (parser.parseLParen())
1846 return failure();
1847
1848 SmallVector<Attribute, 4> constituents;
1849
1850 do {
1851 // The name of the constituent attribute isn't important
1852 const char *attrName = "spec_const";
1853 FlatSymbolRefAttr specConstRef;
1854 NamedAttrList attrs;
1855
1856 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1857 return failure();
1858
1859 constituents.push_back(specConstRef);
1860 } while (!parser.parseOptionalComma());
1861
1862 if (parser.parseRParen())
1863 return failure();
1864
1865 StringAttr compositeSpecConstituentsName =
1866 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1867 result.addAttribute(compositeSpecConstituentsName,
1868 parser.getBuilder().getArrayAttr(constituents));
1869
1870 Type type;
1871 if (parser.parseColonType(type))
1872 return failure();
1873
1874 StringAttr typeAttrName =
1875 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1876 result.addAttribute(typeAttrName, TypeAttr::get(type));
1877
1878 return success();
1879}
1880
1881void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1882 printer << " ";
1883 printer.printSymbolName(getSymName());
1884 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1885 << ") : " << getType();
1886}
1887
1888LogicalResult spirv::SpecConstantCompositeOp::verify() {
1889 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1890 auto constituents = this->getConstituents().getValue();
1891
1892 if (!cType)
1893 return emitError("result type must be a composite type, but provided ")
1894 << getType();
1895
1896 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1897 return emitError("unsupported composite type ") << cType;
1898 if (constituents.size() != cType.getNumElements())
1899 return emitError("has incorrect number of operands: expected ")
1900 << cType.getNumElements() << ", but provided "
1901 << constituents.size();
1902
1903 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1904 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1905
1906 auto constituentSpecConstOp =
1907 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1908 (*this)->getParentOp(), constituent.getAttr()));
1909
1910 if (constituentSpecConstOp.getDefaultValue().getType() !=
1911 cType.getElementType(index))
1912 return emitError("has incorrect types of operands: expected ")
1913 << cType.getElementType(index) << ", but provided "
1914 << constituentSpecConstOp.getDefaultValue().getType();
1915 }
1916
1917 return success();
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// spirv.EXTSpecConstantCompositeReplicateOp
1922//===----------------------------------------------------------------------===//
1923
1924ParseResult
1925spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1927 StringAttr compositeName;
1928 FlatSymbolRefAttr specConstRef;
1929 const char *attrName = "spec_const";
1930 NamedAttrList attrs;
1931 Type type;
1932
1933 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1934 result.attributes) ||
1935 parser.parseLParen() ||
1936 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1937 parser.parseRParen() || parser.parseColonType(type))
1938 return failure();
1939
1940 StringAttr compositeSpecConstituentName =
1941 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1942 result.name);
1943 result.addAttribute(compositeSpecConstituentName, specConstRef);
1944
1945 StringAttr typeAttrName =
1946 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1947 result.addAttribute(typeAttrName, TypeAttr::get(type));
1948
1949 return success();
1950}
1951
1952void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1953 printer << " ";
1954 printer.printSymbolName(getSymName());
1955 printer << " (" << this->getConstituent() << ") : " << getType();
1956}
1957
1958LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1959 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1960 if (!compositeType)
1961 return emitError("result type must be a composite type, but provided ")
1962 << getType();
1963
1965 (*this)->getParentOp(), this->getConstituent());
1966 if (!constituentOp)
1967 return emitError(
1968 "splat spec constant reference defining constituent not found");
1969
1970 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1971 if (!constituentSpecConstOp)
1972 return emitError("constituent is not a spec constant");
1973
1974 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1975 Type compositeElementType = compositeType.getElementType(0);
1976 if (constituentType != compositeElementType)
1977 return emitError("constituent has incorrect type: expected ")
1978 << compositeElementType << ", but provided " << constituentType;
1979
1980 return success();
1981}
1982
1983//===----------------------------------------------------------------------===//
1984// spirv.SpecConstantOperation
1985//===----------------------------------------------------------------------===//
1986
1987ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1989 Region *body = result.addRegion();
1990
1991 if (parser.parseKeyword("wraps"))
1992 return failure();
1993
1994 body->push_back(new Block);
1995 Block &block = body->back();
1996 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1997
1998 if (!wrappedOp)
1999 return failure();
2000
2001 OpBuilder builder(parser.getContext());
2002 builder.setInsertionPointToEnd(&block);
2003 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
2004 result.location = wrappedOp->getLoc();
2005
2006 result.addTypes(wrappedOp->getResult(0).getType());
2007
2008 if (parser.parseOptionalAttrDict(result.attributes))
2009 return failure();
2010
2011 return success();
2012}
2013
2014void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
2015 printer << " wraps ";
2016 printer.printGenericOp(&getBody().front().front());
2017}
2018
2019LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2020 Block &block = getRegion().getBlocks().front();
2021
2022 if (block.getOperations().size() != 2)
2023 return emitOpError("expected exactly 2 nested ops");
2024
2025 Operation &enclosedOp = block.getOperations().front();
2026
2028 return emitOpError("invalid enclosed op");
2029
2030 for (auto operand : enclosedOp.getOperands())
2031 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2032 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2033 return emitOpError(
2034 "invalid operand, must be defined by a constant operation");
2035
2036 return success();
2037}
2038
2039//===----------------------------------------------------------------------===//
2040// spirv.GL.FrexpStruct
2041//===----------------------------------------------------------------------===//
2042
2043LogicalResult spirv::GLFrexpStructOp::verify() {
2044 spirv::StructType structTy =
2045 llvm::dyn_cast<spirv::StructType>(getResult().getType());
2046
2047 if (structTy.getNumElements() != 2)
2048 return emitError("result type must be a struct type with two memebers");
2049
2050 Type significandTy = structTy.getElementType(0);
2051 Type exponentTy = structTy.getElementType(1);
2052 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
2053 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
2054
2055 Type operandTy = getOperand().getType();
2056 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
2057 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
2058
2059 if (significandTy != operandTy)
2060 return emitError("member zero of the resulting struct type must be the "
2061 "same type as the operand");
2062
2063 if (exponentVecTy) {
2064 IntegerType componentIntTy =
2065 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
2066 if (!componentIntTy || componentIntTy.getWidth() != 32)
2067 return emitError("member one of the resulting struct type must"
2068 "be a scalar or vector of 32 bit integer type");
2069 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2070 return emitError("member one of the resulting struct type "
2071 "must be a scalar or vector of 32 bit integer type");
2072 }
2073
2074 // Check that the two member types have the same number of components
2075 if (operandVecTy && exponentVecTy &&
2076 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2077 return success();
2078
2079 if (operandFTy && exponentIntTy)
2080 return success();
2081
2082 return emitError("member one of the resulting struct type must have the same "
2083 "number of components as the operand type");
2084}
2085
2086//===----------------------------------------------------------------------===//
2087// spirv.GL.Ldexp
2088//===----------------------------------------------------------------------===//
2089
2090LogicalResult spirv::GLLdexpOp::verify() {
2091 Type significandType = getX().getType();
2092 Type exponentType = getExp().getType();
2093
2094 if (llvm::isa<FloatType>(significandType) !=
2095 llvm::isa<IntegerType>(exponentType))
2096 return emitOpError("operands must both be scalars or vectors");
2097
2098 auto getNumElements = [](Type type) -> unsigned {
2099 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2100 return vectorType.getNumElements();
2101 return 1;
2102 };
2103
2104 if (getNumElements(significandType) != getNumElements(exponentType))
2105 return emitOpError("operands must have the same number of elements");
2106
2107 return success();
2108}
2109
2110//===----------------------------------------------------------------------===//
2111// spirv.ShiftLeftLogicalOp
2112//===----------------------------------------------------------------------===//
2113
2114LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2115 return verifyShiftOp(*this);
2116}
2117
2118//===----------------------------------------------------------------------===//
2119// spirv.ShiftRightArithmeticOp
2120//===----------------------------------------------------------------------===//
2121
2122LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2123 return verifyShiftOp(*this);
2124}
2125
2126//===----------------------------------------------------------------------===//
2127// spirv.ShiftRightLogicalOp
2128//===----------------------------------------------------------------------===//
2129
2130LogicalResult spirv::ShiftRightLogicalOp::verify() {
2131 return verifyShiftOp(*this);
2132}
2133
2134//===----------------------------------------------------------------------===//
2135// spirv.VectorTimesScalarOp
2136//===----------------------------------------------------------------------===//
2137
2138LogicalResult spirv::VectorTimesScalarOp::verify() {
2139 if (getVector().getType() != getType())
2140 return emitOpError("vector operand and result type mismatch");
2141 auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2142 if (getScalar().getType() != scalarType)
2143 return emitOpError("scalar operand and result element type match");
2144 return success();
2145}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:267
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:775
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:563
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:253
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:299
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:186
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:291
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
iterator begin()
Definition Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:276
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:251
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
constexpr StringRef attributeName()
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.