MLIR 22.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
24#include "mlir/IR/Builders.h"
28#include "mlir/IR/Operation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(descriptorSetName);
104 elidedAttrs.push_back(bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(builtInName);
115 }
116
117 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118}
119
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result.attributes) ||
130 parser.parseColon() || parser.parseType(type))
131 return failure();
132 auto fnType = dyn_cast<FunctionType>(type);
133 if (!fnType) {
134 parser.emitError(loc, "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138 return failure();
139 result.addTypes(fnType.getResults());
140 return success();
141 }
142 return failure(parser.parseOperandList(ops) ||
143 parser.parseOptionalAttrDict(result.attributes) ||
144 parser.parseColonType(type) ||
145 parser.resolveOperands(ops, type, result.operands) ||
146 parser.addTypeToList(type, result.types));
147}
148
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(0).getType();
155 if (llvm::any_of(op->getOperandTypes(),
156 [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(op->getOperands());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
174
175 if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
176 return op.emitOpError("mismatch in result type and pointer type");
177 }
178 return success();
179}
180
181/// Walks the given type hierarchy with the given indices, potentially down
182/// to component granularity, to select an element type. Returns null type and
183/// emits errors with the given loc on failure.
184static Type
186 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
187 if (indices.empty()) {
188 emitErrorFn("expected at least one index for spirv.CompositeExtract");
189 return nullptr;
190 }
191
192 for (auto index : indices) {
193 if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
195 (index < 0 ||
196 static_cast<uint64_t>(index) >= cType.getNumElements())) {
197 emitErrorFn("index ") << index << " out of bounds for " << type;
198 return nullptr;
199 }
200 type = cType.getElementType(index);
201 } else {
202 emitErrorFn("cannot extract from non-composite type ")
203 << type << " with index " << index;
204 return nullptr;
205 }
206 }
207 return type;
208}
209
210static Type
212 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
216 return nullptr;
217 }
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn("expected at least one index for spirv.CompositeExtract");
220 return nullptr;
221 }
222
223 SmallVector<int32_t, 2> indexVals;
224 for (auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
226 if (!indexIntAttr) {
227 emitErrorFn("expected an 32-bit integer for index, but found '")
228 << indexAttr << "'";
229 return nullptr;
230 }
231 indexVals.push_back(indexIntAttr.getInt());
232 }
233 return getElementType(type, indexVals, emitErrorFn);
234}
235
237 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
238 return ::mlir::emitError(loc, err);
239 };
240 return getElementType(type, indices, errorFn);
241}
242
244 SMLoc loc) {
245 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
246 return parser.emitError(loc, err);
247 };
248 return getElementType(type, indices, errorFn);
249}
250
251template <typename ExtendedBinaryOp>
252static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError("expected result struct type containing two members");
256
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
262
263 return success();
264}
265
269 if (parser.parseOptionalAttrDict(result.attributes) ||
270 parser.parseOperandList(operands) || parser.parseColon())
271 return failure();
272
273 Type resultType;
274 SMLoc loc = parser.getCurrentLocation();
275 if (parser.parseType(resultType))
276 return failure();
277
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.emitError(loc, "expected spirv.struct type with two members");
281
282 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
283 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
284 return failure();
285
286 result.addTypes(resultType);
287 return success();
288}
289
291 OpAsmPrinter &printer) {
292 printer << ' ';
293 printer.printOptionalAttrDict(op->getAttrs());
294 printer.printOperands(op->getOperands());
295 printer << " : " << op->getResultTypes().front();
296}
297
298static LogicalResult verifyShiftOp(Operation *op) {
299 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
300 return op->emitError("expected the same type for the first operand and "
301 "result, but provided ")
302 << op->getOperand(0).getType() << " and "
303 << op->getResult(0).getType();
304 }
305 return success();
306}
307
308//===----------------------------------------------------------------------===//
309// spirv.mlir.addressof
310//===----------------------------------------------------------------------===//
311
312void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
315}
316
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
319 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
320 getVariableAttr()));
321 if (!varOp) {
322 return emitOpError("expected spirv.GlobalVariable symbol");
323 }
324 if (getPointer().getType() != varOp.getType()) {
325 return emitOpError(
326 "result type mismatch with the referenced global variable's type");
327 }
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// spirv.CompositeConstruct
333//===----------------------------------------------------------------------===//
334
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
337
338 // There are 4 cases with varying verification rules:
339 // 1. Cooperative Matrices (1 constituent)
340 // 2. Structs (1 constituent for each member)
341 // 3. Arrays (1 constituent for each array element)
342 // 4. Vectors (1 constituent (sub-)element for each vector element)
343
344 auto coopElementType =
347 [](auto coopType) { return coopType.getElementType(); })
348 .Default(nullptr);
349
350 // Case 1. -- matrices.
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError("has incorrect number of operands: expected ")
354 << "1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError("operand type mismatch: expected operand type ")
357 << coopElementType << ", but provided "
358 << constituents.front().getType();
359 return success();
360 }
361
362 // Case 2./3./4. -- number of constituents matches the number of elements.
363 auto cType = cast<spirv::CompositeType>(getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
366 if (constituents[index].getType() != cType.getElementType(index)) {
367 return emitOpError("operand type mismatch: expected operand type ")
368 << cType.getElementType(index) << ", but provided "
369 << constituents[index].getType();
370 }
371 }
372 return success();
373 }
374
375 // Case 4. -- check that all constituents add up tp the expected vector type.
376 auto resultType = dyn_cast<VectorType>(cType);
377 if (!resultType)
378 return emitOpError(
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
381
383 for (Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError("operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
389
390 Type elementType = component.getType();
391 if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
394 } else {
395 sizes.push_back(1);
396 }
397
398 if (elementType != resultType.getElementType())
399 return emitOpError("operand element type mismatch: expected to be ")
400 << resultType.getElementType() << ", but provided " << elementType;
401 }
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError("has incorrect number of operands: expected ")
405 << cType.getNumElements() << ", but provided " << totalCount;
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// spirv.CompositeExtractOp
411//===----------------------------------------------------------------------===//
412
413void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
414 Value composite,
416 auto indexAttr = builder.getI32ArrayAttr(indices);
417 auto elementType =
418 getElementType(composite.getType(), indexAttr, state.location);
419 if (!elementType) {
420 return;
421 }
422 build(builder, state, elementType, composite, indexAttr);
423}
424
425ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
427 OpAsmParser::UnresolvedOperand compositeInfo;
428 Attribute indicesAttr;
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
431 Type compositeType;
432 SMLoc attrLocation;
433
434 if (parser.parseOperand(compositeInfo) ||
435 parser.getCurrentLocation(&attrLocation) ||
436 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
437 parser.parseColonType(compositeType) ||
438 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
439 return failure();
440 }
441
442 Type resultType =
443 getElementType(compositeType, indicesAttr, parser, attrLocation);
444 if (!resultType) {
445 return failure();
446 }
447 result.addTypes(resultType);
448 return success();
449}
450
451void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
452 printer << ' ' << getComposite() << getIndices() << " : "
453 << getComposite().getType();
454}
455
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
458 auto resultType =
459 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
460 if (!resultType)
461 return failure();
462
463 if (resultType != getType()) {
464 return emitOpError("invalid result type: expected ")
465 << resultType << " but provided " << getType();
466 }
467
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// spirv.CompositeInsert
473//===----------------------------------------------------------------------===//
474
475void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
476 Value object, Value composite,
478 auto indexAttr = builder.getI32ArrayAttr(indices);
479 build(builder, state, composite.getType(), object, composite, indexAttr);
480}
481
482ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485 Type objectType, compositeType;
486 Attribute indicesAttr;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
489 auto loc = parser.getCurrentLocation();
490
491 return failure(
492 parser.parseOperandList(operands, 2) ||
493 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
494 parser.parseColonType(objectType) ||
495 parser.parseKeywordType("into", compositeType) ||
496 parser.resolveOperands(operands, {objectType, compositeType}, loc,
497 result.operands) ||
498 parser.addTypesToList(compositeType, result.types));
499}
500
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
503 auto objectType =
504 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
505 if (!objectType)
506 return failure();
507
508 if (objectType != getObject().getType()) {
509 return emitOpError("object operand type should be ")
510 << objectType << ", but found " << getObject().getType();
511 }
512
513 if (getComposite().getType() != getType()) {
514 return emitOpError("result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() << " vs " << getType();
517 }
518
519 return success();
520}
521
522void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
523 printer << " " << getObject() << ", " << getComposite() << getIndices()
524 << " : " << getObject().getType() << " into "
525 << getComposite().getType();
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.Constant
530//===----------------------------------------------------------------------===//
531
532ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534 Attribute value;
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
536 if (parser.parseAttribute(value, valueAttrName, result.attributes))
537 return failure();
538
539 Type type = NoneType::get(parser.getContext());
540 if (auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
543 if (parser.parseColonType(type))
544 return failure();
545 }
546
547 if (isa<TensorArmType>(type)) {
548 if (parser.parseOptionalColon().succeeded())
549 if (parser.parseType(type))
550 return failure();
551 }
552
553 return parser.addTypeToList(type, result.types);
554}
555
556void spirv::ConstantOp::print(OpAsmPrinter &printer) {
557 printer << ' ' << getValue();
558 if (isa<spirv::ArrayType>(getType()))
559 printer << " : " << getType();
560}
561
562static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
563 Type opType) {
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError("expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
569 << denseAttr;
570 }
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError("result type (")
575 << opType << ") does not match value type (" << valueType << ")";
576 return success();
577 }
578 if (isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
581 return success();
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
584 if (!arrayType)
585 return op.emitOpError("result or element type (")
586 << opType << ") does not match value type (" << valueType
587 << "), must be the same or spirv.array";
588
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
594 }
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError("only support nested array result type");
597
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError("result element type (")
601 << opElemType << ") does not match value element type ("
602 << valueElemType << ")";
603 }
604
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError("result number of elements (")
607 << numElements << ") does not match value number of elements ("
608 << shapedType.getNumElements() << ")";
609 }
610 return success();
611 }
612 if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
614 if (!arrayType)
615 return op.emitOpError(
616 "must have spirv.array result type for array value");
617 Type elemType = arrayType.getElementType();
618 for (Attribute element : arrayAttr.getValue()) {
619 // Verify array elements recursively.
620 if (failed(verifyConstantType(op, element, elemType)))
621 return failure();
622 }
623 return success();
624 }
625 return op.emitOpError("cannot have attribute: ") << value;
626}
627
628LogicalResult spirv::ConstantOp::verify() {
629 // ODS already generates checks to make sure the result type is valid. We just
630 // need to additionally check that the value's attribute type is consistent
631 // with the result type.
632 return verifyConstantType(*this, getValueAttr(), getType());
633}
634
635bool spirv::ConstantOp::isBuildableWith(Type type) {
636 // Must be valid SPIR-V type first.
637 if (!isa<spirv::SPIRVType>(type))
638 return false;
639
640 if (isa<SPIRVDialect>(type.getDialect())) {
641 // TODO: support constant struct
642 return isa<spirv::ArrayType>(type);
643 }
644
645 return true;
646}
647
648spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
649 OpBuilder &builder) {
650 if (auto intType = dyn_cast<IntegerType>(type)) {
651 unsigned width = intType.getWidth();
652 if (width == 1)
653 return spirv::ConstantOp::create(builder, loc, type,
654 builder.getBoolAttr(false));
655 return spirv::ConstantOp::create(
656 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
657 }
658 if (auto floatType = dyn_cast<FloatType>(type)) {
659 return spirv::ConstantOp::create(builder, loc, type,
660 builder.getFloatAttr(floatType, 0.0));
661 }
662 if (auto vectorType = dyn_cast<VectorType>(type)) {
663 Type elemType = vectorType.getElementType();
664 if (isa<IntegerType>(elemType)) {
665 return spirv::ConstantOp::create(
666 builder, loc, type,
667 DenseElementsAttr::get(vectorType,
668 IntegerAttr::get(elemType, 0).getValue()));
669 }
670 if (isa<FloatType>(elemType)) {
671 return spirv::ConstantOp::create(
672 builder, loc, type,
673 DenseFPElementsAttr::get(vectorType,
674 FloatAttr::get(elemType, 0.0).getValue()));
675 }
676 }
677
678 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
679}
680
681spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
682 OpBuilder &builder) {
683 if (auto intType = dyn_cast<IntegerType>(type)) {
684 unsigned width = intType.getWidth();
685 if (width == 1)
686 return spirv::ConstantOp::create(builder, loc, type,
687 builder.getBoolAttr(true));
688 return spirv::ConstantOp::create(
689 builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
690 }
691 if (auto floatType = dyn_cast<FloatType>(type)) {
692 return spirv::ConstantOp::create(builder, loc, type,
693 builder.getFloatAttr(floatType, 1.0));
694 }
695 if (auto vectorType = dyn_cast<VectorType>(type)) {
696 Type elemType = vectorType.getElementType();
697 if (isa<IntegerType>(elemType)) {
698 return spirv::ConstantOp::create(
699 builder, loc, type,
700 DenseElementsAttr::get(vectorType,
701 IntegerAttr::get(elemType, 1).getValue()));
702 }
703 if (isa<FloatType>(elemType)) {
704 return spirv::ConstantOp::create(
705 builder, loc, type,
706 DenseFPElementsAttr::get(vectorType,
707 FloatAttr::get(elemType, 1.0).getValue()));
708 }
709 }
710
711 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
712}
713
714void mlir::spirv::ConstantOp::getAsmResultNames(
715 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
716 Type type = getType();
717
718 SmallString<32> specialNameBuffer;
719 llvm::raw_svector_ostream specialName(specialNameBuffer);
720 specialName << "cst";
721
722 IntegerType intTy = dyn_cast<IntegerType>(type);
723
724 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
725 assert(intTy);
726
727 if (intTy.getWidth() == 1) {
728 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
729 }
730
731 if (intTy.isSignless()) {
732 specialName << intCst.getInt();
733 } else if (intTy.isUnsigned()) {
734 specialName << intCst.getUInt();
735 } else {
736 specialName << intCst.getSInt();
737 }
738 }
739
740 if (intTy || isa<FloatType>(type)) {
741 specialName << '_' << type;
742 }
743
744 if (auto vecType = dyn_cast<VectorType>(type)) {
745 specialName << "_vec_";
746 specialName << vecType.getDimSize(0);
747
748 Type elementType = vecType.getElementType();
749
750 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
751 specialName << "x" << elementType;
752 }
753 }
754
755 setNameFn(getResult(), specialName.str());
756}
757
758void mlir::spirv::AddressOfOp::getAsmResultNames(
759 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
760 SmallString<32> specialNameBuffer;
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() << "_addr";
763 setNameFn(getResult(), specialName.str());
764}
765
766//===----------------------------------------------------------------------===//
767// spirv.EXTConstantCompositeReplicate
768//===----------------------------------------------------------------------===//
769
770// Returns type of attribute. In case of a TypedAttr this will simply return
771// the type. But for an ArrayAttr which is untyped and can be multidimensional
772// it creates the ArrayType recursively.
774 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775 return typedAttr.getType();
776 }
777
778 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
779 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
780 }
781
782 return nullptr;
783}
784
785LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
786 Type valueType = getValueType(getValue());
787 if (!valueType)
788 return emitError("unknown value attribute type");
789
790 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
791 if (!compositeType)
792 return emitError("result type is not a composite type");
793
794 Type compositeElementType = compositeType.getElementType(0);
795
796 SmallVector<Type, 3> possibleTypes = {compositeElementType};
797 while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798 compositeElementType = type.getElementType(0);
799 possibleTypes.push_back(compositeElementType);
800 }
801
802 if (!is_contained(possibleTypes, valueType)) {
803 return emitError("expected value attribute type ")
804 << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
805 }
806
807 return success();
808}
809
810//===----------------------------------------------------------------------===//
811// spirv.ControlBarrierOp
812//===----------------------------------------------------------------------===//
813
814LogicalResult spirv::ControlBarrierOp::verify() {
815 return verifyMemorySemantics(getOperation(), getMemorySemantics());
816}
817
818//===----------------------------------------------------------------------===//
819// spirv.EntryPoint
820//===----------------------------------------------------------------------===//
821
822void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
823 spirv::ExecutionModel executionModel,
824 spirv::FuncOp function,
825 ArrayRef<Attribute> interfaceVars) {
826 build(builder, state,
827 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
828 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
829}
830
831ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
833 spirv::ExecutionModel execModel;
834 SmallVector<Attribute, 4> interfaceVars;
835
837 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
838 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
839 return failure();
840 }
841
842 if (!parser.parseOptionalComma()) {
843 // Parse the interface variables
844 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
845 // The name of the interface variable attribute isnt important
846 FlatSymbolRefAttr var;
847 NamedAttrList attrs;
848 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
849 return failure();
850 interfaceVars.push_back(var);
851 return success();
852 }))
853 return failure();
854 }
855 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
856 parser.getBuilder().getArrayAttr(interfaceVars));
857 return success();
858}
859
860void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
861 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
862 printer.printSymbolName(getFn());
863 auto interfaceVars = getInterface().getValue();
864 if (!interfaceVars.empty())
865 printer << ", " << llvm::interleaved(interfaceVars);
866}
867
868LogicalResult spirv::EntryPointOp::verify() {
869 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
870 // verification.
871 return success();
872}
873
874//===----------------------------------------------------------------------===//
875// spirv.ExecutionMode
876//===----------------------------------------------------------------------===//
877
878void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
879 spirv::FuncOp function,
880 spirv::ExecutionMode executionMode,
881 ArrayRef<int32_t> params) {
882 build(builder, state, SymbolRefAttr::get(function),
883 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
884 builder.getI32ArrayAttr(params));
885}
886
887ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
889 spirv::ExecutionMode execMode;
890 Attribute fn;
891 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
893 return failure();
894 }
895
897 Type i32Type = parser.getBuilder().getIntegerType(32);
898 while (!parser.parseOptionalComma()) {
899 NamedAttrList attr;
900 Attribute value;
901 if (parser.parseAttribute(value, i32Type, "value", attr)) {
902 return failure();
903 }
904 values.push_back(cast<IntegerAttr>(value).getInt());
905 }
906 StringRef valuesAttrName =
907 spirv::ExecutionModeOp::getValuesAttrName(result.name);
908 result.addAttribute(valuesAttrName,
909 parser.getBuilder().getI32ArrayAttr(values));
910 return success();
911}
912
913void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
914 printer << " ";
915 printer.printSymbolName(getFn());
916 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
917 ArrayAttr values = this->getValues();
918 if (!values.empty())
919 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
920}
921
922//===----------------------------------------------------------------------===//
923// spirv.func
924//===----------------------------------------------------------------------===//
925
926ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
928 SmallVector<DictionaryAttr> resultAttrs;
929 SmallVector<Type> resultTypes;
930 auto &builder = parser.getBuilder();
931
932 // Parse the name as a symbol.
933 StringAttr nameAttr;
934 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
935 result.attributes))
936 return failure();
937
938 // Parse the function signature.
939 bool isVariadic = false;
941 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
942 resultAttrs))
943 return failure();
944
945 SmallVector<Type> argTypes;
946 for (auto &arg : entryArgs)
947 argTypes.push_back(arg.type);
948 auto fnType = builder.getFunctionType(argTypes, resultTypes);
949 result.addAttribute(getFunctionTypeAttrName(result.name),
950 TypeAttr::get(fnType));
951
952 // Parse the optional function control keyword.
953 spirv::FunctionControl fnControl;
955 return failure();
956
957 // If additional attributes are present, parse them.
958 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
959 return failure();
960
961 // Add the attributes to the function arguments.
962 assert(resultAttrs.size() == resultTypes.size());
964 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
965 getResAttrsAttrName(result.name));
966
967 // Parse the optional function body.
968 auto *body = result.addRegion();
969 OptionalParseResult parseResult =
970 parser.parseOptionalRegion(*body, entryArgs);
971 return failure(parseResult.has_value() && failed(*parseResult));
972}
973
974void spirv::FuncOp::print(OpAsmPrinter &printer) {
975 // Print function name, signature, and control.
976 printer << " ";
977 printer.printSymbolName(getSymName());
978 auto fnType = getFunctionType();
980 printer, *this, fnType.getInputs(),
981 /*isVariadic=*/false, fnType.getResults());
982 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
983 << "\"";
985 printer, *this,
987 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
988 getFunctionControlAttrName()});
989
990 // Print the body if this is not an external function.
991 Region &body = this->getBody();
992 if (!body.empty()) {
993 printer << ' ';
994 printer.printRegion(body, /*printEntryBlockArgs=*/false,
995 /*printBlockTerminators=*/true);
996 }
997}
998
999LogicalResult spirv::FuncOp::verifyType() {
1000 FunctionType fnType = getFunctionType();
1001 if (fnType.getNumResults() > 1)
1002 return emitOpError("cannot have more than one result");
1003
1004 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1005 unsigned argIndex) {
1006 auto func = cast<FunctionOpInterface>(getOperation());
1007 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1008 if (argAttr.getName() != spirv::DecorationAttr::name)
1009 continue;
1010 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1011 return decAttr.getValue() == decoration;
1012 }
1013 return false;
1014 };
1015
1016 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1017 Type param = fnType.getInputs()[i];
1018 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1019 if (!inputPtrType)
1020 continue;
1021
1022 auto pointeePtrType =
1023 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1024 if (pointeePtrType) {
1025 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1026 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1027 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1028 // > storage class, the function parameter must be decorated with exactly
1029 // > one of AliasedPointer or RestrictPointer.
1030 if (pointeePtrType.getStorageClass() !=
1031 spirv::StorageClass::PhysicalStorageBuffer)
1032 continue;
1033
1034 bool hasAliasedPtr =
1035 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1036 bool hasRestrictPtr =
1037 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1038 if (!hasAliasedPtr && !hasRestrictPtr)
1039 return emitOpError()
1040 << "with a pointer points to a physical buffer pointer must "
1041 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1042 continue;
1043 }
1044 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1045 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1046 // > the PhysicalStorageBuffer storage class, the function parameter must
1047 // > be decorated with exactly one of Aliased or Restrict.
1048 if (auto pointeeArrayType =
1049 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1050 pointeePtrType =
1051 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1052 } else {
1053 pointeePtrType = inputPtrType;
1054 }
1055
1056 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1057 spirv::StorageClass::PhysicalStorageBuffer)
1058 continue;
1059
1060 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1061 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1062 if (!hasAliased && !hasRestrict)
1063 return emitOpError() << "with physical buffer pointer must be decorated "
1064 "either 'Aliased' or 'Restrict'";
1065 }
1066
1067 return success();
1068}
1069
1070LogicalResult spirv::FuncOp::verifyBody() {
1071 FunctionType fnType = getFunctionType();
1072 if (!isExternal()) {
1073 Block &entryBlock = front();
1074
1075 unsigned numArguments = this->getNumArguments();
1076 if (entryBlock.getNumArguments() != numArguments)
1077 return emitOpError("entry block must have ")
1078 << numArguments << " arguments to match function signature";
1079
1080 for (auto [index, fnArgType, blockArgType] :
1081 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1082 if (blockArgType != fnArgType) {
1083 return emitOpError("type of entry block argument #")
1084 << index << '(' << blockArgType
1085 << ") must match the type of the corresponding argument in "
1086 << "function signature(" << fnArgType << ')';
1087 }
1088 }
1089 }
1090
1091 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1092 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1093 if (fnType.getNumResults() != 0)
1094 return retOp.emitOpError("cannot be used in functions returning value");
1095 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1096 if (fnType.getNumResults() != 1)
1097 return retOp.emitOpError(
1098 "returns 1 value but enclosing function requires ")
1099 << fnType.getNumResults() << " results";
1100
1101 auto retOperandType = retOp.getValue().getType();
1102 auto fnResultType = fnType.getResult(0);
1103 if (retOperandType != fnResultType)
1104 return retOp.emitOpError(" return value's type (")
1105 << retOperandType << ") mismatch with function's result type ("
1106 << fnResultType << ")";
1107 }
1108 return WalkResult::advance();
1109 });
1110
1111 // TODO: verify other bits like linkage type.
1112
1113 return failure(walkResult.wasInterrupted());
1114}
1115
1116void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1117 StringRef name, FunctionType type,
1118 spirv::FunctionControl control,
1121 builder.getStringAttr(name));
1122 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1124 builder.getAttr<spirv::FunctionControlAttr>(control));
1125 state.attributes.append(attrs.begin(), attrs.end());
1126 state.addRegion();
1127}
1128
1129//===----------------------------------------------------------------------===//
1130// spirv.GLFClampOp
1131//===----------------------------------------------------------------------===//
1132
1133ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1136}
1137void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1138
1139//===----------------------------------------------------------------------===//
1140// spirv.GLUClampOp
1141//===----------------------------------------------------------------------===//
1142
1143ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1146}
1147void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1148
1149//===----------------------------------------------------------------------===//
1150// spirv.GLSClampOp
1151//===----------------------------------------------------------------------===//
1152
1153ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1156}
1157void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1158
1159//===----------------------------------------------------------------------===//
1160// spirv.GLFmaOp
1161//===----------------------------------------------------------------------===//
1162
1163ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1165}
1166void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1167
1168//===----------------------------------------------------------------------===//
1169// spirv.GlobalVariable
1170//===----------------------------------------------------------------------===//
1171
1172void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1173 Type type, StringRef name,
1174 unsigned descriptorSet, unsigned binding) {
1175 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1176 state.addAttribute(
1177 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1178 builder.getI32IntegerAttr(descriptorSet));
1179 state.addAttribute(
1180 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1181 builder.getI32IntegerAttr(binding));
1182}
1183
1184void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1185 Type type, StringRef name,
1186 spirv::BuiltIn builtin) {
1187 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1188 state.addAttribute(
1189 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1190 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1191}
1192
1193ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1195 // Parse variable name.
1196 StringAttr nameAttr;
1197 StringRef initializerAttrName =
1198 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1199 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1200 result.attributes)) {
1201 return failure();
1202 }
1203
1204 // Parse optional initializer
1205 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1206 FlatSymbolRefAttr initSymbol;
1207 if (parser.parseLParen() ||
1208 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1209 result.attributes) ||
1210 parser.parseRParen())
1211 return failure();
1212 }
1213
1214 if (parseVariableDecorations(parser, result)) {
1215 return failure();
1216 }
1217
1218 Type type;
1219 StringRef typeAttrName =
1220 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1221 auto loc = parser.getCurrentLocation();
1222 if (parser.parseColonType(type)) {
1223 return failure();
1224 }
1225 if (!isa<spirv::PointerType>(type)) {
1226 return parser.emitError(loc, "expected spirv.ptr type");
1227 }
1228 result.addAttribute(typeAttrName, TypeAttr::get(type));
1229
1230 return success();
1231}
1232
1233void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1234 SmallVector<StringRef, 4> elidedAttrs{
1236
1237 // Print variable name.
1238 printer << ' ';
1239 printer.printSymbolName(getSymName());
1240 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1241
1242 StringRef initializerAttrName = this->getInitializerAttrName();
1243 // Print optional initializer
1244 if (auto initializer = this->getInitializer()) {
1245 printer << " " << initializerAttrName << '(';
1246 printer.printSymbolName(*initializer);
1247 printer << ')';
1248 elidedAttrs.push_back(initializerAttrName);
1249 }
1250
1251 StringRef typeAttrName = this->getTypeAttrName();
1252 elidedAttrs.push_back(typeAttrName);
1253 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1254 printer << " : " << getType();
1255}
1256
1257LogicalResult spirv::GlobalVariableOp::verify() {
1258 if (!isa<spirv::PointerType>(getType()))
1259 return emitOpError("result must be of a !spv.ptr type");
1260
1261 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1262 // object. It cannot be Generic. It must be the same as the Storage Class
1263 // operand of the Result Type."
1264 // Also, Function storage class is reserved by spirv.Variable.
1265 auto storageClass = this->storageClass();
1266 if (storageClass == spirv::StorageClass::Generic ||
1267 storageClass == spirv::StorageClass::Function) {
1268 return emitOpError("storage class cannot be '")
1269 << stringifyStorageClass(storageClass) << "'";
1270 }
1271
1272 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1273 this->getInitializerAttrName())) {
1275 (*this)->getParentOp(), init.getAttr());
1276 // TODO: Currently only variable initialization with specialization
1277 // constants is supported. There could be normal constants in the module
1278 // scope as well.
1279 //
1280 // In the current setup we also cannot initialize one global variable with
1281 // another. The problem is that if we try to initialize pointer of type X
1282 // with another pointer type, the validator fails because it expects the
1283 // variable to be initialized to be type X, not pointer to X. Now
1284 // `spirv.GlobalVariable` only allows pointer type, so in the current design
1285 // we cannot initialize one `spirv.GlobalVariable` with another.
1286 if (!initOp ||
1287 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1288 return emitOpError("initializer must be result of a "
1289 "spirv.SpecConstant or "
1290 "spirv.SpecConstantCompositeOp op");
1291 }
1292 }
1293
1294 return success();
1295}
1296
1297//===----------------------------------------------------------------------===//
1298// spirv.INTEL.SubgroupBlockRead
1299//===----------------------------------------------------------------------===//
1300
1301LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1302 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1303 return failure();
1304
1305 return success();
1306}
1307
1308//===----------------------------------------------------------------------===//
1309// spirv.INTEL.SubgroupBlockWrite
1310//===----------------------------------------------------------------------===//
1311
1312ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1314 // Parse the storage class specification
1315 spirv::StorageClass storageClass;
1317 auto loc = parser.getCurrentLocation();
1318 Type elementType;
1319 if (parseEnumStrAttr(storageClass, parser) ||
1320 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1321 parser.parseType(elementType)) {
1322 return failure();
1323 }
1324
1325 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1326 if (auto valVecTy = dyn_cast<VectorType>(elementType))
1327 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1328
1329 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1330 result.operands)) {
1331 return failure();
1332 }
1333 return success();
1334}
1335
1336void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1337 printer << " " << getPtr() << ", " << getValue() << " : "
1338 << getValue().getType();
1339}
1340
1341LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1342 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1343 return failure();
1344
1345 return success();
1346}
1347
1348//===----------------------------------------------------------------------===//
1349// spirv.IAddCarryOp
1350//===----------------------------------------------------------------------===//
1351
1352LogicalResult spirv::IAddCarryOp::verify() {
1353 return ::verifyArithmeticExtendedBinaryOp(*this);
1354}
1355
1356ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1358 return ::parseArithmeticExtendedBinaryOp(parser, result);
1359}
1360
1361void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1362 ::printArithmeticExtendedBinaryOp(*this, printer);
1363}
1364
1365//===----------------------------------------------------------------------===//
1366// spirv.ISubBorrowOp
1367//===----------------------------------------------------------------------===//
1368
1369LogicalResult spirv::ISubBorrowOp::verify() {
1370 return ::verifyArithmeticExtendedBinaryOp(*this);
1371}
1372
1373ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1375 return ::parseArithmeticExtendedBinaryOp(parser, result);
1376}
1377
1378void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1379 ::printArithmeticExtendedBinaryOp(*this, printer);
1380}
1381
1382//===----------------------------------------------------------------------===//
1383// spirv.SMulExtended
1384//===----------------------------------------------------------------------===//
1385
1386LogicalResult spirv::SMulExtendedOp::verify() {
1387 return ::verifyArithmeticExtendedBinaryOp(*this);
1388}
1389
1390ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1392 return ::parseArithmeticExtendedBinaryOp(parser, result);
1393}
1394
1395void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1396 ::printArithmeticExtendedBinaryOp(*this, printer);
1397}
1398
1399//===----------------------------------------------------------------------===//
1400// spirv.UMulExtended
1401//===----------------------------------------------------------------------===//
1402
1403LogicalResult spirv::UMulExtendedOp::verify() {
1404 return ::verifyArithmeticExtendedBinaryOp(*this);
1405}
1406
1407ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1409 return ::parseArithmeticExtendedBinaryOp(parser, result);
1410}
1411
1412void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1413 ::printArithmeticExtendedBinaryOp(*this, printer);
1414}
1415
1416//===----------------------------------------------------------------------===//
1417// spirv.MemoryBarrierOp
1418//===----------------------------------------------------------------------===//
1419
1420LogicalResult spirv::MemoryBarrierOp::verify() {
1421 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1422}
1423
1424//===----------------------------------------------------------------------===//
1425// spirv.module
1426//===----------------------------------------------------------------------===//
1427
1428void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1429 std::optional<StringRef> name) {
1430 OpBuilder::InsertionGuard guard(builder);
1431 builder.createBlock(state.addRegion());
1432 if (name) {
1434 builder.getStringAttr(*name));
1435 }
1436}
1437
1438void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1439 spirv::AddressingModel addressingModel,
1440 spirv::MemoryModel memoryModel,
1441 std::optional<VerCapExtAttr> vceTriple,
1442 std::optional<StringRef> name) {
1443 state.addAttribute(
1444 "addressing_model",
1445 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1446 state.addAttribute("memory_model",
1447 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1448 OpBuilder::InsertionGuard guard(builder);
1449 builder.createBlock(state.addRegion());
1450 if (vceTriple)
1451 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1452 if (name)
1454 builder.getStringAttr(*name));
1455}
1456
1457ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1459 Region *body = result.addRegion();
1460
1461 // If the name is present, parse it.
1462 StringAttr nameAttr;
1464 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1465
1466 // Parse attributes
1467 spirv::AddressingModel addrModel;
1468 spirv::MemoryModel memoryModel;
1470 result) ||
1472 result))
1473 return failure();
1474
1475 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1476 spirv::VerCapExtAttr vceTriple;
1477 if (parser.parseAttribute(vceTriple,
1478 spirv::ModuleOp::getVCETripleAttrName(),
1479 result.attributes))
1480 return failure();
1481 }
1482
1483 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1484 parser.parseRegion(*body, /*arguments=*/{}))
1485 return failure();
1486
1487 // Make sure we have at least one block.
1488 if (body->empty())
1489 body->push_back(new Block());
1490
1491 return success();
1492}
1493
1494void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1495 if (std::optional<StringRef> name = getName()) {
1496 printer << ' ';
1497 printer.printSymbolName(*name);
1498 }
1499
1500 SmallVector<StringRef, 2> elidedAttrs;
1501
1502 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1503 << spirv::stringifyMemoryModel(getMemoryModel());
1504 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1505 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1506 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1508
1509 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1510 printer << " requires " << *triple;
1511 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1512 }
1513
1514 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1515 printer << ' ';
1516 printer.printRegion(getRegion());
1517}
1518
1519LogicalResult spirv::ModuleOp::verifyRegions() {
1520 Dialect *dialect = (*this)->getDialect();
1522 entryPoints;
1523 mlir::SymbolTable table(*this);
1524
1525 for (auto &op : *getBody()) {
1526 if (op.getDialect() != dialect)
1527 return op.emitError("'spirv.module' can only contain spirv.* ops");
1528
1529 // For EntryPoint op, check that the function and execution model is not
1530 // duplicated in EntryPointOps. Also verify that the interface specified
1531 // comes from globalVariables here to make this check cheaper.
1532 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1533 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1534 if (!funcOp) {
1535 return entryPointOp.emitError("function '")
1536 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1537 }
1538 if (auto interface = entryPointOp.getInterface()) {
1539 for (Attribute varRef : interface) {
1540 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1541 if (!varSymRef) {
1542 return entryPointOp.emitError(
1543 "expected symbol reference for interface "
1544 "specification instead of '")
1545 << varRef;
1546 }
1547 auto variableOp =
1548 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1549 if (!variableOp) {
1550 return entryPointOp.emitError("expected spirv.GlobalVariable "
1551 "symbol reference instead of'")
1552 << varSymRef << "'";
1553 }
1554 }
1555 }
1556
1557 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1558 funcOp, entryPointOp.getExecutionModel());
1559 if (!entryPoints.try_emplace(key, entryPointOp).second)
1560 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1561 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1562 // If the function is external and does not have 'Import'
1563 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1564 // LinkageAttributes is used to import external functions.
1565 auto linkageAttr = funcOp.getLinkageAttributes();
1566 auto hasImportLinkage =
1567 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1568 spirv::LinkageType::Import);
1569 if (funcOp.isExternal() && !hasImportLinkage)
1570 return op.emitError(
1571 "'spirv.module' cannot contain external functions "
1572 "without 'Import' linkage_attributes (LinkageAttributes)");
1573
1574 // TODO: move this check to spirv.func.
1575 for (auto &block : funcOp)
1576 for (auto &op : block) {
1577 if (op.getDialect() != dialect)
1578 return op.emitError(
1579 "functions in 'spirv.module' can only contain spirv.* ops");
1580 }
1581 }
1582 }
1583
1584 return success();
1585}
1586
1587//===----------------------------------------------------------------------===//
1588// spirv.mlir.referenceof
1589//===----------------------------------------------------------------------===//
1590
1591LogicalResult spirv::ReferenceOfOp::verify() {
1592 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1593 (*this)->getParentOp(), getSpecConstAttr());
1594 Type constType;
1595
1596 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1597 if (specConstOp)
1598 constType = specConstOp.getDefaultValue().getType();
1599
1600 auto specConstCompositeOp =
1601 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1602 if (specConstCompositeOp)
1603 constType = specConstCompositeOp.getType();
1604
1605 if (!specConstOp && !specConstCompositeOp)
1606 return emitOpError(
1607 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1608
1609 if (getReference().getType() != constType)
1610 return emitOpError("result type mismatch with the referenced "
1611 "specialization constant's type");
1612
1613 return success();
1614}
1615
1616//===----------------------------------------------------------------------===//
1617// spirv.SpecConstant
1618//===----------------------------------------------------------------------===//
1619
1620ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1622 StringAttr nameAttr;
1623 Attribute valueAttr;
1624 StringRef defaultValueAttrName =
1625 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1626
1627 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1628 result.attributes))
1629 return failure();
1630
1631 // Parse optional spec_id.
1632 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1633 IntegerAttr specIdAttr;
1634 if (parser.parseLParen() ||
1635 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1636 parser.parseRParen())
1637 return failure();
1638 }
1639
1640 if (parser.parseEqual() ||
1641 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1642 return failure();
1643
1644 return success();
1645}
1646
1647void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1648 printer << ' ';
1649 printer.printSymbolName(getSymName());
1650 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1651 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1652 printer << " = " << getDefaultValue();
1653}
1654
1655LogicalResult spirv::SpecConstantOp::verify() {
1656 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1657 if (specID.getValue().isNegative())
1658 return emitOpError("SpecId cannot be negative");
1659
1660 auto value = getDefaultValue();
1661 if (isa<IntegerAttr, FloatAttr>(value)) {
1662 // Make sure bitwidth is allowed.
1663 if (!isa<spirv::SPIRVType>(value.getType()))
1664 return emitOpError("default value bitwidth disallowed");
1665 return success();
1666 }
1667 return emitOpError(
1668 "default value can only be a bool, integer, or float scalar");
1669}
1670
1671//===----------------------------------------------------------------------===//
1672// spirv.VectorShuffle
1673//===----------------------------------------------------------------------===//
1674
1675LogicalResult spirv::VectorShuffleOp::verify() {
1676 VectorType resultType = cast<VectorType>(getType());
1677
1678 size_t numResultElements = resultType.getNumElements();
1679 if (numResultElements != getComponents().size())
1680 return emitOpError("result type element count (")
1681 << numResultElements
1682 << ") mismatch with the number of component selectors ("
1683 << getComponents().size() << ")";
1684
1685 size_t totalSrcElements =
1686 cast<VectorType>(getVector1().getType()).getNumElements() +
1687 cast<VectorType>(getVector2().getType()).getNumElements();
1688
1689 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1690 uint32_t index = selector.getZExtValue();
1691 if (index >= totalSrcElements &&
1692 index != std::numeric_limits<uint32_t>().max())
1693 return emitOpError("component selector ")
1694 << index << " out of range: expected to be in [0, "
1695 << totalSrcElements << ") or 0xffffffff";
1696 }
1697 return success();
1698}
1699
1700//===----------------------------------------------------------------------===//
1701// spirv.MatrixTimesScalar
1702//===----------------------------------------------------------------------===//
1703
1704LogicalResult spirv::MatrixTimesScalarOp::verify() {
1705 Type elementType =
1708 [](auto matrixType) { return matrixType.getElementType(); })
1709 .Default(nullptr);
1710
1711 assert(elementType && "Unhandled type");
1712
1713 // Check that the scalar type is the same as the matrix element type.
1714 if (getScalar().getType() != elementType)
1715 return emitOpError("input matrix components' type and scaling value must "
1716 "have the same type");
1717
1718 return success();
1719}
1720
1721//===----------------------------------------------------------------------===//
1722// spirv.Transpose
1723//===----------------------------------------------------------------------===//
1724
1725LogicalResult spirv::TransposeOp::verify() {
1726 auto inputMatrix = cast<spirv::MatrixType>(getMatrix().getType());
1727 auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
1728
1729 // Verify that the input and output matrices have correct shapes.
1730 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1731 return emitError("input matrix rows count must be equal to "
1732 "output matrix columns count");
1733
1734 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1735 return emitError("input matrix columns count must be equal to "
1736 "output matrix rows count");
1737
1738 // Verify that the input and output matrices have the same component type
1739 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1740 return emitError("input and output matrices must have the same "
1741 "component type");
1742
1743 return success();
1744}
1745
1746//===----------------------------------------------------------------------===//
1747// spirv.MatrixTimesVector
1748//===----------------------------------------------------------------------===//
1749
1750LogicalResult spirv::MatrixTimesVectorOp::verify() {
1751 auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
1752 auto vectorType = cast<VectorType>(getVector().getType());
1753 auto resultType = cast<VectorType>(getType());
1754
1755 if (matrixType.getNumColumns() != vectorType.getNumElements())
1756 return emitOpError("matrix columns (")
1757 << matrixType.getNumColumns() << ") must match vector operand size ("
1758 << vectorType.getNumElements() << ")";
1759
1760 if (resultType.getNumElements() != matrixType.getNumRows())
1761 return emitOpError("result size (")
1762 << resultType.getNumElements() << ") must match the matrix rows ("
1763 << matrixType.getNumRows() << ")";
1764
1765 if (matrixType.getElementType() != resultType.getElementType())
1766 return emitOpError("matrix and result element types must match");
1767
1768 return success();
1769}
1770
1771//===----------------------------------------------------------------------===//
1772// spirv.VectorTimesMatrix
1773//===----------------------------------------------------------------------===//
1774
1775LogicalResult spirv::VectorTimesMatrixOp::verify() {
1776 auto vectorType = cast<VectorType>(getVector().getType());
1777 auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
1778 auto resultType = cast<VectorType>(getType());
1779
1780 if (matrixType.getNumRows() != vectorType.getNumElements())
1781 return emitOpError("number of components in vector must equal the number "
1782 "of components in each column in matrix");
1783
1784 if (resultType.getNumElements() != matrixType.getNumColumns())
1785 return emitOpError("number of columns in matrix must equal the number of "
1786 "components in result");
1787
1788 if (matrixType.getElementType() != resultType.getElementType())
1789 return emitOpError("matrix must be a matrix with the same component type "
1790 "as the component type in result");
1791
1792 return success();
1793}
1794
1795//===----------------------------------------------------------------------===//
1796// spirv.MatrixTimesMatrix
1797//===----------------------------------------------------------------------===//
1798
1799LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1800 auto leftMatrix = cast<spirv::MatrixType>(getLeftmatrix().getType());
1801 auto rightMatrix = cast<spirv::MatrixType>(getRightmatrix().getType());
1802 auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
1803
1804 // left matrix columns' count and right matrix rows' count must be equal
1805 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1806 return emitError("left matrix columns' count must be equal to "
1807 "the right matrix rows' count");
1808
1809 // right and result matrices columns' count must be the same
1810 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1811 return emitError(
1812 "right and result matrices must have equal columns' count");
1813
1814 // right and result matrices component type must be the same
1815 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1816 return emitError("right and result matrices' component type must"
1817 " be the same");
1818
1819 // left and result matrices component type must be the same
1820 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1821 return emitError("left and result matrices' component type"
1822 " must be the same");
1823
1824 // left and result matrices rows count must be the same
1825 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1826 return emitError("left and result matrices must have equal rows' count");
1827
1828 return success();
1829}
1830
1831//===----------------------------------------------------------------------===//
1832// spirv.SpecConstantComposite
1833//===----------------------------------------------------------------------===//
1834
1835ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1837
1838 StringAttr compositeName;
1839 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1840 result.attributes))
1841 return failure();
1842
1843 if (parser.parseLParen())
1844 return failure();
1845
1846 SmallVector<Attribute, 4> constituents;
1847
1848 do {
1849 // The name of the constituent attribute isn't important
1850 const char *attrName = "spec_const";
1851 FlatSymbolRefAttr specConstRef;
1852 NamedAttrList attrs;
1853
1854 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1855 return failure();
1856
1857 constituents.push_back(specConstRef);
1858 } while (!parser.parseOptionalComma());
1859
1860 if (parser.parseRParen())
1861 return failure();
1862
1863 StringAttr compositeSpecConstituentsName =
1864 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1865 result.addAttribute(compositeSpecConstituentsName,
1866 parser.getBuilder().getArrayAttr(constituents));
1867
1868 Type type;
1869 if (parser.parseColonType(type))
1870 return failure();
1871
1872 StringAttr typeAttrName =
1873 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1874 result.addAttribute(typeAttrName, TypeAttr::get(type));
1875
1876 return success();
1877}
1878
1879void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1880 printer << " ";
1881 printer.printSymbolName(getSymName());
1882 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1883 << ") : " << getType();
1884}
1885
1886LogicalResult spirv::SpecConstantCompositeOp::verify() {
1887 auto cType = dyn_cast<spirv::CompositeType>(getType());
1888 auto constituents = this->getConstituents().getValue();
1889
1890 if (!cType)
1891 return emitError("result type must be a composite type, but provided ")
1892 << getType();
1893
1894 if (isa<spirv::CooperativeMatrixType>(cType))
1895 return emitError("unsupported composite type ") << cType;
1896 if (constituents.size() != cType.getNumElements())
1897 return emitError("has incorrect number of operands: expected ")
1898 << cType.getNumElements() << ", but provided "
1899 << constituents.size();
1900
1901 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1902 auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
1903
1904 auto constituentSpecConstOp =
1905 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1906 (*this)->getParentOp(), constituent.getAttr()));
1907
1908 if (constituentSpecConstOp.getDefaultValue().getType() !=
1909 cType.getElementType(index))
1910 return emitError("has incorrect types of operands: expected ")
1911 << cType.getElementType(index) << ", but provided "
1912 << constituentSpecConstOp.getDefaultValue().getType();
1913 }
1914
1915 return success();
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// spirv.EXTSpecConstantCompositeReplicateOp
1920//===----------------------------------------------------------------------===//
1921
1922ParseResult
1923spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1925 StringAttr compositeName;
1926 FlatSymbolRefAttr specConstRef;
1927 const char *attrName = "spec_const";
1928 NamedAttrList attrs;
1929 Type type;
1930
1931 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1932 result.attributes) ||
1933 parser.parseLParen() ||
1934 parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1935 parser.parseRParen() || parser.parseColonType(type))
1936 return failure();
1937
1938 StringAttr compositeSpecConstituentName =
1939 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1940 result.name);
1941 result.addAttribute(compositeSpecConstituentName, specConstRef);
1942
1943 StringAttr typeAttrName =
1944 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1945 result.addAttribute(typeAttrName, TypeAttr::get(type));
1946
1947 return success();
1948}
1949
1950void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1951 printer << " ";
1952 printer.printSymbolName(getSymName());
1953 printer << " (" << this->getConstituent() << ") : " << getType();
1954}
1955
1956LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1957 auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1958 if (!compositeType)
1959 return emitError("result type must be a composite type, but provided ")
1960 << getType();
1961
1963 (*this)->getParentOp(), this->getConstituent());
1964 if (!constituentOp)
1965 return emitError(
1966 "splat spec constant reference defining constituent not found");
1967
1968 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1969 if (!constituentSpecConstOp)
1970 return emitError("constituent is not a spec constant");
1971
1972 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1973 Type compositeElementType = compositeType.getElementType(0);
1974 if (constituentType != compositeElementType)
1975 return emitError("constituent has incorrect type: expected ")
1976 << compositeElementType << ", but provided " << constituentType;
1977
1978 return success();
1979}
1980
1981//===----------------------------------------------------------------------===//
1982// spirv.SpecConstantOperation
1983//===----------------------------------------------------------------------===//
1984
1985ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1987 Region *body = result.addRegion();
1988
1989 if (parser.parseKeyword("wraps"))
1990 return failure();
1991
1992 body->push_back(new Block);
1993 Block &block = body->back();
1994 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1995
1996 if (!wrappedOp)
1997 return failure();
1998
1999 OpBuilder builder(parser.getContext());
2000 builder.setInsertionPointToEnd(&block);
2001 spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
2002 result.location = wrappedOp->getLoc();
2003
2004 result.addTypes(wrappedOp->getResult(0).getType());
2005
2006 if (parser.parseOptionalAttrDict(result.attributes))
2007 return failure();
2008
2009 return success();
2010}
2011
2012void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
2013 printer << " wraps ";
2014 printer.printGenericOp(&getBody().front().front());
2015}
2016
2017LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2018 Block &block = getRegion().getBlocks().front();
2019
2020 if (block.getOperations().size() != 2)
2021 return emitOpError("expected exactly 2 nested ops");
2022
2023 Operation &enclosedOp = block.getOperations().front();
2024
2026 return emitOpError("invalid enclosed op");
2027
2028 for (auto operand : enclosedOp.getOperands())
2029 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2030 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2031 return emitOpError(
2032 "invalid operand, must be defined by a constant operation");
2033
2034 return success();
2035}
2036
2037//===----------------------------------------------------------------------===//
2038// spirv.GL.FrexpStruct
2039//===----------------------------------------------------------------------===//
2040
2041LogicalResult spirv::GLFrexpStructOp::verify() {
2042 spirv::StructType structTy =
2043 dyn_cast<spirv::StructType>(getResult().getType());
2044
2045 if (structTy.getNumElements() != 2)
2046 return emitError("result type must be a struct type with two memebers");
2047
2048 Type significandTy = structTy.getElementType(0);
2049 Type exponentTy = structTy.getElementType(1);
2050 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2051 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2052
2053 Type operandTy = getOperand().getType();
2054 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2055 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2056
2057 if (significandTy != operandTy)
2058 return emitError("member zero of the resulting struct type must be the "
2059 "same type as the operand");
2060
2061 if (exponentVecTy) {
2062 IntegerType componentIntTy =
2063 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2064 if (!componentIntTy || componentIntTy.getWidth() != 32)
2065 return emitError("member one of the resulting struct type must"
2066 "be a scalar or vector of 32 bit integer type");
2067 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2068 return emitError("member one of the resulting struct type "
2069 "must be a scalar or vector of 32 bit integer type");
2070 }
2071
2072 // Check that the two member types have the same number of components
2073 if (operandVecTy && exponentVecTy &&
2074 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2075 return success();
2076
2077 if (operandFTy && exponentIntTy)
2078 return success();
2079
2080 return emitError("member one of the resulting struct type must have the same "
2081 "number of components as the operand type");
2082}
2083
2084//===----------------------------------------------------------------------===//
2085// spirv.GL.Ldexp
2086//===----------------------------------------------------------------------===//
2087
2088LogicalResult spirv::GLLdexpOp::verify() {
2089 Type significandType = getX().getType();
2090 Type exponentType = getExp().getType();
2091
2092 if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
2093 return emitOpError("operands must both be scalars or vectors");
2094
2095 auto getNumElements = [](Type type) -> unsigned {
2096 if (auto vectorType = dyn_cast<VectorType>(type))
2097 return vectorType.getNumElements();
2098 return 1;
2099 };
2100
2101 if (getNumElements(significandType) != getNumElements(exponentType))
2102 return emitOpError("operands must have the same number of elements");
2103
2104 return success();
2105}
2106
2107//===----------------------------------------------------------------------===//
2108// spirv.ShiftLeftLogicalOp
2109//===----------------------------------------------------------------------===//
2110
2111LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2112 return verifyShiftOp(*this);
2113}
2114
2115//===----------------------------------------------------------------------===//
2116// spirv.ShiftRightArithmeticOp
2117//===----------------------------------------------------------------------===//
2118
2119LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2120 return verifyShiftOp(*this);
2121}
2122
2123//===----------------------------------------------------------------------===//
2124// spirv.ShiftRightLogicalOp
2125//===----------------------------------------------------------------------===//
2126
2127LogicalResult spirv::ShiftRightLogicalOp::verify() {
2128 return verifyShiftOp(*this);
2129}
2130
2131//===----------------------------------------------------------------------===//
2132// spirv.VectorTimesScalarOp
2133//===----------------------------------------------------------------------===//
2134
2135LogicalResult spirv::VectorTimesScalarOp::verify() {
2136 if (getVector().getType() != getType())
2137 return emitOpError("vector operand and result type mismatch");
2138 auto scalarType = cast<VectorType>(getType()).getElementType();
2139 if (getScalar().getType() != scalarType)
2140 return emitOpError("scalar operand and result element type match");
2141 return success();
2142}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:266
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition SPIRVOps.cpp:562
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition SPIRVOps.cpp:252
static LogicalResult verifyShiftOp(Operation *op)
Definition SPIRVOps.cpp:298
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition SPIRVOps.cpp:185
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition SPIRVOps.cpp:290
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:276
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
SPIR-V struct type.
Definition SPIRVTypes.h:251
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
constexpr StringRef attributeName()
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.