MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
26#include "mlir/Pass/Pass.h"
27#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
30
32#include "mlir/IR/Types.h"
33
34#include "llvm/ADT/TypeSwitch.h"
35
36#include <numeric>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45namespace {
46
47// TODO: Below are uArch dependent values, should move away from hardcoding
48static constexpr int32_t systolicDepth{8};
49static constexpr int32_t executionSize{16};
50
51// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
52enum class NdTdescOffset : uint32_t {
53 BasePtr = 0, // Base pointer (i64)
54 BaseShapeW = 2, // Base shape width (i32)
55 BaseShapeH = 3, // Base shape height (i32)
56 BasePitch = 4, // Base pitch (i32)
57};
58
59static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
60 switch (xeGpuMemspace) {
61 case xegpu::MemorySpace::Global:
62 return static_cast<int>(xevm::AddrSpace::GLOBAL);
63 case xegpu::MemorySpace::SLM:
64 return static_cast<int>(xevm::AddrSpace::SHARED);
65 }
66 llvm_unreachable("Unknown XeGPU memory space");
67}
68
69/// Checks if the given MemRefType refers to shared memory.
70static bool isSharedMemRef(const MemRefType &memrefTy) {
71 Attribute attr = memrefTy.getMemorySpace();
72 if (!attr)
73 return false;
74 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
75 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
76 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
77 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
78 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
79}
80
81// Get same bitwidth flat vector type of new element type.
82static VectorType encodeVectorTypeTo(VectorType currentVecType,
83 Type toElemType) {
84 auto elemType = currentVecType.getElementType();
85 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
86 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
87 const int size =
88 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
89 return VectorType::get(size, toElemType);
90}
91
92static xevm::LoadCacheControl
93translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
94 std::optional<xegpu::CachePolicy> L3hint) {
95 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
96 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
97 switch (L1hintVal) {
98 case xegpu::CachePolicy::CACHED:
99 if (L3hintVal == xegpu::CachePolicy::CACHED)
100 return xevm::LoadCacheControl::L1C_L2UC_L3C;
101 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
102 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
103 else
104 llvm_unreachable("Unsupported cache control.");
105 case xegpu::CachePolicy::UNCACHED:
106 if (L3hintVal == xegpu::CachePolicy::CACHED)
107 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
108 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
109 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
110 else
111 llvm_unreachable("Unsupported cache control.");
112 case xegpu::CachePolicy::STREAMING:
113 if (L3hintVal == xegpu::CachePolicy::CACHED)
114 return xevm::LoadCacheControl::L1S_L2UC_L3C;
115 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
116 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
117 else
118 llvm_unreachable("Unsupported cache control.");
119 case xegpu::CachePolicy::READ_INVALIDATE:
120 return xevm::LoadCacheControl::INVALIDATE_READ;
121 default:
122 llvm_unreachable("Unsupported cache control.");
123 }
124}
125
126static xevm::StoreCacheControl
127translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
128 std::optional<xegpu::CachePolicy> L3hint) {
129 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
130 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
131 switch (L1hintVal) {
132 case xegpu::CachePolicy::UNCACHED:
133 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
134 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
135 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
136 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
137 else
138 llvm_unreachable("Unsupported cache control.");
139 case xegpu::CachePolicy::STREAMING:
140 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
141 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
142 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
143 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
144 else
145 llvm_unreachable("Unsupported cache control.");
146 case xegpu::CachePolicy::WRITE_BACK:
147 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
148 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
149 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
150 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
151 else
152 llvm_unreachable("Unsupported cache control.");
153 case xegpu::CachePolicy::WRITE_THROUGH:
154 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
155 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
156 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
157 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
158 else
159 llvm_unreachable("Unsupported cache control.");
160 default:
161 llvm_unreachable("Unsupported cache control.");
162 }
163}
164
165//
166// Note:
167// Block operations for tile of sub byte element types are handled by
168// emulating with larger element types.
169// Tensor descriptor are keep intact and only ops consuming them are
170// emulated
171//
172
173class CreateNdDescToXeVMPattern
174 : public OpConversionPattern<xegpu::CreateNdDescOp> {
175 using OpConversionPattern::OpConversionPattern;
176 LogicalResult
177 matchAndRewrite(xegpu::CreateNdDescOp op,
178 xegpu::CreateNdDescOp::Adaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181 if (mixedOffsets.size() != 0)
182 return rewriter.notifyMatchFailure(op, "Offsets not supported.");
183 auto loc = op.getLoc();
184 auto source = op.getSource();
185 // Op is lowered to a code sequence that populates payload.
186 // Payload is a 8xi32 vector. Offset to individual fields are defined in
187 // NdTdescOffset enum.
188 Type payloadElemTy = rewriter.getI32Type();
189 VectorType payloadTy = VectorType::get(8, payloadElemTy);
190 Type i64Ty = rewriter.getI64Type();
191 // 4xi64 view is used for inserting the base pointer.
192 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
193 // Initialize payload to zero.
194 Value payload = arith::ConstantOp::create(
195 rewriter, loc,
196 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
197
198 Value baseAddr;
199 Value baseShapeW;
200 Value baseShapeH;
201
202 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
203 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
204 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
205 // Descriptor shape is expected to be 2D.
206 int64_t rank = mixedSizes.size();
207 auto sourceTy = source.getType();
208 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
209 // If source is a memref, we need to extract the aligned pointer as index.
210 // Pointer type is passed as i32 or i64 by type converter.
211 if (sourceMemrefTy) {
212 if (!sourceMemrefTy.hasRank()) {
213 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
214 }
215 // Access adaptor after failure check to avoid rolling back generated code
216 // for materialization cast.
217 baseAddr = adaptor.getSource();
218 } else {
219 baseAddr = adaptor.getSource();
220 if (baseAddr.getType() != i64Ty) {
221 // Pointer type may be i32. Cast to i64 if needed.
222 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
223 }
224 }
225 // 1D tensor descriptor is just the base address.
226 if (rank == 1) {
227 rewriter.replaceOp(op, baseAddr);
228 return success();
229 }
230 // Utility for creating offset values from op fold result.
231 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
232 unsigned idx) -> Value {
233 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
234 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
235 return val;
236 };
237 // Get shape values from op fold results.
238 baseShapeW = createOffset(mixedSizes, 1);
239 baseShapeH = createOffset(mixedSizes, 0);
240 // Get pitch value from op fold results.
241 Value basePitch = createOffset(mixedStrides, 0);
242 // Populate payload.
243 Value payLoadAsI64 =
244 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
245 payLoadAsI64 =
246 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
247 static_cast<int>(NdTdescOffset::BasePtr));
248 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
249 payload =
250 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
251 static_cast<int>(NdTdescOffset::BaseShapeW));
252 payload =
253 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
254 static_cast<int>(NdTdescOffset::BaseShapeH));
255 payload =
256 vector::InsertOp::create(rewriter, loc, basePitch, payload,
257 static_cast<int>(NdTdescOffset::BasePitch));
258 rewriter.replaceOp(op, payload);
259 return success();
260 }
261};
262
263template <
264 typename OpType,
265 typename = std::enable_if_t<llvm::is_one_of<
266 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
267class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
268 using OpConversionPattern<OpType>::OpConversionPattern;
269 LogicalResult
270 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 auto mixedOffsets = op.getMixedOffsets();
273 int64_t opOffsetsSize = mixedOffsets.size();
274 auto loc = op.getLoc();
275 auto ctxt = rewriter.getContext();
276
277 auto tdesc = adaptor.getTensorDesc();
278 auto tdescTy = op.getTensorDescType();
279 auto tileRank = tdescTy.getRank();
280 if (opOffsetsSize != tileRank)
281 return rewriter.notifyMatchFailure(
282 op, "Expected offset rank to match descriptor rank.");
283 auto elemType = tdescTy.getElementType();
284 auto elemBitSize = elemType.getIntOrFloatBitWidth();
285 bool isSubByte = elemBitSize < 8;
286 uint64_t wScaleFactor = 1;
287
288 if (!isSubByte && (elemBitSize % 8 != 0))
289 return rewriter.notifyMatchFailure(
290 op, "Expected element type bit width to be multiple of 8.");
291 auto tileW = tdescTy.getDimSize(tileRank - 1);
292 // For sub byte types, only 4bits are currently supported.
293 if (isSubByte) {
294 if (elemBitSize != 4)
295 return rewriter.notifyMatchFailure(
296 op, "Only sub byte types of 4bits are supported.");
297 if (tileRank != 2)
298 return rewriter.notifyMatchFailure(
299 op, "Sub byte types are only supported for 2D tensor descriptors.");
300 auto subByteFactor = 8 / elemBitSize;
301 auto tileH = tdescTy.getDimSize(0);
302 // Handle special case for packed load.
303 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
304 if (op.getPacked().value_or(false)) {
305 // packed load is implemented as packed loads of 8bit elements.
306 if (tileH == systolicDepth * 4 &&
307 tileW == executionSize * subByteFactor) {
308 // Usage case for loading as Matrix B with pack request.
309 // source is assumed to pre-packed into 8bit elements
310 // Emulate with 8bit loads with pack request.
311 // scaled_tileW = executionSize
312 elemType = rewriter.getIntegerType(8);
313 tileW = executionSize;
314 wScaleFactor = subByteFactor;
315 }
316 }
317 }
318 // If not handled by packed load case above, handle other cases.
319 if (wScaleFactor == 1) {
320 auto sub16BitFactor = subByteFactor * 2;
321 if (tileW == executionSize * sub16BitFactor) {
322 // Usage case for loading as Matrix A operand
323 // Emulate with 16bit loads/stores.
324 // scaled_tileW = executionSize
325 elemType = rewriter.getIntegerType(16);
326 tileW = executionSize;
327 wScaleFactor = sub16BitFactor;
328 } else {
329 return rewriter.notifyMatchFailure(
330 op, "Unsupported tile shape for sub byte types.");
331 }
332 }
333 // recompute element bit size for emulation.
334 elemBitSize = elemType.getIntOrFloatBitWidth();
335 }
336
337 // Get address space from tensor descriptor memory space.
338 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
339 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
340 if (tileRank == 2) {
341 // Compute element byte size.
342 Value elemByteSize = arith::ConstantIntOp::create(
343 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
344 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
345 Value payLoadAsI64 =
346 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
347 Value basePtr =
348 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
349 static_cast<int>(NdTdescOffset::BasePtr));
350 Value baseShapeW = vector::ExtractOp::create(
351 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
352 Value baseShapeH = vector::ExtractOp::create(
353 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
354 Value basePitch = vector::ExtractOp::create(
355 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
356 // Offsets are provided by the op.
357 // convert them to i32.
358 Value offsetW =
359 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
360 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
361 rewriter.getI32Type(), offsetW);
362 Value offsetH =
363 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
364 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
365 rewriter.getI32Type(), offsetH);
366 // Convert base pointer (i64) to LLVM pointer type.
367 Value basePtrLLVM =
368 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
369 // FIXME: width or pitch is not the same as baseShapeW it should be the
370 // stride of the second to last dimension in row major layout.
371 // Compute width in bytes.
372 Value baseShapeWInBytes =
373 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
374 // Compute pitch in bytes.
375 Value basePitchBytes =
376 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
377
378 if (wScaleFactor > 1) {
379 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
380 // Note: tileW is already scaled above.
381 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
382 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
383 baseShapeWInBytes = arith::ShRSIOp::create(
384 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
385 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
386 wScaleFactorValLog2);
387 offsetW =
388 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
389 }
390 // Get tile height from the tensor descriptor type.
391 auto tileH = tdescTy.getDimSize(0);
392 // Get vblocks from the tensor descriptor type.
393 int32_t vblocks = tdescTy.getArrayLength();
394 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
395 Value src = adaptor.getValue();
396 // If store value is a scalar, get value from op instead of adaptor.
397 // Adaptor might have optimized away single element vector
398 if (src.getType().isIntOrFloat()) {
399 src = op.getValue();
400 }
401 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
402 if (!srcVecTy)
403 return rewriter.notifyMatchFailure(
404 op, "Expected store value to be a vector type.");
405 // Get flat vector type of integer type with matching element bit size.
406 VectorType newSrcVecTy =
407 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
408 if (srcVecTy != newSrcVecTy)
409 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
410 auto storeCacheControl =
411 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
412 xevm::BlockStore2dOp::create(
413 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
414 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
415 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
416 rewriter.eraseOp(op);
417 } else {
418 auto loadCacheControl =
419 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
420 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
421 xevm::BlockPrefetch2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
424 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
425 rewriter.eraseOp(op);
426 } else {
427 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
428 const bool vnni = op.getPacked().value_or(false);
429 auto transposeValue = op.getTranspose();
430 bool transpose =
431 transposeValue.has_value() && transposeValue.value()[0] == 1;
432 VectorType loadedTy = encodeVectorTypeTo(
433 dstVecTy, vnni ? rewriter.getI32Type()
434 : rewriter.getIntegerType(elemBitSize));
435
436 Value resultFlatVec = xevm::BlockLoad2dOp::create(
437 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
438 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
439 tileH, vblocks, transpose, vnni,
440 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
441 resultFlatVec = vector::BitCastOp::create(
442 rewriter, loc,
443 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
444 resultFlatVec);
445 rewriter.replaceOp(op, resultFlatVec);
446 }
447 }
448 } else {
449 // 1D tensor descriptor.
450 // `tdesc` represents base address as i64
451 // Offset in number of elements, need to multiply by element byte size.
452 // Compute byte offset.
453 // byteOffset = offset * elementByteSize
454 Value offset =
455 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
456 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
457 rewriter.getI64Type(), offset);
458 // Compute element byte size.
459 Value elemByteSize = arith::ConstantIntOp::create(
460 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
461 Value byteOffset =
462 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
463 // Final address = basePtr + byteOffset
464 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
465 loc, tdesc,
466 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
467 byteOffset));
468 // Convert base pointer (i64) to LLVM pointer type.
469 Value finalPtrLLVM =
470 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
471 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
472 Value src = adaptor.getValue();
473 // If store value is a scalar, get value from op instead of adaptor.
474 // Adaptor might have optimized away single element vector
475 if (src.getType().isIntOrFloat()) {
476 src = op.getValue();
477 }
478 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
479 if (!srcVecTy)
480 return rewriter.notifyMatchFailure(
481 op, "Expected store value to be a vector type.");
482 // Get flat vector type of integer type with matching element bit size.
483 VectorType newSrcVecTy =
484 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
485 if (srcVecTy != newSrcVecTy)
486 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
487 auto storeCacheControl =
488 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
489 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
490 op, finalPtrLLVM, src,
491 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
492 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
493 auto loadCacheControl =
494 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 VectorType resTy = cast<VectorType>(op.getValue().getType());
496 VectorType loadedTy =
497 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
498 Value load = xevm::BlockLoadOp::create(
499 rewriter, loc, loadedTy, finalPtrLLVM,
500 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
501 if (loadedTy != resTy)
502 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
503 rewriter.replaceOp(op, load);
504 } else {
505 return rewriter.notifyMatchFailure(
506 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
507 "descriptor rank == 1");
508 }
509 }
510 return success();
511 }
512};
513
514// Add a builder that creates
515// offset * elemByteSize + baseAddr
516static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
517 Location loc, Value baseAddr, Value offset,
518 int64_t elemByteSize) {
520 rewriter, loc, baseAddr.getType(), elemByteSize);
521 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
522 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
523 return newAddr;
524}
525
526template <typename OpType,
527 typename = std::enable_if_t<llvm::is_one_of<
528 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
529class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
530 using OpConversionPattern<OpType>::OpConversionPattern;
531 LogicalResult
532 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
533 ConversionPatternRewriter &rewriter) const override {
534 Value offset = adaptor.getOffsets();
535 if (!offset)
536 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
537 auto loc = op.getLoc();
538 auto ctxt = rewriter.getContext();
539 auto tdescTy = op.getTensorDescType();
540 Value basePtrI64;
541 // Load result or Store valye Type can be vector or scalar.
542 Type valOrResTy;
543 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
544 valOrResTy =
545 this->getTypeConverter()->convertType(op.getResult().getType());
546 else
547 valOrResTy = adaptor.getValue().getType();
548 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
549 bool hasScalarVal = !valOrResVecTy;
550 int64_t elemBitWidth =
551 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
552 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
553 // Element type must be multiple of 8 bits.
554 if (elemBitWidth % 8 != 0)
555 return rewriter.notifyMatchFailure(
556 op, "Expected element type bit width to be multiple of 8.");
557 int64_t elemByteSize = elemBitWidth / 8;
558 // Default memory space is global.
559 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
560 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
561 // If tensor descriptor is available, we use its memory space.
562 if (tdescTy)
563 ptrTypeLLVM = LLVM::LLVMPointerType::get(
564 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
565 // Base pointer can come from source (load) or dest (store).
566 // If they are memrefs, we use their memory space.
567 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
568 basePtrI64 = adaptor.getSource();
569 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
570 auto addrSpace = memRefTy.getMemorySpaceAsInt();
571 if (addrSpace != 0)
572 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
573 }
574 } else {
575 basePtrI64 = adaptor.getDest();
576 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
577 auto addrSpace = memRefTy.getMemorySpaceAsInt();
578 if (addrSpace != 0)
579 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
580 }
581 }
582 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
583 if (basePtrI64.getType() != rewriter.getI64Type()) {
584 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
585 basePtrI64);
586 }
587 Value mask = adaptor.getMask();
588 if (dyn_cast<VectorType>(offset.getType())) {
589 // Offset needs be scalar. Single element vector is converted to scalar
590 // by type converter.
591 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
592 } else {
593 // If offset is provided, we add them to the base pointer.
594 // Offset is in number of elements, we need to multiply by
595 // element byte size.
596 basePtrI64 =
597 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
598 }
599 // Convert base pointer (i64) to LLVM pointer type.
600 Value basePtrLLVM =
601 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
602
603 Value maskForLane;
604 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
605 if (maskVecTy) {
606 // Mask needs be scalar. Single element vector is converted to scalar by
607 // type converter.
608 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
609 } else
610 maskForLane = mask;
611 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
612 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
613 maskForLane, true, true);
614 // If mask is true,- then clause - load from memory and yield.
615 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
616 if (!hasScalarVal)
617 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
618 valOrResVecTy.getElementType());
619 Value loaded =
620 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
621 // Set cache control attribute on the load operation.
622 loaded.getDefiningOp()->setAttr(
623 "cache_control", xevm::LoadCacheControlAttr::get(
624 ctxt, translateLoadXeGPUCacheHint(
625 op.getL1Hint(), op.getL3Hint())));
626 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
627 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
628 // If mask is false - else clause -yield a vector of zeros.
629 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
630 TypedAttr eVal;
631 if (eTy.isFloat())
632 eVal = FloatAttr::get(eTy, 0.0);
633 else
634 eVal = IntegerAttr::get(eTy, 0);
635 if (hasScalarVal)
636 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
637 else
638 loaded = arith::ConstantOp::create(
639 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
640 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
641 rewriter.replaceOp(op, ifOp.getResult(0));
642 } else {
643 // If mask is true, perform the store.
644 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
645 auto body = ifOp.getBody();
646 rewriter.setInsertionPointToStart(body);
647 auto storeOp =
648 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
649 // Set cache control attribute on the store operation.
650 storeOp.getOperation()->setAttr(
651 "cache_control", xevm::StoreCacheControlAttr::get(
652 ctxt, translateStoreXeGPUCacheHint(
653 op.getL1Hint(), op.getL3Hint())));
654 rewriter.eraseOp(op);
655 }
656 return success();
657 }
658};
659
660class CreateMemDescOpPattern final
661 : public OpConversionPattern<xegpu::CreateMemDescOp> {
662public:
663 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
664 LogicalResult
665 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const override {
667
668 rewriter.replaceOp(op, adaptor.getSource());
669 return success();
670 }
671};
672
673template <typename OpType,
674 typename = std::enable_if_t<llvm::is_one_of<
675 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
676class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
677 using OpConversionPattern<OpType>::OpConversionPattern;
678 LogicalResult
679 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
680 ConversionPatternRewriter &rewriter) const override {
681
682 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
683 if (offsets.empty())
684 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
685
686 auto loc = op.getLoc();
687 auto ctxt = rewriter.getContext();
688 Value baseAddr32 = adaptor.getMemDesc();
689 Value mdescVal = op.getMemDesc();
690 // Load result or Store value Type can be vector or scalar.
691 Type dataTy;
692 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
693 Type resType = op.getResult().getType();
694 // Some transforms may leave unit dimension in the 2D vector, adaptors do
695 // not catch it for results.
696 if (auto vecType = dyn_cast<VectorType>(resType)) {
697 assert(llvm::count_if(vecType.getShape(),
698 [](int64_t d) { return d != 1; }) <= 1 &&
699 "Expected either 1D vector or nD with unit dimensions");
700 resType = VectorType::get({vecType.getNumElements()},
701 vecType.getElementType());
702 }
703 dataTy = resType;
704 } else
705 dataTy = adaptor.getData().getType();
706 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
707 if (!valOrResVecTy)
708 valOrResVecTy = VectorType::get(1, dataTy);
709
710 int64_t elemBitWidth =
711 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
712 // Element type must be multiple of 8 bits.
713 if (elemBitWidth % 8 != 0)
714 return rewriter.notifyMatchFailure(
715 op, "Expected element type bit width to be multiple of 8.");
716 int64_t elemByteSize = elemBitWidth / 8;
717
718 // Default memory space is SLM.
719 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
720 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
721
722 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
723
724 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
725 linearOffset = arith::IndexCastUIOp::create(
726 rewriter, loc, rewriter.getI32Type(), linearOffset);
727 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
728 linearOffset, elemByteSize);
729
730 // convert base pointer (i32) to LLVM pointer type
731 Value basePtrLLVM =
732 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
733
734 if (op.getSubgroupBlockIoAttr()) {
735 // if the attribute 'subgroup_block_io' is set to true, it lowers to
736 // xevm.blockload
737
738 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
739 VectorType intVecTy =
740 VectorType::get(valOrResVecTy.getShape(), intElemTy);
741
742 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
743 Value loadOp =
744 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
745 if (intVecTy != valOrResVecTy) {
746 loadOp =
747 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
748 }
749 rewriter.replaceOp(op, loadOp);
750 } else {
751 Value dataToStore = adaptor.getData();
752 if (valOrResVecTy != intVecTy) {
753 dataToStore =
754 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
755 }
756 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
757 nullptr);
758 rewriter.eraseOp(op);
759 }
760 return success();
761 }
762
763 if (valOrResVecTy.getNumElements() >= 1) {
764 auto chipOpt = xegpu::getChipStr(op);
765 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
766 // the lowering for chunk load only works for pvc and bmg
767 return rewriter.notifyMatchFailure(
768 op, "The lowering is specific to pvc or bmg.");
769 }
770 }
771
772 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
773 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
774 // operation. LLVM load/store does not support vector of size 1, so we
775 // need to handle this case separately.
776 auto scalarTy = valOrResVecTy.getElementType();
777 LLVM::LoadOp loadOp;
778 if (valOrResVecTy.getNumElements() == 1)
779 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
780 else
781 loadOp =
782 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
783 rewriter.replaceOp(op, loadOp);
784 } else {
785 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
786 rewriter.eraseOp(op);
787 }
788 return success();
789 }
790};
791
792class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
793 using OpConversionPattern::OpConversionPattern;
794 LogicalResult
795 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter) const override {
797 auto loc = op.getLoc();
798 auto ctxt = rewriter.getContext();
799 auto tdescTy = op.getTensorDescType();
800 Value basePtrI64 = adaptor.getSource();
801 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
802 if (basePtrI64.getType() != rewriter.getI64Type())
803 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
804 basePtrI64);
805 Value offsets = adaptor.getOffsets();
806 if (offsets) {
807 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
808 if (offsetsVecTy) {
809 // Offset needs be scalar.
810 return rewriter.notifyMatchFailure(op,
811 "Expected offsets to be a scalar.");
812 } else {
813 int64_t elemBitWidth{0};
814 int64_t elemByteSize;
815 // Element byte size can come from three sources:
816 if (tdescTy) {
817 // If tensor descriptor is available, we use its element type to
818 // determine element byte size.
819 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
820 } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
821 // If memref is available, we use its element type to
822 // determine element byte size.
823 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
824 } else {
825 // Otherwise, we use the provided offset byte alignment.
826 elemByteSize = *op.getOffsetAlignByte();
827 }
828 if (elemBitWidth != 0) {
829 if (elemBitWidth % 8 != 0)
830 return rewriter.notifyMatchFailure(
831 op, "Expected element type bit width to be multiple of 8.");
832 elemByteSize = elemBitWidth / 8;
833 }
834 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
835 elemByteSize);
836 }
837 }
838 // Default memory space is global.
839 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
840 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
841 // If tensor descriptor is available, we use its memory space.
842 if (tdescTy)
843 ptrTypeLLVM = LLVM::LLVMPointerType::get(
844 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
845 // If source is a memref, we use its memory space.
846 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
847 auto addrSpace = memRefTy.getMemorySpaceAsInt();
848 if (addrSpace != 0)
849 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
850 }
851 // Convert base pointer (i64) to LLVM pointer type.
852 Value ptrLLVM =
853 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
854 // Create the prefetch op with cache control attribute.
855 xevm::PrefetchOp::create(
856 rewriter, loc, ptrLLVM,
857 xevm::LoadCacheControlAttr::get(
858 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
859 rewriter.eraseOp(op);
860 return success();
861 }
862};
863
864class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
865 using OpConversionPattern::OpConversionPattern;
866 LogicalResult
867 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
868 ConversionPatternRewriter &rewriter) const override {
869 auto loc = op.getLoc();
870 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
871 switch (op.getFenceScope()) {
872 case xegpu::FenceScope::Workgroup:
873 memScope = xevm::MemScope::WORKGROUP;
874 break;
875 case xegpu::FenceScope::GPU:
876 memScope = xevm::MemScope::DEVICE;
877 break;
878 }
879 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
880 switch (op.getMemoryKind()) {
881 case xegpu::MemorySpace::Global:
882 addrSpace = xevm::AddrSpace::GLOBAL;
883 break;
884 case xegpu::MemorySpace::SLM:
885 addrSpace = xevm::AddrSpace::SHARED;
886 break;
887 }
888 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
889 rewriter.eraseOp(op);
890 return success();
891 }
892};
893
894class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
895 using OpConversionPattern::OpConversionPattern;
896 LogicalResult
897 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
898 ConversionPatternRewriter &rewriter) const override {
899 auto loc = op.getLoc();
900 auto ctxt = rewriter.getContext();
901 auto aTy = cast<VectorType>(op.getLhs().getType());
902 auto bTy = cast<VectorType>(op.getRhs().getType());
903 auto resultType = cast<VectorType>(op.getResultType());
904
905 auto encodePrecision = [&](Type type) -> xevm::ElemType {
906 if (type == rewriter.getBF16Type())
907 return xevm::ElemType::BF16;
908 else if (type == rewriter.getF16Type())
909 return xevm::ElemType::F16;
910 else if (type == rewriter.getTF32Type())
911 return xevm::ElemType::TF32;
912 else if (type.isInteger(8)) {
913 if (type.isUnsignedInteger())
914 return xevm::ElemType::U8;
915 return xevm::ElemType::S8;
916 } else if (type == rewriter.getF32Type())
917 return xevm::ElemType::F32;
918 else if (type.isInteger(32))
919 return xevm::ElemType::S32;
920 llvm_unreachable("add more support for ElemType");
921 };
922 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
923 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
924 Value c = op.getAcc();
925 if (!c) {
926 auto elementTy = resultType.getElementType();
927 Attribute initValueAttr;
928 if (isa<FloatType>(elementTy))
929 initValueAttr = FloatAttr::get(elementTy, 0.0);
930 else
931 initValueAttr = IntegerAttr::get(elementTy, 0);
932 c = arith::ConstantOp::create(
933 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
934 }
935
936 Value aVec = op.getLhs();
937 Value bVec = op.getRhs();
938 auto cvecty = cast<VectorType>(c.getType());
939 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
940 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
941 VectorType cNty =
942 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
943 if (cvecty != cNty)
944 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
945 Value dpasRes = xevm::MMAOp::create(
946 rewriter, loc, cNty, aVec, bVec, c,
947 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
948 systolicDepth *
949 getNumOperandsPerDword(precATy)),
950 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
951 if (cvecty != cNty)
952 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
953 rewriter.replaceOp(op, dpasRes);
954 return success();
955 }
956
957private:
958 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
959 switch (pTy) {
960 case xevm::ElemType::TF32:
961 return 1;
962 case xevm::ElemType::BF16:
963 case xevm::ElemType::F16:
964 return 2;
965 case xevm::ElemType::U8:
966 case xevm::ElemType::S8:
967 return 4;
968 default:
969 llvm_unreachable("unsupported xevm::ElemType");
970 }
971 }
972};
973
974static std::optional<LLVM::AtomicBinOp>
975matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
976 switch (arithKind) {
977 case arith::AtomicRMWKind::addf:
978 return LLVM::AtomicBinOp::fadd;
979 case arith::AtomicRMWKind::addi:
980 return LLVM::AtomicBinOp::add;
981 case arith::AtomicRMWKind::assign:
982 return LLVM::AtomicBinOp::xchg;
983 case arith::AtomicRMWKind::maximumf:
984 return LLVM::AtomicBinOp::fmax;
985 case arith::AtomicRMWKind::maxs:
986 return LLVM::AtomicBinOp::max;
987 case arith::AtomicRMWKind::maxu:
988 return LLVM::AtomicBinOp::umax;
989 case arith::AtomicRMWKind::minimumf:
990 return LLVM::AtomicBinOp::fmin;
991 case arith::AtomicRMWKind::mins:
992 return LLVM::AtomicBinOp::min;
993 case arith::AtomicRMWKind::minu:
994 return LLVM::AtomicBinOp::umin;
995 case arith::AtomicRMWKind::ori:
996 return LLVM::AtomicBinOp::_or;
997 case arith::AtomicRMWKind::andi:
998 return LLVM::AtomicBinOp::_and;
999 default:
1000 return std::nullopt;
1001 }
1002}
1003
1004class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1005 using OpConversionPattern::OpConversionPattern;
1006 LogicalResult
1007 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1008 ConversionPatternRewriter &rewriter) const override {
1009 auto loc = op.getLoc();
1010 auto ctxt = rewriter.getContext();
1011 auto tdesc = op.getTensorDesc().getType();
1012 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1013 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1014 Value basePtrI64 = arith::IndexCastOp::create(
1015 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1016 Value basePtrLLVM =
1017 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1018 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1019 VectorType srcOrDstFlatVecTy = VectorType::get(
1020 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1021 Value srcFlatVec = vector::ShapeCastOp::create(
1022 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1023 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1024 assert(atomicKind.has_value());
1025 Value resVec = srcFlatVec;
1026 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1027 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1028 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1029 rewriter.getIndexAttr(i));
1030 Value currPtr =
1031 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1032 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1033 Value newVal =
1034 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1035 val, LLVM::AtomicOrdering::seq_cst);
1036 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1037 }
1038 rewriter.replaceOp(op, resVec);
1039 return success();
1040 }
1041};
1042
1043//===----------------------------------------------------------------------===//
1044// Pass Definition
1045//===----------------------------------------------------------------------===//
1046
1047struct ConvertXeGPUToXeVMPass
1048 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1049 using Base::Base;
1050
1051 void runOnOperation() override {
1052 LLVMTypeConverter typeConverter(&getContext());
1053 typeConverter.addConversion([&](VectorType type) -> Type {
1054 unsigned rank = type.getRank();
1055 auto elemType = type.getElementType();
1056 // If the element type is index, convert it to i64.
1057 if (llvm::isa<IndexType>(elemType))
1058 elemType = IntegerType::get(&getContext(), 64);
1059 // If the vector rank is 0 or has a single element, return the element
1060 if (rank == 0 || type.getNumElements() == 1)
1061 return elemType;
1062 // Otherwise, convert the vector to a flat vector type.
1063 int64_t sum = llvm::product_of(type.getShape());
1064 return VectorType::get(sum, elemType);
1065 });
1066 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1067 // Scattered descriptors are not supported in XeVM lowering.
1068 if (type.isScattered())
1069 return {};
1070 if (type.getRank() == 1)
1071 return IntegerType::get(&getContext(), 64);
1072 auto i32Type = IntegerType::get(&getContext(), 32);
1073 return VectorType::get(8, i32Type);
1074 });
1075 // Convert MemDescType into i32 for SLM
1076 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1077 return IntegerType::get(&getContext(), 32);
1078 });
1079
1080 typeConverter.addConversion([&](MemRefType type) -> Type {
1081 return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
1082 });
1083
1084 // LLVM type converter puts unrealized casts for the following cases:
1085 // add materialization casts to handle them.
1086
1087 // Materialization to convert memref to i64 or i32 depending on global/SLM
1088 // Applies only to target materialization.
1089 // Note: int type to memref materialization is not required as xegpu ops
1090 // currently do not produce memrefs as result.
1091 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1092 ValueRange inputs,
1093 Location loc) -> Value {
1094 if (inputs.size() != 1)
1095 return {};
1096 auto input = inputs.front();
1097 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1098 unsigned rank = memrefTy.getRank();
1099 Type indexType = builder.getIndexType();
1100
1101 int64_t intOffsets;
1102 SmallVector<int64_t> intStrides;
1103 Value addr;
1104 Value offset;
1105 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1106 ShapedType::isStatic(intOffsets)) {
1107 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1108 input);
1109 offset = arith::ConstantOp::create(builder, loc,
1110 builder.getIndexAttr(intOffsets));
1111 } else {
1112
1113 // Result types: [base_memref, offset, stride0, stride1, ...,
1114 // strideN-1, size0, size1, ..., sizeN-1]
1115 SmallVector<Type> resultTypes{
1116 MemRefType::get({}, memrefTy.getElementType(),
1117 MemRefLayoutAttrInterface(),
1118 memrefTy.getMemorySpace()),
1119 indexType};
1120 // strides + sizes
1121 resultTypes.append(2 * rank, indexType);
1122
1123 auto meta = memref::ExtractStridedMetadataOp::create(
1124 builder, loc, resultTypes, input);
1125
1126 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1127 builder, loc, meta.getBaseBuffer());
1128 offset = meta.getOffset();
1129 }
1130
1131 auto addrCasted =
1132 arith::IndexCastUIOp::create(builder, loc, type, addr);
1133 auto offsetCasted =
1134 arith::IndexCastUIOp::create(builder, loc, type, offset);
1135
1136 // Compute the final address: base address + byte offset
1137 auto byteSize = arith::ConstantOp::create(
1138 builder, loc, type,
1139 builder.getIntegerAttr(type,
1140 memrefTy.getElementTypeBitWidth() / 8));
1141 auto byteOffset =
1142 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1143 auto addrWithOffset =
1144 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1145
1146 return addrWithOffset.getResult();
1147 }
1148 return {};
1149 };
1150
1151 // Materialization to convert ui64 to i64
1152 // Applies only to target materialization.
1153 // Note: i64 to ui64 materialization is not required as xegpu ops
1154 // currently do not produce ui64 as result.
1155 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1156 ValueRange inputs,
1157 Location loc) -> Value {
1158 if (inputs.size() != 1)
1159 return {};
1160 auto input = inputs.front();
1161 if (input.getType() == builder.getIntegerType(64, false)) {
1162 Value cast =
1163 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1164 .getResult();
1165 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1166 .getResult();
1167 }
1168 return {};
1169 };
1170
1171 // Materialization to convert ui32 to i32
1172 // Applies only to target materialization.
1173 // Note: i32 to ui32 materialization is not required as xegpu ops
1174 // currently do not produce ui32 as result.
1175 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1176 ValueRange inputs,
1177 Location loc) -> Value {
1178 if (inputs.size() != 1)
1179 return {};
1180 auto input = inputs.front();
1181 if (input.getType() == builder.getIntegerType(32, false)) {
1182 Value cast =
1183 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1184 .getResult();
1185 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1186 .getResult();
1187 }
1188 return {};
1189 };
1190
1191 // Materialization to convert
1192 // - bitcast vector of same rank
1193 // - shape vector of different rank but same element type
1194 // Applies to both source and target materialization.
1195 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1196 ValueRange inputs,
1197 Location loc) -> Value {
1198 if (inputs.size() != 1)
1199 return {};
1200 auto input = inputs.front();
1201 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1202 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1203 // If the target type is a vector of same rank,
1204 // bitcast to the target type.
1205 if (targetVecTy.getRank() == vecTy.getRank())
1206 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1207 .getResult();
1208 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1209 // If the target type is a vector of different rank but same element
1210 // type, reshape to the target type.
1211 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1212 .getResult();
1213 }
1214 }
1215 }
1216 return {};
1217 };
1218
1219 // Materialization to convert
1220 // - single element vector to single element of vector element type
1221 // Applies only to target materialization.
1222 auto vectorToSingleElementMaterializationCast =
1223 [](OpBuilder &builder, Type type, ValueRange inputs,
1224 Location loc) -> Value {
1225 if (inputs.size() != 1)
1226 return {};
1227 auto input = inputs.front();
1228 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1229 if (type == vecTy.getElementType() ||
1230 ((vecTy.getElementType() == builder.getIndexType()) &&
1231 type.isInteger())) {
1232 // If the vector rank is 0 or has a single element,
1233 // extract scalar of target type.
1234 auto rank = vecTy.getRank();
1235 Value cast;
1236 if (rank == 0) {
1237 cast =
1238 vector::ExtractOp::create(builder, loc, input, {}).getResult();
1239 } else {
1240 cast = vector::ExtractOp::create(builder, loc, input,
1241 SmallVector<int64_t>(rank, 0))
1242 .getResult();
1243 }
1244 if (type != vecTy.getElementType())
1245 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1246 .getResult();
1247 return cast;
1248 }
1249 }
1250 return {};
1251 };
1252
1253 // Materialization to convert
1254 // - single element of vector element type to single element vector
1255 // If result type of original op is single element vector and lowered type
1256 // is scalar. This materialization cast creates a single element vector by
1257 // broadcasting the scalar value.
1258 // Applies only to source materialization.
1259 auto singleElementToVectorMaterializationCast =
1260 [](OpBuilder &builder, Type type, ValueRange inputs,
1261 Location loc) -> Value {
1262 if (inputs.size() != 1)
1263 return {};
1264 auto input = inputs.front();
1265 // If the target type is a vector of rank 0 or single element vector
1266 // of element type matching input type, broadcast input to target type.
1267 if (auto vecTy = dyn_cast<VectorType>(type)) {
1268 if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
1269 if (input.getType() == vecTy.getElementType()) {
1270 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1271 .getResult();
1272 } else if (vecTy.getElementType() == builder.getIndexType()) {
1273 Value cast = arith::IndexCastUIOp::create(
1274 builder, loc, builder.getIndexType(), input)
1275 .getResult();
1276 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1277 .getResult();
1278 }
1279 }
1280 }
1281 return {};
1282 };
1283 typeConverter.addSourceMaterialization(
1284 singleElementToVectorMaterializationCast);
1285 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1286 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1287 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1288 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1289 typeConverter.addTargetMaterialization(
1290 vectorToSingleElementMaterializationCast);
1291 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1292 ConversionTarget target(getContext());
1293 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1294 vector::VectorDialect, arith::ArithDialect,
1295 memref::MemRefDialect, gpu::GPUDialect,
1296 index::IndexDialect>();
1297 target.addIllegalDialect<xegpu::XeGPUDialect>();
1298
1299 RewritePatternSet patterns(&getContext());
1300 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1302 patterns, target);
1303 if (failed(applyPartialConversion(getOperation(), target,
1304 std::move(patterns))))
1305 signalPassFailure();
1306 }
1307};
1308} // namespace
1309
1310//===----------------------------------------------------------------------===//
1311// Pattern Population
1312//===----------------------------------------------------------------------===//
1314 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1315 patterns.add<CreateNdDescToXeVMPattern,
1316 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1317 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1318 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1319 typeConverter, patterns.getContext());
1320 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1321 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1322 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1323 typeConverter, patterns.getContext());
1324 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1325 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1326 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1327 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1328 patterns.getContext());
1329}
return success()
b getContext())
auto load
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:590
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:103
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:120
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)