MLIR 23.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(descriptorSetName);
104 elidedAttrs.push_back(bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(builtInName);
115 }
116
117 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118}
119
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result.attributes) ||
130 parser.parseColon() || parser.parseType(type))
131 return failure();
132 auto fnType = dyn_cast<FunctionType>(type);
133 if (!fnType) {
134 parser.emitError(loc, "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138 return failure();
139 result.addTypes(fnType.getResults());
140 return success();
141 }
142 return failure(parser.parseOperandList(ops) ||
143 parser.parseOptionalAttrDict(result.attributes) ||
144 parser.parseColonType(type) ||
145 parser.resolveOperands(ops, type, result.operands) ||
146 parser.addTypeToList(type, result.types));
147}
148
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(0).getType();
155 if (llvm::any_of(op->getOperandTypes(),
156 [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(op->getOperands());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
174
175 if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
176 return op.emitOpError("mismatch in result type and pointer type");
177 }
178 return success();
179}
180
181/// Walks the given type hierarchy with the given indices, potentially down
182/// to component granularity, to select an element type. Returns null type and
183/// emits errors with the given loc on failure.
184static Type
186 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
187 if (indices.empty()) {
188 emitErrorFn("expected at least one index for spirv.CompositeExtract");
189 return nullptr;
190 }
191
192 for (auto index : indices) {
193 if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
195 (index < 0 ||
196 static_cast<uint64_t>(index) >= cType.getNumElements())) {
197 emitErrorFn("index ") << index << " out of bounds for " << type;
198 return nullptr;
199 }
200 type = cType.getElementType(index);
201 } else {
202 emitErrorFn("cannot extract from non-composite type ")
203 << type << " with index " << index;
204 return nullptr;
205 }
206 }
207 return type;
208}
209
210static Type
212 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
216 return nullptr;
217 }
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn("expected at least one index for spirv.CompositeExtract");
220 return nullptr;
221 }
222
223 SmallVector<int32_t, 2> indexVals;
224 for (auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
226 if (!indexIntAttr) {
227 emitErrorFn("expected an 32-bit integer for index, but found '")
228 << indexAttr << "'";
229 return nullptr;
230 }
231 indexVals.push_back(indexIntAttr.getInt());
232 }
233 return getElementType(type, indexVals, emitErrorFn);
234}
235
237 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
238 return ::mlir::emitError(loc, err);
239 };
240 return getElementType(type, indices, errorFn);
241}
242
244 SMLoc loc) {
245 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
246 return parser.emitError(loc, err);
247 };
248 return getElementType(type, indices, errorFn);
249}
250
251template <typename ExtendedBinaryOp>
252static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError("expected result struct type containing two members");
256
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
262
263 return success();
264}
265
269 if (parser.parseOptionalAttrDict(result.attributes) ||
270 parser.parseOperandList(operands) || parser.parseColon())
271 return failure();
272
273 Type resultType;
274 SMLoc loc = parser.getCurrentLocation();
275 if (parser.parseType(resultType))
276 return failure();
277
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.emitError(loc, "expected spirv.struct type with two members");
281
282 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
283 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
284 return failure();
285
286 result.addTypes(resultType);
287 return success();
288}
289
291 OpAsmPrinter &printer) {
292 printer << ' ';
293 printer.printOptionalAttrDict(op->getAttrs());
294 printer.printOperands(op->getOperands());
295 printer << " : " << op->getResultTypes().front();
296}
297
298static LogicalResult verifyShiftOp(Operation *op) {
299 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
300 return op->emitError("expected the same type for the first operand and "
301 "result, but provided ")
302 << op->getOperand(0).getType() << " and "
303 << op->getResult(0).getType();
304 }
305 return success();
306}
307
308//===----------------------------------------------------------------------===//
309// spirv.mlir.addressof
310//===----------------------------------------------------------------------===//
311
312void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
315}
316
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
319 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
320 getVariableAttr()));
321 if (!varOp) {
322 return emitOpError("expected spirv.GlobalVariable symbol");
323 }
324 if (getPointer().getType() != varOp.getType()) {
325 return emitOpError(
326 "result type mismatch with the referenced global variable's type");
327 }
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// spirv.CompositeConstruct
333//===----------------------------------------------------------------------===//
334
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
337
338 // There are 4 cases with varying verification rules:
339 // 1. Cooperative Matrices (1 constituent)
340 // 2. Structs (1 constituent for each member)
341 // 3. Arrays (1 constituent for each array element)
342 // 4. Vectors (1 constituent (sub-)element for each vector element)
343
344 auto coopElementType = llvm::TypeSwitch<Type, Type>(getType())
345 .Case([](spirv::CooperativeMatrixType coopType) {
346 return coopType.getElementType();
347 })
348 .Default(nullptr);
349
350 // Case 1. -- matrices.
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError("has incorrect number of operands: expected ")
354 << "1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError("operand type mismatch: expected operand type ")
357 << coopElementType << ", but provided "
358 << constituents.front().getType();
359 return success();
360 }
361
362 // Case 2./3./4. -- number of constituents matches the number of elements.
363 auto cType = cast<spirv::CompositeType>(getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
366 if (constituents[index].getType() != cType.getElementType(index)) {
367 return emitOpError("operand type mismatch: expected operand type ")
368 << cType.getElementType(index) << ", but provided "
369 << constituents[index].getType();
370 }
371 }
372 return success();
373 }
374
375 // Case 4. -- check that all constituents add up tp the expected vector type.
376 auto resultType = dyn_cast<VectorType>(cType);
377 if (!resultType)
378 return emitOpError(
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
381
383 for (Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError("operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
389
390 Type elementType = component.getType();
391 if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
394 } else {
395 sizes.push_back(1);
396 }
397
398 if (elementType != resultType.getElementType())
399 return emitOpError("operand element type mismatch: expected to be ")
400 << resultType.getElementType() << ", but provided " << elementType;
401 }
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError("has incorrect number of operands: expected ")
405 << cType.getNumElements() << ", but provided " << totalCount;
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// spirv.CompositeExtractOp
411//===----------------------------------------------------------------------===//
412
413void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
414 Value composite,
416 auto indexAttr = builder.getI32ArrayAttr(indices);
417 auto elementType =
418 getElementType(composite.getType(), indexAttr, state.location);
419 if (!elementType) {
420 return;
421 }
422 build(builder, state, elementType, composite, indexAttr);
423}
424
425ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
427 OpAsmParser::UnresolvedOperand compositeInfo;
428 Attribute indicesAttr;
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
431 Type compositeType;
432 SMLoc attrLocation;
433
434 if (parser.parseOperand(compositeInfo) ||
435 parser.getCurrentLocation(&attrLocation) ||
436 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
437 parser.parseColonType(compositeType) ||
438 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
439 return failure();
440 }
441
442 Type resultType =
443 getElementType(compositeType, indicesAttr, parser, attrLocation);
444 if (!resultType) {
445 return failure();
446 }
447 result.addTypes(resultType);
448 return success();
449}
450
451void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
452 printer << ' ' << getComposite() << getIndices() << " : "
453 << getComposite().getType();
454}
455
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
458 auto resultType =
459 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
460 if (!resultType)
461 return failure();
462
463 if (resultType != getType()) {
464 return emitOpError("invalid result type: expected ")
465 << resultType << " but provided " << getType();
466 }
467
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// spirv.CompositeInsert
473//===----------------------------------------------------------------------===//
474
475void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
476 Value object, Value composite,
478 auto indexAttr = builder.getI32ArrayAttr(indices);
479 build(builder, state, composite.getType(), object, composite, indexAttr);
480}
481
482ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485 Type objectType, compositeType;
486 Attribute indicesAttr;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
489 auto loc = parser.getCurrentLocation();
490
491 return failure(
492 parser.parseOperandList(operands, 2) ||
493 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
494 parser.parseColonType(objectType) ||
495 parser.parseKeywordType("into", compositeType) ||
496 parser.resolveOperands(operands, {objectType, compositeType}, loc,
497 result.operands) ||
498 parser.addTypesToList(compositeType, result.types));
499}
500
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
503 auto objectType =
504 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
505 if (!objectType)
506 return failure();
507
508 if (objectType != getObject().getType()) {
509 return emitOpError("object operand type should be ")
510 << objectType << ", but found " << getObject().getType();
511 }
512
513 if (getComposite().getType() != getType()) {
514 return emitOpError("result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() << " vs " << getType();
517 }
518
519 return success();
520}
521
522void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
523 printer << " " << getObject() << ", " << getComposite() << getIndices()
524 << " : " << getObject().getType() << " into "
525 << getComposite().getType();
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.Constant
530//===----------------------------------------------------------------------===//
531
532ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534 Attribute value;
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
536 if (parser.parseAttribute(value, valueAttrName, result.attributes))
537 return failure();
538
539 Type type = NoneType::get(parser.getContext());
540 if (auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
543 if (parser.parseColonType(type))
544 return failure();
545 }
546
547 if (isa<TensorArmType>(type)) {
548 if (parser.parseOptionalColon().succeeded())
549 if (parser.parseType(type))
550 return failure();
551 }
552
553 return parser.addTypeToList(type, result.types);
554}
555
556void spirv::ConstantOp::print(OpAsmPrinter &printer) {
557 printer << ' ' << getValue();
558 if (isa<spirv::ArrayType>(getType()))
559 printer << " : " << getType();
560}
561
562static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
563 Type opType) {
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError("expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
569 << denseAttr;
570 }
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError("result type (")
575 << opType << ") does not match value type (" << valueType << ")";
576 return success();
577 }
578 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
581 return success();
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
584 if (!arrayType)
585 return op.emitOpError("result or element type (")
586 << opType << ") does not match value type (" << valueType
587 << "), must be the same or spirv.array";
588
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
594 }
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError("only support nested array result type");
597
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError("result element type (")
601 << opElemType << ") does not match value element type ("
602 << valueElemType << ")";
603 }
604
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError("result number of elements (")
607 << numElements << ") does not match value number of elements ("
608 << shapedType.getNumElements() << ")";
609 }
610 return success();
611 }
612 if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
614 if (!arrayType)
615 return op.emitOpError(
616 "must have spirv.array result type for array value");
617 Type elemType = arrayType.getElementType();
618 for (Attribute element : arrayAttr.getValue()) {
619 // Verify array elements recursively.
620 if (failed(verifyConstantType(op, element, elemType)))
621 return failure();
622 }
623 return success();
624 }
625 return op.emitOpError("cannot have attribute: ") << value;
626}
627
628LogicalResult spirv::ConstantOp::verify() {
629 // ODS already generates checks to make sure the result type is valid. We just
630 // need to additionally check that the value's attribute type is consistent
631 // with the result type.
632 return verifyConstantType(*this, getValueAttr(), getType());
633}
634
635bool spirv::ConstantOp::isBuildableWith(Type type) {
636 // Must be valid SPIR-V type first.
637 if (!isa<spirv::SPIRVType>(type))
638 return false;
639
640 if (isa<SPIRVDialect>(type.getDialect())) {
641 // TODO: support constant struct
642 return isa<spirv::ArrayType>(type);
643 }
644
645 return true;
646}
647
648spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
649 OpBuilder &builder) {
650 if (auto intType = dyn_cast<IntegerType>(type)) {
651 unsigned width = intType.getWidth();
652 if (width == 1)
653 return spirv::ConstantOp::create(builder, loc, type,
654 builder.getBoolAttr(false));
655 return spirv::ConstantOp::create(
656 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
657 }
658 if (auto floatType = dyn_cast<FloatType>(type)) {
659 return spirv::ConstantOp::create(builder, loc, type,
660 builder.getFloatAttr(floatType, 0.0));
661 }
662 if (auto vectorType = dyn_cast<VectorType>(type)) {
663 Type elemType = vectorType.getElementType();
664 if (isa<IntegerType>(elemType)) {
665 return spirv::ConstantOp::create(
666 builder, loc, type,
667 DenseElementsAttr::get(vectorType,
668 IntegerAttr::get(elemType, 0).getValue()));
669 }
670 if (isa<FloatType>(elemType)) {
671 return spirv::ConstantOp::create(
672 builder, loc, type,
673 DenseFPElementsAttr::get(vectorType,
674 FloatAttr::get(elemType, 0.0).getValue()));
675 }
676 }
677
678 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
679}
680
681spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
682 OpBuilder &builder) {
683 if (auto intType = dyn_cast<IntegerType>(type)) {
684 unsigned width = intType.getWidth();
685 if (width == 1)
686 return spirv::ConstantOp::create(builder, loc, type,
687 builder.getBoolAttr(true));
688 return spirv::ConstantOp::create(
689 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
690 }
691 if (auto floatType = dyn_cast<FloatType>(type)) {
692 return spirv::ConstantOp::create(builder, loc, type,
693 builder.getFloatAttr(floatType, 1.0));
694 }
695 if (auto vectorType = dyn_cast<VectorType>(type)) {
696 Type elemType = vectorType.getElementType();
697 if (isa<IntegerType>(elemType)) {
698 return spirv::ConstantOp::create(
699 builder, loc, type,
700 DenseElementsAttr::get(vectorType,
701 IntegerAttr::get(elemType, 1).getValue()));
702 }
703 if (isa<FloatType>(elemType)) {
704 return spirv::ConstantOp::create(
705 builder, loc, type,
706 DenseFPElementsAttr::get(vectorType,
707 FloatAttr::get(elemType, 1.0).getValue()));
708 }
709 }
710
711 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
712}
713
714void mlir::spirv::ConstantOp::getAsmResultNames(
715 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
716 Type type = getType();
717
718 SmallString<32> specialNameBuffer;
719 llvm::raw_svector_ostream specialName(specialNameBuffer);
720 specialName << "cst";
721
722 IntegerType intTy = dyn_cast<IntegerType>(type);
723
724 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
725 assert(intTy);
726
727 if (intTy.getWidth() == 1) {
728 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
729 }
730
731 if (intTy.isSignless()) {
732 specialName << intCst.getInt();
733 } else if (intTy.isUnsigned()) {
734 specialName << intCst.getUInt();
735 } else {
736 specialName << intCst.getSInt();
737 }
738 }
739
740 if (intTy || isa<FloatType>(type)) {
741 specialName << '_' << type;
742 }
743
744 if (auto vecType = dyn_cast<VectorType>(type)) {
745 specialName << "_vec_";
746 specialName << vecType.getDimSize(0);
747
748 Type elementType = vecType.getElementType();
749
750 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
751 specialName << "x" << elementType;
752 }
753 }
754
755 setNameFn(getResult(), specialName.str());
756}
757
758void mlir::spirv::AddressOfOp::getAsmResultNames(
759 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
760 SmallString<32> specialNameBuffer;
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() << "_addr";
763 setNameFn(getResult(), specialName.str());
764}
765
766//===----------------------------------------------------------------------===//
767// spirv.EXTConstantCompositeReplicate
768//===----------------------------------------------------------------------===//
769
770// Returns type of attribute. In case of a TypedAttr this will simply return
771// the type. But for an ArrayAttr which is untyped and can be multidimensional
772// it creates the ArrayType recursively.
774 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775 return typedAttr.getType();
776 }
777
778 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
779 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
780 }
781
782 return nullptr;
783}
784
785LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
786 Type valueType = getValueType(getValue());
787 if (!valueType)
788 return emitError("unknown value attribute type");
789
790 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
791 if (!compositeType)
792 return emitError("result type is not a composite type");
793
794 Type compositeElementType = compositeType.getElementType(0);
795
796 SmallVector<Type, 3> possibleTypes = {compositeElementType};
797 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798 compositeElementType = type.getElementType(0);
799 possibleTypes.push_back(compositeElementType);
800 }
801
802 if (!is_contained(possibleTypes, valueType)) {
803 return emitError("expected value attribute type ")
804 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
805 }
806
807 return success();
808}
809
810//===----------------------------------------------------------------------===//
811// spirv.ControlBarrierOp
812//===----------------------------------------------------------------------===//
813
814LogicalResult spirv::ControlBarrierOp::verify() {
815 return verifyMemorySemantics(getOperation(), getMemorySemantics());
816}
817
818//===----------------------------------------------------------------------===//
819// spirv.EntryPoint
820//===----------------------------------------------------------------------===//
821
822void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
823 spirv::ExecutionModel executionModel,
824 spirv::FuncOp function,
825 ArrayRef<Attribute> interfaceVars) {
826 build(builder, state,
827 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
828 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
829}
830
831ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
833 spirv::ExecutionModel execModel;
834 SmallVector<Attribute, 4> interfaceVars;
835
837 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
838 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
839 return failure();
840 }
841
842 if (!parser.parseOptionalComma()) {
843 // Parse the interface variables
844 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
845 // The name of the interface variable attribute isnt important
846 FlatSymbolRefAttr var;
847 NamedAttrList attrs;
848 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
849 return failure();
850 interfaceVars.push_back(var);
851 return success();
852 }))
853 return failure();
854 }
855 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
856 parser.getBuilder().getArrayAttr(interfaceVars));
857 return success();
858}
859
860void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
861 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
862 printer.printSymbolName(getFn());
863 auto interfaceVars = getInterface().getValue();
864 if (!interfaceVars.empty())
865 printer << ", " << llvm::interleaved(interfaceVars);
866}
867
868LogicalResult spirv::EntryPointOp::verify() {
869 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
870 // verification.
871 return success();
872}
873
874//===----------------------------------------------------------------------===//
875// spirv.ExecutionMode
876//===----------------------------------------------------------------------===//
877
878void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
879 spirv::FuncOp function,
880 spirv::ExecutionMode executionMode,
881 ArrayRef<int32_t> params) {
882 build(builder, state, SymbolRefAttr::get(function),
883 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
884 builder.getI32ArrayAttr(params));
885}
886
887ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
889 spirv::ExecutionMode execMode;
890 Attribute fn;
891 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
893 return failure();
894 }
895
897 Type i32Type = parser.getBuilder().getIntegerType(32);
898 while (!parser.parseOptionalComma()) {
899 NamedAttrList attr;
900 Attribute value;
901 if (parser.parseAttribute(value, i32Type, "value", attr)) {
902 return failure();
903 }
904 values.push_back(cast<IntegerAttr>(value).getInt());
905 }
906 StringRef valuesAttrName =
907 spirv::ExecutionModeOp::getValuesAttrName(result.name);
908 result.addAttribute(valuesAttrName,
909 parser.getBuilder().getI32ArrayAttr(values));
910 return success();
911}
912
913void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
914 printer << " ";
915 printer.printSymbolName(getFn());
916 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
917 ArrayAttr values = this->getValues();
918 if (!values.empty())
919 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
920}
921
922//===----------------------------------------------------------------------===//
923// spirv.ExecutionModeId
924//===----------------------------------------------------------------------===//
925
926ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
928 ExecutionMode execMode;
929 if (Attribute fn;
930 parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
931 parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
932 return failure();
933 }
934
936 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
937 FlatSymbolRefAttr attr;
938 if (parser.parseAttribute(attr))
939 return failure();
940 values.push_back(attr);
941 return success();
942 })) {
943 return failure();
944 }
945
946 StringRef valuesAttrName = getValuesAttrName(result.name);
947 ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
948 result.addAttribute(valuesAttrName, valuesAttr);
949 return success();
950}
951
952void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
953 printer << " ";
954 printer.printSymbolName(getFn());
955 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
956
957 llvm::interleaveComma(
958 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
959 [&](StringRef value) { printer.printSymbolName(value); });
960}
961
962LogicalResult spirv::ExecutionModeIdOp::verify() {
963 // Valid as of SPIRV 1.6
964 switch (getExecutionMode()) {
965 case ExecutionMode::SubgroupsPerWorkgroupId:
966 case ExecutionMode::LocalSizeId:
967 case ExecutionMode::LocalSizeHintId:
968 break;
969 default:
970 return emitOpError("expected ExecutionMode that takes extra operands that "
971 "are <id> operands, got: ")
972 << stringifyExecutionMode(getExecutionMode());
973 }
974
975 if (getValues().empty())
976 return emitOpError("expected at least one value operand");
977
978 for (Attribute value : getValues()) {
979 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
980 if (!valueSymbol)
981 return emitOpError("expected value operands to be symbol reference");
983 (*this)->getParentOp(), valueSymbol);
984 if (!valueOp)
985 return emitOpError("cannot find symbol referenced by value operand: ")
986 << valueSymbol.getValue();
987 }
988
989 return success();
990}
991
992//===----------------------------------------------------------------------===//
993// spirv.func
994//===----------------------------------------------------------------------===//
995
996ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
998 SmallVector<DictionaryAttr> resultAttrs;
999 SmallVector<Type> resultTypes;
1000 auto &builder = parser.getBuilder();
1001
1002 // Parse the name as a symbol.
1003 StringAttr nameAttr;
1004 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1005 result.attributes))
1006 return failure();
1007
1008 // Parse the function signature.
1009 bool isVariadic = false;
1011 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1012 resultAttrs))
1013 return failure();
1014
1015 SmallVector<Type> argTypes;
1016 for (auto &arg : entryArgs)
1017 argTypes.push_back(arg.type);
1018 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1019 result.addAttribute(getFunctionTypeAttrName(result.name),
1020 TypeAttr::get(fnType));
1021
1022 // Parse the optional function control keyword.
1023 spirv::FunctionControl fnControl;
1025 return failure();
1026
1027 // If additional attributes are present, parse them.
1028 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1029 return failure();
1030
1031 // Add the attributes to the function arguments.
1032 assert(resultAttrs.size() == resultTypes.size());
1034 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1035 getResAttrsAttrName(result.name));
1036
1037 // Parse the optional function body.
1038 auto *body = result.addRegion();
1039 OptionalParseResult parseResult =
1040 parser.parseOptionalRegion(*body, entryArgs);
1041 return failure(parseResult.has_value() && failed(*parseResult));
1042}
1043
1044void spirv::FuncOp::print(OpAsmPrinter &printer) {
1045 // Print function name, signature, and control.
1046 printer << " ";
1047 printer.printSymbolName(getSymName());
1048 auto fnType = getFunctionType();
1050 printer, *this, fnType.getInputs(),
1051 /*isVariadic=*/false, fnType.getResults());
1052 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
1053 << "\"";
1055 printer, *this,
1056 {spirv::attributeName<spirv::FunctionControl>(),
1057 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1058 getFunctionControlAttrName()});
1059
1060 // Print the body if this is not an external function.
1061 Region &body = this->getBody();
1062 if (!body.empty()) {
1063 printer << ' ';
1064 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1065 /*printBlockTerminators=*/true);
1066 }
1067}
1068
1069LogicalResult spirv::FuncOp::verifyType() {
1070 FunctionType fnType = getFunctionType();
1071 if (fnType.getNumResults() > 1)
1072 return emitOpError("cannot have more than one result");
1073
1074 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1075 unsigned argIndex) {
1076 auto func = cast<FunctionOpInterface>(getOperation());
1077 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1078 if (argAttr.getName() != spirv::DecorationAttr::name)
1079 continue;
1080 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1081 return decAttr.getValue() == decoration;
1082 }
1083 return false;
1084 };
1085
1086 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1087 Type param = fnType.getInputs()[i];
1088 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1089 if (!inputPtrType)
1090 continue;
1091
1092 auto pointeePtrType =
1093 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1094 if (pointeePtrType) {
1095 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1096 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1097 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1098 // > storage class, the function parameter must be decorated with exactly
1099 // > one of AliasedPointer or RestrictPointer.
1100 if (pointeePtrType.getStorageClass() !=
1101 spirv::StorageClass::PhysicalStorageBuffer)
1102 continue;
1103
1104 bool hasAliasedPtr =
1105 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1106 bool hasRestrictPtr =
1107 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1108 if (!hasAliasedPtr && !hasRestrictPtr)
1109 return emitOpError()
1110 << "with a pointer points to a physical buffer pointer must "
1111 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1112 continue;
1113 }
1114 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1115 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1116 // > the PhysicalStorageBuffer storage class, the function parameter must
1117 // > be decorated with exactly one of Aliased or Restrict.
1118 if (auto pointeeArrayType =
1119 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1120 pointeePtrType =
1121 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1122 } else {
1123 pointeePtrType = inputPtrType;
1124 }
1125
1126 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1127 spirv::StorageClass::PhysicalStorageBuffer)
1128 continue;
1129
1130 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1131 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1132 if (!hasAliased && !hasRestrict)
1133 return emitOpError() << "with physical buffer pointer must be decorated "
1134 "either 'Aliased' or 'Restrict'";
1135 }
1136
1137 return success();
1138}
1139
1140LogicalResult spirv::FuncOp::verifyBody() {
1141 FunctionType fnType = getFunctionType();
1142 if (!isExternal()) {
1143 Block &entryBlock = front();
1144
1145 unsigned numArguments = this->getNumArguments();
1146 if (entryBlock.getNumArguments() != numArguments)
1147 return emitOpError("entry block must have ")
1148 << numArguments << " arguments to match function signature";
1149
1150 for (auto [index, fnArgType, blockArgType] :
1151 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1152 if (blockArgType != fnArgType) {
1153 return emitOpError("type of entry block argument #")
1154 << index << '(' << blockArgType
1155 << ") must match the type of the corresponding argument in "
1156 << "function signature(" << fnArgType << ')';
1157 }
1158 }
1159 }
1160
1161 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1162 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1163 if (fnType.getNumResults() != 0)
1164 return retOp.emitOpError("cannot be used in functions returning value");
1165 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1166 if (fnType.getNumResults() != 1)
1167 return retOp.emitOpError(
1168 "returns 1 value but enclosing function requires ")
1169 << fnType.getNumResults() << " results";
1170
1171 auto retOperandType = retOp.getValue().getType();
1172 auto fnResultType = fnType.getResult(0);
1173 if (retOperandType != fnResultType)
1174 return retOp.emitOpError(" return value's type (")
1175 << retOperandType << ") mismatch with function's result type ("
1176 << fnResultType << ")";
1177 }
1178 return WalkResult::advance();
1179 });
1180
1181 // TODO: verify other bits like linkage type.
1182
1183 return failure(walkResult.wasInterrupted());
1184}
1185
1186void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1187 StringRef name, FunctionType type,
1188 spirv::FunctionControl control,
1191 builder.getStringAttr(name));
1192 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1193 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1194 builder.getAttr<spirv::FunctionControlAttr>(control));
1195 state.attributes.append(attrs.begin(), attrs.end());
1196 state.addRegion();
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// spirv.GLFClampOp
1201//===----------------------------------------------------------------------===//
1202
1203ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1206}
1207void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1208
1209//===----------------------------------------------------------------------===//
1210// spirv.GLUClampOp
1211//===----------------------------------------------------------------------===//
1212
1213ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1216}
1217void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1218
1219//===----------------------------------------------------------------------===//
1220// spirv.GLSClampOp
1221//===----------------------------------------------------------------------===//
1222
1223ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1226}
1227void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1228
1229//===----------------------------------------------------------------------===//
1230// spirv.GLFmaOp
1231//===----------------------------------------------------------------------===//
1232
1233ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1235}
1236void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1237
1238//===----------------------------------------------------------------------===//
1239// spirv.GlobalVariable
1240//===----------------------------------------------------------------------===//
1241
1242void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1243 Type type, StringRef name,
1244 unsigned descriptorSet, unsigned binding) {
1245 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1246 state.addAttribute(
1247 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1248 builder.getI32IntegerAttr(descriptorSet));
1249 state.addAttribute(
1250 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1251 builder.getI32IntegerAttr(binding));
1252}
1253
1254void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1255 Type type, StringRef name,
1256 spirv::BuiltIn builtin) {
1257 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1258 state.addAttribute(
1259 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1260 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1261}
1262
1263ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1265 // Parse variable name.
1266 StringAttr nameAttr;
1267 StringRef initializerAttrName =
1268 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1269 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1270 result.attributes)) {
1271 return failure();
1272 }
1273
1274 // Parse optional initializer
1275 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1276 FlatSymbolRefAttr initSymbol;
1277 if (parser.parseLParen() ||
1278 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1279 result.attributes) ||
1280 parser.parseRParen())
1281 return failure();
1282 }
1283
1284 if (parseVariableDecorations(parser, result)) {
1285 return failure();
1286 }
1287
1288 Type type;
1289 StringRef typeAttrName =
1290 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1291 auto loc = parser.getCurrentLocation();
1292 if (parser.parseColonType(type)) {
1293 return failure();
1294 }
1295 if (!isa<spirv::PointerType>(type)) {
1296 return parser.emitError(loc, "expected spirv.ptr type");
1297 }
1298 result.addAttribute(typeAttrName, TypeAttr::get(type));
1299
1300 return success();
1301}
1302
1303void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1304 SmallVector<StringRef, 4> elidedAttrs{
1305 spirv::attributeName<spirv::StorageClass>()};
1306
1307 // Print variable name.
1308 printer << ' ';
1309 printer.printSymbolName(getSymName());
1310 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1311
1312 StringRef initializerAttrName = this->getInitializerAttrName();
1313 // Print optional initializer
1314 if (auto initializer = this->getInitializer()) {
1315 printer << " " << initializerAttrName << '(';
1316 printer.printSymbolName(*initializer);
1317 printer << ')';
1318 elidedAttrs.push_back(initializerAttrName);
1319 }
1320
1321 StringRef typeAttrName = this->getTypeAttrName();
1322 elidedAttrs.push_back(typeAttrName);
1323 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1324 printer << " : " << getType();
1325}
1326
1327LogicalResult spirv::GlobalVariableOp::verify() {
1328 if (!isa<spirv::PointerType>(getType()))
1329 return emitOpError("result must be of a !spv.ptr type");
1330
1331 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1332 // object. It cannot be Generic. It must be the same as the Storage Class
1333 // operand of the Result Type."
1334 // Also, Function storage class is reserved by spirv.Variable.
1335 auto storageClass = this->storageClass();
1336 if (storageClass == spirv::StorageClass::Generic ||
1337 storageClass == spirv::StorageClass::Function) {
1338 return emitOpError("storage class cannot be '")
1339 << stringifyStorageClass(storageClass) << "'";
1340 }
1341
1342 // SPIR-V spec: "A module-scope OpVariable with an Initializer operand must
1343 // not be decorated with the Import Linkage Type."
1344 if (std::optional<spirv::LinkageAttributesAttr> linkage =
1345 getLinkageAttributes()) {
1346 if (linkage->getLinkageType().getValue() == spirv::LinkageType::Import &&
1347 getInitializer()) {
1348 return emitOpError(
1349 "with Import linkage type must not have an initializer");
1350 }
1351 }
1352
1353 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1354 this->getInitializerAttrName())) {
1356 (*this)->getParentOp(), init.getAttr());
1357 // TODO: Currently only variable initialization with specialization
1358 // constants is supported. There could be normal constants in the module
1359 // scope as well.
1360 //
1361 // In the current setup we also cannot initialize one global variable with
1362 // another. The problem is that if we try to initialize pointer of type X
1363 // with another pointer type, the validator fails because it expects the
1364 // variable to be initialized to be type X, not pointer to X. Now
1365 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1366 // we cannot initialize one `spirv.GlobalVariable` with another.
1367 if (!initOp ||
1368 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1369 return emitOpError("initializer must be result of a "
1370 "spirv.SpecConstant or "
1371 "spirv.SpecConstantCompositeOp op");
1372 }
1373 }
1374
1375 return success();
1376}
1377
1378//===----------------------------------------------------------------------===//
1379// spirv.INTEL.SubgroupBlockRead
1380//===----------------------------------------------------------------------===//
1381
1382LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1383 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1384 return failure();
1385
1386 return success();
1387}
1388
1389//===----------------------------------------------------------------------===//
1390// spirv.INTEL.SubgroupBlockWrite
1391//===----------------------------------------------------------------------===//
1392
1393ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1395 // Parse the storage class specification
1396 spirv::StorageClass storageClass;
1398 auto loc = parser.getCurrentLocation();
1399 Type elementType;
1400 if (parseEnumStrAttr(storageClass, parser) ||
1401 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1402 parser.parseType(elementType)) {
1403 return failure();
1404 }
1405
1406 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1407 if (auto valVecTy = dyn_cast<VectorType>(elementType))
1408 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1409
1410 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1411 result.operands)) {
1412 return failure();
1413 }
1414 return success();
1415}
1416
1417void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1418 printer << " " << getPtr() << ", " << getValue() << " : "
1419 << getValue().getType();
1420}
1421
1422LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1423 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1424 return failure();
1425
1426 return success();
1427}
1428
1429//===----------------------------------------------------------------------===//
1430// spirv.IAddCarryOp
1431//===----------------------------------------------------------------------===//
1432
1433LogicalResult spirv::IAddCarryOp::verify() {
1434 return ::verifyArithmeticExtendedBinaryOp(*this);
1435}
1436
1437ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1439 return ::parseArithmeticExtendedBinaryOp(parser, result);
1440}
1441
1442void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1443 ::printArithmeticExtendedBinaryOp(*this, printer);
1444}
1445
1446//===----------------------------------------------------------------------===//
1447// spirv.ISubBorrowOp
1448//===----------------------------------------------------------------------===//
1449
1450LogicalResult spirv::ISubBorrowOp::verify() {
1451 return ::verifyArithmeticExtendedBinaryOp(*this);
1452}
1453
1454ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1456 return ::parseArithmeticExtendedBinaryOp(parser, result);
1457}
1458
1459void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1460 ::printArithmeticExtendedBinaryOp(*this, printer);
1461}
1462
1463//===----------------------------------------------------------------------===//
1464// spirv.SMulExtended
1465//===----------------------------------------------------------------------===//
1466
1467LogicalResult spirv::SMulExtendedOp::verify() {
1468 return ::verifyArithmeticExtendedBinaryOp(*this);
1469}
1470
1471ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1473 return ::parseArithmeticExtendedBinaryOp(parser, result);
1474}
1475
1476void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1477 ::printArithmeticExtendedBinaryOp(*this, printer);
1478}
1479
1480//===----------------------------------------------------------------------===//
1481// spirv.UMulExtended
1482//===----------------------------------------------------------------------===//
1483
1484LogicalResult spirv::UMulExtendedOp::verify() {
1485 return ::verifyArithmeticExtendedBinaryOp(*this);
1486}
1487
1488ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1490 return ::parseArithmeticExtendedBinaryOp(parser, result);
1491}
1492
1493void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1494 ::printArithmeticExtendedBinaryOp(*this, printer);
1495}
1496
1497//===----------------------------------------------------------------------===//
1498// spirv.MemoryBarrierOp
1499//===----------------------------------------------------------------------===//
1500
1501LogicalResult spirv::MemoryBarrierOp::verify() {
1502 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1503}
1504
1505//===----------------------------------------------------------------------===//
1506// spirv.module
1507//===----------------------------------------------------------------------===//
1508
1509void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1510 std::optional<StringRef> name) {
1511 OpBuilder::InsertionGuard guard(builder);
1512 builder.createBlock(state.addRegion());
1513 if (name) {
1515 builder.getStringAttr(*name));
1516 }
1517}
1518
1519void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1520 spirv::AddressingModel addressingModel,
1521 spirv::MemoryModel memoryModel,
1522 std::optional<VerCapExtAttr> vceTriple,
1523 std::optional<StringRef> name) {
1524 state.addAttribute(
1525 "addressing_model",
1526 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1527 state.addAttribute("memory_model",
1528 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1529 OpBuilder::InsertionGuard guard(builder);
1530 builder.createBlock(state.addRegion());
1531 if (vceTriple)
1532 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1533 if (name)
1535 builder.getStringAttr(*name));
1536}
1537
1538ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1540 Region *body = result.addRegion();
1541
1542 // If the name is present, parse it.
1543 StringAttr nameAttr;
1545 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1546
1547 // Parse attributes
1548 spirv::AddressingModel addrModel;
1549 spirv::MemoryModel memoryModel;
1551 result) ||
1553 result))
1554 return failure();
1555
1556 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1557 spirv::VerCapExtAttr vceTriple;
1558 if (parser.parseAttribute(vceTriple,
1559 spirv::ModuleOp::getVCETripleAttrName(),
1560 result.attributes))
1561 return failure();
1562 }
1563
1564 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1565 parser.parseRegion(*body, /*arguments=*/{}))
1566 return failure();
1567
1568 // Make sure we have at least one block.
1569 if (body->empty())
1570 body->push_back(new Block());
1571
1572 return success();
1573}
1574
1575void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1576 if (std::optional<StringRef> name = getName()) {
1577 printer << ' ';
1578 printer.printSymbolName(*name);
1579 }
1580
1581 SmallVector<StringRef, 2> elidedAttrs;
1582
1583 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1584 << spirv::stringifyMemoryModel(getMemoryModel());
1585 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1586 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1587 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1589
1590 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1591 printer << " requires " << *triple;
1592 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1593 }
1594
1595 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1596 printer << ' ';
1597 printer.printRegion(getRegion());
1598}
1599
1600LogicalResult spirv::ModuleOp::verifyRegions() {
1601 Dialect *dialect = (*this)->getDialect();
1603 entryPoints;
1604 mlir::SymbolTable table(*this);
1605
1606 for (auto &op : *getBody()) {
1607 if (op.getDialect() != dialect)
1608 return op.emitError("'spirv.module' can only contain spirv.* ops");
1609
1610 // For EntryPoint op, check that the function and execution model is not
1611 // duplicated in EntryPointOps. Also verify that the interface specified
1612 // comes from globalVariables here to make this check cheaper.
1613 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1614 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1615 if (!funcOp) {
1616 return entryPointOp.emitError("function '")
1617 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1618 }
1619 if (auto interface = entryPointOp.getInterface()) {
1620 for (Attribute varRef : interface) {
1621 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1622 if (!varSymRef) {
1623 return entryPointOp.emitError(
1624 "expected symbol reference for interface "
1625 "specification instead of '")
1626 << varRef;
1627 }
1628 auto variableOp =
1629 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1630 if (!variableOp) {
1631 return entryPointOp.emitError("expected spirv.GlobalVariable "
1632 "symbol reference instead of'")
1633 << varSymRef << "'";
1634 }
1635 }
1636 }
1637
1638 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1639 funcOp, entryPointOp.getExecutionModel());
1640 if (!entryPoints.try_emplace(key, entryPointOp).second)
1641 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1642 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1643 // If the function is external and does not have 'Import'
1644 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1645 // LinkageAttributes is used to import external functions.
1646 auto linkageAttr = funcOp.getLinkageAttributes();
1647 auto hasImportLinkage =
1648 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1649 spirv::LinkageType::Import);
1650 if (funcOp.isExternal() && !hasImportLinkage)
1651 return op.emitError(
1652 "'spirv.module' cannot contain external functions "
1653 "without 'Import' linkage_attributes (LinkageAttributes)");
1654
1655 // TODO: move this check to spirv.func.
1656 for (auto &block : funcOp)
1657 for (auto &op : block) {
1658 if (op.getDialect() != dialect)
1659 return op.emitError(
1660 "functions in 'spirv.module' can only contain spirv.* ops");
1661 }
1662 }
1663 }
1664
1665 return success();
1666}
1667
1668//===----------------------------------------------------------------------===//
1669// spirv.mlir.referenceof
1670//===----------------------------------------------------------------------===//
1671
1672LogicalResult spirv::ReferenceOfOp::verify() {
1673 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1674 (*this)->getParentOp(), getSpecConstAttr());
1675 Type constType;
1676
1677 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1678 if (specConstOp)
1679 constType = specConstOp.getDefaultValue().getType();
1680
1681 auto specConstCompositeOp =
1682 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1683 if (specConstCompositeOp)
1684 constType = specConstCompositeOp.getType();
1685
1686 if (!specConstOp && !specConstCompositeOp)
1687 return emitOpError(
1688 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1689
1690 if (getReference().getType() != constType)
1691 return emitOpError("result type mismatch with the referenced "
1692 "specialization constant's type");
1693
1694 return success();
1695}
1696
1697//===----------------------------------------------------------------------===//
1698// spirv.SpecConstant
1699//===----------------------------------------------------------------------===//
1700
1701ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1703 StringAttr nameAttr;
1704 Attribute valueAttr;
1705 StringRef defaultValueAttrName =
1706 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1707
1708 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1709 result.attributes))
1710 return failure();
1711
1712 // Parse optional spec_id.
1713 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1714 IntegerAttr specIdAttr;
1715 if (parser.parseLParen() ||
1716 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1717 parser.parseRParen())
1718 return failure();
1719 }
1720
1721 if (parser.parseEqual() ||
1722 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1723 return failure();
1724
1725 return success();
1726}
1727
1728void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1729 printer << ' ';
1730 printer.printSymbolName(getSymName());
1731 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1732 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1733 printer << " = " << getDefaultValue();
1734}
1735
1736LogicalResult spirv::SpecConstantOp::verify() {
1737 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1738 if (specID.getValue().isNegative())
1739 return emitOpError("SpecId cannot be negative");
1740
1741 auto value = getDefaultValue();
1742 if (isa<IntegerAttr, FloatAttr>(value)) {
1743 // Make sure bitwidth is allowed.
1744 if (!isa<spirv::SPIRVType>(value.getType()))
1745 return emitOpError("default value bitwidth disallowed");
1746 return success();
1747 }
1748 return emitOpError(
1749 "default value can only be a bool, integer, or float scalar");
1750}
1751
1752//===----------------------------------------------------------------------===//
1753// spirv.VectorShuffle
1754//===----------------------------------------------------------------------===//
1755
1756LogicalResult spirv::VectorShuffleOp::verify() {
1757 VectorType resultType = cast<VectorType>(getType());
1758
1759 size_t numResultElements = resultType.getNumElements();
1760 if (numResultElements != getComponents().size())
1761 return emitOpError("result type element count (")
1762 << numResultElements
1763 << ") mismatch with the number of component selectors ("
1764 << getComponents().size() << ")";
1765
1766 size_t totalSrcElements =
1767 cast<VectorType>(getVector1().getType()).getNumElements() +
1768 cast<VectorType>(getVector2().getType()).getNumElements();
1769
1770 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1771 uint32_t index = selector.getZExtValue();
1772 if (index >= totalSrcElements &&
1773 index != std::numeric_limits<uint32_t>().max())
1774 return emitOpError("component selector ")
1775 << index << " out of range: expected to be in [0, "
1776 << totalSrcElements << ") or 0xffffffff";
1777 }
1778 return success();
1779}
1780
1781//===----------------------------------------------------------------------===//
1782// spirv.SpecConstantComposite
1783//===----------------------------------------------------------------------===//
1784
1785ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1787
1788 StringAttr compositeName;
1789 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1790 result.attributes))
1791 return failure();
1792
1793 if (parser.parseLParen())
1794 return failure();
1795
1796 SmallVector<Attribute, 4> constituents;
1797
1798 do {
1799 // The name of the constituent attribute isn't important
1800 const char *attrName = "spec_const";
1801 FlatSymbolRefAttr specConstRef;
1802 NamedAttrList attrs;
1803
1804 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1805 return failure();
1806
1807 constituents.push_back(specConstRef);
1808 } while (!parser.parseOptionalComma());
1809
1810 if (parser.parseRParen())
1811 return failure();
1812
1813 StringAttr compositeSpecConstituentsName =
1814 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1815 result.addAttribute(compositeSpecConstituentsName,
1816 parser.getBuilder().getArrayAttr(constituents));
1817
1818 Type type;
1819 if (parser.parseColonType(type))
1820 return failure();
1821
1822 StringAttr typeAttrName =
1823 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1824 result.addAttribute(typeAttrName, TypeAttr::get(type));
1825
1826 return success();
1827}
1828
1829void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1830 printer << " ";
1831 printer.printSymbolName(getSymName());
1832 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1833 << ") : " << getType();
1834}
1835
1836LogicalResult spirv::SpecConstantCompositeOp::verify() {
1837 auto cType = dyn_cast<spirv::CompositeType>(getType());
1838 auto constituents = this->getConstituents().getValue();
1839
1840 if (!cType)
1841 return emitError("result type must be a composite type, but provided ")
1842 << getType();
1843
1844 if (isa<spirv::CooperativeMatrixType>(cType))
1845 return emitError("unsupported composite type ") << cType;
1846 if (constituents.size() != cType.getNumElements())
1847 return emitError("has incorrect number of operands: expected ")
1848 << cType.getNumElements() << ", but provided "
1849 << constituents.size();
1850
1851 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1852 auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
1853
1855 (*this)->getParentOp(), constituent.getAttr());
1856
1857 if (!constituentOp)
1858 return emitError("unknown constituent symbol ") << constituent.getAttr();
1859
1860 Type constituentType;
1861 if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp)) {
1862 constituentType = specConstOp.getDefaultValue().getType();
1863 } else if (auto specConstCompositeOp =
1864 dyn_cast<spirv::SpecConstantCompositeOp>(constituentOp)) {
1865 constituentType = specConstCompositeOp.getType();
1866 } else {
1867 return emitError("unsupported constituent ")
1868 << constituent.getAttr()
1869 << ": must reference a spirv.SpecConstant or "
1870 "spirv.SpecConstantComposite";
1871 }
1872
1873 if (constituentType != cType.getElementType(index))
1874 return emitError("has incorrect types of operands: expected ")
1875 << cType.getElementType(index) << ", but provided "
1876 << constituentType;
1877 }
1878
1879 return success();
1880}
1881
1882//===----------------------------------------------------------------------===//
1883// spirv.EXTSpecConstantCompositeReplicateOp
1884//===----------------------------------------------------------------------===//
1885
1886ParseResult
1887spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1889 StringAttr compositeName;
1890 FlatSymbolRefAttr specConstRef;
1891 const char *attrName = "spec_const";
1892 NamedAttrList attrs;
1893 Type type;
1894
1895 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1896 result.attributes) ||
1897 parser.parseLParen() ||
1898 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1899 parser.parseRParen() || parser.parseColonType(type))
1900 return failure();
1901
1902 StringAttr compositeSpecConstituentName =
1903 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1904 result.name);
1905 result.addAttribute(compositeSpecConstituentName, specConstRef);
1906
1907 StringAttr typeAttrName =
1908 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1909 result.addAttribute(typeAttrName, TypeAttr::get(type));
1910
1911 return success();
1912}
1913
1914void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1915 printer << " ";
1916 printer.printSymbolName(getSymName());
1917 printer << " (" << this->getConstituent() << ") : " << getType();
1918}
1919
1920LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1921 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1922 if (!compositeType)
1923 return emitError("result type must be a composite type, but provided ")
1924 << getType();
1925
1927 (*this)->getParentOp(), this->getConstituent());
1928 if (!constituentOp)
1929 return emitError(
1930 "splat spec constant reference defining constituent not found");
1931
1932 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1933 if (!constituentSpecConstOp)
1934 return emitError("constituent is not a spec constant");
1935
1936 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1937 Type compositeElementType = compositeType.getElementType(0);
1938 if (constituentType != compositeElementType)
1939 return emitError("constituent has incorrect type: expected ")
1940 << compositeElementType << ", but provided " << constituentType;
1941
1942 return success();
1943}
1944
1945//===----------------------------------------------------------------------===//
1946// spirv.SpecConstantOperation
1947//===----------------------------------------------------------------------===//
1948
1949ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1951 Region *body = result.addRegion();
1952
1953 if (parser.parseKeyword("wraps"))
1954 return failure();
1955
1956 body->push_back(new Block);
1957 Block &block = body->back();
1958 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1959
1960 if (!wrappedOp)
1961 return failure();
1962
1963 OpBuilder builder(parser.getContext());
1964 builder.setInsertionPointToEnd(&block);
1965 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
1966 result.location = wrappedOp->getLoc();
1967
1968 result.addTypes(wrappedOp->getResult(0).getType());
1969
1970 if (parser.parseOptionalAttrDict(result.attributes))
1971 return failure();
1972
1973 return success();
1974}
1975
1976void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
1977 printer << " wraps ";
1978 printer.printGenericOp(&getBody().front().front());
1979}
1980
1981LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1982 Block &block = getRegion().getBlocks().front();
1983
1984 if (block.getOperations().size() != 2)
1985 return emitOpError("expected exactly 2 nested ops");
1986
1987 Operation &enclosedOp = block.getOperations().front();
1988
1990 return emitOpError("invalid enclosed op");
1991
1992 for (auto operand : enclosedOp.getOperands())
1993 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1994 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1995 return emitOpError(
1996 "invalid operand, must be defined by a constant operation");
1997
1998 return success();
1999}
2000
2001//===----------------------------------------------------------------------===//
2002// spirv.GL.FrexpStruct
2003//===----------------------------------------------------------------------===//
2004
2005LogicalResult spirv::GLFrexpStructOp::verify() {
2006 spirv::StructType structTy =
2007 dyn_cast<spirv::StructType>(getResult().getType());
2008
2009 if (structTy.getNumElements() != 2)
2010 return emitError("result type must be a struct type with two memebers");
2011
2012 Type significandTy = structTy.getElementType(0);
2013 Type exponentTy = structTy.getElementType(1);
2014 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2015 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2016
2017 Type operandTy = getOperand().getType();
2018 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2019 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2020
2021 if (significandTy != operandTy)
2022 return emitError("member zero of the resulting struct type must be the "
2023 "same type as the operand");
2024
2025 if (exponentVecTy) {
2026 IntegerType componentIntTy =
2027 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2028 if (!componentIntTy || componentIntTy.getWidth() != 32)
2029 return emitError("member one of the resulting struct type must"
2030 "be a scalar or vector of 32 bit integer type");
2031 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2032 return emitError("member one of the resulting struct type "
2033 "must be a scalar or vector of 32 bit integer type");
2034 }
2035
2036 // Check that the two member types have the same number of components
2037 if (operandVecTy && exponentVecTy &&
2038 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2039 return success();
2040
2041 if (operandFTy && exponentIntTy)
2042 return success();
2043
2044 return emitError("member one of the resulting struct type must have the same "
2045 "number of components as the operand type");
2046}
2047
2048//===----------------------------------------------------------------------===//
2049// spirv.GL.Ldexp
2050//===----------------------------------------------------------------------===//
2051
2052static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
2053 Type integerType) {
2054 if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
2055 return op->emitOpError("operands must both be scalars or vectors");
2056
2057 auto getNumElements = [](Type type) -> unsigned {
2058 if (auto vectorType = dyn_cast<VectorType>(type))
2059 return vectorType.getNumElements();
2060 return 1;
2061 };
2062
2063 if (getNumElements(floatType) != getNumElements(integerType))
2064 return op->emitOpError("operands must have the same number of elements");
2065
2066 return success();
2067}
2068
2069LogicalResult spirv::GLLdexpOp::verify() {
2070 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2071 getExp().getType());
2072}
2073
2074//===----------------------------------------------------------------------===//
2075// spirv.CL.ldexp
2076//===----------------------------------------------------------------------===//
2077
2078LogicalResult spirv::CLLdexpOp::verify() {
2079 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2080 getExp().getType());
2081}
2082
2083//===----------------------------------------------------------------------===//
2084// spirv.CL.pown
2085//===----------------------------------------------------------------------===//
2086
2087LogicalResult spirv::CLPownOp::verify() {
2088 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2089 getY().getType());
2090}
2091
2092//===----------------------------------------------------------------------===//
2093// spirv.CL.rootn
2094//===----------------------------------------------------------------------===//
2095
2096LogicalResult spirv::CLRootnOp::verify() {
2097 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2098 getN().getType());
2099}
2100
2101//===----------------------------------------------------------------------===//
2102// spirv.ShiftLeftLogicalOp
2103//===----------------------------------------------------------------------===//
2104
2105LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2106 return verifyShiftOp(*this);
2107}
2108
2109//===----------------------------------------------------------------------===//
2110// spirv.ShiftRightArithmeticOp
2111//===----------------------------------------------------------------------===//
2112
2113LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2114 return verifyShiftOp(*this);
2115}
2116
2117//===----------------------------------------------------------------------===//
2118// spirv.ShiftRightLogicalOp
2119//===----------------------------------------------------------------------===//
2120
2121LogicalResult spirv::ShiftRightLogicalOp::verify() {
2122 return verifyShiftOp(*this);
2123}
2124
2125//===----------------------------------------------------------------------===//
2126// spirv.VectorTimesScalarOp
2127//===----------------------------------------------------------------------===//
2128
2129LogicalResult spirv::VectorTimesScalarOp::verify() {
2130 if (getVector().getType() != getType())
2131 return emitOpError("vector operand and result type mismatch");
2132 auto scalarType = cast<VectorType>(getType()).getElementType();
2133 if (getScalar().getType() != scalarType)
2134 return emitOpError("scalar operand and result element type match");
2135 return success();
2136}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:266
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:562
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:252
static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType, Type integerType)
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:298
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:185
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:290
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:280
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Value getOperand(unsigned idx)
Definition Operation.h:376
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:263
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.