MLIR 23.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 Type pointeeType) {
95 // From SPV_KHR_physical_storage_buffer:
96 // > If an OpVariable's pointee type is a pointer (or array of pointers) in
97 // > PhysicalStorageBuffer storage class, then the variable must be decorated
98 // > with exactly one of AliasedPointer or RestrictPointer.
99 auto pointeePtrType = dyn_cast<spirv::PointerType>(pointeeType);
100 if (!pointeePtrType) {
101 if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(pointeeType)) {
102 pointeePtrType =
103 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
104 }
105 }
106
107 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
108 spirv::StorageClass::PhysicalStorageBuffer)
109 return success();
110
111 auto getDecorationAttr = [op](spirv::Decoration decoration) {
112 return op->getAttr(spirv::getDecorationString(decoration));
113 };
114
115 bool hasAliasedPtr =
116 getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
117 bool hasRestrictPtr =
118 getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
119
120 if (!hasAliasedPtr && !hasRestrictPtr)
121 return op->emitOpError()
122 << " with physical buffer pointer must be decorated "
123 "either 'AliasedPointer' or 'RestrictPointer'";
124
125 if (hasAliasedPtr && hasRestrictPtr)
126 return op->emitOpError()
127 << " with physical buffer pointer must have exactly one "
128 "aliasing decoration";
129
130 return success();
131}
132
134 SmallVectorImpl<StringRef> &elidedAttrs) {
135 // Print optional descriptor binding
136 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
137 stringifyDecoration(spirv::Decoration::DescriptorSet));
138 auto bindingName = llvm::convertToSnakeFromCamelCase(
139 stringifyDecoration(spirv::Decoration::Binding));
140 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
141 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
142 if (descriptorSet && binding) {
143 elidedAttrs.push_back(descriptorSetName);
144 elidedAttrs.push_back(bindingName);
145 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
146 << ")";
147 }
148
149 // Print BuiltIn attribute if present
150 auto builtInName = llvm::convertToSnakeFromCamelCase(
151 stringifyDecoration(spirv::Decoration::BuiltIn));
152 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
153 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
154 elidedAttrs.push_back(builtInName);
155 }
156
157 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
158}
159
163 Type type;
164 // If the operand list is in-between parentheses, then we have a generic form.
165 // (see the fallback in `printOneResultOp`).
166 SMLoc loc = parser.getCurrentLocation();
167 if (!parser.parseOptionalLParen()) {
168 if (parser.parseOperandList(ops) || parser.parseRParen() ||
169 parser.parseOptionalAttrDict(result.attributes) ||
170 parser.parseColon() || parser.parseType(type))
171 return failure();
172 auto fnType = dyn_cast<FunctionType>(type);
173 if (!fnType) {
174 parser.emitError(loc, "expected function type");
175 return failure();
176 }
177 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
178 return failure();
179 result.addTypes(fnType.getResults());
180 return success();
181 }
182 return failure(parser.parseOperandList(ops) ||
183 parser.parseOptionalAttrDict(result.attributes) ||
184 parser.parseColonType(type) ||
185 parser.resolveOperands(ops, type, result.operands) ||
186 parser.addTypeToList(type, result.types));
187}
188
190 assert(op->getNumResults() == 1 && "op should have one result");
191
192 // If not all the operand and result types are the same, just use the
193 // generic assembly form to avoid omitting information in printing.
194 auto resultType = op->getResult(0).getType();
195 if (llvm::any_of(op->getOperandTypes(),
196 [&](Type type) { return type != resultType; })) {
197 p.printGenericOp(op, /*printOpName=*/false);
198 return;
199 }
200
201 p << ' ';
202 p.printOperands(op->getOperands());
204 // Now we can output only one type for all operands and the result.
205 p << " : " << resultType;
206}
207
208template <typename BlockReadWriteOpTy>
209static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
210 Value ptr, Value val) {
211 auto valType = val.getType();
212 if (auto valVecTy = dyn_cast<VectorType>(valType))
213 valType = valVecTy.getElementType();
214
215 if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
216 return op.emitOpError("mismatch in result type and pointer type");
217 }
218 return success();
219}
220
221/// Walks the given type hierarchy with the given indices, potentially down
222/// to component granularity, to select an element type. Returns null type and
223/// emits errors with the given loc on failure.
224static Type
226 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
227 if (indices.empty()) {
228 emitErrorFn("expected at least one index for spirv.CompositeExtract");
229 return nullptr;
230 }
231
232 for (auto index : indices) {
233 if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
234 if (cType.hasCompileTimeKnownNumElements() &&
235 (index < 0 ||
236 static_cast<uint64_t>(index) >= cType.getNumElements())) {
237 emitErrorFn("index ") << index << " out of bounds for " << type;
238 return nullptr;
239 }
240 type = cType.getElementType(index);
241 } else {
242 emitErrorFn("cannot extract from non-composite type ")
243 << type << " with index " << index;
244 return nullptr;
245 }
246 }
247 return type;
248}
249
250static Type
252 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
253 auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
254 if (!indicesArrayAttr) {
255 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
256 return nullptr;
257 }
258 if (indicesArrayAttr.empty()) {
259 emitErrorFn("expected at least one index for spirv.CompositeExtract");
260 return nullptr;
261 }
262
263 SmallVector<int32_t, 2> indexVals;
264 for (auto indexAttr : indicesArrayAttr) {
265 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
266 if (!indexIntAttr) {
267 emitErrorFn("expected an 32-bit integer for index, but found '")
268 << indexAttr << "'";
269 return nullptr;
270 }
271 indexVals.push_back(indexIntAttr.getInt());
272 }
273 return getElementType(type, indexVals, emitErrorFn);
274}
275
277 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
278 return ::mlir::emitError(loc, err);
279 };
280 return getElementType(type, indices, errorFn);
281}
282
284 SMLoc loc) {
285 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
286 return parser.emitError(loc, err);
287 };
288 return getElementType(type, indices, errorFn);
289}
290
291template <typename ExtendedBinaryOp>
292static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
293 auto resultType = cast<spirv::StructType>(op.getType());
294 if (resultType.getNumElements() != 2)
295 return op.emitOpError("expected result struct type containing two members");
296
297 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
298 resultType.getElementType(0),
299 resultType.getElementType(1)}))
300 return op.emitOpError(
301 "expected all operand types and struct member types are the same");
302
303 return success();
304}
305
309 if (parser.parseOptionalAttrDict(result.attributes) ||
310 parser.parseOperandList(operands) || parser.parseColon())
311 return failure();
312
313 Type resultType;
314 SMLoc loc = parser.getCurrentLocation();
315 if (parser.parseType(resultType))
316 return failure();
317
318 auto structType = dyn_cast<spirv::StructType>(resultType);
319 if (!structType || structType.getNumElements() != 2)
320 return parser.emitError(loc, "expected spirv.struct type with two members");
321
322 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
323 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
324 return failure();
325
326 result.addTypes(resultType);
327 return success();
328}
329
331 OpAsmPrinter &printer) {
332 printer << ' ';
333 printer.printOptionalAttrDict(op->getAttrs());
334 printer.printOperands(op->getOperands());
335 printer << " : " << op->getResultTypes().front();
336}
337
338static LogicalResult verifyShiftOp(Operation *op) {
339 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
340 return op->emitError("expected the same type for the first operand and "
341 "result, but provided ")
342 << op->getOperand(0).getType() << " and "
343 << op->getResult(0).getType();
344 }
345 return success();
346}
347
348//===----------------------------------------------------------------------===//
349// spirv.mlir.addressof
350//===----------------------------------------------------------------------===//
351
352void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
353 spirv::GlobalVariableOp var) {
354 build(builder, state, var.getType(), SymbolRefAttr::get(var));
355}
356
357LogicalResult spirv::AddressOfOp::verify() {
358 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
359 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
360 getVariableAttr()));
361 if (!varOp) {
362 return emitOpError("expected spirv.GlobalVariable symbol");
363 }
364 if (getPointer().getType() != varOp.getType()) {
365 return emitOpError(
366 "result type mismatch with the referenced global variable's type");
367 }
368 return success();
369}
370
371//===----------------------------------------------------------------------===//
372// spirv.CompositeConstruct
373//===----------------------------------------------------------------------===//
374
375LogicalResult spirv::CompositeConstructOp::verify() {
376 operand_range constituents = this->getConstituents();
377
378 // There are 4 cases with varying verification rules:
379 // 1. Cooperative Matrices (1 constituent)
380 // 2. Structs (1 constituent for each member)
381 // 3. Arrays (1 constituent for each array element)
382 // 4. Vectors (1 constituent (sub-)element for each vector element)
383
384 auto coopElementType = llvm::TypeSwitch<Type, Type>(getType())
385 .Case([](spirv::CooperativeMatrixType coopType) {
386 return coopType.getElementType();
387 })
388 .Default(nullptr);
389
390 // Case 1. -- matrices.
391 if (coopElementType) {
392 if (constituents.size() != 1)
393 return emitOpError("has incorrect number of operands: expected ")
394 << "1, but provided " << constituents.size();
395 if (coopElementType != constituents.front().getType())
396 return emitOpError("operand type mismatch: expected operand type ")
397 << coopElementType << ", but provided "
398 << constituents.front().getType();
399 return success();
400 }
401
402 // Case 2./3./4. -- number of constituents matches the number of elements.
403 auto cType = cast<spirv::CompositeType>(getType());
404 if (constituents.size() == cType.getNumElements()) {
405 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
406 if (constituents[index].getType() != cType.getElementType(index)) {
407 return emitOpError("operand type mismatch: expected operand type ")
408 << cType.getElementType(index) << ", but provided "
409 << constituents[index].getType();
410 }
411 }
412 return success();
413 }
414
415 // Case 4. -- check that all constituents add up tp the expected vector type.
416 auto resultType = dyn_cast<VectorType>(cType);
417 if (!resultType)
418 return emitOpError(
419 "expected to return a vector or cooperative matrix when the number of "
420 "constituents is less than what the result needs");
421
423 for (Value component : constituents) {
424 if (!isa<VectorType>(component.getType()) &&
425 !component.getType().isIntOrFloat())
426 return emitOpError("operand type mismatch: expected operand to have "
427 "a scalar or vector type, but provided ")
428 << component.getType();
429
430 Type elementType = component.getType();
431 if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
432 sizes.push_back(vectorType.getNumElements());
433 elementType = vectorType.getElementType();
434 } else {
435 sizes.push_back(1);
436 }
437
438 if (elementType != resultType.getElementType())
439 return emitOpError("operand element type mismatch: expected to be ")
440 << resultType.getElementType() << ", but provided " << elementType;
441 }
442 unsigned totalCount = llvm::sum_of(sizes);
443 if (totalCount != cType.getNumElements())
444 return emitOpError("has incorrect number of operands: expected ")
445 << cType.getNumElements() << ", but provided " << totalCount;
446 return success();
447}
448
449//===----------------------------------------------------------------------===//
450// spirv.CompositeExtractOp
451//===----------------------------------------------------------------------===//
452
453void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
454 Value composite,
456 auto indexAttr = builder.getI32ArrayAttr(indices);
457 auto elementType =
458 getElementType(composite.getType(), indexAttr, state.location);
459 if (!elementType) {
460 return;
461 }
462 build(builder, state, elementType, composite, indexAttr);
463}
464
465ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
467 OpAsmParser::UnresolvedOperand compositeInfo;
468 Attribute indicesAttr;
469 StringRef indicesAttrName =
470 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
471 Type compositeType;
472 SMLoc attrLocation;
473
474 if (parser.parseOperand(compositeInfo) ||
475 parser.getCurrentLocation(&attrLocation) ||
476 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
477 parser.parseColonType(compositeType) ||
478 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
479 return failure();
480 }
481
482 Type resultType =
483 getElementType(compositeType, indicesAttr, parser, attrLocation);
484 if (!resultType) {
485 return failure();
486 }
487 result.addTypes(resultType);
488 return success();
489}
490
491void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
492 printer << ' ' << getComposite() << getIndices() << " : "
493 << getComposite().getType();
494}
495
496LogicalResult spirv::CompositeExtractOp::verify() {
497 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
498 auto resultType =
499 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
500 if (!resultType)
501 return failure();
502
503 if (resultType != getType()) {
504 return emitOpError("invalid result type: expected ")
505 << resultType << " but provided " << getType();
506 }
507
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// spirv.CompositeInsert
513//===----------------------------------------------------------------------===//
514
515void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
516 Value object, Value composite,
518 auto indexAttr = builder.getI32ArrayAttr(indices);
519 build(builder, state, composite.getType(), object, composite, indexAttr);
520}
521
522ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
525 Type objectType, compositeType;
526 Attribute indicesAttr;
527 StringRef indicesAttrName =
528 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
529 auto loc = parser.getCurrentLocation();
530
531 return failure(
532 parser.parseOperandList(operands, 2) ||
533 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
534 parser.parseColonType(objectType) ||
535 parser.parseKeywordType("into", compositeType) ||
536 parser.resolveOperands(operands, {objectType, compositeType}, loc,
537 result.operands) ||
538 parser.addTypesToList(compositeType, result.types));
539}
540
541LogicalResult spirv::CompositeInsertOp::verify() {
542 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
543 auto objectType =
544 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
545 if (!objectType)
546 return failure();
547
548 if (objectType != getObject().getType()) {
549 return emitOpError("object operand type should be ")
550 << objectType << ", but found " << getObject().getType();
551 }
552
553 if (getComposite().getType() != getType()) {
554 return emitOpError("result type should be the same as "
555 "the composite type, but found ")
556 << getComposite().getType() << " vs " << getType();
557 }
558
559 return success();
560}
561
562void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
563 printer << " " << getObject() << ", " << getComposite() << getIndices()
564 << " : " << getObject().getType() << " into "
565 << getComposite().getType();
566}
567
568//===----------------------------------------------------------------------===//
569// spirv.Constant
570//===----------------------------------------------------------------------===//
571
572ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
574 Attribute value;
575 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
576 if (parser.parseAttribute(value, valueAttrName, result.attributes))
577 return failure();
578
579 Type type = NoneType::get(parser.getContext());
580 if (auto typedAttr = dyn_cast<TypedAttr>(value))
581 type = typedAttr.getType();
582 if (isa<NoneType, TensorType>(type)) {
583 if (parser.parseColonType(type))
584 return failure();
585 }
586
587 if (isa<TensorArmType>(type)) {
588 if (parser.parseOptionalColon().succeeded())
589 if (parser.parseType(type))
590 return failure();
591 }
592
593 return parser.addTypeToList(type, result.types);
594}
595
596void spirv::ConstantOp::print(OpAsmPrinter &printer) {
597 printer << ' ' << getValue();
598 if (isa<spirv::ArrayType, spirv::StructType>(getType()))
599 printer << " : " << getType();
600}
601
602static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
603 Type opType) {
604 if (isa<spirv::CooperativeMatrixType>(opType)) {
605 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
606 if (!denseAttr || !denseAttr.isSplat())
607 return op.emitOpError("expected a splat dense attribute for cooperative "
608 "matrix constant, but found ")
609 << denseAttr;
610 }
611 if (isa<IntegerAttr, FloatAttr>(value)) {
612 auto valueType = cast<TypedAttr>(value).getType();
613 if (valueType != opType)
614 return op.emitOpError("result type (")
615 << opType << ") does not match value type (" << valueType << ")";
616 return success();
617 }
618 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
619 auto valueType = cast<TypedAttr>(value).getType();
620 if (valueType == opType)
621 return success();
622 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
623 auto shapedType = dyn_cast<ShapedType>(valueType);
624 if (!arrayType)
625 return op.emitOpError("result or element type (")
626 << opType << ") does not match value type (" << valueType
627 << "), must be the same or spirv.array";
628
629 int numElements = arrayType.getNumElements();
630 auto opElemType = arrayType.getElementType();
631 while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
632 numElements *= t.getNumElements();
633 opElemType = t.getElementType();
634 }
635 if (!opElemType.isIntOrFloat())
636 return op.emitOpError("only support nested array result type");
637
638 auto valueElemType = shapedType.getElementType();
639 if (valueElemType != opElemType) {
640 return op.emitOpError("result element type (")
641 << opElemType << ") does not match value element type ("
642 << valueElemType << ")";
643 }
644
645 if (numElements != shapedType.getNumElements()) {
646 return op.emitOpError("result number of elements (")
647 << numElements << ") does not match value number of elements ("
648 << shapedType.getNumElements() << ")";
649 }
650 return success();
651 }
652 if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
653 if (auto structType = dyn_cast<spirv::StructType>(opType)) {
654 // Identified (possibly recursive) structs are not supported as constants.
655 if (structType.isIdentified())
656 return op.emitOpError(
657 "cannot have an identified struct as a constant type");
658 if (arrayAttr.size() != structType.getNumElements())
659 return op.emitOpError("number of constituents (")
660 << arrayAttr.size()
661 << ") does not match number of struct members ("
662 << structType.getNumElements() << ")";
663 for (auto [idx, element] : llvm::enumerate(arrayAttr.getValue())) {
664 if (failed(verifyConstantType(op, element,
665 structType.getElementType(idx))))
666 return failure();
667 }
668 return success();
669 }
670 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
671 if (!arrayType)
672 return op.emitOpError(
673 "must have spirv.array or spirv.struct result type for array value");
674 Type elemType = arrayType.getElementType();
675 for (Attribute element : arrayAttr.getValue()) {
676 // Verify array elements recursively.
677 if (failed(verifyConstantType(op, element, elemType)))
678 return failure();
679 }
680 return success();
681 }
682 return op.emitOpError("cannot have attribute: ") << value;
683}
684
685LogicalResult spirv::ConstantOp::verify() {
686 // ODS already generates checks to make sure the result type is valid. We just
687 // need to additionally check that the value's attribute type is consistent
688 // with the result type.
689 return verifyConstantType(*this, getValueAttr(), getType());
690}
691
692bool spirv::ConstantOp::isBuildableWith(Type type) {
693 // Must be valid SPIR-V type first.
694 if (!isa<spirv::SPIRVType>(type))
695 return false;
696
697 if (isa<SPIRVDialect>(type.getDialect())) {
698 if (auto structType = dyn_cast<spirv::StructType>(type))
699 return !structType.isIdentified();
700 return isa<spirv::ArrayType>(type);
701 }
702
703 return true;
704}
705
706spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
707 OpBuilder &builder) {
708 if (auto intType = dyn_cast<IntegerType>(type)) {
709 unsigned width = intType.getWidth();
710 if (width == 1)
711 return spirv::ConstantOp::create(builder, loc, type,
712 builder.getBoolAttr(false));
713 return spirv::ConstantOp::create(
714 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
715 }
716 if (auto floatType = dyn_cast<FloatType>(type)) {
717 return spirv::ConstantOp::create(builder, loc, type,
718 builder.getFloatAttr(floatType, 0.0));
719 }
720 if (auto vectorType = dyn_cast<VectorType>(type)) {
721 Type elemType = vectorType.getElementType();
722 if (isa<IntegerType>(elemType)) {
723 return spirv::ConstantOp::create(
724 builder, loc, type,
725 DenseElementsAttr::get(vectorType,
726 IntegerAttr::get(elemType, 0).getValue()));
727 }
728 if (isa<FloatType>(elemType)) {
729 return spirv::ConstantOp::create(
730 builder, loc, type,
731 DenseFPElementsAttr::get(vectorType,
732 FloatAttr::get(elemType, 0.0).getValue()));
733 }
734 }
735
736 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
737}
738
739spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
740 OpBuilder &builder) {
741 if (auto intType = dyn_cast<IntegerType>(type)) {
742 unsigned width = intType.getWidth();
743 if (width == 1)
744 return spirv::ConstantOp::create(builder, loc, type,
745 builder.getBoolAttr(true));
746 return spirv::ConstantOp::create(
747 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
748 }
749 if (auto floatType = dyn_cast<FloatType>(type)) {
750 return spirv::ConstantOp::create(builder, loc, type,
751 builder.getFloatAttr(floatType, 1.0));
752 }
753 if (auto vectorType = dyn_cast<VectorType>(type)) {
754 Type elemType = vectorType.getElementType();
755 if (isa<IntegerType>(elemType)) {
756 return spirv::ConstantOp::create(
757 builder, loc, type,
758 DenseElementsAttr::get(vectorType,
759 IntegerAttr::get(elemType, 1).getValue()));
760 }
761 if (isa<FloatType>(elemType)) {
762 return spirv::ConstantOp::create(
763 builder, loc, type,
764 DenseFPElementsAttr::get(vectorType,
765 FloatAttr::get(elemType, 1.0).getValue()));
766 }
767 }
768
769 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
770}
771
772void mlir::spirv::ConstantOp::getAsmResultNames(
773 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
774 Type type = getType();
775
776 SmallString<32> specialNameBuffer;
777 llvm::raw_svector_ostream specialName(specialNameBuffer);
778 specialName << "cst";
779
780 IntegerType intTy = dyn_cast<IntegerType>(type);
781
782 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
783 assert(intTy);
784
785 if (intTy.getWidth() == 1) {
786 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
787 }
788
789 if (intTy.isSignless()) {
790 specialName << intCst.getInt();
791 } else if (intTy.isUnsigned()) {
792 specialName << intCst.getUInt();
793 } else {
794 specialName << intCst.getSInt();
795 }
796 }
797
798 if (intTy || isa<FloatType>(type)) {
799 specialName << '_' << type;
800 }
801
802 if (auto vecType = dyn_cast<VectorType>(type)) {
803 specialName << "_vec_";
804 specialName << vecType.getDimSize(0);
805
806 Type elementType = vecType.getElementType();
807
808 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
809 specialName << "x" << elementType;
810 }
811 }
812
813 setNameFn(getResult(), specialName.str());
814}
815
816void mlir::spirv::AddressOfOp::getAsmResultNames(
817 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
818 SmallString<32> specialNameBuffer;
819 llvm::raw_svector_ostream specialName(specialNameBuffer);
820 specialName << getVariable() << "_addr";
821 setNameFn(getResult(), specialName.str());
822}
823
824//===----------------------------------------------------------------------===//
825// spirv.EXTConstantCompositeReplicate
826//===----------------------------------------------------------------------===//
827
828// Returns type of attribute. In case of a TypedAttr this will simply return
829// the type. But for an ArrayAttr which is untyped and can be multidimensional
830// it creates the ArrayType recursively.
832 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
833 return typedAttr.getType();
834 }
835
836 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
837 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
838 }
839
840 return nullptr;
841}
842
843LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
844 Type valueType = getValueType(getValue());
845 if (!valueType)
846 return emitError("unknown value attribute type");
847
848 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
849 if (!compositeType)
850 return emitError("result type is not a composite type");
851
852 Type compositeElementType = compositeType.getElementType(0);
853
854 SmallVector<Type, 3> possibleTypes = {compositeElementType};
855 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
856 compositeElementType = type.getElementType(0);
857 possibleTypes.push_back(compositeElementType);
858 }
859
860 if (!is_contained(possibleTypes, valueType)) {
861 return emitError("expected value attribute type ")
862 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
863 }
864
865 return success();
866}
867
868//===----------------------------------------------------------------------===//
869// spirv.ControlBarrierOp
870//===----------------------------------------------------------------------===//
871
872LogicalResult spirv::ControlBarrierOp::verify() {
873 return verifyMemorySemantics(getOperation(), getMemorySemantics());
874}
875
876//===----------------------------------------------------------------------===//
877// spirv.EntryPoint
878//===----------------------------------------------------------------------===//
879
880void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
881 spirv::ExecutionModel executionModel,
882 spirv::FuncOp function,
883 ArrayRef<Attribute> interfaceVars) {
884 build(builder, state,
885 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
886 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
887}
888
889ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
891 spirv::ExecutionModel execModel;
892 SmallVector<Attribute, 4> interfaceVars;
893
895 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
896 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
897 return failure();
898 }
899
900 if (!parser.parseOptionalComma()) {
901 // Parse the interface variables
902 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
903 // The name of the interface variable attribute isnt important
904 FlatSymbolRefAttr var;
905 NamedAttrList attrs;
906 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
907 return failure();
908 interfaceVars.push_back(var);
909 return success();
910 }))
911 return failure();
912 }
913 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
914 parser.getBuilder().getArrayAttr(interfaceVars));
915 return success();
916}
917
918void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
919 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
920 printer.printSymbolName(getFn());
921 auto interfaceVars = getInterface().getValue();
922 if (!interfaceVars.empty())
923 printer << ", " << llvm::interleaved(interfaceVars);
924}
925
926LogicalResult spirv::EntryPointOp::verify() {
927 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
928 // verification.
929 return success();
930}
931
932//===----------------------------------------------------------------------===//
933// spirv.ExecutionMode
934//===----------------------------------------------------------------------===//
935
936void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
937 spirv::FuncOp function,
938 spirv::ExecutionMode executionMode,
939 ArrayRef<int32_t> params) {
940 build(builder, state, SymbolRefAttr::get(function),
941 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
942 builder.getI32ArrayAttr(params));
943}
944
945ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
947 spirv::ExecutionMode execMode;
948 Attribute fn;
949 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
951 return failure();
952 }
953
955 Type i32Type = parser.getBuilder().getIntegerType(32);
956 while (!parser.parseOptionalComma()) {
957 NamedAttrList attr;
958 Attribute value;
959 if (parser.parseAttribute(value, i32Type, "value", attr)) {
960 return failure();
961 }
962 values.push_back(cast<IntegerAttr>(value).getInt());
963 }
964 StringRef valuesAttrName =
965 spirv::ExecutionModeOp::getValuesAttrName(result.name);
966 result.addAttribute(valuesAttrName,
967 parser.getBuilder().getI32ArrayAttr(values));
968 return success();
969}
970
971void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
972 printer << " ";
973 printer.printSymbolName(getFn());
974 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
975 ArrayAttr values = this->getValues();
976 if (!values.empty())
977 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
978}
979
980//===----------------------------------------------------------------------===//
981// spirv.ExecutionModeId
982//===----------------------------------------------------------------------===//
983
984ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
986 ExecutionMode execMode;
987 if (Attribute fn;
988 parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
989 parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
990 return failure();
991 }
992
994 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
995 FlatSymbolRefAttr attr;
996 if (parser.parseAttribute(attr))
997 return failure();
998 values.push_back(attr);
999 return success();
1000 })) {
1001 return failure();
1002 }
1003
1004 StringRef valuesAttrName = getValuesAttrName(result.name);
1005 ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
1006 result.addAttribute(valuesAttrName, valuesAttr);
1007 return success();
1008}
1009
1010void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
1011 printer << " ";
1012 printer.printSymbolName(getFn());
1013 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
1014
1015 llvm::interleaveComma(
1016 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
1017 [&](StringRef value) { printer.printSymbolName(value); });
1018}
1019
1020LogicalResult spirv::ExecutionModeIdOp::verify() {
1021 // Valid as of SPIRV 1.6
1022 switch (getExecutionMode()) {
1023 case ExecutionMode::SubgroupsPerWorkgroupId:
1024 case ExecutionMode::LocalSizeId:
1025 case ExecutionMode::LocalSizeHintId:
1026 break;
1027 default:
1028 return emitOpError("expected ExecutionMode that takes extra operands that "
1029 "are <id> operands, got: ")
1030 << stringifyExecutionMode(getExecutionMode());
1031 }
1032
1033 if (getValues().empty())
1034 return emitOpError("expected at least one value operand");
1035
1036 for (Attribute value : getValues()) {
1037 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
1038 if (!valueSymbol)
1039 return emitOpError("expected value operands to be symbol reference");
1041 (*this)->getParentOp(), valueSymbol);
1042 if (!valueOp)
1043 return emitOpError("cannot find symbol referenced by value operand: ")
1044 << valueSymbol.getValue();
1045 }
1046
1047 return success();
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// spirv.func
1052//===----------------------------------------------------------------------===//
1053
1054ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1056 SmallVector<DictionaryAttr> resultAttrs;
1057 SmallVector<Type> resultTypes;
1058 auto &builder = parser.getBuilder();
1059
1060 // Parse the name as a symbol.
1061 StringAttr nameAttr;
1062 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1063 result.attributes))
1064 return failure();
1065
1066 // Parse the function signature.
1067 bool isVariadic = false;
1069 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1070 resultAttrs))
1071 return failure();
1072
1073 SmallVector<Type> argTypes;
1074 for (auto &arg : entryArgs)
1075 argTypes.push_back(arg.type);
1076 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1077 result.addAttribute(getFunctionTypeAttrName(result.name),
1078 TypeAttr::get(fnType));
1079
1080 // Parse the optional function control keyword.
1081 spirv::FunctionControl fnControl;
1083 return failure();
1084
1085 // If additional attributes are present, parse them.
1086 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1087 return failure();
1088
1089 // Add the attributes to the function arguments.
1090 assert(resultAttrs.size() == resultTypes.size());
1092 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1093 getResAttrsAttrName(result.name));
1094
1095 // Parse the optional function body.
1096 auto *body = result.addRegion();
1097 OptionalParseResult parseResult =
1098 parser.parseOptionalRegion(*body, entryArgs);
1099 return failure(parseResult.has_value() && failed(*parseResult));
1100}
1101
1102void spirv::FuncOp::print(OpAsmPrinter &printer) {
1103 // Print function name, signature, and control.
1104 printer << " ";
1105 printer.printSymbolName(getSymName());
1106 auto fnType = getFunctionType();
1108 printer, *this, fnType.getInputs(),
1109 /*isVariadic=*/false, fnType.getResults());
1110 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
1111 << "\"";
1113 printer, *this,
1114 {spirv::attributeName<spirv::FunctionControl>(),
1115 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1116 getFunctionControlAttrName()});
1117
1118 // Print the body if this is not an external function.
1119 Region &body = this->getBody();
1120 if (!body.empty()) {
1121 printer << ' ';
1122 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1123 /*printBlockTerminators=*/true);
1124 }
1125}
1126
1127LogicalResult spirv::FuncOp::verifyType() {
1128 FunctionType fnType = getFunctionType();
1129 if (fnType.getNumResults() > 1)
1130 return emitOpError("cannot have more than one result");
1131
1132 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1133 unsigned argIndex) {
1134 auto func = cast<FunctionOpInterface>(getOperation());
1135 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1136 if (argAttr.getName() != spirv::DecorationAttr::name)
1137 continue;
1138 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1139 return decAttr.getValue() == decoration;
1140 }
1141 return false;
1142 };
1143
1144 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1145 Type param = fnType.getInputs()[i];
1146 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1147 if (!inputPtrType)
1148 continue;
1149
1150 auto pointeePtrType =
1151 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1152 if (pointeePtrType) {
1153 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1154 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1155 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1156 // > storage class, the function parameter must be decorated with exactly
1157 // > one of AliasedPointer or RestrictPointer.
1158 if (pointeePtrType.getStorageClass() !=
1159 spirv::StorageClass::PhysicalStorageBuffer)
1160 continue;
1161
1162 bool hasAliasedPtr =
1163 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1164 bool hasRestrictPtr =
1165 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1166 if (!hasAliasedPtr && !hasRestrictPtr)
1167 return emitOpError()
1168 << "with a pointer points to a physical buffer pointer must "
1169 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1170 continue;
1171 }
1172 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1173 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1174 // > the PhysicalStorageBuffer storage class, the function parameter must
1175 // > be decorated with exactly one of Aliased or Restrict.
1176 if (auto pointeeArrayType =
1177 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1178 pointeePtrType =
1179 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1180 } else {
1181 pointeePtrType = inputPtrType;
1182 }
1183
1184 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1185 spirv::StorageClass::PhysicalStorageBuffer)
1186 continue;
1187
1188 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1189 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1190 if (!hasAliased && !hasRestrict)
1191 return emitOpError() << "with physical buffer pointer must be decorated "
1192 "either 'Aliased' or 'Restrict'";
1193 }
1194
1195 return success();
1196}
1197
1198LogicalResult spirv::FuncOp::verifyBody() {
1199 FunctionType fnType = getFunctionType();
1200 if (!isExternal()) {
1201 Block &entryBlock = front();
1202
1203 unsigned numArguments = this->getNumArguments();
1204 if (entryBlock.getNumArguments() != numArguments)
1205 return emitOpError("entry block must have ")
1206 << numArguments << " arguments to match function signature";
1207
1208 for (auto [index, fnArgType, blockArgType] :
1209 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1210 if (blockArgType != fnArgType) {
1211 return emitOpError("type of entry block argument #")
1212 << index << '(' << blockArgType
1213 << ") must match the type of the corresponding argument in "
1214 << "function signature(" << fnArgType << ')';
1215 }
1216 }
1217 }
1218
1219 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1220 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1221 if (fnType.getNumResults() != 0)
1222 return retOp.emitOpError("cannot be used in functions returning value");
1223 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1224 if (fnType.getNumResults() != 1)
1225 return retOp.emitOpError(
1226 "returns 1 value but enclosing function requires ")
1227 << fnType.getNumResults() << " results";
1228
1229 auto retOperandType = retOp.getValue().getType();
1230 auto fnResultType = fnType.getResult(0);
1231 if (retOperandType != fnResultType)
1232 return retOp.emitOpError(" return value's type (")
1233 << retOperandType << ") mismatch with function's result type ("
1234 << fnResultType << ")";
1235 }
1236 return WalkResult::advance();
1237 });
1238
1239 // TODO: verify other bits like linkage type.
1240
1241 return failure(walkResult.wasInterrupted());
1242}
1243
1244void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1245 StringRef name, FunctionType type,
1246 spirv::FunctionControl control,
1249 builder.getStringAttr(name));
1250 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1251 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1252 builder.getAttr<spirv::FunctionControlAttr>(control));
1253 state.attributes.append(attrs.begin(), attrs.end());
1254 state.addRegion();
1255}
1256
1257//===----------------------------------------------------------------------===//
1258// spirv.GLFClampOp
1259//===----------------------------------------------------------------------===//
1260
1261ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1264}
1265void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1266
1267//===----------------------------------------------------------------------===//
1268// spirv.GLUClampOp
1269//===----------------------------------------------------------------------===//
1270
1271ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1274}
1275void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1276
1277//===----------------------------------------------------------------------===//
1278// spirv.GLSClampOp
1279//===----------------------------------------------------------------------===//
1280
1281ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1284}
1285void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1286
1287//===----------------------------------------------------------------------===//
1288// spirv.GLFmaOp
1289//===----------------------------------------------------------------------===//
1290
1291ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1293}
1294void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1295
1296//===----------------------------------------------------------------------===//
1297// spirv.GlobalVariable
1298//===----------------------------------------------------------------------===//
1299
1300void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1301 Type type, StringRef name,
1302 unsigned descriptorSet, unsigned binding) {
1303 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1304 state.addAttribute(
1305 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1306 builder.getI32IntegerAttr(descriptorSet));
1307 state.addAttribute(
1308 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1309 builder.getI32IntegerAttr(binding));
1310}
1311
1312void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1313 Type type, StringRef name,
1314 spirv::BuiltIn builtin) {
1315 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1316 state.addAttribute(
1317 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1318 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1319}
1320
1321ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1323 // Parse variable name.
1324 StringAttr nameAttr;
1325 StringRef initializerAttrName =
1326 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1327 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1328 result.attributes)) {
1329 return failure();
1330 }
1331
1332 // Parse optional initializer
1333 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1334 FlatSymbolRefAttr initSymbol;
1335 if (parser.parseLParen() ||
1336 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1337 result.attributes) ||
1338 parser.parseRParen())
1339 return failure();
1340 }
1341
1342 if (parseVariableDecorations(parser, result)) {
1343 return failure();
1344 }
1345
1346 Type type;
1347 StringRef typeAttrName =
1348 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1349 auto loc = parser.getCurrentLocation();
1350 if (parser.parseColonType(type)) {
1351 return failure();
1352 }
1353 if (!isa<spirv::PointerType>(type)) {
1354 return parser.emitError(loc, "expected spirv.ptr type");
1355 }
1356 result.addAttribute(typeAttrName, TypeAttr::get(type));
1357
1358 return success();
1359}
1360
1361void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1362 SmallVector<StringRef, 4> elidedAttrs{
1363 spirv::attributeName<spirv::StorageClass>()};
1364
1365 // Print variable name.
1366 printer << ' ';
1367 printer.printSymbolName(getSymName());
1368 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1369
1370 StringRef initializerAttrName = this->getInitializerAttrName();
1371 // Print optional initializer
1372 if (auto initializer = this->getInitializer()) {
1373 printer << " " << initializerAttrName << '(';
1374 printer.printSymbolName(*initializer);
1375 printer << ')';
1376 elidedAttrs.push_back(initializerAttrName);
1377 }
1378
1379 StringRef typeAttrName = this->getTypeAttrName();
1380 elidedAttrs.push_back(typeAttrName);
1381 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1382 printer << " : " << getType();
1383}
1384
1385LogicalResult spirv::GlobalVariableOp::verify() {
1386 if (!isa<spirv::PointerType>(getType()))
1387 return emitOpError("result must be of a !spv.ptr type");
1388
1389 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1390 // object. It cannot be Generic. It must be the same as the Storage Class
1391 // operand of the Result Type."
1392 // Also, Function storage class is reserved by spirv.Variable.
1393 auto storageClass = this->storageClass();
1394 if (storageClass == spirv::StorageClass::Generic ||
1395 storageClass == spirv::StorageClass::Function) {
1396 return emitOpError("storage class cannot be '")
1397 << stringifyStorageClass(storageClass) << "'";
1398 }
1399
1400 // SPIR-V spec: "A module-scope OpVariable with an Initializer operand must
1401 // not be decorated with the Import Linkage Type."
1402 if (std::optional<spirv::LinkageAttributesAttr> linkage =
1403 getLinkageAttributes()) {
1404 if (linkage->getLinkageType().getValue() == spirv::LinkageType::Import &&
1405 getInitializer()) {
1406 return emitOpError(
1407 "with Import linkage type must not have an initializer");
1408 }
1409 }
1410
1411 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1412 this->getInitializerAttrName())) {
1414 (*this)->getParentOp(), init.getAttr());
1415 // TODO: Currently only variable initialization with specialization
1416 // constants is supported. There could be normal constants in the module
1417 // scope as well.
1418 //
1419 // In the current setup we also cannot initialize one global variable with
1420 // another. The problem is that if we try to initialize pointer of type X
1421 // with another pointer type, the validator fails because it expects the
1422 // variable to be initialized to be type X, not pointer to X. Now
1423 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1424 // we cannot initialize one `spirv.GlobalVariable` with another.
1425 if (!initOp ||
1426 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1427 return emitOpError("initializer must be result of a "
1428 "spirv.SpecConstant or "
1429 "spirv.SpecConstantCompositeOp op");
1430 }
1431 }
1432
1433 Type pointeeType = cast<spirv::PointerType>(getType()).getPointeeType();
1434 if (failed(
1435 verifyPhysicalStorageBufferDecorations(getOperation(), pointeeType)))
1436 return failure();
1437
1438 return success();
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// spirv.INTEL.SubgroupBlockRead
1443//===----------------------------------------------------------------------===//
1444
1445LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1446 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1447 return failure();
1448
1449 return success();
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// spirv.INTEL.SubgroupBlockWrite
1454//===----------------------------------------------------------------------===//
1455
1456ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1458 // Parse the storage class specification
1459 spirv::StorageClass storageClass;
1461 auto loc = parser.getCurrentLocation();
1462 Type elementType;
1463 if (parseEnumStrAttr(storageClass, parser) ||
1464 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1465 parser.parseType(elementType)) {
1466 return failure();
1467 }
1468
1469 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1470 if (auto valVecTy = dyn_cast<VectorType>(elementType))
1471 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1472
1473 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1474 result.operands)) {
1475 return failure();
1476 }
1477 return success();
1478}
1479
1480void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1481 printer << " " << getPtr() << ", " << getValue() << " : "
1482 << getValue().getType();
1483}
1484
1485LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1486 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1487 return failure();
1488
1489 return success();
1490}
1491
1492//===----------------------------------------------------------------------===//
1493// spirv.IAddCarryOp
1494//===----------------------------------------------------------------------===//
1495
1496LogicalResult spirv::IAddCarryOp::verify() {
1497 return ::verifyArithmeticExtendedBinaryOp(*this);
1498}
1499
1500ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1502 return ::parseArithmeticExtendedBinaryOp(parser, result);
1503}
1504
1505void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1506 ::printArithmeticExtendedBinaryOp(*this, printer);
1507}
1508
1509//===----------------------------------------------------------------------===//
1510// spirv.ISubBorrowOp
1511//===----------------------------------------------------------------------===//
1512
1513LogicalResult spirv::ISubBorrowOp::verify() {
1514 return ::verifyArithmeticExtendedBinaryOp(*this);
1515}
1516
1517ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1519 return ::parseArithmeticExtendedBinaryOp(parser, result);
1520}
1521
1522void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1523 ::printArithmeticExtendedBinaryOp(*this, printer);
1524}
1525
1526//===----------------------------------------------------------------------===//
1527// spirv.SMulExtended
1528//===----------------------------------------------------------------------===//
1529
1530LogicalResult spirv::SMulExtendedOp::verify() {
1531 return ::verifyArithmeticExtendedBinaryOp(*this);
1532}
1533
1534ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1536 return ::parseArithmeticExtendedBinaryOp(parser, result);
1537}
1538
1539void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1540 ::printArithmeticExtendedBinaryOp(*this, printer);
1541}
1542
1543//===----------------------------------------------------------------------===//
1544// spirv.UMulExtended
1545//===----------------------------------------------------------------------===//
1546
1547LogicalResult spirv::UMulExtendedOp::verify() {
1548 return ::verifyArithmeticExtendedBinaryOp(*this);
1549}
1550
1551ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1553 return ::parseArithmeticExtendedBinaryOp(parser, result);
1554}
1555
1556void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1557 ::printArithmeticExtendedBinaryOp(*this, printer);
1558}
1559
1560//===----------------------------------------------------------------------===//
1561// spirv.MemoryBarrierOp
1562//===----------------------------------------------------------------------===//
1563
1564LogicalResult spirv::MemoryBarrierOp::verify() {
1565 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1566}
1567
1568//===----------------------------------------------------------------------===//
1569// spirv.MemoryNamedBarrierOp
1570//===----------------------------------------------------------------------===//
1571
1572LogicalResult spirv::MemoryNamedBarrierOp::verify() {
1573 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1574}
1575
1576//===----------------------------------------------------------------------===//
1577// spirv.module
1578//===----------------------------------------------------------------------===//
1579
1580void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1581 std::optional<StringRef> name) {
1582 OpBuilder::InsertionGuard guard(builder);
1583 builder.createBlock(state.addRegion());
1584 if (name) {
1586 builder.getStringAttr(*name));
1587 }
1588}
1589
1590void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1591 spirv::AddressingModel addressingModel,
1592 spirv::MemoryModel memoryModel,
1593 std::optional<VerCapExtAttr> vceTriple,
1594 std::optional<StringRef> name) {
1595 state.addAttribute(
1596 "addressing_model",
1597 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1598 state.addAttribute("memory_model",
1599 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1600 OpBuilder::InsertionGuard guard(builder);
1601 builder.createBlock(state.addRegion());
1602 if (vceTriple)
1603 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1604 if (name)
1606 builder.getStringAttr(*name));
1607}
1608
1609ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1611 Region *body = result.addRegion();
1612
1613 // If the name is present, parse it.
1614 StringAttr nameAttr;
1616 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1617
1618 // Parse attributes
1619 spirv::AddressingModel addrModel;
1620 spirv::MemoryModel memoryModel;
1622 result) ||
1624 result))
1625 return failure();
1626
1627 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1628 spirv::VerCapExtAttr vceTriple;
1629 if (parser.parseAttribute(vceTriple,
1630 spirv::ModuleOp::getVCETripleAttrName(),
1631 result.attributes))
1632 return failure();
1633 }
1634
1635 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1636 parser.parseRegion(*body, /*arguments=*/{}))
1637 return failure();
1638
1639 // Make sure we have at least one block.
1640 if (body->empty())
1641 body->push_back(new Block());
1642
1643 return success();
1644}
1645
1646void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1647 if (std::optional<StringRef> name = getName()) {
1648 printer << ' ';
1649 printer.printSymbolName(*name);
1650 }
1651
1652 SmallVector<StringRef, 2> elidedAttrs;
1653
1654 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1655 << spirv::stringifyMemoryModel(getMemoryModel());
1656 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1657 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1658 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1660
1661 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1662 printer << " requires " << *triple;
1663 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1664 }
1665
1666 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1667 printer << ' ';
1668 printer.printRegion(getRegion());
1669}
1670
1671LogicalResult spirv::ModuleOp::verifyRegions() {
1672 Dialect *dialect = (*this)->getDialect();
1674 entryPoints;
1675 mlir::SymbolTable table(*this);
1676
1677 for (auto &op : *getBody()) {
1678 if (op.getDialect() != dialect)
1679 return op.emitError("'spirv.module' can only contain spirv.* ops");
1680
1681 // For EntryPoint op, check that the function and execution model is not
1682 // duplicated in EntryPointOps. Also verify that the interface specified
1683 // comes from globalVariables here to make this check cheaper.
1684 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1685 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1686 if (!funcOp) {
1687 return entryPointOp.emitError("function '")
1688 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1689 }
1690 if (auto interface = entryPointOp.getInterface()) {
1691 for (Attribute varRef : interface) {
1692 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1693 if (!varSymRef) {
1694 return entryPointOp.emitError(
1695 "expected symbol reference for interface "
1696 "specification instead of '")
1697 << varRef;
1698 }
1699 auto variableOp =
1700 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1701 if (!variableOp) {
1702 return entryPointOp.emitError("expected spirv.GlobalVariable "
1703 "symbol reference instead of'")
1704 << varSymRef << "'";
1705 }
1706 }
1707 }
1708
1709 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1710 funcOp, entryPointOp.getExecutionModel());
1711 if (!entryPoints.try_emplace(key, entryPointOp).second)
1712 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1713 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1714 // If the function is external and does not have 'Import'
1715 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1716 // LinkageAttributes is used to import external functions.
1717 auto linkageAttr = funcOp.getLinkageAttributes();
1718 auto hasImportLinkage =
1719 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1720 spirv::LinkageType::Import);
1721 if (funcOp.isExternal() && !hasImportLinkage)
1722 return op.emitError(
1723 "'spirv.module' cannot contain external functions "
1724 "without 'Import' linkage_attributes (LinkageAttributes)");
1725
1726 // TODO: move this check to spirv.func.
1727 for (auto &block : funcOp)
1728 for (auto &op : block) {
1729 if (op.getDialect() != dialect)
1730 return op.emitError(
1731 "functions in 'spirv.module' can only contain spirv.* ops");
1732 }
1733 }
1734 }
1735
1736 return success();
1737}
1738
1739//===----------------------------------------------------------------------===//
1740// spirv.mlir.referenceof
1741//===----------------------------------------------------------------------===//
1742
1743LogicalResult spirv::ReferenceOfOp::verify() {
1744 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1745 (*this)->getParentOp(), getSpecConstAttr());
1746 Type constType;
1747
1748 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1749 if (specConstOp)
1750 constType = specConstOp.getDefaultValue().getType();
1751
1752 auto specConstCompositeOp =
1753 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1754 if (specConstCompositeOp)
1755 constType = specConstCompositeOp.getType();
1756
1757 if (!specConstOp && !specConstCompositeOp)
1758 return emitOpError(
1759 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1760
1761 if (getReference().getType() != constType)
1762 return emitOpError("result type mismatch with the referenced "
1763 "specialization constant's type");
1764
1765 return success();
1766}
1767
1768//===----------------------------------------------------------------------===//
1769// spirv.SpecConstant
1770//===----------------------------------------------------------------------===//
1771
1772ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1774 StringAttr nameAttr;
1775 Attribute valueAttr;
1776 StringRef defaultValueAttrName =
1777 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1778
1779 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1780 result.attributes))
1781 return failure();
1782
1783 // Parse optional spec_id.
1784 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1785 IntegerAttr specIdAttr;
1786 if (parser.parseLParen() ||
1787 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1788 parser.parseRParen())
1789 return failure();
1790 }
1791
1792 if (parser.parseEqual() ||
1793 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1794 return failure();
1795
1796 return success();
1797}
1798
1799void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1800 printer << ' ';
1801 printer.printSymbolName(getSymName());
1802 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1803 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1804 printer << " = " << getDefaultValue();
1805}
1806
1807LogicalResult spirv::SpecConstantOp::verify() {
1808 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1809 if (specID.getValue().isNegative())
1810 return emitOpError("SpecId cannot be negative");
1811
1812 auto value = getDefaultValue();
1813 if (isa<IntegerAttr, FloatAttr>(value)) {
1814 // Make sure bitwidth is allowed.
1815 if (!isa<spirv::SPIRVType>(value.getType()))
1816 return emitOpError("default value bitwidth disallowed");
1817 return success();
1818 }
1819 return emitOpError(
1820 "default value can only be a bool, integer, or float scalar");
1821}
1822
1823//===----------------------------------------------------------------------===//
1824// spirv.VectorShuffle
1825//===----------------------------------------------------------------------===//
1826
1827LogicalResult spirv::VectorShuffleOp::verify() {
1828 VectorType resultType = cast<VectorType>(getType());
1829
1830 size_t numResultElements = resultType.getNumElements();
1831 if (numResultElements != getComponents().size())
1832 return emitOpError("result type element count (")
1833 << numResultElements
1834 << ") mismatch with the number of component selectors ("
1835 << getComponents().size() << ")";
1836
1837 size_t totalSrcElements =
1838 cast<VectorType>(getVector1().getType()).getNumElements() +
1839 cast<VectorType>(getVector2().getType()).getNumElements();
1840
1841 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1842 uint32_t index = selector.getZExtValue();
1843 if (index >= totalSrcElements &&
1844 index != std::numeric_limits<uint32_t>().max())
1845 return emitOpError("component selector ")
1846 << index << " out of range: expected to be in [0, "
1847 << totalSrcElements << ") or 0xffffffff";
1848 }
1849 return success();
1850}
1851
1852//===----------------------------------------------------------------------===//
1853// spirv.SpecConstantComposite
1854//===----------------------------------------------------------------------===//
1855
1856ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1858
1859 StringAttr compositeName;
1860 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1861 result.attributes))
1862 return failure();
1863
1864 if (parser.parseLParen())
1865 return failure();
1866
1867 SmallVector<Attribute, 4> constituents;
1868
1869 do {
1870 // The name of the constituent attribute isn't important
1871 const char *attrName = "spec_const";
1872 FlatSymbolRefAttr specConstRef;
1873 NamedAttrList attrs;
1874
1875 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1876 return failure();
1877
1878 constituents.push_back(specConstRef);
1879 } while (!parser.parseOptionalComma());
1880
1881 if (parser.parseRParen())
1882 return failure();
1883
1884 StringAttr compositeSpecConstituentsName =
1885 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1886 result.addAttribute(compositeSpecConstituentsName,
1887 parser.getBuilder().getArrayAttr(constituents));
1888
1889 Type type;
1890 if (parser.parseColonType(type))
1891 return failure();
1892
1893 StringAttr typeAttrName =
1894 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1895 result.addAttribute(typeAttrName, TypeAttr::get(type));
1896
1897 return success();
1898}
1899
1900void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1901 printer << " ";
1902 printer.printSymbolName(getSymName());
1903 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1904 << ") : " << getType();
1905}
1906
1907LogicalResult spirv::SpecConstantCompositeOp::verify() {
1908 auto cType = dyn_cast<spirv::CompositeType>(getType());
1909 auto constituents = this->getConstituents().getValue();
1910
1911 if (!cType)
1912 return emitError("result type must be a composite type, but provided ")
1913 << getType();
1914
1915 if (isa<spirv::CooperativeMatrixType>(cType))
1916 return emitError("unsupported composite type ") << cType;
1917 if (constituents.size() != cType.getNumElements())
1918 return emitError("has incorrect number of operands: expected ")
1919 << cType.getNumElements() << ", but provided "
1920 << constituents.size();
1921
1922 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1923 auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
1924
1926 (*this)->getParentOp(), constituent.getAttr());
1927
1928 if (!constituentOp)
1929 return emitError("unknown constituent symbol ") << constituent.getAttr();
1930
1931 Type constituentType;
1932 if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp)) {
1933 constituentType = specConstOp.getDefaultValue().getType();
1934 } else if (auto specConstCompositeOp =
1935 dyn_cast<spirv::SpecConstantCompositeOp>(constituentOp)) {
1936 constituentType = specConstCompositeOp.getType();
1937 } else {
1938 return emitError("unsupported constituent ")
1939 << constituent.getAttr()
1940 << ": must reference a spirv.SpecConstant or "
1941 "spirv.SpecConstantComposite";
1942 }
1943
1944 if (constituentType != cType.getElementType(index))
1945 return emitError("has incorrect types of operands: expected ")
1946 << cType.getElementType(index) << ", but provided "
1947 << constituentType;
1948 }
1949
1950 return success();
1951}
1952
1953//===----------------------------------------------------------------------===//
1954// spirv.EXTSpecConstantCompositeReplicateOp
1955//===----------------------------------------------------------------------===//
1956
1957ParseResult
1958spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1960 StringAttr compositeName;
1961 FlatSymbolRefAttr specConstRef;
1962 const char *attrName = "spec_const";
1963 NamedAttrList attrs;
1964 Type type;
1965
1966 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1967 result.attributes) ||
1968 parser.parseLParen() ||
1969 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1970 parser.parseRParen() || parser.parseColonType(type))
1971 return failure();
1972
1973 StringAttr compositeSpecConstituentName =
1974 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1975 result.name);
1976 result.addAttribute(compositeSpecConstituentName, specConstRef);
1977
1978 StringAttr typeAttrName =
1979 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1980 result.addAttribute(typeAttrName, TypeAttr::get(type));
1981
1982 return success();
1983}
1984
1985void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1986 printer << " ";
1987 printer.printSymbolName(getSymName());
1988 printer << " (" << this->getConstituent() << ") : " << getType();
1989}
1990
1991LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1992 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1993 if (!compositeType)
1994 return emitError("result type must be a composite type, but provided ")
1995 << getType();
1996
1998 (*this)->getParentOp(), this->getConstituent());
1999 if (!constituentOp)
2000 return emitError(
2001 "splat spec constant reference defining constituent not found");
2002
2003 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
2004 if (!constituentSpecConstOp)
2005 return emitError("constituent is not a spec constant");
2006
2007 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
2008 Type compositeElementType = compositeType.getElementType(0);
2009 if (constituentType != compositeElementType)
2010 return emitError("constituent has incorrect type: expected ")
2011 << compositeElementType << ", but provided " << constituentType;
2012
2013 return success();
2014}
2015
2016//===----------------------------------------------------------------------===//
2017// spirv.SpecConstantOperation
2018//===----------------------------------------------------------------------===//
2019
2020ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
2022 Region *body = result.addRegion();
2023
2024 if (parser.parseKeyword("wraps"))
2025 return failure();
2026
2027 body->push_back(new Block);
2028 Block &block = body->back();
2029 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
2030
2031 if (!wrappedOp)
2032 return failure();
2033
2034 OpBuilder builder(parser.getContext());
2035 builder.setInsertionPointToEnd(&block);
2036 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
2037 result.location = wrappedOp->getLoc();
2038
2039 result.addTypes(wrappedOp->getResult(0).getType());
2040
2041 if (parser.parseOptionalAttrDict(result.attributes))
2042 return failure();
2043
2044 return success();
2045}
2046
2047void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
2048 printer << " wraps ";
2049 printer.printGenericOp(&getBody().front().front());
2050}
2051
2052LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2053 Block &block = getRegion().getBlocks().front();
2054
2055 if (block.getOperations().size() != 2)
2056 return emitOpError("expected exactly 2 nested ops");
2057
2058 Operation &enclosedOp = block.getOperations().front();
2059
2061 return emitOpError("invalid enclosed op");
2062
2063 for (auto operand : enclosedOp.getOperands())
2064 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2065 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2066 return emitOpError(
2067 "invalid operand, must be defined by a constant operation");
2068
2069 return success();
2070}
2071
2072//===----------------------------------------------------------------------===//
2073// spirv.GL.FrexpStruct
2074//===----------------------------------------------------------------------===//
2075
2076LogicalResult spirv::GLFrexpStructOp::verify() {
2077 spirv::StructType structTy =
2078 dyn_cast<spirv::StructType>(getResult().getType());
2079
2080 if (structTy.getNumElements() != 2)
2081 return emitError("result type must be a struct type with two memebers");
2082
2083 Type significandTy = structTy.getElementType(0);
2084 Type exponentTy = structTy.getElementType(1);
2085 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2086 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2087
2088 Type operandTy = getOperand().getType();
2089 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2090 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2091
2092 if (significandTy != operandTy)
2093 return emitError("member zero of the resulting struct type must be the "
2094 "same type as the operand");
2095
2096 if (exponentVecTy) {
2097 IntegerType componentIntTy =
2098 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2099 if (!componentIntTy || componentIntTy.getWidth() != 32)
2100 return emitError("member one of the resulting struct type must"
2101 "be a scalar or vector of 32 bit integer type");
2102 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2103 return emitError("member one of the resulting struct type "
2104 "must be a scalar or vector of 32 bit integer type");
2105 }
2106
2107 // Check that the two member types have the same number of components
2108 if (operandVecTy && exponentVecTy &&
2109 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2110 return success();
2111
2112 if (operandFTy && exponentIntTy)
2113 return success();
2114
2115 return emitError("member one of the resulting struct type must have the same "
2116 "number of components as the operand type");
2117}
2118
2119//===----------------------------------------------------------------------===//
2120// spirv.GL.Ldexp
2121//===----------------------------------------------------------------------===//
2122
2123static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
2124 Type integerType) {
2125 if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
2126 return op->emitOpError("operands must both be scalars or vectors");
2127
2128 auto getNumElements = [](Type type) -> unsigned {
2129 if (auto vectorType = dyn_cast<VectorType>(type))
2130 return vectorType.getNumElements();
2131 return 1;
2132 };
2133
2134 if (getNumElements(floatType) != getNumElements(integerType))
2135 return op->emitOpError("operands must have the same number of elements");
2136
2137 return success();
2138}
2139
2140LogicalResult spirv::GLLdexpOp::verify() {
2141 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2142 getExp().getType());
2143}
2144
2145//===----------------------------------------------------------------------===//
2146// spirv.CL.ldexp
2147//===----------------------------------------------------------------------===//
2148
2149LogicalResult spirv::CLLdexpOp::verify() {
2150 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2151 getExp().getType());
2152}
2153
2154//===----------------------------------------------------------------------===//
2155// spirv.CL.pown
2156//===----------------------------------------------------------------------===//
2157
2158LogicalResult spirv::CLPownOp::verify() {
2159 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2160 getY().getType());
2161}
2162
2163//===----------------------------------------------------------------------===//
2164// spirv.CL.rootn
2165//===----------------------------------------------------------------------===//
2166
2167LogicalResult spirv::CLRootnOp::verify() {
2168 return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
2169 getN().getType());
2170}
2171
2172//===----------------------------------------------------------------------===//
2173// spirv.ShiftLeftLogicalOp
2174//===----------------------------------------------------------------------===//
2175
2176LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2177 return verifyShiftOp(*this);
2178}
2179
2180//===----------------------------------------------------------------------===//
2181// spirv.ShiftRightArithmeticOp
2182//===----------------------------------------------------------------------===//
2183
2184LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2185 return verifyShiftOp(*this);
2186}
2187
2188//===----------------------------------------------------------------------===//
2189// spirv.ShiftRightLogicalOp
2190//===----------------------------------------------------------------------===//
2191
2192LogicalResult spirv::ShiftRightLogicalOp::verify() {
2193 return verifyShiftOp(*this);
2194}
2195
2196//===----------------------------------------------------------------------===//
2197// spirv.VectorTimesScalarOp
2198//===----------------------------------------------------------------------===//
2199
2200LogicalResult spirv::VectorTimesScalarOp::verify() {
2201 if (getVector().getType() != getType())
2202 return emitOpError("vector operand and result type mismatch");
2203 auto scalarType = cast<VectorType>(getType()).getElementType();
2204 if (getScalar().getType() != scalarType)
2205 return emitOpError("scalar operand and result element type match");
2206 return success();
2207}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:306
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:831
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:602
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:160
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:292
static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType, Type integerType)
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:338
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:209
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:225
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:189
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:330
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:281
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
Value getOperand(unsigned idx)
Definition Operation.h:375
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:274
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:133
LogicalResult verifyPhysicalStorageBufferDecorations(Operation *op, Type pointeeType)
Verifies the SPV_KHR_physical_storage_buffer rule that a variable whose pointee is a pointer (or arra...
Definition SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.