MLIR 23.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(descriptorSetName);
104 elidedAttrs.push_back(bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(builtInName);
115 }
116
117 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118}
119
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result.attributes) ||
130 parser.parseColon() || parser.parseType(type))
131 return failure();
132 auto fnType = dyn_cast<FunctionType>(type);
133 if (!fnType) {
134 parser.emitError(loc, "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138 return failure();
139 result.addTypes(fnType.getResults());
140 return success();
141 }
142 return failure(parser.parseOperandList(ops) ||
143 parser.parseOptionalAttrDict(result.attributes) ||
144 parser.parseColonType(type) ||
145 parser.resolveOperands(ops, type, result.operands) ||
146 parser.addTypeToList(type, result.types));
147}
148
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(0).getType();
155 if (llvm::any_of(op->getOperandTypes(),
156 [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(op->getOperands());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
174
175 if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
176 return op.emitOpError("mismatch in result type and pointer type");
177 }
178 return success();
179}
180
181/// Walks the given type hierarchy with the given indices, potentially down
182/// to component granularity, to select an element type. Returns null type and
183/// emits errors with the given loc on failure.
184static Type
186 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
187 if (indices.empty()) {
188 emitErrorFn("expected at least one index for spirv.CompositeExtract");
189 return nullptr;
190 }
191
192 for (auto index : indices) {
193 if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
195 (index < 0 ||
196 static_cast<uint64_t>(index) >= cType.getNumElements())) {
197 emitErrorFn("index ") << index << " out of bounds for " << type;
198 return nullptr;
199 }
200 type = cType.getElementType(index);
201 } else {
202 emitErrorFn("cannot extract from non-composite type ")
203 << type << " with index " << index;
204 return nullptr;
205 }
206 }
207 return type;
208}
209
210static Type
212 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
216 return nullptr;
217 }
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn("expected at least one index for spirv.CompositeExtract");
220 return nullptr;
221 }
222
223 SmallVector<int32_t, 2> indexVals;
224 for (auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
226 if (!indexIntAttr) {
227 emitErrorFn("expected an 32-bit integer for index, but found '")
228 << indexAttr << "'";
229 return nullptr;
230 }
231 indexVals.push_back(indexIntAttr.getInt());
232 }
233 return getElementType(type, indexVals, emitErrorFn);
234}
235
237 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
238 return ::mlir::emitError(loc, err);
239 };
240 return getElementType(type, indices, errorFn);
241}
242
244 SMLoc loc) {
245 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
246 return parser.emitError(loc, err);
247 };
248 return getElementType(type, indices, errorFn);
249}
250
251template <typename ExtendedBinaryOp>
252static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError("expected result struct type containing two members");
256
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
262
263 return success();
264}
265
269 if (parser.parseOptionalAttrDict(result.attributes) ||
270 parser.parseOperandList(operands) || parser.parseColon())
271 return failure();
272
273 Type resultType;
274 SMLoc loc = parser.getCurrentLocation();
275 if (parser.parseType(resultType))
276 return failure();
277
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.emitError(loc, "expected spirv.struct type with two members");
281
282 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
283 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
284 return failure();
285
286 result.addTypes(resultType);
287 return success();
288}
289
291 OpAsmPrinter &printer) {
292 printer << ' ';
293 printer.printOptionalAttrDict(op->getAttrs());
294 printer.printOperands(op->getOperands());
295 printer << " : " << op->getResultTypes().front();
296}
297
298static LogicalResult verifyShiftOp(Operation *op) {
299 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
300 return op->emitError("expected the same type for the first operand and "
301 "result, but provided ")
302 << op->getOperand(0).getType() << " and "
303 << op->getResult(0).getType();
304 }
305 return success();
306}
307
308//===----------------------------------------------------------------------===//
309// spirv.mlir.addressof
310//===----------------------------------------------------------------------===//
311
312void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
315}
316
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
319 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
320 getVariableAttr()));
321 if (!varOp) {
322 return emitOpError("expected spirv.GlobalVariable symbol");
323 }
324 if (getPointer().getType() != varOp.getType()) {
325 return emitOpError(
326 "result type mismatch with the referenced global variable's type");
327 }
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// spirv.CompositeConstruct
333//===----------------------------------------------------------------------===//
334
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
337
338 // There are 4 cases with varying verification rules:
339 // 1. Cooperative Matrices (1 constituent)
340 // 2. Structs (1 constituent for each member)
341 // 3. Arrays (1 constituent for each array element)
342 // 4. Vectors (1 constituent (sub-)element for each vector element)
343
344 auto coopElementType = llvm::TypeSwitch<Type, Type>(getType())
345 .Case([](spirv::CooperativeMatrixType coopType) {
346 return coopType.getElementType();
347 })
348 .Default(nullptr);
349
350 // Case 1. -- matrices.
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError("has incorrect number of operands: expected ")
354 << "1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError("operand type mismatch: expected operand type ")
357 << coopElementType << ", but provided "
358 << constituents.front().getType();
359 return success();
360 }
361
362 // Case 2./3./4. -- number of constituents matches the number of elements.
363 auto cType = cast<spirv::CompositeType>(getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
366 if (constituents[index].getType() != cType.getElementType(index)) {
367 return emitOpError("operand type mismatch: expected operand type ")
368 << cType.getElementType(index) << ", but provided "
369 << constituents[index].getType();
370 }
371 }
372 return success();
373 }
374
375 // Case 4. -- check that all constituents add up tp the expected vector type.
376 auto resultType = dyn_cast<VectorType>(cType);
377 if (!resultType)
378 return emitOpError(
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
381
383 for (Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError("operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
389
390 Type elementType = component.getType();
391 if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
394 } else {
395 sizes.push_back(1);
396 }
397
398 if (elementType != resultType.getElementType())
399 return emitOpError("operand element type mismatch: expected to be ")
400 << resultType.getElementType() << ", but provided " << elementType;
401 }
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError("has incorrect number of operands: expected ")
405 << cType.getNumElements() << ", but provided " << totalCount;
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// spirv.CompositeExtractOp
411//===----------------------------------------------------------------------===//
412
413void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
414 Value composite,
416 auto indexAttr = builder.getI32ArrayAttr(indices);
417 auto elementType =
418 getElementType(composite.getType(), indexAttr, state.location);
419 if (!elementType) {
420 return;
421 }
422 build(builder, state, elementType, composite, indexAttr);
423}
424
425ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
427 OpAsmParser::UnresolvedOperand compositeInfo;
428 Attribute indicesAttr;
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
431 Type compositeType;
432 SMLoc attrLocation;
433
434 if (parser.parseOperand(compositeInfo) ||
435 parser.getCurrentLocation(&attrLocation) ||
436 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
437 parser.parseColonType(compositeType) ||
438 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
439 return failure();
440 }
441
442 Type resultType =
443 getElementType(compositeType, indicesAttr, parser, attrLocation);
444 if (!resultType) {
445 return failure();
446 }
447 result.addTypes(resultType);
448 return success();
449}
450
451void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
452 printer << ' ' << getComposite() << getIndices() << " : "
453 << getComposite().getType();
454}
455
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
458 auto resultType =
459 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
460 if (!resultType)
461 return failure();
462
463 if (resultType != getType()) {
464 return emitOpError("invalid result type: expected ")
465 << resultType << " but provided " << getType();
466 }
467
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// spirv.CompositeInsert
473//===----------------------------------------------------------------------===//
474
475void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
476 Value object, Value composite,
478 auto indexAttr = builder.getI32ArrayAttr(indices);
479 build(builder, state, composite.getType(), object, composite, indexAttr);
480}
481
482ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485 Type objectType, compositeType;
486 Attribute indicesAttr;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
489 auto loc = parser.getCurrentLocation();
490
491 return failure(
492 parser.parseOperandList(operands, 2) ||
493 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
494 parser.parseColonType(objectType) ||
495 parser.parseKeywordType("into", compositeType) ||
496 parser.resolveOperands(operands, {objectType, compositeType}, loc,
497 result.operands) ||
498 parser.addTypesToList(compositeType, result.types));
499}
500
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
503 auto objectType =
504 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
505 if (!objectType)
506 return failure();
507
508 if (objectType != getObject().getType()) {
509 return emitOpError("object operand type should be ")
510 << objectType << ", but found " << getObject().getType();
511 }
512
513 if (getComposite().getType() != getType()) {
514 return emitOpError("result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() << " vs " << getType();
517 }
518
519 return success();
520}
521
522void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
523 printer << " " << getObject() << ", " << getComposite() << getIndices()
524 << " : " << getObject().getType() << " into "
525 << getComposite().getType();
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.Constant
530//===----------------------------------------------------------------------===//
531
532ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534 Attribute value;
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
536 if (parser.parseAttribute(value, valueAttrName, result.attributes))
537 return failure();
538
539 Type type = NoneType::get(parser.getContext());
540 if (auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
543 if (parser.parseColonType(type))
544 return failure();
545 }
546
547 if (isa<TensorArmType>(type)) {
548 if (parser.parseOptionalColon().succeeded())
549 if (parser.parseType(type))
550 return failure();
551 }
552
553 return parser.addTypeToList(type, result.types);
554}
555
556void spirv::ConstantOp::print(OpAsmPrinter &printer) {
557 printer << ' ' << getValue();
558 if (isa<spirv::ArrayType>(getType()))
559 printer << " : " << getType();
560}
561
562static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
563 Type opType) {
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError("expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
569 << denseAttr;
570 }
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError("result type (")
575 << opType << ") does not match value type (" << valueType << ")";
576 return success();
577 }
578 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
581 return success();
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
584 if (!arrayType)
585 return op.emitOpError("result or element type (")
586 << opType << ") does not match value type (" << valueType
587 << "), must be the same or spirv.array";
588
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
594 }
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError("only support nested array result type");
597
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError("result element type (")
601 << opElemType << ") does not match value element type ("
602 << valueElemType << ")";
603 }
604
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError("result number of elements (")
607 << numElements << ") does not match value number of elements ("
608 << shapedType.getNumElements() << ")";
609 }
610 return success();
611 }
612 if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
614 if (!arrayType)
615 return op.emitOpError(
616 "must have spirv.array result type for array value");
617 Type elemType = arrayType.getElementType();
618 for (Attribute element : arrayAttr.getValue()) {
619 // Verify array elements recursively.
620 if (failed(verifyConstantType(op, element, elemType)))
621 return failure();
622 }
623 return success();
624 }
625 return op.emitOpError("cannot have attribute: ") << value;
626}
627
628LogicalResult spirv::ConstantOp::verify() {
629 // ODS already generates checks to make sure the result type is valid. We just
630 // need to additionally check that the value's attribute type is consistent
631 // with the result type.
632 return verifyConstantType(*this, getValueAttr(), getType());
633}
634
635bool spirv::ConstantOp::isBuildableWith(Type type) {
636 // Must be valid SPIR-V type first.
637 if (!isa<spirv::SPIRVType>(type))
638 return false;
639
640 if (isa<SPIRVDialect>(type.getDialect())) {
641 // TODO: support constant struct
642 return isa<spirv::ArrayType>(type);
643 }
644
645 return true;
646}
647
648spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
649 OpBuilder &builder) {
650 if (auto intType = dyn_cast<IntegerType>(type)) {
651 unsigned width = intType.getWidth();
652 if (width == 1)
653 return spirv::ConstantOp::create(builder, loc, type,
654 builder.getBoolAttr(false));
655 return spirv::ConstantOp::create(
656 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
657 }
658 if (auto floatType = dyn_cast<FloatType>(type)) {
659 return spirv::ConstantOp::create(builder, loc, type,
660 builder.getFloatAttr(floatType, 0.0));
661 }
662 if (auto vectorType = dyn_cast<VectorType>(type)) {
663 Type elemType = vectorType.getElementType();
664 if (isa<IntegerType>(elemType)) {
665 return spirv::ConstantOp::create(
666 builder, loc, type,
667 DenseElementsAttr::get(vectorType,
668 IntegerAttr::get(elemType, 0).getValue()));
669 }
670 if (isa<FloatType>(elemType)) {
671 return spirv::ConstantOp::create(
672 builder, loc, type,
673 DenseFPElementsAttr::get(vectorType,
674 FloatAttr::get(elemType, 0.0).getValue()));
675 }
676 }
677
678 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
679}
680
681spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
682 OpBuilder &builder) {
683 if (auto intType = dyn_cast<IntegerType>(type)) {
684 unsigned width = intType.getWidth();
685 if (width == 1)
686 return spirv::ConstantOp::create(builder, loc, type,
687 builder.getBoolAttr(true));
688 return spirv::ConstantOp::create(
689 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
690 }
691 if (auto floatType = dyn_cast<FloatType>(type)) {
692 return spirv::ConstantOp::create(builder, loc, type,
693 builder.getFloatAttr(floatType, 1.0));
694 }
695 if (auto vectorType = dyn_cast<VectorType>(type)) {
696 Type elemType = vectorType.getElementType();
697 if (isa<IntegerType>(elemType)) {
698 return spirv::ConstantOp::create(
699 builder, loc, type,
700 DenseElementsAttr::get(vectorType,
701 IntegerAttr::get(elemType, 1).getValue()));
702 }
703 if (isa<FloatType>(elemType)) {
704 return spirv::ConstantOp::create(
705 builder, loc, type,
706 DenseFPElementsAttr::get(vectorType,
707 FloatAttr::get(elemType, 1.0).getValue()));
708 }
709 }
710
711 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
712}
713
714void mlir::spirv::ConstantOp::getAsmResultNames(
715 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
716 Type type = getType();
717
718 SmallString<32> specialNameBuffer;
719 llvm::raw_svector_ostream specialName(specialNameBuffer);
720 specialName << "cst";
721
722 IntegerType intTy = dyn_cast<IntegerType>(type);
723
724 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
725 assert(intTy);
726
727 if (intTy.getWidth() == 1) {
728 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
729 }
730
731 if (intTy.isSignless()) {
732 specialName << intCst.getInt();
733 } else if (intTy.isUnsigned()) {
734 specialName << intCst.getUInt();
735 } else {
736 specialName << intCst.getSInt();
737 }
738 }
739
740 if (intTy || isa<FloatType>(type)) {
741 specialName << '_' << type;
742 }
743
744 if (auto vecType = dyn_cast<VectorType>(type)) {
745 specialName << "_vec_";
746 specialName << vecType.getDimSize(0);
747
748 Type elementType = vecType.getElementType();
749
750 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
751 specialName << "x" << elementType;
752 }
753 }
754
755 setNameFn(getResult(), specialName.str());
756}
757
758void mlir::spirv::AddressOfOp::getAsmResultNames(
759 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
760 SmallString<32> specialNameBuffer;
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() << "_addr";
763 setNameFn(getResult(), specialName.str());
764}
765
766//===----------------------------------------------------------------------===//
767// spirv.EXTConstantCompositeReplicate
768//===----------------------------------------------------------------------===//
769
770// Returns type of attribute. In case of a TypedAttr this will simply return
771// the type. But for an ArrayAttr which is untyped and can be multidimensional
772// it creates the ArrayType recursively.
774 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775 return typedAttr.getType();
776 }
777
778 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
779 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
780 }
781
782 return nullptr;
783}
784
785LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
786 Type valueType = getValueType(getValue());
787 if (!valueType)
788 return emitError("unknown value attribute type");
789
790 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
791 if (!compositeType)
792 return emitError("result type is not a composite type");
793
794 Type compositeElementType = compositeType.getElementType(0);
795
796 SmallVector<Type, 3> possibleTypes = {compositeElementType};
797 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798 compositeElementType = type.getElementType(0);
799 possibleTypes.push_back(compositeElementType);
800 }
801
802 if (!is_contained(possibleTypes, valueType)) {
803 return emitError("expected value attribute type ")
804 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
805 }
806
807 return success();
808}
809
810//===----------------------------------------------------------------------===//
811// spirv.ControlBarrierOp
812//===----------------------------------------------------------------------===//
813
814LogicalResult spirv::ControlBarrierOp::verify() {
815 return verifyMemorySemantics(getOperation(), getMemorySemantics());
816}
817
818//===----------------------------------------------------------------------===//
819// spirv.EntryPoint
820//===----------------------------------------------------------------------===//
821
822void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
823 spirv::ExecutionModel executionModel,
824 spirv::FuncOp function,
825 ArrayRef<Attribute> interfaceVars) {
826 build(builder, state,
827 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
828 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
829}
830
831ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
833 spirv::ExecutionModel execModel;
834 SmallVector<Attribute, 4> interfaceVars;
835
837 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
838 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
839 return failure();
840 }
841
842 if (!parser.parseOptionalComma()) {
843 // Parse the interface variables
844 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
845 // The name of the interface variable attribute isnt important
846 FlatSymbolRefAttr var;
847 NamedAttrList attrs;
848 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
849 return failure();
850 interfaceVars.push_back(var);
851 return success();
852 }))
853 return failure();
854 }
855 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
856 parser.getBuilder().getArrayAttr(interfaceVars));
857 return success();
858}
859
860void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
861 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
862 printer.printSymbolName(getFn());
863 auto interfaceVars = getInterface().getValue();
864 if (!interfaceVars.empty())
865 printer << ", " << llvm::interleaved(interfaceVars);
866}
867
868LogicalResult spirv::EntryPointOp::verify() {
869 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
870 // verification.
871 return success();
872}
873
874//===----------------------------------------------------------------------===//
875// spirv.ExecutionMode
876//===----------------------------------------------------------------------===//
877
878void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
879 spirv::FuncOp function,
880 spirv::ExecutionMode executionMode,
881 ArrayRef<int32_t> params) {
882 build(builder, state, SymbolRefAttr::get(function),
883 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
884 builder.getI32ArrayAttr(params));
885}
886
887ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
889 spirv::ExecutionMode execMode;
890 Attribute fn;
891 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
893 return failure();
894 }
895
897 Type i32Type = parser.getBuilder().getIntegerType(32);
898 while (!parser.parseOptionalComma()) {
899 NamedAttrList attr;
900 Attribute value;
901 if (parser.parseAttribute(value, i32Type, "value", attr)) {
902 return failure();
903 }
904 values.push_back(cast<IntegerAttr>(value).getInt());
905 }
906 StringRef valuesAttrName =
907 spirv::ExecutionModeOp::getValuesAttrName(result.name);
908 result.addAttribute(valuesAttrName,
909 parser.getBuilder().getI32ArrayAttr(values));
910 return success();
911}
912
913void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
914 printer << " ";
915 printer.printSymbolName(getFn());
916 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
917 ArrayAttr values = this->getValues();
918 if (!values.empty())
919 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
920}
921
922//===----------------------------------------------------------------------===//
923// spirv.ExecutionModeId
924//===----------------------------------------------------------------------===//
925
926ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
928 ExecutionMode execMode;
929 if (Attribute fn;
930 parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
931 parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
932 return failure();
933 }
934
936 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
937 FlatSymbolRefAttr attr;
938 if (parser.parseAttribute(attr))
939 return failure();
940 values.push_back(attr);
941 return success();
942 })) {
943 return failure();
944 }
945
946 StringRef valuesAttrName = getValuesAttrName(result.name);
947 ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
948 result.addAttribute(valuesAttrName, valuesAttr);
949 return success();
950}
951
952void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
953 printer << " ";
954 printer.printSymbolName(getFn());
955 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
956
957 llvm::interleaveComma(
958 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
959 [&](StringRef value) { printer.printSymbolName(value); });
960}
961
962LogicalResult spirv::ExecutionModeIdOp::verify() {
963 // Valid as of SPIRV 1.6
964 switch (getExecutionMode()) {
965 case ExecutionMode::SubgroupsPerWorkgroupId:
966 case ExecutionMode::LocalSizeId:
967 case ExecutionMode::LocalSizeHintId:
968 break;
969 default:
970 return emitOpError("expected ExecutionMode that takes extra operands that "
971 "are <id> operands, got: ")
972 << stringifyExecutionMode(getExecutionMode());
973 }
974
975 if (getValues().empty())
976 return emitOpError("expected at least one value operand");
977
978 for (Attribute value : getValues()) {
979 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
980 if (!valueSymbol)
981 return emitOpError("expected value operands to be symbol reference");
983 (*this)->getParentOp(), valueSymbol);
984 if (!valueOp)
985 return emitOpError("cannot find symbol referenced by value operand: ")
986 << valueSymbol.getValue();
987 }
988
989 return success();
990}
991
992//===----------------------------------------------------------------------===//
993// spirv.func
994//===----------------------------------------------------------------------===//
995
996ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
998 SmallVector<DictionaryAttr> resultAttrs;
999 SmallVector<Type> resultTypes;
1000 auto &builder = parser.getBuilder();
1001
1002 // Parse the name as a symbol.
1003 StringAttr nameAttr;
1004 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1005 result.attributes))
1006 return failure();
1007
1008 // Parse the function signature.
1009 bool isVariadic = false;
1011 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1012 resultAttrs))
1013 return failure();
1014
1015 SmallVector<Type> argTypes;
1016 for (auto &arg : entryArgs)
1017 argTypes.push_back(arg.type);
1018 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1019 result.addAttribute(getFunctionTypeAttrName(result.name),
1020 TypeAttr::get(fnType));
1021
1022 // Parse the optional function control keyword.
1023 spirv::FunctionControl fnControl;
1025 return failure();
1026
1027 // If additional attributes are present, parse them.
1028 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1029 return failure();
1030
1031 // Add the attributes to the function arguments.
1032 assert(resultAttrs.size() == resultTypes.size());
1034 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1035 getResAttrsAttrName(result.name));
1036
1037 // Parse the optional function body.
1038 auto *body = result.addRegion();
1039 OptionalParseResult parseResult =
1040 parser.parseOptionalRegion(*body, entryArgs);
1041 return failure(parseResult.has_value() && failed(*parseResult));
1042}
1043
1044void spirv::FuncOp::print(OpAsmPrinter &printer) {
1045 // Print function name, signature, and control.
1046 printer << " ";
1047 printer.printSymbolName(getSymName());
1048 auto fnType = getFunctionType();
1050 printer, *this, fnType.getInputs(),
1051 /*isVariadic=*/false, fnType.getResults());
1052 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
1053 << "\"";
1055 printer, *this,
1056 {spirv::attributeName<spirv::FunctionControl>(),
1057 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1058 getFunctionControlAttrName()});
1059
1060 // Print the body if this is not an external function.
1061 Region &body = this->getBody();
1062 if (!body.empty()) {
1063 printer << ' ';
1064 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1065 /*printBlockTerminators=*/true);
1066 }
1067}
1068
1069LogicalResult spirv::FuncOp::verifyType() {
1070 FunctionType fnType = getFunctionType();
1071 if (fnType.getNumResults() > 1)
1072 return emitOpError("cannot have more than one result");
1073
1074 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1075 unsigned argIndex) {
1076 auto func = cast<FunctionOpInterface>(getOperation());
1077 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1078 if (argAttr.getName() != spirv::DecorationAttr::name)
1079 continue;
1080 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1081 return decAttr.getValue() == decoration;
1082 }
1083 return false;
1084 };
1085
1086 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1087 Type param = fnType.getInputs()[i];
1088 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1089 if (!inputPtrType)
1090 continue;
1091
1092 auto pointeePtrType =
1093 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1094 if (pointeePtrType) {
1095 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1096 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1097 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1098 // > storage class, the function parameter must be decorated with exactly
1099 // > one of AliasedPointer or RestrictPointer.
1100 if (pointeePtrType.getStorageClass() !=
1101 spirv::StorageClass::PhysicalStorageBuffer)
1102 continue;
1103
1104 bool hasAliasedPtr =
1105 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1106 bool hasRestrictPtr =
1107 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1108 if (!hasAliasedPtr && !hasRestrictPtr)
1109 return emitOpError()
1110 << "with a pointer points to a physical buffer pointer must "
1111 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1112 continue;
1113 }
1114 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1115 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1116 // > the PhysicalStorageBuffer storage class, the function parameter must
1117 // > be decorated with exactly one of Aliased or Restrict.
1118 if (auto pointeeArrayType =
1119 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1120 pointeePtrType =
1121 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1122 } else {
1123 pointeePtrType = inputPtrType;
1124 }
1125
1126 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1127 spirv::StorageClass::PhysicalStorageBuffer)
1128 continue;
1129
1130 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1131 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1132 if (!hasAliased && !hasRestrict)
1133 return emitOpError() << "with physical buffer pointer must be decorated "
1134 "either 'Aliased' or 'Restrict'";
1135 }
1136
1137 return success();
1138}
1139
1140LogicalResult spirv::FuncOp::verifyBody() {
1141 FunctionType fnType = getFunctionType();
1142 if (!isExternal()) {
1143 Block &entryBlock = front();
1144
1145 unsigned numArguments = this->getNumArguments();
1146 if (entryBlock.getNumArguments() != numArguments)
1147 return emitOpError("entry block must have ")
1148 << numArguments << " arguments to match function signature";
1149
1150 for (auto [index, fnArgType, blockArgType] :
1151 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1152 if (blockArgType != fnArgType) {
1153 return emitOpError("type of entry block argument #")
1154 << index << '(' << blockArgType
1155 << ") must match the type of the corresponding argument in "
1156 << "function signature(" << fnArgType << ')';
1157 }
1158 }
1159 }
1160
1161 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1162 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1163 if (fnType.getNumResults() != 0)
1164 return retOp.emitOpError("cannot be used in functions returning value");
1165 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1166 if (fnType.getNumResults() != 1)
1167 return retOp.emitOpError(
1168 "returns 1 value but enclosing function requires ")
1169 << fnType.getNumResults() << " results";
1170
1171 auto retOperandType = retOp.getValue().getType();
1172 auto fnResultType = fnType.getResult(0);
1173 if (retOperandType != fnResultType)
1174 return retOp.emitOpError(" return value's type (")
1175 << retOperandType << ") mismatch with function's result type ("
1176 << fnResultType << ")";
1177 }
1178 return WalkResult::advance();
1179 });
1180
1181 // TODO: verify other bits like linkage type.
1182
1183 return failure(walkResult.wasInterrupted());
1184}
1185
1186void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1187 StringRef name, FunctionType type,
1188 spirv::FunctionControl control,
1191 builder.getStringAttr(name));
1192 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1193 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1194 builder.getAttr<spirv::FunctionControlAttr>(control));
1195 state.attributes.append(attrs.begin(), attrs.end());
1196 state.addRegion();
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// spirv.GLFClampOp
1201//===----------------------------------------------------------------------===//
1202
1203ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1206}
1207void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1208
1209//===----------------------------------------------------------------------===//
1210// spirv.GLUClampOp
1211//===----------------------------------------------------------------------===//
1212
1213ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1216}
1217void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1218
1219//===----------------------------------------------------------------------===//
1220// spirv.GLSClampOp
1221//===----------------------------------------------------------------------===//
1222
1223ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1226}
1227void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1228
1229//===----------------------------------------------------------------------===//
1230// spirv.GLFmaOp
1231//===----------------------------------------------------------------------===//
1232
1233ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1235}
1236void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1237
1238//===----------------------------------------------------------------------===//
1239// spirv.GlobalVariable
1240//===----------------------------------------------------------------------===//
1241
1242void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1243 Type type, StringRef name,
1244 unsigned descriptorSet, unsigned binding) {
1245 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1246 state.addAttribute(
1247 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1248 builder.getI32IntegerAttr(descriptorSet));
1249 state.addAttribute(
1250 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1251 builder.getI32IntegerAttr(binding));
1252}
1253
1254void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1255 Type type, StringRef name,
1256 spirv::BuiltIn builtin) {
1257 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1258 state.addAttribute(
1259 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1260 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1261}
1262
1263ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1265 // Parse variable name.
1266 StringAttr nameAttr;
1267 StringRef initializerAttrName =
1268 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1269 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1270 result.attributes)) {
1271 return failure();
1272 }
1273
1274 // Parse optional initializer
1275 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1276 FlatSymbolRefAttr initSymbol;
1277 if (parser.parseLParen() ||
1278 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1279 result.attributes) ||
1280 parser.parseRParen())
1281 return failure();
1282 }
1283
1284 if (parseVariableDecorations(parser, result)) {
1285 return failure();
1286 }
1287
1288 Type type;
1289 StringRef typeAttrName =
1290 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1291 auto loc = parser.getCurrentLocation();
1292 if (parser.parseColonType(type)) {
1293 return failure();
1294 }
1295 if (!isa<spirv::PointerType>(type)) {
1296 return parser.emitError(loc, "expected spirv.ptr type");
1297 }
1298 result.addAttribute(typeAttrName, TypeAttr::get(type));
1299
1300 return success();
1301}
1302
1303void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1304 SmallVector<StringRef, 4> elidedAttrs{
1305 spirv::attributeName<spirv::StorageClass>()};
1306
1307 // Print variable name.
1308 printer << ' ';
1309 printer.printSymbolName(getSymName());
1310 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1311
1312 StringRef initializerAttrName = this->getInitializerAttrName();
1313 // Print optional initializer
1314 if (auto initializer = this->getInitializer()) {
1315 printer << " " << initializerAttrName << '(';
1316 printer.printSymbolName(*initializer);
1317 printer << ')';
1318 elidedAttrs.push_back(initializerAttrName);
1319 }
1320
1321 StringRef typeAttrName = this->getTypeAttrName();
1322 elidedAttrs.push_back(typeAttrName);
1323 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1324 printer << " : " << getType();
1325}
1326
1327LogicalResult spirv::GlobalVariableOp::verify() {
1328 if (!isa<spirv::PointerType>(getType()))
1329 return emitOpError("result must be of a !spv.ptr type");
1330
1331 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1332 // object. It cannot be Generic. It must be the same as the Storage Class
1333 // operand of the Result Type."
1334 // Also, Function storage class is reserved by spirv.Variable.
1335 auto storageClass = this->storageClass();
1336 if (storageClass == spirv::StorageClass::Generic ||
1337 storageClass == spirv::StorageClass::Function) {
1338 return emitOpError("storage class cannot be '")
1339 << stringifyStorageClass(storageClass) << "'";
1340 }
1341
1342 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1343 this->getInitializerAttrName())) {
1345 (*this)->getParentOp(), init.getAttr());
1346 // TODO: Currently only variable initialization with specialization
1347 // constants is supported. There could be normal constants in the module
1348 // scope as well.
1349 //
1350 // In the current setup we also cannot initialize one global variable with
1351 // another. The problem is that if we try to initialize pointer of type X
1352 // with another pointer type, the validator fails because it expects the
1353 // variable to be initialized to be type X, not pointer to X. Now
1354 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1355 // we cannot initialize one `spirv.GlobalVariable` with another.
1356 if (!initOp ||
1357 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1358 return emitOpError("initializer must be result of a "
1359 "spirv.SpecConstant or "
1360 "spirv.SpecConstantCompositeOp op");
1361 }
1362 }
1363
1364 return success();
1365}
1366
1367//===----------------------------------------------------------------------===//
1368// spirv.INTEL.SubgroupBlockRead
1369//===----------------------------------------------------------------------===//
1370
1371LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1372 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1373 return failure();
1374
1375 return success();
1376}
1377
1378//===----------------------------------------------------------------------===//
1379// spirv.INTEL.SubgroupBlockWrite
1380//===----------------------------------------------------------------------===//
1381
1382ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1384 // Parse the storage class specification
1385 spirv::StorageClass storageClass;
1387 auto loc = parser.getCurrentLocation();
1388 Type elementType;
1389 if (parseEnumStrAttr(storageClass, parser) ||
1390 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1391 parser.parseType(elementType)) {
1392 return failure();
1393 }
1394
1395 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1396 if (auto valVecTy = dyn_cast<VectorType>(elementType))
1397 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1398
1399 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1400 result.operands)) {
1401 return failure();
1402 }
1403 return success();
1404}
1405
1406void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1407 printer << " " << getPtr() << ", " << getValue() << " : "
1408 << getValue().getType();
1409}
1410
1411LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1412 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1413 return failure();
1414
1415 return success();
1416}
1417
1418//===----------------------------------------------------------------------===//
1419// spirv.IAddCarryOp
1420//===----------------------------------------------------------------------===//
1421
1422LogicalResult spirv::IAddCarryOp::verify() {
1423 return ::verifyArithmeticExtendedBinaryOp(*this);
1424}
1425
1426ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1428 return ::parseArithmeticExtendedBinaryOp(parser, result);
1429}
1430
1431void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1432 ::printArithmeticExtendedBinaryOp(*this, printer);
1433}
1434
1435//===----------------------------------------------------------------------===//
1436// spirv.ISubBorrowOp
1437//===----------------------------------------------------------------------===//
1438
1439LogicalResult spirv::ISubBorrowOp::verify() {
1440 return ::verifyArithmeticExtendedBinaryOp(*this);
1441}
1442
1443ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1445 return ::parseArithmeticExtendedBinaryOp(parser, result);
1446}
1447
1448void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1449 ::printArithmeticExtendedBinaryOp(*this, printer);
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// spirv.SMulExtended
1454//===----------------------------------------------------------------------===//
1455
1456LogicalResult spirv::SMulExtendedOp::verify() {
1457 return ::verifyArithmeticExtendedBinaryOp(*this);
1458}
1459
1460ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1462 return ::parseArithmeticExtendedBinaryOp(parser, result);
1463}
1464
1465void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1466 ::printArithmeticExtendedBinaryOp(*this, printer);
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// spirv.UMulExtended
1471//===----------------------------------------------------------------------===//
1472
1473LogicalResult spirv::UMulExtendedOp::verify() {
1474 return ::verifyArithmeticExtendedBinaryOp(*this);
1475}
1476
1477ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1479 return ::parseArithmeticExtendedBinaryOp(parser, result);
1480}
1481
1482void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1483 ::printArithmeticExtendedBinaryOp(*this, printer);
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// spirv.MemoryBarrierOp
1488//===----------------------------------------------------------------------===//
1489
1490LogicalResult spirv::MemoryBarrierOp::verify() {
1491 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1492}
1493
1494//===----------------------------------------------------------------------===//
1495// spirv.module
1496//===----------------------------------------------------------------------===//
1497
1498void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1499 std::optional<StringRef> name) {
1500 OpBuilder::InsertionGuard guard(builder);
1501 builder.createBlock(state.addRegion());
1502 if (name) {
1504 builder.getStringAttr(*name));
1505 }
1506}
1507
1508void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1509 spirv::AddressingModel addressingModel,
1510 spirv::MemoryModel memoryModel,
1511 std::optional<VerCapExtAttr> vceTriple,
1512 std::optional<StringRef> name) {
1513 state.addAttribute(
1514 "addressing_model",
1515 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1516 state.addAttribute("memory_model",
1517 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1518 OpBuilder::InsertionGuard guard(builder);
1519 builder.createBlock(state.addRegion());
1520 if (vceTriple)
1521 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1522 if (name)
1524 builder.getStringAttr(*name));
1525}
1526
1527ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1529 Region *body = result.addRegion();
1530
1531 // If the name is present, parse it.
1532 StringAttr nameAttr;
1534 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1535
1536 // Parse attributes
1537 spirv::AddressingModel addrModel;
1538 spirv::MemoryModel memoryModel;
1540 result) ||
1542 result))
1543 return failure();
1544
1545 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1546 spirv::VerCapExtAttr vceTriple;
1547 if (parser.parseAttribute(vceTriple,
1548 spirv::ModuleOp::getVCETripleAttrName(),
1549 result.attributes))
1550 return failure();
1551 }
1552
1553 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1554 parser.parseRegion(*body, /*arguments=*/{}))
1555 return failure();
1556
1557 // Make sure we have at least one block.
1558 if (body->empty())
1559 body->push_back(new Block());
1560
1561 return success();
1562}
1563
1564void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1565 if (std::optional<StringRef> name = getName()) {
1566 printer << ' ';
1567 printer.printSymbolName(*name);
1568 }
1569
1570 SmallVector<StringRef, 2> elidedAttrs;
1571
1572 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1573 << spirv::stringifyMemoryModel(getMemoryModel());
1574 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1575 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1576 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1578
1579 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1580 printer << " requires " << *triple;
1581 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1582 }
1583
1584 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1585 printer << ' ';
1586 printer.printRegion(getRegion());
1587}
1588
1589LogicalResult spirv::ModuleOp::verifyRegions() {
1590 Dialect *dialect = (*this)->getDialect();
1592 entryPoints;
1593 mlir::SymbolTable table(*this);
1594
1595 for (auto &op : *getBody()) {
1596 if (op.getDialect() != dialect)
1597 return op.emitError("'spirv.module' can only contain spirv.* ops");
1598
1599 // For EntryPoint op, check that the function and execution model is not
1600 // duplicated in EntryPointOps. Also verify that the interface specified
1601 // comes from globalVariables here to make this check cheaper.
1602 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1603 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1604 if (!funcOp) {
1605 return entryPointOp.emitError("function '")
1606 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1607 }
1608 if (auto interface = entryPointOp.getInterface()) {
1609 for (Attribute varRef : interface) {
1610 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1611 if (!varSymRef) {
1612 return entryPointOp.emitError(
1613 "expected symbol reference for interface "
1614 "specification instead of '")
1615 << varRef;
1616 }
1617 auto variableOp =
1618 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1619 if (!variableOp) {
1620 return entryPointOp.emitError("expected spirv.GlobalVariable "
1621 "symbol reference instead of'")
1622 << varSymRef << "'";
1623 }
1624 }
1625 }
1626
1627 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1628 funcOp, entryPointOp.getExecutionModel());
1629 if (!entryPoints.try_emplace(key, entryPointOp).second)
1630 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1631 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1632 // If the function is external and does not have 'Import'
1633 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1634 // LinkageAttributes is used to import external functions.
1635 auto linkageAttr = funcOp.getLinkageAttributes();
1636 auto hasImportLinkage =
1637 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1638 spirv::LinkageType::Import);
1639 if (funcOp.isExternal() && !hasImportLinkage)
1640 return op.emitError(
1641 "'spirv.module' cannot contain external functions "
1642 "without 'Import' linkage_attributes (LinkageAttributes)");
1643
1644 // TODO: move this check to spirv.func.
1645 for (auto &block : funcOp)
1646 for (auto &op : block) {
1647 if (op.getDialect() != dialect)
1648 return op.emitError(
1649 "functions in 'spirv.module' can only contain spirv.* ops");
1650 }
1651 }
1652 }
1653
1654 return success();
1655}
1656
1657//===----------------------------------------------------------------------===//
1658// spirv.mlir.referenceof
1659//===----------------------------------------------------------------------===//
1660
1661LogicalResult spirv::ReferenceOfOp::verify() {
1662 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1663 (*this)->getParentOp(), getSpecConstAttr());
1664 Type constType;
1665
1666 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1667 if (specConstOp)
1668 constType = specConstOp.getDefaultValue().getType();
1669
1670 auto specConstCompositeOp =
1671 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1672 if (specConstCompositeOp)
1673 constType = specConstCompositeOp.getType();
1674
1675 if (!specConstOp && !specConstCompositeOp)
1676 return emitOpError(
1677 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1678
1679 if (getReference().getType() != constType)
1680 return emitOpError("result type mismatch with the referenced "
1681 "specialization constant's type");
1682
1683 return success();
1684}
1685
1686//===----------------------------------------------------------------------===//
1687// spirv.SpecConstant
1688//===----------------------------------------------------------------------===//
1689
1690ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1692 StringAttr nameAttr;
1693 Attribute valueAttr;
1694 StringRef defaultValueAttrName =
1695 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1696
1697 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1698 result.attributes))
1699 return failure();
1700
1701 // Parse optional spec_id.
1702 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1703 IntegerAttr specIdAttr;
1704 if (parser.parseLParen() ||
1705 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1706 parser.parseRParen())
1707 return failure();
1708 }
1709
1710 if (parser.parseEqual() ||
1711 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1712 return failure();
1713
1714 return success();
1715}
1716
1717void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1718 printer << ' ';
1719 printer.printSymbolName(getSymName());
1720 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1721 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1722 printer << " = " << getDefaultValue();
1723}
1724
1725LogicalResult spirv::SpecConstantOp::verify() {
1726 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1727 if (specID.getValue().isNegative())
1728 return emitOpError("SpecId cannot be negative");
1729
1730 auto value = getDefaultValue();
1731 if (isa<IntegerAttr, FloatAttr>(value)) {
1732 // Make sure bitwidth is allowed.
1733 if (!isa<spirv::SPIRVType>(value.getType()))
1734 return emitOpError("default value bitwidth disallowed");
1735 return success();
1736 }
1737 return emitOpError(
1738 "default value can only be a bool, integer, or float scalar");
1739}
1740
1741//===----------------------------------------------------------------------===//
1742// spirv.VectorShuffle
1743//===----------------------------------------------------------------------===//
1744
1745LogicalResult spirv::VectorShuffleOp::verify() {
1746 VectorType resultType = cast<VectorType>(getType());
1747
1748 size_t numResultElements = resultType.getNumElements();
1749 if (numResultElements != getComponents().size())
1750 return emitOpError("result type element count (")
1751 << numResultElements
1752 << ") mismatch with the number of component selectors ("
1753 << getComponents().size() << ")";
1754
1755 size_t totalSrcElements =
1756 cast<VectorType>(getVector1().getType()).getNumElements() +
1757 cast<VectorType>(getVector2().getType()).getNumElements();
1758
1759 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1760 uint32_t index = selector.getZExtValue();
1761 if (index >= totalSrcElements &&
1762 index != std::numeric_limits<uint32_t>().max())
1763 return emitOpError("component selector ")
1764 << index << " out of range: expected to be in [0, "
1765 << totalSrcElements << ") or 0xffffffff";
1766 }
1767 return success();
1768}
1769
1770//===----------------------------------------------------------------------===//
1771// spirv.SpecConstantComposite
1772//===----------------------------------------------------------------------===//
1773
1774ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1776
1777 StringAttr compositeName;
1778 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1779 result.attributes))
1780 return failure();
1781
1782 if (parser.parseLParen())
1783 return failure();
1784
1785 SmallVector<Attribute, 4> constituents;
1786
1787 do {
1788 // The name of the constituent attribute isn't important
1789 const char *attrName = "spec_const";
1790 FlatSymbolRefAttr specConstRef;
1791 NamedAttrList attrs;
1792
1793 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1794 return failure();
1795
1796 constituents.push_back(specConstRef);
1797 } while (!parser.parseOptionalComma());
1798
1799 if (parser.parseRParen())
1800 return failure();
1801
1802 StringAttr compositeSpecConstituentsName =
1803 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1804 result.addAttribute(compositeSpecConstituentsName,
1805 parser.getBuilder().getArrayAttr(constituents));
1806
1807 Type type;
1808 if (parser.parseColonType(type))
1809 return failure();
1810
1811 StringAttr typeAttrName =
1812 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1813 result.addAttribute(typeAttrName, TypeAttr::get(type));
1814
1815 return success();
1816}
1817
1818void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1819 printer << " ";
1820 printer.printSymbolName(getSymName());
1821 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1822 << ") : " << getType();
1823}
1824
1825LogicalResult spirv::SpecConstantCompositeOp::verify() {
1826 auto cType = dyn_cast<spirv::CompositeType>(getType());
1827 auto constituents = this->getConstituents().getValue();
1828
1829 if (!cType)
1830 return emitError("result type must be a composite type, but provided ")
1831 << getType();
1832
1833 if (isa<spirv::CooperativeMatrixType>(cType))
1834 return emitError("unsupported composite type ") << cType;
1835 if (constituents.size() != cType.getNumElements())
1836 return emitError("has incorrect number of operands: expected ")
1837 << cType.getNumElements() << ", but provided "
1838 << constituents.size();
1839
1840 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1841 auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
1842
1843 auto constituentSpecConstOp =
1844 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1845 (*this)->getParentOp(), constituent.getAttr()));
1846
1847 if (constituentSpecConstOp.getDefaultValue().getType() !=
1848 cType.getElementType(index))
1849 return emitError("has incorrect types of operands: expected ")
1850 << cType.getElementType(index) << ", but provided "
1851 << constituentSpecConstOp.getDefaultValue().getType();
1852 }
1853
1854 return success();
1855}
1856
1857//===----------------------------------------------------------------------===//
1858// spirv.EXTSpecConstantCompositeReplicateOp
1859//===----------------------------------------------------------------------===//
1860
1861ParseResult
1862spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1864 StringAttr compositeName;
1865 FlatSymbolRefAttr specConstRef;
1866 const char *attrName = "spec_const";
1867 NamedAttrList attrs;
1868 Type type;
1869
1870 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1871 result.attributes) ||
1872 parser.parseLParen() ||
1873 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1874 parser.parseRParen() || parser.parseColonType(type))
1875 return failure();
1876
1877 StringAttr compositeSpecConstituentName =
1878 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1879 result.name);
1880 result.addAttribute(compositeSpecConstituentName, specConstRef);
1881
1882 StringAttr typeAttrName =
1883 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1884 result.addAttribute(typeAttrName, TypeAttr::get(type));
1885
1886 return success();
1887}
1888
1889void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1890 printer << " ";
1891 printer.printSymbolName(getSymName());
1892 printer << " (" << this->getConstituent() << ") : " << getType();
1893}
1894
1895LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1896 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1897 if (!compositeType)
1898 return emitError("result type must be a composite type, but provided ")
1899 << getType();
1900
1902 (*this)->getParentOp(), this->getConstituent());
1903 if (!constituentOp)
1904 return emitError(
1905 "splat spec constant reference defining constituent not found");
1906
1907 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1908 if (!constituentSpecConstOp)
1909 return emitError("constituent is not a spec constant");
1910
1911 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1912 Type compositeElementType = compositeType.getElementType(0);
1913 if (constituentType != compositeElementType)
1914 return emitError("constituent has incorrect type: expected ")
1915 << compositeElementType << ", but provided " << constituentType;
1916
1917 return success();
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// spirv.SpecConstantOperation
1922//===----------------------------------------------------------------------===//
1923
1924ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1926 Region *body = result.addRegion();
1927
1928 if (parser.parseKeyword("wraps"))
1929 return failure();
1930
1931 body->push_back(new Block);
1932 Block &block = body->back();
1933 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1934
1935 if (!wrappedOp)
1936 return failure();
1937
1938 OpBuilder builder(parser.getContext());
1939 builder.setInsertionPointToEnd(&block);
1940 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
1941 result.location = wrappedOp->getLoc();
1942
1943 result.addTypes(wrappedOp->getResult(0).getType());
1944
1945 if (parser.parseOptionalAttrDict(result.attributes))
1946 return failure();
1947
1948 return success();
1949}
1950
1951void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
1952 printer << " wraps ";
1953 printer.printGenericOp(&getBody().front().front());
1954}
1955
1956LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1957 Block &block = getRegion().getBlocks().front();
1958
1959 if (block.getOperations().size() != 2)
1960 return emitOpError("expected exactly 2 nested ops");
1961
1962 Operation &enclosedOp = block.getOperations().front();
1963
1965 return emitOpError("invalid enclosed op");
1966
1967 for (auto operand : enclosedOp.getOperands())
1968 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1969 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1970 return emitOpError(
1971 "invalid operand, must be defined by a constant operation");
1972
1973 return success();
1974}
1975
1976//===----------------------------------------------------------------------===//
1977// spirv.GL.FrexpStruct
1978//===----------------------------------------------------------------------===//
1979
1980LogicalResult spirv::GLFrexpStructOp::verify() {
1981 spirv::StructType structTy =
1982 dyn_cast<spirv::StructType>(getResult().getType());
1983
1984 if (structTy.getNumElements() != 2)
1985 return emitError("result type must be a struct type with two memebers");
1986
1987 Type significandTy = structTy.getElementType(0);
1988 Type exponentTy = structTy.getElementType(1);
1989 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
1990 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
1991
1992 Type operandTy = getOperand().getType();
1993 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
1994 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
1995
1996 if (significandTy != operandTy)
1997 return emitError("member zero of the resulting struct type must be the "
1998 "same type as the operand");
1999
2000 if (exponentVecTy) {
2001 IntegerType componentIntTy =
2002 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2003 if (!componentIntTy || componentIntTy.getWidth() != 32)
2004 return emitError("member one of the resulting struct type must"
2005 "be a scalar or vector of 32 bit integer type");
2006 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2007 return emitError("member one of the resulting struct type "
2008 "must be a scalar or vector of 32 bit integer type");
2009 }
2010
2011 // Check that the two member types have the same number of components
2012 if (operandVecTy && exponentVecTy &&
2013 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2014 return success();
2015
2016 if (operandFTy && exponentIntTy)
2017 return success();
2018
2019 return emitError("member one of the resulting struct type must have the same "
2020 "number of components as the operand type");
2021}
2022
2023//===----------------------------------------------------------------------===//
2024// spirv.GL.Ldexp
2025//===----------------------------------------------------------------------===//
2026
2027LogicalResult spirv::GLLdexpOp::verify() {
2028 Type significandType = getX().getType();
2029 Type exponentType = getExp().getType();
2030
2031 if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
2032 return emitOpError("operands must both be scalars or vectors");
2033
2034 auto getNumElements = [](Type type) -> unsigned {
2035 if (auto vectorType = dyn_cast<VectorType>(type))
2036 return vectorType.getNumElements();
2037 return 1;
2038 };
2039
2040 if (getNumElements(significandType) != getNumElements(exponentType))
2041 return emitOpError("operands must have the same number of elements");
2042
2043 return success();
2044}
2045
2046//===----------------------------------------------------------------------===//
2047// spirv.ShiftLeftLogicalOp
2048//===----------------------------------------------------------------------===//
2049
2050LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2051 return verifyShiftOp(*this);
2052}
2053
2054//===----------------------------------------------------------------------===//
2055// spirv.ShiftRightArithmeticOp
2056//===----------------------------------------------------------------------===//
2057
2058LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2059 return verifyShiftOp(*this);
2060}
2061
2062//===----------------------------------------------------------------------===//
2063// spirv.ShiftRightLogicalOp
2064//===----------------------------------------------------------------------===//
2065
2066LogicalResult spirv::ShiftRightLogicalOp::verify() {
2067 return verifyShiftOp(*this);
2068}
2069
2070//===----------------------------------------------------------------------===//
2071// spirv.VectorTimesScalarOp
2072//===----------------------------------------------------------------------===//
2073
2074LogicalResult spirv::VectorTimesScalarOp::verify() {
2075 if (getVector().getType() != getType())
2076 return emitOpError("vector operand and result type mismatch");
2077 auto scalarType = cast<VectorType>(getType()).getElementType();
2078 if (getScalar().getType() != scalarType)
2079 return emitOpError("scalar operand and result element type match");
2080 return success();
2081}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:266
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:562
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:252
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:298
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:185
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:290
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:280
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Value getOperand(unsigned idx)
Definition Operation.h:376
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:261
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.