MLIR 23.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
16
17#include "llvm/Support/Debug.h"
18
19#define DEBUG_TYPE "xegpu"
20
21using namespace mlir;
22using namespace mlir::xegpu;
23
24template <typename T>
25static std::string makeString(T array, bool breakline = false) {
26 std::string buf;
27 buf.clear();
28 llvm::raw_string_ostream os(buf);
29 os << "[";
30 for (size_t i = 1; i < array.size(); i++) {
31 os << array[i - 1] << ", ";
32 if (breakline)
33 os << "\n\t\t";
34 }
35 os << array.back() << "]";
36 return buf;
37}
38
41 if (auto ty = llvm::dyn_cast<ShapedType>(type))
42 shape = SmallVector<int64_t>(ty.getShape());
43 else
44 shape.push_back(1);
45 return shape;
46}
47
48static bool isReadHintOrNone(const CachePolicyAttr &attr) {
49 if (!attr)
50 return true;
51 auto kind = attr.getValue();
52 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
53 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
54}
55
56static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
57 if (!attr)
58 return true;
59 auto kind = attr.getValue();
60 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
61 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
62}
63
64static LogicalResult
66 VectorType valueTy, int64_t chunkSize,
68
69 auto maskVecTy = dyn_cast<VectorType>(maskTy);
70 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
71 if (!valueTy) {
72 if (chunkSize > 1)
73 return emitError() << "Expecting chunk size == 1 for scalar result";
74 if (maskVecTy || offsetsVecTy)
75 return emitError() << "Expecting scalar mask and offsets.";
76 else if (maskVecTy && offsetsVecTy)
77 return emitError() << "Expecting a vector type result.";
78 return success();
79 }
80
81 auto valueSize = valueTy.getNumElements();
82 // SIMT mode with scalar mask and offsets.
83 if (!maskVecTy && !offsetsVecTy) {
84 if (valueSize != chunkSize)
85 return emitError() << "value elements must match chunk size "
86 << chunkSize;
87 return success();
88 }
89 auto maskShape = getShapeOf(maskTy);
90 auto valueShape = getShapeOf(valueTy);
91
92 if (!maskVecTy)
93 return emitError() << "Expecting a vector type mask.";
94 int64_t maskSize = maskVecTy.getNumElements();
95
96 if (chunkSize > 1) {
97 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
98 return emitError() << "value elements must match chunk size "
99 << chunkSize;
100 } else {
101 if (valueSize != maskSize)
102 return emitError()
103 << "Mask should match value except the chunk size dim.";
104 }
105 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
106 if (maskSize == 1)
107 return success();
108 if (chunkSize > 1)
109 expectedMaskShape.pop_back();
110 if (expectedMaskShape != maskShape)
111 return emitError() << "Mask should match value except the chunk size dim.";
112
113 return success();
114}
115
116LogicalResult
117IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
118 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
120
121 if (!dataTy) {
122 if (subgroup_block_io)
123 return emitError() << "subgroup_block_io "
124 "are only allowed when result is a VectorType.";
125 else
126 return success();
127 }
128
129 if (mdescTy.getRank() < 2)
130 return emitError() << "mem_desc must be 2D or greater.";
131
132 ArrayRef<int64_t> dataShape = dataTy.getShape();
133 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
134
135 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
136 ArrayAttr strideAttr = mdescTy.getStrideAttr();
137 SmallVector<int64_t> strides;
138 for (Attribute attr : strideAttr.getValue()) {
139 strides.push_back(cast<IntegerAttr>(attr).getInt());
140 }
141 if (subgroup_block_io && layout) {
142 auto laneData = layout.getEffectiveLaneDataAsInt();
143 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
144 if (!laneData.empty()) {
145 bool isLaneDataContiguous =
146 std::all_of(laneData.begin(), std::prev(laneData.end()),
147 [](int x) { return x == 1; });
148 if (!isLaneDataContiguous)
149 return emitError() << "With subgroup_block_io, accessed data must be "
150 "contiguous and coalesced.";
151 for (size_t i = 0; i < laneData.size(); ++i) {
152 if (laneLayout[i] != blockShape[i])
153 return emitError() << "With subgroup_block_io, the block shape must "
154 "match the lane layout.";
155 if (laneLayout[i] != 1 && strides[i] != 1)
156 return emitError() << "With subgroup_block_io, the distributed "
157 "dimensions must be contiguous.";
158 }
159 }
160 }
161
162 if (layout && !layout.isDistributable(
163 SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
164 return emitError() << "Value shape is not distributable with the layout";
165
166 if (dataShape.size() == 2) {
167 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
168 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
169 return emitError() << "data shape must not exceed mem_desc shape.";
170 } else {
171 // if the subgroup_block_io attribute is set, mdescTy must have block
172 // attribute
173 if (subgroup_block_io && !blockShape.size())
174 return emitError() << "mem_desc must have block attribute when "
175 "subgroup_block_io is set.";
176 // if the subgroup_block_io attribute is set, the memdesc should be row
177 // major
178 if (subgroup_block_io && mdescTy.isColMajor())
179 return emitError() << "mem_desc should be row major when "
180 "subgroup_block_io is set.";
181 }
182
183 return success();
184}
185
186//===----------------------------------------------------------------------===//
187// XeGPU_CreateNdDescOp
188//===----------------------------------------------------------------------===//
189
190void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
191 Type tdesc, TypedValue<MemRefType> source) {
192 [[maybe_unused]] auto ty = source.getType();
193 assert(ty.hasStaticShape() && "expecting a memref with static shape");
194
195 build(builder, state, tdesc, source, ValueRange({}) /* empty dynamic shape */,
196 ValueRange({}) /* empty dynamic strides */,
197 DenseI64ArrayAttr({}) /* empty const shape*/,
198 DenseI64ArrayAttr({}) /* empty const strides*/);
199}
200
201void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
202 Type tdesc, Value source,
205 Type srcTy = source.getType();
206 assert((isa<IntegerType, MemRefType>(srcTy)) &&
207 "Source has to be either int or memref.");
208
209 llvm::SmallVector<Value> dynamicShape;
210 llvm::SmallVector<Value> dynamicStrides;
211
212 llvm::SmallVector<int64_t> staticShape;
213 llvm::SmallVector<int64_t> staticStrides;
214
215 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
216 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
217
218 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
219 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
220
221 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
222 auto memrefShape = memrefTy.getShape();
223 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
224
225 // if shape and strides are from Memref, we don't need attributes for them
226 // to keep the IR print clean (only do so for full-static case, otherwise
227 // printer would fail trying to print empty array-attr).
228 if (staticShape == memrefShape && staticStrides == memrefStrides &&
229 dynamicShape.empty() && dynamicStrides.empty()) {
230 staticShapeAttr = DenseI64ArrayAttr();
231 staticStridesAttr = DenseI64ArrayAttr();
232 }
233 }
234
235 build(builder, state, tdesc, source, dynamicShape, dynamicStrides,
236 staticShapeAttr, staticStridesAttr);
237}
238
239LogicalResult CreateNdDescOp::verify() {
240 size_t rank = getMixedSizes().size();
241 bool invalidRank = rank != getMixedStrides().size();
242 bool invalidElemTy = false;
243
244 // Memory space of created TensorDesc should match with the source.
245 // Both source and TensorDesc are considered for global memory by default,
246 // if the memory scope attr is not specified. If source is an integer,
247 // it is considered as ptr to global memory.
248 auto srcMemorySpace = getSourceMemorySpace();
249 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
250 if (srcMemorySpace != tdescMemorySpace)
251 return emitOpError("Memory space mismatch.")
252 << " Source: " << srcMemorySpace
253 << ", TensorDesc: " << tdescMemorySpace;
254
255 // check source type matches the rank if it is a memref.
256 // It also should have the same ElementType as TensorDesc.
257 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
258 invalidElemTy |= memrefTy.getElementType() != getElementType();
259
260 if (llvm::isa<IntegerType>(getSourceType())) {
261 // strides and shape must present for integer source.
262 if (getMixedStrides().empty() || getMixedSizes().empty())
263 return emitOpError("expecting strides and shape to be present for "
264 "integer source.");
265 }
266
267 if (invalidRank)
268 return emitOpError(
269 "Expecting the rank of shape, strides, and source (if source "
270 "is a memref) should match with each other.");
271
272 // check result TensorDesc rank
273 if (getType().getRank() > (int64_t)rank)
274 return emitOpError("Expecting the TensorDesc rank is not greater than the "
275 "ranks of shape, strides or the memref source.");
276
277 if (invalidElemTy)
278 return emitOpError("TensorDesc should have the same element "
279 "type with the source if it is a memref.\n");
280
281 return success();
282}
283
284//===----------------------------------------------------------------------===//
285// XeGPU_PrefetchNdOp
286//===----------------------------------------------------------------------===//
287
288void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
289 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
290 xegpu::CachePolicyAttr l1_hint,
291 xegpu::CachePolicyAttr l2_hint,
292 xegpu::CachePolicyAttr l3_hint,
293 xegpu::DistributeLayoutAttr layout) {
294 SmallVector<Value> dynamicOffsets;
295 SmallVector<int64_t> staticOffsets;
296 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
297
298 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
299
300 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
301 l2_hint, l3_hint, /*anchor_layout=*/layout);
302}
303
304LogicalResult PrefetchNdOp::verify() {
305 auto tdescTy = getTensorDescType();
306
307 if (!isReadHintOrNone(getL1HintAttr()))
308 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
309
310 if (!isReadHintOrNone(getL2HintAttr()))
311 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
312
313 if (!isReadHintOrNone(getL3HintAttr()))
314 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
315
316 int64_t tDescRank = tdescTy.getRank();
317 int64_t offsetSize = getMixedOffsets().size();
318 if (offsetSize != tDescRank)
319 return emitOpError(
320 "Mismatched ranks between offsets and tensor descriptor");
321
322 if (auto layout = getAnchorLayout()) {
323 if (!layout.isDistributable(getShapeOf(tdescTy)))
324 return emitOpError(
325 "TensorDesc shape is not distributable with the layout");
326 }
327
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// XeGPU_LoadNdOp
333//===----------------------------------------------------------------------===//
334
335void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
336 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
337 UnitAttr packed, DenseI64ArrayAttr transpose,
338 xegpu::CachePolicyAttr l1_hint,
339 xegpu::CachePolicyAttr l2_hint,
340 xegpu::CachePolicyAttr l3_hint,
341 xegpu::DistributeLayoutAttr layout) {
342 SmallVector<Value> dynamicOffsets;
343 SmallVector<int64_t> staticOffsets;
344 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
345
346 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
347
348 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
349 packed, transpose, l1_hint, l2_hint, l3_hint,
350 /*anchor_layout=*/layout);
351}
352
353LogicalResult LoadNdOp::verify() {
354 auto tdescTy = getTensorDescType();
355 auto valueTy = getType();
356
357 if (tdescTy.getRank() > 2)
358 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
359
360 if (!valueTy)
361 return emitOpError("Invalid result, it should be a VectorType.\n");
362
363 if (!isReadHintOrNone(getL1HintAttr()))
364 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
365
366 if (!isReadHintOrNone(getL2HintAttr()))
367 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
368
369 if (!isReadHintOrNone(getL3HintAttr()))
370 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
371
372 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
373 int valueElems = valueTy.getNumElements();
374
375 // If the result vector is 1D and has less elements than the tensor
376 // descriptor, it is supposed to be a SIMT op. The layout attribute in
377 // tensor_desc is not needed.
378 if (valueElems < tdescElems && valueTy.getRank() == 1) {
379 // SIMT mode doesn't need LayoutAttr.
380 if (tdescTy.getLayoutAttr())
381 return emitOpError()
382 << "TensorDesc doesn't need LayoutAttr for SIMT code";
383
384 // For SIMT code, the load is evenly distributed across all lanes in a
385 // subgroup. Since subgroup size is arch dependent, we only check even
386 // distribution here.
387 if (tdescElems % valueElems)
388 return emitOpError()
389 << "Result shape " << makeString(getShapeOf(valueTy))
390 << " is not a valid distribution for tensor descriptor "
391 << tdescTy;
392
393 return success();
394 }
395
396 // Check SIMD mode.
397 auto tdescShape = getShapeOf(tdescTy);
398 auto valueShape = getShapeOf(valueTy);
399
400 if (getTranspose()) {
401 auto trans = getTranspose().value();
402 // Make sure the transpose value is valid, and apply it
403 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
404 tdescShape = applyPermutation(tdescShape, trans);
405 else
406 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
407 }
408
409 if (getPacked()) {
410 if (tdescTy.getRank() == 2) {
411 const int axis = 0;
412 auto vnni_factor = valueShape.back();
413 tdescShape[axis] /= vnni_factor;
414 tdescShape.push_back(vnni_factor);
415 } else {
416 mlir::emitWarning(getLoc())
417 << "Invalid Packed Attr. It is ignored (available for 2D "
418 "TensorDesc only).";
419 }
420 }
421
422 // Handle array_length. Two result shape conventions are accepted:
423 // * 3D shape: leading array_length dimension prepended, e.g. descriptor
424 // 16x16 with array_length=2 -> [2, 16, 16].
425 // * Stacked 2D shape: array blocks stacked along the non-FCD (first)
426 // dimension, e.g. descriptor 16x16 with array_length=2 -> [32, 16].
427 auto array_len = tdescTy.getArrayLength();
428 SmallVector<int64_t> stacked2DShape(tdescShape);
429 SmallVector<int64_t> threeDShape(tdescShape);
430 if (array_len > 1 && !tdescShape.empty()) {
431 stacked2DShape[0] *= array_len;
432 threeDShape.insert(threeDShape.begin(), array_len);
433 }
434
435 if (valueShape != stacked2DShape && valueShape != threeDShape)
436 return emitOpError() << "Result shape " << makeString(valueShape)
437 << " is not consistent with tensor descriptor "
438 << tdescTy;
439
440 int64_t tDescRank = tdescTy.getRank();
441 int64_t offsetSize = getMixedOffsets().size();
442 if (offsetSize != tDescRank)
443 return emitOpError(
444 "Mismatched ranks between offsets and tensor descriptor");
445
446 if (auto layout = getAnchorLayout()) {
447 if (!layout.isDistributable(getShapeOf(tdescTy)))
448 return emitOpError(
449 "TensorDesc shape is not distributable with the layout");
450 }
451
452 return success();
453}
454
455//===----------------------------------------------------------------------===//
456// XeGPU_StoreNdOp
457//===----------------------------------------------------------------------===//
458
459void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
460 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
461 xegpu::CachePolicyAttr l1_hint,
462 xegpu::CachePolicyAttr l2_hint,
463 xegpu::CachePolicyAttr l3_hint,
464 xegpu::DistributeLayoutAttr layout) {
465 SmallVector<Value> dynamicOffsets;
466 SmallVector<int64_t> staticOffsets;
467 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
468
469 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
470
471 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
472 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
473}
474
475LogicalResult StoreNdOp::verify() {
476 auto dstTy = getTensorDescType(); // Tile
477 auto valTy = getValueType(); // Vector
478
479 if (dstTy.getRank() > 2)
480 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
481
482 if (!valTy)
483 return emitOpError("Expecting a VectorType result.\n");
484
485 if (!isWriteHintOrNone(getL1HintAttr()))
486 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
487
488 if (!isWriteHintOrNone(getL2HintAttr()))
489 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
490
491 if (!isWriteHintOrNone(getL3HintAttr()))
492 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
493
494 auto array_len = dstTy.getArrayLength();
495 if (array_len > 1)
496 return emitOpError("array length is not supported by store_nd.\n");
497
498 auto tdescElems = dstTy.getNumElements();
499 auto valueElems = valTy.getNumElements();
500
501 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
502 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
503 // in tensor_desc is not needed.
504 if (valTy.getRank() == 1 && valueElems < tdescElems) {
505 // SIMT mode doesn't need LayoutAttr.
506 if (dstTy.getLayoutAttr())
507 return emitOpError()
508 << "TensorDesc doesn't need LayoutAttr for SIMT code";
509
510 if (tdescElems % valueElems)
511 return emitOpError()
512 << "Value shape " << makeString(getShapeOf(valTy))
513 << " is not a valid distribution for tensor descriptor " << dstTy;
514
515 return success();
516 }
517
518 // SIMD code should have the same shape as the tensor descriptor.
519 auto tdescShape = getShapeOf(dstTy);
520 auto valueShape = getShapeOf(valTy);
521 if (tdescShape != valueShape)
522 return emitOpError() << "Value shape " << makeString(valueShape)
523 << " is not consistent with tensor descriptor "
524 << dstTy;
525
526 int64_t tDescRank = dstTy.getRank();
527 int64_t offsetSize = getMixedOffsets().size();
528 if (offsetSize != tDescRank)
529 return emitOpError(
530 "Mismatched ranks between offsets and tensor descriptor");
531
532 if (auto layout = getAnchorLayout()) {
533 if (!layout.isDistributable(tdescShape))
534 return emitOpError(
535 "TensorDesc shape is not distributable with the layout");
536 }
537
538 return success();
539}
540
541//===----------------------------------------------------------------------===//
542// XeGPU_PrefetchOp
543//===----------------------------------------------------------------------===//
544LogicalResult PrefetchOp::verify() {
545 if (!isReadHintOrNone(getL1HintAttr()))
546 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
547
548 if (!isReadHintOrNone(getL2HintAttr()))
549 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
550
551 if (!isReadHintOrNone(getL3HintAttr()))
552 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
553
554 auto srcTy = getSourceType();
555 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
556 return emitOpError("offset_align_byte is required with integer source.");
557
558 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
559 return emitOpError("offset_align_byte only allowed with integer source.");
560
561 if (auto layout = getAnchorLayout()) {
562 // get the offset operand and its shape
563 auto offsetsTy = getOffsets().getType();
564 if (llvm::isa<VectorType>(offsetsTy) &&
565 !layout.isDistributable(getShapeOf(offsetsTy)))
566 return emitOpError("offset shape is not distributable with the layout");
567 }
568
569 return success();
570}
571
572//===----------------------------------------------------------------------===//
573// XeGPU_LoadGatherOp
574//===----------------------------------------------------------------------===//
575LogicalResult LoadGatherOp::verify() {
576 auto maskTy = getMaskType();
577 auto valueTy = getValueType();
578
579 if (!isReadHintOrNone(getL1HintAttr()))
580 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
581
582 if (!isReadHintOrNone(getL2HintAttr()))
583 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
584
585 if (!isReadHintOrNone(getL3HintAttr()))
586 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
587
588 auto srcTy = getSourceType();
589 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
590 auto memTy = dyn_cast<MemRefType>(srcTy);
591
592 if (memTy && (getElementType() != memTy.getElementType()))
593 return emitError() << "Value should have the same element type as MemRef.";
594
595 if (auto layout = getAnchorLayout()) {
596 if (!layout.isDistributable(getShapeOf(valueTy)))
597 return emitOpError("Value shape is not distributable with the layout");
598 }
599
600 auto offsetsTy = getOffsets().getType();
601 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
602 [&]() { return emitOpError(); });
603}
604
605void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
606 Type valueType, Value source,
607 ArrayRef<OpFoldResult> offsets, Value mask,
608 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
609 xegpu::CachePolicyAttr l2_hint,
610 xegpu::CachePolicyAttr l3_hint) {
611 auto loc = source.getLoc();
612 int64_t size = static_cast<int64_t>(offsets.size());
613 auto type = VectorType::get(size, builder.getIndexType());
614 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
615 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
616
617 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
618 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
619}
620
621void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
622 Type valueType, Value source,
623 ArrayRef<OpFoldResult> offsets, Value mask,
624 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
625 xegpu::CachePolicyAttr l2_hint,
626 xegpu::CachePolicyAttr l3_hint,
627 DistributeLayoutAttr layout) {
628 auto loc = source.getLoc();
629 int64_t size = static_cast<int64_t>(offsets.size());
630 auto type = VectorType::get(size, builder.getIndexType());
631 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
632 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
633
634 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
635 l2_hint, l3_hint, layout);
636}
637
638//===----------------------------------------------------------------------===//
639// XeGPU_StoreScatterOp
640//===----------------------------------------------------------------------===//
641LogicalResult StoreScatterOp::verify() {
642 auto maskTy = getMaskType();
643 auto valueTy = getValueType();
644
645 if (!isWriteHintOrNone(getL1HintAttr()))
646 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
647
648 if (!isWriteHintOrNone(getL2HintAttr()))
649 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
650
651 if (!isWriteHintOrNone(getL3HintAttr()))
652 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
653
654 auto destTy = getDestType();
655 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
656 auto memTy = dyn_cast<MemRefType>(destTy);
657
658 if (memTy && (getElementType() != memTy.getElementType()))
659 return emitError() << "Value should have the same element type as MemRef.";
660
661 if (auto layout = getAnchorLayout()) {
662 if (!layout.isDistributable(getShapeOf(valueTy)))
663 return emitOpError("Value shape is not distributable with the layout");
664 }
665
666 auto offsetsTy = getOffsets().getType();
667 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
668 [&]() { return emitOpError(); });
669}
670
671void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
672 Value value, Value dest,
673 ArrayRef<OpFoldResult> offsets, Value mask,
674 IntegerAttr chunk_size,
675 xegpu::CachePolicyAttr l1_hint,
676 xegpu::CachePolicyAttr l2_hint,
677 xegpu::CachePolicyAttr l3_hint) {
678 auto loc = dest.getLoc();
679 int64_t size = static_cast<int64_t>(offsets.size());
680 auto type = VectorType::get(size, builder.getIndexType());
681 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
682 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
683
684 // Call the correct builder overload that does not expect result types.
685 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
686 l3_hint, /*anchor_layout=*/nullptr);
687}
688
689void StoreScatterOp::build(
690 OpBuilder &builder, OperationState &state, Value value, Value dest,
691 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
692 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
693 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
694 auto loc = dest.getLoc();
695 int64_t size = static_cast<int64_t>(offsets.size());
696 auto type = VectorType::get(size, builder.getIndexType());
697 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
698 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
699
700 // Call the correct builder overload that does not expect result types.
701 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
702 l3_hint, layout);
703}
704
705//===----------------------------------------------------------------------===//
706// DPAS Common Verification Helpers
707//===----------------------------------------------------------------------===//
708
709// Helper to verify layout distributability for a value
710static LogicalResult
712 std::optional<DistributeLayoutAttr> layout,
713 ArrayRef<int64_t> shape, StringRef operandName) {
714 if (layout && !layout->isDistributable(
715 SmallVector<int64_t>(shape.begin(), shape.end())))
716 return op->emitOpError(operandName)
717 << " shape is not distributable with the layout";
718 return success();
719}
720
721// Helper to verify M, N, K dimensions match between A, B, and result matrices
722static LogicalResult verifyDpasDimensions(Operation *op,
723 ArrayRef<int64_t> aShape,
724 ArrayRef<int64_t> bShape,
725 ArrayRef<int64_t> resShape) {
726
727 auto aRank = aShape.size();
728 auto bRank = bShape.size();
729 auto resRank = resShape.size();
730 if (aRank == 1 && bRank == 1 && resRank == 1)
731 return success();
732
733 // Validate A and B are 2D
734 if (aRank != 2)
735 return op->emitOpError("A operand must be a 2D vector.");
736 if (bRank < 2 || bRank > 3)
737 return op->emitOpError("B operand must be a 2D or 3D vector.");
738 if (resRank != 2)
739 return op->emitOpError("Result must be a 2D vector.");
740
741 // Calculate effective K dimension for B (handle 3D packed case)
742 int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0];
743
744 // Verify K dimension match between A and B
745 if (bK != aShape[1])
746 return op->emitOpError("K-dimension mismatch: A has K=")
747 << aShape[1] << " but B has K=" << bK << ".";
748
749 // Verify M dimension match between A and result
750 if (aShape[0] != resShape[0])
751 return op->emitOpError("M-dimension mismatch: A has M=")
752 << aShape[0] << " but result has M=" << resShape[0] << ".";
753
754 // Verify N dimension match between B and result
755 if (bShape[1] != resShape[1])
756 return op->emitOpError("N-dimension mismatch: B has N=")
757 << bShape[1] << " but result has N=" << resShape[1] << ".";
758
759 return success();
760}
761
762// Helper to verify accumulator matches result type
763static LogicalResult verifyDpasAccumulator(Operation *op, Type accType,
764 Type resultType) {
765 if (accType != resultType)
766 return op->emitOpError("Accumulator type must match result type.");
767 return success();
768}
769
770//===----------------------------------------------------------------------===//
771// XeGPU_DpasOp
772//===----------------------------------------------------------------------===//
773LogicalResult DpasOp::verify() {
774 auto lhsShape = getLhsType().getShape();
775 auto rhsShape = getRhsType().getShape();
776 auto resShape = getResultType().getShape();
777
778 // Verify layout distributability
779 if (failed(
780 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
781 return failure();
782 if (failed(verifyLayoutDistributable(*this, getLayoutA(), lhsShape, "A")))
783 return failure();
784 if (failed(verifyLayoutDistributable(*this, getLayoutB(), rhsShape, "B")))
785 return failure();
786
787 // Verify accumulator if present
788 if (getAcc() &&
789 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
790 return failure();
791
792 return verifyDpasDimensions(*this, lhsShape, rhsShape, resShape);
793}
794
795//===----------------------------------------------------------------------===//
796// XeGPU_ConvertLayoutOp
797//===----------------------------------------------------------------------===//
798LogicalResult ConvertLayoutOp::verify() {
799 auto srcLayout = getInputLayout();
800 auto resLayout = getTargetLayout();
801 if (!srcLayout)
802 return emitOpError("expected input layout.");
803 if (!resLayout)
804 return emitOpError("expected target layout.");
805
806 // both input and target layouts should be WgLayout or SgLayout at the same
807 // time.
808 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
809 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
810 return emitOpError("expected input layout and target layout be WgLayout or "
811 "SgLayout at the same time.");
812
813 Type srcType = getSource().getType();
814 if (llvm::isa<VectorType>(srcType)) {
815 SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
816 if (!srcLayout.isDistributable(shape))
817 return emitOpError(
818 "invalid input layout, data cannot be evenly distributed.");
819
820 if (!resLayout.isDistributable(shape))
821 return emitOpError(
822 "invalid target layout, data cannot be evenly distributed.");
823 }
824 return mlir::success();
825}
826
827//===----------------------------------------------------------------------===//
828// XeGPU_LoadMatrixOp
829//===----------------------------------------------------------------------===//
830void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
833 DistributeLayoutAttr layout) {
834 llvm::SmallVector<Value> dynamicOffsets;
835 llvm::SmallVector<int64_t> staticOffsets;
836 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
837 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
838 // Call the generated builder with all parameters (including optional ones as
839 // nullptr/empty)
840 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
841 /*subgroup_block_io=*/nullptr, layout);
842}
843
844LogicalResult LoadMatrixOp::verify() {
845
846 auto resTy = dyn_cast<VectorType>(getRes().getType());
847 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
848 MemDescType mdescTy = getMemDesc().getType();
849
850 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
851 getLayoutAttr(), [&]() { return emitError(); });
852}
853
854//===----------------------------------------------------------------------===//
855// XeGPU_StoreMatrixOp
856//===----------------------------------------------------------------------===//
857void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
860 DistributeLayoutAttr layout) {
861 llvm::SmallVector<Value> dynamicOffsets;
862 llvm::SmallVector<int64_t> staticOffsets;
863 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
864 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
865 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
866 /*subgroup_block_io=*/nullptr, layout);
867}
868
869LogicalResult StoreMatrixOp::verify() {
870
871 auto dataTy = dyn_cast<VectorType>(getData().getType());
872 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
873 MemDescType mdescTy = getMemDesc().getType();
874 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
875 getLayoutAttr(), [&]() { return emitError(); });
876}
877
878//===----------------------------------------------------------------------===//
879// XeGPU_TruncfOp
880//===----------------------------------------------------------------------===//
881
882LogicalResult TruncfOp::verify() {
883 auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
884 auto resultVecType = dyn_cast<VectorType>(getResult().getType());
885
886 if (sourceVecType.getElementTypeBitWidth() <=
887 resultVecType.getElementTypeBitWidth())
888 return emitOpError("input type must be wider than result type.");
889
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// XeGPU_DpasMxOp
895//===----------------------------------------------------------------------===//
896
897LogicalResult DpasMxOp::verify() {
898 auto aShape = getAType().getShape();
899 auto bShape = getBType().getShape();
900 auto resShape = getResultType().getShape();
901
902 // Verify layout distributability for A, B, and result
903 if (failed(
904 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
905 return failure();
906 if (failed(verifyLayoutDistributable(*this, getLayoutA(), aShape, "A")))
907 return failure();
908 if (failed(verifyLayoutDistributable(*this, getLayoutB(), bShape, "B")))
909 return failure();
910
911 // Verify accumulator if present
912 if (getAcc() &&
913 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
914 return failure();
915
916 // Verify M, N, K dimensions
917 if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
918 return failure();
919
920 // Validate scale_a if present
921 if (getScaleA()) {
922 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
923 // Only validate if scale is a vector (scalars are always valid)
924 if (scaleAVecType) {
925 auto scaleAShape = scaleAVecType.getShape();
926
927 if (scaleAVecType.getRank() != 2)
928 return emitOpError("Scale A must be a 2D vector when not a scalar.");
929
930 // Verify layout distributability for scale_a
931 if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
932 scaleAShape, "ScaleA")))
933 return failure();
934
935 // Validate M dimension: scale_a[0] must match a[0]
936 if (scaleAShape[0] != aShape[0])
937 return emitOpError("Scale A M dimension [")
938 << scaleAShape[0] << "] must match A M dimension [" << aShape[0]
939 << "].";
940 }
941 }
942
943 // Validate scale_b if present
944 if (getScaleB()) {
945 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
946 // Only validate if scale is a vector (scalars are always valid)
947 if (scaleBVecType) {
948 auto scaleBShape = scaleBVecType.getShape();
949
950 if (scaleBVecType.getRank() != 2)
951 return emitOpError("Scale B must be a 2D vector when not a scalar.");
952
953 // Verify layout distributability for scale_b
954 if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
955 scaleBShape, "ScaleB")))
956 return failure();
957
958 // Validate N dimension: scale_b[1] must match b[1]
959 if (scaleBShape[1] != bShape[1])
960 return emitOpError("Scale B N dimension [")
961 << scaleBShape[1] << "] must match B N dimension [" << bShape[1]
962 << "].";
963 }
964 }
965
966 // Validate scale K dimension compatibility if both scales are present and
967 // vectors
968 if (getScaleA() && getScaleB()) {
969 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
970 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
971
972 if (scaleAVecType && scaleBVecType) {
973 auto scaleAShape = scaleAVecType.getShape();
974 auto scaleBShape = scaleBVecType.getShape();
975
976 // Validate scale K dimension compatibility: scale_a[1] must match
977 // scale_b[0]
978 if (scaleAShape[1] != scaleBShape[0])
979 return emitOpError("Scale K dimension mismatch: scale_a has K=")
980 << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0]
981 << ".";
982 }
983 }
984
985 return success();
986}
987
988namespace mlir {
989#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
990} // namespace mlir
991#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
992#define GET_OP_CLASSES
993#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:39
static LogicalResult verifyDpasAccumulator(Operation *op, Type accType, Type resultType)
Definition XeGPUOps.cpp:763
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:117
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:25
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:56
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:48
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:65
static LogicalResult verifyDpasDimensions(Operation *op, ArrayRef< int64_t > aShape, ArrayRef< int64_t > bShape, ArrayRef< int64_t > resShape)
Definition XeGPUOps.cpp:722
static LogicalResult verifyLayoutDistributable(Operation *op, std::optional< DistributeLayoutAttr > layout, ArrayRef< int64_t > shape, StringRef operandName)
Definition XeGPUOps.cpp:711
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IndexType getIndexType()
Definition Builders.cpp:55
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This represents an operation in an abstracted form, suitable for use with the builder APIs.