MLIR 22.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
24using namespace mlir::xegpu;
25
26static bool isSharedMemory(const MemRefType &memrefTy) {
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template <typename T>
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast<ShapedType>(type))
55 shape = SmallVector<int64_t>(ty.getShape());
56 else
57 shape.push_back(1);
58 return shape;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
78isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79 TensorDescType tdescTy,
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
86 if (!valueTy) {
87 if (chunkSize > 1)
88 return emitError() << "Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() << "Expecting a vector type result.";
91 return success();
92 }
93
94 auto maskShape = getShapeOf(maskTy);
95 auto valueShape = getShapeOf(valueTy);
96 auto tdescShape = getShapeOf(tdescTy);
97
98 if (valueTy.getElementType() != tdescTy.getElementType())
99 return emitError()
100 << "Value should have the same element type as TensorDesc.";
101
102 llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
103 if (chunkSize > 1)
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
106 return emitError()
107 << "Mask should match TensorDesc except the chunk size dim.";
108
109 // a valid shape for SIMT case
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
113 return success();
114 }
115
116 if (tdescShape != valueShape)
117 return emitError() << "Value shape " << makeString(valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
121 return success();
122}
123
124static LogicalResult
126 VectorType valueTy, int64_t chunkSize,
128
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
131 if (!valueTy) {
132 if (chunkSize > 1)
133 return emitError() << "Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() << "Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() << "Expecting a vector type result.";
138 return success();
139 }
140
141 auto valueSize = valueTy.getNumElements();
142 // SIMT mode with scalar mask and offsets.
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() << "value elements must match chunk size "
146 << chunkSize;
147 return success();
148 }
149 auto maskShape = getShapeOf(maskTy);
150 auto valueShape = getShapeOf(valueTy);
151
152 if (!maskVecTy)
153 return emitError() << "Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
155
156 if (chunkSize > 1) {
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() << "value elements must match chunk size "
159 << chunkSize;
160 } else {
161 if (valueSize != maskSize)
162 return emitError()
163 << "Mask should match value except the chunk size dim.";
164 }
165 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
166 if (maskSize == 1)
167 return success();
168 if (chunkSize > 1)
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() << "Mask should match value except the chunk size dim.";
172
173 return success();
174}
175
176LogicalResult
177IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
180
181 if (!dataTy) {
182 if (subgroup_block_io)
183 return emitError() << "subgroup_block_io "
184 "are only allowed when result is a VectorType.";
185 else
186 return success();
187 }
188
189 if (mdescTy.getRank() != 2)
190 return emitError() << "mem_desc must be 2D.";
191
192 ArrayRef<int64_t> dataShape = dataTy.getShape();
193 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194
195 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
197 SmallVector<int64_t> strides;
198 for (Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
200 }
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() << "With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() << "With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() << "With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
218 }
219 }
220 }
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() << "data shape must not exceed mem_desc shape.";
225 } else {
226 // if the subgroup_block_io attribute is set, mdescTy must have block
227 // attribute
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() << "mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
231 // if the subgroup_block_io attribute is set, the memdesc should be row
232 // major
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() << "mem_desc should be row major when "
235 "subgroup_block_io is set.";
236 }
237
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// XeGPU_CreateNdDescOp
243//===----------------------------------------------------------------------===//
244
245void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
246 Type tdesc, TypedValue<MemRefType> source) {
247 [[maybe_unused]] auto ty = source.getType();
248 assert(ty.hasStaticShape() && "expecting a memref with static shape");
249
250 build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
251 ValueRange({}) /* empty dynamic shape */,
252 ValueRange({}) /* empty dynamic strides */,
253 DenseI64ArrayAttr({}) /* const offsets */,
254 DenseI64ArrayAttr({}) /* empty const shape*/,
255 DenseI64ArrayAttr({}) /* empty const strides*/);
256}
257
258void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
259 Type tdesc, Value source,
262 Type srcTy = source.getType();
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
265
266 llvm::SmallVector<Value> dynamicShape;
267 llvm::SmallVector<Value> dynamicStrides;
268
269 llvm::SmallVector<int64_t> staticShape;
270 llvm::SmallVector<int64_t> staticStrides;
271
272 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
273 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
274
275 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
276 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
277
278 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281
282 // if shape and strides are from Memref, we don't need attributes for them
283 // to keep the IR print clean (only do so for full-static case, otherwise
284 // printer would fail trying to print empty array-attr).
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
287 staticShapeAttr = DenseI64ArrayAttr();
288 staticStridesAttr = DenseI64ArrayAttr();
289 }
290 }
291
292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
293 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
294 staticStridesAttr);
295}
296
297void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
298 Type tdesc, TypedValue<MemRefType> source,
300 [[maybe_unused]] auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
302
303 llvm::SmallVector<int64_t> staticOffsets;
304 llvm::SmallVector<Value> dynamicOffsets;
305 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
306
307 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
308 ValueRange({}) /* empty dynamic shape */,
309 ValueRange({}) /* empty dynamic strides */,
310 builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
311 {} /* empty const shape*/, {} /* empty const strides*/);
312}
313
314void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
315 Type tdesc, Value source,
319 assert(!shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() && shape.size() == offsets.size());
321
322 Type srcTy = source.getType();
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
325
326 llvm::SmallVector<Value> dynamicOffsets;
327 llvm::SmallVector<Value> dynamicShape;
328 llvm::SmallVector<Value> dynamicStrides;
329
330 llvm::SmallVector<int64_t> staticOffsets;
331 llvm::SmallVector<int64_t> staticShape;
332 llvm::SmallVector<int64_t> staticStrides;
333
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
336 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
337
338 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
339 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
340 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
341
342 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
345
346 // if shape and strides are from Memref, we don't need attributes for them
347 // to keep the IR print clean (only do so for full-static case, otherwise
348 // printer would fail trying to print empty array-attr).
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
351 staticShapeAttr = DenseI64ArrayAttr();
352 staticStridesAttr = DenseI64ArrayAttr();
353 }
354 }
355
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
358}
359
360LogicalResult CreateNdDescOp::verify() {
361 size_t rank = getMixedSizes().size();
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy = false;
364
365 // Memory space of created TensorDesc should match with the source.
366 // Both source and TensorDesc are considered for global memory by default,
367 // if the memory scope attr is not specified. If source is an integer,
368 // it is considered as ptr to global memory.
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
372 return emitOpError("Memory space mismatch.")
373 << " Source: " << srcMemorySpace
374 << ", TensorDesc: " << tdescMemorySpace;
375
376 if (size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
378
379 // check source type matches the rank if it is a memref.
380 // It also should have the same ElementType as TensorDesc.
381 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
382 invalidElemTy |= memrefTy.getElementType() != getElementType();
383
384 if (llvm::isa<IntegerType>(getSourceType())) {
385 // strides and shape must present for integer source.
386 if (getMixedStrides().empty() || getMixedSizes().empty())
387 return emitOpError("expecting strides and shape to be present for "
388 "integer source.");
389 }
390
391 if (invalidRank)
392 return emitOpError(
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
395
396 // check result TensorDesc rank
397 if (getType().getRank() > (int64_t)rank)
398 return emitOpError(
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
401
402 if (invalidElemTy)
403 return emitOpError("TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
405
406 if (getType().isScattered())
407 return emitOpError("Expects a non-scattered TensorDesc.\n");
408
409 return success();
410}
411
413 OpAsmParser &parser,
415 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
417
418 SmallVector<int64_t, 4> integerVals;
419 auto parseIntegerOrValue = [&]() {
421 auto res = parser.parseOptionalOperand(operand);
422
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
427 return failure();
428 } else {
429 int64_t integer;
430 if (failed(parser.parseInteger(integer)))
431 return failure();
432 integerVals.push_back(integer);
433 }
434 return success();
435 };
436
437 // If the optional values are given there must be left bracket
438 if (parser.parseOptionalLSquare().succeeded()) {
439 if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
440 parser.parseRSquare())
441 return parser.emitError(parser.getNameLoc())
442 << "expected a list of SSA values or integers";
443 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
444 return success();
445 }
446
447 return success();
448}
449
451 OperandRange values,
452 DenseI64ArrayAttr integers) {
453 if (!integers || integers.empty())
454 return;
455 printDynamicIndexList(printer, op, values, integers,
456 /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
457}
458//===----------------------------------------------------------------------===//
459// XeGPU_PrefetchNdOp
460//===----------------------------------------------------------------------===//
461
462void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
466
467 return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
468 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
469}
470
471void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
472 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint,
476 xegpu::DistributeLayoutAttr layout) {
477 SmallVector<Value> dynamicOffsets;
478 SmallVector<int64_t> staticOffsets;
479 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
480
481 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
482
483 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
484 l2_hint, l3_hint, /*anchor_layout=*/layout);
485}
486
487LogicalResult PrefetchNdOp::verify() {
488 auto tdescTy = getTensorDescType();
489 if (tdescTy.isScattered())
490 return emitOpError("Expects a non-scattered TensorDesc.\n");
491
492 if (!isReadHintOrNone(getL1HintAttr()))
493 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494
495 if (!isReadHintOrNone(getL2HintAttr()))
496 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497
498 if (!isReadHintOrNone(getL3HintAttr()))
499 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500
501 int64_t tDescRank = tdescTy.getRank();
502 int64_t offsetSize = getMixedOffsets().size();
503 if (offsetSize != 0 && offsetSize != tDescRank)
504 return emitOpError(
505 "Mismatched ranks between offsets and tensor descriptor");
506
507 return success();
508}
509
510//===----------------------------------------------------------------------===//
511// XeGPU_LoadNdOp
512//===----------------------------------------------------------------------===//
513
514void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
515 Value tensorDesc, UnitAttr packed,
516 DenseI64ArrayAttr transpose,
517 xegpu::CachePolicyAttr l1_hint,
518 xegpu::CachePolicyAttr l2_hint,
519 xegpu::CachePolicyAttr l3_hint) {
520
521 return build(builder, state, retType, tensorDesc, ValueRange(),
522 DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
523 l3_hint, /*anchor_layout=*/nullptr);
524}
525
526void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
527 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
528 UnitAttr packed, DenseI64ArrayAttr transpose,
529 xegpu::CachePolicyAttr l1_hint,
530 xegpu::CachePolicyAttr l2_hint,
531 xegpu::CachePolicyAttr l3_hint,
532 xegpu::DistributeLayoutAttr layout) {
533 SmallVector<Value> dynamicOffsets;
534 SmallVector<int64_t> staticOffsets;
535 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
536
537 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
538
539 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
540 packed, transpose, l1_hint, l2_hint, l3_hint,
541 /*anchor_layout=*/layout);
542}
543
544LogicalResult LoadNdOp::verify() {
545 auto tdescTy = getTensorDescType();
546 auto valueTy = getType();
547
548 if (tdescTy.isScattered())
549 return emitOpError("Expects a non-scattered TensorDesc.\n");
550
551 if (tdescTy.getRank() > 2)
552 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
553
554 if (!valueTy)
555 return emitOpError("Invalid result, it should be a VectorType.\n");
556
557 if (!isReadHintOrNone(getL1HintAttr()))
558 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
559
560 if (!isReadHintOrNone(getL2HintAttr()))
561 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
562
563 if (!isReadHintOrNone(getL3HintAttr()))
564 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
565
566 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
567 int valueElems = valueTy.getNumElements();
568
569 // If the result vector is 1D and has less elements than the tensor
570 // descriptor, it is supposed to be a SIMT op. The layout attribute in
571 // tensor_desc is not needed.
572 if (valueElems < tdescElems && valueTy.getRank() == 1) {
573 // SIMT mode doesn't need LayoutAttr.
574 if (tdescTy.getLayoutAttr())
575 return emitOpError()
576 << "TensorDesc doesn't need LayoutAttr for SIMT code";
577
578 // For SIMT code, the load is evenly distributed across all lanes in a
579 // subgroup. Since subgroup size is arch dependent, we only check even
580 // distribution here.
581 if (tdescElems % valueElems)
582 return emitOpError()
583 << "Result shape " << makeString(getShapeOf(valueTy))
584 << " is not a valid distribution for tensor descriptor "
585 << tdescTy;
586
587 return success();
588 }
589
590 // Check SIMD mode.
591 auto tdescShape = getShapeOf(tdescTy);
592 auto valueShape = getShapeOf(valueTy);
593
594 if (getTranspose()) {
595 auto trans = getTranspose().value();
596 // Make sure the transpose value is valid, and apply it
597 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
598 tdescShape = applyPermutation(tdescShape, trans);
599 else
600 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
601 }
602
603 if (getPacked()) {
604 if (tdescTy.getRank() == 2) {
605 const int axis = 0;
606 auto vnni_factor = valueShape.back();
607 tdescShape[axis] /= vnni_factor;
608 tdescShape.push_back(vnni_factor);
609 } else {
610 mlir::emitWarning(getLoc())
611 << "Invalid Packed Attr. It is ignored (available for 2D "
612 "TensorDesc only).";
613 }
614 }
615
616 auto array_len = tdescTy.getArrayLength();
617 if (array_len > 1)
618 tdescShape.insert(tdescShape.begin(), array_len);
619
620 if (tdescShape != valueShape)
621 return emitOpError() << "Result shape " << makeString(valueShape)
622 << " is not consistent with tensor descriptor "
623 << tdescTy;
624
625 int64_t tDescRank = tdescTy.getRank();
626 int64_t offsetSize = getMixedOffsets().size();
627 if (offsetSize != 0 && offsetSize != tDescRank)
628 return emitOpError(
629 "Mismatched ranks between offsets and tensor descriptor");
630
631 return success();
632}
633
634//===----------------------------------------------------------------------===//
635// XeGPU_StoreNdOp
636//===----------------------------------------------------------------------===//
637
638void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
639 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
640 xegpu::CachePolicyAttr l2_hint,
641 xegpu::CachePolicyAttr l3_hint) {
642
643 return build(builder, state, value, tensorDesc, ValueRange(),
644 DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
645 /*anchor_layout=*/nullptr);
646}
647
648void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
649 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
650 xegpu::CachePolicyAttr l1_hint,
651 xegpu::CachePolicyAttr l2_hint,
652 xegpu::CachePolicyAttr l3_hint,
653 xegpu::DistributeLayoutAttr layout) {
654 SmallVector<Value> dynamicOffsets;
655 SmallVector<int64_t> staticOffsets;
656 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
657
658 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
659
660 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
661 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
662}
663
664LogicalResult StoreNdOp::verify() {
665 auto dstTy = getTensorDescType(); // Tile
666 auto valTy = getValueType(); // Vector
667
668 if (dstTy.isScattered())
669 return emitOpError("Expects a non-scattered TensorDesc.\n");
670
671 if (dstTy.getRank() > 2)
672 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
673
674 if (!valTy)
675 return emitOpError("Expecting a VectorType result.\n");
676
677 if (!isWriteHintOrNone(getL1HintAttr()))
678 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
679
680 if (!isWriteHintOrNone(getL2HintAttr()))
681 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
682
683 if (!isWriteHintOrNone(getL3HintAttr()))
684 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
685
686 auto array_len = dstTy.getArrayLength();
687 if (array_len > 1)
688 return emitOpError("array length is not supported by store_nd.\n");
689
690 auto tdescElems = dstTy.getNumElements();
691 auto valueElems = valTy.getNumElements();
692
693 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
694 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
695 // in tensor_desc is not needed.
696 if (valTy.getRank() == 1 && valueElems < tdescElems) {
697 // SIMT mode doesn't need LayoutAttr.
698 if (dstTy.getLayoutAttr())
699 return emitOpError()
700 << "TensorDesc doesn't need LayoutAttr for SIMT code";
701
702 if (tdescElems % valueElems)
703 return emitOpError()
704 << "Value shape " << makeString(getShapeOf(valTy))
705 << " is not a valid distribution for tensor descriptor " << dstTy;
706
707 return success();
708 }
709
710 // SIMD code should have the same shape as the tensor descriptor.
711 auto tdescShape = getShapeOf(dstTy);
712 auto valueShape = getShapeOf(valTy);
713 if (tdescShape != valueShape)
714 return emitOpError() << "Value shape " << makeString(valueShape)
715 << " is not consistent with tensor descriptor "
716 << dstTy;
717
718 int64_t tDescRank = dstTy.getRank();
719 int64_t offsetSize = getMixedOffsets().size();
720 if (offsetSize != 0 && offsetSize != tDescRank)
721 return emitOpError(
722 "Mismatched ranks between offsets and tensor descriptor");
723
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// XeGPU_UpdateNDOffsetOp
729//===----------------------------------------------------------------------===//
730LogicalResult UpdateNdOffsetOp::verify() {
731 auto ty = getTensorDescType();
732 if (ty.isScattered())
733 return emitOpError("Expects a non-scattered TensorDesc.\n");
734
735 // number of offsets specified must match the rank of the tensor descriptor
736 if (ty.getRank() != (int64_t)getNumOffsets()) {
737 return emitOpError("Invalid number of offsets.");
738 }
739 return success();
740}
741
742//===----------------------------------------------------------------------===//
743// XeGPU_CreateDescOp
744//===----------------------------------------------------------------------===//
745
746void CreateDescOp::build(OpBuilder &builder, OperationState &state,
747 TensorDescType TensorDesc, Value source,
749 auto loc = source.getLoc();
750 int64_t size = static_cast<int64_t>(offsets.size());
751 auto type = VectorType::get(size, builder.getIndexType());
752 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
753 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
754 build(builder, state, TensorDesc, source, offset);
755}
756
757void CreateDescOp::build(OpBuilder &builder, OperationState &state,
758 TensorDescType TensorDesc, Value source,
759 llvm::ArrayRef<int64_t> offsets) {
760 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
761 build(builder, state, TensorDesc, source, ofrs);
762}
763
764LogicalResult CreateDescOp::verify() {
765 auto tdescTy = getTensorDescType();
766
767 if (!tdescTy.isScattered())
768 return emitOpError("Expects a scattered TensorDesc.\n");
769
770 // Memory space of created TensorDesc should match with the source.
771 // Both source and TensorDesc are considered for global memory by default,
772 // if the memory scope attr is not specified. If source is an integer,
773 // it is considered as ptr to global memory.
774 auto srcMemorySpace = getSourceMemorySpace();
775 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
776 if (srcMemorySpace != tdescMemorySpace)
777 return emitOpError("Memory space mismatch.")
778 << " Source: " << srcMemorySpace
779 << ", TensorDesc: " << tdescMemorySpace;
780
781 // check total size
782 auto chunkSize = tdescTy.getChunkSizeAsInt();
783 SmallVector<int64_t> shape(getOffsetsType().getShape());
784 if (chunkSize != 1)
785 shape.push_back(chunkSize);
786
787 auto tdescShape = getShapeOf(tdescTy);
788 if (shape != tdescShape)
789 return emitOpError("Incorrect TensorDesc shape. ")
790 << "Expected is " << makeString(shape) << "\n";
791
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// XeGPU_PrefetchOp
797//===----------------------------------------------------------------------===//
798LogicalResult PrefetchOp::verify() {
799 auto tdescTy = getTensorDescType();
800
801 if (!tdescTy && !getOffsets())
802 return emitOpError("Expects offsets.");
803
804 if (tdescTy && getOffsets())
805 return emitOpError("offsets not allowed.");
806
807 if (tdescTy && !tdescTy.isScattered())
808 return emitOpError("Expects a scattered TensorDesc.");
809
810 if (!isReadHintOrNone(getL1HintAttr()))
811 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
812
813 if (!isReadHintOrNone(getL2HintAttr()))
814 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
815
816 if (!isReadHintOrNone(getL3HintAttr()))
817 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
818
819 auto srcTy = getSourceType();
820 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
821 return emitOpError("offset_align_byte is required with integer source.");
822
823 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
824 return emitOpError("offset_align_byte only allowed with integer source.");
825
826 return success();
827}
828
829void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
830 xegpu::CachePolicyAttr l1_hint,
831 xegpu::CachePolicyAttr l2_hint,
832 xegpu::CachePolicyAttr l3_hint) {
833 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
834 IntegerAttr{}, /*anchor_layout=*/nullptr);
835}
836
837//===----------------------------------------------------------------------===//
838// XeGPU_LoadGatherOp
839//===----------------------------------------------------------------------===//
840LogicalResult LoadGatherOp::verify() {
841 auto tdescTy = getTensorDescType();
842 auto maskTy = getMaskType();
843 auto valueTy = getValueType();
844
845 if (!tdescTy && !getOffsets())
846 return emitOpError("Expects offsets.");
847
848 if (tdescTy && getOffsets())
849 return emitOpError("offsets not allowed.");
850
851 if (tdescTy && !tdescTy.isScattered())
852 return emitOpError("Expects a scattered TensorDesc.");
853
854 if (!isReadHintOrNone(getL1HintAttr()))
855 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
856
857 if (!isReadHintOrNone(getL2HintAttr()))
858 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
859
860 if (!isReadHintOrNone(getL3HintAttr()))
861 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
862
863 if (tdescTy)
864 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
865 [&]() { return emitOpError(); });
866 auto srcTy = getSourceType();
867 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
868 auto memTy = dyn_cast<MemRefType>(srcTy);
869
870 if (memTy && (getElementType() != memTy.getElementType()))
871 return emitError() << "Value should have the same element type as MemRef.";
872
873 auto offsetsTy = getOffsets().getType();
874 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
875 [&]() { return emitOpError(); });
876}
877
878void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
879 Type valueType, Value source, Value mask,
880 xegpu::CachePolicyAttr l1_hint,
881 xegpu::CachePolicyAttr l2_hint,
882 xegpu::CachePolicyAttr l3_hint) {
883 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
884 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
885}
886
887void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
888 Type valueType, Value source,
889 ArrayRef<OpFoldResult> offsets, Value mask,
890 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
891 xegpu::CachePolicyAttr l2_hint,
892 xegpu::CachePolicyAttr l3_hint) {
893 auto loc = source.getLoc();
894 int64_t size = static_cast<int64_t>(offsets.size());
895 auto type = VectorType::get(size, builder.getIndexType());
896 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
897 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
898
899 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
900 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
901}
902
903void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
904 Type valueType, Value source,
905 ArrayRef<OpFoldResult> offsets, Value mask,
906 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
907 xegpu::CachePolicyAttr l2_hint,
908 xegpu::CachePolicyAttr l3_hint,
909 DistributeLayoutAttr layout) {
910 auto loc = source.getLoc();
911 int64_t size = static_cast<int64_t>(offsets.size());
912 auto type = VectorType::get(size, builder.getIndexType());
913 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
914 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
915
916 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
917 l2_hint, l3_hint, layout);
918}
919
920//===----------------------------------------------------------------------===//
921// XeGPU_StoreScatterOp
922//===----------------------------------------------------------------------===//
923LogicalResult StoreScatterOp::verify() {
924 auto tdescTy = getTensorDescType();
925 auto maskTy = getMaskType();
926 auto valueTy = getValueType();
927
928 if (!tdescTy && !getOffsets())
929 return emitOpError("Expects offsets.");
930
931 if (tdescTy && getOffsets())
932 return emitOpError("offsets not allowed.");
933
934 if (tdescTy && !tdescTy.isScattered())
935 return emitOpError("Expects a scattered TensorDesc.");
936
937 if (!isWriteHintOrNone(getL1HintAttr()))
938 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
939
940 if (!isWriteHintOrNone(getL2HintAttr()))
941 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
942
943 if (!isWriteHintOrNone(getL3HintAttr()))
944 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
945
946 if (tdescTy)
947 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
948 [&]() { return emitOpError(); });
949
950 auto destTy = getDestType();
951 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
952 auto memTy = dyn_cast<MemRefType>(destTy);
953
954 if (memTy && (getElementType() != memTy.getElementType()))
955 return emitError() << "Value should have the same element type as MemRef.";
956
957 auto offsetsTy = getOffsets().getType();
958 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
959 [&]() { return emitOpError(); });
960}
961
962void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
963 Value value, Value dest, Value mask,
964 xegpu::CachePolicyAttr l1_hint,
965 xegpu::CachePolicyAttr l2_hint,
966 xegpu::CachePolicyAttr l3_hint) {
967 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
968 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
969}
970
971void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
972 Value value, Value dest,
973 ArrayRef<OpFoldResult> offsets, Value mask,
974 IntegerAttr chunk_size,
975 xegpu::CachePolicyAttr l1_hint,
976 xegpu::CachePolicyAttr l2_hint,
977 xegpu::CachePolicyAttr l3_hint) {
978 auto loc = dest.getLoc();
979 int64_t size = static_cast<int64_t>(offsets.size());
980 auto type = VectorType::get(size, builder.getIndexType());
981 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
982 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
983
984 // Call the correct builder overload that does not expect result types.
985 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
986 l3_hint, /*anchor_layout=*/nullptr);
987}
988
989void StoreScatterOp::build(
990 OpBuilder &builder, OperationState &state, Value value, Value dest,
991 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
992 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
993 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
994 auto loc = dest.getLoc();
995 int64_t size = static_cast<int64_t>(offsets.size());
996 auto type = VectorType::get(size, builder.getIndexType());
997 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
998 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
999
1000 // Call the correct builder overload that does not expect result types.
1001 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1002 l3_hint, layout);
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// XeGPU_UpdateOffsetOp
1007//===----------------------------------------------------------------------===//
1008void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1009 mlir::Value tensorDesc,
1011 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
1012 assert(tdescTy && "Expecting the source is a TensorDescType value.");
1013 auto loc = tensorDesc.getLoc();
1014 int64_t size = static_cast<int64_t>(offsets.size());
1015 auto type = VectorType::get({size}, builder.getIndexType());
1016 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
1017 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1018 build(builder, state, tdescTy, tensorDesc, offset);
1019}
1020
1021void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1022 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
1023 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
1024 build(builder, state, tensorDesc, ofrs);
1025}
1026
1027LogicalResult UpdateOffsetOp::verify() {
1028 auto tdescTy = getTensorDescType();
1029 if (!tdescTy.isScattered())
1030 return emitOpError("Expects a scattered TensorDesc.\n");
1031
1032 SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
1033 SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
1034 if (tdescTy.getChunkSizeAsInt() > 1)
1035 expectedOffsetShape.pop_back();
1036
1037 if (expectedOffsetShape != offsetShape)
1038 return emitOpError(
1039 "Offsets should match TensorDesc except the chunk size dim.");
1040
1041 return success();
1042}
1043
1044//===----------------------------------------------------------------------===//
1045// XeGPU_DpasOp
1046//===----------------------------------------------------------------------===//
1047LogicalResult DpasOp::verify() {
1048 int64_t lhsRank = getLhsType().getRank();
1049 int64_t rhsRank = getRhsType().getRank();
1050 int64_t resRank = getResultType().getRank();
1051 auto lhsShape = getLhsType().getShape();
1052 auto rhsShape = getRhsType().getShape();
1053 auto resShape = getResultType().getShape();
1054
1055 if (getAcc() && getAcc().getType() != getResultType())
1056 return emitOpError("Expecting the acc type to be the same as result.");
1057
1058 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
1059 // It skips the semantic check since lack of architecture information.
1060 // Users need to ensure the correctness.
1061 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1062 auto numElems = getRhsType().getNumElements();
1063 auto elemTy = getRhsType().getElementType();
1064 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1065 if (numElems % factor != 0)
1066 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
1067 return success();
1068 }
1069
1070 // SIMD code
1071 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1072 return emitOpError(
1073 "expecting lhs and result to be a 2D vector, and rhs to be either "
1074 "2D or 3D (packed) vector.");
1075 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1076 if (bK != lhsShape[1])
1077 return emitOpError("K-dimension mismatch.");
1078 if (lhsShape[0] != resShape[0])
1079 return emitOpError("M-dimension mismatch.");
1080 if (rhsShape[1] != resShape[1])
1081 return emitOpError("N-dimension mismatch.");
1082
1083 return success();
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// XeGPU_ConvertLayoutOp
1088//===----------------------------------------------------------------------===//
1089LogicalResult ConvertLayoutOp::verify() {
1090 auto srcLayout = getInputLayout();
1091 auto resLayout = getTargetLayout();
1092 if (!srcLayout)
1093 return emitOpError("expected input layout.");
1094 if (!resLayout)
1095 return emitOpError("expected target layout.");
1096
1097 // both input and target layouts should be WgLayout or SgLayout at the same
1098 // time.
1099 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1100 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1101 return emitOpError("expected input layout and target layout be WgLayout or "
1102 "SgLayout at the same time.");
1103
1104 auto shape = getSource().getType().getShape();
1105 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1106 return emitOpError(
1107 "invalid input layout, data cannot be evenly distributed.");
1108
1109 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1110 return emitOpError(
1111 "invalid target layout, data cannot be evenly distributed.");
1112
1113 return mlir::success();
1114}
1115
1116OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1117 if (getInputLayout() == getTargetLayout())
1118 return getSource();
1119 return {};
1120}
1121
1122struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
1123 using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
1124 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
1125 PatternRewriter &rewriter) const override {
1126 if (op.getInputLayout() == op.getTargetLayout()) {
1127 rewriter.replaceOp(op, op.getSource());
1128 return success();
1129 }
1130 return failure();
1131 }
1132};
1133
1134void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1135 MLIRContext *context) {
1136 patterns.add<FoldConvertLayoutOp>(context);
1137}
1138
1139//===----------------------------------------------------------------------===//
1140// XeGPU_LoadMatrixOp
1141//===----------------------------------------------------------------------===//
1142void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1145 DistributeLayoutAttr layout) {
1146 llvm::SmallVector<Value> dynamicOffsets;
1147 llvm::SmallVector<int64_t> staticOffsets;
1148 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1149 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1150 // Call the generated builder with all parameters (including optional ones as
1151 // nullptr/empty)
1152 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1153 /*subgroup_block_io=*/nullptr, layout);
1154}
1155
1156LogicalResult LoadMatrixOp::verify() {
1157
1158 auto resTy = dyn_cast<VectorType>(getRes().getType());
1159 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1160 MemDescType mdescTy = getMemDesc().getType();
1161
1162 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1163 getLayoutAttr(), [&]() { return emitError(); });
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// XeGPU_StoreMatrixOp
1168//===----------------------------------------------------------------------===//
1169void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1172 DistributeLayoutAttr layout) {
1173 llvm::SmallVector<Value> dynamicOffsets;
1174 llvm::SmallVector<int64_t> staticOffsets;
1175 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1176 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1177 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1178 /*subgroup_block_io=*/nullptr, layout);
1179}
1180
1181LogicalResult StoreMatrixOp::verify() {
1182
1183 auto dataTy = dyn_cast<VectorType>(getData().getType());
1184 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1185 MemDescType mdescTy = getMemDesc().getType();
1186 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1187 getLayoutAttr(), [&]() { return emitError(); });
1188}
1189
1190namespace mlir {
1191#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1192} // namespace mlir
1193#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1194#define GET_OP_CLASSES
1195#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:775
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:177
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:125
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:450
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:412
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.