MLIR 22.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
24using namespace mlir::xegpu;
25
26static bool isSharedMemory(const MemRefType &memrefTy) {
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template <typename T>
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast<ShapedType>(type))
55 shape = SmallVector<int64_t>(ty.getShape());
56 else
57 shape.push_back(1);
58 return shape;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
78isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79 TensorDescType tdescTy,
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
86 if (!valueTy) {
87 if (chunkSize > 1)
88 return emitError() << "Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() << "Expecting a vector type result.";
91 return success();
92 }
93
94 auto maskShape = getShapeOf(maskTy);
95 auto valueShape = getShapeOf(valueTy);
96 auto tdescShape = getShapeOf(tdescTy);
97
98 if (valueTy.getElementType() != tdescTy.getElementType())
99 return emitError()
100 << "Value should have the same element type as TensorDesc.";
101
102 llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
103 if (chunkSize > 1)
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
106 return emitError()
107 << "Mask should match TensorDesc except the chunk size dim.";
108
109 // a valid shape for SIMT case
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
113 return success();
114 }
115
116 if (tdescShape != valueShape)
117 return emitError() << "Value shape " << makeString(valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
121 return success();
122}
123
124static LogicalResult
126 VectorType valueTy, int64_t chunkSize,
128
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
131 if (!valueTy) {
132 if (chunkSize > 1)
133 return emitError() << "Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() << "Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() << "Expecting a vector type result.";
138 return success();
139 }
140
141 auto valueSize = valueTy.getNumElements();
142 // SIMT mode with scalar mask and offsets.
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() << "value elements must match chunk size "
146 << chunkSize;
147 return success();
148 }
149 auto maskShape = getShapeOf(maskTy);
150 auto valueShape = getShapeOf(valueTy);
151
152 if (!maskVecTy)
153 return emitError() << "Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
155
156 if (chunkSize > 1) {
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() << "value elements must match chunk size "
159 << chunkSize;
160 } else {
161 if (valueSize != maskSize)
162 return emitError()
163 << "Mask should match value except the chunk size dim.";
164 }
165 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
166 if (maskSize == 1)
167 return success();
168 if (chunkSize > 1)
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() << "Mask should match value except the chunk size dim.";
172
173 return success();
174}
175
176LogicalResult
177IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
180
181 if (!dataTy) {
182 if (subgroup_block_io)
183 return emitError() << "subgroup_block_io "
184 "are only allowed when result is a VectorType.";
185 else
186 return success();
187 }
188
189 if (mdescTy.getRank() != 2)
190 return emitError() << "mem_desc must be 2D.";
191
192 ArrayRef<int64_t> dataShape = dataTy.getShape();
193 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194
195 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
197 SmallVector<int64_t> strides;
198 for (Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
200 }
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() << "With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() << "With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() << "With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
218 }
219 }
220 }
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() << "data shape must not exceed mem_desc shape.";
225 } else {
226 // if the subgroup_block_io attribute is set, mdescTy must have block
227 // attribute
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() << "mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
231 // if the subgroup_block_io attribute is set, the memdesc should be row
232 // major
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() << "mem_desc should be row major when "
235 "subgroup_block_io is set.";
236 }
237
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// XeGPU_CreateNdDescOp
243//===----------------------------------------------------------------------===//
244
245void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
246 Type tdesc, TypedValue<MemRefType> source) {
247 [[maybe_unused]] auto ty = source.getType();
248 assert(ty.hasStaticShape() && "expecting a memref with static shape");
249
250 build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
251 ValueRange({}) /* empty dynamic shape */,
252 ValueRange({}) /* empty dynamic strides */,
253 DenseI64ArrayAttr({}) /* const offsets */,
254 DenseI64ArrayAttr({}) /* empty const shape*/,
255 DenseI64ArrayAttr({}) /* empty const strides*/);
256}
257
258void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
259 Type tdesc, Value source,
262 Type srcTy = source.getType();
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
265
266 llvm::SmallVector<Value> dynamicShape;
267 llvm::SmallVector<Value> dynamicStrides;
268
269 llvm::SmallVector<int64_t> staticShape;
270 llvm::SmallVector<int64_t> staticStrides;
271
272 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
273 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
274
275 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
276 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
277
278 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281
282 // if shape and strides are from Memref, we don't need attributes for them
283 // to keep the IR print clean (only do so for full-static case, otherwise
284 // printer would fail trying to print empty array-attr).
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
287 staticShapeAttr = DenseI64ArrayAttr();
288 staticStridesAttr = DenseI64ArrayAttr();
289 }
290 }
291
292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
293 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
294 staticStridesAttr);
295}
296
297void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
298 Type tdesc, TypedValue<MemRefType> source,
300 [[maybe_unused]] auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
302
303 llvm::SmallVector<int64_t> staticOffsets;
304 llvm::SmallVector<Value> dynamicOffsets;
305 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
306
307 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
308 ValueRange({}) /* empty dynamic shape */,
309 ValueRange({}) /* empty dynamic strides */,
310 builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
311 {} /* empty const shape*/, {} /* empty const strides*/);
312}
313
314void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
315 Type tdesc, Value source,
319 assert(!shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() && shape.size() == offsets.size());
321
322 Type srcTy = source.getType();
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
325
326 llvm::SmallVector<Value> dynamicOffsets;
327 llvm::SmallVector<Value> dynamicShape;
328 llvm::SmallVector<Value> dynamicStrides;
329
330 llvm::SmallVector<int64_t> staticOffsets;
331 llvm::SmallVector<int64_t> staticShape;
332 llvm::SmallVector<int64_t> staticStrides;
333
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
336 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
337
338 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
339 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
340 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
341
342 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
345
346 // if shape and strides are from Memref, we don't need attributes for them
347 // to keep the IR print clean (only do so for full-static case, otherwise
348 // printer would fail trying to print empty array-attr).
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
351 staticShapeAttr = DenseI64ArrayAttr();
352 staticStridesAttr = DenseI64ArrayAttr();
353 }
354 }
355
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
358}
359
360LogicalResult CreateNdDescOp::verify() {
361 size_t rank = getMixedSizes().size();
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy = false;
364
365 // Memory space of created TensorDesc should match with the source.
366 // Both source and TensorDesc are considered for global memory by default,
367 // if the memory scope attr is not specified. If source is an integer,
368 // it is considered as ptr to global memory.
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
372 return emitOpError("Memory space mismatch.")
373 << " Source: " << srcMemorySpace
374 << ", TensorDesc: " << tdescMemorySpace;
375
376 if (size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
378
379 // check source type matches the rank if it is a memref.
380 // It also should have the same ElementType as TensorDesc.
381 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
382 invalidElemTy |= memrefTy.getElementType() != getElementType();
383
384 if (llvm::isa<IntegerType>(getSourceType())) {
385 // strides and shape must present for integer source.
386 if (getMixedStrides().empty() || getMixedSizes().empty())
387 return emitOpError("expecting strides and shape to be present for "
388 "integer source.");
389 }
390
391 if (invalidRank)
392 return emitOpError(
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
395
396 // check result TensorDesc rank
397 if (getType().getRank() > (int64_t)rank)
398 return emitOpError(
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
401
402 if (invalidElemTy)
403 return emitOpError("TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
405
406 if (getType().isScattered())
407 return emitOpError("Expects a non-scattered TensorDesc.\n");
408
409 return success();
410}
411
413 OpAsmParser &parser,
415 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
417
418 SmallVector<int64_t, 4> integerVals;
419 auto parseIntegerOrValue = [&]() {
421 auto res = parser.parseOptionalOperand(operand);
422
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
427 return failure();
428 } else {
429 int64_t integer;
430 if (failed(parser.parseInteger(integer)))
431 return failure();
432 integerVals.push_back(integer);
433 }
434 return success();
435 };
436
437 // If the optional values are given there must be left bracket
438 if (parser.parseOptionalLSquare().succeeded()) {
439 if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
440 parser.parseRSquare())
441 return parser.emitError(parser.getNameLoc())
442 << "expected a list of SSA values or integers";
443 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
444 return success();
445 }
446
447 return success();
448}
449
451 OperandRange values,
452 DenseI64ArrayAttr integers) {
453 if (!integers || integers.empty())
454 return;
455 printDynamicIndexList(printer, op, values, integers,
456 /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
457}
458//===----------------------------------------------------------------------===//
459// XeGPU_PrefetchNdOp
460//===----------------------------------------------------------------------===//
461
462void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
466
467 return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
468 l1_hint, l2_hint, l3_hint);
469}
470
471void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
472 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint) {
476 SmallVector<Value> dynamicOffsets;
477 SmallVector<int64_t> staticOffsets;
478 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
479
480 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
481
482 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
483 l2_hint, l3_hint);
484}
485
486LogicalResult PrefetchNdOp::verify() {
487 auto tdescTy = getTensorDescType();
488 if (tdescTy.isScattered())
489 return emitOpError("Expects a non-scattered TensorDesc.\n");
490
491 if (!isReadHintOrNone(getL1HintAttr()))
492 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
493
494 if (!isReadHintOrNone(getL2HintAttr()))
495 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
496
497 if (!isReadHintOrNone(getL3HintAttr()))
498 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
499
500 int64_t tDescRank = tdescTy.getRank();
501 int64_t offsetSize = getMixedOffsets().size();
502 if (offsetSize != 0 && offsetSize != tDescRank)
503 return emitOpError(
504 "Mismatched ranks between offsets and tensor descriptor");
505
506 return success();
507}
508
509//===----------------------------------------------------------------------===//
510// XeGPU_LoadNdOp
511//===----------------------------------------------------------------------===//
512
513void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
514 Value tensorDesc, UnitAttr packed,
515 DenseI64ArrayAttr transpose,
516 xegpu::CachePolicyAttr l1_hint,
517 xegpu::CachePolicyAttr l2_hint,
518 xegpu::CachePolicyAttr l3_hint) {
519
520 return build(builder, state, retType, tensorDesc, ValueRange(),
521 DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
522 l3_hint);
523}
524
525void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
526 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
527 UnitAttr packed, DenseI64ArrayAttr transpose,
528 xegpu::CachePolicyAttr l1_hint,
529 xegpu::CachePolicyAttr l2_hint,
530 xegpu::CachePolicyAttr l3_hint) {
531 SmallVector<Value> dynamicOffsets;
532 SmallVector<int64_t> staticOffsets;
533 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
534
535 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
536
537 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538 packed, transpose, l1_hint, l2_hint, l3_hint);
539}
540
541LogicalResult LoadNdOp::verify() {
542 auto tdescTy = getTensorDescType();
543 auto valueTy = getType();
544
545 if (tdescTy.isScattered())
546 return emitOpError("Expects a non-scattered TensorDesc.\n");
547
548 if (tdescTy.getRank() > 2)
549 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
550
551 if (!valueTy)
552 return emitOpError("Invalid result, it should be a VectorType.\n");
553
554 if (!isReadHintOrNone(getL1HintAttr()))
555 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
556
557 if (!isReadHintOrNone(getL2HintAttr()))
558 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
559
560 if (!isReadHintOrNone(getL3HintAttr()))
561 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
562
563 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
564 int valueElems = valueTy.getNumElements();
565
566 // If the result vector is 1D and has less elements than the tensor
567 // descriptor, it is supposed to be a SIMT op. The layout attribute in
568 // tensor_desc is not needed.
569 if (valueElems < tdescElems && valueTy.getRank() == 1) {
570 // SIMT mode doesn't need LayoutAttr.
571 if (tdescTy.getLayoutAttr())
572 return emitOpError()
573 << "TensorDesc doesn't need LayoutAttr for SIMT code";
574
575 // For SIMT code, the load is evenly distributed across all lanes in a
576 // subgroup. Since subgroup size is arch dependent, we only check even
577 // distribution here.
578 if (tdescElems % valueElems)
579 return emitOpError()
580 << "Result shape " << makeString(getShapeOf(valueTy))
581 << " is not a valid distribution for tensor descriptor "
582 << tdescTy;
583
584 return success();
585 }
586
587 // Check SIMD mode.
588 auto tdescShape = getShapeOf(tdescTy);
589 auto valueShape = getShapeOf(valueTy);
590
591 if (getTranspose()) {
592 auto trans = getTranspose().value();
593 // Make sure the transpose value is valid, and apply it
594 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
595 tdescShape = applyPermutation(tdescShape, trans);
596 else
597 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
598 }
599
600 if (getPacked()) {
601 if (tdescTy.getRank() == 2) {
602 const int axis = 0;
603 auto vnni_factor = valueShape.back();
604 tdescShape[axis] /= vnni_factor;
605 tdescShape.push_back(vnni_factor);
606 } else {
607 mlir::emitWarning(getLoc())
608 << "Invalid Packed Attr. It is ignored (available for 2D "
609 "TensorDesc only).";
610 }
611 }
612
613 auto array_len = tdescTy.getArrayLength();
614 if (array_len > 1)
615 tdescShape.insert(tdescShape.begin(), array_len);
616
617 if (tdescShape != valueShape)
618 return emitOpError() << "Result shape " << makeString(valueShape)
619 << " is not consistent with tensor descriptor "
620 << tdescTy;
621
622 int64_t tDescRank = tdescTy.getRank();
623 int64_t offsetSize = getMixedOffsets().size();
624 if (offsetSize != 0 && offsetSize != tDescRank)
625 return emitOpError(
626 "Mismatched ranks between offsets and tensor descriptor");
627
628 return success();
629}
630
631//===----------------------------------------------------------------------===//
632// XeGPU_StoreNdOp
633//===----------------------------------------------------------------------===//
634
635void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
636 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
637 xegpu::CachePolicyAttr l2_hint,
638 xegpu::CachePolicyAttr l3_hint) {
639
640 return build(builder, state, value, tensorDesc, ValueRange(),
641 DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
642}
643
644void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
645 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
646 xegpu::CachePolicyAttr l1_hint,
647 xegpu::CachePolicyAttr l2_hint,
648 xegpu::CachePolicyAttr l3_hint) {
649 SmallVector<Value> dynamicOffsets;
650 SmallVector<int64_t> staticOffsets;
651 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
652
653 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
654
655 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
656 l1_hint, l2_hint, l3_hint);
657}
658
659LogicalResult StoreNdOp::verify() {
660 auto dstTy = getTensorDescType(); // Tile
661 auto valTy = getValueType(); // Vector
662
663 if (dstTy.isScattered())
664 return emitOpError("Expects a non-scattered TensorDesc.\n");
665
666 if (dstTy.getRank() > 2)
667 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
668
669 if (!valTy)
670 return emitOpError("Expecting a VectorType result.\n");
671
672 if (!isWriteHintOrNone(getL1HintAttr()))
673 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
674
675 if (!isWriteHintOrNone(getL2HintAttr()))
676 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
677
678 if (!isWriteHintOrNone(getL3HintAttr()))
679 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
680
681 auto array_len = dstTy.getArrayLength();
682 if (array_len > 1)
683 return emitOpError("array length is not supported by store_nd.\n");
684
685 auto tdescElems = dstTy.getNumElements();
686 auto valueElems = valTy.getNumElements();
687
688 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
689 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
690 // in tensor_desc is not needed.
691 if (valTy.getRank() == 1 && valueElems < tdescElems) {
692 // SIMT mode doesn't need LayoutAttr.
693 if (dstTy.getLayoutAttr())
694 return emitOpError()
695 << "TensorDesc doesn't need LayoutAttr for SIMT code";
696
697 if (tdescElems % valueElems)
698 return emitOpError()
699 << "Value shape " << makeString(getShapeOf(valTy))
700 << " is not a valid distribution for tensor descriptor " << dstTy;
701
702 return success();
703 }
704
705 // SIMD code should have the same shape as the tensor descriptor.
706 auto tdescShape = getShapeOf(dstTy);
707 auto valueShape = getShapeOf(valTy);
708 if (tdescShape != valueShape)
709 return emitOpError() << "Value shape " << makeString(valueShape)
710 << " is not consistent with tensor descriptor "
711 << dstTy;
712
713 int64_t tDescRank = dstTy.getRank();
714 int64_t offsetSize = getMixedOffsets().size();
715 if (offsetSize != 0 && offsetSize != tDescRank)
716 return emitOpError(
717 "Mismatched ranks between offsets and tensor descriptor");
718
719 return success();
720}
721
722//===----------------------------------------------------------------------===//
723// XeGPU_UpdateNDOffsetOp
724//===----------------------------------------------------------------------===//
725LogicalResult UpdateNdOffsetOp::verify() {
726 auto ty = getTensorDescType();
727 if (ty.isScattered())
728 return emitOpError("Expects a non-scattered TensorDesc.\n");
729
730 // number of offsets specified must match the rank of the tensor descriptor
731 if (ty.getRank() != (int64_t)getNumOffsets()) {
732 return emitOpError("Invalid number of offsets.");
733 }
734 return success();
735}
736
737//===----------------------------------------------------------------------===//
738// XeGPU_CreateDescOp
739//===----------------------------------------------------------------------===//
740
741void CreateDescOp::build(OpBuilder &builder, OperationState &state,
742 TensorDescType TensorDesc, Value source,
744 auto loc = source.getLoc();
745 int64_t size = static_cast<int64_t>(offsets.size());
746 auto type = VectorType::get(size, builder.getIndexType());
747 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
748 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
749 build(builder, state, TensorDesc, source, offset);
750}
751
752void CreateDescOp::build(OpBuilder &builder, OperationState &state,
753 TensorDescType TensorDesc, Value source,
754 llvm::ArrayRef<int64_t> offsets) {
755 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
756 build(builder, state, TensorDesc, source, ofrs);
757}
758
759LogicalResult CreateDescOp::verify() {
760 auto tdescTy = getTensorDescType();
761
762 if (!tdescTy.isScattered())
763 return emitOpError("Expects a scattered TensorDesc.\n");
764
765 // Memory space of created TensorDesc should match with the source.
766 // Both source and TensorDesc are considered for global memory by default,
767 // if the memory scope attr is not specified. If source is an integer,
768 // it is considered as ptr to global memory.
769 auto srcMemorySpace = getSourceMemorySpace();
770 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
771 if (srcMemorySpace != tdescMemorySpace)
772 return emitOpError("Memory space mismatch.")
773 << " Source: " << srcMemorySpace
774 << ", TensorDesc: " << tdescMemorySpace;
775
776 // check total size
777 auto chunkSize = tdescTy.getChunkSizeAsInt();
778 SmallVector<int64_t> shape(getOffsetsType().getShape());
779 if (chunkSize != 1)
780 shape.push_back(chunkSize);
781
782 auto tdescShape = getShapeOf(tdescTy);
783 if (shape != tdescShape)
784 return emitOpError("Incorrect TensorDesc shape. ")
785 << "Expected is " << makeString(shape) << "\n";
786
787 return success();
788}
789
790//===----------------------------------------------------------------------===//
791// XeGPU_PrefetchOp
792//===----------------------------------------------------------------------===//
793LogicalResult PrefetchOp::verify() {
794 auto tdescTy = getTensorDescType();
795
796 if (!tdescTy && !getOffsets())
797 return emitOpError("Expects offsets.");
798
799 if (tdescTy && getOffsets())
800 return emitOpError("offsets not allowed.");
801
802 if (tdescTy && !tdescTy.isScattered())
803 return emitOpError("Expects a scattered TensorDesc.");
804
805 if (!isReadHintOrNone(getL1HintAttr()))
806 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
807
808 if (!isReadHintOrNone(getL2HintAttr()))
809 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
810
811 if (!isReadHintOrNone(getL3HintAttr()))
812 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
813
814 auto srcTy = getSourceType();
815 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
816 return emitOpError("offset_align_byte is required with integer source.");
817
818 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
819 return emitOpError("offset_align_byte only allowed with integer source.");
820
821 return success();
822}
823
824void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
825 xegpu::CachePolicyAttr l1_hint,
826 xegpu::CachePolicyAttr l2_hint,
827 xegpu::CachePolicyAttr l3_hint) {
828 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
829 IntegerAttr{});
830}
831
832//===----------------------------------------------------------------------===//
833// XeGPU_LoadGatherOp
834//===----------------------------------------------------------------------===//
835LogicalResult LoadGatherOp::verify() {
836 auto tdescTy = getTensorDescType();
837 auto maskTy = getMaskType();
838 auto valueTy = getValueType();
839
840 if (!tdescTy && !getOffsets())
841 return emitOpError("Expects offsets.");
842
843 if (tdescTy && getOffsets())
844 return emitOpError("offsets not allowed.");
845
846 if (tdescTy && !tdescTy.isScattered())
847 return emitOpError("Expects a scattered TensorDesc.");
848
849 if (!isReadHintOrNone(getL1HintAttr()))
850 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
851
852 if (!isReadHintOrNone(getL2HintAttr()))
853 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
854
855 if (!isReadHintOrNone(getL3HintAttr()))
856 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
857
858 if (tdescTy)
859 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
860 [&]() { return emitOpError(); });
861 auto srcTy = getSourceType();
862 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
863 auto memTy = dyn_cast<MemRefType>(srcTy);
864
865 if (memTy && (getElementType() != memTy.getElementType()))
866 return emitError() << "Value should have the same element type as MemRef.";
867
868 auto offsetsTy = getOffsets().getType();
869 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
870 [&]() { return emitOpError(); });
871}
872
873void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
874 Type valueType, Value source, Value mask,
875 xegpu::CachePolicyAttr l1_hint,
876 xegpu::CachePolicyAttr l2_hint,
877 xegpu::CachePolicyAttr l3_hint) {
878 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
879 l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
880}
881
882void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
883 Type valueType, Value source,
884 ArrayRef<OpFoldResult> offsets, Value mask,
885 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
886 xegpu::CachePolicyAttr l2_hint,
887 xegpu::CachePolicyAttr l3_hint) {
888 auto loc = source.getLoc();
889 int64_t size = static_cast<int64_t>(offsets.size());
890 auto type = VectorType::get(size, builder.getIndexType());
891 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
892 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
893
894 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
895 l2_hint, l3_hint, /*layout=*/nullptr);
896}
897
898void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
899 Type valueType, Value source,
900 ArrayRef<OpFoldResult> offsets, Value mask,
901 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
902 xegpu::CachePolicyAttr l2_hint,
903 xegpu::CachePolicyAttr l3_hint,
904 xegpu::LayoutAttr layout) {
905 auto loc = source.getLoc();
906 int64_t size = static_cast<int64_t>(offsets.size());
907 auto type = VectorType::get(size, builder.getIndexType());
908 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
909 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
910
911 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
912 l2_hint, l3_hint, layout);
913}
914
915//===----------------------------------------------------------------------===//
916// XeGPU_StoreScatterOp
917//===----------------------------------------------------------------------===//
918LogicalResult StoreScatterOp::verify() {
919 auto tdescTy = getTensorDescType();
920 auto maskTy = getMaskType();
921 auto valueTy = getValueType();
922
923 if (!tdescTy && !getOffsets())
924 return emitOpError("Expects offsets.");
925
926 if (tdescTy && getOffsets())
927 return emitOpError("offsets not allowed.");
928
929 if (tdescTy && !tdescTy.isScattered())
930 return emitOpError("Expects a scattered TensorDesc.");
931
932 if (!isWriteHintOrNone(getL1HintAttr()))
933 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
934
935 if (!isWriteHintOrNone(getL2HintAttr()))
936 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
937
938 if (!isWriteHintOrNone(getL3HintAttr()))
939 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
940
941 if (tdescTy)
942 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
943 [&]() { return emitOpError(); });
944
945 auto destTy = getDestType();
946 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
947 auto memTy = dyn_cast<MemRefType>(destTy);
948
949 if (memTy && (getElementType() != memTy.getElementType()))
950 return emitError() << "Value should have the same element type as MemRef.";
951
952 auto offsetsTy = getOffsets().getType();
953 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
954 [&]() { return emitOpError(); });
955}
956
957void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
958 Value value, Value dest, Value mask,
959 xegpu::CachePolicyAttr l1_hint,
960 xegpu::CachePolicyAttr l2_hint,
961 xegpu::CachePolicyAttr l3_hint) {
962 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
963 l2_hint, l3_hint, /*layout=*/nullptr);
964}
965
966void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
967 Value value, Value dest,
968 ArrayRef<OpFoldResult> offsets, Value mask,
969 IntegerAttr chunk_size,
970 xegpu::CachePolicyAttr l1_hint,
971 xegpu::CachePolicyAttr l2_hint,
972 xegpu::CachePolicyAttr l3_hint) {
973 auto loc = dest.getLoc();
974 int64_t size = static_cast<int64_t>(offsets.size());
975 auto type = VectorType::get(size, builder.getIndexType());
976 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
977 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
978
979 // Call the correct builder overload that does not expect result types.
980 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
981 l3_hint, /*layout=*/nullptr);
982}
983
984void StoreScatterOp::build(
985 OpBuilder &builder, OperationState &state, Value value, Value dest,
986 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
987 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
988 xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
989 auto loc = dest.getLoc();
990 int64_t size = static_cast<int64_t>(offsets.size());
991 auto type = VectorType::get(size, builder.getIndexType());
992 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
993 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
994
995 // Call the correct builder overload that does not expect result types.
996 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
997 l3_hint, layout);
998}
999
1000//===----------------------------------------------------------------------===//
1001// XeGPU_UpdateOffsetOp
1002//===----------------------------------------------------------------------===//
1003void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1004 mlir::Value tensorDesc,
1006 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
1007 assert(tdescTy && "Expecting the source is a TensorDescType value.");
1008 auto loc = tensorDesc.getLoc();
1009 int64_t size = static_cast<int64_t>(offsets.size());
1010 auto type = VectorType::get({size}, builder.getIndexType());
1011 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
1012 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1013 build(builder, state, tdescTy, tensorDesc, offset);
1014}
1015
1016void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1017 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
1018 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
1019 build(builder, state, tensorDesc, ofrs);
1020}
1021
1022LogicalResult UpdateOffsetOp::verify() {
1023 auto tdescTy = getTensorDescType();
1024 if (!tdescTy.isScattered())
1025 return emitOpError("Expects a scattered TensorDesc.\n");
1026
1027 SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
1028 SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
1029 if (tdescTy.getChunkSizeAsInt() > 1)
1030 expectedOffsetShape.pop_back();
1031
1032 if (expectedOffsetShape != offsetShape)
1033 return emitOpError(
1034 "Offsets should match TensorDesc except the chunk size dim.");
1035
1036 return success();
1037}
1038
1039//===----------------------------------------------------------------------===//
1040// XeGPU_DpasOp
1041//===----------------------------------------------------------------------===//
1042LogicalResult DpasOp::verify() {
1043 int64_t lhsRank = getLhsType().getRank();
1044 int64_t rhsRank = getRhsType().getRank();
1045 int64_t resRank = getResultType().getRank();
1046 auto lhsShape = getLhsType().getShape();
1047 auto rhsShape = getRhsType().getShape();
1048 auto resShape = getResultType().getShape();
1049
1050 if (getAcc() && getAcc().getType() != getResultType())
1051 return emitOpError("Expecting the acc type to be the same as result.");
1052
1053 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
1054 // It skips the semantic check since lack of architecture information.
1055 // Users need to ensure the correctness.
1056 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1057 auto numElems = getRhsType().getNumElements();
1058 auto elemTy = getRhsType().getElementType();
1059 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1060 if (numElems % factor != 0)
1061 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
1062 return success();
1063 }
1064
1065 // SIMD code
1066 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1067 return emitOpError(
1068 "expecting lhs and result to be a 2D vector, and rhs to be either "
1069 "2D or 3D (packed) vector.");
1070 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1071 if (bK != lhsShape[1])
1072 return emitOpError("K-dimension mismatch.");
1073 if (lhsShape[0] != resShape[0])
1074 return emitOpError("M-dimension mismatch.");
1075 if (rhsShape[1] != resShape[1])
1076 return emitOpError("N-dimension mismatch.");
1077
1078 return success();
1079}
1080
1081//===----------------------------------------------------------------------===//
1082// XeGPU_ConvertLayoutOp
1083//===----------------------------------------------------------------------===//
1084LogicalResult ConvertLayoutOp::verify() {
1085 auto srcLayout = getInputLayout();
1086 auto resLayout = getTargetLayout();
1087 if (!srcLayout)
1088 return emitOpError("expected input layout.");
1089 if (!resLayout)
1090 return emitOpError("expected target layout.");
1091
1092 // both input and target layouts should be WgLayout or SgLayout at the same
1093 // time.
1094 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1095 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1096 return emitOpError("expected input layout and target layout be WgLayout or "
1097 "SgLayout at the same time.");
1098
1099 auto shape = getSource().getType().getShape();
1100 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1101 return emitOpError(
1102 "invalid input layout, data cannot be evenly distributed.");
1103
1104 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1105 return emitOpError(
1106 "invalid target layout, data cannot be evenly distributed.");
1107
1108 return mlir::success();
1109}
1110
1111OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1112 if (getInputLayout() == getTargetLayout())
1113 return getSource();
1114 return {};
1115}
1116
1117struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
1118 using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
1119 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
1120 PatternRewriter &rewriter) const override {
1121 if (op.getInputLayout() == op.getTargetLayout()) {
1122 rewriter.replaceOp(op, op.getSource());
1123 return success();
1124 }
1125 return failure();
1126 }
1127};
1128
1129void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1130 MLIRContext *context) {
1131 patterns.add<FoldConvertLayoutOp>(context);
1132}
1133
1134//===----------------------------------------------------------------------===//
1135// XeGPU_LoadMatrixOp
1136//===----------------------------------------------------------------------===//
1137void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1140 DistributeLayoutAttr layout) {
1141 llvm::SmallVector<Value> dynamicOffsets;
1142 llvm::SmallVector<int64_t> staticOffsets;
1143 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1144 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1145 // Call the generated builder with all parameters (including optional ones as
1146 // nullptr/empty)
1147 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1148 /*subgroup_block_io=*/nullptr, layout);
1149}
1150
1151LogicalResult LoadMatrixOp::verify() {
1152
1153 auto resTy = dyn_cast<VectorType>(getRes().getType());
1154 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1155 MemDescType mdescTy = getMemDesc().getType();
1156
1157 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1158 getLayoutAttr(), [&]() { return emitError(); });
1159}
1160
1161//===----------------------------------------------------------------------===//
1162// XeGPU_StoreMatrixOp
1163//===----------------------------------------------------------------------===//
1164void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1167 DistributeLayoutAttr layout) {
1168 llvm::SmallVector<Value> dynamicOffsets;
1169 llvm::SmallVector<int64_t> staticOffsets;
1170 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1171 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1172 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1173 /*subgroup_block_io=*/nullptr, layout);
1174}
1175
1176LogicalResult StoreMatrixOp::verify() {
1177
1178 auto dataTy = dyn_cast<VectorType>(getData().getType());
1179 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1180 MemDescType mdescTy = getMemDesc().getType();
1181 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1182 getLayoutAttr(), [&]() { return emitError(); });
1183}
1184
1185namespace mlir {
1186#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1187} // namespace mlir
1188#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1189#define GET_OP_CLASSES
1190#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:775
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:177
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:125
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:450
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:412
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.