MLIR 22.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
24using namespace mlir::xegpu;
25
26static bool isSharedMemory(const MemRefType &memrefTy) {
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template <typename T>
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast<ShapedType>(type))
55 shape = SmallVector<int64_t>(ty.getShape());
56 else
57 shape.push_back(1);
58 return shape;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
78isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79 TensorDescType tdescTy,
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
86 if (!valueTy) {
87 if (chunkSize > 1)
88 return emitError() << "Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() << "Expecting a vector type result.";
91 return success();
92 }
93
94 auto maskShape = getShapeOf(maskTy);
95 auto valueShape = getShapeOf(valueTy);
96 auto tdescShape = getShapeOf(tdescTy);
97
98 if (valueTy.getElementType() != tdescTy.getElementType())
99 return emitError()
100 << "Value should have the same element type as TensorDesc.";
101
102 llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
103 if (chunkSize > 1)
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
106 return emitError()
107 << "Mask should match TensorDesc except the chunk size dim.";
108
109 // a valid shape for SIMT case
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
113 return success();
114 }
115
116 if (tdescShape != valueShape)
117 return emitError() << "Value shape " << makeString(valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
121 return success();
122}
123
124static LogicalResult
126 VectorType valueTy, int64_t chunkSize,
128
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
131 if (!valueTy) {
132 if (chunkSize > 1)
133 return emitError() << "Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() << "Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() << "Expecting a vector type result.";
138 return success();
139 }
140
141 auto valueSize = valueTy.getNumElements();
142 // SIMT mode with scalar mask and offsets.
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() << "value elements must match chunk size "
146 << chunkSize;
147 return success();
148 }
149 auto maskShape = getShapeOf(maskTy);
150 auto valueShape = getShapeOf(valueTy);
151
152 if (!maskVecTy)
153 return emitError() << "Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
155
156 if (chunkSize > 1) {
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() << "value elements must match chunk size "
159 << chunkSize;
160 } else {
161 if (valueSize != maskSize)
162 return emitError()
163 << "Mask should match value except the chunk size dim.";
164 }
165 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
166 if (maskSize == 1)
167 return success();
168 if (chunkSize > 1)
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() << "Mask should match value except the chunk size dim.";
172
173 return success();
174}
175
176LogicalResult
177IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
180
181 if (!dataTy) {
182 if (subgroup_block_io)
183 return emitError() << "subgroup_block_io "
184 "are only allowed when result is a VectorType.";
185 else
186 return success();
187 }
188
189 if (mdescTy.getRank() != 2)
190 return emitError() << "mem_desc must be 2D.";
191
192 ArrayRef<int64_t> dataShape = dataTy.getShape();
193 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194
195 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
197 SmallVector<int64_t> strides;
198 for (Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
200 }
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() << "With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() << "With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() << "With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
218 }
219 }
220 }
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() << "data shape must not exceed mem_desc shape.";
225 } else {
226 // if the subgroup_block_io attribute is set, mdescTy must have block
227 // attribute
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() << "mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
231 // if the subgroup_block_io attribute is set, the memdesc should be row
232 // major
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() << "mem_desc should be row major when "
235 "subgroup_block_io is set.";
236 }
237
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// XeGPU_CreateNdDescOp
243//===----------------------------------------------------------------------===//
244
245void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
246 Type tdesc, TypedValue<MemRefType> source) {
247 [[maybe_unused]] auto ty = source.getType();
248 assert(ty.hasStaticShape() && "expecting a memref with static shape");
249
250 build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
251 ValueRange({}) /* empty dynamic shape */,
252 ValueRange({}) /* empty dynamic strides */,
253 DenseI64ArrayAttr({}) /* const offsets */,
254 DenseI64ArrayAttr({}) /* empty const shape*/,
255 DenseI64ArrayAttr({}) /* empty const strides*/);
256}
257
258void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
259 Type tdesc, Value source,
262 Type srcTy = source.getType();
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
265
266 llvm::SmallVector<Value> dynamicShape;
267 llvm::SmallVector<Value> dynamicStrides;
268
269 llvm::SmallVector<int64_t> staticShape;
270 llvm::SmallVector<int64_t> staticStrides;
271
272 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
273 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
274
275 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
276 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
277
278 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281
282 // if shape and strides are from Memref, we don't need attributes for them
283 // to keep the IR print clean (only do so for full-static case, otherwise
284 // printer would fail trying to print empty array-attr).
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
287 staticShapeAttr = DenseI64ArrayAttr();
288 staticStridesAttr = DenseI64ArrayAttr();
289 }
290 }
291
292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
293 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
294 staticStridesAttr);
295}
296
297void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
298 Type tdesc, TypedValue<MemRefType> source,
300 [[maybe_unused]] auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
302
303 llvm::SmallVector<int64_t> staticOffsets;
304 llvm::SmallVector<Value> dynamicOffsets;
305 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
306
307 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
308 ValueRange({}) /* empty dynamic shape */,
309 ValueRange({}) /* empty dynamic strides */,
310 builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
311 {} /* empty const shape*/, {} /* empty const strides*/);
312}
313
314void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
315 Type tdesc, Value source,
319 assert(!shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() && shape.size() == offsets.size());
321
322 Type srcTy = source.getType();
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
325
326 llvm::SmallVector<Value> dynamicOffsets;
327 llvm::SmallVector<Value> dynamicShape;
328 llvm::SmallVector<Value> dynamicStrides;
329
330 llvm::SmallVector<int64_t> staticOffsets;
331 llvm::SmallVector<int64_t> staticShape;
332 llvm::SmallVector<int64_t> staticStrides;
333
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
336 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
337
338 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
339 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
340 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
341
342 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
345
346 // if shape and strides are from Memref, we don't need attributes for them
347 // to keep the IR print clean (only do so for full-static case, otherwise
348 // printer would fail trying to print empty array-attr).
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
351 staticShapeAttr = DenseI64ArrayAttr();
352 staticStridesAttr = DenseI64ArrayAttr();
353 }
354 }
355
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
358}
359
360LogicalResult CreateNdDescOp::verify() {
361 size_t rank = getMixedSizes().size();
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy = false;
364
365 // Memory space of created TensorDesc should match with the source.
366 // Both source and TensorDesc are considered for global memory by default,
367 // if the memory scope attr is not specified. If source is an integer,
368 // it is considered as ptr to global memory.
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
372 return emitOpError("Memory space mismatch.")
373 << " Source: " << srcMemorySpace
374 << ", TensorDesc: " << tdescMemorySpace;
375
376 if (size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
378
379 // check source type matches the rank if it is a memref.
380 // It also should have the same ElementType as TensorDesc.
381 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
382 invalidElemTy |= memrefTy.getElementType() != getElementType();
383
384 if (llvm::isa<IntegerType>(getSourceType())) {
385 // strides and shape must present for integer source.
386 if (getMixedStrides().empty() || getMixedSizes().empty())
387 return emitOpError("expecting strides and shape to be present for "
388 "integer source.");
389 }
390
391 if (invalidRank)
392 return emitOpError(
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
395
396 // check result TensorDesc rank
397 if (getType().getRank() > (int64_t)rank)
398 return emitOpError(
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
401
402 if (invalidElemTy)
403 return emitOpError("TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
405
406 if (getType().isScattered())
407 return emitOpError("Expects a non-scattered TensorDesc.\n");
408
409 return success();
410}
411
413 OpAsmParser &parser,
415 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
417
418 SmallVector<int64_t, 4> integerVals;
419 auto parseIntegerOrValue = [&]() {
421 auto res = parser.parseOptionalOperand(operand);
422
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
427 return failure();
428 } else {
429 int64_t integer;
430 if (failed(parser.parseInteger(integer)))
431 return failure();
432 integerVals.push_back(integer);
433 }
434 return success();
435 };
436
437 // If the optional values are given there must be left bracket
438 if (parser.parseOptionalLSquare().succeeded()) {
439 if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
440 parser.parseRSquare())
441 return parser.emitError(parser.getNameLoc())
442 << "expected a list of SSA values or integers";
443 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
444 return success();
445 }
446
447 return success();
448}
449
451 OperandRange values,
452 DenseI64ArrayAttr integers) {
453 if (!integers || integers.empty())
454 return;
455 printDynamicIndexList(printer, op, values, integers,
456 /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
457}
458//===----------------------------------------------------------------------===//
459// XeGPU_PrefetchNdOp
460//===----------------------------------------------------------------------===//
461
462void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
466
467 return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
468 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
469}
470
471void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
472 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint) {
476 SmallVector<Value> dynamicOffsets;
477 SmallVector<int64_t> staticOffsets;
478 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
479
480 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
481
482 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
483 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
484}
485
486LogicalResult PrefetchNdOp::verify() {
487 auto tdescTy = getTensorDescType();
488 if (tdescTy.isScattered())
489 return emitOpError("Expects a non-scattered TensorDesc.\n");
490
491 if (!isReadHintOrNone(getL1HintAttr()))
492 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
493
494 if (!isReadHintOrNone(getL2HintAttr()))
495 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
496
497 if (!isReadHintOrNone(getL3HintAttr()))
498 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
499
500 int64_t tDescRank = tdescTy.getRank();
501 int64_t offsetSize = getMixedOffsets().size();
502 if (offsetSize != 0 && offsetSize != tDescRank)
503 return emitOpError(
504 "Mismatched ranks between offsets and tensor descriptor");
505
506 return success();
507}
508
509//===----------------------------------------------------------------------===//
510// XeGPU_LoadNdOp
511//===----------------------------------------------------------------------===//
512
513void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
514 Value tensorDesc, UnitAttr packed,
515 DenseI64ArrayAttr transpose,
516 xegpu::CachePolicyAttr l1_hint,
517 xegpu::CachePolicyAttr l2_hint,
518 xegpu::CachePolicyAttr l3_hint) {
519
520 return build(builder, state, retType, tensorDesc, ValueRange(),
521 DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
522 l3_hint, /*anchor_layout=*/nullptr);
523}
524
525void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
526 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
527 UnitAttr packed, DenseI64ArrayAttr transpose,
528 xegpu::CachePolicyAttr l1_hint,
529 xegpu::CachePolicyAttr l2_hint,
530 xegpu::CachePolicyAttr l3_hint) {
531 SmallVector<Value> dynamicOffsets;
532 SmallVector<int64_t> staticOffsets;
533 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
534
535 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
536
537 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538 packed, transpose, l1_hint, l2_hint, l3_hint,
539 /*anchor_layout=*/nullptr);
540}
541
542LogicalResult LoadNdOp::verify() {
543 auto tdescTy = getTensorDescType();
544 auto valueTy = getType();
545
546 if (tdescTy.isScattered())
547 return emitOpError("Expects a non-scattered TensorDesc.\n");
548
549 if (tdescTy.getRank() > 2)
550 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
551
552 if (!valueTy)
553 return emitOpError("Invalid result, it should be a VectorType.\n");
554
555 if (!isReadHintOrNone(getL1HintAttr()))
556 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
557
558 if (!isReadHintOrNone(getL2HintAttr()))
559 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
560
561 if (!isReadHintOrNone(getL3HintAttr()))
562 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
563
564 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
565 int valueElems = valueTy.getNumElements();
566
567 // If the result vector is 1D and has less elements than the tensor
568 // descriptor, it is supposed to be a SIMT op. The layout attribute in
569 // tensor_desc is not needed.
570 if (valueElems < tdescElems && valueTy.getRank() == 1) {
571 // SIMT mode doesn't need LayoutAttr.
572 if (tdescTy.getLayoutAttr())
573 return emitOpError()
574 << "TensorDesc doesn't need LayoutAttr for SIMT code";
575
576 // For SIMT code, the load is evenly distributed across all lanes in a
577 // subgroup. Since subgroup size is arch dependent, we only check even
578 // distribution here.
579 if (tdescElems % valueElems)
580 return emitOpError()
581 << "Result shape " << makeString(getShapeOf(valueTy))
582 << " is not a valid distribution for tensor descriptor "
583 << tdescTy;
584
585 return success();
586 }
587
588 // Check SIMD mode.
589 auto tdescShape = getShapeOf(tdescTy);
590 auto valueShape = getShapeOf(valueTy);
591
592 if (getTranspose()) {
593 auto trans = getTranspose().value();
594 // Make sure the transpose value is valid, and apply it
595 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
596 tdescShape = applyPermutation(tdescShape, trans);
597 else
598 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
599 }
600
601 if (getPacked()) {
602 if (tdescTy.getRank() == 2) {
603 const int axis = 0;
604 auto vnni_factor = valueShape.back();
605 tdescShape[axis] /= vnni_factor;
606 tdescShape.push_back(vnni_factor);
607 } else {
608 mlir::emitWarning(getLoc())
609 << "Invalid Packed Attr. It is ignored (available for 2D "
610 "TensorDesc only).";
611 }
612 }
613
614 auto array_len = tdescTy.getArrayLength();
615 if (array_len > 1)
616 tdescShape.insert(tdescShape.begin(), array_len);
617
618 if (tdescShape != valueShape)
619 return emitOpError() << "Result shape " << makeString(valueShape)
620 << " is not consistent with tensor descriptor "
621 << tdescTy;
622
623 int64_t tDescRank = tdescTy.getRank();
624 int64_t offsetSize = getMixedOffsets().size();
625 if (offsetSize != 0 && offsetSize != tDescRank)
626 return emitOpError(
627 "Mismatched ranks between offsets and tensor descriptor");
628
629 return success();
630}
631
632//===----------------------------------------------------------------------===//
633// XeGPU_StoreNdOp
634//===----------------------------------------------------------------------===//
635
636void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
637 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
638 xegpu::CachePolicyAttr l2_hint,
639 xegpu::CachePolicyAttr l3_hint) {
640
641 return build(builder, state, value, tensorDesc, ValueRange(),
642 DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
643 /*anchor_layout=*/nullptr);
644}
645
646void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
647 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
648 xegpu::CachePolicyAttr l1_hint,
649 xegpu::CachePolicyAttr l2_hint,
650 xegpu::CachePolicyAttr l3_hint) {
651 SmallVector<Value> dynamicOffsets;
652 SmallVector<int64_t> staticOffsets;
653 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
654
655 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
656
657 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
658 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
659}
660
661LogicalResult StoreNdOp::verify() {
662 auto dstTy = getTensorDescType(); // Tile
663 auto valTy = getValueType(); // Vector
664
665 if (dstTy.isScattered())
666 return emitOpError("Expects a non-scattered TensorDesc.\n");
667
668 if (dstTy.getRank() > 2)
669 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
670
671 if (!valTy)
672 return emitOpError("Expecting a VectorType result.\n");
673
674 if (!isWriteHintOrNone(getL1HintAttr()))
675 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
676
677 if (!isWriteHintOrNone(getL2HintAttr()))
678 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
679
680 if (!isWriteHintOrNone(getL3HintAttr()))
681 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
682
683 auto array_len = dstTy.getArrayLength();
684 if (array_len > 1)
685 return emitOpError("array length is not supported by store_nd.\n");
686
687 auto tdescElems = dstTy.getNumElements();
688 auto valueElems = valTy.getNumElements();
689
690 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
691 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
692 // in tensor_desc is not needed.
693 if (valTy.getRank() == 1 && valueElems < tdescElems) {
694 // SIMT mode doesn't need LayoutAttr.
695 if (dstTy.getLayoutAttr())
696 return emitOpError()
697 << "TensorDesc doesn't need LayoutAttr for SIMT code";
698
699 if (tdescElems % valueElems)
700 return emitOpError()
701 << "Value shape " << makeString(getShapeOf(valTy))
702 << " is not a valid distribution for tensor descriptor " << dstTy;
703
704 return success();
705 }
706
707 // SIMD code should have the same shape as the tensor descriptor.
708 auto tdescShape = getShapeOf(dstTy);
709 auto valueShape = getShapeOf(valTy);
710 if (tdescShape != valueShape)
711 return emitOpError() << "Value shape " << makeString(valueShape)
712 << " is not consistent with tensor descriptor "
713 << dstTy;
714
715 int64_t tDescRank = dstTy.getRank();
716 int64_t offsetSize = getMixedOffsets().size();
717 if (offsetSize != 0 && offsetSize != tDescRank)
718 return emitOpError(
719 "Mismatched ranks between offsets and tensor descriptor");
720
721 return success();
722}
723
724//===----------------------------------------------------------------------===//
725// XeGPU_UpdateNDOffsetOp
726//===----------------------------------------------------------------------===//
727LogicalResult UpdateNdOffsetOp::verify() {
728 auto ty = getTensorDescType();
729 if (ty.isScattered())
730 return emitOpError("Expects a non-scattered TensorDesc.\n");
731
732 // number of offsets specified must match the rank of the tensor descriptor
733 if (ty.getRank() != (int64_t)getNumOffsets()) {
734 return emitOpError("Invalid number of offsets.");
735 }
736 return success();
737}
738
739//===----------------------------------------------------------------------===//
740// XeGPU_CreateDescOp
741//===----------------------------------------------------------------------===//
742
743void CreateDescOp::build(OpBuilder &builder, OperationState &state,
744 TensorDescType TensorDesc, Value source,
746 auto loc = source.getLoc();
747 int64_t size = static_cast<int64_t>(offsets.size());
748 auto type = VectorType::get(size, builder.getIndexType());
749 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
750 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
751 build(builder, state, TensorDesc, source, offset);
752}
753
754void CreateDescOp::build(OpBuilder &builder, OperationState &state,
755 TensorDescType TensorDesc, Value source,
756 llvm::ArrayRef<int64_t> offsets) {
757 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
758 build(builder, state, TensorDesc, source, ofrs);
759}
760
761LogicalResult CreateDescOp::verify() {
762 auto tdescTy = getTensorDescType();
763
764 if (!tdescTy.isScattered())
765 return emitOpError("Expects a scattered TensorDesc.\n");
766
767 // Memory space of created TensorDesc should match with the source.
768 // Both source and TensorDesc are considered for global memory by default,
769 // if the memory scope attr is not specified. If source is an integer,
770 // it is considered as ptr to global memory.
771 auto srcMemorySpace = getSourceMemorySpace();
772 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
773 if (srcMemorySpace != tdescMemorySpace)
774 return emitOpError("Memory space mismatch.")
775 << " Source: " << srcMemorySpace
776 << ", TensorDesc: " << tdescMemorySpace;
777
778 // check total size
779 auto chunkSize = tdescTy.getChunkSizeAsInt();
780 SmallVector<int64_t> shape(getOffsetsType().getShape());
781 if (chunkSize != 1)
782 shape.push_back(chunkSize);
783
784 auto tdescShape = getShapeOf(tdescTy);
785 if (shape != tdescShape)
786 return emitOpError("Incorrect TensorDesc shape. ")
787 << "Expected is " << makeString(shape) << "\n";
788
789 return success();
790}
791
792//===----------------------------------------------------------------------===//
793// XeGPU_PrefetchOp
794//===----------------------------------------------------------------------===//
795LogicalResult PrefetchOp::verify() {
796 auto tdescTy = getTensorDescType();
797
798 if (!tdescTy && !getOffsets())
799 return emitOpError("Expects offsets.");
800
801 if (tdescTy && getOffsets())
802 return emitOpError("offsets not allowed.");
803
804 if (tdescTy && !tdescTy.isScattered())
805 return emitOpError("Expects a scattered TensorDesc.");
806
807 if (!isReadHintOrNone(getL1HintAttr()))
808 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
809
810 if (!isReadHintOrNone(getL2HintAttr()))
811 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
812
813 if (!isReadHintOrNone(getL3HintAttr()))
814 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
815
816 auto srcTy = getSourceType();
817 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
818 return emitOpError("offset_align_byte is required with integer source.");
819
820 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
821 return emitOpError("offset_align_byte only allowed with integer source.");
822
823 return success();
824}
825
826void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
827 xegpu::CachePolicyAttr l1_hint,
828 xegpu::CachePolicyAttr l2_hint,
829 xegpu::CachePolicyAttr l3_hint) {
830 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
831 IntegerAttr{}, /*anchor_layout=*/nullptr);
832}
833
834//===----------------------------------------------------------------------===//
835// XeGPU_LoadGatherOp
836//===----------------------------------------------------------------------===//
837LogicalResult LoadGatherOp::verify() {
838 auto tdescTy = getTensorDescType();
839 auto maskTy = getMaskType();
840 auto valueTy = getValueType();
841
842 if (!tdescTy && !getOffsets())
843 return emitOpError("Expects offsets.");
844
845 if (tdescTy && getOffsets())
846 return emitOpError("offsets not allowed.");
847
848 if (tdescTy && !tdescTy.isScattered())
849 return emitOpError("Expects a scattered TensorDesc.");
850
851 if (!isReadHintOrNone(getL1HintAttr()))
852 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
853
854 if (!isReadHintOrNone(getL2HintAttr()))
855 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
856
857 if (!isReadHintOrNone(getL3HintAttr()))
858 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
859
860 if (tdescTy)
861 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
862 [&]() { return emitOpError(); });
863 auto srcTy = getSourceType();
864 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
865 auto memTy = dyn_cast<MemRefType>(srcTy);
866
867 if (memTy && (getElementType() != memTy.getElementType()))
868 return emitError() << "Value should have the same element type as MemRef.";
869
870 auto offsetsTy = getOffsets().getType();
871 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
872 [&]() { return emitOpError(); });
873}
874
875void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
876 Type valueType, Value source, Value mask,
877 xegpu::CachePolicyAttr l1_hint,
878 xegpu::CachePolicyAttr l2_hint,
879 xegpu::CachePolicyAttr l3_hint) {
880 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
881 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
882}
883
884void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
885 Type valueType, Value source,
886 ArrayRef<OpFoldResult> offsets, Value mask,
887 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
888 xegpu::CachePolicyAttr l2_hint,
889 xegpu::CachePolicyAttr l3_hint) {
890 auto loc = source.getLoc();
891 int64_t size = static_cast<int64_t>(offsets.size());
892 auto type = VectorType::get(size, builder.getIndexType());
893 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
894 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
895
896 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
897 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
898}
899
900void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
901 Type valueType, Value source,
902 ArrayRef<OpFoldResult> offsets, Value mask,
903 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
904 xegpu::CachePolicyAttr l2_hint,
905 xegpu::CachePolicyAttr l3_hint,
906 DistributeLayoutAttr layout) {
907 auto loc = source.getLoc();
908 int64_t size = static_cast<int64_t>(offsets.size());
909 auto type = VectorType::get(size, builder.getIndexType());
910 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
911 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
912
913 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
914 l2_hint, l3_hint, layout);
915}
916
917//===----------------------------------------------------------------------===//
918// XeGPU_StoreScatterOp
919//===----------------------------------------------------------------------===//
920LogicalResult StoreScatterOp::verify() {
921 auto tdescTy = getTensorDescType();
922 auto maskTy = getMaskType();
923 auto valueTy = getValueType();
924
925 if (!tdescTy && !getOffsets())
926 return emitOpError("Expects offsets.");
927
928 if (tdescTy && getOffsets())
929 return emitOpError("offsets not allowed.");
930
931 if (tdescTy && !tdescTy.isScattered())
932 return emitOpError("Expects a scattered TensorDesc.");
933
934 if (!isWriteHintOrNone(getL1HintAttr()))
935 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
936
937 if (!isWriteHintOrNone(getL2HintAttr()))
938 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
939
940 if (!isWriteHintOrNone(getL3HintAttr()))
941 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
942
943 if (tdescTy)
944 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
945 [&]() { return emitOpError(); });
946
947 auto destTy = getDestType();
948 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
949 auto memTy = dyn_cast<MemRefType>(destTy);
950
951 if (memTy && (getElementType() != memTy.getElementType()))
952 return emitError() << "Value should have the same element type as MemRef.";
953
954 auto offsetsTy = getOffsets().getType();
955 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
956 [&]() { return emitOpError(); });
957}
958
959void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
960 Value value, Value dest, Value mask,
961 xegpu::CachePolicyAttr l1_hint,
962 xegpu::CachePolicyAttr l2_hint,
963 xegpu::CachePolicyAttr l3_hint) {
964 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
965 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
966}
967
968void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
969 Value value, Value dest,
970 ArrayRef<OpFoldResult> offsets, Value mask,
971 IntegerAttr chunk_size,
972 xegpu::CachePolicyAttr l1_hint,
973 xegpu::CachePolicyAttr l2_hint,
974 xegpu::CachePolicyAttr l3_hint) {
975 auto loc = dest.getLoc();
976 int64_t size = static_cast<int64_t>(offsets.size());
977 auto type = VectorType::get(size, builder.getIndexType());
978 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
979 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
980
981 // Call the correct builder overload that does not expect result types.
982 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
983 l3_hint, /*anchor_layout=*/nullptr);
984}
985
986void StoreScatterOp::build(
987 OpBuilder &builder, OperationState &state, Value value, Value dest,
988 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
989 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
990 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
991 auto loc = dest.getLoc();
992 int64_t size = static_cast<int64_t>(offsets.size());
993 auto type = VectorType::get(size, builder.getIndexType());
994 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
995 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
996
997 // Call the correct builder overload that does not expect result types.
998 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
999 l3_hint, layout);
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// XeGPU_UpdateOffsetOp
1004//===----------------------------------------------------------------------===//
1005void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1006 mlir::Value tensorDesc,
1008 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
1009 assert(tdescTy && "Expecting the source is a TensorDescType value.");
1010 auto loc = tensorDesc.getLoc();
1011 int64_t size = static_cast<int64_t>(offsets.size());
1012 auto type = VectorType::get({size}, builder.getIndexType());
1013 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
1014 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1015 build(builder, state, tdescTy, tensorDesc, offset);
1016}
1017
1018void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1019 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
1020 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
1021 build(builder, state, tensorDesc, ofrs);
1022}
1023
1024LogicalResult UpdateOffsetOp::verify() {
1025 auto tdescTy = getTensorDescType();
1026 if (!tdescTy.isScattered())
1027 return emitOpError("Expects a scattered TensorDesc.\n");
1028
1029 SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
1030 SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
1031 if (tdescTy.getChunkSizeAsInt() > 1)
1032 expectedOffsetShape.pop_back();
1033
1034 if (expectedOffsetShape != offsetShape)
1035 return emitOpError(
1036 "Offsets should match TensorDesc except the chunk size dim.");
1037
1038 return success();
1039}
1040
1041//===----------------------------------------------------------------------===//
1042// XeGPU_DpasOp
1043//===----------------------------------------------------------------------===//
1044LogicalResult DpasOp::verify() {
1045 int64_t lhsRank = getLhsType().getRank();
1046 int64_t rhsRank = getRhsType().getRank();
1047 int64_t resRank = getResultType().getRank();
1048 auto lhsShape = getLhsType().getShape();
1049 auto rhsShape = getRhsType().getShape();
1050 auto resShape = getResultType().getShape();
1051
1052 if (getAcc() && getAcc().getType() != getResultType())
1053 return emitOpError("Expecting the acc type to be the same as result.");
1054
1055 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
1056 // It skips the semantic check since lack of architecture information.
1057 // Users need to ensure the correctness.
1058 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1059 auto numElems = getRhsType().getNumElements();
1060 auto elemTy = getRhsType().getElementType();
1061 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1062 if (numElems % factor != 0)
1063 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
1064 return success();
1065 }
1066
1067 // SIMD code
1068 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1069 return emitOpError(
1070 "expecting lhs and result to be a 2D vector, and rhs to be either "
1071 "2D or 3D (packed) vector.");
1072 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1073 if (bK != lhsShape[1])
1074 return emitOpError("K-dimension mismatch.");
1075 if (lhsShape[0] != resShape[0])
1076 return emitOpError("M-dimension mismatch.");
1077 if (rhsShape[1] != resShape[1])
1078 return emitOpError("N-dimension mismatch.");
1079
1080 return success();
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// XeGPU_ConvertLayoutOp
1085//===----------------------------------------------------------------------===//
1086LogicalResult ConvertLayoutOp::verify() {
1087 auto srcLayout = getInputLayout();
1088 auto resLayout = getTargetLayout();
1089 if (!srcLayout)
1090 return emitOpError("expected input layout.");
1091 if (!resLayout)
1092 return emitOpError("expected target layout.");
1093
1094 // both input and target layouts should be WgLayout or SgLayout at the same
1095 // time.
1096 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1097 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1098 return emitOpError("expected input layout and target layout be WgLayout or "
1099 "SgLayout at the same time.");
1100
1101 auto shape = getSource().getType().getShape();
1102 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1103 return emitOpError(
1104 "invalid input layout, data cannot be evenly distributed.");
1105
1106 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1107 return emitOpError(
1108 "invalid target layout, data cannot be evenly distributed.");
1109
1110 return mlir::success();
1111}
1112
1113OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1114 if (getInputLayout() == getTargetLayout())
1115 return getSource();
1116 return {};
1117}
1118
1119struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
1120 using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
1121 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
1122 PatternRewriter &rewriter) const override {
1123 if (op.getInputLayout() == op.getTargetLayout()) {
1124 rewriter.replaceOp(op, op.getSource());
1125 return success();
1126 }
1127 return failure();
1128 }
1129};
1130
1131void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1132 MLIRContext *context) {
1133 patterns.add<FoldConvertLayoutOp>(context);
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// XeGPU_LoadMatrixOp
1138//===----------------------------------------------------------------------===//
1139void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1142 DistributeLayoutAttr layout) {
1143 llvm::SmallVector<Value> dynamicOffsets;
1144 llvm::SmallVector<int64_t> staticOffsets;
1145 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1146 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1147 // Call the generated builder with all parameters (including optional ones as
1148 // nullptr/empty)
1149 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1150 /*subgroup_block_io=*/nullptr, layout);
1151}
1152
1153LogicalResult LoadMatrixOp::verify() {
1154
1155 auto resTy = dyn_cast<VectorType>(getRes().getType());
1156 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1157 MemDescType mdescTy = getMemDesc().getType();
1158
1159 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1160 getLayoutAttr(), [&]() { return emitError(); });
1161}
1162
1163//===----------------------------------------------------------------------===//
1164// XeGPU_StoreMatrixOp
1165//===----------------------------------------------------------------------===//
1166void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1169 DistributeLayoutAttr layout) {
1170 llvm::SmallVector<Value> dynamicOffsets;
1171 llvm::SmallVector<int64_t> staticOffsets;
1172 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1173 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1174 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1175 /*subgroup_block_io=*/nullptr, layout);
1176}
1177
1178LogicalResult StoreMatrixOp::verify() {
1179
1180 auto dataTy = dyn_cast<VectorType>(getData().getType());
1181 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1182 MemDescType mdescTy = getMemDesc().getType();
1183 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1184 getLayoutAttr(), [&]() { return emitError(); });
1185}
1186
1187namespace mlir {
1188#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1189} // namespace mlir
1190#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1191#define GET_OP_CLASSES
1192#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:775
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:177
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:125
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:450
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:412
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.