MLIR 23.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
24using namespace mlir::xegpu;
25
26static bool isSharedMemory(const MemRefType &memrefTy) {
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template <typename T>
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast<ShapedType>(type))
55 shape = SmallVector<int64_t>(ty.getShape());
56 else
57 shape.push_back(1);
58 return shape;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
79 VectorType valueTy, int64_t chunkSize,
81
82 auto maskVecTy = dyn_cast<VectorType>(maskTy);
83 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
84 if (!valueTy) {
85 if (chunkSize > 1)
86 return emitError() << "Expecting chunk size == 1 for scalar result";
87 if (maskVecTy || offsetsVecTy)
88 return emitError() << "Expecting scalar mask and offsets.";
89 else if (maskVecTy && offsetsVecTy)
90 return emitError() << "Expecting a vector type result.";
91 return success();
92 }
93
94 auto valueSize = valueTy.getNumElements();
95 // SIMT mode with scalar mask and offsets.
96 if (!maskVecTy && !offsetsVecTy) {
97 if (valueSize != chunkSize)
98 return emitError() << "value elements must match chunk size "
99 << chunkSize;
100 return success();
101 }
102 auto maskShape = getShapeOf(maskTy);
103 auto valueShape = getShapeOf(valueTy);
104
105 if (!maskVecTy)
106 return emitError() << "Expecting a vector type mask.";
107 int64_t maskSize = maskVecTy.getNumElements();
108
109 if (chunkSize > 1) {
110 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
111 return emitError() << "value elements must match chunk size "
112 << chunkSize;
113 } else {
114 if (valueSize != maskSize)
115 return emitError()
116 << "Mask should match value except the chunk size dim.";
117 }
118 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
119 if (maskSize == 1)
120 return success();
121 if (chunkSize > 1)
122 expectedMaskShape.pop_back();
123 if (expectedMaskShape != maskShape)
124 return emitError() << "Mask should match value except the chunk size dim.";
125
126 return success();
127}
128
129LogicalResult
130IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
131 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
133
134 if (!dataTy) {
135 if (subgroup_block_io)
136 return emitError() << "subgroup_block_io "
137 "are only allowed when result is a VectorType.";
138 else
139 return success();
140 }
141
142 if (mdescTy.getRank() < 2)
143 return emitError() << "mem_desc must be 2D or greater.";
144
145 ArrayRef<int64_t> dataShape = dataTy.getShape();
146 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
147
148 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
149 ArrayAttr strideAttr = mdescTy.getStrideAttr();
150 SmallVector<int64_t> strides;
151 for (Attribute attr : strideAttr.getValue()) {
152 strides.push_back(cast<IntegerAttr>(attr).getInt());
153 }
154 if (subgroup_block_io && layout) {
155 auto laneData = layout.getEffectiveLaneDataAsInt();
156 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
157 if (!laneData.empty()) {
158 bool isLaneDataContiguous =
159 std::all_of(laneData.begin(), std::prev(laneData.end()),
160 [](int x) { return x == 1; });
161 if (!isLaneDataContiguous)
162 return emitError() << "With subgroup_block_io, accessed data must be "
163 "contiguous and coalesced.";
164 for (size_t i = 0; i < laneData.size(); ++i) {
165 if (laneLayout[i] != blockShape[i])
166 return emitError() << "With subgroup_block_io, the block shape must "
167 "match the lane layout.";
168 if (laneLayout[i] != 1 && strides[i] != 1)
169 return emitError() << "With subgroup_block_io, the distributed "
170 "dimensions must be contiguous.";
171 }
172 }
173 }
174
175 if (layout && !layout.isDistributable(
176 SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
177 return emitError() << "Value shape is not distributable with the layout";
178
179 if (dataShape.size() == 2) {
180 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
181 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
182 return emitError() << "data shape must not exceed mem_desc shape.";
183 } else {
184 // if the subgroup_block_io attribute is set, mdescTy must have block
185 // attribute
186 if (subgroup_block_io && !blockShape.size())
187 return emitError() << "mem_desc must have block attribute when "
188 "subgroup_block_io is set.";
189 // if the subgroup_block_io attribute is set, the memdesc should be row
190 // major
191 if (subgroup_block_io && mdescTy.isColMajor())
192 return emitError() << "mem_desc should be row major when "
193 "subgroup_block_io is set.";
194 }
195
196 return success();
197}
198
199//===----------------------------------------------------------------------===//
200// XeGPU_CreateNdDescOp
201//===----------------------------------------------------------------------===//
202
203void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
204 Type tdesc, TypedValue<MemRefType> source) {
205 [[maybe_unused]] auto ty = source.getType();
206 assert(ty.hasStaticShape() && "expecting a memref with static shape");
207
208 build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
209 ValueRange({}) /* empty dynamic shape */,
210 ValueRange({}) /* empty dynamic strides */,
211 DenseI64ArrayAttr({}) /* const offsets */,
212 DenseI64ArrayAttr({}) /* empty const shape*/,
213 DenseI64ArrayAttr({}) /* empty const strides*/);
214}
215
216void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
217 Type tdesc, Value source,
220 Type srcTy = source.getType();
221 assert((isa<IntegerType, MemRefType>(srcTy)) &&
222 "Source has to be either int or memref.");
223
224 llvm::SmallVector<Value> dynamicShape;
225 llvm::SmallVector<Value> dynamicStrides;
226
227 llvm::SmallVector<int64_t> staticShape;
228 llvm::SmallVector<int64_t> staticStrides;
229
230 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
231 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
232
233 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
234 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
235
236 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
237 auto memrefShape = memrefTy.getShape();
238 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
239
240 // if shape and strides are from Memref, we don't need attributes for them
241 // to keep the IR print clean (only do so for full-static case, otherwise
242 // printer would fail trying to print empty array-attr).
243 if (staticShape == memrefShape && staticStrides == memrefStrides &&
244 dynamicShape.empty() && dynamicStrides.empty()) {
245 staticShapeAttr = DenseI64ArrayAttr();
246 staticStridesAttr = DenseI64ArrayAttr();
247 }
248 }
249
250 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
251 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
252 staticStridesAttr);
253}
254
255void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
256 Type tdesc, TypedValue<MemRefType> source,
258 [[maybe_unused]] auto ty = source.getType();
259 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
260
261 llvm::SmallVector<int64_t> staticOffsets;
262 llvm::SmallVector<Value> dynamicOffsets;
263 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
264
265 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
266 ValueRange({}) /* empty dynamic shape */,
267 ValueRange({}) /* empty dynamic strides */,
268 builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
269 {} /* empty const shape*/, {} /* empty const strides*/);
270}
271
272void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
273 Type tdesc, Value source,
277 assert(!shape.empty() && !offsets.empty() && !strides.empty() &&
278 shape.size() == strides.size() && shape.size() == offsets.size());
279
280 Type srcTy = source.getType();
281 assert((isa<IntegerType, MemRefType>(srcTy)) &&
282 "Source has to be either int or memref.");
283
284 llvm::SmallVector<Value> dynamicOffsets;
285 llvm::SmallVector<Value> dynamicShape;
286 llvm::SmallVector<Value> dynamicStrides;
287
288 llvm::SmallVector<int64_t> staticOffsets;
289 llvm::SmallVector<int64_t> staticShape;
290 llvm::SmallVector<int64_t> staticStrides;
291
292 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
293 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
294 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
295
296 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
297 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
298 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
299
300 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
301 auto memrefShape = memrefTy.getShape();
302 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
303
304 // if shape and strides are from Memref, we don't need attributes for them
305 // to keep the IR print clean (only do so for full-static case, otherwise
306 // printer would fail trying to print empty array-attr).
307 if (staticShape == memrefShape && staticStrides == memrefStrides &&
308 dynamicShape.empty() && dynamicStrides.empty()) {
309 staticShapeAttr = DenseI64ArrayAttr();
310 staticStridesAttr = DenseI64ArrayAttr();
311 }
312 }
313
314 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
315 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
316}
317
318LogicalResult CreateNdDescOp::verify() {
319 size_t rank = getMixedSizes().size();
320 bool invalidRank = rank != getMixedStrides().size();
321 bool invalidElemTy = false;
322
323 // Memory space of created TensorDesc should match with the source.
324 // Both source and TensorDesc are considered for global memory by default,
325 // if the memory scope attr is not specified. If source is an integer,
326 // it is considered as ptr to global memory.
327 auto srcMemorySpace = getSourceMemorySpace();
328 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
329 if (srcMemorySpace != tdescMemorySpace)
330 return emitOpError("Memory space mismatch.")
331 << " Source: " << srcMemorySpace
332 << ", TensorDesc: " << tdescMemorySpace;
333
334 if (size_t offsetRank = getMixedOffsets().size())
335 invalidRank |= (offsetRank != rank);
336
337 // check source type matches the rank if it is a memref.
338 // It also should have the same ElementType as TensorDesc.
339 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
340 invalidElemTy |= memrefTy.getElementType() != getElementType();
341
342 if (llvm::isa<IntegerType>(getSourceType())) {
343 // strides and shape must present for integer source.
344 if (getMixedStrides().empty() || getMixedSizes().empty())
345 return emitOpError("expecting strides and shape to be present for "
346 "integer source.");
347 }
348
349 if (invalidRank)
350 return emitOpError(
351 "Expecting the rank of shape, strides, offsets, and source (if source "
352 "is a memref) should match with each other.");
353
354 // check result TensorDesc rank
355 if (getType().getRank() > (int64_t)rank)
356 return emitOpError(
357 "Expecting the TensorDesc rank is not greater than the "
358 "ranks of shape, strides, offsets or the memref source.");
359
360 if (invalidElemTy)
361 return emitOpError("TensorDesc should have the same element "
362 "type with the source if it is a memref.\n");
363
364 return success();
365}
366
368 OpAsmParser &parser,
370 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
372
373 SmallVector<int64_t, 4> integerVals;
374 auto parseIntegerOrValue = [&]() {
376 auto res = parser.parseOptionalOperand(operand);
377
378 if (res.has_value() && succeeded(res.value())) {
379 values.push_back(operand);
380 integerVals.push_back(ShapedType::kDynamic);
381 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
382 return failure();
383 } else {
384 int64_t integer;
385 if (failed(parser.parseInteger(integer)))
386 return failure();
387 integerVals.push_back(integer);
388 }
389 return success();
390 };
391
392 // If the optional values are given there must be left bracket
393 if (parser.parseOptionalLSquare().succeeded()) {
394 if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
395 parser.parseRSquare())
396 return parser.emitError(parser.getNameLoc())
397 << "expected a list of SSA values or integers";
398 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
399 return success();
400 }
401
402 return success();
403}
404
406 OperandRange values,
407 DenseI64ArrayAttr integers) {
408 if (!integers || integers.empty())
409 return;
410 printDynamicIndexList(printer, op, values, integers,
411 /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
412}
413//===----------------------------------------------------------------------===//
414// XeGPU_PrefetchNdOp
415//===----------------------------------------------------------------------===//
416
417void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
418 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
419 xegpu::CachePolicyAttr l2_hint,
420 xegpu::CachePolicyAttr l3_hint) {
421
422 return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
423 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
424}
425
426void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
427 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
428 xegpu::CachePolicyAttr l1_hint,
429 xegpu::CachePolicyAttr l2_hint,
430 xegpu::CachePolicyAttr l3_hint,
431 xegpu::DistributeLayoutAttr layout) {
432 SmallVector<Value> dynamicOffsets;
433 SmallVector<int64_t> staticOffsets;
434 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
435
436 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
437
438 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
439 l2_hint, l3_hint, /*anchor_layout=*/layout);
440}
441
442LogicalResult PrefetchNdOp::verify() {
443 auto tdescTy = getTensorDescType();
444
445 if (!isReadHintOrNone(getL1HintAttr()))
446 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
447
448 if (!isReadHintOrNone(getL2HintAttr()))
449 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
450
451 if (!isReadHintOrNone(getL3HintAttr()))
452 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
453
454 int64_t tDescRank = tdescTy.getRank();
455 int64_t offsetSize = getMixedOffsets().size();
456 if (offsetSize != 0 && offsetSize != tDescRank)
457 return emitOpError(
458 "Mismatched ranks between offsets and tensor descriptor");
459
460 if (auto layout = getAnchorLayout()) {
461 if (!layout.isDistributable(getShapeOf(tdescTy)))
462 return emitOpError(
463 "TensorDesc shape is not distributable with the layout");
464 }
465
466 return success();
467}
468
469//===----------------------------------------------------------------------===//
470// XeGPU_LoadNdOp
471//===----------------------------------------------------------------------===//
472
473void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
474 Value tensorDesc, UnitAttr packed,
475 DenseI64ArrayAttr transpose,
476 xegpu::CachePolicyAttr l1_hint,
477 xegpu::CachePolicyAttr l2_hint,
478 xegpu::CachePolicyAttr l3_hint) {
479
480 return build(builder, state, retType, tensorDesc, ValueRange(),
481 DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
482 l3_hint, /*anchor_layout=*/nullptr);
483}
484
485void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
486 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
487 UnitAttr packed, DenseI64ArrayAttr transpose,
488 xegpu::CachePolicyAttr l1_hint,
489 xegpu::CachePolicyAttr l2_hint,
490 xegpu::CachePolicyAttr l3_hint,
491 xegpu::DistributeLayoutAttr layout) {
492 SmallVector<Value> dynamicOffsets;
493 SmallVector<int64_t> staticOffsets;
494 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
495
496 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
497
498 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
499 packed, transpose, l1_hint, l2_hint, l3_hint,
500 /*anchor_layout=*/layout);
501}
502
503LogicalResult LoadNdOp::verify() {
504 auto tdescTy = getTensorDescType();
505 auto valueTy = getType();
506
507 if (tdescTy.getRank() > 2)
508 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
509
510 if (!valueTy)
511 return emitOpError("Invalid result, it should be a VectorType.\n");
512
513 if (!isReadHintOrNone(getL1HintAttr()))
514 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
515
516 if (!isReadHintOrNone(getL2HintAttr()))
517 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
518
519 if (!isReadHintOrNone(getL3HintAttr()))
520 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
521
522 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
523 int valueElems = valueTy.getNumElements();
524
525 // If the result vector is 1D and has less elements than the tensor
526 // descriptor, it is supposed to be a SIMT op. The layout attribute in
527 // tensor_desc is not needed.
528 if (valueElems < tdescElems && valueTy.getRank() == 1) {
529 // SIMT mode doesn't need LayoutAttr.
530 if (tdescTy.getLayoutAttr())
531 return emitOpError()
532 << "TensorDesc doesn't need LayoutAttr for SIMT code";
533
534 // For SIMT code, the load is evenly distributed across all lanes in a
535 // subgroup. Since subgroup size is arch dependent, we only check even
536 // distribution here.
537 if (tdescElems % valueElems)
538 return emitOpError()
539 << "Result shape " << makeString(getShapeOf(valueTy))
540 << " is not a valid distribution for tensor descriptor "
541 << tdescTy;
542
543 return success();
544 }
545
546 // Check SIMD mode.
547 auto tdescShape = getShapeOf(tdescTy);
548 auto valueShape = getShapeOf(valueTy);
549
550 if (getTranspose()) {
551 auto trans = getTranspose().value();
552 // Make sure the transpose value is valid, and apply it
553 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
554 tdescShape = applyPermutation(tdescShape, trans);
555 else
556 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
557 }
558
559 if (getPacked()) {
560 if (tdescTy.getRank() == 2) {
561 const int axis = 0;
562 auto vnni_factor = valueShape.back();
563 tdescShape[axis] /= vnni_factor;
564 tdescShape.push_back(vnni_factor);
565 } else {
566 mlir::emitWarning(getLoc())
567 << "Invalid Packed Attr. It is ignored (available for 2D "
568 "TensorDesc only).";
569 }
570 }
571
572 auto array_len = tdescTy.getArrayLength();
573 if (array_len > 1)
574 tdescShape.insert(tdescShape.begin(), array_len);
575
576 if (tdescShape != valueShape)
577 return emitOpError() << "Result shape " << makeString(valueShape)
578 << " is not consistent with tensor descriptor "
579 << tdescTy;
580
581 int64_t tDescRank = tdescTy.getRank();
582 int64_t offsetSize = getMixedOffsets().size();
583 if (offsetSize != 0 && offsetSize != tDescRank)
584 return emitOpError(
585 "Mismatched ranks between offsets and tensor descriptor");
586
587 if (auto layout = getAnchorLayout()) {
588 if (!layout.isDistributable(getShapeOf(tdescTy)))
589 return emitOpError(
590 "TensorDesc shape is not distributable with the layout");
591 }
592
593 return success();
594}
595
596//===----------------------------------------------------------------------===//
597// XeGPU_StoreNdOp
598//===----------------------------------------------------------------------===//
599
600void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
601 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
602 xegpu::CachePolicyAttr l2_hint,
603 xegpu::CachePolicyAttr l3_hint) {
604
605 return build(builder, state, value, tensorDesc, ValueRange(),
606 DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
607 /*anchor_layout=*/nullptr);
608}
609
610void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
611 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
612 xegpu::CachePolicyAttr l1_hint,
613 xegpu::CachePolicyAttr l2_hint,
614 xegpu::CachePolicyAttr l3_hint,
615 xegpu::DistributeLayoutAttr layout) {
616 SmallVector<Value> dynamicOffsets;
617 SmallVector<int64_t> staticOffsets;
618 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
619
620 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
621
622 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
623 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
624}
625
626LogicalResult StoreNdOp::verify() {
627 auto dstTy = getTensorDescType(); // Tile
628 auto valTy = getValueType(); // Vector
629
630 if (dstTy.getRank() > 2)
631 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
632
633 if (!valTy)
634 return emitOpError("Expecting a VectorType result.\n");
635
636 if (!isWriteHintOrNone(getL1HintAttr()))
637 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
638
639 if (!isWriteHintOrNone(getL2HintAttr()))
640 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
641
642 if (!isWriteHintOrNone(getL3HintAttr()))
643 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
644
645 auto array_len = dstTy.getArrayLength();
646 if (array_len > 1)
647 return emitOpError("array length is not supported by store_nd.\n");
648
649 auto tdescElems = dstTy.getNumElements();
650 auto valueElems = valTy.getNumElements();
651
652 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
653 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
654 // in tensor_desc is not needed.
655 if (valTy.getRank() == 1 && valueElems < tdescElems) {
656 // SIMT mode doesn't need LayoutAttr.
657 if (dstTy.getLayoutAttr())
658 return emitOpError()
659 << "TensorDesc doesn't need LayoutAttr for SIMT code";
660
661 if (tdescElems % valueElems)
662 return emitOpError()
663 << "Value shape " << makeString(getShapeOf(valTy))
664 << " is not a valid distribution for tensor descriptor " << dstTy;
665
666 return success();
667 }
668
669 // SIMD code should have the same shape as the tensor descriptor.
670 auto tdescShape = getShapeOf(dstTy);
671 auto valueShape = getShapeOf(valTy);
672 if (tdescShape != valueShape)
673 return emitOpError() << "Value shape " << makeString(valueShape)
674 << " is not consistent with tensor descriptor "
675 << dstTy;
676
677 int64_t tDescRank = dstTy.getRank();
678 int64_t offsetSize = getMixedOffsets().size();
679 if (offsetSize != 0 && offsetSize != tDescRank)
680 return emitOpError(
681 "Mismatched ranks between offsets and tensor descriptor");
682
683 if (auto layout = getAnchorLayout()) {
684 if (!layout.isDistributable(tdescShape))
685 return emitOpError(
686 "TensorDesc shape is not distributable with the layout");
687 }
688
689 return success();
690}
691
692//===----------------------------------------------------------------------===//
693// XeGPU_UpdateNDOffsetOp
694//===----------------------------------------------------------------------===//
695LogicalResult UpdateNdOffsetOp::verify() {
696 auto ty = getTensorDescType();
697
698 // number of offsets specified must match the rank of the tensor descriptor
699 if (ty.getRank() != (int64_t)getNumOffsets()) {
700 return emitOpError("Invalid number of offsets.");
701 }
702 return success();
703}
704
705//===----------------------------------------------------------------------===//
706// XeGPU_PrefetchOp
707//===----------------------------------------------------------------------===//
708LogicalResult PrefetchOp::verify() {
709 auto tdescTy = getTensorDescType();
710
711 if (!tdescTy && !getOffsets())
712 return emitOpError("Expects offsets.");
713
714 if (tdescTy && getOffsets())
715 return emitOpError("offsets not allowed.");
716
717 if (!isReadHintOrNone(getL1HintAttr()))
718 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
719
720 if (!isReadHintOrNone(getL2HintAttr()))
721 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
722
723 if (!isReadHintOrNone(getL3HintAttr()))
724 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
725
726 auto srcTy = getSourceType();
727 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
728 return emitOpError("offset_align_byte is required with integer source.");
729
730 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
731 return emitOpError("offset_align_byte only allowed with integer source.");
732
733 if (auto layout = getAnchorLayout()) {
734 // get the offset operand and its shape
735 if (auto offsets = getOffsets()) {
736 auto offsetsTy = offsets.getType();
737 if (llvm::isa<VectorType>(offsetsTy) &&
738 !layout.isDistributable(getShapeOf(offsetsTy)))
739 return emitOpError("offset shape is not distributable with the layout");
740 }
741 }
742
743 return success();
744}
745
746void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
747 xegpu::CachePolicyAttr l1_hint,
748 xegpu::CachePolicyAttr l2_hint,
749 xegpu::CachePolicyAttr l3_hint) {
750 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
751 IntegerAttr{}, /*anchor_layout=*/nullptr);
752}
753
754//===----------------------------------------------------------------------===//
755// XeGPU_LoadGatherOp
756//===----------------------------------------------------------------------===//
757LogicalResult LoadGatherOp::verify() {
758 auto tdescTy = getTensorDescType();
759 auto maskTy = getMaskType();
760 auto valueTy = getValueType();
761
762 if (!tdescTy && !getOffsets())
763 return emitOpError("Expects offsets.");
764
765 if (tdescTy && getOffsets())
766 return emitOpError("offsets not allowed.");
767
768 if (!isReadHintOrNone(getL1HintAttr()))
769 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
770
771 if (!isReadHintOrNone(getL2HintAttr()))
772 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
773
774 if (!isReadHintOrNone(getL3HintAttr()))
775 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
776
777 auto srcTy = getSourceType();
778 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
779 auto memTy = dyn_cast<MemRefType>(srcTy);
780
781 if (memTy && (getElementType() != memTy.getElementType()))
782 return emitError() << "Value should have the same element type as MemRef.";
783
784 if (auto layout = getAnchorLayout()) {
785 if (!layout.isDistributable(getShapeOf(valueTy)))
786 return emitOpError("Value shape is not distributable with the layout");
787 }
788
789 auto offsetsTy = getOffsets().getType();
790 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
791 [&]() { return emitOpError(); });
792}
793
794void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
795 Type valueType, Value source, Value mask,
796 xegpu::CachePolicyAttr l1_hint,
797 xegpu::CachePolicyAttr l2_hint,
798 xegpu::CachePolicyAttr l3_hint) {
799 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
800 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
801}
802
803void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
804 Type valueType, Value source,
805 ArrayRef<OpFoldResult> offsets, Value mask,
806 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
807 xegpu::CachePolicyAttr l2_hint,
808 xegpu::CachePolicyAttr l3_hint) {
809 auto loc = source.getLoc();
810 int64_t size = static_cast<int64_t>(offsets.size());
811 auto type = VectorType::get(size, builder.getIndexType());
812 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
813 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
814
815 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
816 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
817}
818
819void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
820 Type valueType, Value source,
821 ArrayRef<OpFoldResult> offsets, Value mask,
822 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
823 xegpu::CachePolicyAttr l2_hint,
824 xegpu::CachePolicyAttr l3_hint,
825 DistributeLayoutAttr layout) {
826 auto loc = source.getLoc();
827 int64_t size = static_cast<int64_t>(offsets.size());
828 auto type = VectorType::get(size, builder.getIndexType());
829 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
830 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
831
832 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
833 l2_hint, l3_hint, layout);
834}
835
836//===----------------------------------------------------------------------===//
837// XeGPU_StoreScatterOp
838//===----------------------------------------------------------------------===//
839LogicalResult StoreScatterOp::verify() {
840 auto tdescTy = getTensorDescType();
841 auto maskTy = getMaskType();
842 auto valueTy = getValueType();
843
844 if (!tdescTy && !getOffsets())
845 return emitOpError("Expects offsets.");
846
847 if (tdescTy && getOffsets())
848 return emitOpError("offsets not allowed.");
849
850 if (!isWriteHintOrNone(getL1HintAttr()))
851 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
852
853 if (!isWriteHintOrNone(getL2HintAttr()))
854 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
855
856 if (!isWriteHintOrNone(getL3HintAttr()))
857 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
858
859 auto destTy = getDestType();
860 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
861 auto memTy = dyn_cast<MemRefType>(destTy);
862
863 if (memTy && (getElementType() != memTy.getElementType()))
864 return emitError() << "Value should have the same element type as MemRef.";
865
866 if (auto layout = getAnchorLayout()) {
867 if (!layout.isDistributable(getShapeOf(valueTy)))
868 return emitOpError("Value shape is not distributable with the layout");
869 }
870
871 auto offsetsTy = getOffsets().getType();
872 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
873 [&]() { return emitOpError(); });
874}
875
876void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
877 Value value, Value dest, Value mask,
878 xegpu::CachePolicyAttr l1_hint,
879 xegpu::CachePolicyAttr l2_hint,
880 xegpu::CachePolicyAttr l3_hint) {
881 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
882 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
883}
884
885void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
886 Value value, Value dest,
887 ArrayRef<OpFoldResult> offsets, Value mask,
888 IntegerAttr chunk_size,
889 xegpu::CachePolicyAttr l1_hint,
890 xegpu::CachePolicyAttr l2_hint,
891 xegpu::CachePolicyAttr l3_hint) {
892 auto loc = dest.getLoc();
893 int64_t size = static_cast<int64_t>(offsets.size());
894 auto type = VectorType::get(size, builder.getIndexType());
895 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
896 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
897
898 // Call the correct builder overload that does not expect result types.
899 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
900 l3_hint, /*anchor_layout=*/nullptr);
901}
902
903void StoreScatterOp::build(
904 OpBuilder &builder, OperationState &state, Value value, Value dest,
905 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
906 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
907 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
908 auto loc = dest.getLoc();
909 int64_t size = static_cast<int64_t>(offsets.size());
910 auto type = VectorType::get(size, builder.getIndexType());
911 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
912 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
913
914 // Call the correct builder overload that does not expect result types.
915 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
916 l3_hint, layout);
917}
918
919//===----------------------------------------------------------------------===//
920// XeGPU_DpasOp
921//===----------------------------------------------------------------------===//
922LogicalResult DpasOp::verify() {
923 int64_t lhsRank = getLhsType().getRank();
924 int64_t rhsRank = getRhsType().getRank();
925 int64_t resRank = getResultType().getRank();
926 auto lhsShape = getLhsType().getShape();
927 auto rhsShape = getRhsType().getShape();
928 auto resShape = getResultType().getShape();
929
930 if (auto cdLayout = getLayoutCd())
931 if (!cdLayout->isDistributable(
932 SmallVector<int64_t>(resShape.begin(), resShape.end())))
933 return emitOpError("Value shape is not distributable with the layout");
934
935 if (auto aLayout = getLayoutA())
936 if (!aLayout->isDistributable(
937 SmallVector<int64_t>(lhsShape.begin(), lhsShape.end())))
938 return emitOpError("Value shape is not distributable with the layout");
939
940 if (auto bLayout = getLayoutB())
941 if (!bLayout->isDistributable(
942 SmallVector<int64_t>(rhsShape.begin(), rhsShape.end())))
943 return emitOpError("Value shape is not distributable with the layout");
944
945 if (getAcc() && getAcc().getType() != getResultType())
946 return emitOpError("Expecting the acc type to be the same as result.");
947
948 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
949 // It skips the semantic check since lack of architecture information.
950 // Users need to ensure the correctness.
951 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
952 auto numElems = getRhsType().getNumElements();
953 auto elemTy = getRhsType().getElementType();
954 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
955 if (numElems % factor != 0)
956 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
957 return success();
958 }
959
960 // SIMD code
961 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
962 return emitOpError(
963 "expecting lhs and result to be a 2D vector, and rhs to be either "
964 "2D or 3D (packed) vector.");
965 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
966 if (bK != lhsShape[1])
967 return emitOpError("K-dimension mismatch.");
968 if (lhsShape[0] != resShape[0])
969 return emitOpError("M-dimension mismatch.");
970 if (rhsShape[1] != resShape[1])
971 return emitOpError("N-dimension mismatch.");
972
973 return success();
974}
975
976//===----------------------------------------------------------------------===//
977// XeGPU_ConvertLayoutOp
978//===----------------------------------------------------------------------===//
979LogicalResult ConvertLayoutOp::verify() {
980 auto srcLayout = getInputLayout();
981 auto resLayout = getTargetLayout();
982 if (!srcLayout)
983 return emitOpError("expected input layout.");
984 if (!resLayout)
985 return emitOpError("expected target layout.");
986
987 // both input and target layouts should be WgLayout or SgLayout at the same
988 // time.
989 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
990 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
991 return emitOpError("expected input layout and target layout be WgLayout or "
992 "SgLayout at the same time.");
993
994 Type srcType = getSource().getType();
995 if (llvm::isa<VectorType>(srcType)) {
996 SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
997 if (!srcLayout.isDistributable(shape))
998 return emitOpError(
999 "invalid input layout, data cannot be evenly distributed.");
1000
1001 if (!resLayout.isDistributable(shape))
1002 return emitOpError(
1003 "invalid target layout, data cannot be evenly distributed.");
1004 }
1005 return mlir::success();
1006}
1007
1008//===----------------------------------------------------------------------===//
1009// XeGPU_LoadMatrixOp
1010//===----------------------------------------------------------------------===//
1011void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1014 DistributeLayoutAttr layout) {
1015 llvm::SmallVector<Value> dynamicOffsets;
1016 llvm::SmallVector<int64_t> staticOffsets;
1017 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1018 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1019 // Call the generated builder with all parameters (including optional ones as
1020 // nullptr/empty)
1021 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1022 /*subgroup_block_io=*/nullptr, layout);
1023}
1024
1025LogicalResult LoadMatrixOp::verify() {
1026
1027 auto resTy = dyn_cast<VectorType>(getRes().getType());
1028 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1029 MemDescType mdescTy = getMemDesc().getType();
1030
1031 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1032 getLayoutAttr(), [&]() { return emitError(); });
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// XeGPU_StoreMatrixOp
1037//===----------------------------------------------------------------------===//
1038void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1041 DistributeLayoutAttr layout) {
1042 llvm::SmallVector<Value> dynamicOffsets;
1043 llvm::SmallVector<int64_t> staticOffsets;
1044 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1045 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1046 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1047 /*subgroup_block_io=*/nullptr, layout);
1048}
1049
1050LogicalResult StoreMatrixOp::verify() {
1051
1052 auto dataTy = dyn_cast<VectorType>(getData().getType());
1053 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1054 MemDescType mdescTy = getMemDesc().getType();
1055 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1056 getLayoutAttr(), [&]() { return emitError(); });
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// XeGPU_TruncfOp
1061//===----------------------------------------------------------------------===//
1062
1063LogicalResult TruncfOp::verify() {
1064 auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
1065 auto resultVecType = dyn_cast<VectorType>(getResult().getType());
1066
1067 if (sourceVecType.getElementTypeBitWidth() <=
1068 resultVecType.getElementTypeBitWidth())
1069 return emitOpError("input type must be wider than result type.");
1070
1071 return success();
1072}
1073
1074//===----------------------------------------------------------------------===//
1075// XeGPU_DpasMxOp
1076//===----------------------------------------------------------------------===//
1077
1078LogicalResult DpasMxOp::verify() {
1079 if (getAcc() && getAcc().getType() != getResultType())
1080 return emitOpError("Expecting the acc type to be the same as result.");
1081
1082 return success();
1083}
1084
1085namespace mlir {
1086#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1087} // namespace mlir
1088#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1089#define GET_OP_CLASSES
1090#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:130
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:405
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:367
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IndexType getIndexType()
Definition Builders.cpp:55
This class represents a diagnostic that is inflight and set to be reported.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:209
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.