MLIR 23.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
16
17#include "llvm/Support/Debug.h"
18
19#define DEBUG_TYPE "xegpu"
20
21using namespace mlir;
22using namespace mlir::xegpu;
23
24template <typename T>
25static std::string makeString(T array, bool breakline = false) {
26 std::string buf;
27 buf.clear();
28 llvm::raw_string_ostream os(buf);
29 os << "[";
30 for (size_t i = 1; i < array.size(); i++) {
31 os << array[i - 1] << ", ";
32 if (breakline)
33 os << "\n\t\t";
34 }
35 os << array.back() << "]";
36 return buf;
37}
38
41 if (auto ty = llvm::dyn_cast<ShapedType>(type))
42 shape = SmallVector<int64_t>(ty.getShape());
43 else
44 shape.push_back(1);
45 return shape;
46}
47
48static bool isReadHintOrNone(const CachePolicyAttr &attr) {
49 if (!attr)
50 return true;
51 auto kind = attr.getValue();
52 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
53 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
54}
55
56static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
57 if (!attr)
58 return true;
59 auto kind = attr.getValue();
60 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
61 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
62}
63
64static LogicalResult
66 VectorType valueTy, int64_t chunkSize,
68
69 auto maskVecTy = dyn_cast<VectorType>(maskTy);
70 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
71 if (!valueTy) {
72 if (chunkSize > 1)
73 return emitError() << "Expecting chunk size == 1 for scalar result";
74 if (maskVecTy || offsetsVecTy)
75 return emitError() << "Expecting scalar mask and offsets.";
76 else if (maskVecTy && offsetsVecTy)
77 return emitError() << "Expecting a vector type result.";
78 return success();
79 }
80
81 auto valueSize = valueTy.getNumElements();
82 // SIMT mode with scalar mask and offsets.
83 if (!maskVecTy && !offsetsVecTy) {
84 if (valueSize != chunkSize)
85 return emitError() << "value elements must match chunk size "
86 << chunkSize;
87 return success();
88 }
89 auto maskShape = getShapeOf(maskTy);
90 auto valueShape = getShapeOf(valueTy);
91
92 if (!maskVecTy)
93 return emitError() << "Expecting a vector type mask.";
94 int64_t maskSize = maskVecTy.getNumElements();
95
96 if (chunkSize > 1) {
97 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
98 return emitError() << "value elements must match chunk size "
99 << chunkSize;
100 } else {
101 if (valueSize != maskSize)
102 return emitError()
103 << "Mask should match value except the chunk size dim.";
104 }
105 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
106 if (maskSize == 1)
107 return success();
108 if (chunkSize > 1)
109 expectedMaskShape.pop_back();
110 if (expectedMaskShape != maskShape)
111 return emitError() << "Mask should match value except the chunk size dim.";
112
113 return success();
114}
115
116LogicalResult
117IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
118 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
120
121 if (!dataTy) {
122 if (subgroup_block_io)
123 return emitError() << "subgroup_block_io "
124 "are only allowed when result is a VectorType.";
125 else
126 return success();
127 }
128
129 ArrayRef<int64_t> dataShape = dataTy.getShape();
130 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
131
132 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
133 ArrayAttr strideAttr = mdescTy.getStrideAttr();
134 SmallVector<int64_t> strides;
135 for (Attribute attr : strideAttr.getValue()) {
136 strides.push_back(cast<IntegerAttr>(attr).getInt());
137 }
138 if (subgroup_block_io && layout) {
139 auto laneData = layout.getEffectiveLaneDataAsInt();
140 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
141 if (!laneData.empty()) {
142 bool isLaneDataContiguous =
143 std::all_of(laneData.begin(), std::prev(laneData.end()),
144 [](int x) { return x == 1; });
145 if (!isLaneDataContiguous)
146 return emitError() << "With subgroup_block_io, accessed data must be "
147 "contiguous and coalesced.";
148 for (size_t i = 0; i < laneData.size(); ++i) {
149 if (laneLayout[i] != blockShape[i])
150 return emitError() << "With subgroup_block_io, the block shape must "
151 "match the lane layout.";
152 if (laneLayout[i] != 1 && strides[i] != 1)
153 return emitError() << "With subgroup_block_io, the distributed "
154 "dimensions must be contiguous.";
155 }
156 }
157 }
158
159 if (layout && !layout.isDistributable(
160 SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
161 return emitError() << "Value shape is not distributable with the layout";
162
163 if (dataShape.size() == mdescShape.size()) {
164 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
165 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
166 return emitError() << "data shape must not exceed mem_desc shape.";
167 }
168 // if the subgroup_block_io attribute is set, mdescTy must have block
169 // attribute
170 if (subgroup_block_io && !blockShape.size())
171 return emitError() << "mem_desc must have block attribute when "
172 "subgroup_block_io is set.";
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// XeGPU_CreateNdDescOp
178//===----------------------------------------------------------------------===//
179
180void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
181 Type tdesc, TypedValue<MemRefType> source) {
182 [[maybe_unused]] auto ty = source.getType();
183 assert(ty.hasStaticShape() && "expecting a memref with static shape");
184
185 build(builder, state, tdesc, source, ValueRange({}) /* empty dynamic shape */,
186 ValueRange({}) /* empty dynamic strides */,
187 DenseI64ArrayAttr({}) /* empty const shape*/,
188 DenseI64ArrayAttr({}) /* empty const strides*/);
189}
190
191void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
192 Type tdesc, Value source,
195 Type srcTy = source.getType();
196 assert((isa<IntegerType, MemRefType>(srcTy)) &&
197 "Source has to be either int or memref.");
198
199 llvm::SmallVector<Value> dynamicShape;
200 llvm::SmallVector<Value> dynamicStrides;
201
202 llvm::SmallVector<int64_t> staticShape;
203 llvm::SmallVector<int64_t> staticStrides;
204
205 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
206 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
207
208 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
209 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
210
211 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
212 auto memrefShape = memrefTy.getShape();
213 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
214
215 // if shape and strides are from Memref, we don't need attributes for them
216 // to keep the IR print clean (only do so for full-static case, otherwise
217 // printer would fail trying to print empty array-attr).
218 if (staticShape == memrefShape && staticStrides == memrefStrides &&
219 dynamicShape.empty() && dynamicStrides.empty()) {
220 staticShapeAttr = DenseI64ArrayAttr();
221 staticStridesAttr = DenseI64ArrayAttr();
222 }
223 }
224
225 build(builder, state, tdesc, source, dynamicShape, dynamicStrides,
226 staticShapeAttr, staticStridesAttr);
227}
228
229LogicalResult CreateNdDescOp::verify() {
230 size_t rank = getMixedSizes().size();
231 bool invalidRank = rank != getMixedStrides().size();
232 bool invalidElemTy = false;
233
234 // Memory space of created TensorDesc should match with the source.
235 // Both source and TensorDesc are considered for global memory by default,
236 // if the memory scope attr is not specified. If source is an integer,
237 // it is considered as ptr to global memory.
238 auto srcMemorySpace = getSourceMemorySpace();
239 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
240 if (srcMemorySpace != tdescMemorySpace)
241 return emitOpError("Memory space mismatch.")
242 << " Source: " << srcMemorySpace
243 << ", TensorDesc: " << tdescMemorySpace;
244
245 // check source type matches the rank if it is a memref.
246 // It also should have the same ElementType as TensorDesc.
247 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
248 invalidElemTy |= memrefTy.getElementType() != getElementType();
249
250 if (llvm::isa<IntegerType>(getSourceType())) {
251 // strides and shape must present for integer source.
252 if (getMixedStrides().empty() || getMixedSizes().empty())
253 return emitOpError("expecting strides and shape to be present for "
254 "integer source.");
255 }
256
257 if (invalidRank)
258 return emitOpError(
259 "Expecting the rank of shape, strides, and source (if source "
260 "is a memref) should match with each other.");
261
262 // check result TensorDesc rank
263 if (getType().getRank() > (int64_t)rank)
264 return emitOpError("Expecting the TensorDesc rank is not greater than the "
265 "ranks of shape, strides or the memref source.");
266
267 if (invalidElemTy)
268 return emitOpError("TensorDesc should have the same element "
269 "type with the source if it is a memref.\n");
270
271 return success();
272}
273
274//===----------------------------------------------------------------------===//
275// XeGPU_PrefetchNdOp
276//===----------------------------------------------------------------------===//
277
278void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
279 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
280 xegpu::CachePolicyAttr l1_hint,
281 xegpu::CachePolicyAttr l2_hint,
282 xegpu::CachePolicyAttr l3_hint,
283 xegpu::DistributeLayoutAttr layout) {
284 SmallVector<Value> dynamicOffsets;
285 SmallVector<int64_t> staticOffsets;
286 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
287
288 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
289
290 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
291 l2_hint, l3_hint, /*anchor_layout=*/layout);
292}
293
294LogicalResult PrefetchNdOp::verify() {
295 auto tdescTy = getTensorDescType();
296
297 if (!isReadHintOrNone(getL1HintAttr()))
298 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
299
300 if (!isReadHintOrNone(getL2HintAttr()))
301 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
302
303 if (!isReadHintOrNone(getL3HintAttr()))
304 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
305
306 int64_t tDescRank = tdescTy.getRank();
307 int64_t offsetSize = getMixedOffsets().size();
308 if (offsetSize != tDescRank)
309 return emitOpError(
310 "Mismatched ranks between offsets and tensor descriptor");
311
312 if (auto layout = getAnchorLayout()) {
313 if (!layout.isDistributable(getShapeOf(tdescTy)))
314 return emitOpError(
315 "TensorDesc shape is not distributable with the layout");
316 }
317
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// XeGPU_LoadNdOp
323//===----------------------------------------------------------------------===//
324
325void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
326 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
327 UnitAttr packed, DenseI64ArrayAttr transpose,
328 xegpu::CachePolicyAttr l1_hint,
329 xegpu::CachePolicyAttr l2_hint,
330 xegpu::CachePolicyAttr l3_hint,
331 xegpu::DistributeLayoutAttr layout) {
332 SmallVector<Value> dynamicOffsets;
333 SmallVector<int64_t> staticOffsets;
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335
336 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
337
338 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
339 packed, transpose, l1_hint, l2_hint, l3_hint,
340 /*anchor_layout=*/layout);
341}
342
343LogicalResult LoadNdOp::verify() {
344 auto tdescTy = getTensorDescType();
345 auto valueTy = getType();
346
347 if (!valueTy)
348 return emitOpError("Invalid result, it should be a VectorType.\n");
349
350 if (!isReadHintOrNone(getL1HintAttr()))
351 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
352
353 if (!isReadHintOrNone(getL2HintAttr()))
354 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
355
356 if (!isReadHintOrNone(getL3HintAttr()))
357 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
358
359 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
360 int valueElems = valueTy.getNumElements();
361
362 // If the result vector is 1D and has less elements than the tensor
363 // descriptor, it is supposed to be a SIMT op. The layout attribute in
364 // tensor_desc is not needed.
365 if (valueElems < tdescElems && valueTy.getRank() == 1) {
366 // SIMT mode doesn't need LayoutAttr.
367 if (tdescTy.getLayoutAttr())
368 return emitOpError()
369 << "TensorDesc doesn't need LayoutAttr for SIMT code";
370
371 // For SIMT code, the load is evenly distributed across all lanes in a
372 // subgroup. Since subgroup size is arch dependent, we only check even
373 // distribution here.
374 if (tdescElems % valueElems)
375 return emitOpError()
376 << "Result shape " << makeString(getShapeOf(valueTy))
377 << " is not a valid distribution for tensor descriptor "
378 << tdescTy;
379
380 return success();
381 }
382
383 // Check SIMD mode.
384 auto tdescShape = getShapeOf(tdescTy);
385 auto valueShape = getShapeOf(valueTy);
386
387 if (getTranspose()) {
388 auto trans = getTranspose().value();
389 // Make sure the transpose value is valid, and apply it
390 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
391 tdescShape = applyPermutation(tdescShape, trans);
392 else
393 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
394 }
395
396 if (getPacked()) {
397 if (tdescTy.getRank() == 2) {
398 const int axis = 0;
399 auto vnni_factor = valueShape.back();
400 tdescShape[axis] /= vnni_factor;
401 tdescShape.push_back(vnni_factor);
402 } else {
403 mlir::emitWarning(getLoc())
404 << "Invalid Packed Attr. It is ignored (available for 2D "
405 "TensorDesc only).";
406 }
407 }
408
409 // Handle array_length. Two result shape conventions are accepted:
410 // * 3D shape: leading array_length dimension prepended, e.g. descriptor
411 // 16x16 with array_length=2 -> [2, 16, 16].
412 // * Stacked 2D shape: array blocks stacked along the non-FCD (first)
413 // dimension, e.g. descriptor 16x16 with array_length=2 -> [32, 16].
414 auto array_len = tdescTy.getArrayLength();
415 SmallVector<int64_t> stacked2DShape(tdescShape);
416 SmallVector<int64_t> threeDShape(tdescShape);
417 if (array_len > 1 && !tdescShape.empty()) {
418 stacked2DShape[0] *= array_len;
419 threeDShape.insert(threeDShape.begin(), array_len);
420 }
421
422 if (valueShape != stacked2DShape && valueShape != threeDShape)
423 return emitOpError() << "Result shape " << makeString(valueShape)
424 << " is not consistent with tensor descriptor "
425 << tdescTy;
426
427 int64_t tDescRank = tdescTy.getRank();
428 int64_t offsetSize = getMixedOffsets().size();
429 if (offsetSize != tDescRank)
430 return emitOpError(
431 "Mismatched ranks between offsets and tensor descriptor");
432
433 if (auto layout = getAnchorLayout()) {
434 if (!layout.isDistributable(getShapeOf(tdescTy)))
435 return emitOpError(
436 "TensorDesc shape is not distributable with the layout");
437 }
438
439 return success();
440}
441
442//===----------------------------------------------------------------------===//
443// XeGPU_StoreNdOp
444//===----------------------------------------------------------------------===//
445
446void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
447 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
448 xegpu::CachePolicyAttr l1_hint,
449 xegpu::CachePolicyAttr l2_hint,
450 xegpu::CachePolicyAttr l3_hint,
451 xegpu::DistributeLayoutAttr layout) {
452 SmallVector<Value> dynamicOffsets;
453 SmallVector<int64_t> staticOffsets;
454 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
455
456 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
457
458 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
459 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
460}
461
462LogicalResult StoreNdOp::verify() {
463 auto dstTy = getTensorDescType(); // Tile
464 auto valTy = getValueType(); // Vector
465
466 if (!valTy)
467 return emitOpError("Expecting a VectorType result.\n");
468
469 if (!isWriteHintOrNone(getL1HintAttr()))
470 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
471
472 if (!isWriteHintOrNone(getL2HintAttr()))
473 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
474
475 if (!isWriteHintOrNone(getL3HintAttr()))
476 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
477
478 auto array_len = dstTy.getArrayLength();
479 if (array_len > 1)
480 return emitOpError("array length is not supported by store_nd.\n");
481
482 auto tdescElems = dstTy.getNumElements();
483 auto valueElems = valTy.getNumElements();
484
485 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
486 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
487 // in tensor_desc is not needed.
488 if (valTy.getRank() == 1 && valueElems < tdescElems) {
489 // SIMT mode doesn't need LayoutAttr.
490 if (dstTy.getLayoutAttr())
491 return emitOpError()
492 << "TensorDesc doesn't need LayoutAttr for SIMT code";
493
494 if (tdescElems % valueElems)
495 return emitOpError()
496 << "Value shape " << makeString(getShapeOf(valTy))
497 << " is not a valid distribution for tensor descriptor " << dstTy;
498
499 return success();
500 }
501
502 // SIMD code should have the same shape as the tensor descriptor.
503 auto tdescShape = getShapeOf(dstTy);
504 auto valueShape = getShapeOf(valTy);
505 if (tdescShape != valueShape)
506 return emitOpError() << "Value shape " << makeString(valueShape)
507 << " is not consistent with tensor descriptor "
508 << dstTy;
509
510 int64_t tDescRank = dstTy.getRank();
511 int64_t offsetSize = getMixedOffsets().size();
512 if (offsetSize != tDescRank)
513 return emitOpError(
514 "Mismatched ranks between offsets and tensor descriptor");
515
516 if (auto layout = getAnchorLayout()) {
517 if (!layout.isDistributable(tdescShape))
518 return emitOpError(
519 "TensorDesc shape is not distributable with the layout");
520 }
521
522 return success();
523}
524
525//===----------------------------------------------------------------------===//
526// XeGPU_PrefetchOp
527//===----------------------------------------------------------------------===//
528LogicalResult PrefetchOp::verify() {
529 if (!isReadHintOrNone(getL1HintAttr()))
530 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
531
532 if (!isReadHintOrNone(getL2HintAttr()))
533 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
534
535 if (!isReadHintOrNone(getL3HintAttr()))
536 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
537
538 auto srcTy = getSourceType();
539 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
540 return emitOpError("offset_align_byte is required with integer source.");
541
542 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
543 return emitOpError("offset_align_byte only allowed with integer source.");
544
545 if (auto layout = getAnchorLayout()) {
546 // get the offset operand and its shape
547 auto offsetsTy = getOffsets().getType();
548 if (llvm::isa<VectorType>(offsetsTy) &&
549 !layout.isDistributable(getShapeOf(offsetsTy)))
550 return emitOpError("offset shape is not distributable with the layout");
551 }
552
553 return success();
554}
555
556//===----------------------------------------------------------------------===//
557// XeGPU_LoadGatherOp
558//===----------------------------------------------------------------------===//
559LogicalResult LoadGatherOp::verify() {
560 auto maskTy = getMaskType();
561 auto valueTy = getValueType();
562
563 if (!isReadHintOrNone(getL1HintAttr()))
564 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
565
566 if (!isReadHintOrNone(getL2HintAttr()))
567 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
568
569 if (!isReadHintOrNone(getL3HintAttr()))
570 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
571
572 auto srcTy = getSourceType();
573 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
574 auto memTy = dyn_cast<MemRefType>(srcTy);
575
576 if (memTy && (getElementType() != memTy.getElementType()))
577 return emitError() << "Value should have the same element type as MemRef.";
578
579 if (auto layout = getAnchorLayout()) {
580 if (!layout.isDistributable(getShapeOf(valueTy)))
581 return emitOpError("Value shape is not distributable with the layout");
582 }
583
584 auto offsetsTy = getOffsets().getType();
585 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
586 [&]() { return emitOpError(); });
587}
588
589void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
590 Type valueType, Value source,
591 ArrayRef<OpFoldResult> offsets, Value mask,
592 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
593 xegpu::CachePolicyAttr l2_hint,
594 xegpu::CachePolicyAttr l3_hint) {
595 auto loc = source.getLoc();
596 int64_t size = static_cast<int64_t>(offsets.size());
597 auto type = VectorType::get(size, builder.getIndexType());
598 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
599 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
600
601 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
602 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
603}
604
605void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
606 Type valueType, Value source,
607 ArrayRef<OpFoldResult> offsets, Value mask,
608 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
609 xegpu::CachePolicyAttr l2_hint,
610 xegpu::CachePolicyAttr l3_hint,
611 DistributeLayoutAttr layout) {
612 auto loc = source.getLoc();
613 int64_t size = static_cast<int64_t>(offsets.size());
614 auto type = VectorType::get(size, builder.getIndexType());
615 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
616 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
617
618 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
619 l2_hint, l3_hint, layout);
620}
621
622//===----------------------------------------------------------------------===//
623// XeGPU_StoreScatterOp
624//===----------------------------------------------------------------------===//
625LogicalResult StoreScatterOp::verify() {
626 auto maskTy = getMaskType();
627 auto valueTy = getValueType();
628
629 if (!isWriteHintOrNone(getL1HintAttr()))
630 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
631
632 if (!isWriteHintOrNone(getL2HintAttr()))
633 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
634
635 if (!isWriteHintOrNone(getL3HintAttr()))
636 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
637
638 auto destTy = getDestType();
639 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
640 auto memTy = dyn_cast<MemRefType>(destTy);
641
642 if (memTy && (getElementType() != memTy.getElementType()))
643 return emitError() << "Value should have the same element type as MemRef.";
644
645 if (auto layout = getAnchorLayout()) {
646 if (!layout.isDistributable(getShapeOf(valueTy)))
647 return emitOpError("Value shape is not distributable with the layout");
648 }
649
650 auto offsetsTy = getOffsets().getType();
651 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
652 [&]() { return emitOpError(); });
653}
654
655void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
656 Value value, Value dest,
657 ArrayRef<OpFoldResult> offsets, Value mask,
658 IntegerAttr chunk_size,
659 xegpu::CachePolicyAttr l1_hint,
660 xegpu::CachePolicyAttr l2_hint,
661 xegpu::CachePolicyAttr l3_hint) {
662 auto loc = dest.getLoc();
663 int64_t size = static_cast<int64_t>(offsets.size());
664 auto type = VectorType::get(size, builder.getIndexType());
665 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
666 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
667
668 // Call the correct builder overload that does not expect result types.
669 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
670 l3_hint, /*anchor_layout=*/nullptr);
671}
672
673void StoreScatterOp::build(
674 OpBuilder &builder, OperationState &state, Value value, Value dest,
675 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
676 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
677 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
678 auto loc = dest.getLoc();
679 int64_t size = static_cast<int64_t>(offsets.size());
680 auto type = VectorType::get(size, builder.getIndexType());
681 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
682 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
683
684 // Call the correct builder overload that does not expect result types.
685 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
686 l3_hint, layout);
687}
688
689//===----------------------------------------------------------------------===//
690// DPAS Common Verification Helpers
691//===----------------------------------------------------------------------===//
692
693// Helper to verify layout distributability for a value
694static LogicalResult
696 std::optional<DistributeLayoutAttr> layout,
697 ArrayRef<int64_t> shape, StringRef operandName) {
698 if (layout && !layout->isDistributable(
699 SmallVector<int64_t>(shape.begin(), shape.end())))
700 return op->emitOpError(operandName)
701 << " shape is not distributable with the layout";
702 return success();
703}
704
705// Helper to verify M, N, K dimensions match between A, B, and result matrices
706static LogicalResult verifyDpasDimensions(Operation *op,
707 ArrayRef<int64_t> aShape,
708 ArrayRef<int64_t> bShape,
709 ArrayRef<int64_t> resShape) {
710
711 auto aRank = aShape.size();
712 auto bRank = bShape.size();
713 auto resRank = resShape.size();
714 if (aRank == 1 && bRank == 1 && resRank == 1)
715 return success();
716
717 // A must be at least 2D, B must be 2D or 3D (innermost dims), result at
718 // least 2D.
719 if (aRank < 2)
720 return op->emitOpError("A operand must be at least a 2D vector.");
721 if (bRank < 2)
722 return op->emitOpError("B operand must be at least a 2D vector.");
723 if (resRank < 2)
724 return op->emitOpError("Result must be at least a 2D vector.");
725
726 // FIXME: B may have one extra trailing dim for VNNI packing
727 // (B[batch..., K/vnni, N, vnni]). We plan to drop VNNI packing support, so
728 // rather than properly verifying the packed dimensions, we simply accept
729 // the packed form here and skip the detailed verification. This branch
730 // should be removed once VNNI packing support is dropped.
731 if (bRank == aRank + 1)
732 return success();
733
734 // All operands have the same rank. They share the same batch dimensions,
735 // with the last two dims being the core matmul dims: A[batch..., M, K],
736 // B[batch..., K, N], result[batch..., M, N].
737 if (aRank != bRank || aRank != resRank)
738 return op->emitOpError("Rank mismatch among A, B, and result.");
739
740 int64_t batchRank = aRank - 2;
741
742 // Verify batch dimensions match.
743 for (int64_t i = 0; i < batchRank; ++i) {
744 if (aShape[i] != resShape[i])
745 return op->emitOpError("Batch dimension mismatch at dim ")
746 << i << ": A has " << aShape[i] << " but result has "
747 << resShape[i] << ".";
748 if (aShape[i] != bShape[i])
749 return op->emitOpError("Batch dimension mismatch at dim ")
750 << i << ": A has " << aShape[i] << " but B has " << bShape[i]
751 << ".";
752 }
753
754 // Core matmul dimensions (last two dims of each operand).
755 int64_t aM = aShape[batchRank];
756 int64_t aK = aShape[batchRank + 1];
757 int64_t bK = bShape[batchRank];
758 int64_t bN = bShape[batchRank + 1];
759 int64_t resM = resShape[batchRank];
760 int64_t resN = resShape[batchRank + 1];
761
762 // Verify K dimension match between A and B
763 if (bK != aK)
764 return op->emitOpError("K-dimension mismatch: A has K=")
765 << aK << " but B has K=" << bK << ".";
766
767 // Verify M dimension match between A and result
768 if (aM != resM)
769 return op->emitOpError("M-dimension mismatch: A has M=")
770 << aM << " but result has M=" << resM << ".";
771
772 // Verify N dimension match between B and result
773 if (bN != resN)
774 return op->emitOpError("N-dimension mismatch: B has N=")
775 << bN << " but result has N=" << resN << ".";
776
777 return success();
778}
779
780// Helper to verify accumulator matches result type
781static LogicalResult verifyDpasAccumulator(Operation *op, Type accType,
782 Type resultType) {
783 if (accType != resultType)
784 return op->emitOpError("Accumulator type must match result type.");
785 return success();
786}
787
788//===----------------------------------------------------------------------===//
789// XeGPU_DpasOp
790//===----------------------------------------------------------------------===//
791LogicalResult DpasOp::verify() {
792 auto lhsShape = getLhsType().getShape();
793 auto rhsShape = getRhsType().getShape();
794 auto resShape = getResultType().getShape();
795
796 // Verify layout distributability
797 if (failed(
798 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
799 return failure();
800 if (failed(verifyLayoutDistributable(*this, getLayoutA(), lhsShape, "A")))
801 return failure();
802 if (failed(verifyLayoutDistributable(*this, getLayoutB(), rhsShape, "B")))
803 return failure();
804
805 // Verify accumulator if present
806 if (getAcc() &&
807 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
808 return failure();
809
810 return verifyDpasDimensions(*this, lhsShape, rhsShape, resShape);
811}
812
813//===----------------------------------------------------------------------===//
814// XeGPU_ConvertLayoutOp
815//===----------------------------------------------------------------------===//
816LogicalResult ConvertLayoutOp::verify() {
817 auto srcLayout = getInputLayout();
818 auto resLayout = getTargetLayout();
819 if (!srcLayout)
820 return emitOpError("expected input layout.");
821 if (!resLayout)
822 return emitOpError("expected target layout.");
823
824 // both input and target layouts should be WgLayout or SgLayout at the same
825 // time.
826 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
827 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
828 return emitOpError("expected input layout and target layout be WgLayout or "
829 "SgLayout at the same time.");
830
831 Type srcType = getSource().getType();
832 if (llvm::isa<VectorType>(srcType)) {
833 SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
834 if (!srcLayout.isDistributable(shape))
835 return emitOpError(
836 "invalid input layout, data cannot be evenly distributed.");
837
838 if (!resLayout.isDistributable(shape))
839 return emitOpError(
840 "invalid target layout, data cannot be evenly distributed.");
841 }
842 return mlir::success();
843}
844
845//===----------------------------------------------------------------------===//
846// XeGPU_LoadMatrixOp
847//===----------------------------------------------------------------------===//
848void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
851 DistributeLayoutAttr layout) {
852 llvm::SmallVector<Value> dynamicOffsets;
853 llvm::SmallVector<int64_t> staticOffsets;
854 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
855 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
856 // Call the generated builder with all parameters (including optional ones as
857 // nullptr/empty)
858 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
859 /*subgroup_block_io=*/nullptr, layout);
860}
861
862LogicalResult LoadMatrixOp::verify() {
863
864 auto resTy = dyn_cast<VectorType>(getRes().getType());
865 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
866 MemDescType mdescTy = getMemDesc().getType();
867
868 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
869 getLayoutAttr(), [&]() { return emitError(); });
870}
871
872//===----------------------------------------------------------------------===//
873// XeGPU_StoreMatrixOp
874//===----------------------------------------------------------------------===//
875void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
878 DistributeLayoutAttr layout) {
879 llvm::SmallVector<Value> dynamicOffsets;
880 llvm::SmallVector<int64_t> staticOffsets;
881 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
882 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
883 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
884 /*subgroup_block_io=*/nullptr, layout);
885}
886
887LogicalResult StoreMatrixOp::verify() {
888
889 auto dataTy = dyn_cast<VectorType>(getData().getType());
890 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
891 MemDescType mdescTy = getMemDesc().getType();
892 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
893 getLayoutAttr(), [&]() { return emitError(); });
894}
895
896//===----------------------------------------------------------------------===//
897// XeGPU_TruncfOp
898//===----------------------------------------------------------------------===//
899
900LogicalResult TruncfOp::verify() {
901 auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
902 auto resultVecType = dyn_cast<VectorType>(getResult().getType());
903
904 if (sourceVecType.getElementTypeBitWidth() <=
905 resultVecType.getElementTypeBitWidth())
906 return emitOpError("input type must be wider than result type.");
907
908 return success();
909}
910
911//===----------------------------------------------------------------------===//
912// XeGPU_DpasMxOp
913//===----------------------------------------------------------------------===//
914
915LogicalResult DpasMxOp::verify() {
916 auto aShape = getAType().getShape();
917 auto bShape = getBType().getShape();
918 auto resShape = getResultType().getShape();
919
920 // Verify layout distributability for A, B, and result
921 if (failed(
922 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
923 return failure();
924 if (failed(verifyLayoutDistributable(*this, getLayoutA(), aShape, "A")))
925 return failure();
926 if (failed(verifyLayoutDistributable(*this, getLayoutB(), bShape, "B")))
927 return failure();
928
929 // Verify accumulator if present
930 if (getAcc() &&
931 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
932 return failure();
933
934 // Verify M, N, K dimensions
935 if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
936 return failure();
937
938 // Determine batch rank from A operand.
939 int64_t aBatchRank = aShape.size() - 2;
940
941 // Validate scale_a if present
942 if (getScaleA()) {
943 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
944 // Only validate if scale is a vector (scalars are always valid)
945 if (scaleAVecType && scaleAVecType.getRank() > 1) {
946 auto scaleAShape = scaleAVecType.getShape();
947
948 if (scaleAVecType.getRank() < 2)
949 return emitOpError("Scale A must be at least a 2D vector when not a "
950 "scalar.");
951
952 // Verify layout distributability for scale_a
953 if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
954 scaleAShape, "ScaleA")))
955 return failure();
956
957 // Validate M dimension: scale_a's M must match A's M (last-1 dim)
958 if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank])
959 return emitOpError("Scale A M dimension [")
960 << scaleAShape[scaleAShape.size() - 2]
961 << "] must match A M dimension [" << aShape[aBatchRank] << "].";
962 }
963 }
964
965 // Validate scale_b if present
966 if (getScaleB()) {
967 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
968 // Only validate if scale is a vector (scalars are always valid)
969 if (scaleBVecType && scaleBVecType.getRank() > 1) {
970 auto scaleBShape = scaleBVecType.getShape();
971
972 if (scaleBVecType.getRank() < 2)
973 return emitOpError("Scale B must be at least a 2D vector when not a "
974 "scalar.");
975
976 // Verify layout distributability for scale_b
977 if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
978 scaleBShape, "ScaleB")))
979 return failure();
980
981 // Validate N dimension: scale_b's N (last dim) must match B's N (last
982 // dim)
983 if (scaleBShape.back() != bShape.back())
984 return emitOpError("Scale B N dimension [")
985 << scaleBShape.back() << "] must match B N dimension ["
986 << bShape.back() << "].";
987 }
988 }
989
990 // Validate scale K dimension compatibility if both scales are present and
991 // vectors
992 if (getScaleA() && getScaleB()) {
993 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
994 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
995
996 if (scaleAVecType && scaleBVecType && scaleAVecType.getRank() > 1 &&
997 scaleBVecType.getRank() > 1) {
998 auto scaleAShape = scaleAVecType.getShape();
999 auto scaleBShape = scaleBVecType.getShape();
1000
1001 // Validate scale K dimension compatibility: scale_a's last dim must
1002 // match scale_b's second-to-last dim
1003 if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
1004 return emitOpError("Scale K dimension mismatch: scale_a has K=")
1005 << scaleAShape.back()
1006 << " but scale_b has K=" << scaleBShape[scaleBShape.size() - 2]
1007 << ".";
1008 }
1009 }
1010
1011 return success();
1012}
1013
1014namespace mlir {
1015#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1016} // namespace mlir
1017#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1018#define GET_OP_CLASSES
1019#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:831
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:39
static LogicalResult verifyDpasAccumulator(Operation *op, Type accType, Type resultType)
Definition XeGPUOps.cpp:781
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:117
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:25
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:56
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:48
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:65
static LogicalResult verifyDpasDimensions(Operation *op, ArrayRef< int64_t > aShape, ArrayRef< int64_t > bShape, ArrayRef< int64_t > resShape)
Definition XeGPUOps.cpp:706
static LogicalResult verifyLayoutDistributable(Operation *op, std::optional< DistributeLayoutAttr > layout, ArrayRef< int64_t > shape, StringRef operandName)
Definition XeGPUOps.cpp:695
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IndexType getIndexType()
Definition Builders.cpp:55
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This represents an operation in an abstracted form, suitable for use with the builder APIs.